home *** CD-ROM | disk | FTP | other *** search
/ PC-Blue - MS DOS Public Domain Library / PC-Blue MS-DOS Public Domain Library - NYACC.iso / vol270 / solve.c < prev   
Encoding:
C/C++ Source or Header  |  1986-12-16  |  6.7 KB  |  318 lines

  1. /*    decomp & solve - solution of linear system
  2.  
  3.  
  4.     AUTHORS
  5.         Forsythe, Malcolm, and Moler, from "Computer Methods for
  6.         Mathematical Computations"
  7.  
  8.         translated to C by J. R. Van Zandt
  9. */
  10.  
  11. #include <stdio.h>
  12. #include <math.h>
  13.  
  14. #define aa(i,j) a[ix[i]+j]        /* like a[i][j] or a[i*ndim+j] but faster */
  15.  
  16. double decomp(ndim,n,a,ipvt)    /* returns estimate of condition number  */
  17.     int ndim,n,ipvt[]; double a[];
  18. {    double *work,ek,t,anorm,ynorm,znorm,cond;
  19.     int nm1,i,j,k,kp1,kb,km1,m,*ix;
  20.     work=malloc(n*sizeof(double));
  21.     ix=malloc(n*sizeof(int));
  22.     if(work==0 || ix==0)
  23.         {fprintf(stderr,"decomp: no workspace\n");
  24.         exit(1);
  25.         }
  26.     for (i=0; i<n; i++)
  27.         ix[i]=i*ndim;
  28.     nm1=n-1;
  29.     ipvt[nm1]=0;
  30.     if(n==1)
  31.         {free(work);
  32.         free(ix);
  33.         if(a[0]==0.)
  34.             return 1.e32;
  35.         else
  36.             return 1.;
  37.         }
  38.             /*    compute 1-norm of a    */
  39.     anorm=0.;
  40.     for (j=0; j<n; j++)
  41.         {t=0.;
  42.         for (i=0; i<n; i++)
  43.             t += fabs(aa(i,j));
  44.         if(t>anorm)
  45.             anorm=t;
  46.         }
  47.                     /*    Gaussian elimination with partial pivoting */
  48.     for (k=0; k<nm1; k++)
  49.         {kp1=k+1;
  50.                     /*    Find pivot    */
  51.         m=k;
  52.         for (i=kp1; i<n; i++)
  53.             if(fabs(aa(i,k)) > fabs(aa(m,k)))
  54.                 m=i;
  55.         ipvt[k]=m;
  56.         t=aa(m,k);
  57. /*        printf("pivoting on a[%d][%d] = %8.4f\n",m,k,t);    */
  58.         if(m!=k)
  59.             {ipvt[nm1]=-ipvt[nm1];
  60.             aa(m,k)=aa(k,k);
  61.             aa(k,k)=t;
  62.             }
  63.                     /*    skip step if pivot is zero */
  64.         if(t!=0.)
  65.             {        /*    compute multipliers    */
  66.             for (i=kp1; i<n; i++)
  67.                 aa(i,k) = -aa(i,k)/t;
  68.                     /*    interchange and eliminate by columns    */
  69.             for (j=kp1; j<n; j++)
  70.                 {t=aa(m,j);
  71.                 aa(m,j) = aa(k,j);
  72.                 aa(k,j) = t;
  73.                 if(t != 0.)
  74.                     for (i=kp1; i<n; i++)
  75.                         aa(i,j) += aa(i,k)*t;
  76.                 }
  77.             }
  78. /*        showm("After pivoting",a,ndim,n);    */
  79.         }
  80.                     /*    
  81.                         cond = (1-norm of a)*(an estimate of 1-norm of
  82.                         a-inverse) estimate obtained by one step of
  83.                         inverse iteration for the small singular
  84.                         vector.  This involves solving two systems of
  85.                         equations, (a-transpose)*y = e and a*z=y there
  86.                         e is a vector of +1 or -1 chosen to cause
  87.                         growth in y.  Estimate = (1-norm of z)/(1-norm
  88.                         of y)
  89.                     */
  90.             
  91.                     /*    solve (a-transpose)*y - e    */
  92.     for (k=0; k<n; k++)
  93.         {t=0.;
  94.         if(k)
  95.             for (i=0; i<k; i++)
  96.                 t += aa(i,k)*work[i];
  97.         if(t<0.) ek = -1.; else ek=1.;
  98.         if (aa(k,k) == 0.)
  99.             {free(work);
  100.             free(ix);
  101.             return 1.e32;
  102.             }
  103.         work[k] = -(ek+t)/aa(k,k);
  104.         }
  105. /*    showv("decompose: work1",work,n);    */
  106.     for (k=n-2; k>=0; k--)
  107.         {t=0.;
  108.         for (i=k+1; i<n; i++)
  109.             t += aa(i,k)*work[k];
  110.         work[k] = t;
  111.         }
  112. /*    showv("decompose: work2",work,n);    */
  113.     m=ipvt[k];
  114.     if (m!=k)
  115.         {t=work[m];
  116.         work[m] = work[k];
  117.         work[k]=t;
  118.         }
  119.     ynorm=0.;
  120.     for (i=0; i<n; i++)
  121.         ynorm += fabs(work[i]);
  122.             /*    solve a*z=y  */
  123.     solve(ndim, n, a, work, ipvt);
  124.     znorm=0.;
  125.     for (i=0; i<n; i++)
  126.         znorm += fabs(work[i]);
  127.             /*    estimate condition */
  128.     cond=anorm*znorm/ynorm;
  129.     if(cond<1.)
  130.         cond=1.;
  131.     free(work);
  132.     free(ix);
  133.     return cond;
  134. }
  135.  
  136. double invert(ndim,n,a,x)    /* returns estimate of condition number of matrix */
  137.     int ndim,            /*    declared row dimension of array containing a    */
  138.     n;                    /*    order of matrix a    */
  139.     double a[],            /*    matrix to be inverted    */
  140.     x[];                /*    resulting inverse    */
  141. {    int *ipvt, i, j;
  142.     double *work, cond, condp1;
  143.     ipvt=malloc(n*sizeof(int));
  144.     work=malloc(n*sizeof(double));
  145.     if(ipvt==0 || work==0)
  146.         {fprintf(stderr,"invert: not enough memory\n");
  147.         exit(1);
  148.         }
  149.     cond=decomp(ndim,n,a,ipvt);
  150. /*    printf("invert: cond = %f\n\n",cond);
  151.     printf("invert: ipvt\n");
  152.     for (i=0; i<n; i++)
  153.         printf("%8d \n",ipvt[i]);
  154. */
  155.     condp1=cond+1.;
  156.     if(condp1==cond)
  157.         {free(work);
  158.         free(ipvt);
  159.         return cond;
  160.         }
  161.     for (i=0; i<n; i++)
  162.         {for (j=0; j<n; j++)
  163.             work[j]=0.;
  164.         work[i]=1.;
  165. /*        showv("invert: RHS",work,n);    */
  166.         solve(ndim,n,a,work,ipvt);
  167.         for (j=0; j<n; j++)
  168.             x[j*ndim+i]=work[j];
  169.         }
  170.     free(work);
  171.     free(ipvt);
  172.     return cond;
  173. }
  174.  
  175. solve (ndim,n,a,b,ipvt)
  176.     int ndim,            /*    declared row dimension of array containing a    */
  177.     n,                    /*    order of matrix    */
  178.     ipvt[];                /*    pivot vector obtained from decomp    */
  179.     double a[],            /*    triangularized matrix obtained from decomp    */
  180.     b[];                /*    right hand side vector    */
  181. {    int kb,km1,nm1,kp1,i,k,m,*ix;
  182.     double t;
  183.     ix=malloc(n*sizeof(int));
  184.     if(ix==0)
  185.         {fprintf(stderr,"solve: no workspace\n");
  186.         exit(1);
  187.         }
  188.     for (i=0; i<n; i++)
  189.         ix[i]=i*ndim;
  190. /*    showm("solve: decomposed matrix",a,ndim,n);
  191.     showv("solve: RHS",b,n);                        */
  192.             /*    forward elimination    */
  193.     if(n!=1)
  194.         {nm1=n-1;
  195.         for (k=0; k<nm1; k++)
  196.             {kp1=k+1;
  197.             m=ipvt[k];
  198.             t=b[m];
  199.             b[m]=b[k];
  200.             b[k]=t;
  201.             for (i=kp1; i<n; i++)
  202.                 b[i] += aa(i,k)*t;
  203.             }
  204. /*        showv("\nafter forward elimination",b,n);    */
  205.                     /*    back substitution */
  206.         for (k=nm1; k; k--)
  207.             {b[k] /= aa(k,k);
  208.             t=-b[k];
  209.             for (i=0; i<k; i++)
  210.                 b[i] += aa(i,k)*t;
  211.             }
  212.         }
  213.     b[0] /= a[0];
  214. /*    showv("\nafter back substitution",b,n);        */
  215.     free(ix);
  216. }
  217.  
  218. #ifdef TEST
  219.  
  220. /*    sample program for decomp and solve */
  221.  
  222. main()
  223. {    double x[13][13], p[13][13], a[13][13], b[13], cond, condp1;
  224.     int    ipvt[13], i, j, k, n, ndim;
  225.  
  226.     ndim=13;
  227.     n=3;
  228.     a[0][0]=10.;
  229.     a[1][0]=-3.;
  230.     a[2][0]= 5.;
  231.     a[0][1]=-7.;
  232.     a[1][1]= 2.;
  233.     a[2][1]=-1.;
  234.     a[0][2]= 0.;
  235.     a[1][2]= 6.;
  236.     a[2][2]= 5.;
  237.     for (i=0; i<n; i++)
  238.         {for (j=0; j<n; j++)
  239.             printf("%8.4f ",a[i][j]);
  240.         printf("\n");
  241.         }
  242.     printf("\n");
  243.     cond=decomp(ndim,n,a,ipvt);
  244.     printf("cond = %f\n\n",cond);
  245.     printf("\nipvt\n");
  246.     for (i=0; i<n; i++)
  247.         printf("%8d \n",ipvt[i]);
  248.     condp1=cond+1.;
  249.     if(condp1==cond) exit();
  250.     b[0]=7.;
  251.     b[1]=4.;
  252.     b[2]=6.;
  253.     showv("RHS",b,n);
  254.     solve(ndim,n,a,b,ipvt);
  255.     showv("solution",b,n);
  256.     a[0][0]=10.;
  257.     a[1][0]=-3.;
  258.     a[2][0]= 5.;
  259.     a[0][1]=-7.;
  260.     a[1][1]= 2.;
  261.     a[2][1]=-1.;
  262.     a[0][2]= 0.;
  263.     a[1][2]= 6.;
  264.     a[2][2]= 5.;
  265.     showm("a, matrix to be inverted",a,ndim,n);
  266.     cond=invert(ndim,n,a,x);
  267.     printf("cond = %f\n\n",cond);
  268.     printf("ipvt\n");
  269.     for (i=0; i<n; i++)
  270.         printf("%8d \n",ipvt[i]);
  271.     condp1=cond+1.;
  272.     if(condp1==cond) exit();
  273.     showm("x, inverse",x,ndim,n);
  274.     a[0][0]=10.;
  275.     a[1][0]=-3.;
  276.     a[2][0]= 5.;
  277.     a[0][1]=-7.;
  278.     a[1][1]= 2.;
  279.     a[2][1]=-1.;
  280.     a[0][2]= 0.;
  281.     a[1][2]= 6.;
  282.     a[2][2]= 5.;
  283.     for (k=0; k<n; k++)
  284.         for (j=0; j<n; j++)
  285.             {p[k][j]=0.;
  286.             for (i=0; i<n; i++)
  287.                 p[k][j] += a[k][i]*x[i][j];
  288.             }
  289.     showm("a*x",p,ndim,n);
  290.     for (k=0; k<n; k++)
  291.         for (j=0; j<n; j++)
  292.             {p[k][j]=0.;
  293.             for (i=0; i<n; i++)
  294.                 p[k][j] += x[k][i]*a[i][j];
  295.             }
  296.     showm("x*a",p,ndim,n);
  297. }
  298.  
  299. showv(s,a,n) char *s; double a[]; int n;
  300. {    int i,j;
  301.     printf("%s\n",s);
  302.     for (i=0; i<n; i++)
  303.         printf("%8.4f \n",a[i]);
  304.     printf("\n");
  305. }
  306.  
  307. showm(s,a,ndim,n) char *s; double a[]; int ndim,n;
  308. {    int i,j;
  309.     printf("%s\n",s);
  310.     for (i=0; i<n; i++)
  311.         {for (j=0; j<n; j++)
  312.             printf("%8.4f ",a[i*ndim+j]);
  313.         printf("\n");
  314.         }
  315.     printf("\n");
  316. }
  317. #endif
  318.