Thursday, February 18, 2016

Visualizing model predictions in 3d

Here is a brief exploration of the misc3d package, which has some nice functions that can be used in conjunction with rgl. I am especially pleased with the output of contour3d which I have used to plot GAM predictions in 3d.
The example is a simple dataset of x, y, and z, data that were used to calculate a 4th variable "value" with the equation:

 value = -0.01x3 + -0.2*y2 + -0.3*z2

 Fitting GAM model to this dataset resulted in the following spline terms.



Then, the fitted GAM model was used to predict values on a regular 3d grid for plotting with the rgl package. The following plot shows the original data, with value values colored (blue colors of the spectrum are low values, red colors are high values). Finally, the contour3d function is used to add the GAM predictions as colored contours.


I got some nice insight from the R code accompanying the book by Murcell (2011).

References:
Murrell, P., 2011. R graphics. CRC Press.


Example script:

# make data ---------------------------------------------------------------
set.seed(1)
n <- 1000
x <- runif(n, min=-10, max=10)
y <- runif(n, min=-10, max=10)
z <- runif(n, min=-10, max=10)
value <- (-0.01*x^3 + -0.2*y^2 + -0.3*z^2) * rlnorm(n, 0, 0.25)
dat <- data.frame(x=x, y=y, z=z, value=value)
 
 
# fit model (GAM) ---------------------------------------------------------
library(mgcv)
fit <- gam(value ~ s(x) + s(y) + s(z), data = dat)
png("gamfit.png", width=8, height=3, units="in", type="cairo", res=400)
op <- par(mar=c(3,3,0.5,0.5), ps=10, mfrow=c(1,3), mgp=c(2,0.25,0), tcl=0.25)
for(i in seq(3)){plot.gam(fit, select=i, shade=TRUE, residuals=TRUE, rug=FALSE)}
par(op)
dev.off()
 
 
# predict to new grid -----------------------------------------------------
reso <- 30
limExt <- 0.1
ranx <- range(x)
rany <- range(y)
ranz <- range(z)
xs <- seq(ranx[1]-diff(ranx)*limExt, ranx[2]+diff(ranx)*limExt,,reso)
ys <- seq(rany[1]-diff(rany)*limExt, rany[2]+diff(rany)*limExt,,reso)
zs <- seq(ranz[1]-diff(ranz)*limExt, ranz[2]+diff(ranz)*limExt,,reso)
 
 
# 3d contours -------------------------------------------------------------
library(sinkr) # https://github.com/marchtaylor/sinkr
library(rgl)
library(misc3d)
 
nlevs <- 5
vran <- range(dat$value)
levs <- seq(vran[1], vran[2], length.out=nlevs+2)[-c(1, nlevs+2)]
levcols <- val2col(levs, jetPal(nlevs), zlim = vran)
fun <- function(x,y,z){predict(fit, data.frame(x=x, y=y, z=z))}
with(dat, spheres3d(x,y,z, col=val2col(value, col=jetPal(20), zlim=vran), radius = 0.2))
contour3d(fun, level = levs, 
          x=xs, y=ys, z=zs,
          color=levcols,
          engine="rgl", add=TRUE, alpha=0.5
)
box3d()
snapshot3d("contour3d.png")
Created by Pretty R at inside-R.org

No comments:

Post a Comment