Skip to content

Commit

Permalink
[CPU EP] Implement Neg unary operation for int16
Browse files Browse the repository at this point in the history
  • Loading branch information
Zyrin committed Dec 18, 2024
1 parent 21e1486 commit 9f52833
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, float, Neg);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, double, Neg);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int8_t, Neg);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int16_t, Neg);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int32_t, Neg);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12, int64_t, Neg);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 11, Pow);
Expand Down Expand Up @@ -700,6 +701,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int16_t, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Neg);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Neg);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Mod);
Expand Down Expand Up @@ -1323,6 +1325,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
double, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12,
int8_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12,
int16_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12,
int32_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 12,
Expand Down Expand Up @@ -2246,6 +2250,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int16_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Neg)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Mod)>,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,13 @@ REG_ELEMENTWISE_TYPED_KERNEL(Abs, 13, uint64_t, Abs);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, float, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, double, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int8_t, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int16_t, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int32_t, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int64_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, float, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, double, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int8_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int16_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int32_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int64_t, Neg);

Expand Down
22 changes: 21 additions & 1 deletion onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ TEST(MathOpTest, Abs_int32) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT parser: Int32 not allowed as input to this layer
}

TEST(MathOpTest, Neg) {
TEST(MathOpTest, Neg_float) {
OpTester test("Neg");
std::vector<int64_t> dims{2, 2};
test.AddInput<float>("X", dims,
Expand All @@ -961,6 +961,18 @@ TEST(MathOpTest, Neg) {
test.Run();
}

TEST(MathOpTest, Neg_double) {
OpTester test("Neg");
std::vector<int64_t> dims{2, 2};
test.AddInput<double>("X", dims,
{1.0, -2.0,
0.0, -10.0});
test.AddOutput<double>("Y", dims,
{-1.0, 2.0,
-0.0, 10.0});
test.Run();
}

TEST(MathOpTest, Neg_int8) {
OpTester test("Neg");
std::vector<int64_t> dims{4};
Expand All @@ -971,6 +983,14 @@ TEST(MathOpTest, Neg_int8) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: INT8 is not supported
}

TEST(MathOpTest, Neg_int16) {
OpTester test("Neg");
std::vector<int64_t> dims{4};
test.AddInput<int8_t>("X", dims, {1, -2, 0, -10});
test.AddOutput<int8_t>("Y", dims, {-1, 2, 0, 10});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Int16 not allowed as input to this layer
}

TEST(MathOpTest, Neg_int32) {
OpTester test("Neg");
std::vector<int64_t> dims{4};
Expand Down

0 comments on commit 9f52833

Please sign in to comment.