wheelerb@imsweb.com
2005-May-25 16:15 UTC
[Rd] Error with user defined split function in rpart (PR#7895)
Full_Name: Bill Wheeler Version: 2.0.1 OS: Windows 2000 Submission from: (NULL) (67.130.36.229) The program to reproduce the error is below. I am calling rpart with a user-defined split function for a binary response variable and one continuous independent variable. The split function works for some datasets but not others. The error is: Error in "$<-.data.frame"(`*tmp*`, "yval2", value = c(0, 15, 10, 0.6, : replacement has 5 rows, data has 1 # # Test out the "user mode" functions, with a binary response # rm(list=ls(all=TRUE)) options(warn = 1); library(rpart); set.seed(7); nobs <- 25; mydata <- data.frame(indx=1:nobs); mydata[, "y"] <- floor(runif(nobs, min=0, max=2)); mydata[, "x"] <- runif(nobs, min=0, max=2); mydata$indx <- NULL; ################################################################ # The 'evaluation' function. Called once per node. # Produce a label (1 or more elements long) for labeling each node, # and a deviance. The latter is # - of length 1 # - equal to 0 if the node is "pure" in some sense (unsplittable) # - does not need to be a deviance: any measure that gets larger # as the node is less acceptable is fine. # - the measure underlies cost-complexity pruning, however temp1 <- function(y, wt, parms) { print("***** START: TEMP1 *****"); n <- length(y); # Get the number of y's in each category sumyEqual0 <- sum(y == 0); sumyEqual1 <- sum(y == 1); # Get the proportion of 0's and 1's p0 <- sumyEqual0/n; p1 <- sumyEqual1/n; if (p0 >= p1) { dev = sumyEqual1; } else { dev = sumyEqual0; } # Get the vector of labels labels <- matrix(nrow=1, ncol=5); # labels[1] is the fitted y category # labels[2] is sum(y == 0) # labels[3] is sum(y == 1) # labels[4] is sum(y == 0)/n # labels[5] is sum(y == 1)/n if (p0 >= p1) { labels[1] = 0; } else { labels[1] = 1; } labels[2] <- sumyEqual0; labels[3] <- sumyEqual1; labels[4] <- sumyEqual0/n; labels[5] <- sumyEqual1/n; ret <- list(label=labels, deviance=dev) print("***** END: TEMP1 *****"); ret } # The split function, where most of the work occurs. # Called once per split variable per node. # If continuous=T # The actual x variable is ordered # y is supplied in the sort order of x, with no missings, # return two vectors of length (n-1): # goodness = goodness of the split, larger numbers are better. # 0 = couldn't find any worthwhile split # the ith value of goodness evaluates splitting obs 1:i vs (i+1):n # direction= -1 = send "y< cutpoint" to the left side of the tree # 1 = send "y< cutpoint" to the right # this is not a big deal, but making larger "mean y's" move towards # the right of the tree, as we do here, seems to make it easier to # read # If continuos=F, x is a set of integers defining the groups for an # unordered predictor. In this case: # direction = a vector of length m= "# groups". It asserts that the # best split can be found by lining the groups up in this order # and going from left to right, so that only m-1 splits need to # be evaluated rather than 2^(m-1) # goodness = m-1 values, as before. # # The reason for returning a vector of goodness is that the C routine # enforces the "minbucket" constraint. It selects the best return value # that is not too close to an edge. temp2 <- function(y, wt, x, parms, continuous) { print("***** START: TEMP2 *****"); n <- length(y) # For binary y, get P(Y=0)/n and P(Y=1)/n at each split temp <- cumsum(y*wt)[-n] left.wt <- cumsum(wt)[-n] right.wt <- sum(wt) - left.wt lp <- temp/left.wt rsum <- matrix(nrow=1, ncol=n-1, data=0); for (i in seq(1, n-1)) { for (j in seq(i+1, n)) { rsum[i] <- rsum[i] + y[j]; } } rp <- rsum/right.wt lprop <- 1 - lp; rprop <- rp; # Get the direction direc <- matrix(nrow=1, ncol=length(lp), data=1); for (i in seq(1, length(lp))) { if (lprop[i] >= rprop[i]) direc[i] <- -1; } goodness <- (lprop + rprop); ret <- list(goodness= goodness, direction=direc) print("***** END: TEMP2 *****"); ret } # The init function: # fix up y to deal with offsets # return a dummy parms list # numresp is the number of values produced by the eval routine's "label" # numy is the number of columns for y # summary is a function used to print one line in summary.rpart # text is a function used to put text on the plot in text.rpart # In general, this function would also check for bad data, see rpart.poisson # for instace. temp3 <- function(y, offset, parms, wt) { print("***** START: TEMP3 *****"); if (!is.null(offset)) y <- y-offset ret <- list(y=y, parms=0, numresp=5, numy=1, summary= function(yval, dev, wt, ylevel, digits ) { paste(" mean=", format(signif(yval, digits)), ", MSE=" , format(signif(dev/wt, digits)), sep='') }, text= function(yval, dev, wt, ylevel, digits, n, use.n ) { if(use.n) {paste(formatg(yval,digits),"\nn=", n,sep="")} else{paste(formatg(yval,digits))} }) print("***** END: TEMP3 *****"); ret } alist <- list(eval=temp1, split=temp2, init=temp3); fit1 <- rpart(y ~ ., data=mydata, method=alist, control=list(cp=0));