Predict.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00021 void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
00022 
00023     SEXP weights, leftnode, rightnode, split;
00024     SEXP responses, inputs, whichNA;
00025     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00026     double sleft = 0.0, sright = 0.0;
00027     int *ix, *levelset, *iwhichNA;
00028     int nobs, i, nna;
00029                     
00030     weights = S3get_nodeweights(node);
00031     dweights = REAL(weights);
00032     responses = GET_SLOT(learnsample, PL2_responsesSym);
00033     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00034     nobs = get_nobs(learnsample);
00035             
00036     /* set up memory for the left daughter */
00037     SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
00038     C_init_node(leftnode, nobs, 
00039         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00040         ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00041     leftweights = REAL(S3get_nodeweights(leftnode));
00042 
00043     /* set up memory for the right daughter */
00044     SET_VECTOR_ELT(node, S3_RIGHT, 
00045                    rightnode = allocVector(VECSXP, NODE_LENGTH));
00046     C_init_node(rightnode, nobs, 
00047         get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
00048         ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00049     rightweights = REAL(S3get_nodeweights(rightnode));
00050 
00051     /* split according to the primary split */
00052     split = S3get_primarysplit(node);
00053     if (has_missings(inputs, S3get_variableID(split))) {
00054         whichNA = get_missings(inputs, S3get_variableID(split));
00055         iwhichNA = INTEGER(whichNA);
00056         nna = LENGTH(whichNA);
00057     } else {
00058         nna = 0;
00059         whichNA = R_NilValue;
00060         iwhichNA = NULL;
00061     }
00062     
00063     if (S3is_ordered(split)) {
00064         cutpoint = REAL(S3get_splitpoint(split))[0];
00065         dx = REAL(get_variable(inputs, S3get_variableID(split)));
00066         for (i = 0; i < nobs; i++) {
00067             if (nna > 0) {
00068                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00069             }
00070             if (dx[i] <= cutpoint) 
00071                 leftweights[i] = dweights[i]; 
00072             else 
00073                 leftweights[i] = 0.0;
00074             rightweights[i] = dweights[i] - leftweights[i];
00075             sleft += leftweights[i];
00076             sright += rightweights[i];
00077         }
00078     } else {
00079         levelset = INTEGER(S3get_splitpoint(split));
00080         ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
00081 
00082         for (i = 0; i < nobs; i++) {
00083             if (nna > 0) {
00084                 if (i_in_set(i + 1, iwhichNA, nna)) continue;
00085             }
00086             if (levelset[ix[i] - 1])
00087                 leftweights[i] = dweights[i];
00088             else 
00089                 leftweights[i] = 0.0;
00090             rightweights[i] = dweights[i] - leftweights[i];
00091             sleft += leftweights[i];
00092             sright += rightweights[i];
00093         }
00094     }
00095     
00096     /* for the moment: NA's go with majority */
00097     if (nna > 0) {
00098         for (i = 0; i < nna; i++) {
00099             if (sleft > sright) {
00100                 leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00101                 rightweights[iwhichNA[i] - 1] = 0.0;
00102             } else {
00103                 rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
00104                 leftweights[iwhichNA[i] - 1] = 0.0;
00105             }
00106         }
00107     }
00108 }
00109 
00110 
00120 SEXP C_get_node(SEXP subtree, SEXP newinputs, 
00121                 double mincriterion, int numobs) {
00122 
00123     SEXP split, whichNA, weights, ssplit, surrsplit;
00124     double cutpoint, x, *dweights, swleft, swright;
00125     int level, *levelset, i, ns;
00126 
00127     if (S3get_nodeterminal(subtree) || 
00128         REAL(S3get_maxcriterion(subtree))[0] < mincriterion) 
00129         return(subtree);
00130     
00131     split = S3get_primarysplit(subtree);
00132 
00133     /* missing values. Maybe store the proportions left / 
00134        right in each node? */
00135     if (has_missings(newinputs, S3get_variableID(split))) {
00136         whichNA = get_missings(newinputs, S3get_variableID(split));
00137     
00138         /* numobs 0 ... n - 1 but whichNA has 1:n */
00139         if (C_i_in_set(numobs + 1, whichNA)) {
00140         
00141             surrsplit = S3get_surrogatesplits(subtree);
00142             ns = 0;
00143             i = numobs;      
00144 
00145             /* try to find a surrogate split */
00146             while(TRUE) {
00147     
00148                 if (ns >= LENGTH(surrsplit)) break;
00149             
00150                 ssplit = VECTOR_ELT(surrsplit, ns);
00151                 if (has_missings(newinputs, S3get_variableID(ssplit))) {
00152                     if (INTEGER(get_missings(newinputs, 
00153                                              S3get_variableID(ssplit)))[i]) {
00154                         ns++;
00155                         continue;
00156                     }
00157                 }
00158 
00159                 cutpoint = REAL(S3get_splitpoint(ssplit))[0];
00160                 x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
00161                      
00162                 if (S3get_toleft(ssplit)) {
00163                     if (x <= cutpoint) {
00164                         return(C_get_node(S3get_leftnode(subtree),
00165                                           newinputs, mincriterion, numobs));
00166                     } else {
00167                         return(C_get_node(S3get_rightnode(subtree),
00168                                newinputs, mincriterion, numobs));
00169                     }
00170                 } else {
00171                     if (x <= cutpoint) {
00172                         return(C_get_node(S3get_rightnode(subtree),
00173                                           newinputs, mincriterion, numobs));
00174                     } else {
00175                         return(C_get_node(S3get_leftnode(subtree),
00176                                newinputs, mincriterion, numobs));
00177                     }
00178                 }
00179                 break;
00180             }
00181 
00182             /* if this was not successful, we go with the majority */
00183             swleft = S3get_sumweights(S3get_leftnode(subtree));
00184             swright = S3get_sumweights(S3get_rightnode(subtree));
00185             if (swleft > swright) {
00186                 return(C_get_node(S3get_leftnode(subtree), 
00187                                   newinputs, mincriterion, numobs));
00188             } else {
00189                 return(C_get_node(S3get_rightnode(subtree), 
00190                                   newinputs, mincriterion, numobs));
00191             }
00192         }
00193     }
00194     
00195     if (S3is_ordered(split)) {
00196         cutpoint = REAL(S3get_splitpoint(split))[0];
00197         x = REAL(get_variable(newinputs, 
00198                      S3get_variableID(split)))[numobs];
00199         if (x <= cutpoint) {
00200             return(C_get_node(S3get_leftnode(subtree), 
00201                               newinputs, mincriterion, numobs));
00202         } else {
00203             return(C_get_node(S3get_rightnode(subtree), 
00204                               newinputs, mincriterion, numobs));
00205         }
00206     } else {
00207         levelset = INTEGER(S3get_splitpoint(split));
00208         level = INTEGER(get_variable(newinputs, 
00209                             S3get_variableID(split)))[numobs];
00210         /* level is in 1, ..., K */
00211         if (levelset[level - 1]) {
00212             return(C_get_node(S3get_leftnode(subtree), newinputs, 
00213                               mincriterion, numobs));
00214         } else {
00215             return(C_get_node(S3get_rightnode(subtree), newinputs, 
00216                               mincriterion, numobs));
00217         }
00218     }
00219 }
00220 
00221 
00230 SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion, 
00231                 SEXP numobs) {
00232     return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
00233                       INTEGER(numobs)[0] - 1));
00234 }
00235 
00236 
00243 SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
00244     
00245     if (nodenum == S3get_nodeID(subtree)) return(subtree);
00246 
00247     if (S3get_nodeterminal(subtree)) 
00248         error("no node with number %d\n", nodenum);
00249 
00250     if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
00251         return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
00252     } else {
00253         return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
00254     }
00255 }
00256 
00257 
00264 SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
00265     return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
00266 }
00267 
00268 
00277 SEXP C_get_prediction(SEXP subtree, SEXP newinputs, 
00278                       double mincriterion, int numobs) {
00279     return(S3get_prediction(C_get_node(subtree, newinputs, 
00280                             mincriterion, numobs)));
00281 }
00282 
00283 
00292 SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs, 
00293                        double mincriterion, int numobs) {
00294     return(S3get_nodeweights(C_get_node(subtree, newinputs, 
00295                              mincriterion, numobs)));
00296 }
00297 
00298 
00307 int C_get_nodeID(SEXP subtree, SEXP newinputs,
00308                   double mincriterion, int numobs) {
00309      return(S3get_nodeID(C_get_node(subtree, newinputs, 
00310             mincriterion, numobs)));
00311 }
00312 
00313 
00321 SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00322 
00323     SEXP ans;
00324     int nobs, i, *dans;
00325             
00326     nobs = get_nobs(newinputs);
00327     PROTECT(ans = allocVector(INTSXP, nobs));
00328     dans = INTEGER(ans);
00329     for (i = 0; i < nobs; i++)
00330          dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00331     UNPROTECT(1);
00332     return(ans);
00333 }
00334 
00335 
00344 void C_predict(SEXP tree, SEXP newinputs, double mincriterion, SEXP ans) {
00345     
00346     int nobs, i;
00347     
00348     nobs = get_nobs(newinputs);    
00349     if (LENGTH(ans) != nobs) 
00350         error("ans is not of length %d\n", nobs);
00351         
00352     for (i = 0; i < nobs; i++)
00353         SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs, 
00354                        mincriterion, i));
00355 }
00356 
00357 
00365 SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion) {
00366 
00367     SEXP ans;
00368     int nobs;
00369     
00370     nobs = get_nobs(newinputs);
00371     PROTECT(ans = allocVector(VECSXP, nobs));
00372     C_predict(tree, newinputs, REAL(mincriterion)[0], ans);
00373     UNPROTECT(1);
00374     return(ans);
00375 }
00376 
00377 
00385 void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
00386 
00387     int nobs, i, *iwhere;
00388     
00389     nobs = LENGTH(where);
00390     iwhere = INTEGER(where);
00391     if (LENGTH(ans) != nobs)
00392         error("ans is not of length %d\n", nobs);
00393         
00394     for (i = 0; i < nobs; i++)
00395         SET_VECTOR_ELT(ans, i, S3get_prediction(
00396             C_get_nodebynum(tree, iwhere[i])));
00397 }
00398 
00399 
00406 SEXP R_getpredictions(SEXP tree, SEXP where) {
00407 
00408     SEXP ans;
00409     int nobs;
00410             
00411     nobs = LENGTH(where);
00412     PROTECT(ans = allocVector(VECSXP, nobs));
00413     C_getpredictions(tree, where, ans);
00414     UNPROTECT(1);
00415     return(ans);
00416 }                        
00417 
00428 SEXP R_predictRF_weights(SEXP forest, SEXP where, SEXP weights, 
00429                          SEXP newinputs, SEXP mincriterion, SEXP oobpred) {
00430 
00431     SEXP ans, tree, bw;
00432     int ntrees, nobs, i, b, j, q, iwhere, oob = 0, count = 0, ntrain;
00433     
00434     if (LOGICAL(oobpred)[0]) oob = 1;
00435     
00436     nobs = get_nobs(newinputs);
00437     ntrees = LENGTH(forest);
00438     q = LENGTH(S3get_prediction(
00439                    C_get_nodebynum(VECTOR_ELT(forest, 0), 1)));
00440 
00441     if (oob) {
00442         if (LENGTH(VECTOR_ELT(weights, 0)) != nobs)
00443             error("number of observations don't match");
00444     }    
00445     
00446     tree = VECTOR_ELT(forest, 0);
00447     ntrain = LENGTH(VECTOR_ELT(weights, 0));
00448     
00449     PROTECT(ans = allocVector(VECSXP, nobs));
00450     
00451     for (i = 0; i < nobs; i++) {
00452         count = 0;
00453         SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
00454         for (j = 0; j < ntrain; j++)
00455             REAL(bw)[j] = 0.0;
00456         for (b = 0; b < ntrees; b++) {
00457             tree = VECTOR_ELT(forest, b);
00458 
00459             if (oob && 
00460                 REAL(VECTOR_ELT(weights, b))[i] > 0.0) 
00461                 continue;
00462 
00463             iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i);
00464             
00465             for (j = 0; j < ntrain; j++) {
00466                 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00467                     REAL(bw)[j] += REAL(VECTOR_ELT(weights, b))[j];
00468             }
00469             count++;
00470         }
00471         if (count == 0) 
00472             error("cannot compute out-of-bag predictions for obs ", i + 1);
00473     }
00474     UNPROTECT(1);
00475     return(ans);
00476 }
00477 
00478 
00484 SEXP R_proximity(SEXP where) {
00485 
00486     SEXP ans, bw, bin;
00487     int ntrees, nobs, i, b, j, iwhere;
00488     
00489     ntrees = LENGTH(where);
00490     nobs = LENGTH(VECTOR_ELT(where, 0));
00491     
00492     PROTECT(ans = allocVector(VECSXP, nobs));
00493     PROTECT(bin = allocVector(INTSXP, nobs));
00494      
00495     for (i = 0; i < nobs; i++) {
00496         SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, nobs));
00497         for (j = 0; j < nobs; j++) {
00498             REAL(bw)[j] = 0.0;
00499             INTEGER(bin)[j] = 0;
00500         }
00501         for (b = 0; b < ntrees; b++) {
00502             /* don't look at out-of-bag observations */
00503             if (INTEGER(VECTOR_ELT(where, b))[i] == 0)
00504                 continue;
00505             iwhere = INTEGER(VECTOR_ELT(where, b))[i];
00506             for (j = 0; j < nobs; j++) {
00507                 if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
00508                     /* only count the number of trees; no weights */
00509                     REAL(bw)[j]++;
00510                 if (INTEGER(VECTOR_ELT(where, b))[j] > 0)
00511                     /* count the number of bootstrap samples
00512                     containing both i and j */
00513                     INTEGER(bin)[j]++;
00514             }
00515         }
00516         for (j = 0; j < nobs; j++)
00517             REAL(bw)[j] = REAL(bw)[j] / INTEGER(bin)[j];
00518     }
00519     UNPROTECT(2);
00520     return(ans);
00521 }

Generated on Tue Jun 16 09:15:20 2009 for party by  doxygen 1.5.8