James Jong
2013-Mar-06 02:47 UTC
[R] CARET and NNET fail to train a model when the input is high dimensional
The following code fails to train a nnet model in a random dataset using
caret:
nR <- 700
nCol <- 2000
myCtrl <- trainControl(method="cv", number=3,
preProcOptions=NULL,
classProbs = TRUE, summaryFunction = twoClassSummary)
trX <- data.frame(replicate(nR, rnorm(nCol)))
trY <- runif(1)*trX[,1]*trX[,2]^2+runif(1)*trX[,3]/trX[,4]
trY <- as.factor(ifelse(sign(trY)>0,'X1','X0'))
my.grid <- createGrid(method.name, grid.len, data=trX)
my.model <-
train(trX,trY,method=method.name,trace=FALSE,trControl=myCtrl,tuneGrid=my.grid,
metric="ROC")
print("Done")
The error I get is:
task 2 failed - "arguments imply differing number of rows: 1334, 666"
However, everything works if I reduce nR to, say 20.
Any thoughts on what may be causing this? Is there a place where I could
report this bug other than this mailing list?
Here is my session info:> sessionInfo()
R version 2.15.2 (2012-10-26)
Platform: x86_64-unknown-linux-gnu (64-bit)
locale:
[1] C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] nnet_7.3-5 pROC_1.5.4 caret_5.15-052 foreach_1.4.0
[5] cluster_1.14.3 plyr_1.8 reshape2_1.2.2 lattice_0.20-13
loaded via a namespace (and not attached):
[1] codetools_0.2-8 compiler_2.15.2 grid_2.15.2 iterators_1.0.6
[5] stringr_0.6.2 tools_2.15.2
Thanks,
James
[[alternative HTML version deleted]]
Max Kuhn
2013-Mar-06 14:59 UTC
[R] CARET and NNET fail to train a model when the input is high dimensional
James,
I did a fresh install from CRAN to get caret_5.15-61 and ran your code with
method.name = "nnet" and grid.len = 3.
I don't get an error, although there were issues:
In nominalTrainWorkflow(dat = trainData, info = trainInfo, ... :
There were missing values in resampled performance measures.
The results had:
Resampling results across tuning parameters:
size decay ROC Sens Spec ROC SD Sens SD Spec SD
1 0 0.521 0.52 0.521 0.0148 0.0312 0.00901
1 1e-04 0.513 0.528 0.498 0.00616 0.00386 0.00552
1 0.1 0.515 0.522 0.514 0.0169 0.0284 0.0426
3 0 NaN NaN NaN NA NA NA
3 1e-04 NaN NaN NaN NA NA NA
3 0.1 NaN NaN NaN NA NA NA
5 0 NaN NaN NaN NA NA NA
5 1e-04 NaN NaN NaN NA NA NA
5 0.1 NaN NaN NaN NA NA NA
To test more, I ran:
> test <- nnet(trX, trY, size = 3, decay = 0)
Error in nnet.default(trX, trY, size = 3, decay = 0) :
too many (2107) weights
So, you need to pass in MaxNWts to nnet() with a value that let's you fit
the model. Off the top of my head, you could use something like:
MaxNWts = length(levels(trY))*(max(my.grid$.size) * (nCol + 1) +
max(my.grid$.size) + 1)
Also, this one of the methods for getting help (the other is to just email
me). I also try to keep up on stack exchange too.
Max
On Tue, Mar 5, 2013 at 9:47 PM, James Jong <ribonucleico@gmail.com> wrote:
> The following code fails to train a nnet model in a random dataset using
> caret:
>
> nR <- 700
> nCol <- 2000
> myCtrl <- trainControl(method="cv", number=3,
preProcOptions=NULL,
> classProbs = TRUE, summaryFunction = twoClassSummary)
> trX <- data.frame(replicate(nR, rnorm(nCol)))
> trY <- runif(1)*trX[,1]*trX[,2]^2+runif(1)*trX[,3]/trX[,4]
> trY <- as.factor(ifelse(sign(trY)>0,'X1','X0'))
> my.grid <- createGrid(method.name, grid.len, data=trX)
> my.model <- train(trX,trY,method=method.name
> ,trace=FALSE,trControl=myCtrl,tuneGrid=my.grid,
> metric="ROC")
> print("Done")
>
> The error I get is:
> task 2 failed - "arguments imply differing number of rows: 1334,
666"
>
> However, everything works if I reduce nR to, say 20.
>
> Any thoughts on what may be causing this? Is there a place where I could
> report this bug other than this mailing list?
>
> Here is my session info:
> > sessionInfo()
> R version 2.15.2 (2012-10-26)
> Platform: x86_64-unknown-linux-gnu (64-bit)
>
> locale:
> [1] C
>
> attached base packages:
> [1] stats graphics grDevices utils datasets methods base
>
> other attached packages:
> [1] nnet_7.3-5 pROC_1.5.4 caret_5.15-052 foreach_1.4.0
> [5] cluster_1.14.3 plyr_1.8 reshape2_1.2.2 lattice_0.20-13
>
> loaded via a namespace (and not attached):
> [1] codetools_0.2-8 compiler_2.15.2 grid_2.15.2 iterators_1.0.6
> [5] stringr_0.6.2 tools_2.15.2
>
> Thanks,
>
> James
>
> [[alternative HTML version deleted]]
>
> ______________________________________________
> R-help@r-project.org mailing list
> https://stat.ethz.ch/mailman/listinfo/r-help
> PLEASE do read the posting guide
> http://www.R-project.org/posting-guide.html
> and provide commented, minimal, self-contained, reproducible code.
>
--
Max
[[alternative HTML version deleted]]