diff --git a/ext/MathOptAIPythonCallExt.jl b/ext/MathOptAIPythonCallExt.jl
index e952b7b..cf0f104 100644
--- a/ext/MathOptAIPythonCallExt.jl
+++ b/ext/MathOptAIPythonCallExt.jl
@@ -18,6 +18,8 @@ import MathOptAI
         config::Dict = Dict{Any,Any}(),
         reduced_space::Bool = false,
         gray_box::Bool = false,
+        gray_box_hessian::Bool = false,
+        gray_box_device::String = "cpu",
     )
 
 Add a trained neural network from PyTorch via PythonCall.jl to `model`.
@@ -41,6 +43,8 @@ Add a trained neural network from PyTorch via PythonCall.jl to `model`.
    nonlinear operator, with gradients provided by `torch.func.jacrev`.
  * `gray_box_hessian`: if `true`, the gray box additionally computes the Hessian
    of the output using `torch.func.hessian`.
+ * `gray_box_device`: device used to construct PyTorch tensors, e.g. `"cuda"`
+   to run on an Nvidia GPU.
 """
 function MathOptAI.add_predictor(
     model::JuMP.AbstractModel,
@@ -63,6 +67,7 @@ end
         config::Dict = Dict{Any,Any}(),
         gray_box::Bool = false,
         gray_box_hessian::Bool = false,
+        gray_box_device::String = "cpu",
     )
 
 Convert a trained neural network from PyTorch via PythonCall.jl to a
@@ -87,18 +92,25 @@ Convert a trained neural network from PyTorch via PythonCall.jl to a
    nonlinear operator, with gradients provided by `torch.func.jacrev`.
  * `gray_box_hessian`: if `true`, the gray box additionally computes the Hessian
    of the output using `torch.func.hessian`.
+ * `gray_box_device`: device used to construct PyTorch tensors, e.g. `"cuda"`
+   to run on an Nvidia GPU.
 """
 function MathOptAI.build_predictor(
     predictor::MathOptAI.PytorchModel;
     config::Dict = Dict{Any,Any}(),
     gray_box::Bool = false,
     gray_box_hessian::Bool = false,
+    gray_box_device::String = "cpu",
 )
     if gray_box
         if !isempty(config)
             error("cannot specify the `config` kwarg if `gray_box = true`")
         end
-        return MathOptAI.GrayBox(predictor; hessian = gray_box_hessian)
+        return MathOptAI.GrayBox(
+            predictor;
+            hessian = gray_box_hessian,
+            device = gray_box_device,
+        )
     end
     torch = PythonCall.pyimport("torch")
     nn = PythonCall.pyimport("torch.nn")
@@ -132,24 +144,26 @@ end
 function MathOptAI.GrayBox(
     predictor::MathOptAI.PytorchModel;
     hessian::Bool = false,
+    device::String = "cpu",
 )
     torch = PythonCall.pyimport("torch")
     torch_model = torch.load(predictor.filename; weights_only = false)
+    torch_model = torch_model.to(device)
     J = torch.func.jacrev(torch_model)
     H = torch.func.hessian(torch_model)
     # TODO(odow): I'm not sure if there is a better way to get the output
     # dimension of a torch model object?
     output_size(::Any) = PythonCall.pyconvert(Int, torch_model[-1].out_features)
     function callback(x)
-        py_x = torch.tensor(collect(x))
-        py_value = torch_model(py_x).detach().numpy()
+        py_x = torch.tensor(collect(x); device = device)
+        py_value = torch_model(py_x).detach().cpu().numpy()
         value = PythonCall.pyconvert(Vector, py_value)
-        py_jacobian = J(py_x).detach().numpy()
+        py_jacobian = J(py_x).detach().cpu().numpy()
         jacobian = PythonCall.pyconvert(Matrix, py_jacobian)
         if !hessian
             return (; value, jacobian)
         end
-        hessians = PythonCall.pyconvert(Array, H(py_x).detach().numpy())
+        hessians = PythonCall.pyconvert(Array, H(py_x).detach().cpu().numpy())
         return (; value, jacobian, hessian = permutedims(hessians, (2, 3, 1)))
     end
     return MathOptAI.GrayBox(output_size, callback; has_hessian = hessian)