00001
00009 #include "party.h"
00010
00011
00022 void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem,
00023 SEXP controls, int *where, int *nodenum) {
00024
00025 SEXP weights;
00026 int nobs, i;
00027 double *dweights;
00028
00029 weights = S3get_nodeweights(node);
00030
00031 if ((nodenum[0] == 2 || nodenum[0] == 3) &&
00032 get_stump(get_tgctrl(controls)))
00033 C_Node(node, learnsample, weights, fitmem, controls, 1);
00034 else
00035 C_Node(node, learnsample, weights, fitmem, controls, 0);
00036
00037 S3set_nodeID(node, nodenum[0]);
00038
00039 if (!S3get_nodeterminal(node)) {
00040
00041 C_splitnode(node, learnsample, controls);
00042
00043
00044 if (get_maxsurrogate(get_splitctrl(controls)) > 0) {
00045 C_surrogates(node, learnsample, weights, controls, fitmem);
00046 C_splitsurrogate(node, learnsample);
00047 }
00048
00049 nodenum[0] += 1;
00050 C_TreeGrow(S3get_leftnode(node), learnsample, fitmem,
00051 controls, where, nodenum);
00052
00053 nodenum[0] += 1;
00054 C_TreeGrow(S3get_rightnode(node), learnsample, fitmem,
00055 controls, where, nodenum);
00056 } else {
00057 dweights = REAL(weights);
00058 nobs = get_nobs(learnsample);
00059 for (i = 0; i < nobs; i++)
00060 if (dweights[i] > 0) where[i] = nodenum[0];
00061 }
00062 }
00063
00064
00074 SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) {
00075
00076 SEXP ans, nweights;
00077 double *dnweights, *dweights;
00078 int nobs, i, nodenum = 1;
00079
00080 nobs = get_nobs(learnsample);
00081 PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00082 C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)),
00083 ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00084 PL2_jointtransfSym)));
00085
00086 nweights = S3get_nodeweights(ans);
00087 dnweights = REAL(nweights);
00088 dweights = REAL(weights);
00089 for (i = 0; i < nobs; i++) dnweights[i] = dweights[i];
00090
00091 C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum);
00092 UNPROTECT(1);
00093 return(ans);
00094 }
00095
00096
00107 SEXP R_Ensemble(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP ans) {
00108
00109 SEXP nweights, tree, where;
00110 double *dnweights, *dweights, sw = 0.0, *prob;
00111 int nobs, i, b, B , nodenum = 1, *iweights, *iwhere;
00112
00113 B = LENGTH(ans);
00114 nobs = get_nobs(learnsample);
00115
00116 iweights = Calloc(nobs, int);
00117 prob = Calloc(nobs, double);
00118 dweights = REAL(weights);
00119
00120 for (i = 0; i < nobs; i++)
00121 sw += dweights[i];
00122 for (i = 0; i < nobs; i++)
00123 prob[i] = dweights[i]/sw;
00124
00125 for (b = 0; b < B; b++) {
00126 SET_VECTOR_ELT(ans, b, tree = allocVector(VECSXP, NODE_LENGTH + 1));
00127 SET_VECTOR_ELT(tree, NODE_LENGTH, where = allocVector(INTSXP, nobs));
00128 iwhere = INTEGER(where);
00129 for (i = 0; i < nobs; i++) iwhere[i] = 0;
00130
00131 C_init_node(tree, nobs, get_ninputs(learnsample),
00132 get_maxsurrogate(get_splitctrl(controls)),
00133 ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym),
00134 PL2_jointtransfSym)));
00135
00136
00137 GetRNGstate();
00138 rmultinom((int) sw, prob, nobs, iweights);
00139 PutRNGstate();
00140
00141 nweights = S3get_nodeweights(tree);
00142 dnweights = REAL(nweights);
00143 for (i = 0; i < nobs; i++) dnweights[i] = (double) iweights[i];
00144
00145 C_TreeGrow(tree, learnsample, fitmem, controls, iwhere, &nodenum);
00146 nodenum = 1;
00147 }
00148 Free(prob); Free(iweights);
00149 return(ans);
00150 }
00151