home *** CD-ROM | disk | FTP | other *** search
/ vis-ftp.cs.umass.edu / vis-ftp.cs.umass.edu.tar / vis-ftp.cs.umass.edu / pub / Software / ASCENDER / umass_foa.tar / mdt_NEW / mdtree / mdtree.c < prev    next >
C/C++ Source or Header  |  1995-01-27  |  16KB  |  590 lines

  1.  
  2. /* =============================================================
  3.    mdtree:  Creates a multivatraite decision tree from the training 
  4.               data in a feature file.  The decision tree is then
  5.           converted to a lookup table, to be later used for 
  6.           classification in the focus-of-attention routine.
  7.  
  8.             Usage is as follows:
  9.               mdtree feature_file(input) lookup_table_file(output)
  10.  
  11.               'feature_file' is the training data gathered by use of the
  12.           'mdtrain' routine.
  13.  
  14.  
  15.    Shashi Buluswar
  16.    Computer Vision Laboratory
  17.    Dept of Computer Science
  18.    University of Massachusetts
  19.    Amherst, MA  01003
  20.  
  21.    Copyright 1995, University of Massachusetts - Amherst
  22.  
  23.    =============================================================== */
  24.  
  25.  
  26. #define MAIN
  27. #include "mdtree.h"
  28.  
  29. /* inner_product returns the inner product of to vectors */
  30. double inner_product (v1, v2)
  31.      LTU_vect v1, v2;
  32. {
  33.   int i;
  34.   double x=0.0;
  35.   for (i=0; i<num_feat; i++)  x = x + v1[i] * v2[i];
  36.   return x;
  37.  
  38. }
  39.  
  40. /* ----------------------------------------------------------------- 
  41.    Linear Threshold Unit (LTU) is a respresentation of the 
  42.    Multivariate test.  It is a binary test of the form
  43.    Transpose(inst)*wt >= 0, where inst is an instance description
  44.    (a pattern vector), consisting of a constant 1 and the n
  45.    features and wt is a vector of n+1 coefficients (weights).
  46.    if Transpose(inst)*wt >= 0 then the LTU infers that inst belongs 
  47.    to class 1; otherwise inst belongs to class2.
  48.    ----------------------------------------------------------------- */
  49.  
  50. int LTU (inst, wt) 
  51.      LTU_vect inst, wt;
  52. {
  53.   double lin_comb = 0.0;
  54.   
  55.   lin_comb = inner_product (inst, wt);
  56.   if (lin_comb < 0) return 0;
  57.   else return 1;
  58. }
  59.  
  60.  
  61. /* ----------------------------------------------------------------- 
  62.    Xk: instance array
  63.    Wk_1: previous weight array
  64.    Wk: current weight array -- to be determined by RLS_train
  65.    Pk: covariance matrix
  66.    Pk_1: previous covariance matrix
  67.  
  68.    ----------------------------------------------------------------- */
  69.  
  70. void RLS_train (Xk, Wk_1, Wk, Yk) 
  71.      LTU_vect Xk, Wk_1, Wk;
  72.      int Yk;
  73. {
  74.   LTU_vect_sqr Pk, Pk_1;
  75.   LTU_vect Pk_1Xk, XkTPk_1, tmp_vect;
  76.   int i, j;
  77.   double scalar_value, sum_of_products;
  78.  
  79.   if (loop == 1) {
  80.     /* initialize diagonal of covariance matrix */
  81.     for (i=0; i< num_feat; i++)
  82.       for (j=0; j< num_feat; j++) 
  83.     if (i==j) Pk_1[i][j] = large_num;
  84.     else Pk_1[i][j] = 0.0;
  85.     
  86.   }
  87.  
  88.  
  89.   /* RLS equations:
  90.                   Pk = Pk_1 - Pk_1Xk [1 + XkTPk_1 * Xk] ^-1 XkTPk_1
  91.           Kk = PkXk
  92.           Wk = Wk_1 - Kk (XkTWk - Yk)                        */
  93.  
  94.  
  95.  
  96.   /* Update Pk -- this is the first part of the RLS equation */
  97.  
  98.   /* Multiply Pk_1 (the array p at time k-1) by Xk (X at time k) */
  99.   for (i=0; i<num_feat; i++) {
  100.       Pk_1Xk [i] = 0.0;
  101.       for (j=0; j<num_feat; j++)
  102.     Pk_1Xk [i] += Pk_1[i][j] * Xk[j];
  103.     }
  104.  
  105.   /* Calculate the scalar value [1 + XkT_Pk_1Xk] ^ -1 */
  106.   scalar_value = 1.0;
  107.   for (i=0; i<num_feat; i++)
  108.     scalar_value += Pk_1Xk[i] * Xk[i];
  109.   scalar_value = 1.0/scalar_value;
  110.  
  111.  
  112.   /* Fold scalar_value into Pk_1Xk */
  113.   for (i=0; i<num_feat; i++) Pk_1Xk[i] *= scalar_value;
  114.  
  115.   /* Multiply XkT (Xk transposed) by Pk_1 */
  116.   for (i=0; i<num_feat; i++)
  117.     {
  118.     XkTPk_1[i] = 0;
  119.     for (j=0; j<num_feat; j++)
  120.       XkTPk_1 [i] += Xk[j] * Pk_1[j][i];
  121.   }
  122.   
  123.  
  124.   /* Multiply Pk_1Xk by XkTPk_1 */
  125.   for (i=0; i<num_feat; i++)
  126.     for (j=0; j<num_feat; j++)
  127.       Pk[i][j] = Pk_1Xk[i] * XkTPk_1[j];
  128.  
  129.   for (i=0; i<num_feat; i++)
  130.     for (j=0; j<num_feat; j++)
  131.       Pk[i][j] = Pk_1[i][j] - Pk[i][j];
  132.  
  133.   /* second part of RLS equation */
  134.   
  135.   /* Multiply XkT (Xktransposed) by Wk_1.  The result is a 
  136.      scalar_value.  Subtract Yk.                            */
  137.   scalar_value = 0.0 - Yk;
  138.   for (i=0; i<num_feat; i++) scalar_value += Xk[i] * Wk_1[i];
  139.  
  140.   /* Multiply Pk by Xk (= Kk) */
  141.   for (i=0; i<num_feat; i++) {
  142.     sum_of_products = 0.0;
  143.  
  144.     for (j=0; j<num_feat; j++) {
  145.       sum_of_products += Pk[i][j] * Xk[j];
  146.     }
  147.     tmp_vect[i] = sum_of_products * scalar_value;
  148.     
  149.   }
  150.  
  151.   for (i=0; i<num_feat; i++) Wk[i] = Wk_1[i] - tmp_vect[i];
  152.  
  153. }
  154.  
  155.  
  156. /* converged determines if the weight-training has converged, */
  157. /*       - returns 1 if converged                             */
  158. /*       -         0 otherwise                                */
  159. /* note: convergence is tested by (estimate == actual +- fudge_factor */
  160.  
  161. int converged (vp, vc)
  162.      LTU_vect vp, vc;
  163. {
  164.   int i, x=1;
  165.   double diff1, diff2;
  166.   for (i=0; i<num_feat; i++) {
  167.     diff1 = vc[i]-vp[i];
  168.     diff2 = vp[i]-vc[i];
  169.     if (diff1 > diff2) {
  170.       if (diff1 > small_num) x=0;
  171.     }
  172.     else 
  173.       if (diff2 > small_num) x=0;
  174.   }
  175.   return x;
  176. }
  177.  
  178. /* homogenous determines if a given set is homogeneous according to */
  179. /* the known classification of the training instances */
  180.  
  181. int homogenous (set, size, class)
  182.      struct instance *set;
  183.      int size, *class;
  184. {
  185.   int i=0, x=1;
  186.  
  187.   
  188.   *class = set[0].known_class;
  189.   for (i=0; i<size; i++)
  190.     if (!(set[i].known_class == *class)) {
  191.       x=0;
  192.     }
  193.  
  194.   return x;
  195.  
  196. }
  197.  
  198.  
  199. /* divide_set divides a given set according to the LTU-decided classification,
  200.    and creates two subsets */
  201.  
  202. void divide_set (super, size_super, sub0, size_sub0, sub1, size_sub1)
  203.      struct instance *super, *sub0, *sub1;
  204.      int size_super, *size_sub0, *size_sub1;
  205. {
  206.   int i;
  207.   
  208.   *size_sub0=0; *size_sub1=0;
  209.   for (i=0; i<size_super; i++) {
  210.     if (super[i].LTU_class == 0) {
  211.       sub0[*size_sub0] = super[i];
  212.       (*size_sub0)++;
  213.     }
  214.     else {
  215.       sub1[*size_sub1] = super[i];
  216.       (*size_sub1)++;
  217.     }      
  218.  
  219.   }
  220. }
  221.  
  222.  
  223. /* traverse-tree traverses any given binary tree in preorder */
  224.  
  225. void traverse_tree (tree_ptr)
  226.      struct tree_node *tree_ptr;
  227. {
  228.   int i;
  229.  
  230.   printf ("%d: ", tot_count);
  231.   for (i=0; i<num_feat; i++) printf ("%f, ", tree_ptr->weight_vect[i]); 
  232.   printf ("class: %d\n", tree_ptr->node_class);
  233.   printf (" %d\n", tree_ptr->node_class);
  234.   tot_count++;
  235.   
  236.   if (tree_ptr->neg != NULL) {
  237.     printf ("Negative: 0\n");
  238.     traverse_tree (tree_ptr->neg);
  239.   }
  240.   printf ("done with negative...\n");
  241.   
  242.   if (tree_ptr->pos != NULL) {
  243.     printf ("Positive: 1\n");
  244.     traverse_tree (tree_ptr->pos);
  245.   }
  246.   printf ("done with positive...\n");
  247. }
  248.  
  249.  
  250.  
  251. /* build_tree builds the decision tree recusively */
  252.  
  253. void build_tree (inst_set, set_size, tree_ptr) 
  254.      struct instance *inst_set; /* the set of training instances */
  255.      int set_size;               /* number of training instances */
  256.      struct tree_node *tree_ptr;
  257. {
  258.   struct instance *sub0, 
  259.                   *sub1; /* subsets according to 0/1 class */
  260.   int size_sub0, size_sub1;       /* size of sub_sets classified as 0 or 1 */
  261.   int class=0;                    /* classification 0/1 */
  262.   int i, j, k, x;                 /* temporary loop variables */
  263.   int conv=0;                     /* flag for convergence */
  264.   LTU_vect learned_weight_vec;    /* learned weight vector after RLS-training */
  265.   LTU_vect prev_weight_vec;       /* previos weight vector for RLS updating */
  266.   float tmp_float;                /* tmp float for random # */
  267.   int num_zero, num_one;
  268.   
  269.   /* -------------------------------------------------------------------------
  270.      find out if the set is homogenous according to known_class 
  271.      if (set is homogenous), or (num_instances < 2*num_features) then do nothing 
  272.      else (i.e. if the set is not homogenous then):
  273.         reset the covariance matrix (to large_num diagonal)
  274.         for each instance in the set
  275.        RLS_train the weight_vector (using the same covariance matrix)
  276.     LTU_classify the instances according to the learned weight vector 
  277.     divide the set into two subsets according to the LTU_classification
  278.     call build_tree on each of the sub_sets   
  279.      ------------------------------------------------------------------------- */
  280.  
  281.   x = homogenous (inst_set, set_size, &class);     /* determine if set is homogenous */
  282.  
  283.   if ((x == 1) || (set_size < min_set_size)) {  /* if so, then do nothing */
  284.  
  285.     tree_ptr->node_class = class;
  286.     
  287.   }
  288.   else {                                       /* if not, then ... */
  289.  
  290.     printf ("set size: %d\n", set_size);
  291.  
  292.     loop = 1;                                  /* flag for resetting cov matrix */
  293.  
  294.     /* init weight to random number */
  295.     for (k=0; k<num_feat; k++) {
  296.       tmp_float = random ();
  297.       prev_weight_vec[k] = tmp_float; 
  298.     }
  299.  
  300.     for (i=0; i<set_size; i++) {               /* train for each instance */
  301.       conv=0;
  302.       while (conv == 0) {
  303.     RLS_train (inst_set[i].feat, prev_weight_vec, learned_weight_vec, 
  304.            inst_set[i].known_class);
  305.     conv = converged (prev_weight_vec, learned_weight_vec);
  306.     loop ++;                                 /* flag for *not* resetting cov mat */
  307.     for (k=0; k<num_feat; k++) prev_weight_vec[k] = learned_weight_vec[k]; 
  308.                                           /* pass on learned vector to next update */      
  309.       }
  310.     }
  311.  
  312.     tree_ptr->node_class = non_class;
  313.     for (i=0; i<num_feat; i++) (tree_ptr->weight_vect)[i] = learned_weight_vec[i];
  314.  
  315.     for (i=0; i<set_size; i++) {           /* classify each instance according to LTU */
  316.       inst_set[i].LTU_class = LTU (inst_set[i].feat, learned_weight_vec);
  317.     }    
  318.     
  319.     if ((sub0 = (struct instance *) malloc (set_size * sizeof (struct instance))) == NULL) 
  320.       printf ("out of memory...\n");
  321.     
  322.     if ((sub1 = (struct instance *) malloc (set_size * sizeof (struct instance))) == NULL) 
  323.       printf ("out of memory...\n");
  324.     
  325.     divide_set (inst_set, set_size, sub0, &size_sub0, sub1, &size_sub1);
  326.  
  327.     if ((set_size==size_sub0) || (set_size==size_sub1)) {
  328.       num_zero=0; num_one=0;
  329.       for (j=0; j<set_size; j++){
  330.     if (inst_set[j].known_class==0) num_zero++;
  331.     else num_one++;
  332.       }
  333.       if (num_one > (num_zero/t_f_ratio)) tree_ptr->node_class = 1;
  334.       else tree_ptr->node_class = 0;
  335.       printf ("non-partitionable: 0=%d, 1=%d: %d\n", num_zero, num_one, tree_ptr->node_class);
  336.     }
  337.     else  {
  338.     
  339.       tree_ptr->pos = malloc (sizeof (struct tree_node));
  340.       tree_ptr->neg = malloc (sizeof (struct tree_node));
  341.  
  342.       /* make sure the pointers in th struct is assigned NULL */
  343.       tree_ptr->pos->pos = NULL;
  344.       tree_ptr->pos->neg = NULL;
  345.       tree_ptr->neg->pos = NULL;
  346.       tree_ptr->neg->neg = NULL;
  347.  
  348.       tree_ptr->pos->node_class = non_class;
  349.       for (i=0; i<num_feat; i++) tree_ptr->pos->weight_vect[i] = 0.0;
  350.       tree_ptr->neg->node_class = non_class;
  351.       for (i=0; i<num_feat; i++) tree_ptr->neg->weight_vect[i] = 0.0;
  352.       
  353.       build_tree (sub0, size_sub0, tree_ptr->neg);
  354.       build_tree (sub1, size_sub1, tree_ptr->pos);
  355.     }
  356.   }
  357. }
  358.  
  359.  
  360. /* classify a given actual instance, given the instance's feature values, 
  361.    and the decision-tree */
  362.  
  363. void classify (tree_ptr, inst_vect)
  364.      struct tree_node *tree_ptr;
  365.      LTU_vect inst_vect;
  366. {
  367.   int x;
  368.  
  369.   if (tree_ptr->node_class == non_class) {
  370.     x = LTU (inst_vect, tree_ptr->weight_vect);
  371.     if (x==1) {
  372.       classify (tree_ptr->pos, inst_vect);
  373.     }
  374.     else {
  375.       classify (tree_ptr->neg, inst_vect);
  376.     }
  377.   }
  378.   else {
  379.     LUT_ret_class = (tree_ptr->node_class);
  380.     /*return (tmpClass);*/
  381.   }
  382. }
  383.  
  384.  
  385. /* read data from input ascii file */
  386.  
  387. void read_actual_infile (file_name, size, data_array, idx)
  388.      char *file_name;
  389.      int size, idx;
  390.      short int data_array [num_feat][num_actual];
  391. {
  392.   FILE *file_ptr;
  393.   int i=0, x;
  394.   double y;
  395.  
  396.   printf ("%s %d\n", file_name, size);
  397.  
  398.   if ((file_ptr = fopen (file_name, "r")) == NULL)
  399.     printf ("fopen error...\n");
  400.   else {
  401.     while ( (!feof (file_ptr)) && (i<size) ) {
  402.       fscanf (file_ptr, "%d", &x);
  403.       if (!feof(file_ptr)) {
  404.     y = x;
  405.     data_array [idx][i] = x;
  406.     i++;
  407.       }
  408.     }
  409.   }
  410.   printf ("# words: %d\n", i);
  411. }
  412.  
  413.  
  414.  
  415. /*********************************************************************
  416. * Count_instances(filename)                                          *
  417. *   Counts the number of lines in filename                           *
  418. *********************************************************************/
  419.  
  420. int count_instances(fname)
  421. char *fname;
  422. {
  423.   int count = 0;
  424.   char *dummy_string;
  425.   FILE *fp;
  426.  
  427.   dummy_string = malloc(sizeof(char) * num_feat * 4);
  428.   
  429.   if ((fp = fopen(fname,"r")) == NULL) {
  430.     fprintf(stderr, "Unable to open %s\n", fname);
  431.     return(-1);
  432.   }
  433.  
  434.   while (fgets(dummy_string, (num_feat * 4), fp) != NULL)
  435.     count++;
  436.  
  437.   free((void *) dummy_string);
  438.   fclose(fp);
  439.  
  440.   return(count);
  441. }
  442.  
  443.  
  444. int bin2dec (bit_arr)
  445.      char bit_arr[CHAR_SIZE];
  446. {
  447.   int dec_num, i;
  448.   
  449.   dec_num=0;
  450.   for (i=(CHAR_SIZE-1); i>=0; i--) {
  451.     if (bit_arr[i] == '1') {
  452.       dec_num+= (pow(2,((CHAR_SIZE-1)-i)));
  453.     }
  454.   }
  455.   
  456.   return(dec_num);
  457. }
  458.  
  459.  
  460. /*--------------------------------------------------------- 
  461.   Go through every combination of RGB (from RGB_min to RGB_max), 
  462.   and classify the pixel at that value -- then store the 
  463.   classification in the LUT.  Finally, write the LUT out
  464.   to a file 
  465. ------------------------------------------------------------*/
  466.  
  467. void write_LUT_to_file (tree_ptr, file_name) 
  468.      struct tree_node *tree_ptr;
  469.      char *file_name;
  470. {
  471.   int iR, iG, iB, bit_count=0;
  472.   FILE *LUT_file_ptr;
  473.   int ASCII_dec;              /* decimal for ASCII char represented by 8 bits */
  474.   unsigned char ASCII_char;   /* ASCII char represented by 8 bits */
  475.   double v[num_feat];
  476.   char tmp_c[1], bit_arr[CHAR_SIZE];
  477.  
  478.   if ((LUT_file_ptr = fopen (file_name, "wb")) == NULL) {
  479.     printf ("fopen error...\n");
  480.     exit(-1);
  481.   }
  482.  
  483.   else {
  484.     for (iR=RGB_min; iR<RGB_max; iR++) {           /* for each RGB pixel */
  485.       for (iG=RGB_min; iG<RGB_max; iG++) {
  486.     for (iB=RGB_min; iB<RGB_max; iB++) {
  487.       v[0]=1.0;                                /* assign instance vector */
  488.       v[1]=iR;
  489.       v[2]=iG;
  490.       v[3]=iB;
  491.       classify (tree_ptr, v);                  /* classify the vector as 1 or 0 */
  492.  
  493.       sprintf (tmp_c, "%d", LUT_ret_class);    /* convert a string of 8 bits */
  494.       bit_arr[bit_count]=tmp_c[0];             /* to a char */
  495.       bit_count++;                             
  496.       if ((bit_count%(CHAR_SIZE))==0) {
  497.         bit_count=0;
  498.         ASCII_dec = bin2dec(bit_arr);
  499.         ASCII_char = (unsigned char) ASCII_dec;
  500.         fprintf (LUT_file_ptr, "%c", ASCII_char);
  501.       }
  502.       /*fprintf (LUT_file_ptr, "%c", tmp_c[0]);*/
  503.     }
  504.       }
  505.     }
  506.   }
  507.   fclose (LUT_file_ptr);
  508. }
  509.  
  510.  
  511. void main (argc, argv)
  512. int argc;
  513. char *argv[];
  514. {
  515.   int num_instances;
  516.   int i, j, temp;
  517.   struct instance *inst;
  518.   struct tree_node *tree_ptr;
  519.   FILE *input;
  520.  
  521.   if (argc != 3) {
  522.     fprintf(stdout, "\nusage: mdt feature_file LUT_file\n");
  523.     exit (-1);
  524.   }
  525.  
  526.   num_instances = count_instances(argv[1]);
  527.   if (num_instances  == -1) {
  528.     fprintf(stderr, "\naborting training\n");
  529.     exit (-1);
  530.   }
  531.  
  532.   inst = malloc(sizeof(struct instance) * num_instances);
  533.   if (inst == NULL) {
  534.     fprintf(stderr, "\nunable to allocate memory for %d instances\n", num_instances);
  535.     exit (-1);
  536.   }
  537.  
  538.   input = fopen(argv[1],"r");
  539.   if (input == NULL) {
  540.     fprintf(stderr, "\nunable to open %s (second time)\n", argv[1]);
  541.     exit (-1);
  542.   }
  543.  
  544.   for (i = 0; i < num_instances; i++) {
  545.     inst[i].feat[0] = 1;
  546.     if (fscanf(input, "%d", &temp) == EOF) {
  547.       fprintf(stderr, "\nNot enough instances?!\n");
  548.       exit (-1);
  549.     }
  550.     inst[i].known_class = temp;
  551.     for (j = 1; j < num_feat; j++) { 
  552.       if (fscanf(input, "%d", &temp) == EOF) {
  553.     fprintf(stderr, "\nRan out of data?!\n");
  554.     exit (-1);
  555.       }
  556.       inst[i].feat[j] = temp;
  557.     }
  558.   }
  559.  
  560.  
  561.   printf ("building tree...\n");
  562.   tree_ptr = malloc (sizeof (struct tree_node));  /* build decision tree */
  563.   /* Make sure the pointers in the node are initialzed */
  564.  
  565.   tree_ptr->pos = NULL;
  566.   tree_ptr->neg = NULL;
  567.   build_tree (inst, num_instances, tree_ptr);
  568.  
  569.   printf ("\n\ndecision tree:\n\n");              /* print the decision tree */
  570.  
  571.   traverse_tree (tree_ptr); 
  572.  
  573.  
  574.   /* clasify each RGB combination and store in LUT;
  575.      write LUT to file LUT_file */
  576.  
  577.   printf ("building LUT... %s\n", argv[2]);  
  578.   write_LUT_to_file (tree_ptr, argv[2]) ;
  579.  
  580.   printf ("done building LUT...\n"); 
  581.  
  582. }
  583.   
  584.  
  585.  
  586.  
  587.  
  588.  
  589.  
  590.