home *** CD-ROM | disk | FTP | other *** search
/ Nebula 1995 August / NEBULA.mdf / SourceCode / Classes / Neural-Network / BackPropEngine.m < prev    next >
Encoding:
Text File  |  1992-07-29  |  3.0 KB  |  133 lines

  1. /* =======================================================
  2.     Neural Network Classes for the NeXT Computer
  3.     Written by: Ralph Zazula
  4.                     University of Arizona - Fall 1991
  5.                     zazula@pri.com (NeXT Mail)
  6. ==========================================================*/
  7. #import "BackPropEngine.h"
  8.  
  9.    double *error;
  10.    
  11. @implementation BackPropEngine
  12.  
  13. - inputs { return inputs; }
  14. - hidden { return hidden; }
  15. - outputs { return outputs; }
  16. - (double)getEta { return ETA; }
  17. - setEta:(double)newEta { ETA = newEta; return self;}
  18. - (double)getError { return Error; }
  19.  
  20. - init
  21. {
  22.    [super init];
  23.    
  24.    inputs = [[List alloc] init];
  25.    hidden = [[List alloc] init];
  26.    outputs = [[List alloc] init];
  27.    random = [[Random alloc] init];
  28.    
  29.    ETA = 0.9;
  30.    ALPHA = 0.9;
  31.    Error = 0.0;
  32.    
  33.    return self;
  34. }
  35.  
  36. - initWithInputs:(int)Nin hidden:(int)Nhid outputs:(int)Nout
  37. {
  38.    int i,j;
  39.    
  40.    [self init];
  41.    
  42.    //
  43.    // create the nodes
  44.    //
  45.    for(i=0; i<Nin; i++)
  46.       [inputs addObject:[[[Neuron alloc] init] 
  47.                         setRandom:random]];
  48.    for(i=0; i<Nhid; i++)
  49.       [hidden addObject:[[[Neuron alloc] init] 
  50.                         setRandom:random]];   
  51.    for(i=0; i<Nout; i++)
  52.       [outputs addObject:[[[Neuron alloc] init] 
  53.                         setRandom:random]];
  54.       
  55.    //
  56.    // make the connections
  57.    //
  58.    for(j=0; j<Nhid; j++)
  59.       for(i=0; i<Nin; i++) 
  60.          [[hidden objectAt:j] connect:[inputs objectAt:i]];
  61.    for(j=0; j<Nout; j++)
  62.       for(i=0; i<Nhid; i++)   
  63.          [[outputs objectAt:j] connect:[hidden objectAt:i]];
  64.    //
  65.    // allocate other variables
  66.    //
  67.    error = (double *)malloc((Nout+1)*sizeof(double));
  68.    
  69.    return self;
  70. }
  71.  
  72. - applyInput:(double *)input
  73. {
  74.    int   i;
  75.    
  76.    for(i=0; i<[inputs count]; i++) 
  77.       [[inputs objectAt:i] setOutput:(double)input[i]];
  78.    
  79.    for(i=0; i<[hidden count]; i++)
  80.       [[hidden objectAt:i] step];
  81.       
  82.    for(i=0; i<[outputs count]; i++)
  83.       [[outputs objectAt:i] step];
  84.  
  85.    return self;
  86. }
  87.  
  88. - correctWithTarget:(double *)target
  89. {
  90.    int i,j,k;
  91.    id oNode, hNode, iNode;
  92.    double O, delta;
  93.    
  94.    //
  95.    // correct the hidden->output weights
  96.    // and update the error
  97.    //
  98.    Error=0.0;
  99.    for(i=0; i<[outputs count]; i++) {
  100.       oNode = [outputs objectAt:i];
  101.       O = [oNode lastOutput];
  102.       error[i] = target[i] - O;
  103.       Error += error[i]*error[i]/2.0;
  104.       for(j=0; j<[hidden count]; j++) {
  105.          hNode = [hidden objectAt:j];
  106.          [oNode changeWeightFor:hNode 
  107.                 by:error[i]*[hNode lastOutput]*ETA*O*(1-O)];
  108.       }
  109.    }
  110.    //
  111.    // correct the input->hidden weights
  112.    //
  113.    for(k=0; k<[inputs count]; k++) {
  114.       iNode = [inputs objectAt:k];
  115.       for(j=0; j<[hidden count]; j++) {
  116.          hNode = [hidden objectAt:j];
  117.          delta = 0.0;
  118.          for(i=0; i<[outputs count]; i++) {
  119.             oNode = [outputs objectAt:i];
  120.             O = [oNode lastOutput];
  121.             delta +=O*(1-O)*error[i]*[oNode getWeightFor:hNode];
  122.          }
  123.          [hNode changeWeightFor:iNode 
  124.                 by:delta*ETA*[iNode lastOutput]*
  125.                 [hNode lastOutput]*(1-[hNode lastOutput])];
  126.       }
  127.    }
  128.    
  129.    return self;
  130. }
  131.  
  132. @end
  133.