home *** CD-ROM | disk | FTP | other *** search
/ ARM Club 3 / TheARMClub_PDCD3.iso / hensa / misc / b186_1 / Source / c / pa < prev    next >
Text File  |  1987-12-23  |  14KB  |  578 lines

  1. /*
  2.  
  3.        This file is part of the PDP software package.
  4.          
  5.        Copyright 1987 by James L. McClelland and David E. Rumelhart.
  6.        
  7.        Please refer to licensing information in the file license.txt,
  8.        which is in the same directory with this source file and is
  9.        included here by reference.
  10. */
  11.  
  12.  
  13. /* file: pa.c
  14.  
  15.     Do the actual work for the pa program.
  16.     
  17.     First version implemented by Elliot Jaffe.
  18.     
  19.     Date of last revision:  8-12-87/JLM.
  20. */
  21.  
  22. #include "general.h"
  23. #include "pa.h"
  24. #include "variable.h"
  25. #include "weights.h"
  26. #include "patterns.h"
  27. #include "command.h"
  28. #include <math.h>
  29.  
  30. char   *Prompt = "pa: ";
  31. boolean System_Defined = FALSE;
  32. char   *Default_step_string = "epoch";
  33. boolean lflag = 1;
  34. boolean linear = 0;
  35. boolean    lt = 0;
  36. boolean cs = 0;
  37. boolean hebb = 0;
  38. int     epochno = 0;
  39. int     nepochs = 500;
  40. int     patno = 0;
  41. float    ndp = 0;
  42. float    nvl = 0;
  43. float    vcor = 0;
  44. float   tss = 0.0;
  45. float   pss = 0.0;
  46. float   ecrit = 0.0;
  47. float  *netinput = NULL;
  48. float  *output = NULL;
  49. float  *error = NULL;
  50. float  *input = NULL;
  51. float  *target = NULL;
  52. float    noise = 0;
  53. float   temp = 15.0;
  54. int    tallflag = 0;
  55.  
  56.  
  57. extern int read_weights();
  58. extern int write_weights();
  59.  
  60. float *
  61. readvec(pstr,len) char *pstr; int len; {
  62.     int j;
  63.     float *tvec;
  64.     char *str;
  65.     char tstr[60];
  66.     
  67.     if (pstr == NULL) {
  68.         tvec = (float *) emalloc((unsigned)(sizeof(float)*len));
  69.     for (j = 0; j < len; j++) {
  70.         tvec[j] = 0.0;
  71.     }
  72.     return(tvec);
  73.     }
  74.     sprintf(tstr,"give %selements:  ",pstr);
  75.     tvec = (float *) emalloc((unsigned)(sizeof(float)*len));
  76.     for (j = 0; j < len; j++) {
  77.     tvec[j] = 0.0;
  78.     }
  79.     for (j = 0; j <= len; j++) {
  80.         str = get_command(tstr);
  81.     if (str == NULL || strcmp(str,"end") == 0) {
  82.         if (j) return(tvec); else return (NULL);
  83.     }
  84.     if (strcmp("+",str) == 0) tvec[j] = 1.0;
  85.     else if (strcmp("-",str) == 0) tvec[j] = -1.0;
  86.     else if (strcmp(".",str) == 0) tvec[j] = 0.0;
  87.     else sscanf(str,"%f",&tvec[j]);
  88.     }
  89.     return(tvec);
  90. }
  91.  
  92. float *
  93. get_vec() {
  94.     char * str;
  95.     int j;
  96.     str = 
  97.       get_command("vector (iN for ipattern, tN for tpattern, E for enter): ");
  98.     if (str == NULL) return(NULL);
  99.     if(*str == 'i') {
  100.     if((patno = get_pattern_number(++str)) < 0) {
  101.         put_error("Invalid pattern specification.");
  102.         return(NULL);
  103.     }
  104.         return(ipattern[patno]);
  105.     }
  106.     else if(*str == 't') {
  107.     if((patno = get_pattern_number(++str)) < 0) {
  108.         put_error("Invalid pattern specification.");
  109.         return(NULL);
  110.     }
  111.         return(tpattern[patno]);
  112.     }
  113.     else return(readvec(" ",nunits));
  114. }
  115.  
  116. float
  117. dotprod(v1,v2,len) float *v1, *v2; int len; {
  118.     register int i;
  119.     double dp = 0;
  120.     double denom;
  121.     denom = (double) len;
  122.     if (denom == 0) return(0.0);
  123.     for (i = 0; i < len; i++,v1++,v2++) {
  124.         dp += (double) ((*v1)*(*v2));
  125.     }
  126.     dp /= denom;
  127.     return(dp);
  128. }
  129.  
  130. float
  131. sumsquares(v1,v2,len) float *v1, *v2; int len; {
  132.     register int i;
  133.     double ss = 0;
  134.  
  135.     for (i = 0; i < len; i++,v1++,v2++) {
  136.         ss += (double)((*v1 - *v2) * (*v1 - *v2));
  137.     }
  138.     return(ss);
  139. }
  140.  
  141. /* the following function computes the vector correlation, or the
  142.    cosine of the angle between v1 and v2 */
  143.  
  144. float
  145. veccor(v1,v2,len) float *v1, *v2; int len; {
  146.     register int i;
  147.     double denom;
  148.     double dp = 0.0;
  149.     double l1 = 0.0;
  150.     double l2 = 0.0;
  151.  
  152.     for (i = 0; i < len; i++,v1++,v2++) {
  153.         dp += (double) (*v1)*(*v2);
  154.         l1 += (double) (*v1)*(*v1);
  155.         l2 += (double) (*v2)*(*v2);
  156.     }
  157.     if (l1 == 0.0 || l2 == 0.0) return (0.0);
  158.     dp /= sqrt(l1*l2);
  159.     return(dp);
  160. }
  161.  
  162. float
  163. veclen(v,len) float *v; int len; {
  164.     int i;
  165.     double denom;
  166.     double vl = 0;
  167.     denom = (double) len;
  168.     if (denom == 0) {
  169.         return(0.0);
  170.     }
  171.     for (i = 0; i < len; i++,v++) {
  172.         vl += (*v)*(*v)/denom;
  173.     }
  174.     vl = sqrt((vl));
  175.     return(vl);
  176. }
  177.  
  178. distort(vect,pattern,len,amount) 
  179. float *vect;
  180. float *pattern;
  181. int len;
  182. float   amount;
  183. {
  184.     int    i;
  185.     float   rval,val;
  186.  
  187.     for (i = 0; i < len; i++) {
  188.     rval = (float) (1.0 - 2.0*rnd());
  189.     *vect++ = *pattern++ + rval*amount;
  190.     }
  191. }
  192.  
  193. init_system() {
  194.     int     strain (), ptrain (), tall (), get_unames(),
  195.             test_pattern (), reset_weights(),newstart();
  196.     int change_lrate();
  197.  
  198.     lrate = 2.0;
  199.     epsilon_menu = NOMENU;
  200.     (void) install_var("lflag", Int,(int *) & lflag, 0, 0, SETPCMENU);
  201.  
  202.     (void) install_command("strain", strain, BASEMENU,(int *) NULL);
  203.     (void) install_command("ptrain", ptrain, BASEMENU,(int *) NULL);
  204.     (void) install_command("tall", tall, BASEMENU,(int *) NULL);
  205.     (void) install_command("test", test_pattern, BASEMENU,(int *) NULL);
  206.     (void) install_command("reset",reset_weights,BASEMENU,(int *)NULL);
  207.     (void) install_command("newstart",newstart,BASEMENU,(int *)NULL);
  208.     (void) install_command("patterns", get_pattern_pairs, 
  209.                            GETMENU,(int *) NULL);
  210.     (void) install_command("unames", get_unames, GETMENU,(int *) NULL);
  211.     (void) install_var("nepochs", Int,(int *) & nepochs, 0, 0, SETPCMENU);
  212.     (void) install_command("lrate", change_lrate, SETPARAMMENU, (int *) NULL);
  213.     (void) install_var("lrate", Float,(int *) & lrate, 0, 0, NOMENU);
  214.     (void) install_var("ecrit", Float, (int *)& ecrit,0,0,SETPCMENU);
  215.     (void) install_var("noise", Float, (int *)&noise,0,0,SETPARAMMENU);
  216.     (void) install_var("linear", Int,(int *) &linear,0,0,SETMODEMENU);
  217.     (void) install_var("temp", Float, (int *)&temp,0,0,SETPARAMMENU);
  218.     (void) install_var("lt", Int,(int *) <,0,0,SETMODEMENU);
  219.     (void) install_var("cs", Int,(int *) &cs,0,0,SETMODEMENU);
  220.     (void) install_var("hebb", Int,(int *) &hebb,0,0,SETMODEMENU);
  221.     (void) install_var("epochno", Int,(int *) & epochno, 0, 0, SETSVMENU);
  222.     (void) install_var("patno", Int,(int *) & patno, 0, 0, SETSVMENU);
  223.     init_pattern_pairs();
  224.     (void) install_var("tss", Float,(int *) & tss, 0, 0, SETSVMENU);
  225.     (void) install_var("pss", Float,(int *) & pss, 0, 0, SETSVMENU);
  226.     (void) install_var("ndp", Float,(int *) & ndp, 0, 0, SETSVMENU);
  227.     (void) install_var("vcor", Float,(int *) & vcor, 0, 0, SETSVMENU);
  228.     (void) install_var("nvl", Float,(int *) & nvl, 0, 0, SETSVMENU);
  229.     init_weights();
  230. }
  231.  
  232. define_system() {
  233.     register int    i,j;
  234.  
  235.     if (!nunits) {
  236.     put_error("cannot init pa system, nunits not defined");
  237.     return(FALSE);
  238.     }
  239.     else
  240.     if (!noutputs) {
  241.         put_error("cannot init pa system, noutputs not defined");
  242.         return(FALSE);
  243.     }
  244.     else
  245.     if (!ninputs) {
  246.         put_error("cannot init pa system, ninputs not defined");
  247.         return(FALSE);
  248.     }
  249.     netinput = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  250.     (void) install_var("netinput", Vfloat,(int *) netinput,
  251.         nunits, 0, SETSVMENU);
  252.     for (i = 0; i < nunits; i++)
  253.     netinput[i] = 0.0;
  254.  
  255.     output = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  256.     (void) install_var("output", Vfloat,(int *) & output[ninputs],
  257.         noutputs, 0, SETSVMENU);
  258.     for (i = 0; i < nunits; i++)
  259.     output[i] = 0.0;
  260.  
  261.     error = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  262.     (void) install_var("error", Vfloat,(int *) & error[ninputs], 
  263.                 noutputs, 0, SETSVMENU);
  264.     for (i = 0; i < nunits; i++)
  265.     error[i] = 0.0;
  266.  
  267.     target = (float *) emalloc((unsigned)(sizeof(float) * noutputs));
  268.     (void) install_var("target", Vfloat,(int *) target, noutputs, 0,
  269.                SETSVMENU);
  270.     for (i = 0; i < noutputs; i++)
  271.     target[i] = 0.0;
  272.  
  273.     input = (float *) emalloc((unsigned)(sizeof(float) * ninputs));
  274.     (void) install_var("input", Vfloat,(int *) input, ninputs, 0, SETSVMENU);
  275.     
  276.     for (i = 0; i < ninputs; i++)
  277.     input[i] = 0.0;
  278.  
  279.     System_Defined = TRUE;
  280.     return(TRUE);
  281. }
  282.  
  283.  
  284. float  logistic (x)
  285. float  x;
  286. {
  287.     x /= temp;
  288.     if (x > 11.5129)
  289.     return(.99999);
  290.       else
  291.     if (x < -11.5129)
  292.         return(.00001);
  293.     else
  294.        return(1.0 / (1.0 + (float) exp( (double) ((-1.0) * x))));
  295. }
  296.  
  297. probability(val)
  298. float  val;
  299. {
  300.     return((rnd() < val) ? 1 : 0);
  301. }
  302.  
  303.  
  304. compute_output() {
  305.     register int    i,j,sender,num;
  306.  
  307.     for (i = ninputs; i < nunits; i++) {/* ranges over output units */
  308.     netinput[i] = bias[i];
  309.     sender = first_weight_to[i];
  310.     num = num_weights_to[i];
  311.     for (j = 0; j < num; j++) { /* ranges over input units */
  312.         netinput[i] += output[sender++]*weight[i][j];
  313.     }
  314.     if (linear) {
  315.       output[i] = netinput[i];
  316.     }
  317.     else if (lt) {
  318.       output[i] = (float) (netinput[i] > 0 ? 1.0 : 0.0 );
  319.     }
  320.     else if    (cs) {
  321.       output[i] =  logistic(netinput[i]);
  322.     }
  323.     else { /* default, stochastic mode */
  324.       output[i] = (float)probability((float)logistic(netinput[i]));
  325.     }
  326.     }
  327. }
  328.  
  329. compute_error() {
  330.     register int    i,j;
  331.  
  332.     for (i = ninputs, j = 0; i < nunits; j++, i++) {
  333.     error[i] = target[j] - output[i];
  334.     }
  335. }
  336.  
  337. change_weights() {
  338.     register int    i,j,ti,sender,num;
  339.  
  340.     if (hebb) {
  341.       for (i = ninputs,ti = 0; i < nunits; i++,ti++) {
  342.         output[i] = target[ti];
  343.     sender = first_weight_to[i];
  344.     num = num_weights_to[i];
  345.     for (j = 0; j < num; j++) {
  346.          weight[i][j] +=
  347.             epsilon[i][j]*output[i]*output[sender++];
  348.     }
  349.     bias[i] += bepsilon[i]*output[i];
  350.       }
  351.     }
  352.     else { /* delta rule, by default */
  353.       for (i = ninputs; i < nunits; i++) {
  354.     sender = first_weight_to[i];
  355.     num = num_weights_to[i];
  356.     for (j = 0; j < num; j++) {
  357.          weight[i][j] +=
  358.             epsilon[i][j]*error[i]*output[sender++];
  359.     }
  360.     bias[i] += bepsilon[i]*error[i];
  361.       }
  362.     }
  363. }
  364.  
  365. constrain_weights() {
  366. }
  367.  
  368. setinput() {
  369.     register int    i;
  370.  
  371.     for (i = 0; i < ninputs; i++) {
  372.         output[i] = input[i];
  373.     }
  374.     if (patno < 0) cpname[0] = '\0';
  375.     else strcpy(cpname,pname[patno]);
  376. }
  377.  
  378. trial() {
  379.     setinput();
  380.     compute_output();
  381.     compute_error();
  382.     sumstats();
  383. }
  384.  
  385. sumstats() {
  386.  
  387.     pss  =  (float) sumsquares(target,&output[ninputs],noutputs);
  388.     vcor =  (float) veccor(target,&output[ninputs],noutputs);
  389.     nvl  =  (float) veclen(&output[ninputs],noutputs);
  390.     ndp  =  (float) dotprod(target,&output[ninputs],noutputs);
  391.     tss += pss;
  392. }
  393.  
  394. ptrain() {
  395.   train('p');
  396. }
  397.  
  398. strain() {
  399.   train('s');
  400. }
  401.  
  402. train(c) char c; {
  403.     int     t,i,old,npat;
  404.     char    *str;
  405.  
  406.     if (!System_Defined)
  407.     if (!define_system())
  408.         return;
  409.  
  410.     for (t = 0; t < nepochs; t++) {
  411.     if (!tallflag) epochno++;
  412.     for (i = 0; i < npatterns; i++)
  413.         used[i] = i;
  414.     if (c == 'p') {
  415.       for (i = 0; i < npatterns; i++) {
  416.         npat = rnd() * (npatterns - i) + i;
  417.         old = used[i];
  418.         used[i] = used[npat];
  419.         used[npat] = old;
  420.       }
  421.     }
  422.     tss = 0.0;
  423.     for (i = 0; i < npatterns; i++) {
  424.         if (Interrupt) {
  425.         Interrupt_flag = 0;
  426.         update_display();
  427.         if (contin_test() == BREAK) return(BREAK);
  428.         }
  429.         patno = used[i];
  430.         distort(input,ipattern[patno],ninputs,noise);
  431.         distort(target,tpattern[patno],noutputs,noise);
  432.         trial();
  433.         /* the && lflag insures that we do not get a redundant
  434.            display update if change_weights is not going to be
  435.            called */
  436.         if (step_size == CYCLE && lflag) {
  437.         update_display();
  438.             if (single_flag) {
  439.            if (contin_test() == BREAK) return(BREAK);
  440.         }
  441.         }
  442.         if (lflag) change_weights();
  443.         if (step_size <= PATTERN) {
  444.           update_display();
  445.           if (single_flag) {
  446.         if (contin_test() == BREAK) return(BREAK);
  447.           }
  448.         }
  449.     }
  450.     if (step_size == EPOCH) {
  451.      update_display();
  452.      if (single_flag) {
  453.         if (contin_test() == BREAK) return(BREAK);
  454.      }
  455.         }
  456.     if (tss < ecrit)
  457.         break;
  458.     }
  459.     if (step_size == NEPOCHS) {
  460.     update_display();
  461.     }
  462.     return(CONTINUE);
  463. }
  464.  
  465. tall() {
  466.   int save_lflag;
  467.   int save_single_flag;
  468.   int save_nepochs;
  469.   int save_step_size;
  470.   
  471.   save_lflag = lflag;  lflag = 0;
  472.   save_single_flag = single_flag; 
  473.   if (in_stream == stdin) single_flag = 1;
  474.   save_nepochs = nepochs;  nepochs = 1;
  475.   save_step_size = step_size; if (step_size > PATTERN) step_size = PATTERN;
  476.   tallflag = 1;
  477.   train('s');
  478.   tallflag = 0;
  479.   lflag = save_lflag;
  480.   nepochs = save_nepochs;
  481.   single_flag = save_single_flag;
  482.   step_size = save_step_size;
  483. }
  484.   
  485. test_pattern() {
  486.     char   *str;
  487.     float *ivec, *tvec;
  488.     float tmp_noise;
  489.  
  490.     if(! System_Defined)
  491.       if(! define_system())
  492.        return(CONTINUE);
  493.  
  494.     str = get_command("input (#N, ?N, E for enter): ");
  495.     if (str == NULL) return(CONTINUE);
  496.     if(*str == '#' || *str == '?') {
  497.     if((patno = get_pattern_number(str+1)) < 0) {
  498.        return(put_error("Invalid pattern specification."));
  499.     }
  500.     tmp_noise = (float) (*str = '#' ? 0.0 : noise );
  501.         distort(input, ipattern[patno], ninputs, tmp_noise);
  502.     }
  503.     else {
  504.     patno = -1;
  505.     if ((ivec = readvec(" input ",ninputs)) == (float *) NULL) 
  506.         return(CONTINUE);
  507.         distort(input, ivec, ninputs, 0.0);
  508.     }
  509.     str = get_command("target (#N, ?N, E for enter): ");
  510.     if (str == NULL) {
  511.     tvec = readvec(" target ",noutputs);
  512.     }
  513.     else if(*str == '#' || *str == '?') {
  514.     if((patno = get_pattern_number(str+1)) < 0) {
  515.        return(put_error("Invalid pattern specification."));
  516.     }
  517.     tmp_noise = (float) (*str = '#' ? 0.0 : noise );
  518.         distort(target, tpattern[patno], noutputs, tmp_noise);
  519.     } 
  520.     else {
  521.     if ((tvec = readvec(" target ",noutputs)) == (float *) NULL) 
  522.         return(CONTINUE);
  523.         distort(target, tvec, noutputs, 0.0);
  524.     }
  525.     trial();
  526.     update_display();
  527.     return(CONTINUE);
  528. }
  529.  
  530. newstart() {
  531.     random_seed = rand();
  532.     reset_weights();
  533. }
  534.  
  535. reset_weights() {
  536.     register int    i,j,end;
  537.     
  538.     epochno = 0;
  539.     tss = 0.0;
  540.     pss = 0.0;
  541.     patno = 0;
  542.     ndp = vcor = nvl = 0.0;
  543.     cpname[0] = '\0';
  544.     
  545.     srand(random_seed);
  546.  
  547.     if (!System_Defined)
  548.     if (!define_system())
  549.         return;
  550.  
  551.     for (j = ninputs; j < nunits; j++) {
  552.     for (i = first_weight_to[j], end = i + num_weights_to[j];
  553.          i < end; i++) {
  554.         weight[j][i] = 0.0;
  555.     }
  556.     bias[j] = 0.0;
  557.     }
  558.     for (i = 0; i < ninputs; i++) {
  559.       input[i] = 0.0;
  560.     }
  561.     for (i = 0; i < noutputs; i++) {
  562.       target[i] = 0.0;
  563.     }
  564.     for (i = 0; i < nunits; i++) {
  565.       output[i] = error[i] = 0.0;
  566.     }
  567.     update_display();
  568. }
  569.  
  570. init_weights() {
  571.     (void) install_command("network", define_network, GETMENU,(int *) NULL);
  572.     (void) install_command("weights", read_weights, GETMENU,(int *) NULL);
  573.     (void) install_command("weights", write_weights, SAVEMENU,(int *) NULL);
  574.     (void) install_var("nunits", Int,(int *) & nunits, 0, 0, SETCONFMENU);
  575.     (void) install_var("ninputs", Int,(int *) & ninputs, 0, 0, SETCONFMENU);
  576.     (void) install_var("noutputs", Int,(int *) & noutputs, 0, 0, SETCONFMENU);
  577. }
  578.