home *** CD-ROM | disk | FTP | other *** search
/ ARM Club 3 / TheARMClub_PDCD3.iso / hensa / misc / b186_1 / Source / c / bp < prev    next >
Text File  |  1990-04-14  |  21KB  |  847 lines

  1. /*
  2.  
  3.        This file is part of the PDP software package.
  4.          
  5.        Copyright 1987 by James L. McClelland and David E. Rumelhart.
  6.        
  7.        Please refer to licensing information in the file license.txt,
  8.        which is in the same directory with this source file and is
  9.        included here by reference.
  10. */
  11.  
  12.  
  13. /* file: bp.c
  14.  
  15.     Do the actual work for the bp program.
  16.     
  17.     First version implemented by Elliot Jaffe
  18.     
  19.     Date of last revision:  8-12-87/JLM
  20. */
  21.  
  22. #include "general.h"
  23. #include "bp.h"
  24. #include "variable.h"
  25. #include "weights.h"
  26. #include "patterns.h"
  27. #include "command.h"
  28.  
  29.  
  30. char   *Prompt = "bp: ";
  31. char   *Default_step_string = "epoch";
  32. char    grain_string[20] = "pattern";
  33. boolean System_Defined = FALSE;
  34. boolean lflag = 1;
  35. boolean cascade = 0;
  36. int     epochno = 0;
  37. int     cycleno = 0;
  38. int     nepochs = 500;
  39. int    ncycles = 50;
  40. int     patno = 0;
  41. float   tss = 0.0;
  42. float   pss = 0.0;
  43. float   ecrit = 0.0;
  44. float    crate = .05;
  45. float    drate = .95;
  46. float    gcor = 0.0;
  47. int    follow = 0;
  48. float  *netinput = NULL;
  49. float  *activation = NULL;
  50. float  *error = NULL;
  51. float  *target = NULL;
  52. float  *delta = NULL;
  53. float  **dweight = NULL;
  54. float  **pwed = NULL;
  55. float  *dbias = NULL;
  56. float  *pbed = NULL;
  57. float   tmax = 1.0;
  58. float   momentum = 0.9;
  59. float    mu = .5;
  60. int    tallflag = 0;
  61.  
  62. extern int read_weights();
  63. extern int write_weights();
  64.  
  65. init_system() {
  66.     int     strain (), ptrain (), tall (), test_pattern (), reset_weights();
  67.     int        get_unames(), set_lgrain(), cycle(), newstart();
  68.     int change_lrate(), change_crate(), set_follow_mode();
  69.  
  70.     epsilon_menu = SETCONFMENU;
  71.  
  72.     init_weights();
  73.  
  74.     (void) install_command("strain", strain, BASEMENU,(int *) NULL);
  75.     (void) install_command("ptrain", ptrain, BASEMENU,(int *) NULL);
  76.     (void) install_command("tall", tall, BASEMENU,(int *) NULL);
  77.     (void) install_command("test", test_pattern, BASEMENU,(int *) NULL);
  78.     (void) install_command("cycle", cycle, BASEMENU,(int *) NULL);
  79.     (void) install_command("reset",reset_weights,BASEMENU,(int *)NULL);
  80.     (void) install_command("newstart",newstart,BASEMENU,(int *)NULL);
  81.     (void) install_command("unames", get_unames, GETMENU,(int *) NULL);
  82.     (void) install_command("patterns", get_pattern_pairs, 
  83.                            GETMENU,(int *) NULL);
  84.     (void) install_var("lflag", Int,(int *) & lflag, 0, 0, SETPCMENU);
  85.     (void) install_var("lgrain", String, (int *) grain_string,0, 0,NOMENU);
  86.     (void) install_command("lgrain",set_lgrain,SETMODEMENU,(int *) NULL);
  87.     (void) install_var("follow", Int, (int *) & follow,0, 0,NOMENU);
  88.     (void) install_command("follow",set_follow_mode,SETMODEMENU,(int *) NULL);
  89.     (void) install_var("cascade", Int,(int *) & cascade, 0, 0, SETMODEMENU);
  90.     (void) install_var("nepochs", Int,(int *) & nepochs, 0, 0, SETPCMENU);
  91.     (void) install_var("ncycles", Int,(int *) & ncycles, 0, 0, SETPCMENU);
  92.     (void) install_var("epochno", Int,(int *) & epochno, 0, 0, SETSVMENU);
  93.     (void) install_var("patno", Int,(int *) & patno, 0, 0, SETSVMENU);
  94.     (void) install_var("cycleno", Int,(int *) & cycleno, 0, 0, SETSVMENU);
  95.     init_pattern_pairs();
  96.     (void) install_var("tss", Float,(int *) & tss, 0, 0, SETSVMENU);
  97.     (void) install_var("pss", Float,(int *) & pss, 0, 0, SETSVMENU);
  98.     (void) install_var("gcor", Float,(int *) & gcor, 0, 0, SETSVMENU);
  99.     (void) install_var("momentum", Float,(int *) &momentum,0,0,SETPARAMMENU);
  100.     (void) install_var("mu", Float,(int *) &mu,0,0,SETPARAMMENU);
  101.     (void) install_command("lrate", change_lrate, SETPARAMMENU, (int *) NULL);
  102.     (void) install_command("crate", change_crate, SETPARAMMENU, (int *) NULL);
  103.     (void) install_var("lrate", Float,(int *) & lrate, 0, 0, NOMENU);
  104.     (void) install_var("crate", Float,(int *) & crate, 0, 0, NOMENU);
  105.     (void) install_var("ecrit", Float,(int *) & ecrit, 0, 0, SETPCMENU);
  106.     (void) install_var("tmax", Float,(int *) & tmax, 0, 0, SETPARAMMENU);
  107. }
  108.  
  109.  
  110. define_system() {
  111.     register int    i,
  112.                     j;
  113.     float *tmp;
  114.  
  115.     if (!nunits) {
  116.     put_error("cannot init bp system, nunits not defined");
  117.     return(FALSE);
  118.     }
  119.     else
  120.     if (!noutputs) {
  121.         put_error("cannot init bp system, noutputs not defined");
  122.         return(FALSE);
  123.     }
  124.     else
  125.     if (!ninputs) {
  126.         put_error("cannot init bp system, ninputs not defined");
  127.         return(FALSE);
  128.     }
  129.     else
  130.     if (!(nunits && noutputs && ninputs)) {
  131.         put_error("cannot run bp system, nunits not defined");
  132.         return(FALSE);
  133.     }
  134.  
  135.     netinput = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  136.     (void) install_var("netinput", Vfloat,(int *) netinput, nunits, 0, SETSVMENU);
  137.     for (i = 0; i < nunits; i++)
  138.     netinput[i] = 0.0;
  139.  
  140.     activation = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  141.     (void) install_var("activation",Vfloat,(int *)activation,nunits,0,SETSVMENU);
  142.     for (i = 0; i < nunits; i++)
  143.     activation[i] = 0.0;
  144.  
  145.     delta = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  146.     (void) install_var("delta", Vfloat,(int *) delta, nunits, 0, SETSVMENU);
  147.     for (i = 0; i < nunits; i++)
  148.     delta[i] = 0.0;
  149.  
  150.     error = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  151.     (void) install_var("error", Vfloat,(int *) error, nunits, 0, SETSVMENU);
  152.     for (i = 0; i < nunits; i++)
  153.     error[i] = 0.0;
  154.  
  155.     target = (float *) emalloc((unsigned)(sizeof(float) * noutputs));
  156.     (void) install_var("target", Vfloat,(int *) target, noutputs,
  157.         0, SETSVMENU);
  158.     for (i = 0; i < noutputs; i++)
  159.     target[i] = 0.0;
  160.  
  161.     dweight = ((float **)
  162.        emalloc((unsigned)(sizeof(float *)*nunits)));
  163.     (void) install_var("dweight", PVweight,(int *) dweight, nunits,
  164.                     nunits, SETSVMENU);
  165.     
  166.     for (i = 0; i < nunits; i++) {
  167.     dweight[i] = ((float *)
  168.          emalloc((unsigned)(sizeof(float)*num_weights_to[i])));
  169.     for (j = 0; j < num_weights_to[i]; j++){
  170.         dweight[i][j] = 0.0;
  171.     }
  172.     }
  173.     dbias = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  174.     (void) install_var("dbias", Vfloat,(int *) dbias,
  175.                     nunits, 0, SETSVMENU);
  176.     for (i = 0; i < nunits; i++)
  177.     dbias[i] = 0.0;
  178. /*  now being done in weights.c
  179.     wed = ((float **) emalloc((unsigned)(sizeof(float *)*nunits)));
  180.     (void) install_var("wed", PVweight,(int *) wed, nunits,
  181.                     nunits, SETSVMENU);
  182.     for (i = 0; i < nunits; i++) {
  183.     wed[i] = ((float *)
  184.           emalloc((unsigned)(sizeof(float)*num_weights_to[i])));
  185.     for (j = 0; j < num_weights_to[i]; j++) {
  186.         wed[i][j] = 0.0;
  187.     }
  188.     }
  189.  
  190.     bed = (float *) emalloc((unsigned)(sizeof(float) * nunits));
  191.     (void) install_var("bed", Vfloat,(int *) bed,
  192.                     nunits, 0, SETSVMENU);
  193.     for (i = 0; i < nunits; i++)
  194.     bed[i] = 0.0;
  195. */
  196.     System_Defined = TRUE;
  197.     return(TRUE);
  198. }
  199.  
  200.  
  201. float  logistic (x)
  202. float  x;
  203. {
  204.     double  exp ();
  205.  
  206. #ifdef MSDOS    
  207. /* we are conservative under msdos to avoid potential underflow
  208.    problems that may arise from returning extremal values -- jlm */
  209.     if (x > 11.5129)
  210.     return(.99999);
  211.       else
  212.     if (x < -11.5129)
  213.         return(.00001);
  214. #else
  215. /* .99999988 is very close to the largest single precis value
  216.    that is resolvably less than 1.0 -- jlm */
  217.       if (x > 15.935773)
  218.     return(.99999988);
  219.       else
  220.       if (x < -15.935773)
  221.         return(.00000012);
  222. #endif
  223.     else
  224.        return(1.0 / (1.0 + (float) exp( (double) ((-1.0) * x))));
  225. }
  226.  
  227. init_output() {
  228.     register int i,j;
  229.     register float *sender, *wt, *end;
  230.     float net;
  231.     
  232.     /* initializes the network to asymptotic outputs given 0 input */
  233.     cycleno = 0;
  234.     
  235.     for (i = ninputs; i < nunits; i++) {/* to this unit */
  236.     net = bias[i];
  237.     sender = &activation[first_weight_to[i]];
  238.     wt = weight[i];
  239.     end = sender + num_weights_to[i];
  240.     for (j = first_weight_to[i]; j < ninputs; j++) {
  241.         sender++; wt++; /* step over input units to 
  242.                    initialize to all-zero input case */
  243.     }
  244.     for (; sender < end ; ){/* from this unit */
  245.         net += (*sender++) * (*wt++);
  246.     }
  247.     netinput[i] = net;
  248.     activation[i] = (float) logistic(net);
  249.     }
  250.     if (step_size < PATTERN) {
  251.     update_display();
  252.     if (single_flag) {
  253.         if (contin_test() == BREAK) return (BREAK);
  254.     }
  255.     }
  256.     if (Interrupt) {
  257.         Interrupt_flag = 0;
  258.         update_display();
  259.         if (contin_test() == BREAK) return (BREAK);
  260.     }
  261.     return(CONTINUE);
  262. }
  263.  
  264. cycle() {
  265.     register int i,cy;
  266.     register float *sender,*wt,*end;
  267.     float newinput;
  268.  
  269.     for (cy = 0; cy < ncycles; cy++) {
  270.     cycleno++;
  271.     for (i = ninputs; i < nunits; i++) {/* to this unit */
  272.         newinput = bias[i];
  273.         sender = &activation[first_weight_to[i]];
  274.         end = sender + num_weights_to[i];
  275.         wt = weight[i];
  276.         for (;sender<end;) {/* from this unit */
  277.         newinput += (*sender++) * (*wt++);
  278.         }
  279.         netinput[i] = crate * newinput + drate * netinput[i];
  280.         activation[i] = (float) logistic(netinput[i]);
  281.     }
  282.     if (step_size == CYCLE) {
  283.         update_display();
  284.         if (single_flag) {
  285.         if (contin_test() == BREAK) return (BREAK);
  286.         }
  287.     }
  288.     if (Interrupt) {
  289.         update_display();
  290.         Interrupt_flag = 0;
  291.         if (contin_test() == BREAK) return (BREAK);
  292.     }
  293.     }
  294.     if (step_size == NCYCLES) {
  295.     update_display();
  296.     }
  297.     return(CONTINUE);
  298. }
  299.     
  300.  
  301. compute_output() {
  302.     register int    i,j;
  303.     float *sender, *wt, *end;
  304.     float net;
  305.  
  306.     for (i = ninputs; i < nunits; i++) {/* to this unit */
  307.     net = bias[i];
  308.     sender = &activation[first_weight_to[i]];
  309.         end = sender + num_weights_to[i];
  310.     wt = weight[i];
  311.     for (; sender < end ;){/* from this unit */
  312.         net += (*sender++)*(*wt++);
  313.     }
  314.     netinput[i] = net;
  315.     activation[i] = (float) logistic(net);
  316.     }
  317. }
  318.  
  319. compute_error() {
  320.     register int i,j,first,num;
  321.     float *wt, *sender, *end;
  322.     float del;
  323.  
  324.     for (i = ninputs; i < nunits - noutputs; i++) {
  325.     error[i] = 0.0;
  326.     }
  327.  
  328.     for (j = 0; i < nunits; j++, i++) {
  329.     if(target[j] >= 0)   /* We care about this one */
  330.          error[i] = target[j] - activation[i];
  331.     else
  332.          error[i] = 0.0;
  333.     }
  334.  
  335.     for (i = nunits - 1; i >= ninputs; i--) {
  336.     del = delta[i] = error[i] * activation[i] * (1.0 - activation[i]);
  337.     if (first_weight_to[i] + num_weights_to[i] < ninputs) continue;
  338.     /* no point in propagating error back to input units */
  339.     sender = &error[first_weight_to[i]];
  340.     end = sender +     num_weights_to[i];
  341.     wt = weight[i];
  342.     for (;sender < end;) {
  343.         *sender++ += del * (*wt++);
  344.     }
  345.     }
  346. }
  347.  
  348. compute_wed() {
  349.     register int   i,j;
  350.     float *wi, *sender, *end;
  351.     float del;
  352.  
  353.     for (i = ninputs; i < nunits; i++) {
  354.     bed[i] += delta[i];
  355.     sender = &activation[first_weight_to[i]];
  356.     end = sender + num_weights_to[i];
  357.     del = delta[i];
  358.     wi = wed[i];
  359.     for (;sender < end;) {
  360.         *wi++ += del * (*sender++);
  361.     }
  362.     }
  363. }
  364.  
  365. clear_wed() {
  366.     register int   i,j,num;
  367.     register float *wi, *end;
  368.  
  369.     for (i = ninputs; i < nunits; i++) {
  370.     bed[i] = 0.0;
  371.     wi = wed[i];
  372.     end = wi + num_weights_to[i];
  373.     for (; wi < end;) {
  374.         *wi++ = 0.0;
  375.     }
  376.     }
  377. }
  378.  
  379. change_weights() {
  380.     register int    i;
  381.     register float *wt, *dwt, *epi, *wi, *end;
  382.     
  383.     link_sum();
  384.     
  385.     for (i = ninputs; i < nunits; i++) {
  386.     dbias[i] = bepsilon[i]*bed[i] + momentum * dbias[i];
  387.     bias[i] += dbias[i];
  388.     bed[i] = 0.0;
  389.     wt = weight[i];
  390.     dwt= dweight[i];
  391.     wi = wed[i];
  392.     epi = epsilon[i];
  393.     end = wt + num_weights_to[i];
  394.     for (; wt < end; ) {
  395.         *dwt = (*epi++)*(*wi) + momentum * (*dwt);
  396.         *wt++ += *dwt++;
  397.         *wi++ = 0.0;
  398.     }
  399.     }
  400.     pos_neg_constraints();
  401. }
  402.  
  403. float p_css = (float) 0.0;
  404. float css = (float) 0.0;
  405.  
  406. change_weights_follow() {
  407.     register int    i;
  408.     register float *wt, *dwt, *epi, *wi, *end, *pwi;
  409.     float tb, dp, den;
  410.  
  411.     p_css = css;
  412.     css = 0.0;
  413.     dp = 0.0;
  414.     
  415.     link_sum();
  416.  
  417.     for (i = ninputs; i < nunits; i++) {
  418.         tb = bed[i];
  419.     dbias[i] = tb*bepsilon[i] + momentum * dbias[i];
  420.     bias[i] += dbias[i];
  421.     css += ((double) tb)*((double) tb); 
  422.     dp += ((double) tb)*((double) pbed[i]); 
  423.     pbed[i] = tb;
  424.     bed[i] = 0.0;
  425.     wt = weight[i];
  426.     dwt= dweight[i];
  427.     wi = wed[i];
  428.     pwi = pwed[i];
  429.     epi = epsilon[i];
  430.     end = wt + num_weights_to[i];
  431.     for (; wt < end; ) {
  432.         *dwt = (*epi++)*(*wi) + momentum * (*dwt);
  433.         *wt++ += *dwt++;
  434.         css += ((double) (*wi))*((double) (*wi)); 
  435.          dp += ((double) (*wi))*((double) (*pwi)); 
  436.         *pwi++ = *wi;
  437.         *wi++ = 0.0;
  438.     }
  439.     }
  440.     
  441.     den = p_css * css;
  442.     if (den > 0.0) gcor = dp/(sqrt(den));
  443.     else gcor = 0.0;
  444.  
  445.     pos_neg_constraints();
  446. }
  447.  
  448. constrain_weights() {
  449.     pos_neg_constraints();
  450.     link_constraints();
  451. }
  452.  
  453. pos_neg_constraints() {
  454.     float **fpt;
  455.  
  456.     for (fpt = positive_constraints; fpt && *fpt; fpt++)
  457.     if (**fpt < 0.0)
  458.         **fpt = 0.0;
  459.  
  460.     for (fpt = negative_constraints; fpt && *fpt; fpt++)
  461.     if (**fpt > 0.0)
  462.         **fpt = 0.0;
  463. }
  464.  
  465. link_constraints() {
  466.     register int    i,j;
  467.     float   t;
  468.  
  469.     for (i = 0; i < nlinks; i++) {
  470.     t = *constraints[i].cvec[0];
  471.     for (j = 1; j < constraints[i].num; j++) {
  472.         *constraints[i].cvec[j] = t;
  473.     }
  474.     }
  475. }
  476.  
  477. link_sum() {
  478.     register int    i,j;
  479.     float   ss;
  480.  
  481.     for (i = 0; i < nlinks; i++) {
  482.     ss = 0.0;
  483.     for (j = 0; j < constraints[i].num; j++) {
  484.         ss += *constraints[i].ivec[j];
  485.     }
  486.     for (j = 0; j < constraints[i].num; j++) {
  487.         *constraints[i].ivec[j] = ss;
  488.     }
  489.     }
  490. }
  491.  
  492. setinput() {
  493.     register int    i,prev_index;
  494.     register float  *pp;
  495.  
  496.     for (i = 0, pp = ipattern[patno]; i < ninputs; i++, pp++) {
  497.     if ( *pp < 0.0) {
  498.         prev_index = ((int) (-(*pp)));
  499.         activation[i] = mu * activation[i] + activation[prev_index];
  500.         /* user must be careful that prev_index >= i */
  501.     }
  502.     else {
  503.         activation[i] = *pp;
  504.     }
  505.     }
  506.     
  507.     strcpy(cpname,pname[patno]);
  508. }
  509.  
  510. settarget() {
  511.     register int    i;
  512.     register float *pp;
  513.  
  514.     for (i = 0, pp = tpattern[patno]; i < noutputs; i++, pp++) {
  515.     target[i] = *pp;
  516.     if (target[i] == 1.0) {
  517.         target[i] = tmax;
  518.     }
  519.     else if(target[i] == 0.0) {
  520.             target[i] = 1 - tmax;
  521.     }
  522.     }
  523. }
  524.  
  525. setup_pattern() {
  526.     setinput();
  527.     settarget();
  528. }
  529.  
  530. trial() {
  531.     setup_pattern();
  532.     if (cascade) {
  533.     if (init_output() == BREAK) return (BREAK);
  534.     if (cycle() == BREAK) return (BREAK);
  535.     }
  536.     else  {    
  537.         compute_output();
  538.         if (step_size < PATTERN) {
  539.           update_display();
  540.           if (single_flag) {
  541.               if (contin_test() == BREAK) return(BREAK);
  542.           }
  543.     }
  544.     }
  545.     compute_error();
  546.     sumstats();
  547.     return (CONTINUE);
  548. }
  549.  
  550. sumstats() {
  551.     register int    i,j;
  552.     register float t;
  553.     pss = 0.0;
  554.     
  555.  
  556.     for (j = 0,i = nunits - noutputs; i < nunits; i++,j++) {
  557.       if (target[j] >= 0) {
  558.           t = error[i];
  559.     pss += t*t;
  560.       }
  561.     }
  562.     tss += pss;
  563. }
  564.  
  565. ptrain() {
  566.   return(train('p'));
  567. }
  568.  
  569. strain() {
  570.   return(train('s'));
  571. }
  572.  
  573. train(c) char c; {
  574.     int     t,i,old,npat;
  575.     char    *str;
  576.  
  577.     if (!System_Defined)
  578.     if (!define_system())
  579.         return(BREAK);
  580.  
  581.     /* in case prev epoch was terminated early we clear the weds and beds */
  582.     if (!tallflag) clear_wed();
  583.     cycleno = 0;
  584.     for (t = 0; t < nepochs; t++) {
  585.     if (!tallflag) epochno++;
  586.     for (i = 0; i < npatterns; i++)
  587.         used[i] = i;
  588.     if (c == 'p') {
  589.       for (i = 0; i < npatterns; i++) {
  590.         npat = rnd() * (npatterns - i) + i;
  591.         old = used[i];
  592.         used[i] = used[npat];
  593.         used[npat] = old;
  594.       }
  595.     }
  596.     tss = 0.0;
  597.     for (i = 0; i < npatterns; i++) {
  598.         patno = used[i];
  599.         if (trial() == BREAK) return (BREAK);
  600.         if (lflag) {
  601.           compute_wed();
  602.           if (grain_string[0] == 'p') {
  603.             if (follow) change_weights_follow();
  604.             else change_weights();
  605.           }
  606.         }
  607.         if (step_size == PATTERN) {
  608.           update_display();
  609.           if (single_flag) {
  610.               if (contin_test() == BREAK) return(BREAK);
  611.           }
  612.         }
  613.         if (Interrupt) {
  614.         Interrupt_flag = 0;
  615.         update_display();
  616.         if (contin_test() == BREAK) return(BREAK);
  617.         }
  618.     }
  619.     if (lflag && grain_string[0] == 'e') {
  620.       if (follow) change_weights_follow();
  621.       else change_weights();
  622.     }
  623.     if (step_size == EPOCH) {
  624.      update_display();
  625.      if (single_flag) {
  626.               if (contin_test() == BREAK) return(BREAK);
  627.      }
  628.         }
  629.     if (tss < ecrit) break;
  630.     }
  631.     if (step_size == NEPOCHS) {
  632.       update_display();
  633.     }
  634.     return(CONTINUE);
  635. }
  636.  
  637. tall() {
  638.   int save_lflag;
  639.   int save_single_flag;
  640.   int save_nepochs;
  641.   int save_step_size;
  642.   
  643.   save_lflag = lflag;  lflag = 0;
  644.   save_single_flag = single_flag; 
  645.   if (in_stream == stdin) single_flag = 1;
  646.   save_step_size = step_size; 
  647.   if (step_size > PATTERN) step_size = PATTERN;
  648.   save_nepochs = nepochs;  nepochs = 1;
  649.   tallflag = 1;
  650.   train('s');
  651.   tallflag = 0;
  652.   lflag = save_lflag;
  653.   nepochs = save_nepochs;
  654.   single_flag = save_single_flag;
  655.   step_size = save_step_size;
  656.   return(CONTINUE);
  657. }
  658.   
  659. test_pattern() {
  660.     char   *str;
  661.     int save_single_flag;
  662.     int save_step_size;
  663.  
  664.     if (!System_Defined)
  665.     if (!define_system())
  666.         return(BREAK);
  667.  
  668.     tss = 0.0;
  669.  
  670.     str = get_command("Test which pattern? ");
  671.     if(str == NULL) return(CONTINUE);
  672.     if ((patno = get_pattern_number(str)) < 0) {
  673.        return(put_error("Invalid pattern specification."));
  674.     }
  675.     if (cascade) {
  676.         save_single_flag = single_flag; single_flag = 1;
  677.     save_step_size = step_size; step_size = CYCLE;
  678.     }
  679.     trial();
  680.     update_display();
  681.     if (cascade) {
  682.     single_flag = save_single_flag;
  683.     step_size = save_step_size;
  684.     }
  685.     return(CONTINUE);
  686. }
  687.  
  688. newstart() {
  689.     random_seed = rand();
  690.     reset_weights();
  691. }
  692.  
  693. reset_weights() {
  694.     register int    i,j,first,num;
  695.     char ch;
  696.     
  697.     epochno = 0;
  698.     pss = tss = gcor = 0.0;
  699.     cpname[0] = '\0';
  700.     srand(random_seed);
  701.  
  702.     if (!System_Defined)
  703.     if (!define_system())
  704.         return(BREAK);
  705.  
  706.     for (j = 0; j < nunits; j++) {
  707.     first = first_weight_to[j];
  708.     num = num_weights_to[j];
  709.       for (i = 0; i < num; i++) {
  710.     wed[j][i] = dweight[j][i] = 0.0;
  711.     if (pwed) pwed[j][i] = 0.0;
  712.     ch = wchar[j][i];
  713.     if (isupper(ch)) ch = tolower(ch);
  714.     if (ch == '.') {
  715.         weight[j][i] = 0.0;        
  716.     }
  717.     else {
  718.         if (constants[ch - 'a'].random) {
  719.             if (constants[ch - 'a'].positive) {
  720.             weight[j][i] = wrange * rnd();
  721.             }
  722.             else
  723.             if (constants[ch - 'a'].negative) {
  724.                 weight[j][i] = wrange * (rnd() - 1);
  725.             }
  726.             else
  727.             weight[j][i] = wrange * (rnd() -.5);
  728.         }
  729.         else {
  730.             weight[j][i] = constants[ch - 'a'].value;
  731.         }
  732.     }
  733.       }
  734.       bed[j] = dbias[j] = 0.0;
  735.       if (pbed) pbed[j] = 0.0;
  736.       ch = bchar[j];
  737.       if (isupper(ch)) ch = tolower(ch);
  738.       if (ch == '.') {
  739.         bias[j] = 0;
  740.       }
  741.       else {
  742.         if (constants[ch - 'a'].random) {
  743.             if (constants[ch - 'a'].positive) {
  744.             bias[j] = wrange * rnd();
  745.             }
  746.             else
  747.             if (constants[ch - 'a'].negative) {
  748.                 bias[j] = wrange * (rnd() - 1);
  749.             }
  750.             else
  751.             bias[j] = wrange * (rnd() -.5);
  752.         }
  753.         else {
  754.             bias[j] = constants[ch - 'a'].value;
  755.         }
  756.       }
  757.     }
  758.     constrain_weights();
  759.     for (i = 0; i < noutputs; i++) {
  760.       target[i] = 0.0;
  761.     }
  762.     for (i = 0; i < nunits; i++) {
  763.       netinput[i] = activation[i] = delta[i] = error[i] = 0.0;
  764.     }
  765.     update_display();
  766.     return(CONTINUE);
  767. }
  768.  
  769. set_lgrain() {
  770.     char old_grain_string[STRINGLENGTH];
  771.     struct Variable *vp, *lookup_var();
  772.     
  773.     strcpy(old_grain_string,grain_string);
  774.  
  775.     vp = lookup_var("lgrain");
  776.     change_variable("lgrain",vp);
  777.     
  778.     if (startsame(grain_string,"epoch")) strcpy(grain_string,"epoch");
  779.     else if (startsame(grain_string,"pattern"))
  780.                         strcpy(grain_string,"pattern");
  781.     else {
  782.     strcpy(grain_string,old_grain_string);
  783.         return(put_error("urecognized grain -- not changed."));
  784.     }
  785.     return(CONTINUE);
  786. }
  787.  
  788. set_follow_mode() {
  789.     struct Variable *vp, *lookup_var();
  790.     int pv, i, j;
  791.     pv = follow;
  792.     
  793.     vp = lookup_var("follow");
  794.     change_variable("follow",vp);
  795.     
  796.     if (follow == 0) return (CONTINUE);
  797.     if (pwed == NULL) {
  798.       pwed = ((float **) emalloc((unsigned)(sizeof(float *)*nunits)));
  799.       (void) install_var("pwed", PVweight,(int *) pwed, nunits,
  800.                     nunits, NOMENU);
  801.       for (i = 0; i < nunits; i++) {
  802.     pwed[i] = ((float *)
  803.           emalloc((unsigned)(sizeof(float)*num_weights_to[i])));
  804.       }
  805.  
  806.       pbed = ((float *) emalloc((unsigned)(sizeof(float) * nunits)));
  807.       (void) install_var("pbed", Vfloat,(int *) pbed,
  808.                     nunits, 0, NOMENU);
  809.     }
  810.     if (pv == 0) {
  811.       for (i = 0; i < nunits; i++) {
  812.     for (j = 0; j < num_weights_to[i]; j++) {
  813.         pwed[i][j] = 0.0;
  814.     }
  815.       }
  816.       for (i = 0; i < nunits; i++)
  817.     pbed[i] = 0.0;
  818.     }
  819.     gcor = css = 0.0;
  820.     return(CONTINUE);
  821. }
  822.  
  823. change_crate() {
  824.     struct Variable *varp;
  825.  
  826.     if ((varp = lookup_var("crate")) != NULL) {
  827.     change_variable("crate",(int *) varp);
  828.     }
  829.     else {
  830.     return(put_error("crate is not defined"));
  831.     }
  832.     drate = 1 - crate;
  833.     return(CONTINUE);
  834. }
  835.  
  836. init_weights() {
  837.     int define_bp_network();
  838.     (void) install_command("network", define_bp_network,GETMENU,(int *) NULL);
  839.     (void) install_command("weights", read_weights, GETMENU,(int *) NULL);
  840.     (void) install_command("weights", write_weights, SAVEMENU,(int *) NULL);
  841.     (void) install_var("nunits", Int,(int *) & nunits, 0, 0, SETCONFMENU);
  842.     (void) install_var("ninputs", Int,(int *) & ninputs, 0, 0, SETCONFMENU);
  843.     (void) install_var("noutputs", Int,(int *) & noutputs, 0, 0, SETCONFMENU);
  844.     (void) install_var("wrange",Float,(int *) &wrange,0,0, SETPARAMMENU);
  845. }
  846.  
  847.