Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Capacity aware partitioning #22766

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct OrtRunOptions;

namespace onnxruntime {

class IResourceAccountant;

/**
Logical device representation.
*/
Expand Down Expand Up @@ -130,7 +132,8 @@ class IExecutionProvider {
*/
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& kernel_lookup) const;
const IKernelLookup& kernel_lookup,
IResourceAccountant* resource_accountant = nullptr) const;

/**
Get kernel registry per execution provider type.
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/op_kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ class OpKernelContext {

virtual OrtValue* GetOrCreateOutputMLValue(int index);

virtual int GetOrtValueIndexForInput(int input_index) const;

virtual int GetOrtValueIndexForOutput(int output_index) const;

private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
int GetInputArgIndex(int index) const;
Expand Down
67 changes: 67 additions & 0 deletions include/onnxruntime/core/framework/resource_accountant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <optional>
#include <string>
#include <variant>

namespace onnxruntime {

// Common holder for potentially different resource accounting
// for different EPs
using ResourceCount = std::variant<size_t, std::monostate>;

/// <summary>
/// This class is used for graph partitioning by EPs
/// It stores the cumulative amount of the resource such as
/// memory that would be consumed by the graph nodes if it is assigned to the EP.
///
/// It provides interfaces to add, remove and query the resource consumption.
///
/// Each provider may assign its own meaning to the resource according to its constraints.
/// </summary>
class IResourceAccountant {
protected:
IResourceAccountant() = default;
IResourceAccountant(const ResourceCount& threshold) : threshold_(threshold) {}

Check warning on line 28 in include/onnxruntime/core/framework/resource_accountant.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: include/onnxruntime/core/framework/resource_accountant.h:28: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]

public:
virtual ~IResourceAccountant() = default;
virtual ResourceCount GetConsumedAmount() const = 0;
virtual void AddConsumedAmount(const ResourceCount& amount) = 0;
virtual void RemoveConsumedAmount(const ResourceCount& amount) = 0;
virtual ResourceCount ComputeResourceCount(const std::string& node_name) const = 0;

std::optional<ResourceCount> GetThreshold() const {
return threshold_;
}

void SetStopAssignment() {
stop_assignment_ = true;
}

bool IsStopIssued() const noexcept { return stop_assignment_; }

private:
bool stop_assignment_ = false;
std::optional<ResourceCount> threshold_;
};

// This struct keeps accounting of the memory allocation stats
// for a kernel during runtime if enabled.
struct NodeAllocationStats {
size_t initializers_sizes = 0;
size_t total_dynamic_sizes = 0;
size_t total_temp_allocations = 0;

NodeAllocationStats& operator+=(const NodeAllocationStats& other) {
initializers_sizes += other.initializers_sizes;
total_dynamic_sizes += other.total_dynamic_sizes;
total_temp_allocations += other.total_temp_allocations;
return *this;
}
};

} // namespace onnxruntime
7 changes: 7 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,13 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return ConstGraphNodes(nodes_, std::move(filter_func));
}

/** Compute node memory requirements, which is mostly initializers
and large attributes that are copied on device (special cases for some nodes)

Returns no value if the node was not found.
*/
size_t ComputeNodeMemoryUsage(NodeIndex) const;

/** Gets the maximum NodeIndex value used in the Graph.
WARNING: This actually returns the max index value used + 1.
*/
Expand Down
49 changes: 49 additions & 0 deletions include/onnxruntime/core/graph/indexed_sub_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <string>
#include <vector>

#include "core/common/inlined_containers_fwd.h"
#include "core/framework/resource_accountant.h"
#include "core/graph/basic_types.h"
#include "core/graph/onnx_protobuf.h"

Expand Down Expand Up @@ -70,9 +72,56 @@ struct IndexedSubGraph {
return meta_def_.get();
}

// Check if the accounting is enabled for the current EP
bool IsAccountingEnabled() const {
return resource_accountant != nullptr &&
nodes_costs.size() == nodes.size();
}

// Should call IsAccountingEnabled() first
// Takes the previously computed ResourceCount for the node
// (usually during GetCapabiilty())
// if present and adds it to the consumed amount
void AccountForNode(size_t cost_index) const {
assert(cost_index < nodes_costs.size());
if (nodes_costs[cost_index].has_value()) {
resource_accountant->AddConsumedAmount(*nodes_costs[cost_index]);
}
}

// This computes and accounts for the resource cost for the node that just
// been fused from other nodes, and the EP did not had a chance to compute the costs.
void ComputeAndAccountForNode(const std::string& node_name) const {
assert(resource_accountant != nullptr);
resource_accountant->AddConsumedAmount(resource_accountant->ComputeResourceCount(node_name));
}

void SetAccountant(IResourceAccountant* res_accountant) {
resource_accountant = res_accountant;
}

// Append resource count to the list of costs for the nodes.
void AppendNodeCost(const ResourceCount& cost) {
assert(resource_accountant != nullptr);
nodes_costs.emplace_back(cost);
}

// Append an absent cost for the node that was already accounted for.
void AppendNodeEmptyCost() {
assert(resource_accountant != nullptr);
nodes_costs.emplace_back();
}

private:
// subgraph meta definition.
std::unique_ptr<MetaDef> meta_def_;
// Optional resource accountant for this subgraph.
IResourceAccountant* resource_accountant = nullptr;
// Vector with resource costs for nodes above. Should have the same size
// Some nodes that were previously accounted for because they already been assigned to an EP
// for example during multiple calls to GetCapabiility() will not have resource count present.
// may not have a resource count present, we skip it.
InlinedVector<std::optional<ResourceCount>> nodes_costs;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,31 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMin
static const char* const kOrtSessionOptionsSavePrePackedConstantInitializers =
"session.save_external_prepacked_constant_initializers";

// Use this config when you want to collect memory stats for each node in the graph.
// The file format is a CSV file with the following columns:
// The file will be created if it does not exist, and will be overwritten if it does.
//
// The content of the file can be used to estimate memory requirements at run time including
// the temporary allocations. This operation is preferably done on a CPU device, as the model may exceed
// device memory limits in constrained environments. When enabling this option, it is important to disable
// memory patterns, as they tend to allocate large blocks to avoid fragmentation and accommodate needs of multiple
// kernels. Memory patterns may make it difficult to allocate on a device with limited memory.
//
// The collected stats then can be used to partition the graph among the devices in a way that only the
// required memory is allocated on each device.
//
// node_name, initializers_memory, dynamic_outputs_sizes, temp_allocations_size
//
// - "full path to file": there is not a default for this option. If the file can not be opened for writing, an error will be returned.
static const char* const kOrtSessionOptionsCollectNodeMemoryStatsToFile = "session.collect_node_memory_stats_to_file";

/// This is a composite CSV setting formatted as "memory limit in kb,file name for collected stats"
/// "limit > 0": enables Capacity Aware Partitioning for Cuda EP. The EP will place nodes on device
/// "file name" : this file is expected to be found at the same folder with the model. The file contains
/// pre-recorded stats collected when running with kOrtSessionOptionsCollectNodeMemoryStatsToFile enforce (see above)
static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings =
"session.resource_cuda_partitioning_settings";

// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
// "0": disable. (default)
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include "core/framework/bfc_arena.h"

#include "core/session/onnxruntime_session_options_config_keys.h"

using namespace onnxruntime::common;

namespace onnxruntime {
Expand Down Expand Up @@ -391,6 +393,11 @@ ExecutionFrame::ExecutionFrame(gsl::span<const int> feed_mlvalue_idxs, gsl::span
}
}

#if !defined(ORT_MINIMAL_BUILD)
node_stats_file_name_ = session_state.GetSessionOptions().config_options.GetConfigOrDefault(
kOrtSessionOptionsCollectNodeMemoryStatsToFile, "");
#endif

// If the session enable memory pattern optimization
// and we have execution plan generated, try to setup
// memory pattern optimization.
Expand Down Expand Up @@ -614,6 +621,12 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
#endif
}

#if !defined(ORT_MINIMAL_BUILD)
if (IsNodeAllocationStatsEnabled()) {
ort_value_to_dynamic_allocations_size_.insert_or_assign(ort_value_index, size);
}
#endif

return Status::OK();
}

Expand Down
36 changes: 34 additions & 2 deletions onnxruntime/core/framework/execution_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <filesystem>

Check warning on line 6 in onnxruntime/core/framework/execution_frame.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 <filesystem> is an unapproved C++17 header. [build/c++17] [5] Raw Output: onnxruntime/core/framework/execution_frame.h:6: <filesystem> is an unapproved C++17 header. [build/c++17] [5]
#include <mutex>
#include <vector>

Expand Down Expand Up @@ -92,17 +93,23 @@

Status ReleaseMLValue(int ort_value_idx);

protected:
// get the ort_value_idx from NodeIndexInfo
int GetNodeIdxToMLValueIdx(int index) const;

virtual bool IsNodeAllocationStatsEnabled() const noexcept {
return false;
}

protected:
OrtValue& GetMutableMLValue(int ort_value_index) { return const_cast<OrtValue&>(GetMLValue(ort_value_index)); }

virtual Status ReleaseMLValueImpl(int ort_value_idx);

// returns true if the ort_value_idx is an output from the graph
bool IsOutput(int ort_value_idx) const;

const OrtValueNameIdxMap& GetOrtValueNameIdxMap() const noexcept { return ort_value_idx_map_; }

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(IExecutionFrame);

Expand Down Expand Up @@ -166,6 +173,25 @@
return planner_.has_value();
}

#if !defined(ORT_MINIMAL_BUILD)
bool IsNodeAllocationStatsEnabled() const noexcept override {
return !node_stats_file_name_.empty();
}

// SessionScope will make use of it
const std::filesystem::path& GetNodeStatsFileName() const noexcept {
return node_stats_file_name_;
}

std::optional<size_t> GetOrtValueDynamicAllocation(int ort_value_index) const {
auto it = ort_value_to_dynamic_allocations_size_.find(ort_value_index);
if (it != ort_value_to_dynamic_allocations_size_.end()) {
return it->second;
}
return std::nullopt;
}
#endif

// This function try retrieve the inferred shapes for the given NodeArg index.
// If the retrival is successful, this function returns true and false otherwise.
bool TryGetInferredShape(int index, TensorShape& shape) const override;
Expand Down Expand Up @@ -258,10 +284,16 @@
// This field is not physical memory size.
// dynamic_activation_memory_sizes_in_byte_[location] is the dynamic memory consumption on "location".
std::unordered_map<std::string, size_t> dynamic_activation_memory_sizes_in_byte_;
#endif

#if !defined(ORT_MINIMAL_BUILD)
std::filesystem::path node_stats_file_name_;

// OrtValue index to the size of dynamic memory allocation.
std::unordered_map<int, size_t> ort_value_to_dynamic_allocations_size_;

Check warning on line 293 in onnxruntime/core/framework/execution_frame.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/execution_frame.h:293: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
#endif
// Mutex which should be acquired when executing non-thread-safe member functions.
// A current example is the tracker of dynamic memory allocation.
mutable std::mutex mtx_;
#endif
};
} // namespace onnxruntime
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace onnxruntime {

std::vector<std::unique_ptr<ComputeCapability>>
IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const IKernelLookup& kernel_lookup) const {
const IKernelLookup& kernel_lookup,
IResourceAccountant*) const {
std::vector<std::unique_ptr<ComputeCapability>> result;
for (const auto& node : graph.Nodes()) {
if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node);
Expand Down
Loading
Loading