home *** CD-ROM | disk | FTP | other *** search
/ The C Users' Group Library 1994 August / wc-cdrom-cusersgrouplibrary-1994-08.iso / vol_200 / 299_01 / bp.c < prev    next >
Text File  |  1989-12-30  |  29KB  |  863 lines

  1. /****************************************************************************/
  2. /*  file name: bp.c                                                         */
  3. /*  (c) by Ronald Michaels. This program may be freely copied, modified,    */
  4. /*  transmitted, or used for any non-commercial purpose.                    */
  5. /*  this is the main file of the back propagation program                   */
  6. /*  This program compiles under the Zortech c compiler v. 1.07 using their  */
  7. /*  graphics library or under Ecosoft v4.07 (set GRAPH 0)                   */
  8. /****************************************************************************/
  9.  
  10. #include<stdio.h>
  11. #include<stdlib.h>
  12. #include<math.h>
  13. #include"error.h"
  14. #include"random.h"
  15.  
  16. #define GRAPH 0  /* GRAPH 1 if it is desired to link in the graphics */
  17.                       /* GRAPH 0 if no graph is desired */ 
  18.  
  19. #if GRAPH==1
  20. #include"plot.h"
  21. #endif
  22.  
  23. #ifdef ECO
  24. #include<malloc.h>             /*  required for eco-c compiler  */
  25. #endif
  26.  
  27. #define U(x) (unsigned int)(x)    /*  type conversion */
  28. #define SQ(x) ((x)*(x))           /*  square macro  */
  29.  
  30. /* function prototypes */
  31. void getdata         (FILE *bp1,FILE *bp2);
  32. void getpattern      (FILE *bp1,int,int,double *);
  33. void allocate_memory (void);
  34. void init_weights    (int,int,double *);
  35. void learn           (int);
  36. void foreward        (int,int,double *,double *,double *);
  37. void recognise       (void);
  38. void calc_delta_o    (int,int,double *,double *,double *);
  39. void calc_delta_h    (int,int,double *,double *,double *,double *);
  40. void calc_descent    (int,int,double,double,double *,double *,double *);
  41. void correct_weight  (int,int,double *,double *);
  42. double activate      (double);
  43. double pattern_error (int,int,double *,double *);
  44. void print_scale     (void);
  45. void get_seed(void);
  46. void get_limits(void);
  47.  
  48. void dump            (int);   /* function to dump intermediate results */
  49.  
  50. /* external variable declarations  */
  51. double *input;       /* pointer to input matrix */
  52. double *output;      /* pointer to output unit output vector */
  53. double *target;      /* pointer to target matrix */
  54. double *weight_h;    /* pointer to weight matrix to hidden units */
  55. double *weight_o;    /* pointer to weight matrix to output units */
  56. double *hidden;      /* pointer to hidden unit output vector */
  57. double *delta_o;     /* pointer to output unit delta vector */
  58. double *delta_h;     /* pointer to hidden unit delta vector */
  59. double *descent_h;   /* pointer to weight change matrix for weights to
  60.                             hidden units */
  61. double *descent_o;   /* pointer to weight change matrix for weights to
  62.                                output units */
  63.  
  64. int n_pattern;       /* number of training patterns to be used */
  65. int n_input;         /* number of input units in one pattern (dimensionality) */
  66. int n_hidden;        /* number of hidden units */
  67. int n_output;        /* number of output units in one target (dimensionality) */
  68.  
  69. double learning_rate;       /* learning rate parameter */
  70. double momentum;            /* proportion of previous weight change */
  71.  
  72. FILE *bp3;                  /*  pointer to output file bp3.dat */
  73.  
  74. /****************************************************************************/
  75. int main()                  /* some compilers want main to be void */
  76. {
  77.  
  78.    FILE *bp1;               /*  pointer to input file bp1.dat */
  79.    FILE *bp2;               /*  pointer to input file bp2.dat */
  80.  
  81.    char buff[10];           /*  buffer to hold number of cycles  */
  82.  
  83.    int choice;              /* program control choice */
  84.    int p;                   /* pattern counter  */
  85.     int cycles;              /* number of cycles for learning algorithm */
  86.     
  87.    if((bp1=fopen("bp1.dat","r"))==NULL){  /* open data input file */
  88.       error(0,FATAL);
  89.    }
  90.    if((bp2=fopen("bp2.dat","r"))==NULL){  /* open configuration file */
  91.       error(1,FATAL);
  92.    }
  93.    if((bp3=fopen("bp3.dat","w"))==NULL){  /* open output file */
  94.       error(2,FATAL);
  95.    }
  96.  
  97.    /* get training pattern size from input file bp1.dat */
  98.    getdata(bp1,bp2);
  99.  
  100.    /* allocate space for input vectors  */
  101.    allocate_memory();
  102.  
  103.    /* load input patterns into memory */
  104.    getpattern(bp1,n_pattern,n_input,input);
  105.  
  106.    /* load target patterns into memory */
  107.    getpattern(bp1,n_pattern,n_output,target);
  108.  
  109.     get_seed();    /* seed random number generator */
  110.     get_limits();  /* set range of random numbers */
  111.  
  112.    /* initialise weight matrices with random weights */
  113.    init_weights(n_input,n_hidden,weight_h);
  114.    init_weights(n_hidden,n_output,weight_o);
  115.  
  116.    /* enter program control loop */
  117.    for(;;){
  118.  
  119.       printf("\nBack Propagation Generalised Delta Rule Learning Program\n");
  120.       printf("          Learn\n          Recognise\n");
  121.       printf("          Dump\n          Quit\n");
  122.       printf("choice:");
  123.       choice = getch();
  124.       putchar(choice);
  125.  
  126.       switch(choice){
  127.          case 'l':
  128.          case 'L':
  129.             printf("\nHow Many Cycles?\n");
  130.                 cycles=atoi(gets(buff));
  131.                 if(cycles<1)cycles=1;
  132.             learn(cycles);
  133.             break;
  134.          case 'r':
  135.          case 'R':
  136.             recognise();
  137.             break;
  138.          case 'd':
  139.          case 'D':
  140.             for(p=0;p<n_pattern;p++)dump(p);
  141.             printf("\nNetwork variables dumped into file bp3.dat");
  142.             break;
  143.          case 'q':
  144.          case 'Q':
  145.             exit(0);
  146.          default:
  147.             break;
  148.       }
  149.    }
  150.    fclose(bp1);
  151.    fclose(bp2);
  152.    fclose(bp3);
  153. }
  154.  
  155. /****************************************************************************/
  156. /* getdata                                                                  */
  157. /* this function gets data from the data file regarding the size and number */
  158. /* of patterns and the configuration file                                   */
  159. /****************************************************************************/
  160.  
  161. void getdata(
  162.    FILE *bp1,              /*  pointer to input file bp1.dat */
  163.    FILE *bp2               /*  pointer to input file bp2.dat */
  164. )
  165. {
  166.    if(fscanf(bp1,"%d",&n_pattern)==EOF){  /* get the number */
  167.       error(3,FATAL);             /* of pattern vectors */
  168.    }
  169.    if(fscanf(bp1,"%d",&n_input)==EOF){    /* get the dimensionality */
  170.       error(3,FATAL);             /* of input vectors */
  171.    }
  172.    if(fscanf(bp1,"%d",&n_output)==EOF){   /* get the dimensionality */
  173.       error(3,FATAL);             /* of target vectors */
  174.    }
  175.    if(fscanf(bp1,"%d",&n_hidden)==EOF){   /* get the number */
  176.       error(3,FATAL);             /* of hidden units */
  177.    }
  178.    if(fscanf(bp2,"%lf",&learning_rate)==EOF){  /* get learning rate */
  179.       error(4,FATAL);
  180.    }
  181.    if(fscanf(bp2,"%lf",&momentum)==EOF){  /* get learning momoentum */
  182.       error(4,FATAL);
  183.    }
  184. }
  185.  
  186. /****************************************************************************/
  187. /* allocate_memory                                                          */
  188. /* this function allocates memory for the network                           */
  189. /****************************************************************************/
  190.  
  191. void allocate_memory()
  192. {
  193.    /* allocate space for input vectors */
  194.    if((input=(double *)calloc(U(n_pattern*n_input),sizeof(double)))==NULL){
  195.       error(6,FATAL);
  196.    }
  197.    /* allocate space for target vectors  */
  198.    if((target=(double *)calloc(U(n_pattern*n_output),sizeof(double)))==NULL){
  199.       error(6,FATAL);
  200.    }
  201.    /* allocate space for output vectors  */
  202.    if((output=(double *)calloc(U(n_pattern*n_output),sizeof(double)))==NULL){
  203.       error(6,FATAL);
  204.    }
  205.    /* allocate space for hidden unit vector */
  206.    if((hidden=(double *)calloc(U(n_hidden),sizeof(double)))==NULL){
  207.       error(6,FATAL);
  208.    }
  209.    /* allocate space for hidden unit delta vector */
  210.    if((delta_h=(double *)calloc(U(n_hidden),sizeof(double)))==NULL){
  211.       error(6,FATAL);
  212.    }
  213.    /* allocate space for output unit delta vector */
  214.    if((delta_o=(double *)calloc(U(n_output),sizeof(double)))==NULL){
  215.       error(6,FATAL);
  216.    }
  217.    /* allocate space for hidden weights */
  218.    if((weight_h=(double *)calloc(U((n_input+1)*n_hidden),sizeof(double)))==NULL){
  219.       error(6,FATAL);
  220.    }
  221.    /* allocate space for output weights */
  222.    if((weight_o=(double *)calloc(U((n_hidden+1)*n_output),sizeof(double)))==NULL){
  223.       error(6,FATAL);
  224.    }
  225.    /* allocate space for weight changes to hidden weights */
  226.    if((descent_h=(double *)calloc(U((n_input+1)*n_hidden),sizeof(double)))==NULL){
  227.       error(6,FATAL);
  228.    }
  229.    /* allocate space for weight changes to output weights */
  230.    if((descent_o=(double *)calloc(U((n_hidden+1)*n_output),sizeof(double)))==NULL){
  231.       error(6,FATAL);
  232.    }
  233. }
  234.  
  235. /****************************************************************************/
  236. /* getpattern                                                               */
  237. /* this function loads values for patterns and targets into memory          */
  238. /****************************************************************************/
  239.  
  240. void getpattern(
  241.    FILE *data,             /* pointer to input data file */
  242.    int n_pattern_vector,   /* number of patterns to be read */
  243.    int n_units,            /* dimensionality of pattern */ 
  244.    double *matrix          /* pointer to matrix to hold values */
  245. )
  246. {
  247.    int p;
  248.    int i;
  249.  
  250.    for(p=0;p<n_pattern_vector;p++){
  251.       for(i=0;i<n_units;i++){
  252.          if(fscanf(data,"%lf",matrix+(p*n_units+i))<=NULL){
  253.             error(3,FATAL);
  254.          }
  255.       }
  256.    }
  257. }
  258.  
  259. /****************************************************************************/
  260. /* init_weights                                                             */
  261. /* this function initialises weight matrices with random numbers            */
  262. /****************************************************************************/
  263.  
  264. void init_weights(
  265.    int n_input_units,
  266.    int n_output_units,
  267.    double *weight
  268. )
  269. {
  270.    int i;
  271.    int j;
  272.  
  273.    for(j=0;j<n_output_units;j++){
  274.       /* note additional weights for bias unit */
  275.       *(weight+j*(n_input_units+1)+n_input_units) = d_rand();
  276.       for(i=0;i<n_input_units;i++){
  277.          *(weight+j*(n_input_units+1)+i) = d_rand();
  278.       }
  279.    }
  280. }
  281.  
  282. /****************************************************************************/
  283. /* learn                                                                    */
  284. /* this function implements the generalised delta rule with back propagation*/
  285. /****************************************************************************/
  286.  
  287. void learn(
  288.    int n_cycle
  289. )
  290. {
  291.    int learning_cycle;
  292.    int p;
  293.  
  294.     #if GRAPH==1
  295.     double x_scale;
  296.     double y_scale;
  297.     #endif
  298.  
  299.    /* dump(0);  */  /* for a complete history of the learning process */
  300.  
  301.     #if GRAPH==1
  302.         init_graph();
  303.     #else
  304.        print_scale();  /* prints scale to measure number of cycles */
  305.     #endif
  306.     
  307.    for(learning_cycle=0;learning_cycle<n_cycle;learning_cycle++){
  308.     
  309.         #if GRAPH!=1
  310.       if(learning_cycle%10==0)printf(".");
  311.         #endif
  312.         
  313.       for(p=0;p<n_pattern;p++){
  314.  
  315.          /* compute hidden layer output */
  316.          foreward(n_input,n_hidden,weight_h,input+p*n_input,hidden);
  317.  
  318.          /* compute output layer output */
  319.          foreward(n_hidden,n_output,weight_o,hidden,output+p*n_output);
  320.  
  321.          /* calculate delta for output units */
  322.          calc_delta_o(p,n_output,delta_o,target,output+p*n_output);
  323.  
  324.          /* calculate delta for hidden units */
  325.          calc_delta_h(n_output,n_hidden,delta_o,delta_h,hidden,weight_o);
  326.  
  327.          /* calculate descent for output weights */
  328.          calc_descent(n_hidden,n_output,learning_rate,momentum,
  329.             descent_o,delta_o,hidden);
  330.  
  331.          /* calculate descent for hidden weights */
  332.          calc_descent(n_input,n_hidden,learning_rate,momentum,
  333.             descent_h,delta_h,(input+p*n_input));
  334.  
  335.          /* correct output weights */
  336.          correct_weight(n_hidden,n_output,weight_o,descent_o);
  337.  
  338.          /* correct hidden weights */
  339.          correct_weight(n_input,n_hidden,weight_h,descent_h);
  340.  
  341.          /* dump(p); */   /* for a complete history of the learning process */
  342.             /* caution this slows program down and creates a large file */
  343.       }
  344.         #if GRAPH==1
  345.         if(learning_cycle==0){
  346.          set_scales(pattern_error(n_pattern,n_output,target,output),
  347.             n_cycle,&x_scale,&y_scale);
  348.       }
  349.       if((learning_cycle+1)%10==0){
  350.          point(pattern_error(n_pattern,n_output,target,output),
  351.             learning_cycle+1,x_scale,y_scale);
  352.         }
  353.     #endif  
  354.    }
  355.       #if GRAPH==1
  356.    point(pattern_error(n_pattern,n_output,target,output),
  357.    learning_cycle,x_scale,y_scale);
  358.     close_graph();
  359.     #endif
  360. }
  361.  
  362. /****************************************************************************/
  363. /* recognise                                                                */
  364. /* this function presents the input patterns and compares the output        */
  365. /* to the target patterns                                                   */
  366. /****************************************************************************/
  367. void recognise ()
  368. {
  369.    int p;
  370.    int i;
  371.    int k;
  372.  
  373.    for(p=0;p<n_pattern;p++){
  374.  
  375.       /* compute hidden layer output */
  376.       foreward(n_input,n_hidden,weight_h,(input+p*n_input),hidden);
  377.  
  378.       /* compute output layer output */
  379.       foreward(n_hidden,n_output,weight_o,hidden,output+p*n_output);
  380.  
  381.       /* print input pattern */
  382.       printf("\ninput  ");
  383.       for(i=0;i<n_input;i++){
  384.          printf("%3.1f  ",*(input+(p*n_input+i)));
  385.       }
  386.       printf("\n");
  387.  
  388.       /* print output pattern */
  389.       printf("output  ");
  390.       for(k=0;k<n_output;k++){
  391.          printf("%f  ",*(output+p*n_output+k));
  392.       }
  393.       printf("\n");
  394.  
  395.       /* print target pattern */
  396.       printf("target  ");
  397.       for(k=0;k<n_output;k++){
  398.          printf("%f  ",*(target+(p*n_output+k)));
  399.       }
  400.       printf("\n");
  401.    }
  402.  
  403.    /* print error */
  404.    printf("RMS error = %f ",pattern_error(n_pattern,n_output,target,output));
  405.  
  406.    printf("          press any key to proceed\n");
  407.    getch();
  408. }
  409.  
  410. /****************************************************************************/
  411. /* foreward                                                                 */
  412. /* this function calculates the output of a unit given input and weight     */
  413. /* the below diagram is meant to illustrate the operation of this function
  414.      
  415.           n_input_units       i is index
  416.  
  417.           0      1      2      3      4
  418. unit_in-> o      o      o      o      bias  (set to 1.0)
  419.           |      |      |      |      o     
  420.           |      |      |      |      |     
  421.           |0     |1     |2     |3     |4    
  422.  weight-> X------X------X------X------X---->activate(sum)---->o 0 <- unit_out
  423.           |      |      |      |      |     
  424.           |5     |6     |7     |8     |9                       
  425.           X------X------X------X------X---->activate(sum)---->o 1
  426.           |      |      |      |      |     
  427.           |10    |11    |12    |13    |14          
  428.           X------X------X------X------X---->activate(sum)---->o 2
  429.           |      |      |      |      |     
  430.           |15    |16    |17    |18    |19   
  431.           X------X------X------X------X---->activate(sum)---->o 3
  432.                                                                 n_output_units     
  433.                                                                 j is index 
  434.  
  435.      weight matrix is  n_output_units X (n_input_units+1)
  436. */
  437. /****************************************************************************/
  438.  
  439. void foreward (
  440.    int n_input_units,
  441.    int n_output_units,
  442.    double *weight,
  443.    double *unit_in,
  444.    double *unit_out
  445. )
  446. {
  447.    int i;
  448.    int j;
  449.    double sum;
  450.  
  451.    for(j=0;j<n_output_units;j++){
  452.       sum = 0.0;
  453.       for(i=0;i<n_input_units;i++){
  454.          sum = sum + (*(unit_in+i))*(*(weight+(j*(n_input_units+1)+i)));
  455.          }
  456.       sum = sum + (*(weight+(j*(n_input_units+1)+n_input_units)));
  457.       *(unit_out+j) = activate(sum);
  458.    }
  459. }
  460.  
  461. /****************************************************************************/
  462. /* activate                                                                 */
  463. /* this function calculates the output of a unit using a linear             */
  464. /* approximation to the sigmoid function                                    */
  465. /****************************************************************************/
  466. /*
  467. double activate(
  468.    double sum
  469. )
  470. {
  471.    double activation;
  472.  
  473.    if(sum < -1.8) activation = 0.1;
  474.  
  475.    else if(sum > 1.8) activation = 0.9;
  476.  
  477.    else{
  478.       activation = (sum+2.0)/4.0;
  479.    }
  480.    return activation;
  481. }
  482. */
  483. /****************************************************************************/
  484. /* activate                                                                 */
  485. /* this function calculates the output of a unit using a linear             */
  486. /* approximation to the sigmoid function                                    */
  487. /****************************************************************************/
  488. /*
  489. double activate(
  490.    double sum
  491. )
  492. {
  493.    double activation;
  494.  
  495.    if(sum < -1.3) activation = 0.1;
  496.  
  497.    else if(sum > 1.3) activation = 0.9;
  498.  
  499.    else if(sum < -1.0){
  500.       activation = (2.0*sum+3.0)/4.0;
  501.    }
  502.  
  503.    else if(sum > 1.0){
  504.       activation = (2.0*sum+1.0)/4.0;
  505.    }
  506.  
  507.    else{
  508.       activation = (sum+2.0)/4.0;
  509.    }
  510.    return activation;
  511. }
  512. */
  513. /****************************************************************************/
  514. /* activate                                                                 */
  515. /* this function calculates the output of a unit using the sigmoid function */
  516. /****************************************************************************/
  517.  
  518. double activate(
  519.    double sum
  520. )
  521. {
  522.    double activation;
  523.  
  524.    activation = 1.0/(1.0+exp(-sum));
  525.    return activation;
  526. }
  527.  
  528. /****************************************************************************/
  529. /* activate                                                                 */
  530. /* this function calculates the output of a unit using the step function    */
  531. /* it does not seem to work with the delta rule                             */
  532. /****************************************************************************/
  533. /*
  534. double activate(
  535.    double sum
  536. )
  537. {
  538.    double activation;
  539.  
  540.    if(sum>=0.0)activation = 0.9;
  541.    else activation = 0.1;
  542.    return activation;
  543. }
  544. */
  545. /****************************************************************************/
  546. /* calc_delta_o                                                             */
  547. /* this function calculates delta for output units                          */
  548. /****************************************************************************/
  549.  
  550. void calc_delta_o(
  551.    int p,                    /* number of the pattern under consideration */
  552.    int n_output_units,    /* number of output units (dimensionality) */
  553.    double *delta,            /* pointer to delta matrix */
  554.    double *unit_target,    /* pointer to target matrix */
  555.    double *unit_out     /* pointer to output matrix */
  556. )
  557. {
  558.    int j;
  559.    double temp;
  560.  
  561.    for(j=0;j<n_output_units;j++){
  562.        *(delta+j) = ((*(unit_target+(p*n_output_units)+j))-(*(unit_out+j)))*
  563.          (*(unit_out+j))*(1-(*(unit_out+j)));
  564.    }
  565. }
  566.  
  567. /****************************************************************************/
  568. /* calc_delta_h                                                             */
  569. /* this function calculates delta for hidden units                          */
  570. /****************************************************************************/
  571.  
  572. void calc_delta_h(
  573.    int n_output_units,        /* number of output units */
  574.    int n_hidden_units,        /* number of hidden units */
  575.    double *delta_out,        /* delta for output units */
  576.    double *delta_hid,        /* delta for hidden units */
  577.    double *unit_hid,            /* pointer to hidden units */
  578.    double *weight                /* pointer to weight matrix */
  579. )
  580. {
  581.    int j;
  582.    int k;
  583.    double sum;
  584.  
  585.    for(j=0;j<n_hidden_units;j++){
  586.       sum = 0.0;
  587.       for(k=0;k<n_output_units;k++){
  588.          sum = sum+(*(delta_out+k))*(*(weight+k*(n_hidden_units+1)+j));
  589.       }
  590.       *(delta_hid+j) = (*(unit_hid+j)) * (1-(*(unit_hid+j))) * sum;
  591.    }
  592. }
  593.  
  594. /****************************************************************************/
  595. /* calc_descent                                                             */
  596. /* this function calculates values for the descent matrix using the deltas  */
  597. /****************************************************************************/
  598.  
  599. void calc_descent(
  600.    int n_input_units,
  601.    int n_output_units,
  602.    double rate,
  603.    double moment,
  604.    double *descent,
  605.    double *delta,
  606.    double *unit_in
  607. )
  608. {
  609.    int i;
  610.    int j;
  611.  
  612.    for(j=0;j<n_output_units;j++){
  613.       for(i=0;i<n_input_units;i++){
  614.          *(descent+j*(n_input_units+1)+i) = rate*(*(delta+j))*(*(unit_in+i))+
  615.             moment*(*(descent+j*(n_input_units+1)+i));
  616.       }
  617.    /*note additional descents for bias weight*/
  618.    *(descent+j*(n_input_units+1)+n_input_units) = rate*(*(delta+j))+
  619.       moment*(*(descent+j*(n_input_units+1)+n_input_units));
  620.    }
  621. }
  622.  
  623. /****************************************************************************/
  624. /* correct_weight                                                           */
  625. /* this function updates the weight matrix using the latest descent values  */
  626. /****************************************************************************/
  627.  
  628. void correct_weight(
  629.    int n_input_units,
  630.    int n_output_units,
  631.    double *weight,
  632.    double *descent
  633. )
  634. {
  635.    int i;
  636.    int j;
  637.  
  638.    for(i=0;i<n_input_units+1;i++){ /*note additional descents for bias weight*/
  639.       for(j=0;j<n_output_units;j++){
  640.          *(weight+j*(n_input_units+1)+i) = *(weight+j*(n_input_units+1)+i) +
  641.          *(descent+j*(n_input_units+1)+i);
  642.       }
  643.    }
  644. }
  645.  
  646. /****************************************************************************/
  647. /* pattern_error                                                            */
  648. /* this function calculates error of outputs vs targets for output units    */
  649. /* error is output to a data file                                           */
  650. /****************************************************************************/
  651.  
  652. double pattern_error(
  653.    int n_pattern_vectors,
  654.    int n_output_units,
  655.    double *unit_target,
  656.    double *unit_out
  657. )
  658. {
  659.    int p;
  660.    int j;
  661.    double temp;
  662.  
  663.    temp = 0.0;
  664.    for (p=0;p<n_pattern_vectors;p++){
  665.       for(j=0;j<n_output_units;j++){
  666.          temp = temp + SQ((*(unit_target+(p*n_output_units)+j))-
  667.             (*(unit_out+p*n_output_units+j)));
  668.       }
  669.    }
  670.    return sqrt(temp);
  671. }
  672.  
  673.  
  674. /****************************************************************************/
  675. /* print_scale                                                              */
  676. /* this function prints a scale for displaying number of iterations of      */
  677. /* learning cycle if graphics mode is not selected                          */
  678. /****************************************************************************/
  679.  
  680. void print_scale()
  681. {
  682.    printf("\n        100       200       300       400");
  683.    printf("       500       600       700      800");
  684.    printf("---------+---------+---------+---------+");
  685.    printf("---------+---------+---------+---------+");
  686. }
  687.  
  688. /****************************************************************************/
  689. /* dump                                                                     */
  690. /* this function prints intermediate results to file bp3.dat                */
  691. /* call dump() from function learn()                                        */
  692. /****************************************************************************/
  693.  
  694. void dump(
  695.    int p
  696. )
  697. {
  698.    int i;
  699.    int j;
  700.    int k;
  701.  
  702.    if(fprintf(bp3,"\n")==NULL)
  703.       error(5,WARN);
  704.    if(fprintf(bp3,"input pattern no. %d\n",p)==NULL)
  705.       error(5,WARN);
  706.    if(fprintf(bp3,"\n")==NULL)
  707.       error(5,WARN);
  708.  
  709.    /* print input pattern */
  710.    if(fprintf(bp3,"input pattern\n")==NULL)
  711.       error(5,WARN);
  712.    for(i=0;i<n_input;i++){
  713.       if(fprintf(bp3,"%f  ",*(input+(p*n_input+i)))==NULL)
  714.          error(5,WARN);
  715.    }
  716.    if(fprintf(bp3,"\n")==NULL)
  717.       error(5,WARN);
  718.  
  719.    /* print hidden pattern */
  720.    if(fprintf(bp3,"hidden pattern\n")==NULL)
  721.       error(5,WARN);
  722.    for(j=0;j<n_hidden;j++){
  723.       if(fprintf(bp3,"%f  ",*(hidden+j))==NULL)
  724.          error(5,WARN);
  725.    }
  726.    if(fprintf(bp3,"\n")==NULL)
  727.       error(5,WARN);
  728.  
  729.    /* print deltas for hidden pattern */
  730.    if(fprintf(bp3,"deltas for hidden pattern\n")==NULL)
  731.       error(5,WARN);
  732.    for(j=0;j<n_hidden;j++){
  733.       if(fprintf(bp3,"%f  ",*(delta_h+j))==NULL)
  734.          error(5,WARN);
  735.    }
  736.    if(fprintf(bp3,"\n")==NULL)
  737.       error(5,WARN);
  738.  
  739.    /* print output pattern */
  740.    if(fprintf(bp3,"output pattern\n")==NULL)
  741.       error(5,WARN);
  742.    for(k=0;k<n_output;k++){
  743.       if(fprintf(bp3,"%f  ",*(output+p*n_output+k))==NULL)
  744.          error(5,WARN);
  745.    }
  746.    if(fprintf(bp3,"\n")==NULL)
  747.       error(5,WARN);
  748.  
  749.    /* print deltas for output pattern */
  750.    if(fprintf(bp3,"deltas for output pattern\n")==NULL)
  751.       error(5,WARN);
  752.    for(k=0;k<n_output;k++){
  753.       if(fprintf(bp3,"%f  ",*(delta_o+k))==NULL)
  754.          error(5,WARN);
  755.    }
  756.    if(fprintf(bp3,"\n")==NULL)
  757.       error(5,WARN);
  758.  
  759.    /* print target pattern */
  760.    if(fprintf(bp3,"target pattern\n")==NULL)
  761.       error(5,WARN);
  762.    for(k=0;k<n_output;k++){
  763.       if(fprintf(bp3,"%f  ",*(target+(p*n_output+k)))==NULL)
  764.          error(5,WARN);
  765.    }
  766.    if(fprintf(bp3,"\n")==NULL)
  767.       error(5,WARN);
  768.  
  769.    /* print weights to hidden units */
  770.    if(fprintf(bp3,"weights to hidden units\n")==NULL)
  771.       error(5,WARN);
  772.    for(j=0;j<n_hidden;j++){
  773.       for(i=0;i<n_input+1;i++){   /* note additional weights for bias unit */
  774.          if(fprintf(bp3,"%f  ",*(weight_h+j*(n_input+1)+i))==NULL)
  775.             error(5,WARN);
  776.       }
  777.    if(fprintf(bp3,"\n")==NULL)
  778.       error(5,WARN);
  779.    }
  780.  
  781.    /* print descents for weights to hidden units */
  782.    if(fprintf(bp3,"descents to weights to hidden units\n")==NULL)
  783.       error(5,WARN);
  784.    for(j=0;j<n_hidden;j++){
  785.       for(i=0;i<n_input+1;i++){   /* note additional descent for bias unit */
  786.          if(fprintf(bp3,"%f  ",*(descent_h+j*(n_input+1)+i))==NULL)
  787.             error(5,WARN);
  788.       }
  789.    if(fprintf(bp3,"\n")==NULL)
  790.       error(5,WARN);
  791.    }
  792.  
  793.    /* print weights to output units */
  794.    if(fprintf(bp3,"weights to output units\n")==NULL)
  795.       error(5,WARN);
  796.    for(k=0;k<n_output;k++){
  797.       for(j=0;j<n_hidden+1;j++){   /* note additional weights for bias unit */
  798.          if(fprintf(bp3,"%f  ",*(weight_o+k*(n_hidden+1)+j))==NULL)
  799.             error(5,WARN);
  800.       }
  801.    if(fprintf(bp3,"\n")==NULL)
  802.       error(5,WARN);
  803.    }
  804.  
  805.    /* print descents for weights to output units */
  806.    if(fprintf(bp3,"descents for weights to output units\n")==NULL)
  807.       error(5,WARN);
  808.    for(k=0;k<n_output;k++){
  809.       for(j=0;j<n_hidden+1;j++){  /* note additional descents for bias unit */
  810.          if(fprintf(bp3,"%f  ",*(descent_o+k*(n_hidden+1)+j))==NULL)
  811.             error(5,WARN);
  812.       }
  813.    if(fprintf(bp3,"\n")==NULL)
  814.       error(5,WARN);
  815.    }
  816. }
  817.  
  818. /****************************************************************************/
  819. /* get_seed                                                                 */
  820. /* this function asks for a number and uses it to seed rand()               */
  821. /****************************************************************************/
  822.  
  823. void get_seed()
  824. {
  825.    char buff[10];
  826.    long int s;
  827.  
  828.    printf("\n\rBack Propagation Generalised Delta Rule Learning Program\n\r");
  829.    printf("Enter seed for pseudorandom number generator.\n\rDefault = 1\n\rseed: ");
  830.    s = atol(gets(buff));
  831.    if(s<1) s = 1;
  832.    s_seed(s);
  833.    printf("\n\rseed = %ld",s);
  834. }
  835.  
  836. /****************************************************************************/
  837. /* get_limits                                                               */
  838. /* this function gets limits for d_rand                                     */
  839. /****************************************************************************/
  840.  
  841. void get_limits()
  842. {
  843.     double upper,lower;     /* limits for random number generator */
  844.  
  845.    char buff[10];
  846.  
  847.    printf("\n\rEnter range for weights.\n\r");
  848.    printf("defaults:\n\r   upper limit = 1.0\n\r   lower limit = -1.0\n\r");
  849.    printf("enter upper limit: ");
  850.    upper = atof(gets(buff));
  851.    printf("enter lower limit: ");
  852.    lower = atof(gets(buff));
  853.    if(lower>=upper){
  854.       printf("upper limit must be greater than lower limit; ");
  855.       printf("default selected\n\r");
  856.       upper = 1.0;
  857.       lower = -1.0;
  858.    }
  859.    printf("upper limit = %4.2f\n\rlower limit = %4.2f",upper,lower);
  860.     s_limits(upper,lower);
  861. }
  862.  
  863.