diff --git a/README.md b/README.md index 82da858..d905e9d 100644 --- a/README.md +++ b/README.md @@ -24,29 +24,50 @@ This project is inspired by two existing projects: ## Supported models -Use `add_predictor`: +Use `Omelette.add_predictor(model, predictor, x)` to add the relationship +`y = predictor(x)` to `model`: + ```julia y = Omelette.add_predictor(model, predictor, x) ``` -### LinearRegression +The following predictors are supported. See their docstrings for details: + + * `Omelette.LinearRegression` + * `Omelette.LogisticRegression` + * `Omelette.Pipeline` + * `Omelette.ReLUBigM` + * `Omelette.ReLUSOS1` + * `Omelette.ReLUQuadratic` + +## Extensions + +The following third-party package extensions are supported. + +### [GLM.jl](https://github.com/JuliaStats/GLM.jl) + +#### LinearRegression ```julia using Omelette, GLM X, Y = rand(10, 2), rand(10) model_glm = GLM.lm(X, Y) -predictor = Omelette.LinearRegression(model_glm) +y = Omelette.add_predictor(model, model_glm, x) ``` -### LogisticRegression +#### LogisticRegression ```julia using Omelette, GLM X, Y = rand(10, 2), rand(Bool, 10) model_glm = GLM.glm(X, Y, GLM.Bernoulli()) -predictor = Omelette.LogisticRegression(model_glm) +y = Omelette.add_predictor(model, model_glm, x) ``` +### [Lux.jl](https://github.com/LuxDL/Lux.jl) + +See `test/test_Lux.jl` for an example. + ## Other constraints ### UnivariateNormalDistribution diff --git a/ext/OmeletteGLMExt.jl b/ext/OmeletteGLMExt.jl index 880d76a..12dc8ac 100644 --- a/ext/OmeletteGLMExt.jl +++ b/ext/OmeletteGLMExt.jl @@ -5,19 +5,28 @@ module OmeletteGLMExt +import JuMP import Omelette import GLM -function Omelette.LinearRegression(model::GLM.LinearModel) - return Omelette.LinearRegression(GLM.coef(model)) +function Omelette.add_predictor( + model::JuMP.Model, + predictor::GLM.LinearModel, + x::Vector{JuMP.VariableRef}, +) + inner_predictor = Omelette.LinearRegression(GLM.coef(predictor)) + return Omelette.add_predictor(model, inner_predictor, x) end -function Omelette.LogisticRegression( - model::GLM.GeneralizedLinearModel{ +function Omelette.add_predictor( + model::JuMP.Model, + predictor::GLM.GeneralizedLinearModel{ GLM.GlmResp{Vector{Float64},GLM.Bernoulli{Float64},GLM.LogitLink}, }, + x::Vector{JuMP.VariableRef}, ) - return Omelette.LogisticRegression(GLM.coef(model)) + inner_predictor = Omelette.LogisticRegression(GLM.coef(predictor)) + return Omelette.add_predictor(model, inner_predictor, x) end end #module diff --git a/ext/OmeletteLuxExt.jl b/ext/OmeletteLuxExt.jl index 6f544f7..2ba3cb7 100644 --- a/ext/OmeletteLuxExt.jl +++ b/ext/OmeletteLuxExt.jl @@ -5,27 +5,38 @@ module OmeletteLuxExt -import Omelette +import JuMP import Lux +import Omelette -function _add_predictor(predictor::Omelette.Pipeline, layer::Lux.Dense, p) +function _add_predictor( + predictor::Omelette.Pipeline, + layer::Lux.Dense, + p; + relu::Omelette.AbstractPredictor, +) push!(predictor.layers, Omelette.LinearRegression(p.weight, vec(p.bias))) if layer.activation === identity # Do nothing elseif layer.activation === Lux.NNlib.relu - push!(predictor.layers, Omelette.ReLUBigM(1e4)) + push!(predictor.layers, relu) else error("Unsupported activation function: $x") end return end -function Omelette.Pipeline(x::Lux.Experimental.TrainState) - predictor = Omelette.Pipeline(Omelette.AbstractPredictor[]) - for (layer, parameter) in zip(x.model.layers, x.parameters) - _add_predictor(predictor, layer, parameter) +function Omelette.add_predictor( + model::JuMP.Model, + predictor::Lux.Experimental.TrainState, + x::Vector{JuMP.VariableRef}; + relu::Omelette.AbstractPredictor = Omelette.ReLUBigM(1e4), +) + inner_predictor = Omelette.Pipeline(Omelette.AbstractPredictor[]) + for (layer, parameter) in zip(predictor.model.layers, predictor.parameters) + _add_predictor(inner_predictor, layer, parameter; relu) end - return predictor + return Omelette.add_predictor(model, inner_predictor, x) end end #module diff --git a/test/test_LinearRegression.jl b/test/test_LinearRegression.jl index bce234b..09f019c 100644 --- a/test/test_LinearRegression.jl +++ b/test/test_LinearRegression.jl @@ -42,10 +42,9 @@ function test_LinearRegression_GLM() model_glm = GLM.lm(X, Y) model = Model(HiGHS.Optimizer) set_silent(model) - model_ml = Omelette.LinearRegression(model_glm) @variable(model, 0 <= x[1:num_features] <= 1) @constraint(model, sum(x) == 1.5) - y = Omelette.add_predictor(model, model_ml, x) + y = Omelette.add_predictor(model, model_glm, x) @objective(model, Max, only(y)) optimize!(model) @assert is_solved_and_feasible(model) diff --git a/test/test_LogisticRegression.jl b/test/test_LogisticRegression.jl index 38d7315..202db8f 100644 --- a/test/test_LogisticRegression.jl +++ b/test/test_LogisticRegression.jl @@ -43,10 +43,9 @@ function test_LogisticRegression_GLM() model_glm = GLM.glm(X, Y, GLM.Bernoulli()) model = Model(Ipopt.Optimizer) set_silent(model) - model_ml = Omelette.LogisticRegression(model_glm) @variable(model, 0 <= x[1:num_features] <= 1) @constraint(model, sum(x) == 1.5) - y = Omelette.add_predictor(model, model_ml, x) + y = Omelette.add_predictor(model, model_glm, x) @objective(model, Max, only(y)) optimize!(model) @assert is_solved_and_feasible(model) diff --git a/test/test_Lux.jl b/test/test_Lux.jl index cdcc138..5ae0ac4 100644 --- a/test/test_Lux.jl +++ b/test/test_Lux.jl @@ -71,11 +71,10 @@ function test_end_to_end() optimizer = Optimisers.Adam(0.03f0), epochs = 250, ) - f = Omelette.Pipeline(state) model = Model(HiGHS.Optimizer) set_silent(model) @variable(model, x) - y = Omelette.add_predictor(model, f, [x]) + y = Omelette.add_predictor(model, state, [x]) @constraint(model, only(y) <= 4) @objective(model, Min, x) optimize!(model)