2022-11-29 01:26:47 -05:00

617 lines
19 KiB
Python

from doctest import DocTestSuite
from unittest import TestCase
from itertools import combinations
from six.moves import range
import more_itertools as mi
def load_tests(loader, tests, ignore):
# Add the doctests
tests.addTests(DocTestSuite('more_itertools.recipes'))
return tests
class AccumulateTests(TestCase):
"""Tests for ``accumulate()``"""
def test_empty(self):
"""Test that an empty input returns an empty output"""
self.assertEqual(list(mi.accumulate([])), [])
def test_default(self):
"""Test accumulate with the default function (addition)"""
self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6])
def test_bogus_function(self):
"""Test accumulate with an invalid function"""
with self.assertRaises(TypeError):
list(mi.accumulate([1, 2, 3], func=lambda x: x))
def test_custom_function(self):
"""Test accumulate with a custom function"""
self.assertEqual(
list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3]
)
class TakeTests(TestCase):
"""Tests for ``take()``"""
def test_simple_take(self):
"""Test basic usage"""
t = mi.take(5, range(10))
self.assertEqual(t, [0, 1, 2, 3, 4])
def test_null_take(self):
"""Check the null case"""
t = mi.take(0, range(10))
self.assertEqual(t, [])
def test_negative_take(self):
"""Make sure taking negative items results in a ValueError"""
self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
def test_take_too_much(self):
"""Taking more than an iterator has remaining should return what the
iterator has remaining.
"""
t = mi.take(10, range(5))
self.assertEqual(t, [0, 1, 2, 3, 4])
class TabulateTests(TestCase):
"""Tests for ``tabulate()``"""
def test_simple_tabulate(self):
"""Test the happy path"""
t = mi.tabulate(lambda x: x)
f = tuple([next(t) for _ in range(3)])
self.assertEqual(f, (0, 1, 2))
def test_count(self):
"""Ensure tabulate accepts specific count"""
t = mi.tabulate(lambda x: 2 * x, -1)
f = (next(t), next(t), next(t))
self.assertEqual(f, (-2, 0, 2))
class TailTests(TestCase):
"""Tests for ``tail()``"""
def test_greater(self):
"""Length of iterable is greater than requested tail"""
self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G'])
def test_equal(self):
"""Length of iterable is equal to the requested tail"""
self.assertEqual(
list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
)
def test_less(self):
"""Length of iterable is less than requested tail"""
self.assertEqual(
list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
)
class ConsumeTests(TestCase):
"""Tests for ``consume()``"""
def test_sanity(self):
"""Test basic functionality"""
r = (x for x in range(10))
mi.consume(r, 3)
self.assertEqual(3, next(r))
def test_null_consume(self):
"""Check the null case"""
r = (x for x in range(10))
mi.consume(r, 0)
self.assertEqual(0, next(r))
def test_negative_consume(self):
"""Check that negative consumsion throws an error"""
r = (x for x in range(10))
self.assertRaises(ValueError, lambda: mi.consume(r, -1))
def test_total_consume(self):
"""Check that iterator is totally consumed by default"""
r = (x for x in range(10))
mi.consume(r)
self.assertRaises(StopIteration, lambda: next(r))
class NthTests(TestCase):
"""Tests for ``nth()``"""
def test_basic(self):
"""Make sure the nth item is returned"""
l = range(10)
for i, v in enumerate(l):
self.assertEqual(mi.nth(l, i), v)
def test_default(self):
"""Ensure a default value is returned when nth item not found"""
l = range(3)
self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
def test_negative_item_raises(self):
"""Ensure asking for a negative item raises an exception"""
self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
class AllEqualTests(TestCase):
"""Tests for ``all_equal()``"""
def test_true(self):
"""Everything is equal"""
self.assertTrue(mi.all_equal('aaaaaa'))
self.assertTrue(mi.all_equal([0, 0, 0, 0]))
def test_false(self):
"""Not everything is equal"""
self.assertFalse(mi.all_equal('aaaaab'))
self.assertFalse(mi.all_equal([0, 0, 0, 1]))
def test_tricky(self):
"""Not everything is identical, but everything is equal"""
items = [1, complex(1, 0), 1.0]
self.assertTrue(mi.all_equal(items))
def test_empty(self):
"""Return True if the iterable is empty"""
self.assertTrue(mi.all_equal(''))
self.assertTrue(mi.all_equal([]))
def test_one(self):
"""Return True if the iterable is singular"""
self.assertTrue(mi.all_equal('0'))
self.assertTrue(mi.all_equal([0]))
class QuantifyTests(TestCase):
"""Tests for ``quantify()``"""
def test_happy_path(self):
"""Make sure True count is returned"""
q = [True, False, True]
self.assertEqual(mi.quantify(q), 2)
def test_custom_predicate(self):
"""Ensure non-default predicates return as expected"""
q = range(10)
self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
class PadnoneTests(TestCase):
"""Tests for ``padnone()``"""
def test_happy_path(self):
"""wrapper iterator should return None indefinitely"""
r = range(2)
p = mi.padnone(r)
self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)])
class NcyclesTests(TestCase):
"""Tests for ``nyclces()``"""
def test_happy_path(self):
"""cycle a sequence three times"""
r = ["a", "b", "c"]
n = mi.ncycles(r, 3)
self.assertEqual(
["a", "b", "c", "a", "b", "c", "a", "b", "c"],
list(n)
)
def test_null_case(self):
"""asking for 0 cycles should return an empty iterator"""
n = mi.ncycles(range(100), 0)
self.assertRaises(StopIteration, lambda: next(n))
def test_pathalogical_case(self):
"""asking for negative cycles should return an empty iterator"""
n = mi.ncycles(range(100), -10)
self.assertRaises(StopIteration, lambda: next(n))
class DotproductTests(TestCase):
"""Tests for ``dotproduct()``'"""
def test_happy_path(self):
"""simple dotproduct example"""
self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
class FlattenTests(TestCase):
"""Tests for ``flatten()``"""
def test_basic_usage(self):
"""ensure list of lists is flattened one level"""
f = [[0, 1, 2], [3, 4, 5]]
self.assertEqual(list(range(6)), list(mi.flatten(f)))
def test_single_level(self):
"""ensure list of lists is flattened only one level"""
f = [[0, [1, 2]], [[3, 4], 5]]
self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
class RepeatfuncTests(TestCase):
"""Tests for ``repeatfunc()``"""
def test_simple_repeat(self):
"""test simple repeated functions"""
r = mi.repeatfunc(lambda: 5)
self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
def test_finite_repeat(self):
"""ensure limited repeat when times is provided"""
r = mi.repeatfunc(lambda: 5, times=5)
self.assertEqual([5, 5, 5, 5, 5], list(r))
def test_added_arguments(self):
"""ensure arguments are applied to the function"""
r = mi.repeatfunc(lambda x: x, 2, 3)
self.assertEqual([3, 3], list(r))
def test_null_times(self):
"""repeat 0 should return an empty iterator"""
r = mi.repeatfunc(range, 0, 3)
self.assertRaises(StopIteration, lambda: next(r))
class PairwiseTests(TestCase):
"""Tests for ``pairwise()``"""
def test_base_case(self):
"""ensure an iterable will return pairwise"""
p = mi.pairwise([1, 2, 3])
self.assertEqual([(1, 2), (2, 3)], list(p))
def test_short_case(self):
"""ensure an empty iterator if there's not enough values to pair"""
p = mi.pairwise("a")
self.assertRaises(StopIteration, lambda: next(p))
class GrouperTests(TestCase):
"""Tests for ``grouper()``"""
def test_even(self):
"""Test when group size divides evenly into the length of
the iterable.
"""
self.assertEqual(
list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')]
)
def test_odd(self):
"""Test when group size does not divide evenly into the length of the
iterable.
"""
self.assertEqual(
list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)]
)
def test_fill_value(self):
"""Test that the fill value is used to pad the final group"""
self.assertEqual(
list(mi.grouper(3, 'ABCDE', 'x')),
[('A', 'B', 'C'), ('D', 'E', 'x')]
)
class RoundrobinTests(TestCase):
"""Tests for ``roundrobin()``"""
def test_even_groups(self):
"""Ensure ordered output from evenly populated iterables"""
self.assertEqual(
list(mi.roundrobin('ABC', [1, 2, 3], range(3))),
['A', 1, 0, 'B', 2, 1, 'C', 3, 2]
)
def test_uneven_groups(self):
"""Ensure ordered output from unevenly populated iterables"""
self.assertEqual(
list(mi.roundrobin('ABCD', [1, 2], range(0))),
['A', 1, 'B', 2, 'C', 'D']
)
class PartitionTests(TestCase):
"""Tests for ``partition()``"""
def test_bool(self):
"""Test when pred() returns a boolean"""
lesser, greater = mi.partition(lambda x: x > 5, range(10))
self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5])
self.assertEqual(list(greater), [6, 7, 8, 9])
def test_arbitrary(self):
"""Test when pred() returns an integer"""
divisibles, remainders = mi.partition(lambda x: x % 3, range(10))
self.assertEqual(list(divisibles), [0, 3, 6, 9])
self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
class PowersetTests(TestCase):
"""Tests for ``powerset()``"""
def test_combinatorics(self):
"""Ensure a proper enumeration"""
p = mi.powerset([1, 2, 3])
self.assertEqual(
list(p),
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
)
class UniqueEverseenTests(TestCase):
"""Tests for ``unique_everseen()``"""
def test_everseen(self):
"""ensure duplicate elements are ignored"""
u = mi.unique_everseen('AAAABBBBCCDAABBB')
self.assertEqual(
['A', 'B', 'C', 'D'],
list(u)
)
def test_custom_key(self):
"""ensure the custom key comparison works"""
u = mi.unique_everseen('aAbACCc', key=str.lower)
self.assertEqual(list('abC'), list(u))
def test_unhashable(self):
"""ensure things work for unhashable items"""
iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
u = mi.unique_everseen(iterable)
self.assertEqual(list(u), ['a', [1, 2, 3]])
def test_unhashable_key(self):
"""ensure things work for unhashable items with a custom key"""
iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
u = mi.unique_everseen(iterable, key=lambda x: x)
self.assertEqual(list(u), ['a', [1, 2, 3]])
class UniqueJustseenTests(TestCase):
"""Tests for ``unique_justseen()``"""
def test_justseen(self):
"""ensure only last item is remembered"""
u = mi.unique_justseen('AAAABBBCCDABB')
self.assertEqual(list('ABCDAB'), list(u))
def test_custom_key(self):
"""ensure the custom key comparison works"""
u = mi.unique_justseen('AABCcAD', str.lower)
self.assertEqual(list('ABCAD'), list(u))
class IterExceptTests(TestCase):
"""Tests for ``iter_except()``"""
def test_exact_exception(self):
"""ensure the exact specified exception is caught"""
l = [1, 2, 3]
i = mi.iter_except(l.pop, IndexError)
self.assertEqual(list(i), [3, 2, 1])
def test_generic_exception(self):
"""ensure the generic exception can be caught"""
l = [1, 2]
i = mi.iter_except(l.pop, Exception)
self.assertEqual(list(i), [2, 1])
def test_uncaught_exception_is_raised(self):
"""ensure a non-specified exception is raised"""
l = [1, 2, 3]
i = mi.iter_except(l.pop, KeyError)
self.assertRaises(IndexError, lambda: list(i))
def test_first(self):
"""ensure first is run before the function"""
l = [1, 2, 3]
f = lambda: 25
i = mi.iter_except(l.pop, IndexError, f)
self.assertEqual(list(i), [25, 3, 2, 1])
class FirstTrueTests(TestCase):
"""Tests for ``first_true()``"""
def test_something_true(self):
"""Test with no keywords"""
self.assertEqual(mi.first_true(range(10)), 1)
def test_nothing_true(self):
"""Test default return value."""
self.assertIsNone(mi.first_true([0, 0, 0]))
def test_default(self):
"""Test with a default keyword"""
self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
def test_pred(self):
"""Test with a custom predicate"""
self.assertEqual(
mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6
)
class RandomProductTests(TestCase):
"""Tests for ``random_product()``
Since random.choice() has different results with the same seed across
python versions 2.x and 3.x, these tests use highly probably events to
create predictable outcomes across platforms.
"""
def test_simple_lists(self):
"""Ensure that one item is chosen from each list in each pair.
Also ensure that each item from each list eventually appears in
the chosen combinations.
Odds are roughly 1 in 7.1 * 10e16 that one item from either list will
not be chosen after 100 samplings of one item from each list. Just to
be safe, better use a known random seed, too.
"""
nums = [1, 2, 3]
lets = ['a', 'b', 'c']
n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)])
n, m = set(n), set(m)
self.assertEqual(n, set(nums))
self.assertEqual(m, set(lets))
self.assertEqual(len(n), len(nums))
self.assertEqual(len(m), len(lets))
def test_list_with_repeat(self):
"""ensure multiple items are chosen, and that they appear to be chosen
from one list then the next, in proper order.
"""
nums = [1, 2, 3]
lets = ['a', 'b', 'c']
r = list(mi.random_product(nums, lets, repeat=100))
self.assertEqual(2 * 100, len(r))
n, m = set(r[::2]), set(r[1::2])
self.assertEqual(n, set(nums))
self.assertEqual(m, set(lets))
self.assertEqual(len(n), len(nums))
self.assertEqual(len(m), len(lets))
class RandomPermutationTests(TestCase):
"""Tests for ``random_permutation()``"""
def test_full_permutation(self):
"""ensure every item from the iterable is returned in a new ordering
15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so
we fix a seed value just to be sure.
"""
i = range(15)
r = mi.random_permutation(i)
self.assertEqual(set(i), set(r))
if i == r:
raise AssertionError("Values were not permuted")
def test_partial_permutation(self):
"""ensure all returned items are from the iterable, that the returned
permutation is of the desired length, and that all items eventually
get returned.
Sampling 100 permutations of length 5 from a set of 15 leaves a
(2/3)^100 chance that an item will not be chosen. Multiplied by 15
items, there is a 1 in 2.6e16 chance that at least 1 item will not
show up in the resulting output. Using a random seed will fix that.
"""
items = range(15)
item_set = set(items)
all_items = set()
for _ in range(100):
permutation = mi.random_permutation(items, 5)
self.assertEqual(len(permutation), 5)
permutation_set = set(permutation)
self.assertLessEqual(permutation_set, item_set)
all_items |= permutation_set
self.assertEqual(all_items, item_set)
class RandomCombinationTests(TestCase):
"""Tests for ``random_combination()``"""
def test_pseudorandomness(self):
"""ensure different subsets of the iterable get returned over many
samplings of random combinations"""
items = range(15)
all_items = set()
for _ in range(50):
combination = mi.random_combination(items, 5)
all_items |= set(combination)
self.assertEqual(all_items, set(items))
def test_no_replacement(self):
"""ensure that elements are sampled without replacement"""
items = range(15)
for _ in range(50):
combination = mi.random_combination(items, len(items))
self.assertEqual(len(combination), len(set(combination)))
self.assertRaises(
ValueError, lambda: mi.random_combination(items, len(items) + 1)
)
class RandomCombinationWithReplacementTests(TestCase):
"""Tests for ``random_combination_with_replacement()``"""
def test_replacement(self):
"""ensure that elements are sampled with replacement"""
items = range(5)
combo = mi.random_combination_with_replacement(items, len(items) * 2)
self.assertEqual(2 * len(items), len(combo))
if len(set(combo)) == len(combo):
raise AssertionError("Combination contained no duplicates")
def test_pseudorandomness(self):
"""ensure different subsets of the iterable get returned over many
samplings of random combinations"""
items = range(15)
all_items = set()
for _ in range(50):
combination = mi.random_combination_with_replacement(items, 5)
all_items |= set(combination)
self.assertEqual(all_items, set(items))
class NthCombinationTests(TestCase):
def test_basic(self):
iterable = 'abcdefg'
r = 4
for index, expected in enumerate(combinations(iterable, r)):
actual = mi.nth_combination(iterable, r, index)
self.assertEqual(actual, expected)
def test_long(self):
actual = mi.nth_combination(range(180), 4, 2000000)
expected = (2, 12, 35, 126)
self.assertEqual(actual, expected)
def test_invalid_r(self):
for r in (-1, 3):
with self.assertRaises(ValueError):
mi.nth_combination([], r, 0)
def test_invalid_index(self):
with self.assertRaises(IndexError):
mi.nth_combination('abcdefg', 3, -36)
class PrependTests(TestCase):
def test_basic(self):
value = 'a'
iterator = iter('bcdefg')
actual = list(mi.prepend(value, iterator))
expected = list('abcdefg')
self.assertEqual(actual, expected)
def test_multiple(self):
value = 'ab'
iterator = iter('cdefg')
actual = tuple(mi.prepend(value, iterator))
expected = ('ab',) + tuple('cdefg')
self.assertEqual(actual, expected)