home *** CD-ROM | disk | FTP | other *** search
/ Nebula 1995 August / NEBULA.bin / SourceCode / Classes / Neural-Network / Neuron-HT.m < prev    next >
Encoding:
Text File  |  1992-01-14  |  5.0 KB  |  237 lines

  1. /*$Log:    Neuron.m,v $
  2. Revision 1.4  92/01/14  21:19:46  zazula
  3. Check in before starting HashTable mod
  4.  
  5. Revision 1.3  92/01/02  14:04:31  zazula
  6. Faster linked-list for connections
  7. No more Storage object
  8.  
  9. Revision 1.2  92/01/02  12:41:34  zazula
  10. Initial version - support for stochastic networks via temperature T
  11. */
  12. #import "Neuron.h"
  13. #import <appkit/nextstd.h>
  14. #import "math.h"
  15.  
  16.  
  17. //----------------------------------------------------------
  18.  
  19. @implementation Neuron
  20.  
  21. - inputs { return inputs; }
  22. - setType:(int)type { nodeType = type; return self; }
  23. - (int)getType { return nodeType; }
  24. - setTemp:(double)newT { T = newT; return self; }
  25. - (double)getTemp { return T; }
  26. - setRandom:theRandom { random = theRandom; return self; }
  27.  
  28. //----------------------------------------------------------
  29.  
  30. - (double)activation:(double)net
  31. {
  32.    double   temp;
  33.     
  34.     if(random == nil) random = [[Random alloc] init];
  35.    switch (nodeType) {
  36.    case Binary :
  37.       if(T > 0.0)
  38.          temp = ([random percent] <= 1.0/(1.0+exp(-2*net/T))) ? 1.0 : 0.0;
  39.       else
  40.          temp = (net > 0.5) ? 1.0 : 0.0;
  41.       break;
  42.    case Sigmoid : 
  43.       temp = 1.0/(1.0+exp(-net));
  44.       break;
  45.    case Sign : 
  46.       if(T > 0.0)
  47.          temp = ([random percent] <= 1.0/(1.0+exp(-2*net/T))) ? 1.0 : -1.0;
  48.       else
  49.          temp = (net > 0.0) ? 1.0 : -1.0;
  50.       break;
  51.    case Tanh :
  52.         if(T > 0.0)
  53.             temp = tanh(net/T);
  54.         else
  55.           temp = tanh(net);
  56.       break;
  57.    }
  58.    
  59.    return temp;
  60. }
  61.  
  62. //----------------------------------------------------------
  63.  
  64. - init
  65. {
  66.    [super init];
  67.    lastOutput = 0.0;
  68.    nodeType = Sigmoid;        // default node type
  69.    T = 0.0;                   // default temperature
  70.       head = tail = NULL;            // initialize the linked-list of connections
  71.     ht = [[HashTable alloc] initKeyDesc:"@" valueDesc:"i"];
  72.     
  73.    return self;
  74. }
  75.  
  76. //-----------------------------------------------------------
  77.  
  78. - step
  79. // update the output value based on our inputs
  80. {
  81.    int i = 0;
  82. //   connection *C;
  83.    double temp=0.0;     // use temp variable to allow for feedback
  84.     NXHashState state = [ht initState];
  85.     double *weight;
  86.     id         source;
  87.        
  88. //    C = head;
  89. //    while(C != NULL) {
  90. //        temp += C->weight*[C->source lastOutput];
  91. //        C = (connection *)C->next;
  92. //    }
  93.     while([ht nextState:&state key:&source value:&weight]) {
  94.         temp += *weight * [source lastOutput];
  95.     }
  96.         
  97.    lastOutput = [self activation:temp];
  98.    
  99.    return self;
  100. }
  101.  
  102. //-----------------------------------------------------------
  103.  
  104. - (double)lastOutput
  105. {
  106.    return lastOutput;
  107. }
  108.  
  109. //-----------------------------------------------------------
  110.  
  111. - connect:sender
  112. {
  113.     if(random == nil) random = [[Random alloc] init];
  114.    return [self connect:sender withWeight:[random percent]/10.0];
  115. }
  116.  
  117. //-----------------------------------------------------------
  118.  
  119. - connect:sender withWeight:(double)weight
  120. //
  121. // adds sender to the list of inputs
  122. // we should check to make sure sender is a Neruon
  123. // also need to check if it is already in the list
  124. //
  125. {
  126. /*
  127.    connection *C;
  128.    
  129.    C = (connection *)malloc(sizeof(connection));
  130.     if(head == NULL) {
  131.         head = C;
  132.     }
  133.     else {
  134.         tail->next = C;
  135.     }
  136.     tail = C;
  137.    C->source = sender;
  138.    C->weight = weight;
  139.     C->next   = NULL;
  140. */
  141.     double    *value = malloc(sizeof(double));
  142.     *value = weight;
  143.     [ht insertKey:sender value:value];
  144.           
  145.    return self;
  146. }
  147.  
  148. //-----------------------------------------------------------
  149.  
  150. - (double)getWeightFor:source
  151. {
  152. /*
  153.    int i=0;
  154.    connection *C;
  155.  
  156.     C = head;
  157.     while((C != NULL) && (C->source != source))
  158.         C = (connection *)C->next;
  159.         
  160.    if(C != NULL) {            // if C==NULL, source isn't an input
  161.       return C->weight;
  162.    }
  163.    else {
  164.       fprintf(stderr,"connection not found in getWeightFor:\n");
  165.       return NAN;
  166.    }
  167. */
  168.     return *(double *)[ht valueForKey:source];    
  169. }
  170.  
  171. //-----------------------------------------------------------
  172.  
  173. - setWeightFor:source to:(double)weight
  174. {
  175. /*
  176.    int i=0;
  177.    connection *C;
  178.    
  179.     C = head;
  180.     while((C != NULL) && (C->source != source))
  181.         C = (connection *)C->next;
  182.         
  183.    if(C != NULL) {            // if C==NULL, source isn't an input
  184.       C->weight = weight;
  185.       return self;
  186.    }
  187.    else {
  188.       fprintf(stderr,"connection not found in setWeightFor:to:\n");
  189.       return nil;
  190.    }
  191. */
  192.     *(double *)[ht valueForKey:source] = weight;
  193.     return self;
  194. }
  195.  
  196. //-----------------------------------------------------------
  197.  
  198. - setOutput:(double)output
  199. {
  200.    lastOutput = output;
  201.    
  202.    return self;
  203. }
  204. //-----------------------------------------------------------
  205.  
  206. - changeWeightFor:source by:(double)delta
  207. {
  208. /*
  209.    int i=0;
  210.    connection *C;
  211.    
  212.     C = head;
  213.     while((C != NULL) && (C->source != source))
  214.         C = (connection *)C->next;
  215.         
  216.    if(C != NULL) {            // if C==NULL, source isn't an input
  217.       C->weight += delta;
  218.         if(SYMMETRIC) // for symmetric connections
  219.            [source setWeightFor:self to:C->weight];
  220.       return self;
  221.    }
  222.    else {
  223.       fprintf(stderr,"connection not found in changeWeightfor:by:\n");
  224. //      printf("connection not found in changeWeightfor:by:\n");
  225.       return nil;
  226.    }
  227. */
  228.     *(double *)[ht valueForKey:source] += delta;
  229.     [source setWeightFor:self to:[self getWeightFor:source]];
  230.     
  231.     return self;
  232.     
  233. }
  234.  
  235.  
  236. @end
  237.