Skip to content

Commit

Permalink
[Static Runtime] Add PyTorchPredictor::predict_managed_result to retu…
Browse files Browse the repository at this point in the history
…rn managed output tensors (pytorch#65598)

Summary:
Pull Request resolved: pytorch#65598

This change adds `PyTorchPredictor::predict_managed_result` to enable Static Runtime to return managed output tensors, allocated and owned by Static Runtime to accelerate inference workloads.

- `PyTorchPredictor::predict_managed_result` does only meaningful work for the overridden `PyTorchStaticRuntimePredictor::predict_managed_result`. For other subclasses, it returns a simple object that just wraps the returned `Ivalue`.

- When `manage_output_tensors` is enabled, a `StaticRuntime` cannot be reentered until its return value gets deallocated by calling `StaticRuntime::deallocateOutputTensors`. Currently an instance of `StaticRuntime` gets immediately pushed back to `static_runtime_pool` to be reentered again, and this cannot be done when `manage_output_tensors` is enabled. `PyTorchStaticRuntimePredictorManagedResult` makes sure to delay pushing a `StaticRuntime` instance back to the pool only after `StaticRuntime::deallocateOutputTensors` is called on the runtime instance.

- When `manage_output_tensors` is enabled, `PyTorchStaticRuntimePredictor::predict_managed_result` returns the prediction result, whose backing memory is managed by an instance of `StaticRuntime`. The lifetime of any value reachable from `PyTorchStaticRuntimePredictorManagedResult.get()` is expected to end before `PyTorchStaticRuntimePredictorManagedResult` gets destructed. As explained above, `PyTorchPredictorManagedResult`'s destruction pushes the runtime instance that returned the result back to `static_runtime_pool` to be reused again.

- The current API design of adding `predict_managed_result` instead of forcing `operator()` to return `PyTorchPredictorManagedResult` was motivated by the fact that `manage_output_tensors` will be selectively enabled just for a few models. In case `manage_output_tensors` becomes a commonly used feature we should revisit this API design to merge them together.

Reviewed By: hlu1

Differential Revision: D31149323

fbshipit-source-id: 5ca026188077232d6a49a46759124a978439d7b2
  • Loading branch information
d1jang authored and facebook-github-bot committed Nov 3, 2021
1 parent 18955d3 commit e86a5a3
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,17 +1145,24 @@ float StaticRuntime::benchmark_model(

const bool is_kwargs_empty = kwargs_list.size() == 0;
const std::unordered_map<std::string, c10::IValue> empty_kwargs;
bool manage_output_tensors = static_module_.opts().manage_output_tensors;
for (const auto i : c10::irange(warmup_runs)) {
(void)i; // Suppress unused variable warning
for (const auto j : c10::irange(args_list.size())) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
if (manage_output_tensors) {
deallocateOutputTensors();
}
}
}
caffe2::Timer timer;
for (const auto i : c10::irange(main_runs)) {
(void)i; // Suppress unused variable warning
for (const auto j : c10::irange(args_list.size())) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
if (manage_output_tensors) {
deallocateOutputTensors();
}
}
}
float millis = timer.MilliSeconds();
Expand Down Expand Up @@ -1253,7 +1260,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(

const bool is_kwargs_empty = kwargs_list.size() == 0;
const std::unordered_map<std::string, c10::IValue> empty_kwargs;

bool manage_output_tensors = static_module_.opts().manage_output_tensors;
// See comment on above use of InferenceMode for
// explanation.
c10::InferenceMode mode;
Expand All @@ -1273,13 +1280,19 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// iterations just use the already established memory planning.
timer.Start();
operator()(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
if (manage_output_tensors) {
deallocateOutputTensors();
}
results.first_iter_time = timer.MilliSeconds();

// warmup runs
for (const auto i : c10::irange(warmup_runs - 1)) {
(void)i; // Suppress unused variable warning
for (const auto j : c10::irange(args_list.size())) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
if (manage_output_tensors) {
deallocateOutputTensors();
}
}
}

Expand Down Expand Up @@ -1310,6 +1323,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
// clean up owning refs of input tensors
clean_up_input_ivalues();
}
if (manage_output_tensors) {
deallocateOutputTensors();
}
millis = timer.MilliSeconds();
results.memory_dealloc_time += millis;

Expand Down

0 comments on commit e86a5a3

Please sign in to comment.