home *** CD-ROM | disk | FTP | other *** search
/ PC Welt 2006 November (DVD) / PCWELT_11_2006.ISO / casper / filesystem.squashfs / usr / lib / python2.4 / site-packages / Numeric / Numeric.pyc (.txt) < prev    next >
Encoding:
Python Compiled Bytecode  |  2006-08-31  |  31.2 KB  |  942 lines

  1. # Source Generated with Decompyle++
  2. # File: in.pyc (Python 2.4)
  3.  
  4. '''Numeric module defining a multi-dimensional array and useful procedures for
  5.    Numerical computation.
  6.  
  7. Functions
  8.  
  9. -   array                      - NumPy Array construction
  10. -   zeros                      - Return an array of all zeros
  11. -   empty                      - Return an uninitialized array (200x faster than zeros)
  12. -   shape                      - Return shape of sequence or array
  13. -   rank                       - Return number of dimensions
  14. -   size                       - Return number of elements in entire array or a
  15.                                  certain dimension
  16. -   fromstring                 - Construct array from (byte) string
  17. -   take                       - Select sub-arrays using sequence of indices
  18. -   put                        - Set sub-arrays using sequence of 1-D indices
  19. -   putmask                    - Set portion of arrays using a mask
  20. -   reshape                    - Return array with new shape
  21. -   repeat                     - Repeat elements of array
  22. -   choose                     - Construct new array from indexed array tuple
  23. -   cross_correlate            - Correlate two 1-d arrays
  24. -   searchsorted               - Search for element in 1-d array
  25. -   sum                        - Total sum over a specified dimension
  26. -   average                    - Average, possibly weighted, over axis or array.
  27. -   cumsum                     - Cumulative sum over a specified dimension
  28. -   product                    - Total product over a specified dimension
  29. -   cumproduct                 - Cumulative product over a specified dimension
  30. -   alltrue                    - Logical and over an entire axis
  31. -   sometrue                   - Logical or over an entire axis
  32. -   allclose\t\t       - Tests if sequences are essentially equal
  33.  
  34. More Functions:
  35.  
  36. -   arrayrange (arange)        - Return regularly spaced array
  37. -   asarray                    - Guarantee NumPy array
  38. -   sarray                     - Guarantee a NumPy array that keeps precision
  39. -   convolve                   - Convolve two 1-d arrays
  40. -   swapaxes                   - Exchange axes
  41. -   concatenate                - Join arrays together
  42. -   transpose                  - Permute axes
  43. -   sort                       - Sort elements of array
  44. -   argsort                    - Indices of sorted array
  45. -   argmax                     - Index of largest value
  46. -   argmin                     - Index of smallest value
  47. -   innerproduct               - Innerproduct of two arrays
  48. -   dot                        - Dot product (matrix multiplication)
  49. -   outerproduct               - Outerproduct of two arrays
  50. -   resize                     - Return array with arbitrary new shape
  51. -   indices                    - Tuple of indices
  52. -   fromfunction               - Construct array from universal function
  53. -   diagonal                   - Return diagonal array
  54. -   trace                      - Trace of array
  55. -   dump                       - Dump array to file object (pickle)
  56. -   dumps                      - Return pickled string representing data
  57. -   load                       - Return array stored in file object
  58. -   loads                      - Return array from pickled string
  59. -   ravel                      - Return array as 1-D
  60. -   nonzero                    - Indices of nonzero elements for 1-D array
  61. -   shape                      - Shape of array
  62. -   where                      - Construct array from binary result
  63. -   compress                   - Elements of array where condition is true
  64. -   clip                       - Clip array between two values
  65. -   ones                       - Array of all ones
  66. -   identity                   - 2-D identity array (matrix)
  67.  
  68. (Universal) Math Functions
  69.  
  70.        add                    logical_or             exp
  71.        subtract               logical_xor            log
  72.        multiply               logical_not            log10
  73.        divide                 maximum                sin
  74.        divide_safe            minimum                sinh
  75.        conjugate              bitwise_and            sqrt
  76.        power                  bitwise_or             tan
  77.        absolute               bitwise_xor            tanh
  78.        negative               invert                 ceil
  79.        greater                left_shift             fabs
  80.        greater_equal          right_shift            floor
  81.        less                   arccos                 arctan2
  82.        less_equal             arcsin                 fmod
  83.        equal                  arctan                 hypot
  84.        not_equal              cos                    around
  85.        logical_and            cosh                   sign
  86.        arccosh                arcsinh                arctanh
  87.  
  88. '''
  89. import numeric_version
  90. __version__ = numeric_version.version
  91. del numeric_version
  92. import multiarray
  93. from umath import *
  94. from Precision import *
  95. import _numpy
  96. import string
  97. import types
  98. import math
  99. NewAxis = None
  100. arrayrange = multiarray.arange
  101. array = multiarray.array
  102. zeros = multiarray.zeros
  103. empty = multiarray.empty
  104.  
  105. def asarray(a, typecode = None, savespace = 0):
  106.     '''asarray(a,typecode=None) returns a as a NumPy array.  Unlike array(),
  107.     no copy is performed if a is already an array.
  108.     '''
  109.     return multiarray.array(a, typecode, copy = 0, savespace = savespace)
  110.  
  111.  
  112. def sarray(a, typecode = None, copy = 0):
  113.     '''sarray(a, typecode=None, copy=0) calls array with savespace=1.'''
  114.     return multiarray.array(a, typecode, copy, savespace = 1)
  115.  
  116. fromstring = multiarray.fromstring
  117. take = multiarray.take
  118. reshape = multiarray.reshape
  119. choose = multiarray.choose
  120. cross_correlate = multiarray.cross_correlate
  121.  
  122. def repeat(a, repeats, axis = 0):
  123.     '''repeat elements of a repeats times along axis
  124.        repeats is a sequence of length a.shape[axis]
  125.        telling how many times to repeat each element.
  126.        If repeats is an integer, it is interpreted as
  127.        a tuple of length a.shape[axis] containing repeats.
  128.        The argument a can be anything array(a) will accept.
  129.     '''
  130.     a = array(a, copy = 0)
  131.     s = a.shape
  132.     if isinstance(repeats, types.IntType):
  133.         repeats = tuple([
  134.             repeats] * s[axis])
  135.     
  136.     if len(repeats) != s[axis]:
  137.         raise ValueError, 'repeat requires second argument integer or of length of a.shape[axis].'
  138.     
  139.     d = multiarray.repeat(a, repeats, axis)
  140.     return d
  141.  
  142.  
  143. def put(a, ind, v):
  144.     '''put(a, ind, v) results in a[n] = v[n] for all n in ind
  145.        If v is shorter than mask it will be repeated as necessary.
  146.        In particular v can be a scalar or length 1 array.
  147.        The routine put is the equivalent of the following (although the loop
  148.        is in C for speed):
  149.  
  150.            ind = array(indices, copy=0)
  151.            v = array(values, copy=0).astype(a, typecode())
  152.            for i in ind: a.flat[i] = v[i]
  153.        a must be a contiguous Numeric array.
  154.     '''
  155.     multiarray.put(a, ind, array(v, copy = 0).astype(a.typecode()))
  156.  
  157.  
  158. def putmask(a, mask, v):
  159.     '''putmask(a, mask, v) results in a = v for all places mask is true.
  160.        If v is shorter than mask it will be repeated as necessary.
  161.        In particular v can be a scalar or length 1 array.
  162.     '''
  163.     tc = a.typecode()
  164.     mask = asarray(mask).astype(Int)
  165.     v = array(v, copy = 0).astype(tc)
  166.     if tc == PyObject:
  167.         if v.shape == ():
  168.             v.shape = (1,)
  169.         
  170.         ax = ravel(a)
  171.         mx = ravel(mask)
  172.         vx = ravel(v)
  173.         vx = resize(vx, ax.shape)
  174.         for i in range(len(ax)):
  175.             if mx[i]:
  176.                 ax[i] = vx[i]
  177.                 continue
  178.         
  179.     else:
  180.         multiarray.putmask(a, mask, v)
  181.  
  182.  
  183. def convolve(a, v, mode = 2):
  184.     '''Returns the discrete, linear convolution of 1-D
  185.     sequences a and v; mode can be 0 (valid), 1 (same), or 2 (full)
  186.     to specify size of the resulting sequence.
  187.     '''
  188.     if len(v) > len(a):
  189.         temp = a
  190.         a = v
  191.         v = temp
  192.         del temp
  193.     
  194.     return cross_correlate(a, asarray(v)[::-1], mode)
  195.  
  196. ArrayType = multiarray.arraytype
  197. UfuncType = type(sin)
  198.  
  199. def swapaxes(a, axis1, axis2):
  200.     '''swapaxes(a, axis1, axis2) returns array a with axis1 and axis2
  201.     interchanged.
  202.     '''
  203.     a = array(a, copy = 0)
  204.     n = len(a.shape)
  205.     if n <= 1:
  206.         return a
  207.     
  208.     if axis1 < 0:
  209.         axis1 += n
  210.     
  211.     if axis2 < 0:
  212.         axis2 += n
  213.     
  214.     if axis1 < 0 or axis1 >= n:
  215.         raise ValueError, 'Bad axis1 argument to swapaxes.'
  216.     
  217.     if axis2 < 0 or axis2 >= n:
  218.         raise ValueError, 'Bad axis2 argument to swapaxes.'
  219.     
  220.     new_axes = arange(n)
  221.     new_axes[axis1] = axis2
  222.     new_axes[axis2] = axis1
  223.     return multiarray.transpose(a, new_axes)
  224.  
  225. arraytype = multiarray.arraytype
  226.  
  227. def concatenate(a, axis = 0):
  228.     '''concatenate(a, axis=0) joins the tuple of sequences in a into a single
  229.     NumPy array.
  230.     '''
  231.     if axis == 0:
  232.         return multiarray.concatenate(a)
  233.     else:
  234.         new_list = []
  235.         for m in a:
  236.             new_list.append(swapaxes(m, axis, 0))
  237.         
  238.     return swapaxes(multiarray.concatenate(new_list), axis, 0)
  239.  
  240.  
  241. def transpose(a, axes = None):
  242.     '''transpose(a, axes=None) returns array with dimensions permuted
  243.     according to axes.  If axes is None (default) returns array with
  244.     dimensions reversed.
  245.     '''
  246.     return multiarray.transpose(a, axes)
  247.  
  248.  
  249. def sort(a, axis = -1):
  250.     '''sort(a,axis=-1) returns array with elements sorted along given axis.
  251.     '''
  252.     a = array(a, copy = 0)
  253.     n = len(a.shape)
  254.     if axis < 0:
  255.         axis += n
  256.     
  257.     if axis < 0 or axis >= n:
  258.         raise ValueError, 'sort axis argument out of bounds'
  259.     
  260.     if axis != n - 1:
  261.         a = swapaxes(a, axis, n - 1)
  262.     
  263.     s = multiarray.sort(a)
  264.     if axis != n - 1:
  265.         s = swapaxes(s, axis, -1)
  266.     
  267.     return s
  268.  
  269.  
  270. def argsort(a, axis = -1):
  271.     '''argsort(a,axis=-1) return the indices into a of the sorted array
  272.     along the given axis, so that take(a,result,axis) is the sorted array.
  273.     '''
  274.     a = array(a, copy = 0)
  275.     n = len(a.shape)
  276.     if axis < 0:
  277.         axis += n
  278.     
  279.     if axis < 0 or axis >= n:
  280.         raise ValueError, 'argsort axis argument out of bounds'
  281.     
  282.     if axis != n - 1:
  283.         a = swapaxes(a, axis, n - 1)
  284.     
  285.     s = multiarray.argsort(a)
  286.     if axis != n - 1:
  287.         s = swapaxes(s, axis, -1)
  288.     
  289.     return s
  290.  
  291.  
  292. def argmax(a, axis = -1):
  293.     '''argmax(a,axis=-1) returns the indices to the maximum value of the
  294.     1-D arrays along the given axis.
  295.     '''
  296.     a = array(a, copy = 0)
  297.     n = len(a.shape)
  298.     if axis < 0:
  299.         axis += n
  300.     
  301.     if axis < 0 or axis >= n:
  302.         raise ValueError, 'argmax axis argument out of bounds'
  303.     
  304.     if axis != n - 1:
  305.         a = swapaxes(a, axis, n - 1)
  306.     
  307.     s = multiarray.argmax(a)
  308.     if axis != n - 1:
  309.         s = swapaxes(s, axis, -1)
  310.     
  311.     return s
  312.  
  313.  
  314. def argmin(a, axis = -1):
  315.     '''argmin(a,axis=-1) returns the indices to the minimum value of the
  316.     1-D arrays along the given axis.
  317.     '''
  318.     arra = array(a, copy = 0)
  319.     type = arra.typecode()
  320.     num = array(0, type)
  321.     if type in [
  322.         'bwu']:
  323.         num = -array(1, type)
  324.     
  325.     a = num - arra
  326.     n = len(a.shape)
  327.     if axis < 0:
  328.         axis += n
  329.     
  330.     if axis < 0 or axis >= n:
  331.         raise ValueError, 'argmin axis argument out of bounds'
  332.     
  333.     if axis != n - 1:
  334.         a = swapaxes(a, axis, n - 1)
  335.     
  336.     s = multiarray.argmax(a)
  337.     if axis != n - 1:
  338.         s = swapaxes(s, axis, -1)
  339.     
  340.     return s
  341.  
  342. searchsorted = multiarray.binarysearch
  343.  
  344. def innerproduct(a, b):
  345.     '''innerproduct(a,b) returns the dot product of two arrays, which has
  346.     shape a.shape[:-1] + b.shape[:-1] with elements computed by summing the
  347.     product of the elements from the last dimensions of a and b.
  348.     '''
  349.     
  350.     try:
  351.         return multiarray.innerproduct(a, b)
  352.     except TypeError:
  353.         detail = None
  354.         if array(a).shape == () or array(b).shape == ():
  355.             return a * b
  356.         elif not detail:
  357.             pass
  358.         raise TypeError, 'invalid types for dot'
  359.  
  360.  
  361.  
  362. def outerproduct(a, b):
  363.     '''outerproduct(a,b) returns the outer product of two vectors.
  364.       result(i,j) = a(i)*b(j) when a and b are vectors
  365.       Will accept any arguments that can be made into vectors.
  366.    '''
  367.     return array(a).flat[(:, NewAxis)] * array(b).flat[(NewAxis, :)]
  368.  
  369.  
  370. def dot(a, b):
  371.     '''dot(a,b) returns matrix-multiplication between a and b.  The product-sum
  372.     is over the last dimension of a and the second-to-last dimension of b.
  373.     '''
  374.     
  375.     try:
  376.         return multiarray.matrixproduct(a, b)
  377.     except TypeError:
  378.         detail = None
  379.         if array(a).shape == () or array(b).shape == ():
  380.             return a * b
  381.         elif not detail:
  382.             pass
  383.         raise TypeError, 'invalid types for dot'
  384.  
  385.  
  386.  
  387. def vdot(a, b):
  388.     '''Returns the dot product of 2 vectors (or anything that can be made into
  389.        a vector). NB: this is not the same as `dot`, as it takes the conjugate
  390.        of its first argument if complex and always returns a scalar.'''
  391.     return multiarray.matrixproduct(conjugate(ravel(a)), ravel(b))
  392.  
  393.  
  394. try:
  395.     from dotblas import dot, innerproduct, vdot
  396. except ImportError:
  397.     pass
  398.  
  399. matrixmultiply = dot
  400.  
  401. def _move_axis_to_0(a, axis):
  402.     if axis == 0:
  403.         return a
  404.     
  405.     n = len(a.shape)
  406.     if axis < 0:
  407.         axis += n
  408.     
  409.     axes = range(1, axis + 1) + [
  410.         0] + range(axis + 1, n)
  411.     return multiarray.transpose(a, axes)
  412.  
  413.  
  414. def cross_product(a, b, axis1 = -1, axis2 = -1):
  415.     '''Return the cross product of two vectors.
  416.  
  417.     The cross product is performed over the last axes of a and b by default,
  418.     and can handle axes with dimensions 2 and 3. For a dimension of 2,
  419.     the z-component of the equivalent three-dimensional cross product is
  420.     returned.
  421.     '''
  422.     a = _move_axis_to_0(asarray(a), axis1)
  423.     b = _move_axis_to_0(asarray(b), axis2)
  424.     if a.shape[0] != b.shape[0]:
  425.         raise ValueError('incompatible dimensions for cross product')
  426.     elif a.shape[0] == 2:
  427.         return a[0] * b[1] - a[1] * b[0]
  428.     elif a.shape[0] == 3:
  429.         x = a[1] * b[2] - a[2] * b[1]
  430.         y = a[2] * b[0] - a[0] * b[2]
  431.         z = a[0] * b[1] - a[1] * b[0]
  432.         cp = array([
  433.             x,
  434.             y,
  435.             z])
  436.         if len(cp.shape) == 1:
  437.             return cp
  438.         
  439.         axes = range(1, len(cp.shape)) + [
  440.             0]
  441.         return multiarray.transpose(cp, axes)
  442.     else:
  443.         raise ValueError('can only do cross product for axes with dimensions 2 or 3')
  444.  
  445. from ArrayPrinter import array2string
  446.  
  447. def array_repr(a, max_line_width = None, precision = None, suppress_small = None):
  448.     return array2string(a, max_line_width, precision, suppress_small, ', ', 1)
  449.  
  450.  
  451. def array_str(a, max_line_width = None, precision = None, suppress_small = None):
  452.     return array2string(a, max_line_width, precision, suppress_small, ' ', 0)
  453.  
  454. multiarray.set_string_function(array_str, 0)
  455. multiarray.set_string_function(array_repr, 1)
  456. LittleEndian = fromstring('\x01' + '\x00' * 7, 'i')[0] == 1
  457.  
  458. def resize(a, new_shape):
  459.     """resize(a,new_shape) returns a new array with the specified shape.
  460.     The original array's total size can be any size.
  461.     """
  462.     a = ravel(a)
  463.     if not len(a):
  464.         return zeros(new_shape, a.typecode())
  465.     
  466.     total_size = multiply.reduce(new_shape)
  467.     n_copies = int(total_size / len(a))
  468.     extra = total_size % len(a)
  469.     if extra != 0:
  470.         n_copies = n_copies + 1
  471.         extra = len(a) - extra
  472.     
  473.     a = concatenate((a,) * n_copies)
  474.     if extra > 0:
  475.         a = a[:-extra]
  476.     
  477.     return reshape(a, new_shape)
  478.  
  479.  
  480. def indices(dimensions, typecode = None):
  481.     '''indices(dimensions,typecode=None) returns an array representing a grid
  482.     of indices with row-only, and column-only variation.
  483.     '''
  484.     tmp = ones(dimensions, typecode)
  485.     lst = []
  486.     for i in range(len(dimensions)):
  487.         lst.append(add.accumulate(tmp, i) - 1)
  488.     
  489.     return array(lst)
  490.  
  491.  
  492. def fromfunction(function, dimensions):
  493.     '''fromfunction(function, dimensions) returns an array constructed by
  494.     calling function on a tuple of number grids.  The function should
  495.     accept as many arguments as there are dimensions which is a list of
  496.     numbers indicating the length of the desired output for each axis.
  497.     '''
  498.     return apply(function, tuple(indices(dimensions)))
  499.  
  500.  
  501. def diagonal(a, offset = 0, axis1 = 0, axis2 = 1):
  502.     '''diagonal(a, offset=0, axis1=0, axis2=1) returns all offset diagonals
  503.     defined by the given dimensions of the array.
  504.     '''
  505.     a = asarray(a)
  506.     nd = len(a.shape)
  507.     new_axes = range(nd)
  508.     if axis1 < 0:
  509.         axis1 += nd
  510.     
  511.     if axis2 < 0:
  512.         axis2 += nd
  513.     
  514.     
  515.     try:
  516.         new_axes.remove(axis1)
  517.         new_axes.remove(axis2)
  518.     except ValueError:
  519.         raise ValueError, 'axis1(=%d) and axis2(=%d) must be different and within range (nd=%d).' % (axis1, axis2, nd)
  520.  
  521.     new_axes = new_axes + [
  522.         axis1,
  523.         axis2]
  524.     a = transpose(a, new_axes)
  525.     s = a.shape
  526.     if len(s) == 2:
  527.         n1 = s[0]
  528.         n2 = s[1]
  529.         n = n1 * n2
  530.         s = (n,)
  531.         a = reshape(a, s)
  532.         if offset < 0:
  533.             return take(a, range(-n2 * offset, min(n2, n1 + offset) * (n2 + 1) - n2 * offset, n2 + 1), 0)
  534.         else:
  535.             return take(a, range(offset, min(n1, n2 - offset) * (n2 + 1) + offset, n2 + 1), 0)
  536.     else:
  537.         my_diagonal = []
  538.         for i in range(s[0]):
  539.             my_diagonal.append(diagonal(a[i], offset))
  540.         
  541.         return array(my_diagonal)
  542.  
  543.  
  544. def trace(a, offset = 0, axis1 = 0, axis2 = 1):
  545.     '''trace(a,offset=0, axis1=0, axis2=1) returns the sum along diagonals
  546.     (defined by the last two dimensions) of the array.
  547.     '''
  548.     return add.reduce(diagonal(a, offset, axis1, axis2), -1)
  549.  
  550.  
  551. def DumpArray(m, fp):
  552.     if m.typecode() == 'O':
  553.         raise TypeError, "Numeric Pickler can't pickle arrays of Objects"
  554.     
  555.     s = m.shape
  556.     if LittleEndian:
  557.         endian = 'L'
  558.     else:
  559.         endian = 'B'
  560.     fp.write('A%s%s%d ' % (m.typecode(), endian, m.itemsize()))
  561.     for d in s:
  562.         fp.write('%d ' % d)
  563.     
  564.     fp.write('\n')
  565.     fp.write(m.tostring())
  566.  
  567.  
  568. def LoadArray(fp):
  569.     ln = string.split(fp.readline())
  570.     if ln[0][0] == 'A':
  571.         ln[0] = ln[0][1:]
  572.     
  573.     typecode = ln[0][0]
  574.     endian = ln[0][1]
  575.     shape = map((lambda x: string.atoi(x)), ln[1:])
  576.     itemsize = string.atoi(ln[0][2:])
  577.     sz = reduce(multiply, shape) * itemsize
  578.     data = fp.read(sz)
  579.     m = fromstring(data, typecode)
  580.     m = reshape(m, shape)
  581.     if (LittleEndian or endian == 'B' or not LittleEndian) and endian == 'L':
  582.         return m.byteswapped()
  583.     else:
  584.         return m
  585.  
  586. import pickle
  587. import copy
  588.  
  589. class Unpickler(pickle.Unpickler):
  590.     
  591.     def load_array(self):
  592.         self.stack.append(LoadArray(self))
  593.  
  594.     dispatch = copy.copy(pickle.Unpickler.dispatch)
  595.     dispatch['A'] = load_array
  596.  
  597.  
  598. class Pickler(pickle.Pickler):
  599.     
  600.     def save_array(self, object):
  601.         DumpArray(object, self)
  602.  
  603.     dispatch = copy.copy(pickle.Pickler.dispatch)
  604.     dispatch[ArrayType] = save_array
  605.  
  606. from StringIO import StringIO
  607. from pickle import load, loads, dump, dumps
  608. import copy_reg
  609.  
  610. def array_constructor(shape, typecode, thestr, Endian = LittleEndian):
  611.     if typecode == 'O':
  612.         x = array(thestr, 'O')
  613.     else:
  614.         x = fromstring(thestr, typecode)
  615.     x.shape = shape
  616.     if LittleEndian != Endian:
  617.         return x.byteswapped()
  618.     else:
  619.         return x
  620.  
  621.  
  622. def pickle_array(a):
  623.     if a.typecode() == 'O':
  624.         return (array_constructor, (a.shape, a.typecode(), a.tolist(), LittleEndian))
  625.     else:
  626.         return (array_constructor, (a.shape, a.typecode(), a.tostring(), LittleEndian))
  627.  
  628. copy_reg.pickle(ArrayType, pickle_array, array_constructor)
  629.  
  630. def ravel(m):
  631.     """ravel(m) returns a 1d array corresponding to all the elements of it's
  632.     argument.
  633.     """
  634.     return reshape(m, (-1,))
  635.  
  636.  
  637. def nonzero(a):
  638.     '''nonzero(a) returns the indices of the elements of a which are not zero,
  639.     a must be 1d
  640.     '''
  641.     return repeat(arange(len(a)), not_equal(a, 0))
  642.  
  643.  
  644. def shape(a):
  645.     '''shape(a) returns the shape of a (as a function call which
  646.        also works on nested sequences).
  647.     '''
  648.     return asarray(a).shape
  649.  
  650.  
  651. def where(condition, x, y):
  652.     '''where(condition,x,y) is shaped like condition and has elements of x and
  653.     y where condition is respectively true or false.
  654.     '''
  655.     return choose(not_equal(condition, 0), (y, x))
  656.  
  657.  
  658. def compress(condition, m, axis = -1):
  659.     '''compress(condition, x, axis=-1) = those elements of x corresponding
  660.     to those elements of condition that are "true".  condition must be the
  661.     same size as the given dimension of x.'''
  662.     return take(m, nonzero(condition), axis)
  663.  
  664.  
  665. def clip(m, m_min, m_max):
  666.     '''clip(m, m_min, m_max) = every entry in m that is less than m_min is
  667.     replaced by m_min, and every entry greater than m_max is replaced by
  668.     m_max.
  669.     '''
  670.     selector = less(m, m_min) + 2 * greater(m, m_max)
  671.     return choose(selector, (m, m_min, m_max))
  672.  
  673.  
  674. def ones(shape, typecode = 'l', savespace = 0):
  675.     '''ones(shape, typecode=Int, savespace=0) returns an array of the given
  676.     dimensions which is initialized to all ones.
  677.     '''
  678.     a = zeros(shape, typecode, savespace)
  679.     a[...] = 1
  680.     return a
  681.  
  682.  
  683. def identity(n, typecode = 'l'):
  684.     '''identity(n) returns the identity matrix of shape n x n.
  685.     '''
  686.     return resize(array([
  687.         1] + n * [
  688.         0], typecode = typecode), (n, n))
  689.  
  690.  
  691. def sum(x, axis = 0):
  692.     '''Sum the array over the given axis.
  693.     '''
  694.     x = array(x, copy = 0)
  695.     n = len(x.shape)
  696.     if axis < 0:
  697.         axis += n
  698.     
  699.     if n == 0 and axis in [
  700.         0,
  701.         -1]:
  702.         return x[0]
  703.     
  704.     if axis < 0 or axis >= n:
  705.         raise ValueError, 'Improper axis argument to sum.'
  706.     
  707.     return add.reduce(x, axis)
  708.  
  709.  
  710. def product(x, axis = 0):
  711.     '''Product of the array elements over the given axis.'''
  712.     x = array(x, copy = 0)
  713.     n = len(x.shape)
  714.     if axis < 0:
  715.         axis += n
  716.     
  717.     if n == 0 and axis in [
  718.         0,
  719.         -1]:
  720.         return x[0]
  721.     
  722.     if axis < 0 or axis >= n:
  723.         return (ValueError, 'Improper axis argument to product.')
  724.     
  725.     return multiply.reduce(x, axis)
  726.  
  727.  
  728. def sometrue(x, axis = 0):
  729.     '''Perform a logical_or over the given axis.'''
  730.     x = array(x, copy = 0)
  731.     n = len(x.shape)
  732.     if axis < 0:
  733.         axis += n
  734.     
  735.     if n == 0 and axis in [
  736.         0,
  737.         -1]:
  738.         return x[0] != 0
  739.     
  740.     if axis < 0 or axis >= n:
  741.         return (ValueError, 'Improper axis argument to sometrue.')
  742.     
  743.     return logical_or.reduce(x, axis)
  744.  
  745.  
  746. def alltrue(x, axis = 0):
  747.     '''Perform a logical_and over the given axis.'''
  748.     x = array(x, copy = 0)
  749.     n = len(x.shape)
  750.     if axis < 0:
  751.         axis += n
  752.     
  753.     if n == 0 and axis in [
  754.         0,
  755.         -1]:
  756.         return x[0] != 0
  757.     
  758.     if axis < 0 or axis >= n:
  759.         return (ValueError, 'Improper axis argument to product.')
  760.     
  761.     return logical_and.reduce(x, axis)
  762.  
  763.  
  764. def cumsum(x, axis = 0):
  765.     '''Sum the array over the given axis.'''
  766.     x = array(x, copy = 0)
  767.     n = len(x.shape)
  768.     if axis < 0:
  769.         axis += n
  770.     
  771.     if n == 0 and axis in [
  772.         0,
  773.         -1]:
  774.         return x[0]
  775.     
  776.     if axis < 0 or axis >= n:
  777.         return (ValueError, 'Improper axis argument to cumsum.')
  778.     
  779.     return add.accumulate(x, axis)
  780.  
  781.  
  782. def cumproduct(x, axis = 0):
  783.     '''Sum the array over the given axis.'''
  784.     x = array(x, copy = 0)
  785.     n = len(x.shape)
  786.     if axis < 0:
  787.         axis += n
  788.     
  789.     if n == 0 and axis in [
  790.         0,
  791.         -1]:
  792.         return x[0]
  793.     
  794.     if axis < 0 or axis >= n:
  795.         return (ValueError, 'Improper axis argument to cumproduct.')
  796.     
  797.     return multiply.accumulate(x, axis)
  798.  
  799. arange = multiarray.arange
  800.  
  801. def around(m, decimals = 0):
  802.     '''around(m, decimals=0)     Round in the same way as standard python performs rounding. Returns
  803.     always a float.
  804.     '''
  805.     m = asarray(m)
  806.     s = sign(m)
  807.     if decimals:
  808.         m = absolute(m * 10.0 ** decimals)
  809.     else:
  810.         m = absolute(m)
  811.     rem = m - asarray(m).astype(Int)
  812.     m = where(less(rem, 0.5), floor(m), ceil(m))
  813.     if decimals:
  814.         m = m * s / 10.0 ** decimals
  815.     else:
  816.         m = m * s
  817.     return m
  818.  
  819.  
  820. def sign(m):
  821.     '''sign(m) gives an array with shape of m with elements defined by sign
  822.     function:  where m is less than 0 return -1, where m greater than 0, a=1,
  823.     elsewhere a=0.
  824.     '''
  825.     m = asarray(m)
  826.     return (zeros(shape(m)) - less(m, 0)) + greater(m, 0)
  827.  
  828.  
  829. def allclose(a, b, rtol = 1.0000000000000001e-05, atol = 1e-08):
  830.     ''' allclose(a,b,rtol=1.e-5,atol=1.e-8)
  831.         Returns true if all components of a and b are equal
  832.         subject to given tolerances.
  833.         The relative error rtol must be positive and << 1.0
  834.         The absolute error atol comes into play for those elements
  835.         of y that are very small or zero; it says how small x must be also.
  836.     '''
  837.     x = array(a, copy = 0)
  838.     y = array(b, copy = 0)
  839.     d = less(absolute(x - y), atol + rtol * absolute(y))
  840.     return alltrue(ravel(d))
  841.  
  842.  
  843. def rank(a):
  844.     '''Get the rank of sequence a (the number of dimensions, not a matrix rank)
  845.        The rank of a scalar is zero.
  846.     '''
  847.     return len(shape(a))
  848.  
  849.  
  850. def shape(a):
  851.     '''Get the shape of sequence a'''
  852.     
  853.     try:
  854.         return a.shape
  855.     except:
  856.         return array(a).shape
  857.  
  858.  
  859.  
  860. def size(a, axis = None):
  861.     '''Get the number of elements in sequence a, or along a certain axis.'''
  862.     s = shape(a)
  863.     if axis is None:
  864.         if len(s) == 0:
  865.             return 1
  866.         else:
  867.             return reduce((lambda x, y: x * y), s)
  868.     else:
  869.         return s[axis]
  870.  
  871.  
  872. def average(a, axis = 0, weights = None, returned = 0):
  873.     """average(a, axis=0, weights=None)
  874.        Computes average along indicated axis.
  875.        If axis is None, average over the entire array.
  876.        Inputs can be integer or floating types; result is type Float.
  877.  
  878.        If weights are given, result is:
  879.            sum(a*weights)/(sum(weights))
  880.        weights must have a's shape or be the 1-d with length the size
  881.        of a in the given axis. Integer weights are converted to Float.
  882.  
  883.        Not supplying weights is equivalent to supply weights that are
  884.        all 1.
  885.  
  886.        If returned, return a tuple: the result and the sum of the weights
  887.        or count of values. The shape of these two results will be the same.
  888.  
  889.        raises ZeroDivisionError if appropriate when result is scalar.
  890.        (The version in MA does not -- it returns masked values).
  891.     """
  892.     if axis is None:
  893.         a = array(a).flat
  894.         if weights is None:
  895.             n = add.reduce(a)
  896.             d = len(a) * 1.0
  897.         else:
  898.             w = array(weights, typecode = Float, copy = 0).flat
  899.             n = add.reduce(a * w)
  900.             d = add.reduce(w)
  901.     else:
  902.         a = array(a)
  903.         ash = a.shape
  904.         if ash == ():
  905.             a.shape = (1,)
  906.         
  907.         if weights is None:
  908.             n = add.reduce(a, axis)
  909.             d = ash[axis] * 1.0
  910.             if returned:
  911.                 d = ones(shape(n)) * d
  912.             
  913.         else:
  914.             w = array(weights, copy = 0) * 1.0
  915.             wsh = w.shape
  916.             if wsh == ():
  917.                 wsh = (1,)
  918.             
  919.             if wsh == ash:
  920.                 n = add.reduce(a * w, axis)
  921.                 d = add.reduce(w, axis)
  922.             elif wsh == (ash[axis],):
  923.                 ni = ash[axis]
  924.                 r = [
  925.                     NewAxis] * ni
  926.                 r[axis] = slice(None, None, 1)
  927.                 w1 = eval('w[' + repr(tuple(r)) + ']*ones(ash, Float)')
  928.                 n = add.reduce(a * w1, axis)
  929.                 d = add.reduce(w1, axis)
  930.             else:
  931.                 raise ValueError, 'average: weights wrong shape.'
  932.     if not isinstance(d, ArrayType):
  933.         if d == 0.0:
  934.             raise ZeroDivisionError, 'Numeric.average, zero denominator'
  935.         
  936.     
  937.     if returned:
  938.         return (n / d, d)
  939.     else:
  940.         return n / d
  941.  
  942.