From 2bbff7a455bcbc15892c7ff0935fbecee3fa98b0 Mon Sep 17 00:00:00 2001 From: Robert Parker Date: Thu, 16 Jan 2025 19:13:37 -0500 Subject: [PATCH] Fix PytorchModel when last layer doesn't support out_features (#166) --- ext/MathOptAIPythonCallExt.jl | 11 +++++--- test/test_PythonCall.jl | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/ext/MathOptAIPythonCallExt.jl b/ext/MathOptAIPythonCallExt.jl index cf0f1041..494e71be 100644 --- a/ext/MathOptAIPythonCallExt.jl +++ b/ext/MathOptAIPythonCallExt.jl @@ -151,9 +151,14 @@ function MathOptAI.GrayBox( torch_model = torch_model.to(device) J = torch.func.jacrev(torch_model) H = torch.func.hessian(torch_model) - # TODO(odow): I'm not sure if there is a better way to get the output - # dimension of a torch model object? - output_size(::Any) = PythonCall.pyconvert(Int, torch_model[-1].out_features) + function output_size(x::Vector) + # Get the output size by passing a zero vector through the torch model. + # We do this instead of `torch_model[-1].out_features` as the last layer + # may not support out_features. + z = torch.zeros(length(x)) + y = torch_model(z) + return PythonCall.pyconvert(Int, PythonCall.pybuiltins.len(y)) + end function callback(x) py_x = torch.tensor(collect(x); device = device) py_value = torch_model(py_x).detach().cpu().numpy() diff --git a/test/test_PythonCall.jl b/test/test_PythonCall.jl index 4c11fff8..a66111d9 100644 --- a/test/test_PythonCall.jl +++ b/test/test_PythonCall.jl @@ -453,6 +453,55 @@ function test_model_Tanh_vector_GrayBox_hessian() return end +function test_model_Sigmoid_last_layer_GrayBox() + dir = mktempdir() + filename = joinpath(dir, "model_Sigmoid_last_layer_GrayBox.pt") + PythonCall.pyexec( + """ + import torch + + model = torch.nn.Sequential( + torch.nn.Linear(3, 16), + torch.nn.Sigmoid(), + ) + + torch.save(model, filename) + """, + @__MODULE__, + (; filename = filename), + ) + # Full-space + model = Model(Ipopt.Optimizer) + set_silent(model) + @variable(model, x[i in 1:3] == i) + ml_model = MathOptAI.PytorchModel(filename) + y, formulation = + MathOptAI.add_predictor(model, ml_model, x; gray_box = true) + @test num_variables(model) == 19 + @test num_constraints(model; count_variable_in_set_constraints = true) == 19 + optimize!(model) + @test is_solved_and_feasible(model) + @test ≈(_evaluate_model(filename, value.(x)), value.(y); atol = 1e-5) + # Reduced-space + model = Model(Ipopt.Optimizer) + set_silent(model) + @variable(model, x[i in 1:3] == i) + ml_model = MathOptAI.PytorchModel(filename) + y, formulation = MathOptAI.add_predictor( + model, + ml_model, + x; + gray_box = true, + reduced_space = true, + ) + @test num_variables(model) == 3 + @test num_constraints(model; count_variable_in_set_constraints = true) == 3 + optimize!(model) + @test is_solved_and_feasible(model) + @test ≈(_evaluate_model(filename, value.(x)), value.(y); atol = 1e-5) + return +end + end # module TestPythonCallExt.runtests()