00001 
00009 #include "party.h"
00010 
00021 void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00022                   SEXP fitmem) {
00023 
00024     SEXP x, y, expcovinf; 
00025     SEXP splitctrl, inputs; 
00026     SEXP split, thiswhichNA;
00027     int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
00028     double ms, cp, *thisweights, *cutpoint, *maxstat, 
00029            *splitstat, *dweights, *tweights, *dx, *dy;
00030     double cut, *twotab, *ytmp, sumw = 0.0;
00031     
00032     nobs = get_nobs(learnsample);
00033     ninputs = get_ninputs(learnsample);
00034     splitctrl = get_splitctrl(controls);
00035     maxsurr = get_maxsurrogate(splitctrl);
00036     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00037     jselect = S3get_variableID(S3get_primarysplit(node));
00038     
00039     
00040     y = S3get_nodeweights(VECTOR_ELT(node, S3_LEFT));
00041     ytmp = Calloc(nobs, double);
00042     for (i = 0; i < nobs; i++) {
00043         ytmp[i] = REAL(y)[i];
00044         if (ytmp[i] > 1.0) ytmp[i] = 1.0;
00045     }
00046 
00047     for (j = 0; j < ninputs; j++) {
00048         if (is_nominal(inputs, j + 1)) continue;
00049         nvar++;
00050     }
00051     nvar--;
00052 
00053     if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
00054         error("nodes does not have %d surrogate splits", maxsurr);
00055     if (maxsurr > nvar)
00056         error("cannot set up %d surrogate splits with only %d ordered input variable(s)", 
00057               maxsurr, nvar);
00058 
00059     tweights = Calloc(nobs, double);
00060     dweights = REAL(weights);
00061     for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
00062     if (has_missings(inputs, jselect)) {
00063         thiswhichNA = get_missings(inputs, jselect);
00064         for (k = 0; k < LENGTH(thiswhichNA); k++)
00065             tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
00066     }
00067 
00068     
00069     sumw = 0.0;
00070     for (i = 0; i < nobs; i++) sumw += tweights[i];
00071     if (sumw < 2.0)
00072         error("can't implement surrogate splits, not enough observations available");
00073 
00074     expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
00075     C_ExpectCovarInfluence(ytmp, 1, tweights, nobs, expcovinf);
00076     
00077     splitstat = REAL(get_splitstatistics(fitmem));
00078     
00079     maxstat = Calloc(ninputs, double);
00080     cutpoint = Calloc(ninputs, double);
00081     order = Calloc(ninputs, int);
00082     
00083     
00084     
00085     
00086 
00087 
00088     for (j = 0; j < ninputs; j++) {
00089     
00090          order[j] = j + 1;
00091          maxstat[j] = 0.0;
00092          cutpoint[j] = 0.0;
00093 
00094          
00095          if ((j + 1) == jselect || is_nominal(inputs, j + 1))
00096              continue;
00097 
00098          x = get_variable(inputs, j + 1);
00099 
00100          if (has_missings(inputs, j + 1)) {
00101 
00102              thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
00103 
00104              
00105              sumw = 0.0;
00106              for (i = 0; i < nobs; i++) sumw += thisweights[i];
00107              if (sumw < 2.0) continue;
00108                  
00109              C_ExpectCovarInfluence(ytmp, 1, thisweights, nobs, expcovinf);
00110              
00111              C_split(REAL(x), 1, ytmp, 1, thisweights, nobs,
00112                      INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00113                      GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00114                      expcovinf, &cp, &ms, splitstat);
00115          } else {
00116          
00117              C_split(REAL(x), 1, ytmp, 1, tweights, nobs,
00118              INTEGER(get_ordering(inputs, j + 1)), splitctrl,
00119              GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00120              expcovinf, &cp, &ms, splitstat);
00121          }
00122 
00123          maxstat[j] = -ms;
00124          cutpoint[j] = cp;
00125     }
00126 
00127     
00128 
00129 
00130 
00131 
00132 
00133     
00134     rsort_with_index(maxstat, order, ninputs);
00135     
00136     twotab = Calloc(4, double);
00137     
00138     
00139     for (j = 0; j < maxsurr; j++) {
00140 
00141         for (i = 0; i < 4; i++) twotab[i] = 0.0;
00142         cut = cutpoint[order[j] - 1];
00143         SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
00144                        split = allocVector(VECSXP, SPLIT_LENGTH));
00145         C_init_orderedsplit(split, 0);
00146         S3set_variableID(split, order[j]);
00147         REAL(S3get_splitpoint(split))[0] = cut;
00148         dx = REAL(get_variable(inputs, order[j]));
00149         dy = REAL(y);
00150 
00151         
00152 
00153 
00154 
00155         for (i = 0; i < nobs; i++) {
00156             twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
00157             twotab[1] += (dy[i] == 1) * tweights[i];
00158             twotab[2] += (dx[i] <= cut) * tweights[i];
00159             twotab[3] += tweights[i];
00160         }
00161         S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
00162                      twotab[3]) > 0);
00163     }
00164     
00165     Free(maxstat);
00166     Free(cutpoint);
00167     Free(order);
00168     Free(tweights);
00169     Free(twotab);
00170     Free(ytmp);
00171 }
00172 
00183 SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
00184                   SEXP fitmem) {
00185 
00186     C_surrogates(node, learnsample, weights, controls, fitmem);
00187     return(S3get_surrogatesplits(node));
00188     
00189 }
00190 
00198 void C_splitsurrogate(SEXP node, SEXP learnsample) {
00199 
00200     SEXP weights, split, surrsplit;
00201     SEXP inputs, whichNA;
00202     double cutpoint, *dx, *dweights, *leftweights, *rightweights;
00203     int *iwhichNA, k;
00204     int nobs, i, nna, ns;
00205                     
00206     weights = S3get_nodeweights(node);
00207     dweights = REAL(weights);
00208     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00209     nobs = get_nobs(learnsample);
00210             
00211     leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
00212     rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
00213     surrsplit = S3get_surrogatesplits(node);
00214 
00215     
00216     split = S3get_primarysplit(node);
00217     if (has_missings(inputs, S3get_variableID(split))) {
00218 
00219         
00220         whichNA = get_missings(inputs, S3get_variableID(split));
00221         iwhichNA = INTEGER(whichNA);
00222         nna = LENGTH(whichNA);
00223 
00224         
00225         for (k = 0; k < nna; k++) {
00226             ns = 0;
00227             i = iwhichNA[k] - 1;
00228             if (dweights[i] == 0) continue;
00229             
00230             
00231             while(TRUE) {
00232             
00233                 if (ns >= LENGTH(surrsplit)) break;
00234             
00235                 split = VECTOR_ELT(surrsplit, ns);
00236                 if (has_missings(inputs, S3get_variableID(split))) {
00237                     if (INTEGER(get_missings(inputs, 
00238                             S3get_variableID(split)))[i]) {
00239                         ns++;
00240                         continue;
00241                     }
00242                 }
00243 
00244                 cutpoint = REAL(S3get_splitpoint(split))[0];
00245                 dx = REAL(get_variable(inputs, S3get_variableID(split)));
00246 
00247                 if (S3get_toleft(split)) {
00248                     if (dx[i] <= cutpoint) {
00249                         leftweights[i] = dweights[i];
00250                         rightweights[i] = 0.0;
00251                     } else {
00252                         rightweights[i] = dweights[i];
00253                         leftweights[i] = 0.0;
00254                     }
00255                 } else {
00256                     if (dx[i] <= cutpoint) {
00257                         rightweights[i] = dweights[i];
00258                         leftweights[i] = 0.0;
00259                     } else {
00260                         leftweights[i] = dweights[i];
00261                         rightweights[i] = 0.0;
00262                     }
00263                 }
00264                 break;
00265             }
00266         }
00267     }
00268 }