home *** CD-ROM | disk | FTP | other *** search
/ Shareware Overload / ShartewareOverload.cdr / virus / ddj0491.zip / NEURLNET.ASC < prev    next >
Text File  |  1991-03-15  |  32KB  |  1,094 lines

  1. _NEURAL NETS TELL WHY_
  2. by Casimir C. "Casey" Klimasauskas
  3.  
  4. [LISTING ONE]
  5.  
  6. /* network.c --  Backprop network with explain function */
  7.  
  8. /************************************************************************
  9.  *                                    *
  10.  *    Explain How a Neural Network "thinks" using Sensitivity        *
  11.  *         Analysis                            *
  12.  *                                    *
  13.  ************************************************************************
  14.     This is a program designed to implement a back-propagation network
  15.     and show how sensitivity analysis works to "explain" the network's
  16.     reasoning.
  17.  */
  18.  
  19. #include <stdio.h>    /* file & printer support */
  20. #include <stdlib.h>    /* malloc, rand, RAND_MAX & other things */
  21. #include <math.h>    /* exp() */
  22. #include <dos.h>    /* FP_SEG(), FP_OFF() */
  23.  
  24. #define    MAX_LAYER    4        /* maximum number of layers */
  25. #define    MAX_PES        50        /* maximum number of PEs */
  26.  
  27. /* --- locally required defines --- */
  28.  
  29. #define    WORD(x)        (((short *)(&x))[0])
  30. #define    MAX(x,y)    ((x)>(y)?(x):(y))
  31. typedef float        WREAL;        /* work real */
  32.  
  33. /* --- Connection structure --- */
  34.  
  35. typedef struct _conn {            /* connection */
  36.     struct _pe    *SourceP;    /* source pointer */
  37.     WREAL         WtValR;    /* weight value */
  38.     WREAL         DWtValR;    /* delta weight value */
  39. } CONN;
  40.  
  41. typedef struct _pe {            /* processing element */
  42.     struct _pe    *NextPEP;    /* next PE in layer */
  43.     WREAL         OutputR;    /* output of PE */
  44.     WREAL         ErrorR;    /* work area for error */
  45.     WREAL         WorkR;        /* work area for explain */
  46.     int         PEIndexN;    /* PE index (ordinal) */
  47.     int         MaxConnsN;    /* maximum number of connections */
  48.     int         NConnsN;    /* # of connections used */
  49.     CONN         ConnsS[1];    /* connections to this PE */
  50. } PE;
  51.  
  52. PE    *LayerTP[MAX_LAYER+1]    = {0};        /* pointer to PEs in layer */
  53. int     LayerNI[MAX_LAYER+1]    = {0};        /* # of items in each layer */
  54.  
  55. PE    *PEIndexP[MAX_PES]    = {0};        /* index into PEs */
  56. int     NextPEXN        = {0};        /* index of next free PE */
  57.  
  58. PE    PEBias = { 0, 1.0, 0.0, 0, };        /* "Bias" PE */
  59.  
  60.  
  61. /************************************************************************
  62.  *                                    *
  63.  *    RRandR() - compute uniform random number over a range        *
  64.  *                                    *
  65.  ************************************************************************
  66.  */
  67.  
  68. double RRandR( vR )        /* random value over a range */
  69. double         vR;        /* range magnitude */
  70. {
  71.     double     rvR;        /* return value */
  72.  
  73.     /* compute random value in range 0..1 */
  74.     rvR = ((double)rand()) / (double)RAND_MAX;
  75.  
  76.     /* rescale to range -vR..vR */
  77.     rvR = vR * (rvR + rvR - 1.0);
  78.  
  79.     return( rvR );
  80. }
  81.  
  82. /************************************************************************
  83.  *                                    *
  84.  *    AllocPE() - Allocate a PE dynamically                *
  85.  *                                    *
  86.  ************************************************************************
  87.  */
  88.  
  89. PE *AllocPE( peXN, MaxConnsN )        /* allocate A PE dynamically */
  90. int         peXN;        /* index of PE (0=auto) */
  91. int         MaxConnsN;    /* max number of connections */
  92. {
  93.     PE        *peP;        /* pointer to PE allocated */
  94.     int         AlcSize;    /* size to allocate */
  95.  
  96.     if ( NextPEXN == 0 ) {
  97.     PEIndexP[0] = &PEBias;            /* bias PE */
  98.     NextPEXN++;
  99.     }
  100.  
  101.     if ( peXN == 0 )
  102.     peXN = NextPEXN++;
  103.     else if ( peXN >= NextPEXN )
  104.     NextPEXN = peXN+1;
  105.  
  106.     if ( peXN < 0 || MAX_PES <= peXN ) {
  107.     printf( "Illegal PE number to allocate: %d\n", peXN );
  108.     exit( 1 );
  109.     }
  110.  
  111.     if ( PEIndexP[peXN] != (PE *)0 ) {
  112.     printf( "PE number %d is already in use\n", peXN );
  113.     exit( 1 );
  114.     }
  115.  
  116.     AlcSize = sizeof(PE) + MaxConnsN*sizeof(CONN);
  117.     peP = (PE *)malloc( AlcSize );
  118.     if ( peP == (PE *)0 ) {
  119.     printf( "Could not allocate %d bytes for PE number %d\n",
  120.         AlcSize, peXN );
  121.     exit( 1 );
  122.     }
  123.  
  124.     memset( (char *)peP, 0, AlcSize );
  125.     peP->MaxConnsN = MaxConnsN+1;    /* max number of connections */
  126.     peP->PEIndexN  = peXN;        /* self index for load/save */
  127.     PEIndexP[peXN] = peP;        /* key for later */
  128.  
  129.     return( peP );
  130. }
  131.  
  132.  
  133. /************************************************************************
  134.  *                                    *
  135.  *    AllocLayer() - Dynamically allocate PEs in a layer        *
  136.  *                                    *
  137.  ************************************************************************
  138.  */
  139.  
  140. int AllocLayer( LayN, NPEsN, NConnPPEN )    /* allocate a layer */
  141. int         LayN;            /* layer number */
  142. int         NPEsN;            /* # of PEs in layer */
  143. int         NConnPPEN;        /* # of connections per PE */
  144. {
  145.     PE        *peP;            /* PE Pointer */
  146.     PE        *apeP;            /* alternate PE pointer */
  147.     int         wxN;            /* general counter */
  148.  
  149.     /* Sanity check */
  150.     if ( LayN < 0 || sizeof(LayerTP)/sizeof(LayerTP[0]) <= LayN ) {
  151.     printf( "Layer nubmer (%d) is out of range\n", LayN );
  152.     exit( 1 );
  153.     }
  154.  
  155.     /* Allocate PEs in the layer & link them together */
  156.     LayerNI[LayN] = NPEsN;
  157.     for( wxN = 0; wxN < NPEsN; wxN++, apeP = peP ) {
  158.     peP  = AllocPE( 0, NConnPPEN+1 );    /* allocate next PE */
  159.     if ( LayerTP[LayN] == (PE *)0 ) {
  160.         LayerTP[LayN] = peP;    /* insert table pionter */
  161.     } else {
  162.         apeP->NextPEP = peP;    /* link forward */
  163.     }
  164.     }
  165.  
  166.     return( 0 );
  167. }
  168.  
  169. /************************************************************************
  170.  *                                    *
  171.  *    FreeNet() - Free network memory                    *
  172.  *                                    *
  173.  ************************************************************************
  174.  */
  175.  
  176. int FreeNet()
  177. {
  178.     int         wxN;        /* work index */
  179.     char    *P;        /* work pointer */
  180.  
  181.     for( wxN = 1; wxN < MAX_PES; wxN++ ) {
  182.     if ( (P = (char *)PEIndexP[wxN]) != (char *)0 )
  183.         free( P );
  184.     PEIndexP[wxN] = (PE *)0;
  185.     }
  186.     NextPEXN = 0;
  187.  
  188.     for( wxN = 0; wxN < MAX_LAYER; wxN++ ) {
  189.     LayerTP[wxN] = (PE *)0;
  190.     LayerNI[wxN] = 0;
  191.     }
  192.  
  193.     return( 0 );
  194. }
  195.  
  196. /************************************************************************
  197.  *                                    *
  198.  *    SaveNet() - Save Network                    *
  199.  *                                    *
  200.  ************************************************************************
  201.  */
  202.  
  203. int SaveNet( fnP )        /* save network */
  204. char        *fnP;        /* name of file to save */
  205. {
  206.     int         wxN;        /* work index */
  207.     FILE    *fP;        /* file pointer */
  208.     PE        *peP;        /* PE pointer for save */
  209.     CONN    *cP;        /* connection pointer */
  210.     int         ConnXN;    /* connection index */
  211.  
  212.     if ( NextPEXN <= 1 )
  213.     return( 0 );            /* nothing to do */
  214.  
  215.     if ( (fP = fopen(fnP, "w")) == (FILE *)0 ) {
  216.     printf( "Could not open output file <%s>\n", fnP );
  217.     return( -1 );
  218.     }
  219.  
  220.     /* --- save all of the PEs --- */
  221.     for( wxN = 1; wxN < NextPEXN; wxN++ ) {
  222.     peP = PEIndexP[wxN];
  223.     if ( peP == (PE *)0 ) {
  224.         fprintf( fP, "%d 0 PE\n", wxN );
  225.         continue;
  226.     }
  227.     fprintf( fP, "%d %d  PE\n", wxN, peP->NConnsN );
  228.     for( ConnXN = 0; ConnXN < peP->NConnsN; ConnXN++ ) {
  229.         cP = &peP->ConnsS[ConnXN];
  230.         fprintf( fP, "%d %.6f %.6f\n",
  231.         cP->SourceP->PEIndexN, cP->WtValR, cP->DWtValR );
  232.     }
  233.     }
  234.     fprintf( fP, "%d %d END OF PES\n", -1, 0 );
  235.  
  236.     /* --- save information about how layers are assembled --- */
  237.     for( wxN = 0; wxN < MAX_LAYER; wxN++ ) {
  238.     if ( (peP = LayerTP[wxN]) == (PE *)0 )
  239.         continue;
  240.     fprintf( fP, "%d LAYER\n", wxN );
  241.     do {
  242.         fprintf( fP, "%d\n", peP->PEIndexN );
  243.         peP = peP->NextPEP;
  244.     } while( peP != (PE *)0 );
  245.     fprintf( fP, "-1 End Layer\n" );    /* end of layer */
  246.     }
  247.     fprintf( fP, "-1\n" );        /* no more layers */
  248.  
  249.     fclose( fP );
  250.  
  251.     return( 0 );
  252. }
  253.  
  254. /************************************************************************
  255.  *                                    *
  256.  *    LoadNet() - Load Network                    *
  257.  *                                    *
  258.  ************************************************************************
  259.  */
  260.  
  261. int LoadNet( fnP )        /* load a network file */
  262. char        *fnP;        /* file name pointer */
  263. {
  264.     int         wxN;        /* work index */
  265.     FILE    *fP;        /* file pointer */
  266.     PE        *peP;        /* PE pointer for save */
  267.     PE         *lpeP;        /* last pe in chain */
  268.     int         LayN;        /* layer number */
  269.     CONN    *cP;        /* connection pointer */
  270.     int         ConnXN;    /* connection index */
  271.     int         PEN;        /* PE number */
  272.     int         PENConnsN;    /* # of connections */
  273.     int         StateN;    /* current state 0=PEs, 1=Layers */
  274.     float     WtR, DWtR;    /* weight & delta weight */
  275.     char     BufC[80];    /* work buffer */
  276.  
  277.     fP = (FILE *)0;
  278.     if ( (fP = fopen( fnP, "r" )) == (FILE *)0 ) {
  279.     printf( "Could not open output file <%s>\n", fnP );
  280.     return( -1 );
  281.     }
  282.  
  283.     FreeNet();            /* release any existing network */
  284.  
  285.     StateN = 0;
  286.     while( fgets( BufC, sizeof(BufC)-1, fP ) != (char *)0 ) {
  287.     switch( StateN ) {
  288.     case 0:            /* PEs */
  289.         sscanf( BufC, "%d %d", &PEN, &PENConnsN );
  290.         if ( PEN < 0 ) {
  291.         StateN = 2;
  292.         break;
  293.         }
  294.  
  295.         peP = AllocPE( PEN, PENConnsN );    /* allocate PE */
  296.         cP  = &peP->ConnsS[0];        /* Pointer to Conns */
  297.         ConnXN = PENConnsN;
  298.         if ( ConnXN > 0 )
  299.         StateN = 1;            /* scanning for conns */
  300.         break;
  301.  
  302.     case 1:            /* PE Connections */
  303.         sscanf( BufC, "%d %f %f", &PEN, &WtR, &DWtR );
  304.         WORD(cP->SourceP)    = PEN;
  305.         cP->WtValR        = WtR;
  306.         cP->DWtValR        = DWtR;
  307.         cP++;                /* next connection area */
  308.         peP->NConnsN++;            /* count connections */
  309.         if ( --ConnXN <= 0 )
  310.         StateN = 0;            /* back to looking for PEs */
  311.         break;
  312.  
  313.     case 2:            /* Layer data */
  314.         sscanf( BufC, "%d", &LayN );
  315.         StateN = 3;
  316.         if ( LayN < 0 )
  317.         goto Done;
  318.         lpeP = (PE *)&LayerTP[LayN];
  319.         break;
  320.  
  321.     case 3:            /* layer items */
  322.         sscanf( BufC, "%d", &PEN );
  323.         if ( PEN < 0 ) {
  324.         StateN = 2;
  325.         break;
  326.         }
  327.  
  328.         LayerNI[LayN]++;            /* update # of PEs */
  329.         peP = PEIndexP[PEN];        /* point to PE */
  330.         lpeP->NextPEP = peP;        /* forward chain */
  331.         lpeP = peP;
  332.         break;
  333.     }
  334.     }
  335.  
  336. Done:
  337.     /* go through ALL PEs and convert PE index to pointers */
  338.     for( wxN = 1; wxN < MAX_PES; wxN++ ) {
  339.     if ( (peP = PEIndexP[wxN]) == (PE *)0 )
  340.         continue;
  341.  
  342.     for( ConnXN = peP->NConnsN, cP = &peP->ConnsS[0];
  343.          --ConnXN >= 0;
  344.         cP++ ) {
  345.         cP->SourceP = PEIndexP[ WORD(cP->SourceP) ];
  346.     }
  347.     }
  348.  
  349.     if ( fP ) fclose( fP );
  350.     return( 0 );
  351.  
  352. ErrExit:
  353.     if ( fP ) fclose( fP );
  354.     FreeNet();
  355.     return( -1 );
  356. }
  357.  
  358. /************************************************************************
  359.  *                                    *
  360.  *    PrintNet() - Print out Network                    *
  361.  *                                    *
  362.  ************************************************************************
  363.  */
  364.  
  365. int PrintNet( fnP )        /* print out network */
  366. char        *fnP;        /* file to print to (append) */
  367. {
  368.     FILE    *fP;        /* file pointer */
  369.     PE        *dpeP;        /* destination PE */
  370.     int         layerXN;    /* layer index */
  371.     CONN    *cP;        /* connection pointer */
  372.     int         ConnXN;    /* connection index */
  373.  
  374.     if ( *fnP == '\0' ) {
  375.     fP = stdout;
  376.     } else {
  377.     if ( (fP = fopen( fnP, "a" )) == (FILE *)0 ) {
  378.         printf( "Could not open print output file <%s>\n", fnP );
  379.         return( -1 );
  380.     }
  381.     }
  382.  
  383.     for( layerXN = 0; (dpeP = LayerTP[layerXN]) != (PE *)0; layerXN++ ) {
  384.     fprintf( fP, "\nLayer %d with %d PEs\n", layerXN, LayerNI[layerXN] );
  385.  
  386.     for(; dpeP != (PE *)0; dpeP = dpeP->NextPEP ) {
  387.         fprintf( fP,
  388.         "  %2d %04x:%04x PE Output=%6.3f Error=%6.3f WorkR=%6.3f NConns=%d\n",
  389.         dpeP->PEIndexN, FP_SEG(dpeP), FP_OFF(dpeP),
  390.         dpeP->OutputR, dpeP->ErrorR, dpeP->WorkR, dpeP->NConnsN );
  391.         for( ConnXN = 0; ConnXN < dpeP->NConnsN; ConnXN++ ) {
  392.         cP = &dpeP->ConnsS[ConnXN];
  393.         fprintf( fP,
  394.             "    Src=%2d %04x:%04x Weight=%7.3f Delta Wt=%6.3f\n",
  395.             cP->SourceP->PEIndexN, 
  396.             FP_SEG(cP->SourceP), FP_OFF(cP->SourceP),
  397.             cP->WtValR, cP->DWtValR );
  398.         }
  399.     }
  400.     }
  401.  
  402.     if ( fP != stdout )
  403.     fclose( fP );
  404.     return( 0 );
  405. }
  406.  
  407. /************************************************************************
  408.  *                                    *
  409.  *    FullyConn() - Fully connect a source layer to a destination    *
  410.  *                                    *
  411.  ************************************************************************
  412.  */
  413.  
  414. int FullyConn( DLayN, SLayN, RangeR )
  415. int         DLayN;        /* destination layer */
  416. int         SLayN;        /* source layer */
  417. double         RangeR;    /* range magnitude */
  418. {
  419.     CONN    *cP;        /* connection pointer */
  420.     PE        *speP;        /* source PE pointer */
  421.     PE        *dpeP;        /* destination PE pointer */
  422.  
  423.     /* loop through each of the PEs in the destination layer */
  424.     for( dpeP = LayerTP[DLayN]; dpeP != (PE *)0; dpeP = dpeP->NextPEP ) {
  425.     cP = &dpeP->ConnsS[dpeP->NConnsN];    /* start of connections */
  426.     if ( dpeP->NConnsN == 0 ) {
  427.         /* insert bias PE as first one */
  428.         cP->SourceP = &PEBias;        /* bias PE */
  429.         cP->WtValR  = RRandR( RangeR );    /* initial weight */
  430.         cP->DWtValR = 0.0;
  431.         cP++;                /* account for this conn */
  432.         dpeP->NConnsN++;
  433.     }
  434.  
  435.     /* loop through all PEs in source layer & make connections */
  436.     for( speP = LayerTP[SLayN]; speP != (PE *)0; speP = speP->NextPEP ) {
  437.         cP->SourceP = speP;            /* point to PE */
  438.         cP->WtValR  = RRandR( RangeR );    /* initial weight */
  439.         cP->DWtValR = 0.0;
  440.         cP++;                /* account for this conn */
  441.         dpeP->NConnsN++;
  442.     }
  443.     }
  444.  
  445.     return( 0 );        /* layers fully connected */
  446. }
  447.  
  448. /************************************************************************
  449.  *                                    *
  450.  *    BuildNet() - Build all data structures for back-prop network    *
  451.  *                                    *
  452.  ************************************************************************
  453.  */
  454.  
  455. int BuildNet( NInpN, NHid1N, NHid2N, NOutN, ConnPrevF )
  456. int         NInpN;        /* # of input PEs */
  457. int         NHid1N;    /* # of hidden 1 PEs (zero if none) */
  458. int         NHid2N;    /* # of hidden 2 PEs (zero if none) */
  459. int         NOutN;        /* # of output PEs */
  460. int         ConnPrevF;    /* 1=connect to all prior layers */
  461. {
  462.     int         ReqdPEsN;    /* # of required PEs */
  463.     int         LayerXN;    /* layer index */
  464.     int         SLayN, DLayN;    /* source / destination layer indicies */
  465.  
  466.     if ( NInpN <= 0 || NOutN <= 0 )
  467.     return( -1 );            /* could not build ! */
  468.  
  469.     FreeNet();                    /* kill existing net */
  470.     ReqdPEsN = NInpN + NHid1N + NHid2N + NOutN;
  471.  
  472.     LayerXN = 0;                /* layer index */
  473.     AllocLayer( LayerXN, NInpN, 0 );        /* input layer */
  474.     if ( NHid1N > 0 ) {
  475.     LayerXN++;                /* next layer */
  476.     AllocLayer( LayerXN, NHid1N, NInpN );
  477.     if ( NHid2N > 0 ) {
  478.         LayerXN++;
  479.         AllocLayer( LayerXN, NHid2N, NHid1N + (ConnPrevF?NInpN:0) );
  480.     }
  481.     }
  482.  
  483.     LayerXN++;
  484.     AllocLayer( LayerXN, NOutN, ConnPrevF?(NInpN+NHid1N+NHid2N):NHid2N );
  485.  
  486.     /* connect up the layers */
  487.     for( DLayN = 1; LayerTP[DLayN] != (PE *)0; DLayN++ ) {
  488.     for( SLayN = ConnPrevF?0:(DLayN-1); SLayN < DLayN; SLayN++ )
  489.         FullyConn( DLayN, SLayN, 0.2 );
  490.     }
  491.  
  492.     return( 0 );
  493. }
  494.  
  495. /************************************************************************
  496.  *                                    *
  497.  *    Recall() - Step network through one recall cycle        *
  498.  *                                    *
  499.  ************************************************************************
  500.  */
  501.  
  502. int Recall( ivRP, ovRP )    /* perform a recall */
  503. float        *ivRP;        /* input vector */
  504. float        *ovRP;        /* output vector */
  505. {
  506.     int         DLayN;        /* destination layer index */
  507.     PE        *peP;        /* work PE pointer */
  508.     CONN    *cP;        /* connection pointer */
  509.     int         ConnC;        /* connection counter */
  510.     double     SumR;        /* summation function */
  511.  
  512.     for( DLayN = 0; (peP = LayerTP[DLayN]) != (PE *)0; DLayN++ ) {
  513.     for( ; peP != (PE *)0; peP = peP->NextPEP ) {
  514.         if ( DLayN == 0 ) {
  515.         /* input layer, output is just input vector */
  516.         peP->OutputR = ivRP[0];        /* copy input values */
  517.         peP->ErrorR  = 0.0;        /* clear error */
  518.         ivRP++;
  519.         } else {
  520.         /* hidden or output layer, compute weighted sum & transform */
  521.         ConnC = peP->NConnsN;        /* # of connections */
  522.         cP    = &peP->ConnsS[0];        /* pointer to connections */
  523.         SumR  = 0.0;            /* no sum yet */
  524.         for( ; --ConnC >= 0; cP++ )
  525.             SumR += cP->SourceP->OutputR * cP->WtValR;
  526.         peP->OutputR = 1.0 / (1.0 + exp(-SumR) );
  527.         peP->ErrorR  = 0.0;
  528.         }
  529.  
  530.         if ( LayerTP[DLayN+1] == (PE *)0 ) {
  531.         /* this is output layer, copy result back to user */
  532.         ovRP[0] = peP->OutputR;        /* copy output value */
  533.         ovRP++;                /* next output value */
  534.         }
  535.     }
  536.     }
  537.  
  538.     return( 0 );
  539. }
  540.  
  541. /************************************************************************
  542.  *                                    *
  543.  *    Learn() - step network through one learn cycle            *
  544.  *                                    *
  545.  ************************************************************************
  546.     return:  squared error for this training example
  547.  */
  548.  
  549. double Learn( ivRP, ovRP, doRP, LearnR, MomR )    /* train network */
  550. float        *ivRP;        /* input vector */
  551. float        *ovRP;        /* output vector */
  552. float        *doRP;        /* desired output vector */
  553. double         LearnR;    /* learning rate */
  554. double         MomR;        /* momentum */
  555. {
  556.     double     ErrorR = 0.0;    /* squared error */
  557.     double     LErrorR;    /* local error */
  558.     PE        *speP;        /* source PE */
  559.     int         DLayN;        /* destination layer index */
  560.     PE        *peP;        /* work PE pointer */
  561.     CONN    *cP;        /* connection pointer */
  562.     int         ConnC;        /* connection counter */
  563.  
  564.     Recall( ivRP, ovRP );        /* perform recall */
  565.  
  566.     /* search for output layer */
  567.     for( DLayN = 0; (LayerTP[DLayN+1]) != (PE *)0; )
  568.     DLayN++;
  569.  
  570.     /* compute error, backpropagate error, update weights */
  571.     for( ; DLayN > 0; DLayN-- ) {
  572.     for( peP = LayerTP[DLayN]; peP != (PE *)0; peP = peP->NextPEP ) {
  573.         if ( LayerTP[DLayN+1] == (PE *)0 ) {
  574.         /* output layer, compute error specially */
  575.         peP->ErrorR = (doRP[0] - peP->OutputR);
  576.         ErrorR += (peP->ErrorR * peP->ErrorR);
  577.         doRP++;
  578.         }
  579.  
  580.         /* pass error back through transfer function */
  581.         peP->ErrorR *= peP->OutputR * (1.0 - peP->OutputR);
  582.  
  583.         /* back-propagate it through connections & update them */
  584.         ConnC = peP->NConnsN;        /* # of connections */
  585.         cP    = &peP->ConnsS[0];        /* pointer to connections */
  586.         LErrorR = peP->ErrorR;        /* local error */
  587.         for( ; --ConnC >= 0; cP++ ) {
  588.         speP = cP->SourceP;
  589.         speP->ErrorR += LErrorR * cP->WtValR;    /* propagate error */
  590.         cP->DWtValR =                /* compute new weight */
  591.             LearnR * LErrorR * speP->OutputR +
  592.             MomR * cP->DWtValR;
  593.         cP->WtValR += cP->DWtValR;        /* update weight */
  594.         }
  595.     }
  596.     }
  597.  
  598.     return( ErrorR );
  599. }
  600.  
  601. /************************************************************************
  602.  *                                    *
  603.  *    Explain() - compute the derivative of the output for changes    *
  604.  *      in the inputs                            *
  605.  *                                    *
  606.  ************************************************************************
  607.    Basic Procedure:
  608.     1) do a recall to find out what the nominal output values are
  609.     2) copy the nominal values to "WorkR" in the PE structure.
  610.         (We could have used the ErrorR field but WorkR was
  611.          used to reduce confusion.)
  612.     3) for each input:
  613.         a) Add a small amount to the input value (DitherR)
  614.         b) do a Recall & compute derivative of output
  615.         c) subtract samll amount from nominal in put value
  616.         d) do a Recall & compute derivative of outputs
  617.         e) Average two derivatives
  618.  */
  619.  
  620. int Explain( ivRP, ovRP, evRP, DitherR )
  621. float        *ivRP;        /* input vector */
  622. float        *ovRP;        /* output result vector */
  623. float        *evRP;        /* explain vector */
  624. double         DitherR;    /* dither */
  625. {
  626.     PE        *speP;        /* source PE (input) */
  627.     int         speXN;        /* source PE index */
  628.     PE        *dpeP;        /* destination PE (output) */
  629.     int         dpeXN;        /* destination PE index */
  630.     int         OutLXN;    /* output layer index */
  631.  
  632.     /* figure out index of output layer */
  633.     for( OutLXN = 0; LayerTP[OutLXN+1] != (PE *)0; )
  634.     OutLXN++;
  635.  
  636.     Recall( ivRP, ovRP );            /* set up initial recall */
  637.  
  638.     /* go through output layer and copy output to "WorkR" */
  639.     for( dpeP = LayerTP[OutLXN]; dpeP != (PE *)0; dpeP = dpeP->NextPEP )
  640.     dpeP->WorkR = dpeP->OutputR;
  641.  
  642.     /* for each input, compute its effects on the output */
  643.     for( speXN = 0, speP = LayerTP[0];
  644.      speP != (PE *)0;
  645.      speXN++, speP = speP->NextPEP ) {
  646.     /* dither in positive direction */
  647.     ivRP[speXN] += DitherR;            /* add dither */
  648.     Recall( ivRP, ovRP );            /* new output */
  649.  
  650.     /* set initial results to evRP */
  651.     for( dpeXN = 0, dpeP = LayerTP[OutLXN];
  652.          dpeP != (PE *)0;
  653.          dpeXN++, dpeP = dpeP->NextPEP )
  654.         evRP[dpeXN] = 0.5 * (dpeP->OutputR - dpeP->WorkR) / DitherR;
  655.  
  656.     /* dither in negative direction */
  657.     ivRP[speXN] -= (DitherR + DitherR);    /* subtract dither */
  658.     Recall( ivRP, ovRP );            /* new output */
  659.  
  660.     /* set final results to evRP */
  661.     for( dpeXN = 0, dpeP = LayerTP[OutLXN];
  662.          dpeP != (PE *)0;
  663.          dpeXN++, dpeP = dpeP->NextPEP )
  664.         evRP[dpeXN] -= 0.5 * (dpeP->OutputR - dpeP->WorkR) / DitherR;
  665.  
  666.     /* point to next row of explain vector */
  667.     evRP += dpeXN;
  668.  
  669.     /* restore current input to original value */
  670.     ivRP[speXN] += DitherR;
  671.     }
  672.  
  673.     return( 0 );
  674. }
  675.  
  676. /************************************************************************
  677.  *                                    *
  678.  *    Network Training Data                        *
  679.  *                                    *
  680.  ************************************************************************
  681.  
  682.     +--------------+-----------+ 1.0
  683.     |              |           |
  684.    2    |    zero      |           |
  685.     |              |   one     |
  686.    t    +--------------+           | 0.7
  687.    u    |                          |
  688.    p    |                          |
  689.    n    +--------------------------+ 0.3
  690.    I    |                          |
  691.     |            zero          |
  692.     |                          |
  693.     +--------------------------+ 0.0
  694.        0.0            0.5         1.0
  695.  
  696.         Input 1
  697.  
  698.     Input 3 is "noise".
  699.     Output 1 is shown above.
  700.     Output 2 is opposite of output 1.
  701.  */
  702.  
  703. typedef struct _example {        /* example */
  704.     float     InVecR[3];        /* input vector */
  705.     float     DoVecR[2];        /* desired output vector */
  706. } EXAMPLE;
  707.  
  708.  
  709. #define    NTEST    (sizeof(testE)/sizeof(testE[0]))    /* # of test items */
  710.  
  711. EXAMPLE testE[] = {
  712. /*       --- Inputs ---     --- Desired Outputs -- */
  713.     { 0.0, 0.0, 0.0,    0.0, 1.0 },
  714.     { 0.0, 0.2, 0.6,    0.0, 1.0 },
  715.     { 0.0, 0.4, 0.1,    1.0, 0.0 },
  716.     { 0.0, 0.6, 0.7,    1.0, 0.0 },
  717.     { 0.0, 0.8, 0.2,    0.0, 1.0 },
  718.     { 0.0, 1.0, 0.8,    0.0, 1.0 },
  719.  
  720.     { 0.2, 0.0, 0.9,    0.0, 1.0 },
  721.     { 0.2, 0.2, 0.3,    0.0, 1.0 },
  722.     { 0.2, 0.4, 0.8,    1.0, 0.0 },
  723.     { 0.2, 0.6, 0.2,    1.0, 0.0 },
  724.     { 0.2, 0.8, 0.7,    0.0, 1.0 },
  725.     { 0.2, 1.0, 0.1,    0.0, 1.0 },
  726.  
  727.     { 0.4, 0.0, 0.0,    0.0, 1.0 },
  728.     { 0.4, 0.2, 0.6,    0.0, 1.0 },
  729.     { 0.4, 0.4, 0.1,    1.0, 0.0 },
  730.     { 0.4, 0.6, 0.7,    1.0, 0.0 },
  731.     { 0.4, 0.8, 0.2,    0.0, 1.0 },
  732.     { 0.4, 1.0, 0.8,    0.0, 1.0 },
  733.  
  734.     { 0.6, 0.0, 0.9,    0.0, 1.0 },
  735.     { 0.6, 0.2, 0.3,    0.0, 1.0 },
  736.     { 0.6, 0.4, 0.8,    1.0, 0.0 },
  737.     { 0.6, 0.6, 0.2,    1.0, 0.0 },
  738.     { 0.6, 0.8, 0.7,    1.0, 0.0 },
  739.     { 0.6, 1.0, 0.1,    1.0, 0.0 },
  740.  
  741.     { 0.8, 0.0, 0.4,    0.0, 1.0 },
  742.     { 0.8, 0.2, 0.6,    0.0, 1.0 },
  743.     { 0.8, 0.4, 0.1,    1.0, 0.0 },
  744.     { 0.8, 0.6, 0.7,    1.0, 0.0 },
  745.     { 0.8, 0.8, 0.2,    1.0, 0.0 },
  746.     { 0.8, 1.0, 0.8,    1.0, 0.0 },
  747.  
  748.     { 1.0, 0.0, 1.0,    0.0, 1.0 },
  749.     { 1.0, 0.2, 0.3,    0.0, 1.0 },
  750.     { 1.0, 0.4, 0.8,    1.0, 0.0 },
  751.     { 1.0, 0.6, 0.2,    1.0, 0.0 },
  752.     { 1.0, 0.8, 0.7,    1.0, 0.0 },
  753.     { 1.0, 1.0, 0.0,    1.0, 0.0 }
  754. };
  755.  
  756. int TestShuffleN[ NTEST ]    = {0};        /* shuffle array */
  757. int TestSXN            = {NTEST+1};    /* current shuffle index */
  758.  
  759. /************************************************************************
  760.  *                                    *
  761.  *    NextTestN() - Do shuffle & deal randomization of training set    *
  762.  *                                    *
  763.  ************************************************************************
  764.  */
  765.  
  766. int NextTestN()        /* Get next Training example index */
  767. {
  768.     int        HitsN;    /* # of items we have added to list */
  769.     int        wxN;    /* work index into shuffle array */
  770.     int        xN,yN;    /* indicies of items to swap */
  771.  
  772.     if ( TestSXN >= NTEST ) {
  773.     /* reshuffle the array */
  774.     for( wxN = 0; wxN < NTEST; wxN++ )
  775.         TestShuffleN[wxN] = wxN;
  776.  
  777.     /* quick & dirty way to shuffle.  Much better ways exist. */
  778.     for( HitsN = 0; HitsN < NTEST+NTEST/2; HitsN++ ) {
  779.         xN = rand() % NTEST;
  780.         yN = rand() % NTEST;
  781.         wxN = TestShuffleN[xN];
  782.         TestShuffleN[xN] = TestShuffleN[yN];
  783.             TestShuffleN[yN] = wxN;
  784.     }
  785.  
  786.     TestSXN = 0;
  787.     }
  788.  
  789.     return( TestShuffleN[TestSXN++] );
  790. }
  791.  
  792. /************************************************************************
  793.  *                                    *
  794.  *    TrainNet() - Driver for training Network            *
  795.  *                                    *
  796.  ************************************************************************
  797.  */
  798.  
  799. int TrainNet( ErrLvlR, MaxPassN )        /* train network */
  800. double         ErrLvlR;    /* error level to achieve */
  801. long         MaxPassN;    /* max number of passes */
  802. {
  803.     float     rvR[MAX_PES];        /* result vector */
  804.     double     lsErrR;
  805.     int         CurTestN;        /* current test number */
  806.     int         HitsN;            /* # of times below threshold */
  807.     int         PassN;            /* pass through the data */
  808.     int         ExampleN;        /* example number */
  809.  
  810.     HitsN    = 0;
  811.     CurTestN    = 0;
  812.     lsErrR    = 0.0;
  813.     PassN    = 0;
  814.     for(;;) {
  815.     ExampleN = NextTestN();        /* next test number */
  816.     lsErrR += Learn( &testE[ExampleN].InVecR[0],
  817.              &rvR[0],
  818.              &testE[ExampleN].DoVecR[0], 0.9, 0.5 );
  819.     CurTestN++;
  820.     if ( CurTestN >= NTEST ) {
  821.         PassN++;
  822.         lsErrR = sqrt(lsErrR)/ (double)NTEST;
  823.         if ( lsErrR < ErrLvlR )
  824.             HitsN++;
  825.         else    HitsN = 0;
  826.  
  827.         printf( "Pass %3d Error = %.3f Hits = %d\n",
  828.         PassN, lsErrR, HitsN );
  829.  
  830.         if ( PassN > MaxPassN || HitsN > 3 )    /* exit criterial */
  831.         break;
  832.         CurTestN = 0;
  833.         lsErrR = 0.0;
  834.     }
  835.     }
  836.  
  837.     /* done training, start testing */
  838.  
  839.     return( 0 );
  840. }
  841.  
  842. /************************************************************************
  843.  *                                    *
  844.  *    ExplainNet() - do explain & print it out            *
  845.  *                                    *
  846.  ************************************************************************
  847.  */
  848.  
  849. int ExplainNet( fnP, DitherR )        /* explain & print */
  850. char        *fnP;        /* output file name */
  851. double         DitherR;    /* amount to dither */
  852. {
  853.     FILE    *fP;        /* file pointer */
  854.     int         wxN;        /* work index */
  855.     int         xN, yN;    /* x,y values */
  856.     int         axN;        /* alternate work index */
  857.     float    *wfP;        /* work float pointer */
  858.     static float ivR[MAX_PES] = {0};        /* input vector */
  859.     static float ovR[MAX_PES] = {0};        /* work area for output data */
  860.     static float evR[MAX_PES*MAX_PES] = {0};    /* explain vector */
  861.     /*    evR[0] = dY1 vs Input 1
  862.      *    evR[1] = dY2 vs Input 1
  863.      *  evR[2] = dY1 vs Input 2
  864.      *  evR[3] = dY2 vs Input 2
  865.      *    evR[4] = dY1 vs Input 3
  866.      *    evR[5] = dY2 vs Input 3
  867.      */
  868.  
  869.  
  870.     if ( *fnP == '\0' ) {
  871.     fP = stdout;
  872.     } else {
  873.     if ( (fP = fopen( fnP, "a" )) == (FILE *)0 ) {
  874.         printf( "Could not open explain output file <%s>\n", fnP );
  875.         return( -1 );
  876.     }
  877.     }
  878.  
  879.     fprintf( fP,
  880.     "\f\n*** Network Output as a function of inputs 1 & 2 ***\n\n" );
  881.  
  882.     ivR[2] = 0.5;
  883.     for( yN = 20; yN >= 0; yN-- ) {
  884.     if ( (yN % 2) == 0 )    fprintf( fP, "%6.2f | ", yN/20. );
  885.     else            fprintf( fP, "       | " );
  886.     for( xN = 0; xN <= 20; xN++ ) {
  887.         ivR[0] = xN / 20.;
  888.         ivR[1] = yN / 20.;
  889.         Recall( &ivR[0], &ovR[0] );
  890.  
  891.         /* --- ignore very small changes --- */
  892.         if ( fabs(ovR[0]) < .1 )    fprintf( fP, "   - " );
  893.         else            fprintf( fP, "%5.1f", ovR[0] );
  894.     }
  895.     fprintf( fP, "\n" );
  896.     }
  897.     fprintf( fP, "       +-" );
  898.     for( xN = 0; xN <= 20; xN++ )
  899.     fprintf( fP, "-----" );
  900.     fprintf( fP, "\n         " );
  901.  
  902.     for( xN = 0; xN <= 20; xN++ )
  903.     fprintf( fP, (xN % 2)==0?"%5.1f":"     ", xN/20. );
  904.     fprintf( fP, "\n" );
  905.  
  906.  
  907.  
  908.     fprintf( fP,
  909.     "\f\n*** Plot of Explain Function for Input 1 over input range ***\n\n" );
  910.  
  911.     ivR[2] = 0.5;
  912.     for( yN = 20; yN >= 0; yN-- ) {
  913.     if ( (yN % 2) == 0 )    fprintf( fP, "%6.2f | ", yN/20. );
  914.     else            fprintf( fP, "       | " );
  915.     for( xN = 0; xN <= 20; xN++ ) {
  916.         ivR[0] = xN / 20.;
  917.         ivR[1] = yN / 20.;
  918.         Explain( &ivR[0], &ovR[0], &evR[0], DitherR );
  919.  
  920.         /* --- ignore very small changes --- */
  921.         if ( fabs(evR[0]) < .1 )    fprintf( fP, "   - " );
  922.         else            fprintf( fP, "%5.1f", evR[0] );
  923.     }
  924.     fprintf( fP, "\n" );
  925.     }
  926.     fprintf( fP, "       +-" );
  927.     for( xN = 0; xN <= 20; xN++ )
  928.     fprintf( fP, "-----" );
  929.     fprintf( fP, "\n         " );
  930.  
  931.     for( xN = 0; xN <= 20; xN++ )
  932.     fprintf( fP, (xN % 2)==0?"%5.1f":"     ", xN/20. );
  933.     fprintf( fP, "\n" );
  934.  
  935.  
  936.  
  937.     fprintf( fP,
  938.     "\f\n*** Plot of Explain Function for Input 2 over input range ***\n\n" );
  939.  
  940.  
  941.     ivR[2] = 0.5;
  942.     for( yN = 20; yN >= 0; yN-- ) {
  943.     if ( (yN % 2) == 0 )    fprintf( fP, "%6.2f | ", yN/20. );
  944.     else            fprintf( fP, "       | " );
  945.     for( xN = 0; xN <= 20; xN++ ) {
  946.         ivR[0] = xN / 20.;
  947.         ivR[1] = yN / 20.;
  948.         Explain( &ivR[0], &ovR[0], &evR[0], DitherR );
  949.  
  950.         /* --- ignore very small changes --- */
  951.         if ( fabs(evR[2]) < .1 )    fprintf( fP, "   - " );
  952.         else            fprintf( fP, "%5.1f", evR[2] );
  953.     }
  954.     fprintf( fP, "\n" );
  955.     }
  956.     fprintf( fP, "       +-" );
  957.     for( xN = 0; xN <= 20; xN++ )
  958.     fprintf( fP, "-----" );
  959.     fprintf( fP, "\n         " );
  960.  
  961.     for( xN = 0; xN <= 20; xN++ )
  962.     fprintf( fP, (xN % 2)==0?"%5.1f":"     ", xN/20. );
  963.     fprintf( fP, "\n" );
  964.  
  965.  
  966.  
  967.     fprintf( fP,
  968.     "\f\n*** Plot of Explain Function for Input 3 over input range ***\n\n" );
  969.  
  970.  
  971.     ivR[2] = 0.5;
  972.     for( yN = 20; yN >= 0; yN-- ) {
  973.     if ( (yN % 2) == 0 )    fprintf( fP, "%6.2f | ", yN/20. );
  974.     else            fprintf( fP, "       | " );
  975.     for( xN = 0; xN <= 20; xN++ ) {
  976.         ivR[0] = xN / 20.;
  977.         ivR[1] = yN / 20.;
  978.         Explain( &ivR[0], &ovR[0], &evR[0], DitherR );
  979.  
  980.         /* --- ignore very small changes --- */
  981.         if ( fabs(evR[4]) < .1 )    fprintf( fP, "   - " );
  982.         else            fprintf( fP, "%5.1f", evR[4] );
  983.     }
  984.     fprintf( fP, "\n" );
  985.     }
  986.     fprintf( fP, "       +-" );
  987.     for( xN = 0; xN <= 20; xN++ )
  988.     fprintf( fP, "-----" );
  989.     fprintf( fP, "\n         " );
  990.  
  991.     for( xN = 0; xN <= 20; xN++ )
  992.     fprintf( fP, (xN % 2)==0?"%5.1f":"     ", xN/20. );
  993.     fprintf( fP, "\n" );
  994.  
  995.  
  996.     if ( fP != stdout )
  997.     fclose( fP );
  998.     return( 0 );
  999. }
  1000.  
  1001.  
  1002. /************************************************************************
  1003.  *                                    *
  1004.  *    main() - Driver for entre program                *
  1005.  *                                    *
  1006.  ************************************************************************
  1007.  */
  1008.  
  1009. main()
  1010. {
  1011.     int         ActionN;    /* action character */
  1012.     char    *sP;        /* string pointer */
  1013.     char    *aP;        /* alternate pointer */
  1014.     char     BufC[80];    /* work buffer */
  1015.  
  1016.     printf( "\nC-Program to Explain a Neural Network's Conclusions\n" );
  1017.     printf( "  Written by: Casimir C. 'Casey' Klimasauskas, 04-Jan-91\n" );
  1018.     for(;;) {
  1019.     printf( "\
  1020. C          - create a new network\n\
  1021. L [fname]  - load a trained network\n\
  1022. S [fname]  - save a network\n\
  1023. P [fname]  - print out network\n\
  1024. F          - free network from memory\n\
  1025. T          - Train network\n\
  1026. E [fname]  - Explain network\n\
  1027. X          - eXit from the program\n\
  1028. What do you want to do? " );
  1029.     fflush( stdout );
  1030.     sP = fgets( BufC, sizeof(BufC)-1, stdin );
  1031.     if ( sP == (char *)0 )
  1032.         break;
  1033.  
  1034.     while( *sP != 0 && *sP <= ' ' )
  1035.         sP++;
  1036.     ActionN = *sP;
  1037.     if ( 'A' <= ActionN && ActionN <= 'Z' )
  1038.         ActionN -= 'A'-'a';            /* convert to LC */
  1039.     sP++;
  1040.     while( *sP != 0 && *sP <= ' ' )
  1041.         sP++;                /* skip to argument */
  1042.     for( aP = sP; *aP > ' '; )
  1043.         aP++;                /* skip to end of argument */
  1044.     *aP = '\0';                /* null terminate it */
  1045.  
  1046.     switch( ActionN ) {
  1047.     case 'c':        /* create network */
  1048.         BuildNet( 3, 5, 0, 2, 1 );
  1049.         break;
  1050.  
  1051.     case 'l':        /* load network */
  1052.         if ( *sP == '\0' )
  1053.         sP = "network.net";
  1054.         LoadNet( sP );
  1055.         break;
  1056.  
  1057.     case 's':        /* save network */
  1058.         if ( *sP == '\0' )
  1059.         sP = "network.net";
  1060.         SaveNet( sP );
  1061.         break;
  1062.  
  1063.     case 'p':        /* print network */
  1064.         PrintNet( sP );
  1065.         break;
  1066.  
  1067.     case 'f':        /* free network */
  1068.         FreeNet();
  1069.         break;
  1070.  
  1071.     case 't':        /* train network */
  1072.         TrainNet( 0.001, 100000L );
  1073.         break;
  1074.  
  1075.     case 'e':        /* explain network */
  1076.         ExplainNet( sP, .01 );
  1077.         break;
  1078.  
  1079.     case 'x':        /* done */
  1080.         goto Done;
  1081.  
  1082.     default:
  1083.         break;
  1084.     }
  1085.     }
  1086.  
  1087. Done:
  1088.     return( 0 );
  1089. }
  1090.  
  1091.  
  1092.  
  1093.  
  1094.