home *** CD-ROM | disk | FTP | other *** search
/ Fresh Fish 9 / FreshFishVol9-CD2.bin / bbs / gnu / octave-1.1.1-src.lha / octave-1.1.1 / liboctave / dDiagMatrix.cc < prev    next >
Encoding:
C/C++ Source or Header  |  1995-01-04  |  18.5 KB  |  940 lines

  1. // DiagMatrix manipulations.                             -*- C++ -*-
  2. /*
  3.  
  4. Copyright (C) 1992, 1993, 1994, 1995 John W. Eaton
  5.  
  6. This file is part of Octave.
  7.  
  8. Octave is free software; you can redistribute it and/or modify it
  9. under the terms of the GNU General Public License as published by the
  10. Free Software Foundation; either version 2, or (at your option) any
  11. later version.
  12.  
  13. Octave is distributed in the hope that it will be useful, but WITHOUT
  14. ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  15. FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  16. for more details.
  17.  
  18. You should have received a copy of the GNU General Public License
  19. along with Octave; see the file COPYING.  If not, write to the Free
  20. Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
  21.  
  22. */
  23.  
  24. #ifdef HAVE_CONFIG_H
  25. #include "config.h"
  26. #endif
  27.  
  28. #include <iostream.h>
  29.  
  30. #include <Complex.h>
  31.  
  32. #include "mx-base.h"
  33. #include "mx-inlines.cc"
  34. #include "lo-error.h"
  35.  
  36. /*
  37.  * Diagonal Matrix class.
  38.  */
  39.  
  40. #define KLUDGE_DIAG_MATRICES
  41. #define TYPE double
  42. #define KL_DMAT_TYPE DiagMatrix
  43. #include "mx-kludge.cc"
  44. #undef KLUDGE_DIAG_MATRICES
  45. #undef TYPE
  46. #undef KL_DMAT_TYPE
  47.  
  48. #if 0
  49. DiagMatrix&
  50. DiagMatrix::resize (int r, int c)
  51. {
  52.   if (r < 0 || c < 0)
  53.     {
  54.       (*current_liboctave_error_handler)
  55.     ("can't resize to negative dimensions");
  56.       return *this;
  57.     }
  58.  
  59.   int new_len = r < c ? r : c;
  60.   double *new_data = 0;
  61.   if (new_len > 0)
  62.     {
  63.       new_data = new double [new_len];
  64.  
  65.       int min_len = new_len < len ? new_len : len;
  66.  
  67.       for (int i = 0; i < min_len; i++)
  68.     new_data[i] = data[i];
  69.     }
  70.  
  71.   delete [] data;
  72.   nr = r;
  73.   nc = c;
  74.   len = new_len;
  75.   data = new_data;
  76.  
  77.   return *this;
  78. }
  79.  
  80. DiagMatrix&
  81. DiagMatrix::resize (int r, int c, double val)
  82. {
  83.   if (r < 0 || c < 0)
  84.     {
  85.       (*current_liboctave_error_handler)
  86.     ("can't resize to negative dimensions");
  87.       return *this;
  88.     }
  89.  
  90.   int new_len = r < c ? r : c;
  91.   double *new_data = 0;
  92.   if (new_len > 0)
  93.     {
  94.       new_data = new double [new_len];
  95.  
  96.       int min_len = new_len < len ? new_len : len;
  97.  
  98.       for (int i = 0; i < min_len; i++)
  99.     new_data[i] = data[i];
  100.  
  101.       for (i = min_len; i < new_len; i++)
  102.     new_data[i] = val;
  103.     }
  104.  
  105.   delete [] data;
  106.   nr = r;
  107.   nc = c;
  108.   len = new_len;
  109.   data = new_data;
  110.  
  111.   return *this;
  112. }
  113. #endif
  114.  
  115. int
  116. DiagMatrix::operator == (const DiagMatrix& a) const
  117. {
  118.   if (rows () != a.rows () || cols () != a.cols ())
  119.     return 0;
  120.  
  121.   return equal (data (), a.data (), length ());
  122. }
  123.  
  124. int
  125. DiagMatrix::operator != (const DiagMatrix& a) const
  126. {
  127.   return !(*this == a);
  128. }
  129.  
  130. DiagMatrix&
  131. DiagMatrix::fill (double val)
  132. {
  133.   for (int i = 0; i < length (); i++)
  134.     elem (i, i) = val;
  135.   return *this;
  136. }
  137.  
  138. DiagMatrix&
  139. DiagMatrix::fill (double val, int beg, int end)
  140. {
  141.   if (beg < 0 || end >= length () || end < beg)
  142.     {
  143.       (*current_liboctave_error_handler) ("range error for fill");
  144.       return *this;
  145.     }
  146.  
  147.   for (int i = beg; i < end; i++)
  148.     elem (i, i) = val;
  149.  
  150.   return *this;
  151. }
  152.  
  153. DiagMatrix&
  154. DiagMatrix::fill (const ColumnVector& a)
  155. {
  156.   int len = length ();
  157.   if (a.length () != len)
  158.     {
  159.       (*current_liboctave_error_handler) ("range error for fill");
  160.       return *this;
  161.     }
  162.  
  163.   for (int i = 0; i < len; i++)
  164.     elem (i, i) = a.elem (i);
  165.  
  166.   return *this;
  167. }
  168.  
  169. DiagMatrix&
  170. DiagMatrix::fill (const RowVector& a)
  171. {
  172.   int len = length ();
  173.   if (a.length () != len)
  174.     {
  175.       (*current_liboctave_error_handler) ("range error for fill");
  176.       return *this;
  177.     }
  178.  
  179.   for (int i = 0; i < len; i++)
  180.     elem (i, i) = a.elem (i);
  181.  
  182.   return *this;
  183. }
  184.  
  185. DiagMatrix&
  186. DiagMatrix::fill (const ColumnVector& a, int beg)
  187. {
  188.   int a_len = a.length ();
  189.   if (beg < 0 || beg + a_len >= length ())
  190.     {
  191.       (*current_liboctave_error_handler) ("range error for fill");
  192.       return *this;
  193.     }
  194.  
  195.   for (int i = 0; i < a_len; i++)
  196.     elem (i+beg, i+beg) = a.elem (i);
  197.  
  198.   return *this;
  199. }
  200.  
  201. DiagMatrix&
  202. DiagMatrix::fill (const RowVector& a, int beg)
  203. {
  204.   int a_len = a.length ();
  205.   if (beg < 0 || beg + a_len >= length ())
  206.     {
  207.       (*current_liboctave_error_handler) ("range error for fill");
  208.       return *this;
  209.     }
  210.  
  211.   for (int i = 0; i < a_len; i++)
  212.     elem (i+beg, i+beg) = a.elem (i);
  213.  
  214.   return *this;
  215. }
  216.  
  217. DiagMatrix
  218. DiagMatrix::transpose (void) const
  219. {
  220.   return DiagMatrix (dup (data (), length ()), cols (), rows ());
  221. }
  222.  
  223. Matrix
  224. DiagMatrix::extract (int r1, int c1, int r2, int c2) const
  225. {
  226.   if (r1 > r2) { int tmp = r1; r1 = r2; r2 = tmp; }
  227.   if (c1 > c2) { int tmp = c1; c1 = c2; c2 = tmp; }
  228.  
  229.   int new_r = r2 - r1 + 1;
  230.   int new_c = c2 - c1 + 1;
  231.  
  232.   Matrix result (new_r, new_c);
  233.  
  234.   for (int j = 0; j < new_c; j++)
  235.     for (int i = 0; i < new_r; i++)
  236.       result.elem (i, j) = elem (r1+i, c1+j);
  237.  
  238.   return result;
  239. }
  240.  
  241. // extract row or column i.
  242.  
  243. RowVector
  244. DiagMatrix::row (int i) const
  245. {
  246.   int nr = rows ();
  247.   int nc = cols ();
  248.   if (i < 0 || i >= nr)
  249.     {
  250.       (*current_liboctave_error_handler) ("invalid row selection");
  251.       return RowVector (); 
  252.     }
  253.  
  254.   RowVector retval (nc, 0.0);
  255.   if (nr <= nc || (nr > nc && i < nc))
  256.     retval.elem (i) = elem (i, i);
  257.  
  258.   return retval;
  259. }
  260.  
  261. RowVector
  262. DiagMatrix::row (char *s) const
  263. {
  264.   if (! s)
  265.     {
  266.       (*current_liboctave_error_handler) ("invalid row selection");
  267.       return RowVector (); 
  268.     }
  269.  
  270.   char c = *s;
  271.   if (c == 'f' || c == 'F')
  272.     return row (0);
  273.   else if (c == 'l' || c == 'L')
  274.     return row (rows () - 1);
  275.   else
  276.     {
  277.       (*current_liboctave_error_handler) ("invalid row selection");
  278.       return RowVector (); 
  279.     }
  280. }
  281.  
  282. ColumnVector
  283. DiagMatrix::column (int i) const
  284. {
  285.   int nr = rows ();
  286.   int nc = cols ();
  287.   if (i < 0 || i >= nc)
  288.     {
  289.       (*current_liboctave_error_handler) ("invalid column selection");
  290.       return ColumnVector (); 
  291.     }
  292.  
  293.   ColumnVector retval (nr, 0.0);
  294.   if (nr >= nc || (nr < nc && i < nr))
  295.     retval.elem (i) = elem (i, i);
  296.  
  297.   return retval;
  298. }
  299.  
  300. ColumnVector
  301. DiagMatrix::column (char *s) const
  302. {
  303.   if (! s)
  304.     {
  305.       (*current_liboctave_error_handler) ("invalid column selection");
  306.       return ColumnVector (); 
  307.     }
  308.  
  309.   char c = *s;
  310.   if (c == 'f' || c == 'F')
  311.     return column (0);
  312.   else if (c == 'l' || c == 'L')
  313.     return column (cols () - 1);
  314.   else
  315.     {
  316.       (*current_liboctave_error_handler) ("invalid column selection");
  317.       return ColumnVector (); 
  318.     }
  319. }
  320.  
  321. DiagMatrix
  322. DiagMatrix::inverse (void) const
  323. {
  324.   int info;
  325.   return inverse (info);
  326. }
  327.  
  328. DiagMatrix
  329. DiagMatrix::inverse (int &info) const
  330. {
  331.   int nr = rows ();
  332.   int nc = cols ();
  333.   int len = length ();
  334.   if (nr != nc)
  335.     {
  336.       (*current_liboctave_error_handler) ("inverse requires square matrix");
  337.       return DiagMatrix ();
  338.     }
  339.  
  340.   info = 0;
  341.   double *tmp_data = dup (data (), len);
  342.   for (int i = 0; i < len; i++)
  343.     {
  344.       if (elem (i, i) == 0.0)
  345.     {
  346.       info = -1;
  347.       copy (tmp_data, data (), len); // Restore contents.
  348.       break;
  349.     }
  350.       else
  351.     {
  352.       tmp_data[i] = 1.0 / elem (i, i);
  353.     }
  354.     }
  355.  
  356.   return DiagMatrix (tmp_data, nr, nc);
  357. }
  358.  
  359. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  360.  
  361. DiagMatrix&
  362. DiagMatrix::operator += (const DiagMatrix& a)
  363. {
  364.   int nr = rows ();
  365.   int nc = cols ();
  366.   if (nr != a.rows () || nc != a.cols ())
  367.     {
  368.       (*current_liboctave_error_handler)
  369.     ("nonconformant matrix += operation attempted");
  370.       return *this;
  371.     }
  372.  
  373.   if (nc == 0 || nr == 0)
  374.     return *this;
  375.  
  376.   double *d = fortran_vec (); // Ensures only one reference to my privates!
  377.  
  378.   add2 (d, a.data (), length ());
  379.   return *this;
  380. }
  381.  
  382. DiagMatrix&
  383. DiagMatrix::operator -= (const DiagMatrix& a)
  384. {
  385.   int nr = rows ();
  386.   int nc = cols ();
  387.   if (nr != a.rows () || nc != a.cols ())
  388.     {
  389.       (*current_liboctave_error_handler)
  390.     ("nonconformant matrix -= operation attempted");
  391.       return *this;
  392.     }
  393.  
  394.   if (nr == 0 || nc == 0)
  395.     return *this;
  396.  
  397.   double *d = fortran_vec (); // Ensures only one reference to my privates!
  398.  
  399.   subtract2 (d, a.data (), length ());
  400.   return *this;
  401. }
  402.  
  403. // diagonal matrix by scalar -> matrix operations
  404.  
  405. Matrix
  406. operator + (const DiagMatrix& a, double s)
  407. {
  408.   Matrix tmp (a.rows (), a.cols (), s);
  409.   return a + tmp;
  410. }
  411.  
  412. Matrix
  413. operator - (const DiagMatrix& a, double s)
  414. {
  415.   Matrix tmp (a.rows (), a.cols (), -s);
  416.   return a + tmp;
  417. }
  418.  
  419. ComplexMatrix
  420. operator + (const DiagMatrix& a, const Complex& s)
  421. {
  422.   ComplexMatrix tmp (a.rows (), a.cols (), s);
  423.   return a + tmp;
  424. }
  425.  
  426. ComplexMatrix
  427. operator - (const DiagMatrix& a, const Complex& s)
  428. {
  429.   ComplexMatrix tmp (a.rows (), a.cols (), -s);
  430.   return a + tmp;
  431. }
  432.  
  433. // diagonal matrix by scalar -> diagonal matrix operations
  434.  
  435. ComplexDiagMatrix
  436. operator * (const DiagMatrix& a, const Complex& s)
  437. {
  438.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  439.                 a.rows (), a.cols ());
  440. }
  441.  
  442. ComplexDiagMatrix
  443. operator / (const DiagMatrix& a, const Complex& s)
  444. {
  445.   return ComplexDiagMatrix (divide (a.data (), a.length (), s),
  446.                 a.rows (), a.cols ());
  447. }
  448.  
  449. // scalar by diagonal matrix -> matrix operations
  450.  
  451. Matrix
  452. operator + (double s, const DiagMatrix& a)
  453. {
  454.   Matrix tmp (a.rows (), a.cols (), s);
  455.   return tmp + a;
  456. }
  457.  
  458. Matrix
  459. operator - (double s, const DiagMatrix& a)
  460. {
  461.   Matrix tmp (a.rows (), a.cols (), s);
  462.   return tmp - a;
  463. }
  464.  
  465. ComplexMatrix
  466. operator + (const Complex& s, const DiagMatrix& a)
  467. {
  468.   ComplexMatrix tmp (a.rows (), a.cols (), s);
  469.   return tmp + a;
  470. }
  471.  
  472. ComplexMatrix
  473. operator - (const Complex& s, const DiagMatrix& a)
  474. {
  475.   ComplexMatrix tmp (a.rows (), a.cols (), s);
  476.   return tmp - a;
  477. }
  478.  
  479. // scalar by diagonal matrix -> diagonal matrix operations
  480.  
  481. ComplexDiagMatrix
  482. operator * (const Complex& s, const DiagMatrix& a)
  483. {
  484.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  485.                 a.rows (), a.cols ());
  486. }
  487.  
  488. // diagonal matrix by column vector -> column vector operations
  489.  
  490. ColumnVector
  491. operator * (const DiagMatrix& m, const ColumnVector& a)
  492. {
  493.   int nr = m.rows ();
  494.   int nc = m.cols ();
  495.   int a_len = a.length ();
  496.   if (nc != a_len)
  497.     {
  498.       (*current_liboctave_error_handler)
  499.     ("nonconformant matrix multiplication attempted");
  500.       return ColumnVector ();
  501.     }
  502.  
  503.   if (nc == 0 || nr == 0)
  504.     return ColumnVector (0);
  505.  
  506.   ColumnVector result (nr);
  507.  
  508.   for (int i = 0; i < a_len; i++)
  509.     result.elem (i) = a.elem (i) * m.elem (i, i);
  510.  
  511.   for (i = a_len; i < nr; i++)
  512.     result.elem (i) = 0.0;
  513.  
  514.   return result;
  515. }
  516.  
  517. ComplexColumnVector
  518. operator * (const DiagMatrix& m, const ComplexColumnVector& a)
  519. {
  520.   int nr = m.rows ();
  521.   int nc = m.cols ();
  522.   int a_len = a.length ();
  523.   if (nc != a_len)
  524.     {
  525.       (*current_liboctave_error_handler)
  526.     ("nonconformant matrix multiplication attempted");
  527.       return ColumnVector ();
  528.     }
  529.  
  530.   if (nc == 0 || nr == 0)
  531.     return ComplexColumnVector (0);
  532.  
  533.   ComplexColumnVector result (nr);
  534.  
  535.   for (int i = 0; i < a_len; i++)
  536.     result.elem (i) = a.elem (i) * m.elem (i, i);
  537.  
  538.   for (i = a_len; i < nr; i++)
  539.     result.elem (i) = 0.0;
  540.  
  541.   return result;
  542. }
  543.  
  544. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  545.  
  546. DiagMatrix
  547. operator * (const DiagMatrix& a, const DiagMatrix& b)
  548. {
  549.   int nr_a = a.rows ();
  550.   int nc_a = a.cols ();
  551.   int nr_b = b.rows ();
  552.   int nc_b = b.cols ();
  553.   if (nc_a != nr_b)
  554.     {
  555.       (*current_liboctave_error_handler)
  556.         ("nonconformant matrix multiplication attempted");
  557.       return DiagMatrix ();
  558.     }
  559.  
  560.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  561.     return DiagMatrix (nr_a, nc_a, 0.0);
  562.  
  563.   DiagMatrix c (nr_a, nc_b);
  564.  
  565.   int len = nr_a < nc_b ? nr_a : nc_b;
  566.  
  567.   for (int i = 0; i < len; i++)
  568.     {
  569.       double a_element = a.elem (i, i);
  570.       double b_element = b.elem (i, i);
  571.  
  572.       if (a_element == 0.0 || b_element == 0.0)
  573.         c.elem (i, i) = 0.0;
  574.       else if (a_element == 1.0)
  575.         c.elem (i, i) = b_element;
  576.       else if (b_element == 1.0)
  577.         c.elem (i, i) = a_element;
  578.       else
  579.         c.elem (i, i) = a_element * b_element;
  580.     }
  581.  
  582.   return c;
  583. }
  584.  
  585. ComplexDiagMatrix
  586. operator + (const DiagMatrix& m, const ComplexDiagMatrix& a)
  587. {
  588.   int nr = m.rows ();
  589.   int nc = m.cols ();
  590.   if (nr != a.rows () || nc != a.cols ())
  591.     {
  592.       (*current_liboctave_error_handler)
  593.     ("nonconformant matrix addition attempted");
  594.       return ComplexDiagMatrix ();
  595.     }
  596.  
  597.   if (nc == 0 || nr == 0)
  598.     return ComplexDiagMatrix (nr, nc);
  599.  
  600.   return ComplexDiagMatrix (add (m.data (), a.data (), m.length ()),  nr, nc);
  601. }
  602.  
  603. ComplexDiagMatrix
  604. operator - (const DiagMatrix& m, const ComplexDiagMatrix& a)
  605. {
  606.   int nr = m.rows ();
  607.   int nc = m.cols ();
  608.   if (nr != a.rows () || nc != a.cols ())
  609.     {
  610.       (*current_liboctave_error_handler)
  611.     ("nonconformant matrix subtraction attempted");
  612.       return ComplexDiagMatrix ();
  613.     }
  614.  
  615.   if (nc == 0 || nr == 0)
  616.     return ComplexDiagMatrix (nr, nc);
  617.  
  618.   return ComplexDiagMatrix (subtract (m.data (), a.data (), m.length ()),
  619.                 nr, nc);
  620. }
  621.  
  622. ComplexDiagMatrix
  623. operator * (const DiagMatrix& a, const ComplexDiagMatrix& b)
  624. {
  625.   int nr_a = a.rows ();
  626.   int nc_a = a.cols ();
  627.   int nr_b = b.rows ();
  628.   int nc_b = b.cols ();
  629.   if (nc_a != nr_b)
  630.     {
  631.       (*current_liboctave_error_handler)
  632.         ("nonconformant matrix multiplication attempted");
  633.       return ComplexDiagMatrix ();
  634.     }
  635.  
  636.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  637.     return ComplexDiagMatrix (nr_a, nc_a, 0.0);
  638.  
  639.   ComplexDiagMatrix c (nr_a, nc_b);
  640.  
  641.   int len = nr_a < nc_b ? nr_a : nc_b;
  642.  
  643.   for (int i = 0; i < len; i++)
  644.     {
  645.       double a_element = a.elem (i, i);
  646.       Complex b_element = b.elem (i, i);
  647.  
  648.       if (a_element == 0.0 || b_element == 0.0)
  649.         c.elem (i, i) = 0.0;
  650.       else if (a_element == 1.0)
  651.         c.elem (i, i) = b_element;
  652.       else if (b_element == 1.0)
  653.         c.elem (i, i) = a_element;
  654.       else
  655.         c.elem (i, i) = a_element * b_element;
  656.     }
  657.  
  658.   return c;
  659. }
  660.  
  661. ComplexDiagMatrix
  662. product (const DiagMatrix& m, const ComplexDiagMatrix& a)
  663. {
  664.   int nr = m.rows ();
  665.   int nc = m.cols ();
  666.   if (nr != a.rows () || nc != a.cols ())
  667.     {
  668.       (*current_liboctave_error_handler)
  669.     ("nonconformant matrix product attempted");
  670.       return ComplexDiagMatrix ();
  671.     }
  672.  
  673.   if (nc == 0 || nr == 0)
  674.     return ComplexDiagMatrix (nr, nc);
  675.  
  676.   return ComplexDiagMatrix (multiply (m.data (), a.data (), m.length ()),
  677.                 nr, nc);
  678. }
  679.  
  680. // diagonal matrix by matrix -> matrix operations
  681.  
  682. Matrix
  683. operator + (const DiagMatrix& m, const Matrix& a)
  684. {
  685.   int nr = m.rows ();
  686.   int nc = m.cols ();
  687.   if (nr != a.rows () || nc != a.cols ())
  688.     {
  689.       (*current_liboctave_error_handler)
  690.     ("nonconformant matrix addition attempted");
  691.       return Matrix ();
  692.     }
  693.  
  694.   if (nr == 0 || nc == 0)
  695.     return Matrix (nr, nc);
  696.  
  697.   Matrix result (a);
  698.   for (int i = 0; i < m.length (); i++)
  699.     result.elem (i, i) += m.elem (i, i);
  700.  
  701.   return result;
  702. }
  703.  
  704. Matrix
  705. operator - (const DiagMatrix& m, const Matrix& a)
  706. {
  707.   int nr = m.rows ();
  708.   int nc = m.cols ();
  709.   if (nr != a.rows () || nc != a.cols ())
  710.     {
  711.       (*current_liboctave_error_handler)
  712.     ("nonconformant matrix subtraction attempted");
  713.       return Matrix ();
  714.     }
  715.  
  716.   if (nr == 0 || nc == 0)
  717.     return Matrix (nr, nc);
  718.  
  719.   Matrix result (-a);
  720.   for (int i = 0; i < m.length (); i++)
  721.     result.elem (i, i) += m.elem (i, i);
  722.  
  723.   return result;
  724. }
  725.  
  726. Matrix
  727. operator * (const DiagMatrix& m, const Matrix& a)
  728. {
  729.   int nr = m.rows ();
  730.   int nc = m.cols ();
  731.   int a_nr = a.rows ();
  732.   int a_nc = a.cols ();
  733.   if (nc != a_nr)
  734.     {
  735.       (*current_liboctave_error_handler)
  736.     ("nonconformant matrix multiplication attempted");
  737.       return Matrix ();
  738.     }
  739.  
  740.   if (nr == 0 || nc == 0 || a_nc == 0)
  741.     return Matrix (nr, a_nc, 0.0);
  742.  
  743.   Matrix c (nr, a_nc);
  744.  
  745.   for (int i = 0; i < m.length (); i++)
  746.     {
  747.       if (m.elem (i, i) == 1.0)
  748.     {
  749.       for (int j = 0; j < a_nc; j++)
  750.         c.elem (i, j) = a.elem (i, j);
  751.     }
  752.       else if (m.elem (i, i) == 0.0)
  753.     {
  754.       for (int j = 0; j < a_nc; j++)
  755.         c.elem (i, j) = 0.0;
  756.     }
  757.       else
  758.     {
  759.       for (int j = 0; j < a_nc; j++)
  760.         c.elem (i, j) = m.elem (i, i) * a.elem (i, j);
  761.     }
  762.     }
  763.  
  764.   if (nr > nc)
  765.     {
  766.       for (int j = 0; j < a_nc; j++)
  767.     for (int i = a_nr; i < nr; i++)
  768.       c.elem (i, j) = 0.0;
  769.     }
  770.  
  771.   return c;
  772. }
  773.  
  774. ComplexMatrix
  775. operator + (const DiagMatrix& m, const ComplexMatrix& a)
  776. {
  777.   int nr = m.rows ();
  778.   int nc = m.cols ();
  779.   if (nr != a.rows () || nc != a.cols ())
  780.     {
  781.       (*current_liboctave_error_handler)
  782.     ("nonconformant matrix addition attempted");
  783.       return ComplexMatrix ();
  784.     }
  785.  
  786.   if (nr == 0 || nc == 0)
  787.     return ComplexMatrix (nr, nc);
  788.  
  789.   ComplexMatrix result (a);
  790.   for (int i = 0; i < m.length (); i++)
  791.     result.elem (i, i) += m.elem (i, i);
  792.  
  793.   return result;
  794. }
  795.  
  796. ComplexMatrix
  797. operator - (const DiagMatrix& m, const ComplexMatrix& a)
  798. {
  799.   int nr = m.rows ();
  800.   int nc = m.cols ();
  801.   if (nr != a.rows () || nc != a.cols ())
  802.     {
  803.       (*current_liboctave_error_handler)
  804.     ("nonconformant matrix subtraction attempted");
  805.       return ComplexMatrix ();
  806.     }
  807.  
  808.   if (nr == 0 || nc == 0)
  809.     return ComplexMatrix (nr, nc);
  810.  
  811.   ComplexMatrix result (-a);
  812.   for (int i = 0; i < m.length (); i++)
  813.     result.elem (i, i) += m.elem (i, i);
  814.  
  815.   return result;
  816. }
  817.  
  818. ComplexMatrix
  819. operator * (const DiagMatrix& m, const ComplexMatrix& a)
  820. {
  821.   int nr = m.rows ();
  822.   int nc = m.cols ();
  823.   int a_nr = a.rows ();
  824.   int a_nc = a.cols ();
  825.   if (nc != a_nr)
  826.     {
  827.       (*current_liboctave_error_handler)
  828.     ("nonconformant matrix multiplication attempted");
  829.       return ComplexMatrix ();
  830.     }
  831.  
  832.   if (nr == 0 || nc == 0 || a_nc == 0)
  833.     return ComplexMatrix (nr, nc, 0.0);
  834.  
  835.   ComplexMatrix c (nr, a_nc);
  836.  
  837.   for (int i = 0; i < m.length (); i++)
  838.     {
  839.       if (m.elem (i, i) == 1.0)
  840.     {
  841.       for (int j = 0; j < a_nc; j++)
  842.         c.elem (i, j) = a.elem (i, j);
  843.     }
  844.       else if (m.elem (i, i) == 0.0)
  845.     {
  846.       for (int j = 0; j < a_nc; j++)
  847.         c.elem (i, j) = 0.0;
  848.     }
  849.       else
  850.     {
  851.       for (int j = 0; j < a_nc; j++)
  852.         c.elem (i, j) = m.elem (i, i) * a.elem (i, j);
  853.     }
  854.     }
  855.  
  856.   if (nr > nc)
  857.     {
  858.       for (int j = 0; j < a_nc; j++)
  859.     for (int i = a_nr; i < nr; i++)
  860.       c.elem (i, j) = 0.0;
  861.     }
  862.  
  863.   return c;
  864. }
  865.  
  866. // other operations
  867.  
  868. ColumnVector
  869. DiagMatrix::diag (void) const
  870. {
  871.   return diag (0);
  872. }
  873.  
  874. // Could be optimized...
  875.  
  876. ColumnVector
  877. DiagMatrix::diag (int k) const
  878. {
  879.   int nnr = rows ();
  880.   int nnc = cols ();
  881.   if (k > 0)
  882.     nnc -= k;
  883.   else if (k < 0)
  884.     nnr += k;
  885.  
  886.   ColumnVector d;
  887.  
  888.   if (nnr > 0 && nnc > 0)
  889.     {
  890.       int ndiag = (nnr < nnc) ? nnr : nnc;
  891.  
  892.       d.resize (ndiag);
  893.  
  894.       if (k > 0)
  895.     {
  896.       for (int i = 0; i < ndiag; i++)
  897.         d.elem (i) = elem (i, i+k);
  898.     }
  899.       else if ( k < 0)
  900.     {
  901.       for (int i = 0; i < ndiag; i++)
  902.         d.elem (i) = elem (i-k, i);
  903.     }
  904.       else
  905.     {
  906.       for (int i = 0; i < ndiag; i++)
  907.         d.elem (i) = elem (i, i);
  908.     }
  909.     }
  910.   else
  911.     cerr << "diag: requested diagonal out of range\n";
  912.  
  913.   return d;
  914. }
  915.  
  916. ostream&
  917. operator << (ostream& os, const DiagMatrix& a)
  918. {
  919. //  int field_width = os.precision () + 7;
  920.   for (int i = 0; i < a.rows (); i++)
  921.     {
  922.       for (int j = 0; j < a.cols (); j++)
  923.     {
  924.       if (i == j)
  925.         os << " " /* setw (field_width) */ << a.elem (i, i);
  926.       else
  927.         os << " " /* setw (field_width) */ << 0.0;
  928.     }
  929.       os << "\n";
  930.     }
  931.   return os;
  932. }
  933.  
  934. /*
  935. ;;; Local Variables: ***
  936. ;;; mode: C++ ***
  937. ;;; page-delimiter: "^/\\*" ***
  938. ;;; End: ***
  939. */
  940.