alsoRun
2008-Mar-12 18:57 UTC
[R] R code for kernel density using kd-tree, looking for speed up
Dear R-help-list, The following is R function I wrote for computing multi-dimensional kernel density. I am seeking R experts who can make the code to run faster, 50 times faster ideally. Specifically, for function kernel.estimate = function(points, bw), the argument points is a d by n matrix as the n points in the d-dimensional space, bw is the bandwidth. The function will compute the kernel density estimate at the n-points of the input matrix points. To avoid the n^2 computational burden, I build a rd-tree which allows to quickly determine if a source point is too far away from target point and thus can be skipped in the summation (I used a finite support kernel). The kd-tree was built using R list in a structure provided by R core member Thomas Lumley. Interesting, this is an example that Luke Tierney's R compiler provides three times speed up.. Thanks. alsoRun ############################### #points are the d by n matrix of the source points get.diameter = function(box.lower.limit, box.upper.limit) { temp = box.lower.limit - box.upper.limit sqrt(sum(temp*temp))/2 } ######################################## Kconst = function(d, n, bw) { con = gamma(d/2+1)/pi^(d/2); con = con/( (2-2*d)/(d+2) + 2*d*(d+7)/(d+4)/(d+6)); con/n/bw^d; } ################################################################### ## create an empty node newtree = function(){ list(center=NULL, diameter=NULL, left=NULL, right=NULL) } #################################################################################### ## add a node to the kdtree addNode = function(tree, points) { numOfPoints = ncol(points); if(numOfPoints==1) { tree$center = as.vector(points); return(tree); } ########################################################## box.lower.limit = apply(points, 1, min); box.upper.limit = apply(points, 1, max); tree$center = (box.lower.limit + box.upper.limit)/2; tree$diameter = get.diameter(box.lower.limit, box.upper.limit); ######################################################## #preparing for the left and right tree diff = box.upper.limit - box.lower.limit; split.dim = which.max(diff); split.mean = tree$center[split.dim] index1 = (points[split.dim,] < split.mean); leftPoints = points[,index1,drop=F]; rightPoints = points[,!index1, drop=F]; tree$left = addNode(newtree(), leftPoints); tree$right = addNode(newtree(), rightPoints); return(tree); } evaluate.element.obj = function(target.element, bw) { bw2 = bw*bw func1 = function(tree) { temp = target.element - tree$center dis2 = sum(temp*temp) if(is.null(tree$left)) { temp = 1. - dis2/bw2 if(temp<0.) 0. else temp^3 } else { temp2 = bw + tree$diameter if( dis2 > temp2*temp2 ) 0. else func1(tree$left) + func1(tree$right) #faster than using Recall } } func1 } evaluate = function(target, tree, bw) { func = function(x) evaluate.element.obj(x, bw)(tree) estimate = apply(target, 2, FUN=func) estimate*Kconst(d=nrow(target), n=ncol(target), bw) } ################################################################ kernel.estimate = function(points, bw) { tree = addNode(newtree(), points) print(date()) evaluate(points, tree, bw) } main1 = function(n, d) { bw = 1.4794953 - 1/(d+2)*log(n) x = rnorm(n*d); dim(x) = c(d, n) result = kernel.estimate(x, exp(bw)) hist(result) } print(date()) print(system.time(main1(1000*4,2))) print(date()) . --------------------------------- [[alternative HTML version deleted]]