home *** CD-ROM | disk | FTP | other *** search
/ Nebula 1995 August / NEBULA.mdf / SourceCode / Classes / Neural-Network / weighted_matching.m < prev    next >
Encoding:
Text File  |  1992-07-29  |  5.1 KB  |  217 lines

  1. /* =======================================================
  2.     Neural Network Classes for the NeXT Computer
  3.     Written by: Ralph Zazula
  4.                     University of Arizona - Fall 1991
  5.                     zazula@pri.com (NeXT Mail)
  6. ==========================================================*/
  7. #import "Neuron.h"
  8. #import "Random.h"
  9. #import <math.h>
  10.  
  11. #define    Np    16                    // the number of points in the matching problem
  12. #define     N    Np*(Np-1)/2        // the number of neurons required to represent Np
  13.  
  14. id            random;                // random number generator instance
  15. id            nList;                // list of the neurons
  16. double    gamma = 1.0;        // "gamma" factor
  17. double    alpha    = 0.995;        // temperature decay rate
  18. double    p[Np][2];            // array to store the positions the points
  19. double    d[Np][Np];            // array to store the distances between the points
  20. double    T = 1.0;               // the current temperature
  21.  
  22. //===========================================================
  23. // function to return the index (from 0->N) of the neuron
  24. // that represents the connection from point i->j
  25. //
  26. int index(i,j)
  27. int    i,j;
  28. {
  29.     int    k;
  30.     
  31.     if(i>j) {                // swap them if reversed
  32.         k = j;
  33.         j = i;
  34.         i = k;
  35.     }
  36.     return (int)(((float)Np - 3.0/2.0)*(float)i - (float)(i*i)/2.0 + j - 1);
  37. }
  38.  
  39. //==========================================================
  40. // function to return the distance between the points i and j
  41. //
  42. double dist(i,j)
  43. int    i,j;
  44. {
  45.     return    sqrt((p[i][0] - p[j][0])*(p[i][0] - p[j][0]) +
  46.                       (p[i][1] - p[j][1])*(p[i][1] - p[j][1]));
  47. }
  48.  
  49. //===========================================================
  50. // function to return the current value of the energy function
  51. //
  52. double H()
  53. {
  54.     double    h=0.0, x;
  55.     
  56.     int    i,j;
  57.  
  58.     
  59.     for(i=0; i<Np; i++) {
  60.         x = 0.0;
  61.         for(j=0; j<Np; j++)
  62.             if(i!=j)
  63.                 x += [[nList objectAt:index(i,j)] lastOutput];
  64.         h += (gamma/2.0)*(1 - x)*(1 - x);
  65.     }
  66.  
  67.     for(j=1; j<Np; j++)
  68.         for(i=0; i<j; i++)
  69.             h += d[i][j]*[[nList objectAt:index(i,j)] lastOutput];
  70.     
  71.     return h;
  72. }    
  73.  
  74. //==========================================================
  75. // function to return whether or not this node has flipped
  76. // given the current temperature and the effect of the 
  77. // flip on the energy funcion.
  78. //
  79. BOOL flip(n)
  80. int    n;
  81. {
  82.     double    H1,H2,dH;
  83.     
  84.     //
  85.     // get current energy-function, flip the node, get new energy-function
  86.     // calculate dH
  87.     //
  88.     H1 = H();
  89.     if([[nList objectAt:n] lastOutput])
  90.         [[nList objectAt:n] setOutput:0];
  91.     else 
  92.         [[nList objectAt:n] setOutput:1];
  93.                             
  94.     H2 = H();
  95.     dH = H2 - H1;
  96.     
  97.     // 
  98.     // flip back if the probablility was too low to flip this node
  99.     // keep it flipped otherwise
  100.     //
  101.     if([random percent] > 1/(1 + exp(dH/T)))
  102.         if([[nList objectAt:n] lastOutput])
  103.             [[nList objectAt:n] setOutput:0];
  104.         else
  105.             [[nList objectAt:n] setOutput:1];
  106.  
  107. }
  108.  
  109. //===========================================================
  110. // function to return the current length of all connections
  111. //
  112. double L()
  113. {
  114.     int    i,j;
  115.     double l = 0.0;
  116.     
  117.     for(j=1; j<Np; j++)
  118.         for(i=0; i<j; i++) 
  119.             l += d[i][j]*[[nList objectAt:index(i,j)] lastOutput];
  120.             
  121.     return l;
  122. }
  123.  
  124. //===========================================================
  125.  
  126. void main()
  127. {
  128.     int    i,j,count,n;
  129.     
  130.     nList = [[List alloc] init];                // list of neurons
  131.     random = [[Random alloc] init];            // get a random number generator
  132.  
  133.     [random setSeeds:5335:7777:32197];
  134.     //
  135.     // generate Np random points
  136.     //
  137.     printf("generating points\n");
  138.     for(i=0; i<Np; i++) {
  139.         p[i][0] = [random percent];
  140.         p[i][1] = [random percent];
  141.     }
  142.     // 
  143.     // fill-in the distance array
  144.     //
  145.     printf("calculating distances\n");
  146.     for(j=1; j<Np; j++)
  147.         for(i=0; i<j; i++)
  148.             d[i][j] = 
  149.             d[j][i] = 
  150.              sqrt((p[i][0] - p[j][0])*(p[i][0] - p[j][0]) +
  151.                     (p[i][1] - p[j][1])*(p[i][1] - p[j][1]));
  152.     //
  153.     // generate Neurons
  154.     //
  155.     printf("generating neurons\n");
  156.     for(i=0; i<N; i++) {
  157.         [nList addObject:[[[Neuron alloc] init] setType:Binary]];
  158.         [[nList lastObject] setOutput:[random randMax:1]];
  159. //        printf("%f\n",[[nList lastObject] lastOutput]);
  160.     }
  161.     //
  162.     // make initial connections
  163.     //
  164.     for(j=1; j<N; j++)
  165.         for(i=0; i<j; i++) {
  166.             [[nList objectAt:i] connect:[nList objectAt:j] withWeight:-gamma];
  167.             [[nList objectAt:j] connect:[nList objectAt:i] withWeight:-gamma];
  168. //            [[nList lastObject] step];
  169.         }
  170.  
  171. /*
  172.     for(i=0; i<Np; i++)
  173.         printf("%u %f %f\n",i,p[i][0],p[i][1]);
  174.     
  175.     for(i=0; i<Np; i++)
  176.         for(j=i+1; j<Np; j++)
  177.             printf("%u->%u: %e\n",i,j,d[i][j]);
  178. */
  179.  
  180.     // start the network
  181.     printf("starting total connection length: %f\n",L());
  182.     printf("T         L         Energy     #of connections\n");
  183.     for(count=0; count<200*N; count++)
  184.     {
  185.         //
  186.         // pick a random node to update
  187.         //
  188.         n = [random randMax:N];
  189.         [[nList objectAt:n] step];
  190.         flip(n);
  191.         if(!(count % N))    T = alpha*T;
  192.         if(!(count % 100)) {
  193.             n = 0;
  194.             for(i=0; i<N; i++)
  195.                 if([[nList objectAt:i] lastOutput]) n++;
  196. //            printf("count: %u  length:  %f  Temp: %f  Connections: %u\n",
  197. //                count,L(),T,n);
  198.             printf("%f %f %f %u\n",T,L(),H(),n);
  199.         }
  200.     }
  201.     printf("points\n-----\n");
  202.     for(i=0; i<Np; i++)
  203.         printf("%u %f %f\n",i,p[i][0],p[i][1]);
  204.  
  205.     printf("distances\n-----\n");
  206.     for(i=0; i<Np; i++)
  207.         for(j=i+1; j<Np; j++)
  208.                 printf("%u->%u: %e\n",i,j,d[i][j]);
  209.  
  210.     printf("connections\n-----\n");
  211.     for(i=0; i<Np; i++)
  212.         for(j=i+1; j<Np; j++)
  213.             if([[nList objectAt:index(i,j)] lastOutput])
  214.                 printf("%u->%u: %e\n",i,j,d[i][j]);
  215.  
  216. }
  217.