Skip to content

Commit

Permalink
Simplify to make extensions implement add_predictor (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jun 6, 2024
1 parent e736aee commit a074f63
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 24 deletions.
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,50 @@ This project is inspired by two existing projects:

## Supported models

Use `add_predictor`:
Use `Omelette.add_predictor(model, predictor, x)` to add the relationship
`y = predictor(x)` to `model`:

```julia
y = Omelette.add_predictor(model, predictor, x)
```

### LinearRegression
The following predictors are supported. See their docstrings for details:

* `Omelette.LinearRegression`
* `Omelette.LogisticRegression`
* `Omelette.Pipeline`
* `Omelette.ReLUBigM`
* `Omelette.ReLUSOS1`
* `Omelette.ReLUQuadratic`

## Extensions

The following third-party package extensions are supported.

### [GLM.jl](https://github.com/JuliaStats/GLM.jl)

#### LinearRegression

```julia
using Omelette, GLM
X, Y = rand(10, 2), rand(10)
model_glm = GLM.lm(X, Y)
predictor = Omelette.LinearRegression(model_glm)
y = Omelette.add_predictor(model, model_glm, x)
```

### LogisticRegression
#### LogisticRegression

```julia
using Omelette, GLM
X, Y = rand(10, 2), rand(Bool, 10)
model_glm = GLM.glm(X, Y, GLM.Bernoulli())
predictor = Omelette.LogisticRegression(model_glm)
y = Omelette.add_predictor(model, model_glm, x)
```

### [Lux.jl](https://github.com/LuxDL/Lux.jl)

See `test/test_Lux.jl` for an example.

## Other constraints

### UnivariateNormalDistribution
Expand Down
19 changes: 14 additions & 5 deletions ext/OmeletteGLMExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@

module OmeletteGLMExt

import JuMP
import Omelette
import GLM

function Omelette.LinearRegression(model::GLM.LinearModel)
return Omelette.LinearRegression(GLM.coef(model))
function Omelette.add_predictor(
model::JuMP.Model,
predictor::GLM.LinearModel,
x::Vector{JuMP.VariableRef},
)
inner_predictor = Omelette.LinearRegression(GLM.coef(predictor))
return Omelette.add_predictor(model, inner_predictor, x)
end

function Omelette.LogisticRegression(
model::GLM.GeneralizedLinearModel{
function Omelette.add_predictor(
model::JuMP.Model,
predictor::GLM.GeneralizedLinearModel{
GLM.GlmResp{Vector{Float64},GLM.Bernoulli{Float64},GLM.LogitLink},
},
x::Vector{JuMP.VariableRef},
)
return Omelette.LogisticRegression(GLM.coef(model))
inner_predictor = Omelette.LogisticRegression(GLM.coef(predictor))
return Omelette.add_predictor(model, inner_predictor, x)
end

end #module
27 changes: 19 additions & 8 deletions ext/OmeletteLuxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,38 @@

module OmeletteLuxExt

import Omelette
import JuMP
import Lux
import Omelette

function _add_predictor(predictor::Omelette.Pipeline, layer::Lux.Dense, p)
function _add_predictor(
predictor::Omelette.Pipeline,
layer::Lux.Dense,
p;
relu::Omelette.AbstractPredictor,
)
push!(predictor.layers, Omelette.LinearRegression(p.weight, vec(p.bias)))
if layer.activation === identity
# Do nothing
elseif layer.activation === Lux.NNlib.relu
push!(predictor.layers, Omelette.ReLUBigM(1e4))
push!(predictor.layers, relu)
else
error("Unsupported activation function: $x")
end
return
end

function Omelette.Pipeline(x::Lux.Experimental.TrainState)
predictor = Omelette.Pipeline(Omelette.AbstractPredictor[])
for (layer, parameter) in zip(x.model.layers, x.parameters)
_add_predictor(predictor, layer, parameter)
function Omelette.add_predictor(
model::JuMP.Model,
predictor::Lux.Experimental.TrainState,
x::Vector{JuMP.VariableRef};
relu::Omelette.AbstractPredictor = Omelette.ReLUBigM(1e4),
)
inner_predictor = Omelette.Pipeline(Omelette.AbstractPredictor[])
for (layer, parameter) in zip(predictor.model.layers, predictor.parameters)
_add_predictor(inner_predictor, layer, parameter; relu)
end
return predictor
return Omelette.add_predictor(model, inner_predictor, x)
end

end #module
3 changes: 1 addition & 2 deletions test/test_LinearRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ function test_LinearRegression_GLM()
model_glm = GLM.lm(X, Y)
model = Model(HiGHS.Optimizer)
set_silent(model)
model_ml = Omelette.LinearRegression(model_glm)
@variable(model, 0 <= x[1:num_features] <= 1)
@constraint(model, sum(x) == 1.5)
y = Omelette.add_predictor(model, model_ml, x)
y = Omelette.add_predictor(model, model_glm, x)
@objective(model, Max, only(y))
optimize!(model)
@assert is_solved_and_feasible(model)
Expand Down
3 changes: 1 addition & 2 deletions test/test_LogisticRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ function test_LogisticRegression_GLM()
model_glm = GLM.glm(X, Y, GLM.Bernoulli())
model = Model(Ipopt.Optimizer)
set_silent(model)
model_ml = Omelette.LogisticRegression(model_glm)
@variable(model, 0 <= x[1:num_features] <= 1)
@constraint(model, sum(x) == 1.5)
y = Omelette.add_predictor(model, model_ml, x)
y = Omelette.add_predictor(model, model_glm, x)
@objective(model, Max, only(y))
optimize!(model)
@assert is_solved_and_feasible(model)
Expand Down
3 changes: 1 addition & 2 deletions test/test_Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,10 @@ function test_end_to_end()
optimizer = Optimisers.Adam(0.03f0),
epochs = 250,
)
f = Omelette.Pipeline(state)
model = Model(HiGHS.Optimizer)
set_silent(model)
@variable(model, x)
y = Omelette.add_predictor(model, f, [x])
y = Omelette.add_predictor(model, state, [x])
@constraint(model, only(y) <= 4)
@objective(model, Min, x)
optimize!(model)
Expand Down

0 comments on commit a074f63

Please sign in to comment.