home *** CD-ROM | disk | FTP | other *** search
/ Maximum CD 2011 July / maximum-cd-2011-07.iso / DiscContents / LibO_3.3.2_Win_x86_install_multi.exe / libreoffice1.cab / test_heapq.py < prev    next >
Encoding:
Python Source  |  2011-03-15  |  12.9 KB  |  389 lines

  1. """Unittests for heapq."""
  2.  
  3. import random
  4. import unittest
  5. from test import test_support
  6. import sys
  7.  
  8. # We do a bit of trickery here to be able to test both the C implementation
  9. # and the Python implementation of the module.
  10.  
  11. # Make it impossible to import the C implementation anymore.
  12. sys.modules['_heapq'] = 0
  13. # We must also handle the case that heapq was imported before.
  14. if 'heapq' in sys.modules:
  15.     del sys.modules['heapq']
  16.  
  17. # Now we can import the module and get the pure Python implementation.
  18. import heapq as py_heapq
  19.  
  20. # Restore everything to normal.
  21. del sys.modules['_heapq']
  22. del sys.modules['heapq']
  23.  
  24. # This is now the module with the C implementation.
  25. import heapq as c_heapq
  26.  
  27.  
  28. class TestHeap(unittest.TestCase):
  29.     module = None
  30.  
  31.     def test_push_pop(self):
  32.         # 1) Push 256 random numbers and pop them off, verifying all's OK.
  33.         heap = []
  34.         data = []
  35.         self.check_invariant(heap)
  36.         for i in range(256):
  37.             item = random.random()
  38.             data.append(item)
  39.             self.module.heappush(heap, item)
  40.             self.check_invariant(heap)
  41.         results = []
  42.         while heap:
  43.             item = self.module.heappop(heap)
  44.             self.check_invariant(heap)
  45.             results.append(item)
  46.         data_sorted = data[:]
  47.         data_sorted.sort()
  48.         self.assertEqual(data_sorted, results)
  49.         # 2) Check that the invariant holds for a sorted array
  50.         self.check_invariant(results)
  51.  
  52.         self.assertRaises(TypeError, self.module.heappush, [])
  53.         try:
  54.             self.assertRaises(TypeError, self.module.heappush, None, None)
  55.             self.assertRaises(TypeError, self.module.heappop, None)
  56.         except AttributeError:
  57.             pass
  58.  
  59.     def check_invariant(self, heap):
  60.         # Check the heap invariant.
  61.         for pos, item in enumerate(heap):
  62.             if pos: # pos 0 has no parent
  63.                 parentpos = (pos-1) >> 1
  64.                 self.assert_(heap[parentpos] <= item)
  65.  
  66.     def test_heapify(self):
  67.         for size in range(30):
  68.             heap = [random.random() for dummy in range(size)]
  69.             self.module.heapify(heap)
  70.             self.check_invariant(heap)
  71.  
  72.         self.assertRaises(TypeError, self.module.heapify, None)
  73.  
  74.     def test_naive_nbest(self):
  75.         data = [random.randrange(2000) for i in range(1000)]
  76.         heap = []
  77.         for item in data:
  78.             self.module.heappush(heap, item)
  79.             if len(heap) > 10:
  80.                 self.module.heappop(heap)
  81.         heap.sort()
  82.         self.assertEqual(heap, sorted(data)[-10:])
  83.  
  84.     def heapiter(self, heap):
  85.         # An iterator returning a heap's elements, smallest-first.
  86.         try:
  87.             while 1:
  88.                 yield self.module.heappop(heap)
  89.         except IndexError:
  90.             pass
  91.  
  92.     def test_nbest(self):
  93.         # Less-naive "N-best" algorithm, much faster (if len(data) is big
  94.         # enough <wink>) than sorting all of data.  However, if we had a max
  95.         # heap instead of a min heap, it could go faster still via
  96.         # heapify'ing all of data (linear time), then doing 10 heappops
  97.         # (10 log-time steps).
  98.         data = [random.randrange(2000) for i in range(1000)]
  99.         heap = data[:10]
  100.         self.module.heapify(heap)
  101.         for item in data[10:]:
  102.             if item > heap[0]:  # this gets rarer the longer we run
  103.                 self.module.heapreplace(heap, item)
  104.         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
  105.  
  106.         self.assertRaises(TypeError, self.module.heapreplace, None)
  107.         self.assertRaises(TypeError, self.module.heapreplace, None, None)
  108.         self.assertRaises(IndexError, self.module.heapreplace, [], None)
  109.  
  110.     def test_nbest_with_pushpop(self):
  111.         data = [random.randrange(2000) for i in range(1000)]
  112.         heap = data[:10]
  113.         self.module.heapify(heap)
  114.         for item in data[10:]:
  115.             self.module.heappushpop(heap, item)
  116.         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
  117.         self.assertEqual(self.module.heappushpop([], 'x'), 'x')
  118.  
  119.     def test_heappushpop(self):
  120.         h = []
  121.         x = self.module.heappushpop(h, 10)
  122.         self.assertEqual((h, x), ([], 10))
  123.  
  124.         h = [10]
  125.         x = self.module.heappushpop(h, 10.0)
  126.         self.assertEqual((h, x), ([10], 10.0))
  127.         self.assertEqual(type(h[0]), int)
  128.         self.assertEqual(type(x), float)
  129.  
  130.         h = [10];
  131.         x = self.module.heappushpop(h, 9)
  132.         self.assertEqual((h, x), ([10], 9))
  133.  
  134.         h = [10];
  135.         x = self.module.heappushpop(h, 11)
  136.         self.assertEqual((h, x), ([11], 10))
  137.  
  138.     def test_heapsort(self):
  139.         # Exercise everything with repeated heapsort checks
  140.         for trial in xrange(100):
  141.             size = random.randrange(50)
  142.             data = [random.randrange(25) for i in range(size)]
  143.             if trial & 1:     # Half of the time, use heapify
  144.                 heap = data[:]
  145.                 self.module.heapify(heap)
  146.             else:             # The rest of the time, use heappush
  147.                 heap = []
  148.                 for item in data:
  149.                     self.module.heappush(heap, item)
  150.             heap_sorted = [self.module.heappop(heap) for i in range(size)]
  151.             self.assertEqual(heap_sorted, sorted(data))
  152.  
  153.     def test_merge(self):
  154.         inputs = []
  155.         for i in xrange(random.randrange(5)):
  156.             row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
  157.             inputs.append(row)
  158.         self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
  159.         self.assertEqual(list(self.module.merge()), [])
  160.  
  161.     def test_merge_stability(self):
  162.         class Int(int):
  163.             pass
  164.         inputs = [[], [], [], []]
  165.         for i in range(20000):
  166.             stream = random.randrange(4)
  167.             x = random.randrange(500)
  168.             obj = Int(x)
  169.             obj.pair = (x, stream)
  170.             inputs[stream].append(obj)
  171.         for stream in inputs:
  172.             stream.sort()
  173.         result = [i.pair for i in self.module.merge(*inputs)]
  174.         self.assertEqual(result, sorted(result))
  175.  
  176.     def test_nsmallest(self):
  177.         data = [(random.randrange(2000), i) for i in range(1000)]
  178.         for f in (None, lambda x:  x[0] * 547 % 2000):
  179.             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  180.                 self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
  181.                 self.assertEqual(self.module.nsmallest(n, data, key=f),
  182.                                  sorted(data, key=f)[:n])
  183.  
  184.     def test_nlargest(self):
  185.         data = [(random.randrange(2000), i) for i in range(1000)]
  186.         for f in (None, lambda x:  x[0] * 547 % 2000):
  187.             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  188.                 self.assertEqual(self.module.nlargest(n, data),
  189.                                  sorted(data, reverse=True)[:n])
  190.                 self.assertEqual(self.module.nlargest(n, data, key=f),
  191.                                  sorted(data, key=f, reverse=True)[:n])
  192.  
  193. class TestHeapPython(TestHeap):
  194.     module = py_heapq
  195.  
  196. class TestHeapC(TestHeap):
  197.     module = c_heapq
  198.  
  199.     def test_comparison_operator(self):
  200.         # Issue 3501: Make sure heapq works with both __lt__ and __le__
  201.         def hsort(data, comp):
  202.             data = map(comp, data)
  203.             self.module.heapify(data)
  204.             return [self.module.heappop(data).x for i in range(len(data))]
  205.         class LT:
  206.             def __init__(self, x):
  207.                 self.x = x
  208.             def __lt__(self, other):
  209.                 return self.x > other.x
  210.         class LE:
  211.             def __init__(self, x):
  212.                 self.x = x
  213.             def __le__(self, other):
  214.                 return self.x >= other.x
  215.         data = [random.random() for i in range(100)]
  216.         target = sorted(data, reverse=True)
  217.         self.assertEqual(hsort(data, LT), target)
  218.         self.assertEqual(hsort(data, LE), target)
  219.  
  220.  
  221. #==============================================================================
  222.  
  223. class LenOnly:
  224.     "Dummy sequence class defining __len__ but not __getitem__."
  225.     def __len__(self):
  226.         return 10
  227.  
  228. class GetOnly:
  229.     "Dummy sequence class defining __getitem__ but not __len__."
  230.     def __getitem__(self, ndx):
  231.         return 10
  232.  
  233. class CmpErr:
  234.     "Dummy element that always raises an error during comparison"
  235.     def __cmp__(self, other):
  236.         raise ZeroDivisionError
  237.  
  238. def R(seqn):
  239.     'Regular generator'
  240.     for i in seqn:
  241.         yield i
  242.  
  243. class G:
  244.     'Sequence using __getitem__'
  245.     def __init__(self, seqn):
  246.         self.seqn = seqn
  247.     def __getitem__(self, i):
  248.         return self.seqn[i]
  249.  
  250. class I:
  251.     'Sequence using iterator protocol'
  252.     def __init__(self, seqn):
  253.         self.seqn = seqn
  254.         self.i = 0
  255.     def __iter__(self):
  256.         return self
  257.     def next(self):
  258.         if self.i >= len(self.seqn): raise StopIteration
  259.         v = self.seqn[self.i]
  260.         self.i += 1
  261.         return v
  262.  
  263. class Ig:
  264.     'Sequence using iterator protocol defined with a generator'
  265.     def __init__(self, seqn):
  266.         self.seqn = seqn
  267.         self.i = 0
  268.     def __iter__(self):
  269.         for val in self.seqn:
  270.             yield val
  271.  
  272. class X:
  273.     'Missing __getitem__ and __iter__'
  274.     def __init__(self, seqn):
  275.         self.seqn = seqn
  276.         self.i = 0
  277.     def next(self):
  278.         if self.i >= len(self.seqn): raise StopIteration
  279.         v = self.seqn[self.i]
  280.         self.i += 1
  281.         return v
  282.  
  283. class N:
  284.     'Iterator missing next()'
  285.     def __init__(self, seqn):
  286.         self.seqn = seqn
  287.         self.i = 0
  288.     def __iter__(self):
  289.         return self
  290.  
  291. class E:
  292.     'Test propagation of exceptions'
  293.     def __init__(self, seqn):
  294.         self.seqn = seqn
  295.         self.i = 0
  296.     def __iter__(self):
  297.         return self
  298.     def next(self):
  299.         3 // 0
  300.  
  301. class S:
  302.     'Test immediate stop'
  303.     def __init__(self, seqn):
  304.         pass
  305.     def __iter__(self):
  306.         return self
  307.     def next(self):
  308.         raise StopIteration
  309.  
  310. from itertools import chain, imap
  311. def L(seqn):
  312.     'Test multiple tiers of iterators'
  313.     return chain(imap(lambda x:x, R(Ig(G(seqn)))))
  314.  
  315. class TestErrorHandling(unittest.TestCase):
  316.     # only for C implementation
  317.     module = c_heapq
  318.  
  319.     def test_non_sequence(self):
  320.         for f in (self.module.heapify, self.module.heappop):
  321.             self.assertRaises(TypeError, f, 10)
  322.         for f in (self.module.heappush, self.module.heapreplace,
  323.                   self.module.nlargest, self.module.nsmallest):
  324.             self.assertRaises(TypeError, f, 10, 10)
  325.  
  326.     def test_len_only(self):
  327.         for f in (self.module.heapify, self.module.heappop):
  328.             self.assertRaises(TypeError, f, LenOnly())
  329.         for f in (self.module.heappush, self.module.heapreplace):
  330.             self.assertRaises(TypeError, f, LenOnly(), 10)
  331.         for f in (self.module.nlargest, self.module.nsmallest):
  332.             self.assertRaises(TypeError, f, 2, LenOnly())
  333.  
  334.     def test_get_only(self):
  335.         for f in (self.module.heapify, self.module.heappop):
  336.             self.assertRaises(TypeError, f, GetOnly())
  337.         for f in (self.module.heappush, self.module.heapreplace):
  338.             self.assertRaises(TypeError, f, GetOnly(), 10)
  339.         for f in (self.module.nlargest, self.module.nsmallest):
  340.             self.assertRaises(TypeError, f, 2, GetOnly())
  341.  
  342.     def test_get_only(self):
  343.         seq = [CmpErr(), CmpErr(), CmpErr()]
  344.         for f in (self.module.heapify, self.module.heappop):
  345.             self.assertRaises(ZeroDivisionError, f, seq)
  346.         for f in (self.module.heappush, self.module.heapreplace):
  347.             self.assertRaises(ZeroDivisionError, f, seq, 10)
  348.         for f in (self.module.nlargest, self.module.nsmallest):
  349.             self.assertRaises(ZeroDivisionError, f, 2, seq)
  350.  
  351.     def test_arg_parsing(self):
  352.         for f in (self.module.heapify, self.module.heappop,
  353.                   self.module.heappush, self.module.heapreplace,
  354.                   self.module.nlargest, self.module.nsmallest):
  355.             self.assertRaises(TypeError, f, 10)
  356.  
  357.     def test_iterable_args(self):
  358.         for f in (self.module.nlargest, self.module.nsmallest):
  359.             for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
  360.                 for g in (G, I, Ig, L, R):
  361.                     self.assertEqual(f(2, g(s)), f(2,s))
  362.                 self.assertEqual(f(2, S(s)), [])
  363.                 self.assertRaises(TypeError, f, 2, X(s))
  364.                 self.assertRaises(TypeError, f, 2, N(s))
  365.                 self.assertRaises(ZeroDivisionError, f, 2, E(s))
  366.  
  367.  
  368. #==============================================================================
  369.  
  370.  
  371. def test_main(verbose=None):
  372.     from types import BuiltinFunctionType
  373.  
  374.     test_classes = [TestHeapPython, TestHeapC, TestErrorHandling]
  375.     test_support.run_unittest(*test_classes)
  376.  
  377.     # verify reference counting
  378.     if verbose and hasattr(sys, "gettotalrefcount"):
  379.         import gc
  380.         counts = [None] * 5
  381.         for i in xrange(len(counts)):
  382.             test_support.run_unittest(*test_classes)
  383.             gc.collect()
  384.             counts[i] = sys.gettotalrefcount()
  385.         print counts
  386.  
  387. if __name__ == "__main__":
  388.     test_main(verbose=True)
  389.