diff --git a/src/arraymancer/tensor/math_functions.nim b/src/arraymancer/tensor/math_functions.nim index 626c4bf3..839c4a1c 100644 --- a/src/arraymancer/tensor/math_functions.nim +++ b/src/arraymancer/tensor/math_functions.nim @@ -274,6 +274,40 @@ proc classify*[T: SomeFloat](t: Tensor[T]): Tensor[FloatClass] {.noinit.} = ## - fcNegInf: value is negative infinity t.map_inline(classify(x)) +proc almostEqual*[T: SomeFloat | Complex32 | Complex64](t1, t2: Tensor[T], + unitsInLastPlace: Natural = 4): Tensor[bool] {.noinit.} = + ## Element-wise almostEqual function + ## + ## Checks whether pairs of elements of two tensors are almost equal, using + ## the [machine epsilon](https://en.wikipedia.org/wiki/Machine_epsilon). + ## + ## For more details check the section covering the `almostEqual` procedure in + ## nim's standard library documentation. + ## + ## Inputs: + ## - t1, t2: Input (floating point or complex) tensors of the same shape. + ## - unitsInLastPlace: The max number of + ## [units in the last place](https://en.wikipedia.org/wiki/Unit_in_the_last_place) + ## difference tolerated when comparing two numbers. The + ## larger the value, the more error is allowed. A `0` + ## value means that two numbers must be exactly the + ## same to be considered equal. + ## + ## Result: + ## - A new boolean tensor of the same shape as the inputs, in which elements + ## are true if the two values in the same position on the two input tensors + ## are almost equal (and false if they are not). + ## + ## Note: + ## - You can combine this function with `all` to check if two real tensors + ## are almost equal. + map2_inline(t1, t2): + when T is Complex: + almostEqual(x.re, y.re, unitsInLastPlace=unitsInLastPlace) and + almostEqual(x.im, y.im, unitsInLastPlace=unitsInLastPlace) + else: + almostEqual(x, y, unitsInLastPlace=unitsInLastPlace) + type ConvolveMode* = enum full, same, valid proc convolveImpl[T: SomeNumber | Complex32 | Complex64]( diff --git a/tests/tensor/test_math_functions.nim b/tests/tensor/test_math_functions.nim index bafa5ca3..122c6b5e 100644 --- a/tests/tensor/test_math_functions.nim +++ b/tests/tensor/test_math_functions.nim @@ -163,6 +163,24 @@ proc main() = check: expected_isNaN == a.isNaN check: expected_classification == a.classify + test "almostEqual": + block: # Real + let t1 = arange(1.0, 5.0) + let t2 = t1.clone() + check: all(almostEqual(t1, t2)) == true + var t3 = t1.clone() + t3[0] += 2e-15 + check: almostEqual(t1, t3) == [false, true, true, true].toTensor() + check: all(almostEqual(t1, t3, unitsInLastPlace = 5)) == true + block: # Complex + let t1 = complex(arange(1.0, 5.0), arange(1.0, 5.0)) + let t2 = t1.clone() + check: all(almostEqual(t1, t2)) == true + var t3 = t1.clone() + t3[0] += complex(2e-15) + check: almostEqual(t1, t3) == [false, true, true, true].toTensor() + check: all(almostEqual(t1, t3, unitsInLastPlace = 5)) == true + test "1-D convolution": block: let a = arange(4)