Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Robbybp committed Nov 5, 2024
1 parent 594911f commit 1d9e8f2
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions ext/MathOptAIPythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,15 @@ function MathOptAI.GrayBox(
device::String = "cpu",
)
torch = PythonCall.pyimport("torch")
torch_model = torch.load(
predictor.filename;
weights_only = false,
).to(device)
torch_model = torch.load(predictor.filename; weights_only = false)
torch_model = torch_model.to(device)
J = torch.func.jacrev(torch_model)
H = torch.func.hessian(torch_model)
# TODO(odow): I'm not sure if there is a better way to get the output
# dimension of a torch model object?
output_size(::Any) = PythonCall.pyconvert(Int, torch_model[-1].out_features)
function callback(x)
py_x = torch.tensor(collect(x), device = device)
py_x = torch.tensor(collect(x); device = device)
py_value = torch_model(py_x).detach().cpu().numpy()
value = PythonCall.pyconvert(Vector, py_value)
py_jacobian = J(py_x).detach().cpu().numpy()
Expand Down

0 comments on commit 1d9e8f2

Please sign in to comment.