home *** CD-ROM | disk | FTP | other *** search
/ ARM Club 3 / TheARMClub_PDCD3.iso / hensa / misc / b186_1 / Source / c / cs < prev    next >
Text File  |  1987-12-23  |  16KB  |  672 lines

  1. /*
  2.  
  3.        This file is part of the PDP software package.
  4.          
  5.        Copyright 1987 by James L. McClelland and David E. Rumelhart.
  6.        
  7.        Please refer to licensing information in the file license.txt,
  8.        which is in the same directory with this source file and is
  9.        included here by reference.
  10. */
  11.  
  12.  
  13. /* file: cs.c
  14.  
  15.     Do the actual work for the cs program.
  16.     
  17.     First version implemented by Elliot Jaffe.
  18.     
  19.     Date of last revision:  8-12-87/JLM.
  20. */
  21.  
  22. #include "general.h"
  23. #include "cs.h"
  24. #include "variable.h"
  25. #include "command.h"
  26. #include "patterns.h"
  27. #include "weights.h"
  28.  
  29. #define     MAXTIMES    20
  30. #define  FMIN (1.0e-37)
  31. #define  fcheck(x) (fabs(x) > FMIN ? (float) x : (float) 0.0)
  32.  
  33. char   *Prompt = "cs: ";
  34. char   *Default_step_string = "cycle";
  35. boolean System_Defined = FALSE;
  36.  
  37. boolean     clamp = 0;
  38. boolean        boltzmann = 0;
  39. boolean        harmony = 0;
  40.  
  41. float        temperature,coolrate;
  42. float        goodness;
  43.  
  44. float  *activation;
  45. float  *netinput;
  46. float  *intinput;
  47. float  *extinput;
  48.  
  49. float   estr = 1.0;
  50. float   istr = 1.0;
  51.  
  52. float    kappa;
  53.  
  54. int    epochno = 0; /* not used in cs */
  55. int    patno = 0;
  56. int     ncycles = 10;
  57. int     nupdates = 100;
  58. int     cycleno = 0;
  59. int     updateno = 0;
  60. int    unitno = 0;
  61. char    cuname[40];
  62.  
  63. int    ntimes = 0;
  64.  
  65. struct anneal_schedule {
  66.     int    time;
  67.     float    temp;
  68. }  *anneal_schedule;
  69.  
  70. int maxtimes = MAXTIMES;
  71.  
  72. struct anneal_schedule *last_temp;
  73. struct anneal_schedule *next_temp;
  74. struct anneal_schedule *current_temp;
  75.  
  76. define_system() {
  77.     int     i,j;
  78.     float   *tmp;
  79.  
  80.     activation = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  81.     (void)install_var("activation",Vfloat,(int *)activation,nunits,0,
  82.                                 SETSVMENU);
  83.     for (i = 0; i < nunits; i++)
  84.     activation[i] = 0.0;
  85.  
  86.     netinput = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  87.     (void)install_var("netinput",Vfloat,(int *)netinput,nunits,0,SETSVMENU);
  88.     for (i = 0; i < nunits; i++)
  89.     netinput[i] = 0.0;
  90.  
  91.     intinput = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  92.     (void)install_var("intinput",Vfloat,(int *)intinput,nunits,0,SETSVMENU);
  93.     for (i = 0; i < nunits; i++)
  94.     intinput[i] = 0.0;
  95.  
  96.     extinput = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  97.     (void)install_var("extinput",Vfloat,(int *)extinput,nunits,0,SETSVMENU);
  98.     for (i = 0; i < nunits; i++)
  99.     extinput[i] = 0.0;
  100.  
  101.     anneal_schedule = ((struct anneal_schedule *)
  102.            emalloc((unsigned)maxtimes*sizeof(struct anneal_schedule)));
  103.                
  104.     next_temp = anneal_schedule;
  105.     last_temp = anneal_schedule;
  106.     last_temp->time = 0;
  107.     last_temp->temp = 0.0;
  108.     
  109.     constrain_weights();
  110.  
  111.     System_Defined = TRUE;
  112.     reset_system();
  113.     return(TRUE);
  114. }
  115.  
  116.  
  117. double  logistic (i) float   i; {
  118.     double  val;
  119.     double  ret_val;
  120.     double  exp ();
  121.  
  122.     if( temperature <= 0.0)
  123.     return(i > 0);
  124.     else
  125.         val = i / temperature;
  126.  
  127.     if (val > 11.5129)
  128.     return(.99999);
  129.     else
  130.     if (val < -11.5129)
  131.         return(.00001);
  132.     else
  133.         ret_val = 1.0 / (1.0 + exp(-1.0 * val));
  134.     if (ret_val > FMIN) return(ret_val);
  135.     return(0.0);
  136. }
  137.  
  138. probability(val) double  val; {
  139.     return((rnd() < val) ? 1 : 0);
  140. }
  141.  
  142. float   
  143. annealing ( iter) int     iter; {
  144.  /* compute the current temperature given the the last landmark, iteration
  145.            number and the current coolrate We use a simple linear function. */
  146.  
  147.     double tmp;
  148.  
  149.     if(iter >= last_temp->time) return(last_temp->temp);
  150.     if(iter >= next_temp->time) {
  151.     tmp = next_temp->temp;
  152.     current_temp = next_temp++;
  153.     coolrate = (current_temp->temp-next_temp->temp)/
  154.            (float)(next_temp->time - current_temp->time);
  155.     }
  156.     else
  157.         tmp = current_temp->temp -
  158.                (coolrate*(float)(iter - current_temp->time));
  159.  
  160.     return((tmp < FMIN ? 0.0 : tmp));
  161. }
  162.  
  163. get_schedule() {
  164.     int  cnt;
  165.     char  *str;
  166.     char   string[40];
  167.     struct anneal_schedule *asptr;
  168.  
  169.     if(!System_Defined)
  170.        if(!define_system())
  171.       return(BREAK);
  172.  
  173.     asptr = anneal_schedule;
  174.     next_temp = anneal_schedule;
  175.  
  176. restart:
  177.     str = get_command("Setting annealing schedule, initial temperature : ");
  178.     if( str == NULL) return(CONTINUE);
  179.     if(sscanf(str,"%f",&(asptr->temp)) == 0) {
  180.        if (put_error("Invalid initial temperature specification.") == BREAK) {
  181.            return(BREAK);
  182.        }
  183.        goto restart;
  184.     }
  185.     if(asptr->temp < 0) {
  186.        if (put_error("Temperatures must be positive.") == BREAK) {
  187.         return(BREAK);
  188.        }
  189.        goto restart;
  190.     }
  191.     cnt = 1;
  192.     last_temp = asptr++;
  193.     sprintf(string,"time for first milestone:  ");
  194.     while((str = get_command(string)) != NULL) {
  195.        if(strcmp(str,"end") == 0) return(CONTINUE);
  196.        if (cnt >= maxtimes) {
  197.       maxtimes += 10;
  198.           anneal_schedule = ((struct anneal_schedule *) 
  199.                erealloc((char *)anneal_schedule,
  200.                      (unsigned)((maxtimes-10)*sizeof(struct anneal_schedule)),
  201.             (unsigned)maxtimes*sizeof(struct anneal_schedule)));
  202.       next_temp = anneal_schedule;
  203.       asptr = anneal_schedule + cnt;
  204.       last_temp = asptr;
  205.        }
  206.        if (sscanf(str,"%d",&asptr->time) == 0) {
  207.           if (put_error("Non_numeric time. ") == BREAK) {
  208.        return(BREAK);
  209.       }
  210.       continue;
  211.        }
  212.        if(asptr->time <= last_temp->time) {
  213.       if (put_error("Times must increase.") == BREAK) {
  214.        return(BREAK);
  215.       }
  216.       continue;
  217.        }
  218.        sprintf(string,"at time %d the temp should be: ",asptr->time);
  219.        if((str = get_command(string)) == NULL) {
  220.              if (put_error("Nothing set at this milestone.") == BREAK) {
  221.        return(BREAK);
  222.       }
  223.           goto retry;
  224.        }
  225.        if(sscanf(str,"%f",&(asptr->temp)) == 0) {
  226.              if(put_error("Non_numberic temperature.") == BREAK) {
  227.        return(BREAK);
  228.       }
  229.           goto retry;
  230.        }
  231.        if(asptr->temp < 0) {
  232.           if(put_error("Temperatures must be positive.") == BREAK) {
  233.        return(BREAK);
  234.       }
  235.           goto retry;
  236.        }
  237.        last_temp = asptr++;
  238.        cnt++;
  239. retry:
  240.        sprintf(string,"time for milestone %d: ",cnt);
  241.     }
  242.     return(CONTINUE);
  243. }
  244.  
  245.  
  246. cycle() {
  247.     int     iter;
  248.     char    *str;
  249.  
  250.     if (!System_Defined)
  251.     if (!define_system())
  252.         return(BREAK);
  253.  
  254.  
  255.     for (iter = 0; iter < ncycles; iter++) {
  256.     cycleno++;
  257.     if(boltzmann || harmony)
  258.         temperature = annealing(cycleno);
  259.     if (rupdate() == BREAK) return(BREAK);
  260.     if(step_size == CYCLE) {
  261.        get_goodness();
  262.        cs_update_display();
  263.        if(single_flag) {
  264.           if (contin_test() == BREAK) return(BREAK);
  265.        }
  266.     }
  267.     if (Interrupt) {
  268.         get_goodness();
  269.         cs_update_display();
  270.         Interrupt_flag = 0;
  271.         if (contin_test() == BREAK) return(BREAK);
  272.  
  273.     }
  274.     }
  275.     if (step_size == NCYCLES) {
  276.        get_goodness();
  277.        cs_update_display();
  278.     }
  279.     return(CONTINUE);
  280. }
  281.  
  282. get_goodness() {
  283.  
  284.     int i,j;
  285.     int num, sender, fs, ls; /* fs is first sender, ls is last */
  286.     double dg;
  287.  
  288.     dg = 0.0;
  289.     
  290.     if(harmony) {
  291.         for(i=ninputs; i < nunits; i++) {
  292.            sender = first_weight_to[i];
  293.            num = num_weights_to[i];
  294.            for(j = 0; j < num && sender < ninputs; j++,sender++) {
  295.                dg += 
  296.                weight[i][j]*activation[i]*activation[sender];
  297.            }
  298.            if (activation[i]) dg -= kappa*sigma[i];
  299.         }
  300.         goto ret_goodness;
  301.     }
  302.     for(i=0; i < nunits; i++) {
  303.        fs = first_weight_to[i];
  304.        ls = num_weights_to[i] + fs -1;
  305.        for(j = i+1; j < nunits; j++) {
  306.         if ( j < fs ) continue;
  307.         if ( j > ls ) break;
  308.         dg += weight[i][j-fs]*activation[i]*activation[j];
  309.        }
  310.        dg += bias[i]*activation[i];
  311.     }
  312. /* >> dont we want to let goodness be affected by istr whether or not
  313.    clamp is 0? Boltz is always clamped, but not schema */
  314.     if(clamp == 0) {
  315.         dg *= istr;
  316.         for(i=0; i < nunits; i++) {
  317.         dg += activation[i]*extinput[i]*estr;
  318.         }
  319.     }
  320. ret_goodness:    
  321.     goodness = dg;
  322.     return;
  323. }
  324.  
  325. constrain_weights() {
  326.     int    *nconnections;
  327.     int     i,j,num;
  328.     float   value;
  329.  
  330.     if(!harmony) return;
  331.  
  332.     nconnections = (int *) emalloc((unsigned)(sizeof(int) * nunits));
  333.     for (i = 0; i < nunits; i++) {
  334.     nconnections[i] = 0;
  335.     }
  336.  
  337.     for (j = ninputs; j < nunits; j++) {
  338.     num = num_weights_to[j];
  339.     for (i = 0; i < num; i++) {
  340.         if (weight[j][i])
  341.         nconnections[j]++;
  342.     }
  343.     }
  344.  
  345.     for (j = ninputs; j < nunits; j++) {
  346.     if (!nconnections[j])
  347.         continue;
  348.     value = sigma[j] / (float) nconnections[j];
  349.     num = num_weights_to[j];
  350.     for (i = 0; i < num; i++) {
  351.         if (weight[j][i]) {
  352.         weight[j][i] *= value;
  353.         }
  354.     }
  355.     }
  356.     free((char *)nconnections);
  357. }
  358.  
  359. zarrays() {
  360.     register int    i;
  361.  
  362.     if (!System_Defined)
  363.     if (!define_system())
  364.         return(BREAK);
  365.  
  366.     cycleno = 0;
  367.  
  368.     next_temp = anneal_schedule;
  369.     if(last_temp != next_temp) {
  370.     current_temp = next_temp++;
  371.     coolrate = 
  372.            (current_temp->temp-next_temp->temp)/(float)next_temp->time;
  373.     }
  374.     temperature = annealing (cycleno);
  375.  
  376.     goodness = 0;
  377.     updateno = 0;
  378.  
  379.     for (i = 0; i < nunits; i++) {
  380.     intinput[i] = netinput[i] = activation[i] = 0;
  381.     }
  382.     if (clamp) {
  383.         init_activations();
  384.     }
  385.     return(CONTINUE);
  386. }
  387.  
  388. init_activations() {
  389.     register int i;
  390.     for (i = 0; i < nunits; i++) {
  391.     if (extinput[i] == 1.0) {
  392.         activation[i] = 1.0;
  393.         continue;
  394.     }
  395.     if (extinput[i] == -1.0) {
  396.         activation[i] == 0.0;
  397.         continue;
  398.     }
  399.     }
  400. }
  401.  
  402. rupdate() {
  403.     register int    j,wi,sender,num,*fwp,*nwp,i,n;
  404.     char *str;
  405.     double dt, inti,neti,acti;
  406.  
  407.     for (updateno = 0,n = 0; n < nupdates; n++) {
  408.     updateno++;
  409.     unitno = i = randint(0, nunits - 1);
  410.     inti = 0.0;
  411.     if (harmony) {
  412.          neti = 0.0;
  413.          if (i < ninputs) {
  414.         if (extinput[i] == 0.0) {
  415.            for (j = ninputs,fwp = &first_weight_to[ninputs],
  416.                             nwp = &num_weights_to[ninputs];
  417.                     j < nunits; j++) {
  418.             wi = i - *fwp++;
  419.             if ( (wi >= *nwp++) || (wi < 0) ) continue;
  420.             neti += activation[j]*weight[j][wi];
  421.            }
  422.            neti = 2 * neti;
  423.            if (probability(logistic(neti)))
  424.             activation[i] = 1;
  425.            else
  426.             activation[i] = -1;
  427.         }
  428.         else {
  429.            if(extinput[i] < 0.0) activation[i] = -1;
  430.            if(extinput[i] > 0.0) activation[i] = 1;
  431.         }
  432.          }
  433.          else {
  434.            sender = first_weight_to[i];
  435.            num = num_weights_to[i];
  436.            for (j = 0; j < num && sender < ninputs; j++,sender++) {
  437.               neti += activation[sender]*weight[i][j];
  438.         }
  439.         neti -=  sigma[i]*kappa;
  440.         activation[i] = probability(logistic(neti));
  441.         netinput[i] = neti;
  442.         }
  443.     }
  444.     else {
  445.       if (clamp) {
  446.         if (extinput[i] > 0.0) {
  447.         activation[i] = 1.0;
  448.          goto end_of_rupdate;
  449.         }
  450.         if (extinput[i] < 0.0) {
  451.         activation[i] = 0.0;
  452.         goto end_of_rupdate;
  453.         }
  454.       }
  455.       sender = first_weight_to[i];
  456.       num = num_weights_to[i];
  457.       for (j = 0; j < num; j++) {
  458.         inti += activation[sender++] * weight[i][j];
  459.       }
  460.       inti  += bias[i];
  461.       if (clamp == 0) {
  462.         neti = istr * inti + estr * extinput[i];
  463.       }
  464.       else {
  465.         neti = istr * inti;
  466.       }
  467.       netinput[i] = neti;
  468.       intinput[i] = inti;
  469.       if (boltzmann) {
  470.         if (probability(logistic(neti)))
  471.         activation[i] = 1.0;
  472.         else
  473.         activation[i] = 0.0;
  474.       }
  475.       else {
  476.         if (neti > 0.0) {
  477.           if (activation[i] < 1.0) {
  478.         acti = activation[i];
  479.         dt = acti + neti*(1.0 - acti);
  480.         if (dt > 1.0) {
  481.             activation[i] = (float) 1.0;
  482.                 }
  483.         else activation[i] = (float) dt;
  484.               }
  485.             }
  486.         else {
  487.           if (activation[i] > (float) 0.0) {
  488.         acti = activation[i];
  489.             dt = acti + neti * acti;
  490.         if (dt < FMIN) {
  491.             activation[i] = (float) 0.0;
  492.             }
  493.         else activation[i] = (float) dt;
  494.           }
  495.             }
  496.       }
  497.     }
  498. end_of_rupdate:    
  499.         if(step_size == UPDATE)  {
  500.        get_goodness();
  501.        cs_update_display();
  502.        if (single_flag) {
  503.           if (contin_test() == BREAK) {
  504.               return(BREAK);
  505.           }
  506.        }
  507.     }
  508.         if(Interrupt)  {
  509.        Interrupt_flag = 0;
  510.        get_goodness();
  511.        cs_update_display();
  512.        if (contin_test() == BREAK) {
  513.               return(BREAK);
  514.        }
  515.     }
  516.     }
  517.     return(CONTINUE);
  518. }
  519.  
  520. input() {
  521.     int     i;
  522.     char   *str,tstr[100];
  523.  
  524.     if (!System_Defined)
  525.     if (!define_system())
  526.         return(BREAK);
  527.     if (!nunames) {
  528.         return(put_error("Must provide unit names. "));
  529.     }
  530. again:
  531.     str = get_command("Do you want to reset all inputs?: (y or n)");
  532.     if (str == NULL) goto again;
  533.     if (str[0] == 'y') {
  534.         for (i = 0; i < nunits; i++)
  535.         extinput[i] = 0;
  536.     }
  537.     else if (str[0] != 'n') {
  538.         put_error ("Must enter y or n!");
  539.     goto again;
  540.     }
  541.  
  542. gcname: 
  543.     str = get_command("give unit name or number: ");
  544.     if (str == NULL || strcmp(str,"end") == 0) {
  545.         if (clamp) init_activations();
  546.     cs_update_display();
  547.     return(CONTINUE);
  548.     }
  549.     if (sscanf(str,"%d",&i) == 0) {
  550.         for (i = 0; i < nunames; i++) {
  551.         if (startsame(str, uname[i])) break;
  552.         }
  553.     }
  554.     if (i >= nunames) {
  555.     if (put_error("invalid name or number -- try again.") == BREAK) {
  556.         return(BREAK);
  557.     }
  558.     goto gcname;
  559.     }
  560. gcval: 
  561.     sprintf(tstr,"enter input strength of %s:  ",uname[i]);
  562.     str = get_command(tstr);
  563.     if (str == NULL) {
  564.         sprintf(err_string,"No strength specified for %s",uname[i]);
  565.     if (put_error(err_string) == BREAK) {
  566.      return(BREAK);
  567.     }
  568.     goto gcname;
  569.     }
  570.     if (sscanf(str, "%f", &extinput[i]) != 1) {
  571.     if (put_error("unrecognized value -- try again.") == BREAK) {
  572.      return(BREAK);
  573.     }
  574.     goto gcval;
  575.     }
  576.     goto gcname;
  577. }
  578.  
  579. setinput() {
  580.     register int    i;
  581.     register float  *pp;
  582.  
  583.     for (i = 0, pp = ipattern[patno]; i < nunits; i++, pp++) {
  584.             extinput[i] = *pp;
  585.     }
  586.     strcpy(cpname,pname[patno]);
  587. }
  588.  
  589. test_pattern() {
  590.     char   *str;
  591.  
  592.     if (!System_Defined)
  593.         if (!define_system())
  594.             return(BREAK);
  595.  
  596.     if(ipattern[0] == NULL) {
  597.        return(put_error("No file of test patterns has been read in."));
  598.     }
  599. again:
  600.     str = get_command("Test which pattern? (name or number): ");
  601.     if(str == NULL) return(CONTINUE);
  602.     if ((patno = get_pattern_number(str)) < 0) {
  603.         if (put_error("Invalid pattern specification") == BREAK) {
  604.      return(BREAK);
  605.     }
  606.         goto again;
  607.     }
  608.     setinput();
  609.     zarrays();
  610.  
  611.     cycle();
  612.     return(CONTINUE);
  613. }
  614.  
  615.  
  616. newstart() {
  617.     random_seed = rand();
  618.     reset_system();
  619. }
  620.  
  621. reset_system() {
  622.     srand(random_seed);
  623.     clear_display();
  624.     zarrays();
  625.     cs_update_display();
  626.     return(CONTINUE);
  627. }
  628.  
  629. init_system() {
  630.     int get_unames(),test_pattern(),read_weights(),write_weights();
  631.  
  632.     epsilon_menu = NOMENU;
  633.  
  634.     (void) install_command("network", define_network, GETMENU,(int *) NULL);
  635.     (void) install_command("weights", read_weights, GETMENU,(int *) NULL);
  636.     (void) install_command("cycle", cycle, BASEMENU,(int *) NULL);
  637.     (void) install_command("input", input, BASEMENU,(int *) NULL);
  638.     (void) install_command("test", test_pattern, BASEMENU,(int *) NULL);
  639.     (void) install_command("unames", get_unames, GETMENU,(int *) NULL);
  640.     (void) install_command("patterns", get_patterns, GETMENU,(int *) NULL);
  641.     (void) install_command("reset", reset_system, BASEMENU,(int *) NULL);
  642.     (void) install_command("newstart", newstart, BASEMENU,(int *) NULL);
  643.     (void) install_command("weights", write_weights, SAVEMENU,(int *) NULL);
  644.     (void) install_command("annealing", get_schedule, GETMENU,(int *) NULL);
  645.  
  646.     (void) install_var("patno", Int,(int *) & patno, 0, 0, SETSVMENU);
  647.     init_patterns();
  648.     (void) install_var("cycleno", Int,(int *) & cycleno, 0, 0, SETSVMENU);
  649.     (void) install_var("updateno", Int,(int *) & updateno, 0, 0, SETSVMENU);
  650.     (void) install_var("unitno", Int,(int *) & unitno, 0, 0, SETSVMENU);
  651.     (void) install_var("cuname", String,(int *) cuname, 0, 0, SETSVMENU);
  652.     (void) install_var("clamp", Int,(int *) & clamp, 0, 0, SETMODEMENU);
  653.     (void) install_var("nunits", Int,(int *) & nunits, 0, 0, SETCONFMENU);
  654.     (void) install_var("ninputs", Int,(int *) & ninputs, 0, 0, SETCONFMENU);
  655.     (void) install_var("estr", Float,(int *) & estr, 0, 0, SETPARAMMENU);
  656.     (void) install_var("istr", Float,(int *) & istr, 0, 0, SETPARAMMENU);
  657.     (void) install_var("kappa", Float,(int *) & kappa, 0, 0, SETPARAMMENU);
  658.     (void) install_var("boltzmann", Int, (int *) & boltzmann, 0, 0, 
  659.                                 SETMODEMENU);
  660.     (void) install_var("harmony", Int, (int *) & harmony, 0, 0, SETMODEMENU);
  661.     (void) install_var("temperature",Float, (int *) & temperature, 0, 0, 
  662.                                 SETSVMENU);
  663.     (void) install_var("goodness",Float, (int *) & goodness, 0, 0, SETSVMENU);
  664.     (void) install_var("ncycles", Int,(int *) & ncycles, 0, 0, SETPCMENU);
  665.     (void) install_var("nupdates", Int,(int *) & nupdates, 0, 0, SETPCMENU);
  666. }
  667.  
  668. cs_update_display() {
  669.     if (unitno < nunames) strcpy(cuname,uname[unitno]);
  670.     update_display();
  671. }
  672.