From 9f52833083f96f60901c65621ae931be028da496 Mon Sep 17 00:00:00 2001 From: Alexis Tsogias <1114095+Zyrin@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:30:14 +0100 Subject: [PATCH] [CPU EP] Implement Neg unary operation for int16 --- .../providers/cpu/cpu_execution_provider.cc | 5 +++++ .../providers/cpu/math/element_wise_ops.cc | 2 ++ .../cpu/math/element_wise_ops_test.cc | 22 ++++++++++++++++++- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index b26265ccbda41..5ae1a6047c19c 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -107,6 +107,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Neg); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Neg); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int8_t, Neg); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int16_t, Neg); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int32_t, Neg); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int64_t, Neg); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 11, Pow); @@ -700,6 +701,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Neg); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int16_t, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Neg); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Mod); @@ -1323,6 +1325,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, Neg)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index d3590be6b561e..8131cda9480cc 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -270,11 +270,13 @@ REG_ELEMENTWISE_TYPED_KERNEL(Abs, 13, uint64_t, Abs); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, float, Neg); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, double, Neg); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int8_t, Neg); +REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int16_t, Neg); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int32_t, Neg); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int64_t, Neg); REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, float, Neg); REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, double, Neg); REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int8_t, Neg); +REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int16_t, Neg); REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int32_t, Neg); REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int64_t, Neg); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 8a600135895a2..c0141d5519d61 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -949,7 +949,7 @@ TEST(MathOpTest, Abs_int32) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT parser: Int32 not allowed as input to this layer } -TEST(MathOpTest, Neg) { +TEST(MathOpTest, Neg_float) { OpTester test("Neg"); std::vector dims{2, 2}; test.AddInput("X", dims, @@ -961,6 +961,18 @@ TEST(MathOpTest, Neg) { test.Run(); } +TEST(MathOpTest, Neg_double) { + OpTester test("Neg"); + std::vector dims{2, 2}; + test.AddInput("X", dims, + {1.0, -2.0, + 0.0, -10.0}); + test.AddOutput("Y", dims, + {-1.0, 2.0, + -0.0, 10.0}); + test.Run(); +} + TEST(MathOpTest, Neg_int8) { OpTester test("Neg"); std::vector dims{4}; @@ -971,6 +983,14 @@ TEST(MathOpTest, Neg_int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: INT8 is not supported } +TEST(MathOpTest, Neg_int16) { + OpTester test("Neg"); + std::vector dims{4}; + test.AddInput("X", dims, {1, -2, 0, -10}); + test.AddOutput("Y", dims, {-1, 2, 0, 10}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Int16 not allowed as input to this layer +} + TEST(MathOpTest, Neg_int32) { OpTester test("Neg"); std::vector dims{4};