alsoRun
2008-Mar-12 18:57 UTC
[R] R code for kernel density using kd-tree, looking for speed up
Dear R-help-list,
The following is R function I wrote for computing multi-dimensional kernel
density. I am seeking R experts who can make the code to run faster, 50 times
faster ideally.
Specifically, for function
kernel.estimate = function(points, bw),
the argument points is a d by n matrix as the n points in the d-dimensional
space, bw is the bandwidth. The function will compute the kernel density
estimate at the n-points of the input matrix points. To avoid the n^2
computational burden, I build a rd-tree which allows to quickly determine if a
source point is too far away from target point and thus can be skipped in the
summation (I used a finite support kernel). The kd-tree was built using R list
in a structure provided by R core member Thomas Lumley.
Interesting, this is an example that Luke Tierney's R compiler provides
three times speed up..
Thanks.
alsoRun
###############################
#points are the d by n matrix of the source points
get.diameter = function(box.lower.limit, box.upper.limit)
{
temp = box.lower.limit - box.upper.limit
sqrt(sum(temp*temp))/2
}
########################################
Kconst = function(d, n, bw)
{
con = gamma(d/2+1)/pi^(d/2);
con = con/( (2-2*d)/(d+2) + 2*d*(d+7)/(d+4)/(d+6));
con/n/bw^d;
}
###################################################################
## create an empty node
newtree = function(){ list(center=NULL, diameter=NULL, left=NULL, right=NULL) }
####################################################################################
## add a node to the kdtree
addNode = function(tree, points)
{
numOfPoints = ncol(points);
if(numOfPoints==1)
{
tree$center = as.vector(points);
return(tree);
}
##########################################################
box.lower.limit = apply(points, 1, min);
box.upper.limit = apply(points, 1, max);
tree$center = (box.lower.limit + box.upper.limit)/2;
tree$diameter = get.diameter(box.lower.limit, box.upper.limit);
########################################################
#preparing for the left and right tree
diff = box.upper.limit - box.lower.limit;
split.dim = which.max(diff);
split.mean = tree$center[split.dim]
index1 = (points[split.dim,] < split.mean);
leftPoints = points[,index1,drop=F];
rightPoints = points[,!index1, drop=F];
tree$left = addNode(newtree(), leftPoints);
tree$right = addNode(newtree(), rightPoints);
return(tree);
}
evaluate.element.obj = function(target.element, bw)
{
bw2 = bw*bw
func1 = function(tree)
{
temp = target.element - tree$center
dis2 = sum(temp*temp)
if(is.null(tree$left))
{
temp = 1. - dis2/bw2
if(temp<0.) 0. else temp^3
}
else
{
temp2 = bw + tree$diameter
if( dis2 > temp2*temp2 ) 0.
else func1(tree$left) + func1(tree$right) #faster than using Recall
}
}
func1
}
evaluate = function(target, tree, bw)
{
func = function(x) evaluate.element.obj(x, bw)(tree)
estimate = apply(target, 2, FUN=func)
estimate*Kconst(d=nrow(target), n=ncol(target), bw)
}
################################################################
kernel.estimate = function(points, bw)
{
tree = addNode(newtree(), points)
print(date())
evaluate(points, tree, bw)
}
main1 = function(n, d)
{
bw = 1.4794953 - 1/(d+2)*log(n)
x = rnorm(n*d);
dim(x) = c(d, n)
result = kernel.estimate(x, exp(bw))
hist(result)
}
print(date())
print(system.time(main1(1000*4,2)))
print(date())
.
---------------------------------
[[alternative HTML version deleted]]
