SurrogateSplits.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00022                   SEXP fitmem) {
00023 
00024     SEXP x, y, expcovinf; 
00025     SEXP splitctrl, inputs; 
00026     SEXP split, thiswhichNA;
00027     int nobs, ninputs, i, j, k, jselect, maxsurr, *order;
00028     double ms, cp, *thisweights, *cutpoint, *maxstat, 
00029            *splitstat, *dweights, *tweights, *dx, *dy;
00030     double cut, *twotab;
00031     
00032     nobs = get_nobs(learnsample);
00033     ninputs = get_ninputs(learnsample);
00034     splitctrl = get_splitctrl(controls);
00035     maxsurr = get_maxsurrogate(splitctrl);
00036 
00037     if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00038         error("nodes does not have %d surrogate splits", maxsurr);
00039     if ((ninputs - 1 - maxsurr) < 1)
00040         error("cannot set up %d surrogate splits with only %d input variable(s)", 
00041               maxsurr, ninputs);
00042 
00043     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00044     jselect = S3get_variableID(S3get_primarysplit(node));
00045     y = S3get_nodeweights(VECTOR_ELT(node, 7));
00046 
00047     tweights = Calloc(nobs, double);
00048     dweights = REAL(weights);
00049     for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00050     if (has_missings(inputs, jselect)) {
00051         thiswhichNA = get_missings(inputs, jselect);
00052         for (k = 0; k < LENGTH(thiswhichNA); k++)
00053             tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00054     }
00055 
00056     expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00057     C_ExpectCovarInfluence(REAL(y), 1, REAL(weights), nobs, expcovinf);
00058     
00059     splitstat = REAL(get_splitstatistics(fitmem));
00060     /* <FIXME> extend `TreeFitMemory' to those as well ... */
00061     maxstat = Calloc(ninputs, double);
00062     cutpoint = Calloc(ninputs, double);
00063     order = Calloc(ninputs, int);
00064     /* <FIXME> */
00065     
00066     /* this is essentially an exhaustive search */
00067     /* <FIXME>: we don't want to do this for random forest like trees 
00068        </FIXME>
00069      */
00070     for (j = 0; j < ninputs; j++) {
00071     
00072          order[j] = j + 1;
00073          maxstat[j] = 0.0;
00074          cutpoint[j] = 0.0;
00075 
00076          /* ordered input variables only (for the moment) */
00077          if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00078              continue;
00079 
00080          x = get_variable(inputs, j + 1);
00081 
00082          if (has_missings(inputs, j + 1)) {
00083 
00084              thisweights = REAL(get_weights(fitmem, j + 1));
00085              for (i = 0; i < nobs; i++) thisweights[i] = tweights[i];
00086              thiswhichNA = get_missings(inputs, j + 1);
00087              for (k = 0; k < LENGTH(thiswhichNA); k++)
00088                  thisweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00089                  
00090              C_ExpectCovarInfluence(REAL(y), 1, thisweights, nobs, expcovinf);
00091              
00092              C_split(REAL(x), 1, REAL(y), 1, thisweights, nobs,
00093                      INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00094                      GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00095                      expcovinf, &cp, &ms, splitstat);
00096          } else {
00097          
00098              C_split(REAL(x), 1, REAL(y), 1, tweights, nobs,
00099              INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00100              GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00101              expcovinf, &cp, &ms, splitstat);
00102          }
00103 
00104          maxstat[j] = -ms;
00105          cutpoint[j] = cp;
00106     }
00107 
00108     /* order with respect to maximal statistic */
00109     rsort_with_index(maxstat, order, ninputs);
00110     
00111     twotab = Calloc(4, double);
00112     
00113     /* the best `maxsurr' ones are implemented */
00114     for (j = 0; j < maxsurr; j++) {
00115 
00116         for (i = 0; i < 4; i++) twotab[i] = 0.0;
00117         cut = cutpoint[order[j] - 1];
00118         SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
00119                        split = allocVector(VECSXP, SPLIT_LENGTH));
00120         C_init_orderedsplit(split, 0);
00121         S3set_variableID(split, order[j]);
00122         REAL(S3get_splitpoint(split))[0] = cut;
00123         dx = REAL(get_variable(inputs, order[j]));
00124         dy = REAL(y);
00125 
00126         /* OK, this is a dirty hack: determine if the split 
00127            goes left or right by the Pearson residual of a 2x2 table.
00128            I don't want to use the big caliber here 
00129         */
00130         for (i = 0; i < nobs; i++) {
00131             twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00132             twotab[1] += (dy[i] == 1) * tweights[i];
00133             twotab[2] += (dx[i] <= cut) * tweights[i];
00134             twotab[3] += tweights[i];
00135         }
00136         S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
00137                      twotab[3]) > 0);
00138     }
00139     
00140     Free(maxstat);
00141     Free(cutpoint);
00142     Free(order);
00143     Free(tweights);
00144     Free(twotab);
00145 }
00146 
00157 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00158                   SEXP fitmem) {
00159 
00160     C_surrogates(node, learnsample, weights, controls, fitmem);
00161     return(S3get_surrogatesplits(node));
00162     
00163 }
00164 
00172 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00173 
00174     SEXP weights, split, surrsplit;
00175     SEXP inputs, whichNA;
00176     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00177     int *iwhichNA, k;
00178     int nobs, i, nna, ns;
00179                     
00180     weights = S3get_nodeweights(node);
00181     dweights = REAL(weights);
00182     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00183     nobs = get_nobs(learnsample);
00184             
00185     leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00186     rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00187     surrsplit = S3get_surrogatesplits(node);
00188 
00189     /* if the primary split has any missings */
00190     split = S3get_primarysplit(node);
00191     if (has_missings(inputs, S3get_variableID(split))) {
00192 
00193         /* where are the missings? */
00194         whichNA = get_missings(inputs, S3get_variableID(split));
00195         iwhichNA = INTEGER(whichNA);
00196         nna = LENGTH(whichNA);
00197 
00198         /* for all missing values ... */
00199         for (k = 0; k < nna; k++) {
00200             ns = 0;
00201             i = iwhichNA[k] - 1;
00202             if (dweights[i] == 0) continue;
00203             
00204             /* loop over surrogate splits until an appropriate one is found */
00205             while(TRUE) {
00206             
00207                 if (ns >= LENGTH(surrsplit)) break;
00208             
00209                 split = VECTOR_ELT(surrsplit, ns);
00210                 if (has_missings(inputs, S3get_variableID(split))) {
00211                     if (INTEGER(get_missings(inputs, 
00212                             S3get_variableID(split)))[i]) {
00213                         ns++;
00214                         continue;
00215                     }
00216                 }
00217 
00218                 cutpoint = REAL(S3get_splitpoint(split))[0];
00219                 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00220 
00221                 if (S3get_toleft(split)) {
00222                     if (dx[i] <= cutpoint) {
00223                         leftweights[i] = dweights[i];
00224                         rightweights[i] = 0.0;
00225                     } else {
00226                         rightweights[i] = dweights[i];
00227                         leftweights[i] = 0.0;
00228                     }
00229                 } else {
00230                     if (dx[i] <= cutpoint) {
00231                         rightweights[i] = dweights[i];
00232                         leftweights[i] = 0.0;
00233                     } else {
00234                         leftweights[i] = dweights[i];
00235                         rightweights[i] = 0.0;
00236                     }
00237                 }
00238                 break;
00239             }
00240         }
00241     }
00242 }

Generated on Wed Jun 20 15:55:33 2007 for party by  doxygen 1.4.6