Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
- Don't pass immutable simple int value by reference
- Populate std::string_view, std::map
  • Loading branch information
Honry committed Jan 20, 2025
1 parent c924c3c commit 264626a
Show file tree
Hide file tree
Showing 24 changed files with 89 additions and 92 deletions.
42 changes: 21 additions & 21 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
return supported_nodes;
}

bool AreInputDataTypesSame(const std::string& op_type,
bool AreInputDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger) {
for (size_t i = 1; i < input_types.size(); i++) {
Expand All @@ -136,52 +136,52 @@ bool AreInputDataTypesSame(const std::string& op_type,
return true;
}

bool IsSupportedDataType(const int32_t& onnx_data_type, const emscripten::val& webnn_supported_data_types) {
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) {
auto it = onnx_to_webnn_data_type_map.find(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_data_type));
if (it == onnx_to_webnn_data_type_map.end())
return false;

std::string webnn_data_type = it->second;
const std::string_view webnn_data_type = it->second;

// Check if WebNN supports the data type.
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
emscripten::val(webnn_data_type));
emscripten::val is_supported =
webnn_supported_data_types.call<emscripten::val>("includes", emscripten::val(webnn_data_type.data()));
return is_supported.as<bool>();
}

// Check if the input or output data type of ONNX node is supported by the WebNN operator.
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t& onnx_data_type,
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger) {
std::string webnn_op_type;
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;
const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type);

return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
return !webnn_op_type.empty() &&
IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
webnn_input_output_name, onnx_input_output_name, logger);
}

bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t& onnx_data_type,
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
const std::string_view webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger) {
if (wnn_limits[webnn_op_type].isUndefined()) {
if (wnn_limits[webnn_op_type.data()].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
return false;
}

if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {
if (wnn_limits[webnn_op_type.data()][webnn_input_output_name.data()].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
<< webnn_input_output_name << "]";
return false;
}
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
if (!IsSupportedDataType(
onnx_data_type, wnn_limits[webnn_op_type.data()][webnn_input_output_name.data()]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: ["
<< onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now";
return false;
Expand Down
42 changes: 19 additions & 23 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
const logging::Logger& logger);

// Some ONNX ops are supported by decomposed WebNN ops.
static const InlinedHashMap<std::string, std::vector<std::string>> decomposed_op_map = {
const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = {
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}},
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
};
// ONNX op type to WebNN op type mapping.
static const InlinedHashMap<std::string, std::string> op_map = {
const std::map<std::string_view, std::string_view> op_map = {
{"Abs", "abs"},
{"Add", "add"},
{"And", "logicalAnd"},
Expand Down Expand Up @@ -307,7 +307,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {

// WebNN op name to its first input name mapping, only record the name that is different from "input".
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
static const InlinedHashMap<std::string, std::string> webnn_op_first_input_name_map = {
const std::map<std::string_view, std::string_view> webnn_op_first_input_name_map = {
{"add", "a"},
{"concat", "inputs"},
{"div", "a"},
Expand All @@ -333,22 +333,18 @@ static const InlinedHashMap<std::string, std::string> webnn_op_first_input_name_
// Retrieve the first input name of a WebNN op used for validating supported input data types.
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
inline std::string GetWebNNOpFirstInputName(const std::string& webnn_op_type) {
inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) {
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
}

inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
inline std::string_view GetWebNNOpType(const std::string_view op_type) {
auto it = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map.
if (it == op_map.end()) {
return false;
}
webnn_op_type = it->second;
return true;
// Return an empty string if the op_type is not listed in the op_map.
return (it != op_map.end()) ? it->second : "";
}

static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_webnn_data_type_map = {

Check warning on line 347 in onnxruntime/core/providers/webnn/builders/helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <map> for map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.h:347: Add #include <map> for map<> [build/include_what_you_use] [4]
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
Expand All @@ -362,22 +358,22 @@ static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> o
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
};

bool AreInputDataTypesSame(const std::string& op_type,
bool AreInputDataTypesSame(const std::string_view op_type,
gsl::span<const int32_t> input_types,
const logging::Logger& logger);
bool IsSupportedDataType(const int32_t& onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const int32_t& onnx_data_type,
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger);
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t& onnx_data_type,
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
const std::string_view webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const std::string_view webnn_input_output_name,
const std::string_view onnx_input_output_name,
const logging::Logger& logger);

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializ
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
std::string webnn_op_type;
if (!GetWebNNOpType(op_type, webnn_op_type))
const std::string_view webnn_op_type = GetWebNNOpType(op_type);
if (webnn_op_type.empty())
return false;

const auto webnn_input_name = GetWebNNOpFirstInputName(op_type);
const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits,
webnn_input_name, "input", logger);
}
Expand All @@ -88,7 +88,7 @@ bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const logging::Logger& logger) const {
// We only check the type of output 0 by default, specific op builder can override this.
const auto& output = *node.OutputDefs()[0];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t output_type;
if (!GetType(output, output_type, logger))
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;

if (!GetType(*input_defs[0], input0_type, logger))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type; // input data type
int32_t input1_type; // weight data type
int32_t input2_type; // bias or x_zero_point data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;
bool has_input1 = TensorExists(input_defs, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet&
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* in
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();

int32_t data_type;
int32_t indices_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_type;
int32_t indices_type;
if (!GetType(input, input_type, logger) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type; // A data type
int32_t input1_type; // B data type
int32_t input2_type; // C or a_zero_point data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input_X_type = 0; // input data type
int32_t input_W_type = 0; // weight data type
int32_t input_R_type = 0; // recurrent weight data type
Expand Down Expand Up @@ -226,7 +226,7 @@ bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t Y_type = 0;
int32_t Y_h_type = 0;
bool has_Y = TensorExists(output_defs, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

std::string webnn_op_type;
ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type");
const std::string_view webnn_op_type = GetWebNNOpType(op_type);
ORT_RETURN_IF(webnn_op_type.empty(), "Cannot get WebNN op type");

if (input_defs.size() == 1) {
// Not
output = model_builder.GetBuilder().call<emscripten::val>(webnn_op_type.c_str(), input0, options);
output = model_builder.GetBuilder().call<emscripten::val>(webnn_op_type.data(), input0, options);
} else {
input1 = model_builder.GetOperand(input_defs[1]->Name());
output = model_builder.GetBuilder().call<emscripten::val>(webnn_op_type.c_str(), input0, input1, options);
output = model_builder.GetBuilder().call<emscripten::val>(webnn_op_type.data(), input0, input1, options);
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
Expand Down Expand Up @@ -74,7 +74,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
bool LogicalOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
const std::string_view op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;

Expand Down
Loading

0 comments on commit 264626a

Please sign in to comment.