Skip to content

Commit

Permalink
Add support for LogisticRegression (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored May 21, 2024
1 parent be7026d commit 94db7dc
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 13 deletions.
30 changes: 17 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,29 @@ This project is inspired by two existing projects:

## Supported models

Use `add_predictor` to add a model.
Use `add_predictor`:
```julia
Omelette.add_predictor(model, model_ml, x, y)
y = Omelette.add_predictor(model, model_ml, x)
y = Omelette.add_predictor(model, predictor, x)
```
or:
```julia
Omelette.add_predictor!(model, predictor, x, y)
```

### LinearRegression

```julia
num_features, num_observations = 2, 10
X = rand(num_observations, num_features)
θ = rand(num_features)
Y = X * θ + randn(num_observations)
using Omelette, GLM
X, Y = rand(10, 2), rand(10)
model_glm = GLM.lm(X, Y)
predictor = Omelette.LinearRegression(model_glm)
model = Model(HiGHS.Optimizer)
set_silent(model)
@variable(model, 0 <= x[1:num_features] <= 1)
@constraint(model, sum(x) == 1.5)
y = Omelette.add_predictor(model, predictor, x)
@objective(model, Max, y[1])
```

### LogisticRegression

```julia
using Omelette, GLM
X, Y = rand(10, 2), rand(Bool, 10)
model_glm = GLM.glm(X, Y, GLM.Bernoulli())
predictor = Omelette.LogisticRegression(model_glm)
```
8 changes: 8 additions & 0 deletions ext/OmeletteGLMExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,12 @@ function Omelette.LinearRegression(model::GLM.LinearModel)
return Omelette.LinearRegression(GLM.coef(model))
end

function Omelette.LogisticRegression(
model::GLM.GeneralizedLinearModel{
GLM.GlmResp{Vector{Float64},GLM.Bernoulli{Float64},GLM.LogitLink},
},
)
return Omelette.LogisticRegression(GLM.coef(model))
end

end #module
24 changes: 24 additions & 0 deletions src/models/LogisticRegression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

struct LogisticRegression <: AbstractPredictor
parameters::Matrix{Float64}
end

function LogisticRegression(parameters::Vector{Float64})
return LogisticRegression(reshape(parameters, 1, length(parameters)))
end

Base.size(f::LogisticRegression) = size(f.parameters)

function _add_predictor_inner(
model::JuMP.Model,
predictor::LogisticRegression,
x::Vector{JuMP.VariableRef},
y::Vector{JuMP.VariableRef},
)
JuMP.@constraint(model, 1 ./ (1 .+ exp.(-predictor.parameters * x)) .== y)
return
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
[deps]
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Omelette = "e52c2cb8-508e-4e12-9dd2-9c4755b60e73"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
GLM = "1"
HiGHS = "1"
Ipopt = "1"
JuMP = "1"
Test = "<0.0.1, 1.6"
julia = "1.9"
75 changes: 75 additions & 0 deletions test/test_LogisticRegression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024: Oscar Dowson and contributors
#
# Use of this source code is governed by an MIT-style license that can be found
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.

module LogisticRegressionTests

using JuMP
using Test

import GLM
import Ipopt
import Omelette

is_test(x) = startswith(string(x), "test_")

function runtests()
@testset "$name" for name in filter(is_test, names(@__MODULE__; all = true))
getfield(@__MODULE__, name)()
end
return
end

function test_LogisticRegression()
model = Model()
@variable(model, x[1:2])
@variable(model, y[1:1])
f = Omelette.LogisticRegression([2.0, 3.0])
Omelette.add_predictor!(model, f, x, y)
cons = all_constraints(model; include_variable_in_set_constraints = false)
obj = constraint_object(only(cons))
@test obj.set == MOI.EqualTo(0.0)
g = 1.0 / (1.0 + exp(-2.0 * x[1] - 3.0 * x[2])) - y[1]
@test isequal_canonical(obj.func, g)
return
end

function test_LogisticRegression_dimension_mismatch()
model = Model()
@variable(model, x[1:3])
@variable(model, y[1:2])
f = Omelette.LogisticRegression([2.0, 3.0])
@test size(f) == (1, 2)
@test_throws DimensionMismatch Omelette.add_predictor!(model, f, x, y[1:1])
@test_throws DimensionMismatch Omelette.add_predictor!(model, f, x[1:2], y)
g = Omelette.LogisticRegression([2.0 3.0; 4.0 5.0; 6.0 7.0])
@test size(g) == (3, 2)
@test_throws DimensionMismatch Omelette.add_predictor!(model, g, x, y)
return
end

function test_LogisticRegression_GLM()
num_features = 2
num_observations = 10
X = rand(num_observations, num_features)
θ = rand(num_features)
Y = X * θ + randn(num_observations) .>= 0
model_glm = GLM.glm(X, Y, GLM.Bernoulli())
model = Model(Ipopt.Optimizer)
set_silent(model)
model_ml = Omelette.LogisticRegression(model_glm)
@variable(model, 0 <= x[1:num_features] <= 1)
@constraint(model, sum(x) == 1.5)
y = Omelette.add_predictor(model, model_ml, x)
@objective(model, Max, only(y))
optimize!(model)
@assert is_solved_and_feasible(model)
y_star_glm = GLM.predict(model_glm, value.(x)')
@test isapprox(objective_value(model), y_star_glm; atol = 1e-6)
return
end

end

LogisticRegressionTests.runtests()

0 comments on commit 94db7dc

Please sign in to comment.