From e5cb8aa32ae43c9936f70b8ed65e196cfabdabb1 Mon Sep 17 00:00:00 2001
From: odow <o.dowson@gmail.com>
Date: Tue, 3 Sep 2024 15:13:44 +1200
Subject: [PATCH] WIP: add hooks for JuMP extensions

---
 src/predictors/Affine.jl             |  6 ++--
 src/predictors/BinaryDecisionTree.jl | 11 +++---
 src/predictors/GrayBox.jl            |  8 ++++-
 src/predictors/Quantile.jl           |  2 +-
 src/predictors/ReLU.jl               | 51 ++++++++++++++------------
 src/predictors/Scale.jl              |  7 ++--
 src/predictors/Sigmoid.jl            | 12 +++++--
 src/predictors/SoftMax.jl            | 25 ++++++++-----
 src/predictors/SoftPlus.jl           | 16 +++++++--
 src/predictors/Tanh.jl               |  6 ++--
 src/utilities.jl                     | 54 +++++++++++++++++++++++-----
 11 files changed, 138 insertions(+), 60 deletions(-)

diff --git a/src/predictors/Affine.jl b/src/predictors/Affine.jl
index 59880a7..f8e8d9c 100644
--- a/src/predictors/Affine.jl
+++ b/src/predictors/Affine.jl
@@ -61,8 +61,8 @@ end
 
 function add_predictor(model::JuMP.AbstractModel, predictor::Affine, x::Vector)
     m = size(predictor.A, 1)
-    y = JuMP.@variable(model, [1:m], base_name = "moai_Affine")
-    bounds = _get_variable_bounds.(x)
+    y = add_variables(model, predictor, x, m; base_name = "moai_Affine")
+    bounds = get_bounds.(x)
     for i in 1:size(predictor.A, 1)
         y_lb, y_ub = predictor.b[i], predictor.b[i]
         for j in 1:size(predictor.A, 2)
@@ -71,7 +71,7 @@ function add_predictor(model::JuMP.AbstractModel, predictor::Affine, x::Vector)
             y_ub += a_ij * ifelse(a_ij >= 0, ub, lb)
             y_lb += a_ij * ifelse(a_ij >= 0, lb, ub)
         end
-        _set_bounds_if_finite(y[i], y_lb, y_ub)
+        set_bounds(y[i], y_lb, y_ub)
     end
     JuMP.@constraint(model, predictor.A * x .+ predictor.b .== y)
     return y
diff --git a/src/predictors/BinaryDecisionTree.jl b/src/predictors/BinaryDecisionTree.jl
index 5b93c94..4beff62 100644
--- a/src/predictors/BinaryDecisionTree.jl
+++ b/src/predictors/BinaryDecisionTree.jl
@@ -75,14 +75,17 @@ function add_predictor(
     atol::Float64 = 0.0,
 )
     paths = _tree_to_paths(predictor)
-    z = JuMP.@variable(
+    vars = add_variables(
         model,
-        [1:length(paths)],
-        binary = true,
+        predictor,
+        x,
+        1 + length(paths);
         base_name = "moai_BinaryDecisionTree_z",
     )
+    y, z = vars[1], vars[2:end]
+    JuMP.set_name(y, "moai_BinaryDecisionTree_value")
+    JuMP.set_binary.(z)
     JuMP.@constraint(model, sum(z) == 1)
-    y = JuMP.@variable(model, base_name = "moai_BinaryDecisionTree_value")
     y_expr = JuMP.AffExpr(0.0)
     for (zi, (leaf, path)) in zip(z, paths)
         JuMP.add_to_expression!(y_expr, leaf, zi)
diff --git a/src/predictors/GrayBox.jl b/src/predictors/GrayBox.jl
index f22b0b3..3d1ff3c 100644
--- a/src/predictors/GrayBox.jl
+++ b/src/predictors/GrayBox.jl
@@ -73,7 +73,13 @@ end
 
 function add_predictor(model::JuMP.AbstractModel, predictor::GrayBox, x::Vector)
     op = add_predictor(model, ReducedSpace(predictor), x)
-    y = JuMP.@variable(model, [1:length(op)], base_name = "moai_GrayBox")
+    y = add_variables(
+        model,
+        predictor,
+        x,
+        length(op);
+        base_name = "moai_GrayBox",
+    )
     JuMP.@constraint(model, op .== y)
     return y
 end
diff --git a/src/predictors/Quantile.jl b/src/predictors/Quantile.jl
index 944ccb6..edc1aa1 100644
--- a/src/predictors/Quantile.jl
+++ b/src/predictors/Quantile.jl
@@ -44,7 +44,7 @@ function add_predictor(
     x::Vector,
 )
     M, N = length(x), length(predictor.quantiles)
-    y = JuMP.@variable(model, [1:N], base_name = "moai_quantile")
+    y = add_variables(model, predictor, x, N; base_name = "moai_quantile")
     quantile(q, x...) = Distributions.quantile(predictor.distribution(x...), q)
     for (qi, yi) in zip(predictor.quantiles, y)
         op_i = JuMP.add_nonlinear_operator(
diff --git a/src/predictors/ReLU.jl b/src/predictors/ReLU.jl
index 5441660..fcb44d1 100644
--- a/src/predictors/ReLU.jl
+++ b/src/predictors/ReLU.jl
@@ -43,10 +43,10 @@ julia> y = MathOptAI.add_predictor(model, MathOptAI.ReducedSpace(f), x)
 """
 struct ReLU <: AbstractPredictor end
 
-function add_predictor(model::JuMP.AbstractModel, ::ReLU, x::Vector)
-    ub = last.(_get_variable_bounds.(x))
-    y = JuMP.@variable(model, [1:length(x)], base_name = "moai_ReLU")
-    _set_bounds_if_finite.(y, 0, ub)
+function add_predictor(model::JuMP.AbstractModel, predictor::ReLU, x::Vector)
+    ub = last.(get_bounds.(x))
+    y = add_variables(model, predictor, x, length(x); base_name = "moai_ReLU")
+    set_bounds.(y, 0, ub)
     JuMP.@constraint(model, y .== max.(0, x))
     return y
 end
@@ -109,17 +109,25 @@ function add_predictor(
     x::Vector,
 )
     m = length(x)
-    bounds = _get_variable_bounds.(x)
-    y = JuMP.@variable(model, [1:m], base_name = "moai_ReLU")
-    _set_bounds_if_finite.(y, 0, last.(bounds))
+    bounds = get_bounds.(x)
+    vars = add_variables(
+        model,
+        predictor,
+        x,
+        2 * length(x);
+        base_name = "moai_ReLU",
+    )
+    y, z = vars[1:m], vars[m+1:end]
+    set_bounds.(y, 0, last.(bounds))
+    JuMP.set_binary.(z)
+    JuMP.set_name.(z, "")
     for i in 1:m
         lb, ub = bounds[i]
-        z = JuMP.@variable(model, binary = true)
         JuMP.@constraint(model, y[i] >= x[i])
         U = min(ub, predictor.M)
-        JuMP.@constraint(model, y[i] <= U * z)
+        JuMP.@constraint(model, y[i] <= U * z[i])
         L = min(max(0, -lb), predictor.M)
-        JuMP.@constraint(model, y[i] <= x[i] + L * (1 - z))
+        JuMP.@constraint(model, y[i] <= x[i] + L * (1 - z[i]))
     end
     return y
 end
@@ -178,14 +186,13 @@ function add_predictor(
     predictor::ReLUSOS1,
     x::Vector,
 )
-    m = length(x)
-    bounds = _get_variable_bounds.(x)
-    y = JuMP.@variable(model, [i in 1:m], base_name = "moai_ReLU")
-    _set_bounds_if_finite.(y, 0, last.(bounds))
-    z = JuMP.@variable(model, [1:m], lower_bound = 0, base_name = "_z")
-    _set_bounds_if_finite.(z, nothing, -first.(bounds))
+    bounds = get_bounds.(x)
+    y = add_variables(model, predictor, x, length(x); base_name = "moai_ReLU")
+    set_bounds.(y, 0, last.(bounds))
+    z = add_variables(model, predictor, x, length(x); base_name = "_z")
+    set_bounds.(z, 0, -first.(bounds))
     JuMP.@constraint(model, x .== y - z)
-    for i in 1:m
+    for i in 1:length(x)
         JuMP.@constraint(model, [y[i], z[i]] in MOI.SOS1([1.0, 2.0]))
     end
     return y
@@ -246,11 +253,11 @@ function add_predictor(
     x::Vector,
 )
     m = length(x)
-    bounds = _get_variable_bounds.(x)
-    y = JuMP.@variable(model, [1:m], base_name = "moai_ReLU")
-    _set_bounds_if_finite.(y, 0, last.(bounds))
-    z = JuMP.@variable(model, [1:m], base_name = "_z")
-    _set_bounds_if_finite.(z, 0, -first.(bounds))
+    bounds = get_bounds.(x)
+    y = add_variables(model, predictor, x, length(x); base_name = "moai_ReLU")
+    set_bounds.(y, 0, last.(bounds))
+    z = add_variables(model, predictor, x, length(x); base_name = "_z")
+    set_bounds.(z, 0, -first.(bounds))
     JuMP.@constraint(model, x .== y - z)
     JuMP.@constraint(model, y .* z .== 0)
     return y
diff --git a/src/predictors/Scale.jl b/src/predictors/Scale.jl
index 950cd53..d62cf76 100644
--- a/src/predictors/Scale.jl
+++ b/src/predictors/Scale.jl
@@ -54,15 +54,14 @@ function Base.show(io::IO, ::Scale)
 end
 
 function add_predictor(model::JuMP.AbstractModel, predictor::Scale, x::Vector)
-    m = length(predictor.scale)
-    y = JuMP.@variable(model, [1:m], base_name = "moai_Scale")
-    bounds = _get_variable_bounds.(x)
+    y = add_variables(model, predictor, x, length(x); base_name = "moai_Scale")
+    bounds = get_bounds.(x)
     for (i, scale) in enumerate(predictor.scale)
         y_lb = y_ub = predictor.bias[i]
         lb, ub = bounds[i]
         y_ub += scale * ifelse(scale >= 0, ub, lb)
         y_lb += scale * ifelse(scale >= 0, lb, ub)
-        _set_bounds_if_finite(y[i], y_lb, y_ub)
+        set_bounds(y[i], y_lb, y_ub)
     end
     JuMP.@constraint(model, predictor.scale .* x .+ predictor.bias .== y)
     return y
diff --git a/src/predictors/Sigmoid.jl b/src/predictors/Sigmoid.jl
index c07e087..a6f0564 100644
--- a/src/predictors/Sigmoid.jl
+++ b/src/predictors/Sigmoid.jl
@@ -45,9 +45,15 @@ julia> y = MathOptAI.add_predictor(model, MathOptAI.ReducedSpace(f), x)
 """
 struct Sigmoid <: AbstractPredictor end
 
-function add_predictor(model::JuMP.AbstractModel, ::Sigmoid, x::Vector)
-    y = JuMP.@variable(model, [1:length(x)], base_name = "moai_Sigmoid")
-    _set_bounds_if_finite.(y, 0, 1)
+function add_predictor(model::JuMP.AbstractModel, predictor::Sigmoid, x::Vector)
+    y = add_variables(
+        model,
+        predictor,
+        x,
+        length(x);
+        base_name = "moai_Sigmoid",
+    )
+    set_bounds.(y, 0, 1)
     JuMP.@constraint(model, [i in 1:length(x)], y[i] == 1 / (1 + exp(-x[i])))
     return y
 end
diff --git a/src/predictors/SoftMax.jl b/src/predictors/SoftMax.jl
index 2fa9e78..eaeb8ac 100644
--- a/src/predictors/SoftMax.jl
+++ b/src/predictors/SoftMax.jl
@@ -47,11 +47,18 @@ julia> y = MathOptAI.add_predictor(model, MathOptAI.ReducedSpace(f), x)
 """
 struct SoftMax <: AbstractPredictor end
 
-function add_predictor(model::JuMP.AbstractModel, ::SoftMax, x::Vector)
-    y = JuMP.@variable(model, [1:length(x)], base_name = "moai_SoftMax")
-    _set_bounds_if_finite.(y, 0, 1)
-    denom = JuMP.@variable(model, base_name = "moai_SoftMax_denom")
-    JuMP.set_lower_bound(denom, 0)
+function add_predictor(model::JuMP.AbstractModel, predictor::SoftMax, x::Vector)
+    vars = add_variables(
+        model,
+        predictor,
+        x,
+        1 + length(x);
+        base_name = "moai_SoftMax",
+    )
+    denom, y = vars[1], vars[2:end]
+    set_bounds.(y, 0, 1)
+    JuMP.set_name(denom, "moai_SoftMax_denom")
+    set_bounds(denom, 0, nothing)
     JuMP.@constraint(model, denom == sum(exp.(x)))
     JuMP.@constraint(model, y .== exp.(x) ./ denom)
     return y
@@ -59,11 +66,13 @@ end
 
 function add_predictor(
     model::JuMP.AbstractModel,
-    ::ReducedSpace{SoftMax},
+    predictor::ReducedSpace{SoftMax},
     x::Vector,
 )
-    denom = JuMP.@variable(model, base_name = "moai_SoftMax_denom")
-    JuMP.set_lower_bound(denom, 0)
+    vars =
+        add_variables(model, predictor, x, 1; base_name = "moai_SoftMax_denom")
+    denom = only(vars)
+    set_bounds(denom, 0, nothing)
     JuMP.@constraint(model, denom == sum(exp.(x)))
     return exp.(x) ./ denom
 end
diff --git a/src/predictors/SoftPlus.jl b/src/predictors/SoftPlus.jl
index 1e628f8..16a7ebf 100644
--- a/src/predictors/SoftPlus.jl
+++ b/src/predictors/SoftPlus.jl
@@ -43,9 +43,19 @@ julia> y = MathOptAI.add_predictor(model, MathOptAI.ReducedSpace(f), x)
 """
 struct SoftPlus <: AbstractPredictor end
 
-function add_predictor(model::JuMP.AbstractModel, ::SoftPlus, x::Vector)
-    y = JuMP.@variable(model, [1:length(x)], base_name = "moai_SoftPlus")
-    _set_bounds_if_finite.(y, 0, nothing)
+function add_predictor(
+    model::JuMP.AbstractModel,
+    predictor::SoftPlus,
+    x::Vector,
+)
+    y = add_variables(
+        model,
+        predictor,
+        x,
+        length(x);
+        base_name = "moai_SoftPlus",
+    )
+    set_bounds.(y, 0, nothing)
     JuMP.@constraint(model, y .== log.(1 .+ exp.(x)))
     return y
 end
diff --git a/src/predictors/Tanh.jl b/src/predictors/Tanh.jl
index 68c976e..e21602c 100644
--- a/src/predictors/Tanh.jl
+++ b/src/predictors/Tanh.jl
@@ -45,9 +45,9 @@ julia> y = MathOptAI.add_predictor(model, MathOptAI.ReducedSpace(f), x)
 """
 struct Tanh <: AbstractPredictor end
 
-function add_predictor(model::JuMP.AbstractModel, ::Tanh, x::Vector)
-    y = JuMP.@variable(model, [1:length(x)], base_name = "moai_Tanh")
-    _set_bounds_if_finite.(y, -1, 1)
+function add_predictor(model::JuMP.AbstractModel, predictor::Tanh, x::Vector)
+    y = add_variables(model, predictor, x, length(x); base_name = "moai_Tanh")
+    set_bounds.(y, -1, 1)
     JuMP.@constraint(model, y .== tanh.(x))
     return y
 end
diff --git a/src/utilities.jl b/src/utilities.jl
index 8a3055e..53020e0 100644
--- a/src/utilities.jl
+++ b/src/utilities.jl
@@ -4,7 +4,40 @@
 # Use of this source code is governed by a BSD-style license that can be found
 # in the LICENSE.md file.
 
-function _get_variable_bounds(x::JuMP.GenericVariableRef{T}) where {T}
+"""
+    add_variables(
+        model::JuMP.AbstractModel,
+        predictor::AbstractPredictor,
+        x::Vector;
+        base_name::String,
+    )
+
+!!! note
+    This method is for JuMP extensions. It should not be called in regular usage
+    of MathOptAI.
+"""
+function add_variables(
+    model::JuMP.AbstractModel,
+    predictor::AbstractPredictor,
+    x::Vector,
+    n::Int;
+    base_name::String,
+)
+    return JuMP.@variable(model, [1:n], base_name = base_name)
+end
+
+"""
+    get_bounds(x::JuMP.AbstractVariable)::Tuple
+
+Return a tuple of the `(lower, upper)` bounds associated with variable `x`.
+
+!!! note
+    This method is for JuMP extensions. It should not be called in regular usage
+    of MathOptAI.
+"""
+get_bounds(::Any) = -Inf, Inf
+
+function get_bounds(x::JuMP.GenericVariableRef{T}) where {T}
     lb, ub = typemin(T), typemax(T)
     if JuMP.has_upper_bound(x)
         ub = JuMP.upper_bound(x)
@@ -21,7 +54,18 @@ function _get_variable_bounds(x::JuMP.GenericVariableRef{T}) where {T}
     return lb, ub
 end
 
-function _set_bounds_if_finite(
+"""
+    set_bounds(x::JuMP.AbstractVariable, lower, upper)::Nothing
+
+Set the bounds of `x` to `lower` and `upper` respectively.
+
+!!! note
+    This method is for JuMP extensions. It should not be called in regular usage
+    of MathOptAI.
+"""
+set_bounds(::Any, ::Any, ::Any) = nothing
+
+function set_bounds(
     x::JuMP.GenericVariableRef{T},
     l::Union{Nothing,Real},
     u::Union{Nothing,Real},
@@ -34,9 +78,3 @@ function _set_bounds_if_finite(
     end
     return
 end
-
-# Default fallback: provide no detail on the bounds
-_get_variable_bounds(::Any) = -Inf, Inf
-
-# Default fallback: skip setting variable bound
-_set_bounds_if_finite(::Any, ::Any, ::Any) = nothing