Skip to content

Commit

Permalink
add gray_box_device option to docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Robbybp committed Nov 5, 2024
1 parent 87e7dcf commit ca5b7d1
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion ext/MathOptAIPythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import MathOptAI
config::Dict = Dict{Any,Any}(),
reduced_space::Bool = false,
gray_box::Bool = false,
gray_box_hessian::Bool = false,
gray_box_device::String = "cpu",
)
Add a trained neural network from PyTorch via PythonCall.jl to `model`.
Expand All @@ -41,6 +43,8 @@ Add a trained neural network from PyTorch via PythonCall.jl to `model`.
nonlinear operator, with gradients provided by `torch.func.jacrev`.
* `gray_box_hessian`: if `true`, the gray box additionally computes the Hessian
of the output using `torch.func.hessian`.
* `gray_box_device`: device used to construct PyTorch tensors, e.g. `"cuda"`
to run on an Nvidia GPU.
"""
function MathOptAI.add_predictor(
model::JuMP.AbstractModel,
Expand All @@ -63,6 +67,7 @@ end
config::Dict = Dict{Any,Any}(),
gray_box::Bool = false,
gray_box_hessian::Bool = false,
gray_box_device::String = "cpu",
)
Convert a trained neural network from PyTorch via PythonCall.jl to a
Expand All @@ -87,13 +92,15 @@ Convert a trained neural network from PyTorch via PythonCall.jl to a
nonlinear operator, with gradients provided by `torch.func.jacrev`.
* `gray_box_hessian`: if `true`, the gray box additionally computes the Hessian
of the output using `torch.func.hessian`.
* `gray_box_device`: device used to construct PyTorch tensors, e.g. `"cuda"`
to run on an Nvidia GPU.
"""
function MathOptAI.build_predictor(
predictor::MathOptAI.PytorchModel;
config::Dict = Dict{Any,Any}(),
gray_box::Bool = false,
gray_box_hessian::Bool = false,
gray_box_device = "cpu",
gray_box_device::String = "cpu",
)
if gray_box
if !isempty(config)
Expand Down

0 comments on commit ca5b7d1

Please sign in to comment.