diff --git a/src/arraymancer/tensor/shapeshifting.nim b/src/arraymancer/tensor/shapeshifting.nim index 8d6bbc07..0b0c49c7 100644 --- a/src/arraymancer/tensor/shapeshifting.nim +++ b/src/arraymancer/tensor/shapeshifting.nim @@ -687,3 +687,67 @@ proc repeat_values*[T](t: Tensor[T], reps: Tensor[int]): Tensor[T] {.noinit, inl ## version. ## ``` t.repeat_values(reps.toSeq1D) + +proc tile*[T](t: Tensor[T], reps: varargs[int]): Tensor[T] = + ## Construct a new tensor by repeating the input tensor a number of times on one or more axes + ## + ## Inputs: + ## - t: The tensor to repeat + ## - reps: One or more integers indicating the number of times to repeat + ## the tensor on each axis (starting with axis 0) + ## + ## Result: + ## - A new tensor whose shape is `t.shape *. reps` + ## + ## Notes: + ## - If a rep value is 1, the tensor is not repeated on that particular axis + ## - If there are more rep values than the input tensor has axes, additional + ## dimensions are prepended to the input tensor as needed. Note that this + ## is similar to numpy's `tile` function behavior, but different to + ## Matlab's `repmat` behavior, which appends missing dimensions instead + ## of prepending them. + ## - This function behavior is similar to nims `sequtils.repeat`, in that + ## it repeats the full tensor multiple times. If what you want is to + ## repeat the _elements_ of the tensor multiple times, rather than the + ## full tensor, use the `repeat_values` procedure instead. + ## + ## Examples: + ## ```nim + ## let x = arange(4).reshape(2, 2) + ## + ## # When the number of reps and tensor dimensions match, the ouptut tensor + ## # shape is the `reps *. t.shape` + ## echo tile(x, 2, 3) + ## > Tensor[system.int] of shape "[4, 6]" on backend "Cpu" + ## > |0 1 0 1 0 1| + ## > |2 3 2 3 2 3| + ## > |0 1 0 1 0 1| + ## > |2 3 2 3 2 3| + ## + ## # If there are fewer reps than tensor dimensions, start + ## # repeating on the first axis (leaving alone axis with missing reps) + ## echo tile(x, 2) + ## > Tensor[system.int] of shape "[4, 2]" on backend "Cpu" + ## > |0 1| + ## > |2 3| + ## > |0 1| + ## > |2 3| + ## + ## # If there are more reps than tensor dimensions, prepend the missing + ## # dimensions before repeating + ## echo tile(x, 1, 2, 3) + ## > Tensor[system.int] of shape "[1, 4, 6]" on backend "Cpu" + ## > 0 + ## > |0 1 0 1 0 1| + ## > |2 3 2 3 2 3| + ## > |0 1 0 1 0 1| + ## > |2 3 2 3 2 3| + ## ``` + result = t + for ax in countdown(reps.high, 0): + var concat_seq = repeat(result, reps[ax]) + if ax >= result.shape.len: + # mutate the repeated tensors to have one more axis + concat_seq.applyIt(unsqueeze(it, 0)) + result = concat(concat_seq, axis=ax) + diff --git a/tests/tensor/test_shapeshifting.nim b/tests/tensor/test_shapeshifting.nim index f174f3ec..a7dfadf1 100644 --- a/tests/tensor/test_shapeshifting.nim +++ b/tests/tensor/test_shapeshifting.nim @@ -344,5 +344,58 @@ proc main() = check: a.repeat_values([1, 0, 3, 2]) == expected check: a.repeat_values([1, 0, 3, 2].toTensor) == expected + test "Tile": + let t = arange(6).reshape(2, 3) + + block: # Tile over the first axis + let expected = [ + [0, 1, 2], + [3, 4, 5], + [0, 1, 2], + [3, 4, 5], + ].toTensor + check: t.tile(2) == expected + + block: # Tile over the all the axis of the input tensor + let expected = [ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5] + ].toTensor + check: t.tile(2, 3) == expected + + block: # Tile over the more axis than the input tensor has + let expected = [ + [ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5] + ], + [ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5] + ], + [ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5] + ], + [ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5] + ] + ].toTensor + check: t.tile(4, 2, 3) == expected + + block: # tiling and repeating values are sometimes equivalent + check: t.tile(2, 1, 1) == t.unsqueeze(axis=0).repeat_values(2, axis = 0) + main() GC_fullCollect()