home *** CD-ROM | disk | FTP | other *** search
/ ARM Club 3 / TheARMClub_PDCD3.iso / hensa / misc / b186_1 / Source / c / cl < prev    next >
Text File  |  1994-02-15  |  10KB  |  426 lines

  1. /*
  2.  
  3.        This file is part of the PDP software package.
  4.          
  5.        Copyright 1987 by James L. McClelland and David E. Rumelhart.
  6.        
  7.        Please refer to licensing information in the file license.txt,
  8.        which is in the same directory with this source file and is
  9.        included here by reference.
  10. */
  11.  
  12.  
  13. /* file: cl.c
  14.  
  15.     System file for competitive learning.
  16.     
  17.     First version implemented by Elliot Jaffe
  18.     
  19.     Date of last revision:  8-12-87/JLM
  20. */
  21.  
  22. #include "general.h"
  23. #include "variable.h"
  24. #include "patterns.h"
  25. #include "command.h"
  26. #include "cl.h"
  27.  
  28. char   *Prompt = "cl: ";
  29. char   *Default_step_string = "epoch";
  30. boolean System_Defined = FALSE;
  31.  
  32. boolean lflag = 1;
  33. int     nepochs = 20;
  34. int     epochno = 0;
  35. float   lrate = 0.2;
  36.  
  37. int    nunits = 0;
  38. int     noutputs = 0;        /* number of units in competitive pool */
  39. int     ninputs = 0;        /* number of units in input layer */
  40. float  **weight;        /* pointers to vectors of weights */
  41. int    *activation = NULL;    /* activations for all units */
  42. float  *netinput = NULL;    /* sum of inputs for pool units */
  43.  
  44. int    tallflag = 0;
  45.  
  46. int     patno = 0;
  47. int    winner = 0;
  48.  
  49. define_system() {
  50.     int     i,j;
  51.  
  52.     if (noutputs <= 0) {
  53.     put_error("cannot initialize weights without noutputs");
  54.     return(FALSE);
  55.     }
  56.  
  57.     if (ninputs <= 0) {
  58.     put_error("cannot initialize weights without ninputs");
  59.     return(FALSE);
  60.     }
  61.  
  62.     nunits = ninputs + noutputs;
  63.  
  64.     if (activation != NULL) {
  65.     free((char *) activation);
  66.     }
  67.  
  68.     if (netinput != NULL) {
  69.     free((char *) netinput);
  70.     }
  71.  
  72.     activation = (int *) emalloc((unsigned)(nunits * sizeof(int)));
  73.     (void) install_var("activation", Vint,(int *) activation, nunits,0,SETSVMENU);
  74.     netinput = (float *) emalloc((unsigned)(nunits * sizeof(float)));
  75.     (void) install_var("netinput", Vfloat,(int *) netinput, nunits,0,SETSVMENU);
  76.  
  77.     weight = ((float **) emalloc((unsigned)(nunits*sizeof(float *))));
  78.     
  79.     for (i = ninputs; i < nunits; i++) {
  80.     weight[i] = ( (float *) emalloc ( (unsigned) ninputs * sizeof (float) ));
  81.     for (j = 0; j < ninputs; j++) {
  82.         weight[i][j] = 0.0;
  83.     }
  84.     }
  85.     (void) install_var("weight", PVweight,(int *) weight, nunits, nunits,SETWTMENU);
  86.        
  87.     first_weight_to = (int *) emalloc((unsigned)(sizeof(int) * nunits));
  88.     num_weights_to = (int *) emalloc((unsigned)(sizeof(int) * nunits));
  89.  
  90.     for (i = 0; i < ninputs; i++) {
  91.     first_weight_to[i] = 0;
  92.     num_weights_to[i] = 0;
  93.     }
  94.     for ( ; i < nunits; i++) {
  95.     first_weight_to[i] = 0;
  96.     num_weights_to[i] = ninputs;
  97.     }
  98.     
  99.     for (i = 0; i < nunits; i++) {
  100.     activation[i] = 0;
  101.     }
  102.  
  103.     for (i = ninputs; i < nunits; i++) {
  104.     netinput[i] = 0.0;
  105.         for (j = 0; j < ninputs; j++) {
  106.         netinput[i] += weight[i][j] = rnd();
  107.     }
  108.     for (j = 0; j < ninputs; j++) {
  109.         weight[i][j] *= 1.0/netinput[i];
  110.     }
  111.     activation[i] = 0;
  112.     netinput[i] = 0.0;
  113.     }
  114.  
  115.     System_Defined = TRUE;
  116.  
  117.     return(TRUE);
  118. }
  119.  
  120. get_weights() {
  121.     register int    i,
  122.                     j;
  123.     char   *str;
  124.     FILE * iop;
  125.  
  126.     if(! System_Defined)
  127.       if(! define_system())
  128.        return(BREAK);
  129.  
  130.     str = get_command("fname: ");
  131.     if (str == NULL) return(CONTINUE);
  132.     if ((iop = fopen(str, "r")) == NULL) {
  133.     return(put_error("Cannot open file"));
  134.     }
  135.     for (i = ninputs; i < nunits; i++) {
  136.     for (j = 0; j < ninputs; j++) {
  137.         (void) fscanf(iop, "%f", &weight[i][j]);
  138.     }
  139.     (void) fscanf(iop, "\n");
  140.     }
  141.  
  142.     epochno = 0;
  143.  
  144.     for (i = ninputs; i < nunits; i++) {
  145.       netinput[i]=activation[i] = 0.0;
  146.     }
  147.     for (i = 0; i < ninputs; i++) {
  148.       activation[i] = 0.0;
  149.     }
  150.     update_display();
  151.     fclose(iop);
  152.     return(CONTINUE);
  153. }
  154.  
  155.  
  156. save_weights() {
  157.     register int    i,
  158.                     j;
  159.     float  *fp;
  160.     FILE * iop;
  161.     char   *str;
  162.     char tstr[40];
  163.     char fname[BUFSIZ];
  164.     char *star_ptr;
  165.  
  166.     if(! System_Defined)
  167.       if(! define_system())
  168.        return(BREAK);
  169.  
  170. nameagain:
  171.     str = get_command("file name: ");
  172.     if (str == NULL) return(CONTINUE);
  173.     if ( (star_ptr = strchr(str,'*')) != NULL) {
  174.         strcpy(tstr,star_ptr+1);
  175.         sprintf(star_ptr,"%d",epochno);
  176.     strcat(str,tstr);
  177.     }
  178.     strcpy(fname,str);
  179.     if ((iop = fopen(fname, "r")) != NULL) {
  180.         fclose(iop);
  181.         get_command("file exists -- clobber? ");
  182.     if (str == NULL || str[0] != 'y') {
  183.        goto nameagain;
  184.     }
  185.     }
  186.     if ((iop = fopen(fname, "w")) == NULL) {
  187.     return(put_error("cannot open file for weights"));
  188.     }
  189.     for (i = ninputs; i < nunits; i++) {
  190.     for (j = 0; j < ninputs; j++) {
  191.         fprintf(iop, "%6.3f", weight[i][j]);
  192.     }
  193.     fprintf(iop, "\n");
  194.     }
  195.     (void) fclose(iop);
  196.     return(CONTINUE);
  197. }
  198.  
  199. compute_output() {
  200.     int     i, j;
  201.  
  202.     for (i = ninputs; i < nunits; i++) {
  203.     netinput[i] = 0.0;
  204.     activation[i] = 0;
  205.     }
  206.  
  207.     for (i = 0; i < ninputs; i++) {
  208.     if (activation[i]) {
  209.         for (j = ninputs; j < nunits; j++) {
  210.         netinput[j] += weight[j][i];
  211.         }
  212.     }
  213.     }
  214.  
  215.     for (winner = ninputs, i = ninputs; i < nunits; i++) {
  216.     if (netinput[winner] < netinput[i]) {
  217.         winner = i;
  218.     }
  219.     }
  220.  
  221.     activation[winner] = 1;
  222. }
  223.  
  224. change_weights()
  225. {
  226.     int     i;
  227.     float   nactive = 0;
  228.  
  229.     for (i = 0; i < ninputs; i++) {
  230.     if (activation[i])
  231.         nactive += 1;
  232.     }
  233.  
  234.     if(nactive == 0) return;
  235.  
  236.     for (i = 0; i < ninputs; i++) {
  237.     weight[winner][i] += 
  238.       lrate * ((activation[i] / nactive) 
  239.                - weight[winner][i]);
  240.     }
  241. }
  242.  
  243. setinput() {
  244.     register int    i;
  245.     register float  *pp;
  246.  
  247.     for (i = 0, pp = ipattern[patno]; i < ninputs; i++, pp++) {
  248.         activation[i] = *pp;
  249.     }
  250.     strcpy(cpname,pname[patno]);
  251. }
  252.  
  253.  
  254. setup_pattern() {
  255.     setinput();
  256. }
  257.  
  258. trial() {
  259.     setup_pattern();
  260.     compute_output();
  261. }
  262.  
  263.  
  264. ptrain() {
  265.   return(train('p'));
  266. }
  267.  
  268. strain() {
  269.   return(train('s'));
  270. }
  271.  
  272. train(c) char c; {
  273.     int     t,i,old,npat;
  274.     char    *str;
  275.  
  276.     if (!System_Defined)
  277.     if (!define_system())
  278.         return(BREAK);
  279.  
  280.     for (t = 0; t < nepochs; t++) {
  281.     if (!tallflag) epochno++;
  282.     for (i = 0; i < npatterns; i++)
  283.         used[i] = i;
  284.     if (c == 'p') {
  285.       for (i = 0; i < npatterns; i++) {
  286.         npat = rnd() * (npatterns - i) + i;
  287.         old = used[i];
  288.         used[i] = used[npat];
  289.         used[npat] = old;
  290.       }
  291.     }
  292.     for (i = 0; i < npatterns; i++) {
  293.         if (Interrupt) {
  294.         Interrupt_flag = 0;
  295.         update_display();
  296.             if (contin_test() == BREAK) return(BREAK);
  297.         }
  298.         patno = used[i];
  299.         trial();
  300.         if (lflag) change_weights();
  301.         if (step_size == PATTERN) {
  302.           update_display();
  303.           if (single_flag) {
  304.              if (contin_test() == BREAK) return(BREAK);
  305.           }
  306.         }
  307.     }
  308.     if (step_size == EPOCH) {
  309.      update_display();
  310.      if (single_flag) {
  311.              if (contin_test() == BREAK) return(BREAK);
  312.      }
  313.         }
  314.     }
  315.     if (step_size == NEPOCHS) {
  316.         update_display();
  317.     }
  318.     return(CONTINUE);
  319. }
  320.  
  321. tall() {
  322.   int save_lflag;
  323.   int save_single_flag;
  324.   int save_nepochs;
  325.   int save_step_size;
  326.   
  327.   save_lflag = lflag;  lflag = 0;
  328.   save_single_flag = single_flag; 
  329.   if (in_stream == stdin) single_flag = 1;
  330.   save_nepochs = nepochs;  nepochs = 1;
  331.   save_step_size = step_size;
  332.   if (step_size > PATTERN) step_size = PATTERN;
  333.   tallflag = 1;
  334.   train('s');
  335.   tallflag = 0;
  336.   lflag = save_lflag;
  337.   nepochs = save_nepochs;
  338.   single_flag = save_single_flag;
  339.   step_size = save_step_size;
  340.   return(CONTINUE);
  341. }
  342.   
  343. test_pattern() {
  344.     char   *str;
  345.  
  346.     if (!System_Defined)
  347.     if (!define_system())
  348.         return(BREAK);
  349.  
  350.     str = get_command("Test which pattern? ");
  351.     if(str == NULL) return(CONTINUE);
  352.     if ((patno = get_pattern_number(str)) < 0 ) {
  353.     return(put_error("Invalid pattern specification."));
  354.     }
  355.     trial();
  356.     update_display();
  357.     return(CONTINUE);
  358. }
  359.  
  360. newstart() {
  361.     random_seed = rand();
  362.     reset_weights();
  363. }
  364.  
  365. reset_weights() {
  366.     register int    i,j;
  367.     
  368.     epochno = 0;
  369.     cpname[0] = '\0';
  370.  
  371.     if (!System_Defined)
  372.     if (!define_system())
  373.         return(BREAK);
  374.  
  375.  
  376.     srand(random_seed);
  377.     
  378.     for (j = ninputs; j < nunits; j++) {
  379.     netinput[j] = 0;
  380.     for (i = 0; i < ninputs; i++) {
  381.         netinput[j] += weight[j][i] = rnd();
  382.     }
  383.     }
  384.  
  385.     for (j = ninputs; j < nunits; j++) {
  386.     for (i = 0; i < ninputs; i++) {
  387.         weight[j][i] *= 1.0/netinput[j];
  388.     }
  389.     }
  390.  
  391.     for (i = ninputs; i < nunits; i++) {
  392.       netinput[i]=activation[i] = 0.0;
  393.     }
  394.     for (i = 0; i < ninputs; i++) {
  395.       activation[i] = 0.0;
  396.     }
  397.     update_display();
  398.     return(CONTINUE);
  399. }
  400.  
  401. init_system() {
  402.     int     strain(), ptrain(), tall(), test_pattern(),reset_weights();
  403.     int        get_unames();
  404.  
  405.     (void) install_command("strain", strain, BASEMENU,(int *) NULL);
  406.     (void) install_command("ptrain", ptrain, BASEMENU,(int *) NULL);
  407.     (void) install_command("tall", tall, BASEMENU,(int *) NULL);
  408.     (void) install_command("test", test_pattern, BASEMENU,(int *) NULL);
  409.     (void) install_command("newstart",newstart,BASEMENU,(int *)NULL);
  410.     (void) install_command("reset",reset_weights,BASEMENU,(int *)NULL);
  411.     (void) install_command("weights", get_weights, GETMENU,(int *) NULL);
  412.     (void) install_command("weights", save_weights, SAVEMENU,(int *) NULL);
  413.     (void) install_command("patterns", get_patterns, GETMENU,(int *) NULL);
  414.     (void) install_command("unames", get_unames, GETMENU,(int *) NULL);
  415.     (void) install_var("noutputs", Int,(int *) & noutputs, 0, 0, SETCONFMENU);
  416.     (void) install_var("ninputs", Int,(int *) & ninputs, 0, 0, SETCONFMENU);
  417.     (void) install_var("nunits", Int,(int *) & nunits, 0, 0, SETCONFMENU);
  418.     (void) install_var("lrate", Float,(int *) & lrate, 0, 0, SETPARAMMENU);
  419.     (void) install_var("lflag", Int,(int *) & lflag, 0, 0, SETPCMENU);
  420.     (void) install_var("nepochs", Int,(int *) & nepochs, 0, 0, SETPCMENU);
  421.     (void) install_var("epochno", Int,(int *) & epochno, 0, 0, SETSVMENU);
  422.     (void) install_var("patno", Int,(int *) & patno, 0, 0, SETSVMENU);
  423.     
  424.     init_patterns();
  425. }
  426.