-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] in Xnnpack EP, the conversion for fused activation param isn't correct #23115
base: main
Are you sure you want to change the base?
Changes from 7 commits
ba52bc0
6032820
242c182
7c7f16a
3d75696
c4f0455
dd9865f
d556acb
a4dac51
ee98190
52d099a
3cc345d
67aa30c
0baa34b
e0e8304
f1d3b16
042e5cd
d7f9e6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,9 +6,11 @@ | |||||||||||||||||
|
||||||||||||||||||
#include "core/common/logging/logging.h" | ||||||||||||||||||
#include "core/common/span_utils.h" | ||||||||||||||||||
#include "core/framework/float16.h" | ||||||||||||||||||
#include "core/framework/utils.h" | ||||||||||||||||||
#include "core/graph/graph.h" | ||||||||||||||||||
#include "core/providers/xnnpack/xnnpack_execution_provider.h" | ||||||||||||||||||
#include "core/providers/xnnpack/xnnpack_init.h" | ||||||||||||||||||
#include "core/session/inference_session.h" | ||||||||||||||||||
#include "core/session/onnxruntime_cxx_api.h" | ||||||||||||||||||
#include "core/session/onnxruntime_session_options_config_keys.h" | ||||||||||||||||||
|
@@ -89,6 +91,51 @@ TEST(XnnpackEP, TestNhwcConvReluClipFusion) { | |||||||||||||||||
RunAndVerifyOutputsWithEP(ort_model_path, "TestNhwcConvReluClipFusion", std::move(ep), feeds, params); | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
#ifdef XNNPACK_FP16_SUPPORTED | ||||||||||||||||||
TEST(XnnpackEP, TestNhwcConvReluClipFusion_FP16) { | ||||||||||||||||||
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "nhwc_conv_clip_relu_fp16.onnx"; | ||||||||||||||||||
|
||||||||||||||||||
RandomValueGenerator generator; | ||||||||||||||||||
TensorShape input_shape_x{1, 16, 16, 192}; | ||||||||||||||||||
std::vector<MLFloat16> input_x = generator.Uniform<MLFloat16>(input_shape_x.GetDims(), -128, 128); | ||||||||||||||||||
|
||||||||||||||||||
OrtValue ml_value_x; | ||||||||||||||||||
CreateMLValue<MLFloat16>(input_shape_x.GetDims(), input_x.data(), OrtMemoryInfo(), &ml_value_x); | ||||||||||||||||||
|
||||||||||||||||||
NameMLValMap feeds; | ||||||||||||||||||
feeds.insert(std::make_pair("model_input", ml_value_x)); | ||||||||||||||||||
|
||||||||||||||||||
std::function<void(const Graph&)> verify = [](const Graph& graph) -> void { | ||||||||||||||||||
ASSERT_EQ(graph.NumberOfNodes(), 3) << "Transpose nodes should have been removed, and " | ||||||||||||||||||
"Conv+Relu and Conv+Clip should have been fused, leaving 3 nodes."; | ||||||||||||||||||
auto node_iter = graph.Nodes().begin(); | ||||||||||||||||||
auto check_node = [](const Node& node, const std::string& fusion_type) { | ||||||||||||||||||
const auto& attr = node.GetAttributes(); | ||||||||||||||||||
auto activation = attr.find("activation"); | ||||||||||||||||||
ASSERT_NE(activation, attr.cend()) << "Fused node should have activation attribute"; | ||||||||||||||||||
ASSERT_EQ(activation->second.s(), fusion_type); | ||||||||||||||||||
}; | ||||||||||||||||||
|
||||||||||||||||||
// check 2nd and 3rd nodes. | ||||||||||||||||||
// the first node is the Conv that does not get fused (created after first call to GetCapability) | ||||||||||||||||||
// the 2nd and 3rd nodes are the fused nodes (created after second call to GetCapability) | ||||||||||||||||||
++node_iter; | ||||||||||||||||||
check_node(*node_iter, "Clip"); | ||||||||||||||||||
++node_iter; | ||||||||||||||||||
check_node(*node_iter, "Relu"); | ||||||||||||||||||
}; | ||||||||||||||||||
|
||||||||||||||||||
EPVerificationParams params; | ||||||||||||||||||
params.ep_node_assignment = ExpectedEPNodeAssignment::All; | ||||||||||||||||||
params.fp32_abs_err = 0.0002f; | ||||||||||||||||||
params.graph_verifier = &verify; | ||||||||||||||||||
|
||||||||||||||||||
auto ep = DefaultXnnpackExecutionProvider(); | ||||||||||||||||||
// So far, CPU EP doensn't support Fp16 Conv fusion, so verify_outputs is skipped. | ||||||||||||||||||
RunAndVerifyOutputsWithEP(ort_model_path, "TestNhwcConvReluClipFusion_FP16", std::move(ep), feeds, params, {}, false); | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite following. There should still be valid output from the CPU EP even if it doesn't fuse, so why can't we use verify_outputs?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx, fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So far, CPU EP doesn't implement FP16 Clip fusion. The output verification fails because it looks CPU EP falls back to FP32 Clip. onnxruntime/onnxruntime/core/providers/cpu/fp16/fp16_activations.h Lines 74 to 77 in e0e8304
To verify the Xnnpack FP16 conv fusion correctness, I add a new test with a new FP16 model ( with only Conv+Relu). Current test (Conv+Clip+Relu) is kept because I want to make sure that Conv+Clip fusion can run, that is, the activition parameters are added correctly. |
||||||||||||||||||
} | ||||||||||||||||||
#endif | ||||||||||||||||||
|
||||||||||||||||||
// test we can share the cpu ep allocator with the xnnpack EP | ||||||||||||||||||
TEST(XnnpackEP, TestAllocatorSharing) { | ||||||||||||||||||
auto init_session = [](std::vector<std::shared_ptr<IExecutionProvider>>& eps, | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if
GetType(arg, arg_type)
failed here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally type info is always available, so I think this is ok. Shape info may be missing depending on the model.
The Conv op looks to be setup to allow fp32, u8, s8 and optionally fp16. Should this also handle u8 and s8 or should ClipReluChecker limit fusion to fp32 and fp16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far, core runtime Clip fusion only supports float too.
onnxruntime/onnxruntime/core/optimizer/utils.cc
Lines 335 to 349 in c6ba7ed
Shall we update them together?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @snnn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd leave the core Clip fusion as-is for now. Can be a separate PR if we think there's a use-case that would benefit.
Are you planning on updating ClipReluChecker to limit the types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked https://onnx.ai/onnx/operators/onnx__Conv.html#type-constraints, Onnx Conv node shouldn't have u8 or s8 inputs. @skottmckay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XNNPack EP's Conv implementation also handles QLinearConv doesn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But QLinearConv isn't in node_to_be_fuse list yet. Could we add it in the next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be good in that case.
To be safer it would be good to add an
else
that returns an error so that if we get a datatype other than fp32 or fp16 it isn't silently ignored. If we add QLinearConv to the nodes that can fuse (not sure why we don't allow that - maybe xnnpack doesn't support it) theelse
will make it much easier for a developer to discover they need to update this code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already added