2 Using generalized additive models (GAMs) to learn non-monotone relationships
Tải bản đầy đủ
222
CHAPTER 9
Exploring advanced methods
y
0
−10
−2
0
2
x
Figure 9.2
9.2.2
A spline that has been fit through a series of points
A one-dimensional regression example
Let’s consider a toy example where the response y is a noisy nonlinear function of the
input variable x (in fact, it’s the function shown in figure 9.2). As usual, we’ll split the
data into training and test sets.
Listing 9.6
Preparing an artificial problem
set.seed(602957)
x <- rnorm(1000)
noise <- rnorm(1000, sd=1.5)
y <- 3*sin(2*x) + cos(0.75*x) - 1.5*(x^2 ) + noise
select <- runif(1000)
frame <- data.frame(y=y, x = x)
train <- frame[select > 0.1,]
test <-frame[select <= 0.1,]
Given the data is from the nonlinear functions sin() and cos(), there shouldn’t be a
good linear fit from x to y. We’ll start by building a (poor) linear regression.
Listing 9.7
Linear regression applied to our artificial example
> lin.model <- lm(y ~ x, data=train)
> summary(lin.model)
Using generalized additive models (GAMs) to learn non-monotone relationships
223
Call:
lm(formula = y ~ x, data = train)
Residuals:
Min
1Q
-17.698 -1.774
Median
0.193
3Q
2.499
Max
7.529
Coefficients:
Estimate Std. Error t value
(Intercept) -0.8330
0.1161 -7.175
x
0.7395
0.1197
6.180
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01
Pr(>|t|)
1.51e-12 ***
9.74e-10 ***
‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 3.485 on 899 degrees of freedom
Multiple R-squared: 0.04075,
Adjusted R-squared: 0.03968
F-statistic: 38.19 on 1 and 899 DF, p-value: 9.737e-10
#
# calculate the root mean squared error (rmse)
#
> resid.lin <- train$y-predict(lin.model)
> sqrt(mean(resid.lin^2))
[1] 3.481091
The resulting model’s predictions are plotted versus true response in figure 9.3. As
expected, it’s a very poor fit, with an R-squared of about 0.04. In particular, the errors
Response v. Prediction, linear model
actual
0
−10
−3
−2
−1
0
1
pred
Figure 9.3 Linear model’s predictions versus actual response. The solid line is the line of
perfect prediction (prediction=actual).
224
CHAPTER 9
Exploring advanced methods
are heteroscedastic:5 there are regions where the model systematically underpredicts
and regions where it systematically overpredicts. If the relationship between x and y
were truly linear (with noise), then the errors would be homoscedastic: the errors would
be evenly distributed (mean 0) around the predicted value everywhere.
Let’s try finding a nonlinear model that maps x to y. We’ll use the function gam()
in the package mgcv.6 When using gam(), you can model variables as either linear or
nonlinear. You model a variable x as nonlinear by wrapping it in the s() notation. In
this example, rather than using the formula y ~ x to describe the model, you’d use the
formula y ~s(x). Then gam() will search for the spline s() that best describes the
relationship between x and y, as shown in listing 9.8. Only terms surrounded by s()
get the GAM/spline treatment.
Listing 9.8
Build the
model,
specifying
that x
should be
treated
as a
nonlinear
variable.
GAM applied to our artificial example
> library(mgcv)
> glin.model <- gam(y~s(x), data=train)
> glin.model$converged
[1] TRUE
> summary(glin.model)
Family: gaussian
Link function: identity
Formula:
y ~ s(x)
Load the mgcv package.
The converged parameter tells
you if the algorithm converged.
You should only trust the output
if this is TRUE.
Setting family=gaussian and link=identity
tells you that the model was treated with
the same distributions assumptions as a
standard linear regression.
The parametric coefficients are the
linear terms (in this example, only the
constant term). This section of the
summary tells you which linear terms
were significantly different from 0.
Parametric coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -0.83467
0.04852
-17.2
<2e-16 ***
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Approximate significance of smooth terms:
edf Ref.df
F p-value
5
6
The smooth terms are the nonlinear terms.
This section of the summary tells you which
nonlinear terms were significantly different
from 0. It also tells you the effective
degrees of freedom (edf) used up to build
each smooth term. An edf near 1 indicates
that the variable has an approximately
linear relationship to the output.
Heteroscedastic errors are errors whose magnitude is correlated with the quantity to be predicted. Heteroscedastic errors are bad because they’re systematic and violate the assumption that errors are uncorrelated with
outcomes, which is used in many proofs of the good properties of regression methods.
There’s an older package called gam, written by Hastie and Tibshirani, the inventors of GAMs. The gam package works fine. But it’s incompatible with the mgcv package, which ggplot already loads. Since we’re using
ggplot for plotting, we’ll use mgcv for our examples.
Using generalized additive models (GAMs) to learn non-monotone relationships
225
s(x) 8.685 8.972 497.8 <2e-16 ***
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
R-sq.(adj) = 0.832
Deviance explained = 83.4%
GCV score = 2.144 Scale est. = 2.121
n = 901
“R-sq (adj)” is the
adjusted R-squared.
“Deviance explained” is
the raw R-squared
(0.834).
#
# calculate the root mean squared error (rmse)
#
> resid.glin <- train$y-predict(glin.model)
> sqrt(mean(resid.glin^2))
[1] 1.448514
The resulting model’s predictions are plotted versus true response in figure 9.4. This fit
is much better: the model explains over 80% of the variance (R-squared of 0.83), and
the root mean squared error (RMSE) over the training data is less than half the RMSE
of the linear model. Note that the points in figure 9.4 are distributed more or less evenly
around the line of perfect prediction. The GAM has been fit to be homoscedastic, and
any given prediction is as likely to be an overprediction as an underprediction.
Response v. Prediction, GAM
actual
0
−10
−15
−10
−5
0
pred
Figure 9.4 GAM’s predictions versus actual response. The solid line is the theoretical line
of perfect prediction (prediction=actual).
226
CHAPTER 9
Exploring advanced methods
Modeling linear relationships using gam()
By default, gam() will perform standard linear regression. If you were to call gam()
with the formula y ~ x, you’d get the same model that you got using lm(). More
generally, the call gam(y ~ x1 + s(x2), data=...) would model the variable x1
as having a linear relationship with y, and try to fit the best possible smooth curve to
model the relationship between x2 and y. Of course, the best smooth curve could be
a straight line, so if you’re not sure whether the relationship between x and y is linear, you can use s(x). If you see that the coefficient has an edf (effective degrees
of freedom—see the model summary in listing 9.8) of about 1, then you can try refitting the variable as a linear term.
The use of splines gives GAMs a richer model space to choose from; this increased flexibility brings a higher risk of overfitting. Let’s check the models’ performances on the
test data.
Listing 9.9
>
>
>
>
>
Comparing linear regression and GAM performance
actual <- test$y
pred.lin <- predict(lin.model, newdata=test)
pred.glin <- predict(glin.model, newdata=test)
resid.lin <- actual-pred.lin
resid.glin <- actual-pred.glin
> sqrt(mean(resid.lin^2))
[1] 2.792653
> sqrt(mean(resid.glin^2))
[1] 1.401399
> cor(actual, pred.lin)^2
[1] 0.1543172
> cor(actual, pred.glin)^2
[1] 0.7828869
Call both models on
the test data.
Compare the RMSE of the linear
model and the GAM on the test data.
Compare the R-squared of the linear
model and the GAM on test data.
The GAM performed similarly on both sets (RMSE of 1.40 on test versus 1.45 on training; R-squared of 0.78 on test versus 0.83 on training). So there’s likely no overfit.
9.2.3
Extracting the nonlinear relationships
Once you fit a GAM, you’ll probably be interested in what the s() functions look
like. Calling plot() on a GAM will give you a plot for each s() curve, so you can visualize nonlinearities. In our example, plot(glin.model) produces the top curve in
figure 9.5.
The shape of the curve is quite similar to the scatter plot we saw in figure 9.2
(which is reproduced as the lower half of figure 9.5). In fact, the spline that’s superimposed on the scatter plot in figure 9.2 is the same curve.
227
−20
−15
−10
s(x,8.69)
−5
0
5
Using generalized additive models (GAMs) to learn non-monotone relationships
−3
−2
−1
0
1
2
3
x
y
0
−10
−2
0
2
x
Figure 9.5 Top: The nonlinear function s(PWGT) discovered by gam(), as output by
plot(gam.model) Bottom: The same spline superimposed over the training data
We can extract the data points that were used to make this graph by using the
predict() function with the argument type="terms". This produces a matrix where
the ith column represents s(x[,i]). Listing 9.10 demonstrates how to reproduce the
lower plot in figure 9.5.
228
CHAPTER 9
Listing 9.10
Exploring advanced methods
Extracting a learned spline from a GAM
> sx <- predict(glin.model, type="terms")
> summary(sx)
s(x)
Min.
:-17.527035
1st Qu.: -2.378636
Median : 0.009427
Mean
: 0.000000
3rd Qu.: 2.869166
Max.
: 4.084999
> xframe <- cbind(train, sx=sx[,1])
> ggplot(xframe, aes(x=x)) + geom_point(aes(y=y), alpha=0.4) +
geom_line(aes(y=sx))
Now that we’ve worked through a simple example, let’s try a more realistic example
with more variables.
9.2.4
Using GAM on actual data
For this example, we’ll predict a newborn baby’s weight (DBWT) using data from the
CDC 2010 natality dataset that we used in section 7.2 (though this is not the risk data
used in that chapter).7 As input, we’ll consider mother’s weight (PWGT), mother’s pregnancy weight gain (WTGAIN), mother’s age (MAGER), and the number of prenatal medical visits (UPREVIS).8
In the following listing, we’ll fit a linear model and a GAM, and compare.
Listing 9.11
>
>
>
>
>
>
>
>
Applying linear regression (with and without GAM) to health data
library(mgcv)
library(ggplot2)
load("NatalBirthData.rData")
train <- sdata[sdata$ORIGRANDGROUP<=5,]
test <- sdata[sdata$ORIGRANDGROUP>5,]
form.lin <- as.formula("DBWT ~ PWGT + WTGAIN + MAGER + UPREVIS")
linmodel <- lm(form.lin, data=train)
summary(linmodel)
Build a linear model
with four variables.
Call:
lm(formula = form.lin, data = train)
Residuals:
Min
1Q
-3155.43 -272.09
7
8
Median
45.04
3Q
349.81
Max
2870.55
The dataset can be found at https://github.com/WinVector/zmPDSwR/blob/master/CDC/NatalBirthData
.rData. A script for preparing the dataset from the original CDC extract can be found at https://github.com/
WinVector/zmPDSwR/blob/master/CDC/prepBirthWeightData.R.
We’ve chosen this example to highlight the mechanisms of gam(), not to find the best model for birth weight.
Adding other variables beyond the four we’ve chosen will improve the fit, but obscure the exposition.
229
Using generalized additive models (GAMs) to learn non-monotone relationships
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 2419.7090
31.9291 75.784 < 2e-16 ***
PWGT
2.1713
0.1241 17.494 < 2e-16 ***
WTGAIN
7.5773
0.3178 23.840 < 2e-16 ***
MAGER
5.3213
0.7787
6.834 8.6e-12 ***
UPREVIS
12.8753
1.1786 10.924 < 2e-16 ***
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 562.7 on 14381 degrees of freedom
Multiple R-squared: 0.06596, Adjusted R-squared: 0.0657
F-statistic: 253.9 on 4 and 14381 DF, p-value: < 2.2e-16
The model
explains about
7% of the
variance; all
coefficients are
significantly
different from 0.
Build a GAM
> form.glin <- as.formula("DBWT ~ s(PWGT) + s(WTGAIN) +
with the same
s(MAGER) + s(UPREVIS)")
variables.
> glinmodel <- gam(form.glin, data=train)
> glinmodel$converged
[1] TRUE
Verify that the model
> summary(glinmodel)
has converged.
Family: gaussian
Link function: identity
Formula:
DBWT ~ s(PWGT) + s(WTGAIN) + s(MAGER) + s(UPREVIS)
Parametric coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 3276.948
4.623
708.8
<2e-16 ***
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Approximate significance of smooth terms:
edf Ref.df
F p-value
s(PWGT)
5.374 6.443 68.981 < 2e-16 ***
s(WTGAIN) 4.719 5.743 102.313 < 2e-16 ***
s(MAGER)
7.742 8.428
6.959 1.82e-09 ***
s(UPREVIS) 5.491 6.425 48.423 < 2e-16 ***
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
R-sq.(adj) = 0.0927
Deviance explained = 9.42%
GCV score = 3.0804e+05 Scale est. = 3.0752e+05 n = 14386
The model
explains just
under 10% of
the variance; all
variables have a
nonlinear effect
significantly
different from 0.
The GAM has improved the fit, and all four variables seem to have a nonlinear relationship with birth weight, as evidenced by edfs all greater than 1. We could use
plot(glinmodel) to examine the shape of the s() functions; instead, we’ll compare
them with a direct smoothing curve of each variable against mother’s weight.
230
CHAPTER 9
Listing 9.12
Exploring advanced methods
Plotting GAM results
> terms <- predict(glinmodel, type="terms")
Get the matrix of s() functions.
> tframe
Make the
column
names
referencefriendly
(“s(PWGT)”
is converted
to “sPWGT”,
etc.).
Bind in birth
weight; convert
to data frame.
> colnames(tframe) <- gsub('[()]', '', colnames(tframe))
> pframe <- cbind(tframe, train[,c("PWGT", "WTGAIN",
"MAGER", "UPREVIS")])
> p1 <- ggplot(pframe, aes(x=PWGT)) +
Plot s(PWGT) shifted to be
zero mean versus PWGT
(mother’s weight) as points.
geom_point(aes(y=scale(sPWGT, scale=F))) +
geom_smooth(aes(y=scale(DBWT, scale=F))) +
[...]
Repeat for remaining variables
(omitted for brevity).
Bind in the
input
variables.
Plot the smoothing curve of
DWBT (birth weight) shifted to
be zero mean versus PWGT
(mother’s weight).
The plots of the s() splines compared with the smooth curves directly relating the
input variables to birth weight are shown in figure 9.6. The smooth curves in each case
are similar to the corresponding s() in shape, and nonlinear for all of the variables.
As usual, we should check for overfit with hold-out data.
Listing 9.13
Checking GAM model performance on hold-out data
pred.lin <- predict(linmodel, newdata=test)
pred.glin <- predict(glinmodel, newdata=test)
cor(pred.lin, test$DBWT)^2
# [1] 0.0616812
cor(pred.glin, test$DBWT)^2
# [1] 0.08857426
Run both the linear
model and the GAM
on the test data.
Calculate R-squared
for both models.
The performance of the linear model and the GAM were similar on the test set, as they
were on the training set, so in this example there’s no substantial overfit.
scale(sWTGAIN, scale = F)
scale(sPWGT, scale = F)
Using generalized additive models (GAMs) to learn non-monotone relationships
0
-250
-500
231
200
0
-200
100
200
300
0
25
50
75
100
WTGAIN
scale(sUPREVIS, scale = F)
scale(sMAGER, scale = F)
PWGT
0
-250
-500
-750
500
250
0
-250
-500
20
30
40
50
MAGER
0
10
20
30
40
50
UPREVIS
Figure 9.6 Smoothing curves of each of the four input variables plotted against birth
weight, compared with the splines discovered by gam(). All curves have been shifted
to be zero mean for comparison of shape.
9.2.5
Using GAM for logistic regression
The gam() function can be used for logistic regression as well. Suppose that we wanted
to predict the birth of underweight babies (defined as DBWT < 2000) from the same
variables we’ve been using. The logistic regression call to do that would be as shown in
the following listing.
Listing 9.14
GLM logistic regression
form <- as.formula("DBWT < 2000 ~ PWGT + WTGAIN + MAGER + UPREVIS")
logmod <- glm(form, data=train, family=binomial(link="logit"))
232
CHAPTER 9
Exploring advanced methods
The corresponding call to gam() also specifies the binomial family with the logit link.
Listing 9.15
GAM logistic regression
> form2 <- as.formula("DBWT<2000~s(PWGT)+s(WTGAIN)+
s(MAGER)+s(UPREVIS)")
> glogmod <- gam(form2, data=train, family=binomial(link="logit"))
> glogmod$converged
[1] TRUE
> summary(glogmod)
Family: binomial
Link function: logit
Formula:
DBWT < 2000 ~ s(PWGT) + s(WTGAIN) + s(MAGER) + s(UPREVIS)
Parametric coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -3.94085
0.06794
-58
<2e-16 ***
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
“Deviance
explained” is
the pseudo
R-squared:
1 - (deviance/
null.deviance)
Approximate significance of smooth terms:
Note that there’s no proof
that the mother’s weight
edf Ref.df Chi.sq p-value
(PWGT) has a significant
s(PWGT)
1.905 2.420
2.463 0.36412
effect on outcome.
s(WTGAIN) 3.674 4.543 64.426 1.72e-12 ***
s(MAGER)
1.003 1.005
8.335 0.00394 **
s(UPREVIS) 6.802 7.216 217.631 < 2e-16 ***
--Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
R-sq.(adj) = 0.0331
UBRE score = -0.76987
Deviance explained = 9.14%
Scale est. = 1
n = 14386
As with the standard logistic regression call, we recover the class probabilities with the
call predict(glogmodel, newdata=train, type="response"). Again these models
are coming out with low quality, and in practice we would look for more explanatory
variables to build better screening models.
The gam() package requires explicit formulas as input
You may have noticed that when calling lm(), glm(), or rpart(), we can input the
formula specification as a string. These three functions quietly convert the string
into a formula object. Unfortunately, neither gam() nor randomForest(), which you
saw in section 9.1.2, will do this automatic conversion. You must explicitly call
as.formula() to convert the string into a formula object.