home *** CD-ROM | disk | FTP | other *** search
/ PC Welt 2006 November (DVD) / PCWELT_11_2006.ISO / casper / filesystem.squashfs / usr / share / pycentral / python-numeric / site-packages / Numeric / Numeric.py < prev    next >
Encoding:
Python Source  |  2005-09-08  |  28.0 KB  |  826 lines

  1. """Numeric module defining a multi-dimensional array and useful procedures for
  2.    Numerical computation.
  3.  
  4. Functions
  5.  
  6. -   array                      - NumPy Array construction
  7. -   zeros                      - Return an array of all zeros
  8. -   empty                      - Return an uninitialized array (200x faster than zeros)
  9. -   shape                      - Return shape of sequence or array
  10. -   rank                       - Return number of dimensions
  11. -   size                       - Return number of elements in entire array or a
  12.                                  certain dimension
  13. -   fromstring                 - Construct array from (byte) string
  14. -   take                       - Select sub-arrays using sequence of indices
  15. -   put                        - Set sub-arrays using sequence of 1-D indices
  16. -   putmask                    - Set portion of arrays using a mask
  17. -   reshape                    - Return array with new shape
  18. -   repeat                     - Repeat elements of array
  19. -   choose                     - Construct new array from indexed array tuple
  20. -   cross_correlate            - Correlate two 1-d arrays
  21. -   searchsorted               - Search for element in 1-d array
  22. -   sum                        - Total sum over a specified dimension
  23. -   average                    - Average, possibly weighted, over axis or array.
  24. -   cumsum                     - Cumulative sum over a specified dimension
  25. -   product                    - Total product over a specified dimension
  26. -   cumproduct                 - Cumulative product over a specified dimension
  27. -   alltrue                    - Logical and over an entire axis
  28. -   sometrue                   - Logical or over an entire axis
  29. -   allclose               - Tests if sequences are essentially equal
  30.  
  31. More Functions:
  32.  
  33. -   arrayrange (arange)        - Return regularly spaced array
  34. -   asarray                    - Guarantee NumPy array
  35. -   sarray                     - Guarantee a NumPy array that keeps precision
  36. -   convolve                   - Convolve two 1-d arrays
  37. -   swapaxes                   - Exchange axes
  38. -   concatenate                - Join arrays together
  39. -   transpose                  - Permute axes
  40. -   sort                       - Sort elements of array
  41. -   argsort                    - Indices of sorted array
  42. -   argmax                     - Index of largest value
  43. -   argmin                     - Index of smallest value
  44. -   innerproduct               - Innerproduct of two arrays
  45. -   dot                        - Dot product (matrix multiplication)
  46. -   outerproduct               - Outerproduct of two arrays
  47. -   resize                     - Return array with arbitrary new shape
  48. -   indices                    - Tuple of indices
  49. -   fromfunction               - Construct array from universal function
  50. -   diagonal                   - Return diagonal array
  51. -   trace                      - Trace of array
  52. -   dump                       - Dump array to file object (pickle)
  53. -   dumps                      - Return pickled string representing data
  54. -   load                       - Return array stored in file object
  55. -   loads                      - Return array from pickled string
  56. -   ravel                      - Return array as 1-D
  57. -   nonzero                    - Indices of nonzero elements for 1-D array
  58. -   shape                      - Shape of array
  59. -   where                      - Construct array from binary result
  60. -   compress                   - Elements of array where condition is true
  61. -   clip                       - Clip array between two values
  62. -   ones                       - Array of all ones
  63. -   identity                   - 2-D identity array (matrix)
  64.  
  65. (Universal) Math Functions
  66.  
  67.        add                    logical_or             exp
  68.        subtract               logical_xor            log
  69.        multiply               logical_not            log10
  70.        divide                 maximum                sin
  71.        divide_safe            minimum                sinh
  72.        conjugate              bitwise_and            sqrt
  73.        power                  bitwise_or             tan
  74.        absolute               bitwise_xor            tanh
  75.        negative               invert                 ceil
  76.        greater                left_shift             fabs
  77.        greater_equal          right_shift            floor
  78.        less                   arccos                 arctan2
  79.        less_equal             arcsin                 fmod
  80.        equal                  arctan                 hypot
  81.        not_equal              cos                    around
  82.        logical_and            cosh                   sign
  83.        arccosh                arcsinh                arctanh
  84.  
  85. """
  86.  
  87. import numeric_version
  88. __version__ = numeric_version.version
  89. del numeric_version
  90.  
  91. import multiarray
  92. from umath import *
  93. from Precision import *
  94.  
  95. import _numpy # for freeze dependency resolution (at least on Mac)
  96.  
  97. import string, types, math
  98.  
  99. #Use this to add a new axis to an array
  100. NewAxis = None
  101.  
  102. #The following functions are considered builtin, they all might be
  103. #in C some day
  104.  
  105. ##def range2(start, stop=None, step=1, typecode=None):
  106. ##    """Just like range() except it returns a array whose type can be specified
  107. ##    by the keyword argument typecode.
  108. ##    """
  109. ##
  110. ##    if (stop is None):
  111. ##        stop = start
  112. ##        start = 0
  113. ##    n = int(math.ceil(float(stop-start)/step))
  114. ##    if n <= 0:
  115. ##        m = zeros( (0,) )+(step+start+stop)
  116. ##    else:
  117. ##        m = (add.accumulate(ones((n,), Int))-1)*step +(start+(stop-stop))
  118. ##        # the last bit is to deal with e.g. Longs -- 3L-3L==0L
  119. ##   if typecode is not None and m.typecode() != typecode:
  120. ##        return m.astype(typecode)
  121. ##    else:
  122. ##        return m
  123.  
  124. arrayrange = multiarray.arange
  125.  
  126. array = multiarray.array
  127. zeros = multiarray.zeros
  128. empty = multiarray.empty
  129.  
  130. def asarray(a, typecode=None, savespace=0):
  131.     """asarray(a,typecode=None) returns a as a NumPy array.  Unlike array(),
  132.     no copy is performed if a is already an array.
  133.     """
  134.     return multiarray.array(a, typecode, copy=0, savespace=savespace)
  135.  
  136. def sarray(a, typecode=None, copy=0):
  137.     """sarray(a, typecode=None, copy=0) calls array with savespace=1."""
  138.     return multiarray.array(a, typecode, copy, savespace=1)
  139.  
  140. fromstring = multiarray.fromstring
  141. take = multiarray.take
  142. reshape = multiarray.reshape
  143. choose = multiarray.choose
  144. cross_correlate = multiarray.cross_correlate
  145.  
  146. def repeat(a, repeats, axis=0):
  147.     """repeat elements of a repeats times along axis
  148.        repeats is a sequence of length a.shape[axis]
  149.        telling how many times to repeat each element.
  150.        If repeats is an integer, it is interpreted as
  151.        a tuple of length a.shape[axis] containing repeats.
  152.        The argument a can be anything array(a) will accept.
  153.     """
  154.     a = array(a, copy=0)
  155.     s = a.shape
  156.     if isinstance(repeats, types.IntType):
  157.         repeats = tuple([repeats]*(s[axis]))
  158.     if len(repeats) != s[axis]:
  159.         raise ValueError, "repeat requires second argument integer or of length of a.shape[axis]."
  160.     d = multiarray.repeat(a, repeats, axis)
  161.     return d
  162.  
  163. def put (a, ind, v):
  164.     """put(a, ind, v) results in a[n] = v[n] for all n in ind
  165.        If v is shorter than mask it will be repeated as necessary.
  166.        In particular v can be a scalar or length 1 array.
  167.        The routine put is the equivalent of the following (although the loop
  168.        is in C for speed):
  169.  
  170.            ind = array(indices, copy=0)
  171.            v = array(values, copy=0).astype(a, typecode())
  172.            for i in ind: a.flat[i] = v[i]
  173.        a must be a contiguous Numeric array.
  174.     """
  175.     multiarray.put (a, ind, array(v, copy=0).astype(a.typecode()))
  176.  
  177. def putmask (a, mask, v):
  178.     """putmask(a, mask, v) results in a = v for all places mask is true.
  179.        If v is shorter than mask it will be repeated as necessary.
  180.        In particular v can be a scalar or length 1 array.
  181.     """
  182.     tc = a.typecode()
  183.     mask = asarray(mask).astype(Int)
  184.     v = array(v, copy=0).astype(tc)
  185.     if tc == PyObject:
  186.         if v.shape == (): v.shape=(1,)
  187.         ax = ravel(a)
  188.         mx = ravel(mask)
  189.         vx = ravel(v)
  190.         vx = resize(vx, ax.shape)
  191.         for i in range(len(ax)):
  192.             if mx[i]: ax[i] = vx[i]
  193.     else:
  194.         multiarray.putmask (a, mask, v)
  195.  
  196. def convolve(a,v,mode=2):
  197.     """Returns the discrete, linear convolution of 1-D
  198.     sequences a and v; mode can be 0 (valid), 1 (same), or 2 (full)
  199.     to specify size of the resulting sequence.
  200.     """
  201.     if (len(v) > len(a)):
  202.         temp = a
  203.         a = v
  204.         v = temp
  205.         del temp
  206.     return cross_correlate(a,asarray(v)[::-1],mode)
  207.  
  208. ArrayType = multiarray.arraytype
  209. UfuncType = type(sin)
  210.  
  211. def swapaxes(a, axis1, axis2):
  212.     """swapaxes(a, axis1, axis2) returns array a with axis1 and axis2
  213.     interchanged.
  214.     """
  215.     a = array(a, copy=0)
  216.     n = len(a.shape)
  217.     if n <= 1: return a
  218.     if axis1 < 0: axis1 += n
  219.     if axis2 < 0: axis2 += n
  220.     if axis1 < 0 or axis1 >= n:
  221.         raise ValueError, "Bad axis1 argument to swapaxes."
  222.     if axis2 < 0 or axis2 >= n:
  223.         raise ValueError, "Bad axis2 argument to swapaxes."
  224.     new_axes = arange(n)
  225.     new_axes[axis1] = axis2
  226.     new_axes[axis2] = axis1
  227.     return multiarray.transpose(a, new_axes)
  228.  
  229. arraytype = multiarray.arraytype
  230. #add extra intelligence to the basic C functions
  231. def concatenate(a, axis=0):
  232.     """concatenate(a, axis=0) joins the tuple of sequences in a into a single
  233.     NumPy array.
  234.     """
  235.     if axis == 0:
  236.         return multiarray.concatenate(a)
  237.     else:
  238.         new_list = []
  239.         for m in a:
  240.             new_list.append(swapaxes(m, axis, 0))
  241.     return swapaxes(multiarray.concatenate(new_list), axis, 0)
  242.  
  243. def transpose(a, axes=None):
  244.     """transpose(a, axes=None) returns array with dimensions permuted
  245.     according to axes.  If axes is None (default) returns array with
  246.     dimensions reversed.
  247.     """
  248. #    if axes is None: # this test has been moved into multiarray.transpose
  249. #        axes = arange(len(array(a).shape))[::-1]
  250.     return multiarray.transpose(a, axes)
  251.  
  252. def sort(a, axis=-1):
  253.     """sort(a,axis=-1) returns array with elements sorted along given axis.
  254.     """
  255.     a = array(a, copy=0)
  256.     n = len(a.shape)
  257.     if axis < 0: axis += n
  258.     if axis < 0 or axis >= n:
  259.         raise ValueError, "sort axis argument out of bounds"
  260.     if axis != n-1: a = swapaxes(a, axis, n-1)
  261.     s = multiarray.sort(a)
  262.     if axis != n-1: s = swapaxes(s, axis, -1)
  263.     return s
  264.  
  265. def argsort(a, axis=-1):
  266.     """argsort(a,axis=-1) return the indices into a of the sorted array
  267.     along the given axis, so that take(a,result,axis) is the sorted array.
  268.     """
  269.     a = array(a, copy=0)
  270.     n = len(a.shape)
  271.     if axis < 0: axis += n
  272.     if axis < 0 or axis >= n:
  273.         raise ValueError, "argsort axis argument out of bounds"
  274.     if axis != n-1: a = swapaxes(a, axis, n-1)
  275.     s = multiarray.argsort(a)
  276.     if axis != n-1: s = swapaxes(s, axis, -1)
  277.     return s
  278.  
  279. def argmax(a, axis=-1):
  280.     """argmax(a,axis=-1) returns the indices to the maximum value of the
  281.     1-D arrays along the given axis.
  282.     """
  283.     a = array(a, copy=0)
  284.     n = len(a.shape)
  285.     if axis < 0: axis += n
  286.     if axis < 0 or axis >= n:
  287.         raise ValueError, "argmax axis argument out of bounds"
  288.     if axis != n-1: a = swapaxes(a, axis, n-1)
  289.     s = multiarray.argmax(a)
  290.     if axis != n-1: s = swapaxes(s, axis, -1)
  291.     return s
  292.  
  293. def argmin(a, axis=-1):
  294.     """argmin(a,axis=-1) returns the indices to the minimum value of the
  295.     1-D arrays along the given axis.
  296.     """
  297.     arra = array(a,copy=0)
  298.     type = arra.typecode()
  299.     num = array(0,type)
  300.     if type in ['bwu']:
  301.         num = -array(1,type)
  302.     a = num-arra
  303.     n = len(a.shape)
  304.     if axis < 0: axis += n
  305.     if axis < 0 or axis >= n:
  306.         raise ValueError, "argmin axis argument out of bounds"
  307.     if axis != n-1: a = swapaxes(a, axis, n-1)
  308.     s = multiarray.argmax(a)
  309.     if axis != n-1: s = swapaxes(s, axis, -1)
  310.     return s
  311.  
  312.  
  313. searchsorted = multiarray.binarysearch
  314.  
  315. def innerproduct(a,b):
  316.     """innerproduct(a,b) returns the dot product of two arrays, which has
  317.     shape a.shape[:-1] + b.shape[:-1] with elements computed by summing the
  318.     product of the elements from the last dimensions of a and b.
  319.     """
  320.     try:
  321.         return multiarray.innerproduct(a,b)
  322.     except TypeError,detail:
  323.         if array(a).shape == () or array(b).shape == ():
  324.             return a*b
  325.         else:
  326.             raise TypeError, detail or "invalid types for dot"
  327.  
  328. def outerproduct(a,b):
  329.    """outerproduct(a,b) returns the outer product of two vectors.
  330.       result(i,j) = a(i)*b(j) when a and b are vectors
  331.       Will accept any arguments that can be made into vectors.
  332.    """
  333.    return array(a).flat[:,NewAxis]*array(b).flat[NewAxis,:]
  334.  
  335. #dot = multiarray.matrixproduct
  336. # how can I associate this doc string with the function dot?
  337. def dot(a, b):
  338.     """dot(a,b) returns matrix-multiplication between a and b.  The product-sum
  339.     is over the last dimension of a and the second-to-last dimension of b.
  340.     """
  341.     try:
  342.         return multiarray.matrixproduct(a, b)
  343.     except TypeError,detail:
  344.         if array(a).shape == () or array(b).shape == ():
  345.             return a*b
  346.         else:
  347.             raise TypeError, detail or "invalid types for dot"
  348.  
  349. def vdot(a, b):
  350.     """Returns the dot product of 2 vectors (or anything that can be made into
  351.        a vector). NB: this is not the same as `dot`, as it takes the conjugate
  352.        of its first argument if complex and always returns a scalar."""
  353.     return multiarray.matrixproduct(conjugate(ravel(a)),
  354.                                     ravel(b))
  355.  
  356. # try to import blas optimized dot, innerproduct and vdot, if available
  357. try:
  358.     from dotblas import dot, innerproduct, vdot
  359. except ImportError: pass
  360.  
  361. #This is obsolete, don't use in new code
  362. matrixmultiply = dot
  363.  
  364. def _move_axis_to_0(a, axis):
  365.     if axis == 0:
  366.         return a
  367.     n = len(a.shape)
  368.     if axis < 0:
  369.         axis += n
  370.     axes = range(1, axis+1) + [0,] + range(axis+1, n)
  371.     return multiarray.transpose(a, axes)
  372.  
  373. def cross_product(a, b, axis1=-1, axis2=-1):
  374.     """Return the cross product of two vectors.
  375.  
  376.     The cross product is performed over the last axes of a and b by default,
  377.     and can handle axes with dimensions 2 and 3. For a dimension of 2,
  378.     the z-component of the equivalent three-dimensional cross product is
  379.     returned.
  380.     """
  381.     a = _move_axis_to_0(asarray(a), axis1)
  382.     b = _move_axis_to_0(asarray(b), axis2)
  383.     if a.shape[0] != b.shape[0]:
  384.         raise ValueError("incompatible dimensions for cross product")
  385.     elif a.shape[0] == 2:
  386.         return a[0]*b[1] - a[1]*b[0]
  387.     elif a.shape[0] == 3:
  388.         x = a[1]*b[2] - a[2]*b[1]
  389.         y = a[2]*b[0] - a[0]*b[2]
  390.         z = a[0]*b[1] - a[1]*b[0]
  391.         cp = array([x,y,z])
  392.         if len(cp.shape) == 1:
  393.             return cp
  394.         # we put the result (in the first axis) as the last axis
  395.         axes = range(1, len(cp.shape)) + [0]
  396.         return multiarray.transpose(cp, axes)
  397.     else:
  398.         raise ValueError("can only do cross product for axes with dimensions 2 or 3")
  399.  
  400. #Use Konrad's printing function (modified for both str and repr now)
  401. from ArrayPrinter import array2string
  402. def array_repr(a, max_line_width = None, precision = None, suppress_small = None):
  403.     return array2string(a, max_line_width, precision, suppress_small, ', ', 1)
  404.  
  405. def array_str(a, max_line_width = None, precision = None, suppress_small = None):
  406.     return array2string(a, max_line_width, precision, suppress_small, ' ', 0)
  407.  
  408. multiarray.set_string_function(array_str, 0)
  409. multiarray.set_string_function(array_repr, 1)
  410.  
  411. #This is a nice value to have around
  412. #Maybe in sys some day
  413. LittleEndian = fromstring("\001"+"\000"*7, 'i')[0] == 1
  414.  
  415. def resize(a, new_shape):
  416.     """resize(a,new_shape) returns a new array with the specified shape.
  417.     The original array's total size can be any size.
  418.     """
  419.  
  420.     a = ravel(a)
  421.     if not len(a): return zeros(new_shape, a.typecode())
  422.     total_size = multiply.reduce(new_shape)
  423.     n_copies = int(total_size / len(a))
  424.     extra = total_size % len(a)
  425.  
  426.     if extra != 0:
  427.         n_copies = n_copies+1
  428.         extra = len(a)-extra
  429.  
  430.     a = concatenate( (a,)*n_copies)
  431.     if extra > 0:
  432.         a = a[:-extra]
  433.  
  434.     return reshape(a, new_shape)
  435.  
  436. def indices(dimensions, typecode=None):
  437.     """indices(dimensions,typecode=None) returns an array representing a grid
  438.     of indices with row-only, and column-only variation.
  439.     """
  440.     tmp = ones(dimensions, typecode)
  441.     lst = []
  442.     for i in range(len(dimensions)):
  443.         lst.append( add.accumulate(tmp, i, )-1 )
  444.     return array(lst)
  445.  
  446. def fromfunction(function, dimensions):
  447.     """fromfunction(function, dimensions) returns an array constructed by
  448.     calling function on a tuple of number grids.  The function should
  449.     accept as many arguments as there are dimensions which is a list of
  450.     numbers indicating the length of the desired output for each axis.
  451.     """
  452.     return apply(function, tuple(indices(dimensions)))
  453.  
  454.  
  455. def diagonal(a, offset= 0, axis1=0, axis2=1):
  456.     """diagonal(a, offset=0, axis1=0, axis2=1) returns all offset diagonals
  457.     defined by the given dimensions of the array.
  458.     """
  459.     a = asarray (a)
  460.     nd = len(a.shape)
  461.     new_axes = range(nd)
  462.     if (axis1 < 0): axis1 += nd
  463.     if (axis2 < 0): axis2 += nd
  464.     try:
  465.         new_axes.remove(axis1)
  466.         new_axes.remove(axis2)
  467.     except ValueError:
  468.             raise ValueError, "axis1(=%d) and axis2(=%d) must be different "\
  469.                   "and within range (nd=%d)." % (axis1, axis2, nd)
  470.     new_axes = new_axes + [axis1, axis2]
  471.     a = transpose(a, new_axes)
  472.     s = a.shape
  473.     if len (s) == 2:
  474.         n1 = s [0]
  475.         n2 = s [1]
  476.         n = n1 * n2
  477.         s = (n,)
  478.         a = reshape (a, s)
  479.         if offset < 0:
  480.             return take (a, range (-n2*offset, min(n2, n1+offset)*(n2+1) -
  481.                                    n2*offset, n2+1), 0)
  482.         else:
  483.             return take (a, range (offset,  min(n1, n2-offset) * (n2+1) +
  484.                                    offset, n2+1), 0)
  485.     else :
  486.         my_diagonal = []
  487.         for i in range (s [0]) :
  488.             my_diagonal.append (diagonal (a [i], offset))
  489.         return array (my_diagonal)
  490.  
  491. def trace(a, offset=0, axis1=0, axis2=1):
  492.     """trace(a,offset=0, axis1=0, axis2=1) returns the sum along diagonals
  493.     (defined by the last two dimensions) of the array.
  494.     """
  495.     return add.reduce(diagonal(a, offset, axis1, axis2),-1)
  496.  
  497.  
  498. # These two functions are used in my modified pickle.py so that
  499. # matrices can be pickled.  Notice that matrices are written in
  500. # binary format for efficiency, but that they pay attention to
  501. # byte-order issues for  portability.
  502.  
  503. def DumpArray(m, fp):
  504.     if m.typecode() == 'O':
  505.         raise TypeError, "Numeric Pickler can't pickle arrays of Objects"
  506.     s = m.shape
  507.     if LittleEndian: endian = "L"
  508.     else: endian = "B"
  509.     fp.write("A%s%s%d " % (m.typecode(), endian, m.itemsize()))
  510.     for d in s:
  511.         fp.write("%d "% d)
  512.     fp.write('\n')
  513.     fp.write(m.tostring())
  514.  
  515. def LoadArray(fp):
  516.     ln = string.split(fp.readline())
  517.     if ln[0][0] == 'A': ln[0] = ln[0][1:] # Nasty hack showing my ignorance of pickle
  518.     typecode = ln[0][0]
  519.     endian = ln[0][1]
  520.  
  521.     shape = map(lambda x: string.atoi(x), ln[1:])
  522.     itemsize = string.atoi(ln[0][2:])
  523.  
  524.     sz = reduce(multiply, shape)*itemsize
  525.     data = fp.read(sz)
  526.  
  527.     m = fromstring(data, typecode)
  528.     m = reshape(m, shape)
  529.  
  530.     if (LittleEndian and endian == 'B') or (not LittleEndian and endian == 'L'):
  531.         return m.byteswapped()
  532.     else:
  533.         return m
  534.  
  535. import pickle, copy
  536. class Unpickler(pickle.Unpickler):
  537.     def load_array(self):
  538.         self.stack.append(LoadArray(self))
  539.  
  540.     dispatch = copy.copy(pickle.Unpickler.dispatch)
  541.     dispatch['A'] = load_array
  542.  
  543. class Pickler(pickle.Pickler):
  544.     def save_array(self, object):
  545.         DumpArray(object, self)
  546.  
  547.     dispatch = copy.copy(pickle.Pickler.dispatch)
  548.     dispatch[ArrayType] = save_array
  549.  
  550. #Convenience functions
  551. from StringIO import StringIO
  552.  
  553. from pickle import load, loads, dump, dumps 
  554.  
  555. # slightly different format uses the copy_reg mechanism 
  556. import copy_reg 
  557.  
  558. def array_constructor(shape, typecode, thestr, Endian=LittleEndian): 
  559.     if typecode == "O":
  560.         x = array(thestr,"O") 
  561.     else: 
  562.         x = fromstring(thestr, typecode) 
  563.     x.shape = shape 
  564.     if LittleEndian != Endian: 
  565.         return x.byteswapped() 
  566.     else: 
  567.         return x 
  568.  
  569. def pickle_array(a): 
  570.     if a.typecode() == "O": 
  571.         return (array_constructor,  
  572.                 (a.shape, a.typecode(), a.tolist(), LittleEndian)) 
  573.     else: 
  574.         return (array_constructor,  
  575.                 (a.shape, a.typecode(), a.tostring(), LittleEndian))
  576.           
  577. copy_reg.pickle(ArrayType, pickle_array, array_constructor) 
  578.  
  579.  
  580. # These are all essentially abbreviations
  581. # These might wind up in a special abbreviations module
  582.  
  583. def ravel(m):
  584.     """ravel(m) returns a 1d array corresponding to all the elements of it's
  585.     argument.
  586.     """
  587.     return reshape(m, (-1,))
  588.  
  589. def nonzero(a):
  590.     """nonzero(a) returns the indices of the elements of a which are not zero,
  591.     a must be 1d
  592.     """
  593.     return repeat(arange(len(a)), not_equal(a, 0))
  594.  
  595. def shape(a):
  596.     """shape(a) returns the shape of a (as a function call which
  597.        also works on nested sequences).
  598.     """
  599.     return asarray(a).shape
  600.  
  601. def where(condition, x, y):
  602.     """where(condition,x,y) is shaped like condition and has elements of x and
  603.     y where condition is respectively true or false.
  604.     """
  605.     return choose(not_equal(condition, 0), (y, x))
  606.  
  607. def compress(condition, m, axis=-1):
  608.     """compress(condition, x, axis=-1) = those elements of x corresponding
  609.     to those elements of condition that are "true".  condition must be the
  610.     same size as the given dimension of x."""
  611.     return take(m, nonzero(condition), axis)
  612.  
  613. def clip(m, m_min, m_max):
  614.     """clip(m, m_min, m_max) = every entry in m that is less than m_min is
  615.     replaced by m_min, and every entry greater than m_max is replaced by
  616.     m_max.
  617.     """
  618.     selector = less(m, m_min)+2*greater(m, m_max)
  619.     return choose(selector, (m, m_min, m_max))
  620.  
  621. def ones(shape, typecode='l', savespace=0):
  622.     """ones(shape, typecode=Int, savespace=0) returns an array of the given
  623.     dimensions which is initialized to all ones.
  624.     """
  625.     a=zeros(shape, typecode, savespace)
  626.     a[...]=1
  627.     return a
  628.  
  629. def identity(n,typecode='l'):
  630.     """identity(n) returns the identity matrix of shape n x n.
  631.     """
  632.     return resize(array([1]+n*[0],typecode=typecode), (n,n))
  633.  
  634. def sum (x, axis=0):
  635.     """Sum the array over the given axis.
  636.     """
  637.     x = array(x, copy=0)
  638.     n = len(x.shape)
  639.     if axis < 0: axis += n
  640.     if n == 0 and axis in [0,-1]: return x[0]
  641.     if axis < 0 or axis >= n:
  642.         raise ValueError, 'Improper axis argument to sum.'
  643.     return add.reduce(x, axis)
  644.  
  645. def product (x, axis=0):
  646.     """Product of the array elements over the given axis."""
  647.     x = array(x, copy=0)
  648.     n = len(x.shape)
  649.     if axis < 0: axis += n
  650.     if n == 0 and axis in [0,-1]: return x[0]
  651.     if axis < 0 or axis >= n:
  652.         return ValueError, 'Improper axis argument to product.'
  653.     return multiply.reduce(x, axis)
  654.  
  655. def sometrue (x, axis=0):
  656.     """Perform a logical_or over the given axis."""
  657.     x = array(x, copy=0)
  658.     n = len(x.shape)
  659.     if axis < 0: axis += n
  660.     if n == 0 and axis in [0,-1]: return x[0] != 0
  661.     if axis < 0 or axis >= n:
  662.         return ValueError, 'Improper axis argument to sometrue.'
  663.     return logical_or.reduce(x, axis)
  664.  
  665. def alltrue (x, axis=0):
  666.     """Perform a logical_and over the given axis."""
  667.     x = array(x, copy=0)
  668.     n = len(x.shape)
  669.     if axis < 0: axis += n
  670.     if n == 0 and axis in [0,-1]: return x[0] != 0
  671.     if axis < 0 or axis >= n:
  672.         return ValueError, 'Improper axis argument to product.'
  673.     return logical_and.reduce(x, axis)
  674.  
  675. def cumsum (x, axis=0):
  676.     """Sum the array over the given axis."""
  677.     x = array(x, copy=0)
  678.     n = len(x.shape)
  679.     if axis < 0: axis += n
  680.     if n == 0 and axis in [0,-1]: return x[0]
  681.     if axis < 0 or axis >= n:
  682.         return ValueError, 'Improper axis argument to cumsum.'
  683.     return add.accumulate(x, axis)
  684.  
  685. def cumproduct (x, axis=0):
  686.     """Sum the array over the given axis."""
  687.     x = array(x, copy=0)
  688.     n = len(x.shape)
  689.     if axis < 0: axis += n
  690.     if n == 0 and axis in [0,-1]: return x[0]
  691.     if axis < 0 or axis >= n:
  692.         return ValueError, 'Improper axis argument to cumproduct.'
  693.     return multiply.accumulate(x, axis)
  694.  
  695. arange = multiarray.arange
  696.  
  697. def around(m, decimals=0):
  698.     """around(m, decimals=0) \
  699.     Round in the same way as standard python performs rounding. Returns
  700.     always a float.
  701.     """
  702.     m = asarray(m)
  703.     s = sign(m)
  704.     if decimals:
  705.         m = absolute(m*10.**decimals)
  706.     else:
  707.         m = absolute(m)
  708.     rem = m-asarray(m).astype(Int)
  709.     m = where(less(rem,0.5), floor(m), ceil(m))
  710.     # convert back
  711.     if decimals:
  712.         m = m*s/(10.**decimals)
  713.     else:
  714.         m = m*s
  715.     return m
  716.  
  717. def sign(m):
  718.     """sign(m) gives an array with shape of m with elements defined by sign
  719.     function:  where m is less than 0 return -1, where m greater than 0, a=1,
  720.     elsewhere a=0.
  721.     """
  722.     m = asarray(m)
  723.     return zeros(shape(m))-less(m,0)+greater(m,0)
  724.  
  725. def allclose (a, b, rtol=1.e-5, atol=1.e-8):
  726.     """ allclose(a,b,rtol=1.e-5,atol=1.e-8)
  727.         Returns true if all components of a and b are equal
  728.         subject to given tolerances.
  729.         The relative error rtol must be positive and << 1.0
  730.         The absolute error atol comes into play for those elements
  731.         of y that are very small or zero; it says how small x must be also.
  732.     """
  733.     x = array(a, copy=0)
  734.     y = array(b, copy=0)
  735.     d = less(absolute(x-y), atol + rtol * absolute(y))
  736.     return alltrue(ravel(d))
  737.  
  738. def rank (a):
  739.     """Get the rank of sequence a (the number of dimensions, not a matrix rank)
  740.        The rank of a scalar is zero.
  741.     """
  742.     return len(shape(a))
  743.  
  744. def shape (a):
  745.     "Get the shape of sequence a"
  746.     try:
  747.         return a.shape
  748.     except:
  749.         return array(a).shape
  750.  
  751. def size (a, axis=None):
  752.     "Get the number of elements in sequence a, or along a certain axis."
  753.     s = shape(a)
  754.     if axis is None:
  755.         if len(s)==0:
  756.             return 1
  757.         else:
  758.             return reduce(lambda x,y:x*y, s)
  759.     else:
  760.         return s[axis]
  761.  
  762. def average (a, axis=0, weights=None, returned = 0):
  763.     """average(a, axis=0, weights=None)
  764.        Computes average along indicated axis.
  765.        If axis is None, average over the entire array.
  766.        Inputs can be integer or floating types; result is type Float.
  767.  
  768.        If weights are given, result is:
  769.            sum(a*weights)/(sum(weights))
  770.        weights must have a's shape or be the 1-d with length the size
  771.        of a in the given axis. Integer weights are converted to Float.
  772.  
  773.        Not supplying weights is equivalent to supply weights that are
  774.        all 1.
  775.  
  776.        If returned, return a tuple: the result and the sum of the weights
  777.        or count of values. The shape of these two results will be the same.
  778.  
  779.        raises ZeroDivisionError if appropriate when result is scalar.
  780.        (The version in MA does not -- it returns masked values).
  781.     """
  782.     if axis is None:
  783.         a = array(a).flat
  784.         if weights is None:
  785.             n = add.reduce(a)
  786.             d = len(a) * 1.0
  787.         else:
  788.             w = array(weights, typecode=Float, copy=0).flat
  789.             n = add.reduce(a*w)
  790.             d = add.reduce(w)
  791.     else:
  792.         a = array(a)
  793.         ash = a.shape
  794.         if ash == ():
  795.             a.shape = (1,)
  796.         if weights is None:
  797.             n = add.reduce(a, axis)
  798.             d = ash[axis] * 1.0
  799.             if returned:
  800.                 d = ones(shape(n)) * d
  801.         else:
  802.             w = array(weights, copy=0) * 1.0
  803.             wsh = w.shape
  804.             if wsh == ():
  805.                 wsh = (1,)
  806.             if wsh == ash:
  807.                 n = add.reduce(a*w, axis)
  808.                 d = add.reduce(w, axis)
  809.             elif wsh == (ash[axis],):
  810.                 ni = ash[axis]
  811.                 r = [NewAxis]*ni
  812.                 r[axis] = slice(None,None,1)
  813.                 w1 = eval("w["+repr(tuple(r))+"]*ones(ash, Float)")
  814.                 n = add.reduce(a*w1, axis)
  815.                 d = add.reduce(w1, axis)
  816.             else:
  817.                 raise ValueError, 'average: weights wrong shape.'
  818.  
  819.     if not isinstance(d, ArrayType):
  820.         if d == 0.0:
  821.             raise ZeroDivisionError, 'Numeric.average, zero denominator'
  822.     if returned:
  823.         return n/d, d
  824.     else:
  825.         return n/d
  826.