home *** CD-ROM | disk | FTP | other *** search
/ ProfitPress Mega CDROM2 …eeware (MSDOS)(1992)(Eng) / ProfitPress-MegaCDROM2.B6I / MAGAZINE / MICROCOR / ISSUE_51.ZIP / BATCHNET.C < prev    next >
Encoding:
Text File  |  1989-11-13  |  16.8 KB  |  486 lines

  1. /*  This code supports an article in issue #51 of:
  2.  
  3.     Micro Cornucopia Magazine
  4.     P.O. Box 223
  5.     Bend, OR 97709
  6. */
  7. /* Figure 3 -- batchnet.c  */
  8.  
  9.    Generic back-propagation neural network
  10.  
  11.    Copyright (c)  1988, 1989  R.W.Dobbins and R.C.Eberhart
  12.    All Rights Reserved
  13.  
  14.    $Revision:   1.1  $            $Date:   21 Sep 1989 11:35:06  $
  15. */
  16.  
  17. #include <stdio.h>
  18. #include <stdlib.h>
  19. #include <math.h>
  20. #include <conio.h>
  21. #include <ctype.h>
  22. #include <string.h>
  23.  
  24. #define   ESC          27
  25. #define   ERRORLEVEL   0.04
  26. #define   ITEMS        8
  27.  
  28. /* typedefs and prototypes for dynamic storage of arrays */
  29. typedef float *PFLOAT;
  30. typedef PFLOAT VECTOR;
  31. typedef PFLOAT *MATRIX;
  32.  
  33. void  VectorAllocate(VECTOR *vector, int nCols);
  34. void  AllocateCols(PFLOAT matrix[], int nRows, int nCols);
  35. void  MatrixAllocate(MATRIX *pmatrix, int nRows, int nCols);
  36. void  MatrixFree(MATRIX matrix,  int nRows);
  37.  
  38. /* define storage for net layers */
  39. /* Arrays for inputs, outputs, deltas, weights & targets */
  40. MATRIX   out0;        /* input layer  */
  41. MATRIX   out1;        /* hidden layer */
  42. MATRIX   delta1;      /* delta at hidden layer  */
  43. MATRIX   delw1;       /* change in weights input:hidden */
  44. MATRIX   w1;          /* weights input:hidden */
  45. MATRIX   out2;        /* output layer */
  46. MATRIX   delta2;      /* delta at output layer  */
  47. MATRIX   delw2;       /* change in weights hidden:output */
  48. MATRIX   w2;          /* weights hidden:output */
  49. MATRIX   target;      /* target output */
  50. VECTOR   PatternID;   /* identifier for each stored pattern */
  51.  
  52.  
  53. void  main(int argc, char *argv[])
  54. {
  55.    float eta   =  0.15,        /* default learning rate             */
  56.          alpha =  0.075;       /* default momentum factor           */
  57.    int   nReportErrors = 100;  /* error reporting frequency         */
  58.    float ErrorLevel = ERRORLEVEL; /* satisfactory error level       */
  59.    char  MonitorError = 0;     /* true when monitor error display   */
  60.    float error;                /* latest sum squared error value    */
  61.    register int   h;           /* index hidden layer                */
  62.    register int   i;           /* index input layer                 */
  63.    register int   j;           /* index output layer                */
  64.    int   p,                    /* index pattern number              */
  65.          q,                    /* index iteration number            */
  66.          r,                    /* index run number                  */
  67.          nPatterns,            /* number of patterns desired        */
  68.          nInputNodes,          /* number of input nodes             */
  69.          nHiddenNodes,         /* number of hidden nodes            */
  70.          nOutputNodes,         /* number of output nodes            */
  71.          nIterations,          /* number of iterations desired      */
  72.          nRuns;                /* number of runs (or input lines)   */
  73.    FILE  *fpRun,               /* run file                          */
  74.          *fpPattern,           /* source pattern input file         */
  75.          *fpWeights,           /* initial weight file               */
  76.          *fpWeightsOut,        /* final weight output file          */
  77.          *fpResults,           /* results output file               */
  78.          *fpError;             /* error output file                 */
  79.    char  szResults[66];        /* various filenames (pathnames)     */
  80.    char  szError[66];
  81.    char  szPattern[66];
  82.    char  szWeights[66];
  83.    char  szWeightsOut[66];
  84.    char  *progname  =  *argv;  /* name of executable DOS 3.x only  */
  85.  
  86.    /* read optional - arguments */
  87.    for (; argc > 1;  argc--)
  88.    {
  89.       char *arg = *++argv;
  90.  
  91.       if (*arg  !=  '-')
  92.          break;
  93.  
  94.       switch (*++arg)
  95.       {
  96.          case 'e':   sscanf(++arg,  "%d",  &nReportErrors);   break;
  97.          case 'd':   sscanf(++arg,  "%f",  &ErrorLevel);      break;
  98.          default:    break;
  99.       }
  100.    }
  101.  
  102.    if (argc < 2)
  103.    {
  104.       fprintf(stderr, "Usage:  %s {-en -df} runfilename\n",  progname);
  105.       fprintf(stderr, "   -en   =>  report error every n iterations\n");
  106.       fprintf(stderr, "   -df   =>  done if sum squared error < f\n");
  107.       exit(1);
  108.    }
  109.  
  110.    /* Open run file for reading */
  111.    if ((fpRun = fopen(*argv, "r"))   ==   NULL)
  112.    {
  113.       fprintf(stderr, "%s: can't open file %s\n", progname, *argv);
  114.       exit(1);
  115.    }
  116.  
  117.    /* Read first line: no. of runs (lines to read from run file) */
  118.    fscanf(fpRun,  "%d",  &nRuns);
  119.  
  120.    /*--------------------- beginning of work loop -------------------------*/
  121.    for (r = 0;   r < nRuns;   r++)
  122.    {
  123.       /* read and parse the run specification line; */
  124.       fscanf(fpRun,
  125.           "%s %s %s %s %s %d %d %d %d %d %f %f",
  126.           szResults,          /* output results file */
  127.           szError,            /* error output file */
  128.           szPattern,          /* pattern input file */
  129.           szWeights,          /* initial weights file */
  130.           szWeightsOut,       /* final weights output file */
  131.           &nPatterns,         /* number of patterns to learn */
  132.           &nIterations,       /* number of iterations through the data */
  133.           &nInputNodes,       /* number of input nodes  */
  134.           &nHiddenNodes,      /* number of hidden nodes */
  135.           &nOutputNodes,      /* number of output nodes */
  136.           &eta,               /* learning rate */
  137.           &alpha);            /* momentum factor */
  138.  
  139.       /*----------allocate dynamic storage for all data ---------------*/
  140.       MatrixAllocate(&out0,      nPatterns,    nInputNodes);
  141.       MatrixAllocate(&out1,      nPatterns,    nHiddenNodes);
  142.       MatrixAllocate(&out2,      nPatterns,    nOutputNodes);
  143.       MatrixAllocate(&delta2,    nPatterns,    nOutputNodes);
  144.       MatrixAllocate(&delw2,     nOutputNodes, nHiddenNodes + 1);
  145.       MatrixAllocate(&w2,        nOutputNodes, nHiddenNodes + 1);
  146.       MatrixAllocate(&delta1,    nPatterns,    nHiddenNodes);
  147.       MatrixAllocate(&delw1,     nHiddenNodes, nInputNodes + 1);
  148.       MatrixAllocate(&w1,        nHiddenNodes, nInputNodes + 1);
  149.       MatrixAllocate(&target,    nPatterns,    nOutputNodes);
  150.       VectorAllocate(&PatternID, nPatterns);
  151.  
  152.       /*--------- Read the initial weight matrices: -------------------*/
  153.       if ((fpWeights = fopen(szWeights,"r"))  ==  NULL)
  154.       {
  155.          fprintf(stderr,  "%s: can't open file %s\n",  progname, szWeights);
  156.          exit(1);
  157.       }
  158.  
  159.       /* read input:hidden weights */
  160.       for (h = 0;  h < nHiddenNodes;  h++)
  161.          for (i = 0;  i <= nInputNodes;  i++)
  162.          {
  163.             fscanf(fpWeights,  "%f",      &w1[h][i]);
  164.             delw1[h][i] = 0.0;
  165.          }
  166.  
  167.       /* read hidden:out weights */
  168.       for (j = 0;  j < nOutputNodes;  j++)
  169.          for (h = 0;  h <= nHiddenNodes;  h++)
  170.          {
  171.             fscanf(fpWeights,  "%f",      &w2[j][h]);
  172.             delw2[j][h] = 0.0;
  173.          }
  174.  
  175.       fclose(fpWeights);
  176.  
  177.       /*------------ Read in all patterns to be learned:----------------*/
  178.       if ((fpPattern = fopen(szPattern, "r"))  ==  NULL)
  179.       {
  180.          fprintf(stderr,  "%s: can't open file %s\n",  progname, szPattern);
  181.          exit(1);
  182.       }
  183.  
  184.       for (p = 0;  p < nPatterns;  p++)
  185.       {
  186.          for (i = 0;   i < nInputNodes;   i++)
  187.             if (fscanf(fpPattern,  "%f",   &out0[p][i])  != 1)
  188.                goto  ALLPATTERNSREAD;
  189.  
  190.  
  191.          /* read in target outputs for input patterns read */
  192.          for (j = 0;  j < nOutputNodes;  j++)
  193.             fscanf(fpPattern,  "%f",   &target[p][j]);
  194.  
  195.          /* read in identifier for each pattern */
  196.          fscanf(fpPattern,  "%f ",   &PatternID[p]);
  197.       }
  198.  
  199.       ALLPATTERNSREAD:
  200.       fclose(fpPattern);
  201.  
  202.       if (p < nPatterns)
  203.       {
  204.          fprintf(stderr, "%s:  %d out of %d patterns read\n",
  205.                  progname,  p,  nPatterns);
  206.          nPatterns = p;
  207.       }
  208.  
  209.       /* open error output file */
  210.       if ((fpError = fopen(szError, "w"))  ==  NULL)
  211.       {
  212.          fprintf(stderr,  "%s: can't open file %s\n",  progname, szError);
  213.          exit(1);
  214.       }
  215.  
  216.       fprintf(stderr,  nIterations > 1  ?  "Training...\n"  :  "Testing\n");
  217.  
  218.       /*--------------------- begin iteration loop ------------------------*/
  219.       for (q = 0;  q < nIterations;  q++)
  220.       {
  221.          for (p = 0;  p < nPatterns;  p++)
  222.          {
  223.             /*-------------------- hidden layer --------------------------*/
  224.             /* Sum input to hidden layer over all
  225.                 input-weight combinations */
  226.             for (h = 0;  h < nHiddenNodes;  h++)
  227.             {
  228.                float sum = w1[h][nInputNodes];  /* begin with bias  */
  229.  
  230.                for (i = 0;  i < nInputNodes;  i++)
  231.                   sum   +=   w1[h][i]  *  out0[p][i];
  232.  
  233.                /* Compute output (use sigmoid) */
  234.                out1[p][h]   =   1.0  /  (1.0  +  exp(-sum));
  235.             }
  236.  
  237.             /*-------------------- output layer --------------------------*/
  238.             for (j = 0;  j < nOutputNodes;  j++)
  239.             {
  240.                float  sum = w2[j][nHiddenNodes];
  241.  
  242.                for (h = 0;  h < nHiddenNodes;  h++)
  243.                   sum  +=   w2[j][h]  *  out1[p][h];
  244.  
  245.                out2[p][j]  =  1.0  /  (1.0  +  exp(-sum));
  246.             }
  247.  
  248.             /*-------------------- delta output --------------------------*/
  249.             /* Compute deltas for each output unit for a given pattern */
  250.             for (j = 0;  j < nOutputNodes;  j++)
  251.                delta2[p][j]   =   (target[p][j] - out2[p][j])  *
  252.                                 out2[p][j]   *   (1.0 - out2[p][j]);
  253.  
  254.  
  255.             /*-------------------- delta hidden --------------------------*/
  256.             for (h = 0;  h < nHiddenNodes;  h++)
  257.             {
  258.                float  sum = 0.0;
  259.  
  260.                for (j = 0;  j < nOutputNodes;  j++)
  261.                   sum  +=  delta2[p][j] * w2[j][h];
  262.  
  263.                delta1[p][h]  =  sum  *  out1[p][h]  *  (1.0 - out1[p][h]);
  264.             }
  265.          }
  266.  
  267.          /*-------------- adapt weights hidden:output ---------------------*/
  268.          for (j = 0;  j < nOutputNodes;  j++)
  269.          {
  270.             float  dw;                  /* delta weight */
  271.             float  sum = 0.0;
  272.  
  273.             /* grand sum of deltas for each output node for one epoch */
  274.             for (p = 0;  p < nPatterns;  p++)
  275.                sum  +=  delta2[p][j];
  276.  
  277.             /* Calculate new bias weight for each output unit */
  278.             dw   =   eta * sum  +  alpha * delw2[j][nHiddenNodes];
  279.             w2[j][nHiddenNodes]   +=   dw;
  280.             delw2[j][nHiddenNodes] =   dw;     /* delta for bias */
  281.  
  282.             /* Calculate new weights */
  283.             for (h = 0;  h < nHiddenNodes;  h++)
  284.             {
  285.                float  sum = 0.0;
  286.  
  287.                for (p = 0;  p < nPatterns;  p++)
  288.                   sum  +=  delta2[p][j] * out1[p][h];
  289.  
  290.                dw           =   eta * sum  +  alpha * delw2[j][h];
  291.                w2[j][h]     +=  dw;
  292.                delw2[j][h]  =   dw;
  293.             }
  294.          }
  295.  
  296.          /*-------------------- adapt weights input:hidden -----------------*/
  297.          for (h = 0;  h < nHiddenNodes;  h++)
  298.          {
  299.             float  dw;                  /* delta weight */
  300.             float  sum = 0.0;
  301.  
  302.             for (p = 0;  p < nPatterns;  p++)
  303.                sum  +=  delta1[p][h];
  304.  
  305.             /* Calculate new bias weight for each hidden unit */
  306.             dw   =   eta * sum  +  alpha * delw1[h][nInputNodes];
  307.             w1[h][nInputNodes]   +=   dw;
  308.             delw1[h][nInputNodes] =   dw;
  309.  
  310.             /* Calculate new weights */
  311.             for (i = 0;  i < nInputNodes;  i++)
  312.             {
  313.                float  sum = 0.0;
  314.  
  315.                for (p = 0;  p < nPatterns;  p++)
  316.                   sum  +=  delta1[p][h] * out0[p][i];
  317.  
  318.                dw           =   eta * sum  +  alpha * delw1[h][i];
  319.                w1[h][i]     +=  dw;
  320.                delw1[h][i]  =   dw;
  321.             }
  322.          }
  323.  
  324.          /* -------------- monitor keyboard requests ---------------------*/
  325.          if (kbhit())
  326.          {
  327.              int    c = getch();
  328.  
  329.              if ((c = toupper(c))  == 'E')
  330.                 MonitorError++;
  331.              else if (c == ESC)
  332.                 break;              /* Terminate gracefully on quit key */
  333.          }
  334.  
  335.          /*-------------------- Sum Squared Error ------------------------*/
  336.          if (MonitorError  ||  (q % nReportErrors   ==   0))
  337.          {
  338.             for (p = 0, error = 0.0;   p < nPatterns;   p++)
  339.             {
  340.                for (j = 0;  j < nOutputNodes;  j++)
  341.                {
  342.                   float  temp   =   target[p][j] - out2[p][j];
  343.  
  344.                   error += temp * temp;
  345.                }
  346.             }
  347.  
  348.             /* Average error over all patterns */
  349.             error  /=  nPatterns;
  350.  
  351.             /* Print iteration number and  error value */
  352.             fprintf(stderr,"Iteration %5d/%-5d  Error %f\r",
  353.                     q, nIterations, error);               /* to console */
  354.             MonitorError = 0;
  355.  
  356.             if (q % nReportErrors   ==   0)
  357.                fprintf(fpError, "%d  %f\n",  q,  error);  /* to file */
  358.  
  359.             /* Terminate when error satisfactory */
  360.             if (error < ErrorLevel)
  361.                break;
  362.          }
  363.       }
  364.       /* --------------- end of iteration loop -----------------------------*/
  365.  
  366.       for (p = 0, error = 0.0;  p < nPatterns;  p++)
  367.       {
  368.          for (j = 0;  j < nOutputNodes;  j++)
  369.          {
  370.                float  temp   =   target[p][j] - out2[p][j];
  371.  
  372.                error += temp * temp;
  373.          }
  374.       }
  375.  
  376.       /* Average error over all patterns */
  377.       error  /=  nPatterns;
  378.  
  379.       /* Print final iteration number and error value */
  380.       fprintf(stderr, "Iteration %5d/%-5d  Error %f\n", q, nIterations, error); /* to console */
  381.       fprintf(fpError, "\n%d  %f\n",  q,  error);        /* to file */
  382.       fclose(fpError);
  383.  
  384.       /*---------------- print final weights -------------------------------*/
  385.       if ((fpWeightsOut = fopen(szWeightsOut,"w"))  ==  NULL)
  386.       {
  387.          fprintf(stderr,  "%s: can't write file %s\n",  progname, szWeightsOut);
  388.          exit(1);
  389.       }
  390.  
  391.       for (h = 0;  h < nHiddenNodes;  h++)
  392.          for (i = 0;  i <= nInputNodes;  i++)
  393.             fprintf(fpWeightsOut,  "%g%c", w1[h][i], i%ITEMS==ITEMS-1 ? '\n':' ');
  394.  
  395.       for (j = 0;  j < nOutputNodes;  j++)
  396.          for (h = 0;  h <= nHiddenNodes;  h++)
  397.             fprintf(fpWeightsOut,  "%g%c", w2[j][h], j%ITEMS==ITEMS-1 ? '\n':' ');
  398.  
  399.       fclose(fpWeightsOut);
  400.  
  401.       /*----------------- Print final activation values-------------------- */
  402.       if ((fpResults = fopen(szResults,"w"))  ==  NULL)
  403.       {
  404.          fprintf(stderr,  "%s: can't write file %s\n",  progname, szResults);
  405.          fpResults = stderr;
  406.       }
  407.  
  408.       /* Print final output vector */
  409.       for (p = 0;  p < nPatterns;  p++)
  410.       {
  411.          fprintf(fpResults, "%d   ",  p);
  412.  
  413.          for (j = 0;  j < nOutputNodes;  j++)
  414.             fprintf(fpResults, " %f",  out2[p][j]);
  415.  
  416.          fprintf(fpResults, "  %-6.0f\n", PatternID[p]);
  417.       }
  418.  
  419.       fclose(fpResults);
  420.  
  421.       /*---------------- free dynamic storage for data ---------------------*/
  422.       MatrixFree(out0,      nPatterns);
  423.       MatrixFree(out1,      nPatterns);
  424.       MatrixFree(delta1,    nPatterns);
  425.       MatrixFree(delw1,     nHiddenNodes);
  426.       MatrixFree(w1,        nHiddenNodes);
  427.       MatrixFree(out2,      nPatterns);
  428.       MatrixFree(delta2,    nPatterns);
  429.       MatrixFree(delw2,     nOutputNodes);
  430.       MatrixFree(w2,        nOutputNodes);
  431.       MatrixFree(target,    nPatterns);
  432.       free(PatternID);
  433.    }
  434.  
  435.    fclose(fpRun);                     /* close run file */
  436. }
  437.  
  438.  
  439. /*----------------- Array storage allocation routines ---------------------*/
  440. /* Allocate space for vector of float cells for
  441.    one dimensional dynamic vector[cols]
  442. */
  443. void VectorAllocate(VECTOR *vector, int nCols)
  444. {
  445.    if ((*vector = (VECTOR) calloc(nCols, sizeof(float))) == NULL)
  446.    {
  447.       fprintf(stderr, "Sorry! Not enough memory for nodes\n");
  448.       exit(1);
  449.    }
  450. }
  451.  
  452.  
  453. /* Allocate space for columns (float cells) for
  454.    dynamic two dimensional matrix[rows][cols]
  455. */
  456. void AllocateCols(PFLOAT matrix[], int nRows, int nCols)
  457. {
  458.    int  i;
  459.  
  460.    for (i = 0;  i < nRows;  i++)
  461.       VectorAllocate(&matrix[i], nCols);
  462. }
  463.  
  464. /* Allocate space for a two dimensional dynamic matrix [rows] [cols]
  465. */
  466.  
  467. void MatrixAllocate(MATRIX *pmatrix, int nRows, int nCols)
  468. {
  469.    if ( (*pmatrix  =  (MATRIX) calloc(nRows,  sizeof(PFLOAT) ) )   ==  NULL)
  470.    {
  471.       fprintf(stderr, "Sorry! Not enough memory for nodes\n");
  472.       exit(1);
  473.    }
  474.  
  475.    AllocateCols(*pmatrix, nRows, nCols);
  476. }
  477.  
  478. /* free space for two dimensional dynamic array */
  479. void MatrixFree(MATRIX matrix,  int nRows)
  480. {
  481.    int   i;
  482.    for (i = 0;  i < nRows;  i++)
  483.       free(matrix[i]);
  484.    free(matrix);
  485. }
  486.