/* * Generating Rules from Back-Propagation Weights * * This program demos how the rules may be extracted for a net with * several inputs and one output. All the program needs are the weights * and threshold values. The algorithm is to search the space of * statements that can be made about the input to the network. For * example, "i1=1 and i4=0 and i7=1" is a statement of length three * which defines a class of inputs. The statement space is searched in * shortest-to-longest and strongest-to-weakest order. The data struct * representing a statement is an array of integers. The size is also * passed with the array, and the array is always sorted least to * greatest. */ #include #include #include #include #include #include #include #define maxwets 5000 #define maxth 1000 #define wmax 1000 FILE *netf, *rulef; char netfn[80]; char rulefn[80]; float wet[maxwets]; /* network weights are stored here */ float thresh[maxth]; float bar; int map[wmax]; int vec[wmax]; float wet2[wmax]; /* a sorted abs copy of wet */ int wsiz, limit; int t; /* used by explore */ float totmin; int insiz,h1siz,h2siz,outsiz; int numwets,numlayers,numnodes; int pass,n,fw; /* used by extract_rules and by print_rule */ char rhs[10]; char lhs[10]; long rnum; /* number of rules generated */ /**************************************************************************/ #define visit #define print_rule \ ++rnum; \ fprintf(rulef,"If"); \ for (i=0; i0) fprintf(rulef," &"); \ if ((wet[map[vec[i]]]>0) ^ (pass==0)) \ fprintf(rulef," not"); \ fprintf(rulef," %s%d",lhs,map[vec[i]]-fw); \ } \ fprintf(rulef," then"); \ if (pass!=0) fprintf(rulef," not"); \ fprintf(rulef," %s%d\n",rhs,n); /**************************************************************************/ void load_weights() /* load the dimensions and weights from the neural net file */ { char vers[11]; int i; netf = fopen(netfn,"r"); if (netf==NULL) { printf("Error opening input."); exit(0); } rulef = fopen(rulefn,"w"); if (rulef==NULL) { printf("Error opening output."); exit(0); } fscanf(netf,"%*s %s",vers); if (strcmp(vers,"1.2-100989")) { /* wrong version */ printf("Network file incompatible, version '%s'.\n",vers); exit(0); } fscanf(netf,"%*s %*s %*[^\n] %*s %*s %*s %*s"); fscanf(netf,"%d %d %d %d %d",&insiz,&h1siz,&h2siz,&outsiz,&numnodes); printf("Input size %d\n",insiz); printf("Output size %d\n",outsiz); for (i=0; i<28; i++) /* skip extra stuff */ fscanf(netf,"%*s"); for (i=0; i maxwets) { printf("Network is too large, must have less than %d weights\n",maxwets); exit(0); } for (i=0; i bar) { print_rule; return 1; } return 0; } void explore(int *vec, int num) { int *myvec; int i; t = num*sizeof(int); /* compute size of vec in bytes */ myvec = malloc(t+sizeof(int)); memcpy(myvec,vec,t); /* make a copy of it */ while (1) if (succeed(myvec,num)) { /* statement succeeded */ for (i=num-2; i>=0; i--) { t = myvec[i+1] - myvec[i]; if (t > 2) break; /* if t > 2 then no new path */ if (t == 2) { /* if t==2 then explore new path */ ++myvec[i]; explore(myvec,num); --myvec[i]; break; /* for */ } /* if t==1 then keep looking */ } if (++myvec[num-1] >= wsiz) /* keep exploring regular path */ break; /* while */ } else { /* statement failed */ if (num < limit) { if ((myvec[num]=myvec[num-1]+1) < wsiz) explore(myvec,num+1); } break; } /* end of while-loop */ free(myvec); } /**************************************************************************/ int compare(int *i,int *j) { return (fabs(wet[*i]) < fabs(wet[*j])) ? 1 : -1; } /**************************************************************************/ void extract_rules(int firstnode,int numnode,int firstwet,int numwet) /* this takes the weight array of one layer of nodes and generates rules */ { int i; wsiz = numwet; fw = firstwet; limit = wsiz / 5; if (limit < 5) limit = 5; /* set limit to wsiz/5 but not less than 5 */ for (n=0; n bar) fprintf(rulef,"If TRUE then %s%d\n",rhs,n); else { vec[0] = 0; explore(vec,1); } pass = 1; /* second pass: prove low outputs */ for (totmin=0,i=0; i0) totmin-=wet[i+fw]; totmin -= thresh[n+firstnode]; if (totmin > bar) fprintf(rulef,"If TRUE then not %s%d\n",rhs,n); else { vec[0] = 0; explore(vec,1); } fw += wsiz; } } main(int argc, char *argv[]) { if (argc == 4) { strcpy(netfn,argv[1]); printf("Input file : %s\n",netfn); strcpy(rulefn,argv[2]); printf("Output file : %s\n",rulefn); sscanf(argv[3],"%f",&bar); printf("Bar is : %1.2f\n",bar); } else { printf("Enter input neural net filename : "); scanf("%s",netfn); printf("Enter output rule filename : "); scanf("%s",rulefn); printf("Enter bar value : "); scanf("%f",&bar); } load_weights(); rnum = 0; switch (numlayers) { case 1: strcpy(rhs,"OUT"); strcpy(lhs,"IN"); extract_rules(insiz,outsiz,0,insiz); break; case 2: strcpy(rhs,"HID"); strcpy(lhs,"IN"); extract_rules(insiz,h1siz,0,insiz); strcpy(rhs,"OUT"); strcpy(lhs,"HID"); extract_rules(insiz+h1siz,outsiz,insiz*h1siz,h1siz); break; case 3: printf("I can't do three layer nets yet"); } printf("Finished.\n%ld rules generated.",rnum); }