Skip to content

Commit

Permalink
January clean-up (#438)
Browse files Browse the repository at this point in the history
* remove bors mention

* elseif to remove duplicate warning

* warning for exactly duplicate columns of G_ens

* localizer and failure handler, better readout on verbose

* condition throwing 2 warnings now throws 1

* update std_of_corr calculation, and remove precomputation call (though method remains)

* update loc example for readability and defaults

* remove unused method

* format

* remove unnecessary restriction on building observations

* typo and format

* remove unused type in SECNice

* example improvement for readme

* update readme

* update readme

* update readme

* update readme

* update readme

* working example

* try png in readme

* try png in readme

* try png in readme

* forgot to upload png

* relative path

* README adjustment

* README.md

* README.md

* check nonunique G rows

* format

* README.md

* fix index error

* format
  • Loading branch information
odunbar authored Jan 15, 2025
1 parent 747ea72 commit 53170ba
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 79 deletions.
84 changes: 78 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# EnsembleKalmanProcesses.jl
Implements optimization and approximate uncertainty quantification algorithms, Ensemble Kalman Inversion, and Ensemble Kalman Processes.


## Citing us

If you use the examples or code, please cite our article at JOSS in your published materials.
Implements optimization and approximate uncertainty quantification algorithms, Ensemble Kalman Inversion, and other Ensemble Kalman Processes.


| **Documentation** | [![dev][docs-latest-img]][docs-latest-url] |
Expand Down Expand Up @@ -36,6 +31,77 @@ If you use the examples or code, please cite our article at JOSS in your publish
### Requirements
Julia LTS version or newer

## What does the package do?
EnsembleKalmanProcesses (EKP) enables users to find an (locally-) optimal parameter set `u` for a computer code `G` to fit some (noisy) observational data `y`. It uses a suite of methods from the Ensemble Kalman filtering literature that have a long history of success in the weather forecasting community.

What makes EKP different?
- EKP algorithms are efficient (complexity doesn't strongly scale with number of parameters), and can optimize with noisy and complex parameter-to-data landscapes.
- We don't require differentiating the model `G` at all! you just need to be able to run it at different parameter configurations.
- We don't even require `G` to be coded up in Julia!
- Ensemble model evaluations are fully parallelizable - so we can exploit our HPC systems capabilities!
- We provide some lego-like interfaces for creating complex priors and observations.
- We provied easy interfaces to toggle between many different algorithms and configurable features.

## What does it look like to use?
Below we will outline the current user experience for using `EnsembleKalmanProcesses.jl`. Copy-paste the snippets to reproduce the results (up to random number generation).

We solve the classic inverse problem where we learn `y = G(u)`, noisy forward map `G` distributed as `N(0,Γ)`. For example,
```julia
using LinearAlgebra
G(u) = [
1/abs(u[1]),
sum(u[2:5]),
prod(u[3:4]),
u[1]^2-u[2]-u[3],
u[4],
u[5]^3,
] .+ 0.1*randn(6)
true_u = [3, 1, 2,-3,-4]
y = G(true_u)
Γ = (0.1)^2*I
```
We assume some prior knowledge of the parameters `u` in the problem (such as approximate scales, and the first parameter being positive), then we are ready to go!

```julia
using EnsembleKalmanProcesses
using EnsembleKalmanProcesses.ParameterDistributions

prior_u1 = constrained_gaussian("positive_with_mean_2", 2, 1, 0, Inf)
prior_u2 = constrained_gaussian("four_with_spread_5", 0, 5, -Inf, Inf, repeats=4)
prior = combine_distributions([prior_u1, prior_u2])

N_ensemble = 50
initial_ensemble = construct_initial_ensemble(prior, N_ensemble)
ensemble_kalman_process = EnsembleKalmanProcess(
initial_ensemble, y, Γ, Inversion(), verbose=true)

N_iterations = 10
for i in 1:N_iterations
params_i = get_ϕ_final(prior, ensemble_kalman_process)

G_matrix = hcat(
[G(params_i[:, i]) for i in 1:N_ensemble]... # Parallelize here!
)

update_ensemble!(ensemble_kalman_process, G_matrix)
end

final_solution = get_ϕ_mean_final(prior, ensemble_kalman_process)


# Let's see what's going on!
using Plots
p = plot(prior)
for (i,sp) in enumerate(p.subplots)
vline!(sp, [true_u[i]], lc="black", lw=4)
vline!(sp, [final_solution[i]], lc="magenta", lw=4)
end
display(p)
```
![quick-readme-example](docs/src/assets/readme_example.png)

See a similar working example [here!](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/literated/sinusoid_example/). Check out our many example scripts above in `examples/`

# Quick links!

- [How do I build prior distributions?](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/parameter_distributions/)
Expand All @@ -47,6 +113,12 @@ Julia LTS version or newer
- [What is this error/warning/message?](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/troubleshooting/)
- [Where can I walk through a simple example?](https://clima.github.io/EnsembleKalmanProcesses.jl/dev/literated/sinusoid_example/)


## Citing us

If you use the examples or code, please cite [our article at JOSS](https://joss.theoj.org/papers/10.21105/joss.04869) in your published materials.


### Getting Started
![eki-getting-started](https://github.com/CliMA/EnsembleKalmanProcesses.jl/assets/45243236/e083ab8c-4f93-432f-9ad5-97aff22764ad)
<!---
Expand Down
Binary file added docs/src/assets/readme_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 1 addition & 5 deletions docs/src/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,4 @@ do not require the unit tests to be run.

### The merge process

We use [`bors`](https://bors.tech/) to manage merging PR's in the the `EnsembleKalmanProcesses` repo.
If you're a collaborator and have the necessary permissions, you can type
`bors try` in a comment on a PR to have integration test suite run on that
PR, or `bors r+` to try and merge the code. Bors ensures that all integration tests
for a given PR always pass before merging into `main`. The integration tests currently run example cases in `examples/`. Any breaking changes will need to also update the `examples/`, else bors will fail.
If you're a collaborator and have the necessary permissions, and if you have both approved code-review and the (necessary) integration tests passing, then you may merge the pull-request into `main`. Our preferred method is the to click the `Squash and Merge` button set as default on the pull request.
15 changes: 11 additions & 4 deletions examples/Localization/localization_example_lorenz96.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,15 @@ initial_ensemble = EKP.construct_initial_ensemble(rng, prior, N_ens)


# Solve problem without localization
ekiobj_vanilla =
EKP.EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(); rng = rng, scheduler = DefaultScheduler())
ekiobj_vanilla = EKP.EnsembleKalmanProcess(
initial_ensemble,
y,
Γ,
Inversion();
rng = rng,
scheduler = DefaultScheduler(),
localization_method = NoLocalization(),
)
for i in 1:N_iter
g_ens_vanilla = G(get_ϕ_final(prior, ekiobj_vanilla))
EKP.update_ensemble!(ekiobj_vanilla, g_ens_vanilla, deterministic_forward_map = true)
Expand All @@ -94,7 +101,7 @@ ekiobj_inflated = EKP.EnsembleKalmanProcess(
Inversion();
rng = rng,
scheduler = DefaultScheduler(),
# localization_method = BernoulliDropout(0.98),
localization_method = NoLocalization(),
)

for i in 1:N_iter
Expand Down Expand Up @@ -199,7 +206,7 @@ fig = plot(
bottom_margin = 5Plots.mm,
left_margin = 5Plots.mm,
)
plot!(get_error(ekiobj_inflated), label = "Inflation only", lw = 6)
plot!(get_error(ekiobj_inflated), label = "Inflation only", lw = 6, ls = :dash)
plot!(get_error(ekiobj_sec), label = "SEC (Lee, 2021)", lw = 6)
plot!(get_error(ekiobj_sec_fisher), label = "SECFisher (Flowerdew, 2015)", lw = 6)
plot!(get_error(ekiobj_sec_cutoff), label = "SEC with cutoff", lw = 6)
Expand Down
22 changes: 15 additions & 7 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ function EnsembleKalmanProcess(

if N_ens < 10
@warn "Recommended minimum ensemble size (`N_ens`) is 10. Got `N_ens` = $(N_ens)."
end
if (N_par < 10) && (N_ens < 10 * N_par)
elseif (N_par < 10) && (N_ens < 10 * N_par)
@warn "For $(N_par) parameters, the recommended minimum ensemble size (`N_ens`) is $(10*(N_par)). Got `N_ens` = $(N_ens)`."
end
if (N_par >= 10) && (N_ens < 100)
Expand Down Expand Up @@ -261,18 +260,19 @@ function EnsembleKalmanProcess(
end

# failure handler
failure_handler = FailureHandler(process, configuration["failure_handler_method"])
fh_method = configuration["failure_handler_method"]
failure_handler = FailureHandler(process, fh_method)

# localizer
if isa(process, TransformInversion) && !(isa(configuration["localization_method"], NoLocalization))
loc_method = configuration["localization_method"]
if isa(process, TransformInversion) && !(isa(loc_method, NoLocalization))
throw(ArgumentError("`TransformInversion` cannot currently be used with localization."))
end

localizer = Localizer(configuration["localization_method"], N_ens, FT)

localizer = Localizer(loc_method, N_ens, FT)

if verbose
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(localizer)))\nFailure handler: $(nameof(typeof(failure_handler)))\nScheduler: $(nameof(typeof(scheduler)))\nAccelerator: $(nameof(typeof(accelerator)))"
@info "Initializing ensemble Kalman process of type $(nameof(typeof(process)))\nNumber of ensemble members: $(N_ens)\nLocalization: $(nameof(typeof(loc_method)))\nFailure handler: $(nameof(typeof(fh_method)))\nScheduler: $(nameof(typeof(scheduler)))\nAccelerator: $(nameof(typeof(accelerator)))"
end

EnsembleKalmanProcess{FT, IT, P, RS, AC, VVV}(
Expand Down Expand Up @@ -927,6 +927,14 @@ function update_ensemble!(
),
)
end
# check if columns of g are the same (and not NaN)
n_nans = sum(isnan.(sum(g, dims = 1)))
nan_adjust = (n_nans > 0) ? -n_nans + 1 : 0
# as unique reduces NaNs to one column if present. or 0 if not
if length(unique(eachcol(g))) < size(g, 2) + nan_adjust
nonunique_cols = size(g, 2) + nan_adjust - length(unique(eachcol(g)))
@warn "Detected $(nonunique_cols) clashes where forward map evaluations are exactly equal (and not NaN), this is likely to cause `LinearAlgebra` difficulty. Please check forward evaluations for bugs."
end

terminate = calculate_timestep!(ekp, g, Δt_new)
if isnothing(terminate)
Expand Down
64 changes: 11 additions & 53 deletions src/Localizers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using Interpolations

export NoLocalization, Delta, RBF, BernoulliDropout, SEC, SECFisher, SECNice
export LocalizationMethod, Localizer
export approximate_corr_std
abstract type LocalizationMethod end

"Idempotent localization method."
Expand Down Expand Up @@ -104,19 +103,16 @@ Thus no algorithm parameters are required, though some tuning of the discrepancy
$(TYPEDFIELDS)
"""
struct SECNice{FT <: Real, AV <: AbstractVector} <: LocalizationMethod
struct SECNice{FT <: Real} <: LocalizationMethod
"number of samples to approximate the std of correlation distribution (default 1000)"
n_samples::Int
"scaling for discrepancy principle for ug correlation (default 1.0)"
δ_ug::FT
"scaling for discrepancy principle for gg correlation (default 1.0)"
δ_gg::FT
"A vector that will house a Interpolation object on first call to the localizer"
std_of_corr::AV
end
SECNice() = SECNice(1000, 1.0, 1.0)
SECNice() = SECNice(1000, 1.0, 1.0) # best
SECNice(δ_ug, δ_gg) = SECNice(1000, δ_ug, δ_gg)
SECNice(n_samples, δ_ug, δ_gg) = SECNice(n_samples, δ_ug, δ_gg, []) # always start with empty

"""
Localizer{LM <: LocalizationMethod, T}
Expand Down Expand Up @@ -274,43 +270,12 @@ function Localizer(localization::SECFisher, J::Int, T = Float64)
return Localizer{SECFisher, T}((cov, T, p, d, J) -> sec_fisher(cov, J))
end

"""
For `N_ens >= 6`: The sampling distribution of a correlation coefficient for Gaussian random variables is, under the Fisher transformation, approximately Gaussian. To estimate the standard deviation in the sampling distribution of the correlation coefficient, we draw samples from a Gaussian, apply the inverse Fisher transformation to them, and estimate an empirical standard deviation from the transformed samples.
For `N_ens < 6`: Approximate the standard deviation of correlation coefficient empirically by sampling between two correlated Gaussians of known coefficient.
"""
function approximate_corr_std(r, N_ens, n_samples)

if N_ens >= 6 # apply Fisher Transform
# ρ = arctanh(r) from Fisher
# assume r input is the mean value, i.e. assume arctanh(E(r)) = E(arctanh(r))

ρ = r # approx solution is the identity
#sample in ρ space
ρ_samples = rand(Normal(0.5 * log((1 + ρ) / (1 - ρ)), 1 / sqrt(N_ens - 3)), n_samples)

# map back through Fisher to get std of r from samples tanh(ρ)
return std(tanh.(ρ_samples))
else # transformation not appropriate for N < 6
# Generate sample pairs with a correlation coefficient r
samples_1 = rand(Normal(0, 1), N_ens, n_samples)
samples_2 = rand(Normal(0, 1), N_ens, n_samples)
samples_corr_with_1 = r * samples_1 + sqrt(1 - r^2) * samples_2 # will have correlation r with samples_1

corrs = zeros(n_samples)
for i in 1:n_samples
corrs[i] = cor(samples_1[:, i], samples_corr_with_1[:, i])
end
return std(corrs)
end

end


"""
Function that performs sampling error correction as per Vishny, Morzfeld, et al. (2024).
The input is assumed to be a covariance matrix, hence square.
The input is assumed to be a covariance matrix, hence square. The standard deviation for a correlation `corr` with `N_ens` samples is internally estimated simply by `std_corrs = (1 .- corr)/sqrt(N_ens)`. This requires no precomputation and appears sufficiently accurate.
"""
function sec_nice(cov, std_of_corr, δ_ug, δ_gg, N_ens, p, d)
function sec_nice(cov, δ_ug, δ_gg, N_ens, p, d)
bd_tol = 1e8 * eps()

v = sqrt.(diag(cov))
Expand All @@ -330,9 +295,11 @@ function sec_nice(cov, std_of_corr, δ_ug, δ_gg, N_ens, p, d)
for (idx_set, δ) in zip([ug_idx, gg_idx], [δ_ug, δ_gg])

corr_tmp = corr[idx_set...]
# use find the variability in the corr coeff matrix entries
# std_corrs = approximate_corr_std.(corr_tmp, N_ens, n_samples) # !! slowest part of code -> could speed up by precomputing/using an interpolation
std_corrs = std_of_corr.(corr_tmp)

# Find the variability in the corr coeff matrix entries
# Below has no precomputation and is surprisingly fine accuracy! (~10^-4 error to empirical at N_ens=20)
std_corrs = (1 .- corr_tmp) / sqrt(N_ens)

std_tol = sqrt(sum(std_corrs .^ 2))
γ_min_exceeded = max_exponent
for γ in 2:2:max_exponent # even exponents give a PSD correction
Expand Down Expand Up @@ -366,16 +333,7 @@ end

"Sampling error correction of Vishny, Morzfeld, et al. (2024) constructor"
function Localizer(localization::SECNice, J::Int, T = Float64)
if length(localization.std_of_corr) == 0 #i.e. if the user hasn't provided an interpolation
dr = 0.001
grid = LinRange(-1, 1, Int(1 / dr + 1))
std_grid = approximate_corr_std.(grid, J, localization.n_samples) # odd number to include 0
push!(localization.std_of_corr, linear_interpolation(grid, std_grid)) # pw-linear interpolation
end

return Localizer{SECNice, T}(
(cov, T, p, d, J) -> sec_nice(cov, localization.std_of_corr[1], localization.δ_ug, localization.δ_gg, J, p, d),
)
return Localizer{SECNice, T}((cov, T, p, d, J) -> sec_nice(cov, localization.δ_ug, localization.δ_gg, J, p, d))
end


Expand Down
6 changes: 3 additions & 3 deletions src/Observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,10 @@ function get_obs(os::OS; build = true) where {OS <: ObservationSeries}
if !build # return y as vec of vecs
return get_obs.(observations_vec, build = false)
else # stack y
sample_length = length(get_obs(observations_vec[1], build = true))
minibatch_samples = zeros(sample_length * minibatch_length)
sample_lengths = [length(get_obs(ov, build = true)) for ov in observations_vec]
minibatch_samples = zeros(sum(sample_lengths))
for (i, observation) in enumerate(observations_vec)
idx = ((i - 1) * sample_length + 1):(i * sample_length)
idx = (sum(sample_lengths[1:(i - 1)]) + 1):sum(sample_lengths[1:i])
minibatch_samples[idx] = get_obs(observation, build = true)
end
return minibatch_samples
Expand Down
17 changes: 16 additions & 1 deletion test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ end
#check for small ens
y_obs_tmp, G_tmp, Γy_tmp, A_tmp = inv_problems[1]
initial_ensemble_small = EKP.construct_initial_ensemble(rng, prior, 9)
@test_logs (:warn,) (:warn,) EKP.EnsembleKalmanProcess(initial_ensemble_small, y_obs_tmp, Γy_tmp, Inversion()) # throws two warnings
@test_logs (:warn,) EKP.EnsembleKalmanProcess(initial_ensemble_small, y_obs_tmp, Γy_tmp, Inversion())
prior_60dims = constrained_gaussian("60dims", 0, 1, -Inf, Inf, repeats = 60)
initial_ensemble_small = EKP.construct_initial_ensemble(rng, prior_60dims, 99)
@test_logs (:warn,) EKP.EnsembleKalmanProcess(initial_ensemble_small, y_obs_tmp, Γy_tmp, Inversion())
Expand Down Expand Up @@ -527,6 +527,16 @@ end
scheduler = deepcopy(scheduler),
localization_method = deepcopy(localization_method),
)
ekiobj2 = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Γy,
Inversion();
rng = copy(rng),
failure_handler_method = SampleSuccGauss(),
scheduler = deepcopy(scheduler),
localization_method = deepcopy(localization_method),
)
ekiobj_unsafe = EKP.EnsembleKalmanProcess(
initial_ensemble,
y_obs,
Expand Down Expand Up @@ -580,6 +590,11 @@ end
@test_throws DimensionMismatch EKP.update_ensemble!(ekiobj, g_ens_t)
end

# test for additional warning if two columns are equal in
g_ens_nonunique = copy(g_ens)
g_ens_nonunique[:, 2] = g_ens_nonunique[:, 3]
@test_logs (:warn,) update_ensemble!(ekiobj2, g_ens_nonunique)

# test the deterministic flag on only one iteration for errors
EKP.update_ensemble!(ekiobj_nonoise_update, g_ens, deterministic_forward_map = false)
@info "No error with flag deterministic_forward_map = false"
Expand Down

0 comments on commit 53170ba

Please sign in to comment.