-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathsim.lm.hi.R
106 lines (93 loc) · 3.88 KB
/
sim.lm.hi.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
## Linear regression: test out conformal intervals and parametric intervals
## across a variety of high-dimensional settings
library(conformalInference)
# Set some overall simulation parameters
n = 500; p = 490 # Numbers of observations and features
s = 10 # Number of truly relevant features
n0 = 100 # Number of points at which to make predictions
nrep = 50 # Number of repetitions for a given setting
sigma = 1 # Marginal error standard deviation
bval = 1 # Magnitude of nonzero coefficients
lambda = c(0,1,10,50) # Lambda values to try in ridge regression
alpha = 0.1 # Miscoverage level
# Define conformal inference functions: these are basically just wrappers
# around a particular instatiation of conformal.pred, conformal.pred.jack, or
# conformal.pred.split
my.lm.funs = lm.funs(lambda=lambda)
my.conf.fun = function(x, y, x0) {
conformal.pred(x,y,x0,alpha=alpha,verb="\t\t",
train.fun=my.lm.funs$train,
predict.fun=my.lm.funs$predict)
}
my.jack.fun = function(x, y, x0) {
conformal.pred.jack(x,y,x0,alpha=alpha,verb="\t\t",
train.fun=my.lm.funs$train,
predict.fun=my.lm.funs$predict,
special.fun=my.lm.funs$special)
}
my.lm.funs.2 = lm.funs(lambda=c(1e-8,lambda[-1]))
my.split.fun = function(x, y, x0) {
conformal.pred.split(x,y,x0,alpha=alpha,
train.fun=my.lm.funs.2$train,
predict.fun=my.lm.funs.2$predict)
}
# Hack together our own "conformal" inference function, really, just one that
# returns the parametric intervals
my.param.fun = function(x, y, x0) {
n = nrow(x); n0 = nrow(x0)
out = my.lm.funs$train(x,y)
fit = matrix(my.lm.funs$predict(out,x),nrow=n)
pred = matrix(my.lm.funs$predict(out,x0),nrow=n0)
m = ncol(pred)
x1 = cbind(rep(1,n0),x0)
q = qt(1-alpha/2, n-p-1)
lo = up = matrix(0,n0,m)
for (j in 1:m) {
sig.hat = sqrt(sum((y - fit[,j])^2)/(n-ncol(x1)))
g = diag(x1 %*% chol.solve(out$chol.R[[j]], t(x1)))
lo[,j] = pred[,j] - sqrt(1+g)*sig.hat*q
up[,j] = pred[,j] + sqrt(1+g)*sig.hat*q
}
# Return proper outputs in proper formatting
return(list(pred=pred,lo=lo,up=up,fit=fit))
}
# Lastly, define a split conformal function that uses cross-validation
my.ridge.funs = ridge.funs(cv=TRUE)
my.split.cv.fun = function(x, y, x0) {
return(conformal.pred.split(x,y,x0,alpha=alpha,
train.fun=my.ridge.funs$train,
predict.fun=my.ridge.funs$predict))
}
# Now put together a list with all of our conformal inference functions
conformal.pred.funs = list(my.conf.fun, my.jack.fun, my.split.fun, my.param.fun,
my.split.cv.fun)
names(conformal.pred.funs) = c("Conformal","Jackknife","Split conformal",
"Parametric","Split + CV")
path = "rds/lm.hi2."
#source("sim.setting.a.R")
#source("sim.setting.b.R")
source("sim.setting.c.R")
## # What values of lambda did split+CV choose?
## lambda.a = lambda.b = lambda.c = c()
## for (r in 1:nrep) {
## xy.a = sim.xy(n, p, x.dist="normal", cor="none",
## mean.fun="linear", s=s, error.dist="normal", sigma=sigma,
## bval=bval, sigma.type="const")
## i = sample(n,floor(n/2))
## lambda.a = c(lambda.a, n/2*cv.glmnet(xy.a$x[i,],xy.a$y[i])$lambda.min)
## xy.b = sim.xy(n, p, x.dist="normal", cor="none",
## mean.fun="additive", m=4, s=s, error.dist="t", df=2,
## sigma=sigma, bval=bval, sigma.type="const")
## i = sample(n,floor(n/2))
## lambda.b = c(lambda.b, n/2*cv.glmnet(xy.b$x[i,],xy.b$y[i])$lambda.min)
## xy.c = sim.xy(n, p, x.dist="mix", cor="auto", k=5,
## mean.fun="linear", error.dist="t", df=2, sigma=sigma, bval=bval,
## sigma.type="var")
## i = sample(n,floor(n/2))
## lambda.c = c(lambda.c, n/2*cv.glmnet(xy.c$x[i,],xy.c$y[i])$lambda.min)
## cat(paste0(r,"."))
## }
## cat("\n")
## cat(mean(lambda.a),"\n") # 42.72443
## cat(mean(lambda.b),"\n") # 100.7886
## cat(mean(lambda.c),"\n") # 618.3248