home *** CD-ROM | disk | FTP | other *** search
/ OS/2 Shareware BBS: 10 Tools / 10-Tools.zip / octa21fs.zip / octave / octave-2.1.23 / liboctave / CDiagMatrix.cc < prev    next >
C/C++ Source or Header  |  2000-01-15  |  15KB  |  737 lines

  1. // DiagMatrix manipulations.
  2. /*
  3.  
  4. Copyright (C) 1996, 1997 John W. Eaton
  5.  
  6. This file is part of Octave.
  7.  
  8. Octave is free software; you can redistribute it and/or modify it
  9. under the terms of the GNU General Public License as published by the
  10. Free Software Foundation; either version 2, or (at your option) any
  11. later version.
  12.  
  13. Octave is distributed in the hope that it will be useful, but WITHOUT
  14. ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  15. FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  16. for more details.
  17.  
  18. You should have received a copy of the GNU General Public License
  19. along with Octave; see the file COPYING.  If not, write to the Free
  20. Software Foundation, 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
  21.  
  22. */
  23.  
  24. #if defined (__GNUG__)
  25. #pragma implementation
  26. #endif
  27.  
  28. #ifdef HAVE_CONFIG_H
  29. #include <config.h>
  30. #endif
  31.  
  32. #include <iostream.h>
  33.  
  34. #include "lo-error.h"
  35. #include "mx-base.h"
  36. #include "mx-inlines.cc"
  37. #include "oct-cmplx.h"
  38.  
  39. // Complex Diagonal Matrix class
  40.  
  41. ComplexDiagMatrix::ComplexDiagMatrix (const DiagMatrix& a)
  42.   : MDiagArray2<Complex> (a.rows (), a.cols ())
  43. {
  44.   for (int i = 0; i < length (); i++)
  45.     elem (i, i) = a.elem (i, i);
  46. }
  47.  
  48. bool
  49. ComplexDiagMatrix::operator == (const ComplexDiagMatrix& a) const
  50. {
  51.   if (rows () != a.rows () || cols () != a.cols ())
  52.     return 0;
  53.  
  54.   return equal (data (), a.data (), length ());
  55. }
  56.  
  57. bool
  58. ComplexDiagMatrix::operator != (const ComplexDiagMatrix& a) const
  59. {
  60.   return !(*this == a);
  61. }
  62.  
  63. ComplexDiagMatrix&
  64. ComplexDiagMatrix::fill (double val)
  65. {
  66.   for (int i = 0; i < length (); i++)
  67.     elem (i, i) = val;
  68.   return *this;
  69. }
  70.  
  71. ComplexDiagMatrix&
  72. ComplexDiagMatrix::fill (const Complex& val)
  73. {
  74.   for (int i = 0; i < length (); i++)
  75.     elem (i, i) = val;
  76.   return *this;
  77. }
  78.  
  79. ComplexDiagMatrix&
  80. ComplexDiagMatrix::fill (double val, int beg, int end)
  81. {
  82.   if (beg < 0 || end >= length () || end < beg)
  83.     {
  84.       (*current_liboctave_error_handler) ("range error for fill");
  85.       return *this;
  86.     }
  87.  
  88.   for (int i = beg; i <= end; i++)
  89.     elem (i, i) = val;
  90.  
  91.   return *this;
  92. }
  93.  
  94. ComplexDiagMatrix&
  95. ComplexDiagMatrix::fill (const Complex& val, int beg, int end)
  96. {
  97.   if (beg < 0 || end >= length () || end < beg)
  98.     {
  99.       (*current_liboctave_error_handler) ("range error for fill");
  100.       return *this;
  101.     }
  102.  
  103.   for (int i = beg; i <= end; i++)
  104.     elem (i, i) = val;
  105.  
  106.   return *this;
  107. }
  108.  
  109. ComplexDiagMatrix&
  110. ComplexDiagMatrix::fill (const ColumnVector& a)
  111. {
  112.   int len = length ();
  113.   if (a.length () != len)
  114.     {
  115.       (*current_liboctave_error_handler) ("range error for fill");
  116.       return *this;
  117.     }
  118.  
  119.   for (int i = 0; i < len; i++)
  120.     elem (i, i) = a.elem (i);
  121.  
  122.   return *this;
  123. }
  124.  
  125. ComplexDiagMatrix&
  126. ComplexDiagMatrix::fill (const ComplexColumnVector& a)
  127. {
  128.   int len = length ();
  129.   if (a.length () != len)
  130.     {
  131.       (*current_liboctave_error_handler) ("range error for fill");
  132.       return *this;
  133.     }
  134.  
  135.   for (int i = 0; i < len; i++)
  136.     elem (i, i) = a.elem (i);
  137.  
  138.   return *this;
  139. }
  140.  
  141. ComplexDiagMatrix&
  142. ComplexDiagMatrix::fill (const RowVector& a)
  143. {
  144.   int len = length ();
  145.   if (a.length () != len)
  146.     {
  147.       (*current_liboctave_error_handler) ("range error for fill");
  148.       return *this;
  149.     }
  150.  
  151.   for (int i = 0; i < len; i++)
  152.     elem (i, i) = a.elem (i);
  153.  
  154.   return *this;
  155. }
  156.  
  157. ComplexDiagMatrix&
  158. ComplexDiagMatrix::fill (const ComplexRowVector& a)
  159. {
  160.   int len = length ();
  161.   if (a.length () != len)
  162.     {
  163.       (*current_liboctave_error_handler) ("range error for fill");
  164.       return *this;
  165.     }
  166.  
  167.   for (int i = 0; i < len; i++)
  168.     elem (i, i) = a.elem (i);
  169.  
  170.   return *this;
  171. }
  172.  
  173. ComplexDiagMatrix&
  174. ComplexDiagMatrix::fill (const ColumnVector& a, int beg)
  175. {
  176.   int a_len = a.length ();
  177.   if (beg < 0 || beg + a_len >= length ())
  178.     {
  179.       (*current_liboctave_error_handler) ("range error for fill");
  180.       return *this;
  181.     }
  182.  
  183.   for (int i = 0; i < a_len; i++)
  184.     elem (i+beg, i+beg) = a.elem (i);
  185.  
  186.   return *this;
  187. }
  188.  
  189. ComplexDiagMatrix&
  190. ComplexDiagMatrix::fill (const ComplexColumnVector& a, int beg)
  191. {
  192.   int a_len = a.length ();
  193.   if (beg < 0 || beg + a_len >= length ())
  194.     {
  195.       (*current_liboctave_error_handler) ("range error for fill");
  196.       return *this;
  197.     }
  198.  
  199.   for (int i = 0; i < a_len; i++)
  200.     elem (i+beg, i+beg) = a.elem (i);
  201.  
  202.   return *this;
  203. }
  204.  
  205. ComplexDiagMatrix&
  206. ComplexDiagMatrix::fill (const RowVector& a, int beg)
  207. {
  208.   int a_len = a.length ();
  209.   if (beg < 0 || beg + a_len >= length ())
  210.     {
  211.       (*current_liboctave_error_handler) ("range error for fill");
  212.       return *this;
  213.     }
  214.  
  215.   for (int i = 0; i < a_len; i++)
  216.     elem (i+beg, i+beg) = a.elem (i);
  217.  
  218.   return *this;
  219. }
  220.  
  221. ComplexDiagMatrix&
  222. ComplexDiagMatrix::fill (const ComplexRowVector& a, int beg)
  223. {
  224.   int a_len = a.length ();
  225.   if (beg < 0 || beg + a_len >= length ())
  226.     {
  227.       (*current_liboctave_error_handler) ("range error for fill");
  228.       return *this;
  229.     }
  230.  
  231.   for (int i = 0; i < a_len; i++)
  232.     elem (i+beg, i+beg) = a.elem (i);
  233.  
  234.   return *this;
  235. }
  236.  
  237. ComplexDiagMatrix
  238. ComplexDiagMatrix::hermitian (void) const
  239. {
  240.   return ComplexDiagMatrix (conj_dup (data (), length ()), cols (), rows ());
  241. }
  242.  
  243. ComplexDiagMatrix
  244. ComplexDiagMatrix::transpose (void) const
  245. {
  246.   return ComplexDiagMatrix (dup (data (), length ()), cols (), rows ());
  247. }
  248.  
  249. ComplexDiagMatrix
  250. conj (const ComplexDiagMatrix& a)
  251. {
  252.   ComplexDiagMatrix retval;
  253.   int a_len = a.length ();
  254.   if (a_len > 0)
  255.     retval = ComplexDiagMatrix (conj_dup (a.data (), a_len),
  256.                 a.rows (), a.cols ());
  257.   return retval;
  258. }
  259.  
  260. // resize is the destructive analog for this one
  261.  
  262. ComplexMatrix
  263. ComplexDiagMatrix::extract (int r1, int c1, int r2, int c2) const
  264. {
  265.   if (r1 > r2) { int tmp = r1; r1 = r2; r2 = tmp; }
  266.   if (c1 > c2) { int tmp = c1; c1 = c2; c2 = tmp; }
  267.  
  268.   int new_r = r2 - r1 + 1;
  269.   int new_c = c2 - c1 + 1;
  270.  
  271.   ComplexMatrix result (new_r, new_c);
  272.  
  273.   for (int j = 0; j < new_c; j++)
  274.     for (int i = 0; i < new_r; i++)
  275.       result.elem (i, j) = elem (r1+i, c1+j);
  276.  
  277.   return result;
  278. }
  279.  
  280. // extract row or column i.
  281.  
  282. ComplexRowVector
  283. ComplexDiagMatrix::row (int i) const
  284. {
  285.   int nr = rows ();
  286.   int nc = cols ();
  287.   if (i < 0 || i >= nr)
  288.     {
  289.       (*current_liboctave_error_handler) ("invalid row selection");
  290.       return RowVector (); 
  291.     }
  292.  
  293.   ComplexRowVector retval (nc, 0.0);
  294.   if (nr <= nc || (nr > nc && i < nc))
  295.     retval.elem (i) = elem (i, i);
  296.  
  297.   return retval;
  298. }
  299.  
  300. ComplexRowVector
  301. ComplexDiagMatrix::row (char *s) const
  302. {
  303.   if (! s)
  304.     {
  305.       (*current_liboctave_error_handler) ("invalid row selection");
  306.       return ComplexRowVector (); 
  307.     }
  308.  
  309.   char c = *s;
  310.   if (c == 'f' || c == 'F')
  311.     return row (0);
  312.   else if (c == 'l' || c == 'L')
  313.     return row (rows () - 1);
  314.   else
  315.     {
  316.       (*current_liboctave_error_handler) ("invalid row selection");
  317.       return ComplexRowVector ();
  318.     }
  319. }
  320.  
  321. ComplexColumnVector
  322. ComplexDiagMatrix::column (int i) const
  323. {
  324.   int nr = rows ();
  325.   int nc = cols ();
  326.   if (i < 0 || i >= nc)
  327.     {
  328.       (*current_liboctave_error_handler) ("invalid column selection");
  329.       return ColumnVector (); 
  330.     }
  331.  
  332.   ComplexColumnVector retval (nr, 0.0);
  333.   if (nr >= nc || (nr < nc && i < nr))
  334.     retval.elem (i) = elem (i, i);
  335.  
  336.   return retval;
  337. }
  338.  
  339. ComplexColumnVector
  340. ComplexDiagMatrix::column (char *s) const
  341. {
  342.   if (! s)
  343.     {
  344.       (*current_liboctave_error_handler) ("invalid column selection");
  345.       return ColumnVector (); 
  346.     }
  347.  
  348.   char c = *s;
  349.   if (c == 'f' || c == 'F')
  350.     return column (0);
  351.   else if (c == 'l' || c == 'L')
  352.     return column (cols () - 1);
  353.   else
  354.     {
  355.       (*current_liboctave_error_handler) ("invalid column selection");
  356.       return ColumnVector (); 
  357.     }
  358. }
  359.  
  360. ComplexDiagMatrix
  361. ComplexDiagMatrix::inverse (void) const
  362. {
  363.   int info;
  364.   return inverse (info);
  365. }
  366.  
  367. ComplexDiagMatrix
  368. ComplexDiagMatrix::inverse (int& info) const
  369. {
  370.   int nr = rows ();
  371.   int nc = cols ();
  372.   if (nr != nc)
  373.     {
  374.       (*current_liboctave_error_handler) ("inverse requires square matrix");
  375.       return DiagMatrix ();
  376.     }
  377.  
  378.   ComplexDiagMatrix retval (nr, nc);
  379.  
  380.   info = 0;
  381.   for (int i = 0; i < length (); i++)
  382.     {
  383.       if (elem (i, i) == 0.0)
  384.     {
  385.       info = -1;
  386.       return *this;
  387.     }
  388.       else
  389.     retval.elem (i, i) = 1.0 / elem (i, i);
  390.     }
  391.  
  392.   return retval;
  393. }
  394.  
  395. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  396.  
  397. ComplexDiagMatrix&
  398. ComplexDiagMatrix::operator += (const DiagMatrix& a)
  399. {
  400.   int nr = rows ();
  401.   int nc = cols ();
  402.  
  403.   int a_nr = a.rows ();
  404.   int a_nc = a.cols ();
  405.  
  406.   if (nr != a_nr || nc != a_nc)
  407.     {
  408.       gripe_nonconformant ("operator +=", nr, nc, a_nr, a_nc);
  409.       return *this;
  410.     }
  411.  
  412.   if (nr == 0 || nc == 0)
  413.     return *this;
  414.  
  415.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  416.  
  417.   add2 (d, a.data (), length ());
  418.   return *this;
  419. }
  420.  
  421. ComplexDiagMatrix&
  422. ComplexDiagMatrix::operator -= (const DiagMatrix& a)
  423. {
  424.   int nr = rows ();
  425.   int nc = cols ();
  426.  
  427.   int a_nr = a.rows ();
  428.   int a_nc = a.cols ();
  429.  
  430.   if (nr != a_nr || nc != a_nc)
  431.     {
  432.       gripe_nonconformant ("operator -=", nr, nc, a_nr, a_nc);
  433.       return *this;
  434.     }
  435.  
  436.   if (nr == 0 || nc == 0)
  437.     return *this;
  438.  
  439.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  440.  
  441.   subtract2 (d, a.data (), length ());
  442.   return *this;
  443. }
  444.  
  445. ComplexDiagMatrix&
  446. ComplexDiagMatrix::operator += (const ComplexDiagMatrix& a)
  447. {
  448.   int nr = rows ();
  449.   int nc = cols ();
  450.  
  451.   int a_nr = a.rows ();
  452.   int a_nc = a.cols ();
  453.  
  454.   if (nr != a_nr || nc != a_nc)
  455.     {
  456.       gripe_nonconformant ("operator +=", nr, nc, a_nr, a_nc);
  457.       return *this;
  458.     }
  459.  
  460.   if (nr == 0 || nc == 0)
  461.     return *this;
  462.  
  463.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  464.  
  465.   add2 (d, a.data (), length ());
  466.   return *this;
  467. }
  468.  
  469. ComplexDiagMatrix&
  470. ComplexDiagMatrix::operator -= (const ComplexDiagMatrix& a)
  471. {
  472.   int nr = rows ();
  473.   int nc = cols ();
  474.  
  475.   int a_nr = a.rows ();
  476.   int a_nc = a.cols ();
  477.  
  478.   if (nr != a_nr || nc != a_nc)
  479.     {
  480.       gripe_nonconformant ("operator -=", nr, nc, a_nr, a_nc);
  481.       return *this;
  482.     }
  483.  
  484.   if (nr == 0 || nc == 0)
  485.     return *this;
  486.  
  487.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  488.  
  489.   subtract2 (d, a.data (), length ());
  490.   return *this;
  491. }
  492.  
  493. // diagonal matrix by scalar -> diagonal matrix operations
  494.  
  495. ComplexDiagMatrix
  496. operator * (const ComplexDiagMatrix& a, double s)
  497. {
  498.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  499.                 a.rows (), a.cols ());
  500. }
  501.  
  502. ComplexDiagMatrix
  503. operator / (const ComplexDiagMatrix& a, double s)
  504. {
  505.   return ComplexDiagMatrix (divide (a.data (), a.length (), s),
  506.                 a.rows (), a.cols ());
  507. }
  508.  
  509. ComplexDiagMatrix
  510. operator * (const DiagMatrix& a, const Complex& s)
  511. {
  512.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  513.                 a.rows (), a.cols ());
  514. }
  515.  
  516. ComplexDiagMatrix
  517. operator / (const DiagMatrix& a, const Complex& s)
  518. {
  519.   return ComplexDiagMatrix (divide (a.data (), a.length (), s),
  520.                 a.rows (), a.cols ());
  521. }
  522.  
  523. // scalar by diagonal matrix -> diagonal matrix operations
  524.  
  525. ComplexDiagMatrix
  526. operator * (double s, const ComplexDiagMatrix& a)
  527. {
  528.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  529.                 a.rows (), a.cols ());
  530. }
  531.  
  532. ComplexDiagMatrix
  533. operator * (const Complex& s, const DiagMatrix& a)
  534. {
  535.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  536.                 a.rows (), a.cols ());
  537. }
  538.  
  539. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  540.  
  541. ComplexDiagMatrix
  542. operator * (const ComplexDiagMatrix& a, const ComplexDiagMatrix& b)
  543. {
  544.   int nr_a = a.rows ();
  545.   int nc_a = a.cols ();
  546.  
  547.   int nr_b = b.rows ();
  548.   int nc_b = b.cols ();
  549.  
  550.   if (nc_a != nr_b)
  551.     {
  552.       gripe_nonconformant ("operator *", nr_a, nc_a, nr_b, nc_b);
  553.       return ComplexDiagMatrix ();
  554.     }
  555.  
  556.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  557.     return ComplexDiagMatrix (nr_a, nc_a, 0.0);
  558.  
  559.   ComplexDiagMatrix c (nr_a, nc_b);
  560.  
  561.   int len = nr_a < nc_b ? nr_a : nc_b;
  562.  
  563.   for (int i = 0; i < len; i++)
  564.     {
  565.       Complex a_element = a.elem (i, i);
  566.       Complex b_element = b.elem (i, i);
  567.  
  568.       if (a_element == 0.0 || b_element == 0.0)
  569.         c.elem (i, i) = 0.0;
  570.       else if (a_element == 1.0)
  571.         c.elem (i, i) = b_element;
  572.       else if (b_element == 1.0)
  573.         c.elem (i, i) = a_element;
  574.       else
  575.         c.elem (i, i) = a_element * b_element;
  576.     }
  577.  
  578.   return c;
  579. }
  580.  
  581. ComplexDiagMatrix
  582. operator * (const ComplexDiagMatrix& a, const DiagMatrix& b)
  583. {
  584.   int nr_a = a.rows ();
  585.   int nc_a = a.cols ();
  586.  
  587.   int nr_b = b.rows ();
  588.   int nc_b = b.cols ();
  589.  
  590.   if (nc_a != nr_b)
  591.     {
  592.       gripe_nonconformant ("operator *", nr_a, nc_a, nr_b, nc_b);
  593.       return ComplexDiagMatrix ();
  594.     }
  595.  
  596.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  597.     return ComplexDiagMatrix (nr_a, nc_a, 0.0);
  598.  
  599.   ComplexDiagMatrix c (nr_a, nc_b);
  600.  
  601.   int len = nr_a < nc_b ? nr_a : nc_b;
  602.  
  603.   for (int i = 0; i < len; i++)
  604.     {
  605.       Complex a_element = a.elem (i, i);
  606.       double b_element = b.elem (i, i);
  607.  
  608.       if (a_element == 0.0 || b_element == 0.0)
  609.         c.elem (i, i) = 0.0;
  610.       else if (a_element == 1.0)
  611.         c.elem (i, i) = b_element;
  612.       else if (b_element == 1.0)
  613.         c.elem (i, i) = a_element;
  614.       else
  615.         c.elem (i, i) = a_element * b_element;
  616.     }
  617.  
  618.   return c;
  619. }
  620.  
  621. ComplexDiagMatrix
  622. operator * (const DiagMatrix& a, const ComplexDiagMatrix& b)
  623. {
  624.   int nr_a = a.rows ();
  625.   int nc_a = a.cols ();
  626.  
  627.   int nr_b = b.rows ();
  628.   int nc_b = b.cols ();
  629.  
  630.   if (nc_a != nr_b)
  631.     {
  632.       gripe_nonconformant ("operator *", nr_a, nc_a, nr_b, nc_b);
  633.       return ComplexDiagMatrix ();
  634.     }
  635.  
  636.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  637.     return ComplexDiagMatrix (nr_a, nc_a, 0.0);
  638.  
  639.   ComplexDiagMatrix c (nr_a, nc_b);
  640.  
  641.   int len = nr_a < nc_b ? nr_a : nc_b;
  642.  
  643.   for (int i = 0; i < len; i++)
  644.     {
  645.       double a_element = a.elem (i, i);
  646.       Complex b_element = b.elem (i, i);
  647.  
  648.       if (a_element == 0.0 || b_element == 0.0)
  649.         c.elem (i, i) = 0.0;
  650.       else if (a_element == 1.0)
  651.         c.elem (i, i) = b_element;
  652.       else if (b_element == 1.0)
  653.         c.elem (i, i) = a_element;
  654.       else
  655.         c.elem (i, i) = a_element * b_element;
  656.     }
  657.  
  658.   return c;
  659. }
  660.  
  661. // other operations
  662.  
  663. ComplexColumnVector
  664. ComplexDiagMatrix::diag (void) const
  665. {
  666.   return diag (0);
  667. }
  668.  
  669. // Could be optimized...
  670.  
  671. ComplexColumnVector
  672. ComplexDiagMatrix::diag (int k) const
  673. {
  674.   int nnr = rows ();
  675.   int nnc = cols ();
  676.   if (k > 0)
  677.     nnc -= k;
  678.   else if (k < 0)
  679.     nnr += k;
  680.  
  681.   ComplexColumnVector d;
  682.  
  683.   if (nnr > 0 && nnc > 0)
  684.     {
  685.       int ndiag = (nnr < nnc) ? nnr : nnc;
  686.  
  687.       d.resize (ndiag);
  688.  
  689.       if (k > 0)
  690.     {
  691.       for (int i = 0; i < ndiag; i++)
  692.         d.elem (i) = elem (i, i+k);
  693.     }
  694.       else if ( k < 0)
  695.     {
  696.       for (int i = 0; i < ndiag; i++)
  697.         d.elem (i) = elem (i-k, i);
  698.     }
  699.       else
  700.     {
  701.       for (int i = 0; i < ndiag; i++)
  702.         d.elem (i) = elem (i, i);
  703.     }
  704.     }
  705.   else
  706.     cerr << "diag: requested diagonal out of range\n";
  707.  
  708.   return d;
  709. }
  710.  
  711. // i/o
  712.  
  713. ostream&
  714. operator << (ostream& os, const ComplexDiagMatrix& a)
  715. {
  716.   Complex ZERO (0.0);
  717. //  int field_width = os.precision () + 7;
  718.   for (int i = 0; i < a.rows (); i++)
  719.     {
  720.       for (int j = 0; j < a.cols (); j++)
  721.     {
  722.       if (i == j)
  723.         os << " " /* setw (field_width) */ << a.elem (i, i);
  724.       else
  725.         os << " " /* setw (field_width) */ << ZERO;
  726.     }
  727.       os << "\n";
  728.     }
  729.   return os;
  730. }
  731.  
  732. /*
  733. ;;; Local Variables: ***
  734. ;;; mode: C++ ***
  735. ;;; End: ***
  736. */
  737.