Skip to content

Commit

Permalink
edits to pfilter weight argument
Browse files Browse the repository at this point in the history
Modified tests + examples for weights
  • Loading branch information
jeswheel authored and kingaa committed Dec 12, 2024
1 parent 9bd52a4 commit 8086bb1
Show file tree
Hide file tree
Showing 45 changed files with 171 additions and 101 deletions.
36 changes: 27 additions & 9 deletions R/pfilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
##' @param filter.traj logical; if \code{TRUE}, a filtered trajectory is returned for the state variables and parameters.
##' See \code{\link{filter_traj}} for more information.
##' @param save.states character;
##' If \code{save.states="unweighted"}, the state-vector for each unweighted particle at each time is saved.
##' If \code{save.states="weighted"}, the state-vector for each weighted particle at each time is saved, along with the corresponding weight.
##' If \code{save.states="no"}, information on the latent states is not saved.
##' \code{"FALSE"} is a synonym for \code{"no"} and \code{"TRUE"} is a synonym for \code{"unweighted"}.
##' If \code{save.states="no"} (the default), information on the latent states is not saved.
##' If \code{save.states="filter"}, the state-vector for each filtered particle \eqn{X_{n,j}^F} at each time \eqn{n} is saved.
##' If \code{save.states="prediction"}, the state-vector for each prediction particle \eqn{X_{n,j}^P} at each time \eqn{n} is saved, along with the corresponding weight \eqn{w_{n,j} = f_{Y_n|X_n}(y^*|X_{n, j}^P;\theta)}.
##' The options "unweighted", "weighted", TRUE, and FALSE are deprecated and will issue a warning if used, mapping to the new values for backward compatibility.
##' The option "unweighted" and TRUE are synonymous with "filter";
##' the option "weighted" is synonymous with "prediction";
##' the option FALSE is synonymous with "no".
##' To retrieve the saved states, apply \code{\link{saved_states}} to the result of the \code{pfilter} computation.
##' @param ... additional arguments are passed to \code{\link{pomp}}.
##' This allows one to set, unset, or modify \link[=basic_components]{basic model components} within a call to this function.
Expand Down Expand Up @@ -129,7 +132,8 @@ setMethod(
pred.var = FALSE,
filter.mean = FALSE,
filter.traj = FALSE,
save.states = c("no", "weighted", "unweighted", "FALSE", "TRUE"),
save.states = c("no", "filter", "prediction",
"weighted", "unweighted", "FALSE", "TRUE"),
verbose = getOption("verbose", FALSE)
) {

Expand Down Expand Up @@ -168,7 +172,8 @@ setMethod(
pred.var = FALSE,
filter.mean = FALSE,
filter.traj = FALSE,
save.states = c("no", "weighted", "unweighted", "FALSE", "TRUE"),
save.states = c("no", "filter", "prediction",
"weighted", "unweighted", "FALSE", "TRUE"),
verbose = getOption("verbose", FALSE)
) {

Expand Down Expand Up @@ -213,7 +218,8 @@ pfilter_internal <- function (
...,
pred.mean = FALSE, pred.var = FALSE, filter.mean = FALSE,
filter.traj = FALSE, cooling, cooling.m,
save.states = c("no", "weighted", "unweighted", "FALSE", "TRUE"),
save.states = c("no", "filter", "prediction",
"weighted", "unweighted", "FALSE", "TRUE"),
.gnsi = TRUE, verbose = FALSE
) {

Expand All @@ -229,9 +235,21 @@ pfilter_internal <- function (
pred.var <- as.logical(pred.var)
filter.mean <- as.logical(filter.mean)
filter.traj <- as.logical(filter.traj)

save.states <- as.character(save.states)
save.states <- match.arg(save.states)

## Check if input is deprecated argument, and return warning.
state_arg_map <- c(
`FALSE`="no",`TRUE`="filter",
weighted="prediction",unweighted="filter"
)
if (save.states %in% names(state_arg_map)) {
pWarn_("The ", sQuote("save.states"), " option ",
sQuote(save.states)," is deprecated and will be removed in a future version.\nUse ", sQuote(state_arg_map[save.states]), " instead.")
save.states <- state_arg_map[save.states]
}

params <- coef(object)
times <- time(object,t0=TRUE)
ntimes <- length(times)-1
Expand All @@ -247,8 +265,8 @@ pfilter_internal <- function (
x <- init.x

## set up storage for saving samples from filtering distributions
stsav <- save.states %in% c("unweighted","TRUE")
wtsav <- save.states == "weighted"
stsav <- save.states %in% c("filter")
wtsav <- save.states == "prediction"
if (stsav || wtsav || filter.traj) {
xparticles <- vector(mode="list",length=ntimes)
if (wtsav) xweights <- xparticles
Expand Down
5 changes: 3 additions & 2 deletions R/saved_states.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
##'
##' Retrieve latent state trajectories from a particle filter calculation.
##'
##' When one calls \code{\link{pfilter}} with \code{save.states=TRUE}, the latent state vector associated with each particle is saved.
##' When one calls \code{\link{pfilter}} with \code{save.states="filter"} or \code{save.states="prediction"}, the latent state vector associated with each particle is saved.
##' This can be extracted by calling \code{saved_states} on the \sQuote{pfilterd.pomp} object.
##' These are the \emph{unweighted} particles, saved \emph{after} resampling.
##' If the filtered particles are saved, these particles are \emph{unweighted}, saved \emph{after} resampling using their normalized weights.
##' If the argument \code{save.states="prediction"} was used, the particles correspond to simulations from \code{rprocess}, and their corresponding unnormalized weights are included in the output.
##'
##' @name saved_states
##' @aliases saved_states,ANY-method saved_states,missing-method
Expand Down
2 changes: 1 addition & 1 deletion examples/pfilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ eff_sample_size(pf) ## effective sample size
logLik(pfilter(pf)) ## run it again with 1000 particles

## run it again with 2000 particles
pf <- pfilter(pf,Np=2000,filter.mean=TRUE,filter.traj=TRUE,save.states="weighted")
pf <- pfilter(pf,Np=2000,filter.mean=TRUE,filter.traj=TRUE,save.states="filter")
fm <- filter_mean(pf) ## extract the filtering means
ft <- filter_traj(pf) ## one draw from the smoothing distribution
ss <- saved_states(pf,format="d") ## the latent-state portion of each particle
Expand Down
19 changes: 12 additions & 7 deletions man/pfilter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/saved_states.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/bake.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ stopifnot(
identical(x5,x6),
length(out)==2,
grepl("^NOTE: creating archive directory",out),
grepl("results/bob/mary'.",out),
grepl("results/results/bob/mary'.",out[2])
grepl("results/bob/mary.",out),
grepl("results/results/bob/mary.",out[2])
)
4 changes: 2 additions & 2 deletions tests/bake.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ NOTE: in 'stew': recomputing archive tmp2.rda.
+ identical(x5,x6),
+ length(out)==2,
+ grepl("^NOTE: creating archive directory",out),
+ grepl("results/bob/mary'.",out),
+ grepl("results/results/bob/mary'.",out[2])
+ grepl("results/bob/mary.",out),
+ grepl("results/results/bob/mary.",out[2])
+ )
>
14 changes: 7 additions & 7 deletions tests/bsplines1.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ try(y <- bspline_basis(x,degree=3,nbasis=3,names=letters[1:3]))
y <- bspline_basis(x,degree=3,nbasis=12,names=letters[1:12])
y <- bspline_basis(x,degree=3,nbasis=9,names="basis")
y <- bspline_basis(x,degree=3,nbasis=9,names="basis%02d")
matplot(x,y,type='l',ylim=c(0,1.1))
matplot(x,y,type="l",ylim=c(0,1.1))
lines(x,apply(y,1,sum),lwd=2)

x <- seq(-1,2,by=0.01)
try(y <- periodic_bspline_basis(x,nbasis=6,names=letters[1:2]))
y <- periodic_bspline_basis(x,nbasis=6,names=tail(letters,6))
y <- periodic_bspline_basis(x,nbasis=5,names="spline")
y <- periodic_bspline_basis(x,nbasis=5,names="spline%d")
matplot(x,y,type='l')
matplot(x,y,type="l")

x <- seq(0,1,length=5)
try(bspline_basis(x,degree=-1,nbasis=9))
Expand All @@ -39,13 +39,13 @@ d2 <- bspline_basis(x,nbasis=nb,degree=deg,deriv=2)

B <- apply(B,2,function(x)x-x[1])
dd <- apply(d,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
matplot(B,dd,type='l')
matplot(B,dd,type="l")
abline(a=0,b=1)
stopifnot(all(signif(diag(cor(B,dd)),6)==1))

d <- apply(d,2,function(x) x-x[1])
dd <- apply(d2,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
matplot(d,dd,type='l')
matplot(d,dd,type="l")
abline(a=0,b=1)
stopifnot(all(signif(diag(cor(d,dd)),6)==1))

Expand All @@ -56,13 +56,13 @@ d2<- periodic_bspline_basis(x,nbasis=nb,degree=deg,deriv=2)

B <- apply(B,2,function(x)x-x[1])
dd <- apply(d,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
matplot(B,dd,type='l')
matplot(B,dd,type="l")
abline(a=0,b=1)
stopifnot(all(signif(diag(cor(B,dd)),6)==1))

d <- apply(d,2,function(x) x-x[1])
dd <- apply(d2,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
matplot(d,dd,type='l')
matplot(d,dd,type="l")
abline(a=0,b=1)
stopifnot(all(signif(diag(cor(d,dd)),6)==1))

Expand All @@ -75,6 +75,6 @@ stopifnot(isTRUE(all(B==0)))
try(bspline_basis(x,degree=1,nbasis=6,rg=c(4,3)))
try(bspline_basis(x,degree=1,nbasis=6,rg=c(4,4)))
B <- bspline_basis(x,degree=1,nbasis=6,rg=c(-1,3))
matplot(x,B,type='l')
matplot(x,B,type="l")

dev.off()
14 changes: 7 additions & 7 deletions tests/bsplines1.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Error : in 'bspline_basis': must have nbasis > degree
> y <- bspline_basis(x,degree=3,nbasis=12,names=letters[1:12])
> y <- bspline_basis(x,degree=3,nbasis=9,names="basis")
> y <- bspline_basis(x,degree=3,nbasis=9,names="basis%02d")
> matplot(x,y,type='l',ylim=c(0,1.1))
> matplot(x,y,type="l",ylim=c(0,1.1))
> lines(x,apply(y,1,sum),lwd=2)
>
> x <- seq(-1,2,by=0.01)
Expand All @@ -39,7 +39,7 @@ Error : in 'periodic_bspline_basis': 'names' must be of length 1 or 6
> y <- periodic_bspline_basis(x,nbasis=6,names=tail(letters,6))
> y <- periodic_bspline_basis(x,nbasis=5,names="spline")
> y <- periodic_bspline_basis(x,nbasis=5,names="spline%d")
> matplot(x,y,type='l')
> matplot(x,y,type="l")
>
> x <- seq(0,1,length=5)
> try(bspline_basis(x,degree=-1,nbasis=9))
Expand Down Expand Up @@ -67,13 +67,13 @@ Error : in 'periodic_bspline_basis': must have deriv >= 0
>
> B <- apply(B,2,function(x)x-x[1])
> dd <- apply(d,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
> matplot(B,dd,type='l')
> matplot(B,dd,type="l")
> abline(a=0,b=1)
> stopifnot(all(signif(diag(cor(B,dd)),6)==1))
>
> d <- apply(d,2,function(x) x-x[1])
> dd <- apply(d2,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
> matplot(d,dd,type='l')
> matplot(d,dd,type="l")
> abline(a=0,b=1)
> stopifnot(all(signif(diag(cor(d,dd)),6)==1))
>
Expand All @@ -84,13 +84,13 @@ Error : in 'periodic_bspline_basis': must have deriv >= 0
>
> B <- apply(B,2,function(x)x-x[1])
> dd <- apply(d,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
> matplot(B,dd,type='l')
> matplot(B,dd,type="l")
> abline(a=0,b=1)
> stopifnot(all(signif(diag(cor(B,dd)),6)==1))
>
> d <- apply(d,2,function(x) x-x[1])
> dd <- apply(d2,2,function(x){y <- diffinv(x); (head(y,-1)+tail(y,-1))/2*dx})
> matplot(d,dd,type='l')
> matplot(d,dd,type="l")
> abline(a=0,b=1)
> stopifnot(all(signif(diag(cor(d,dd)),6)==1))
>
Expand All @@ -109,7 +109,7 @@ Error : in 'bspline_basis': improper range 'rg'
> try(bspline_basis(x,degree=1,nbasis=6,rg=c(4,4)))
Error : in 'bspline_basis': improper range 'rg'
> B <- bspline_basis(x,degree=1,nbasis=6,rg=c(-1,3))
> matplot(x,B,type='l')
> matplot(x,B,type="l")
>
> dev.off()
null device
Expand Down
2 changes: 1 addition & 1 deletion tests/dacca.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pfilter(
pred.mean=TRUE,
pred.var=TRUE,
filter.traj=TRUE,
save.states=TRUE
save.states="filter"
) -> pf

stopifnot(
Expand Down
2 changes: 1 addition & 1 deletion tests/dacca.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Type 'q()' to quit R.
+ pred.mean=TRUE,
+ pred.var=TRUE,
+ filter.traj=TRUE,
+ save.states=TRUE
+ save.states="filter"
+ ) -> pf
>
> stopifnot(
Expand Down
2 changes: 1 addition & 1 deletion tests/ebola.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pfilter(
pred.mean=TRUE,
pred.var=TRUE,
filter.traj=TRUE,
save.states=TRUE
save.states="filter"
) -> pf

logLik(pf)
Expand Down
2 changes: 1 addition & 1 deletion tests/ebola.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Type 'q()' to quit R.
+ pred.mean=TRUE,
+ pred.var=TRUE,
+ filter.traj=TRUE,
+ save.states=TRUE
+ save.states="filter"
+ ) -> pf
>
> logLik(pf)
Expand Down
2 changes: 1 addition & 1 deletion tests/gompertz.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pfilter(
pred.mean=TRUE,
pred.var=TRUE,
filter.traj=TRUE,
save.states=TRUE
save.states="filter"
) -> pf

stopifnot(
Expand Down
2 changes: 1 addition & 1 deletion tests/gompertz.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Type 'q()' to quit R.
+ pred.mean=TRUE,
+ pred.var=TRUE,
+ filter.traj=TRUE,
+ save.states=TRUE
+ save.states="filter"
+ ) -> pf
>
> stopifnot(
Expand Down
4 changes: 2 additions & 2 deletions tests/lookup.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ lookup(ct,t=c(1,2.3,4,7,20))
ct <- covariate_table(x=20:30,y=10:0,times=seq(0,10),order="constant")
lookup(ct,t=c(1,2.3,4,7))
lookup(ct,t=6.1)
plot(y~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type='l')
plot(y~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type="l")
lines(seq(0,10),10:0,col="blue",type="s")
plot(x~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type='l')
plot(x~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type="l")
lines(seq(0,10),20:30,col="blue",type="s")

ct <- covariate_table(x=20:31,y=12:1,times=c(0:5,5:10),order="constant")
Expand Down
4 changes: 2 additions & 2 deletions tests/lookup.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ in 'table_lookup': extrapolating at 2.000000e+01.
> lookup(ct,t=6.1)
t x y
1 6.1 26 4
> plot(y~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type='l')
> plot(y~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type="l")
There were 50 or more warnings (use warnings() to see the first 50)
> lines(seq(0,10),10:0,col="blue",type="s")
> plot(x~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type='l')
> plot(x~t,data=lookup(ct,t=seq(0,10.5,by=0.01)),type="l")
There were 50 or more warnings (use warnings() to see the first 50)
> lines(seq(0,10),20:30,col="blue",type="s")
>
Expand Down
2 changes: 1 addition & 1 deletion tests/ou2.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pfilter(
pred.mean=TRUE,
pred.var=TRUE,
filter.traj=TRUE,
save.states=TRUE
save.states="filter"
) -> pf

plot(pf,yax.flip=TRUE)
Expand Down
2 changes: 1 addition & 1 deletion tests/ou2.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Type 'q()' to quit R.
+ pred.mean=TRUE,
+ pred.var=TRUE,
+ filter.traj=TRUE,
+ save.states=TRUE
+ save.states="filter"
+ ) -> pf
>
> plot(pf,yax.flip=TRUE)
Expand Down
Loading

0 comments on commit 8086bb1

Please sign in to comment.