Skip to content

Commit

Permalink
Add reshape_infer procedure (#646)
Browse files Browse the repository at this point in the history
Unlike numpy, `reshape` does not support having dimensions with value -1 to infer their value. To do so a new `reshape_infer` is added.

This is added as a separate procedure to avoid the (small) cost this adds on top of the usual reshape (which could be called relatively frequently).
  • Loading branch information
AngelEzquerra authored Apr 15, 2024
1 parent 9867253 commit d21362a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
31 changes: 28 additions & 3 deletions src/arraymancer/tensor/private/p_shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import ../../laser/tensor/initialization,
./p_checks,
nimblas

import std / sequtils

proc contiguousImpl*[T](t: Tensor[T], layout: OrderType, result: var Tensor[T]) =
if layout == rowMajor:
result = t.map_inline(x)
Expand All @@ -28,16 +30,39 @@ proc contiguousImpl*[T](t: Tensor[T], layout: OrderType, result: var Tensor[T])
apply2_inline(result, t):
y

proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs[int]|Metadata, result: var Tensor[T]) =
proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs[int]|Metadata|seq[int], result: var Tensor[T]) =
result = newTensorUninit[T](new_shape)
result.apply2_inline(t,y)

proc reshape_no_copy*(t: AnyTensor, new_shape: varargs[int]|Metadata, result: var AnyTensor, layout: OrderType) {.noSideEffect.}=
proc reshape_no_copy*(t: AnyTensor, new_shape: varargs[int]|Metadata|seq[int], result: var AnyTensor, layout: OrderType) {.noSideEffect.}=
result.shape.copyFrom(new_shape)
shape_to_strides(result.shape, layout, result.strides)
result.offset = t.offset

proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|Metadata, result: var AnyTensor) =
proc infer_shape*(t: Tensor, new_shape: varargs[int]): seq[int] {.noinit.} =
## Replace the single -1 value on `new_shape` with the value that
## makes the size the same as that of the input tensor
result = new_shape.toSeq
var auto_axis = -1
var auto_axis_count = 0
for n in 0 .. result.high:
if result[n] == -1:
auto_axis_count += 1
auto_axis = n
break
if auto_axis_count > 1:
raise newException(ValueError, "Only one dimension can be inferred by inferShape")
elif auto_axis_count == 0:
when compileOption("boundChecks"):
raise newException(ValueError, "At least one dimension must be inferred by inferShape")
else:
result[auto_axis] = t.size div result.filterIt(it != -1).prod

proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|Metadata|seq[int],
result: var AnyTensor, infer: static bool) =
when infer:
let new_shape = t.infer_shape(new_shape)

when compileOption("boundChecks"):
check_reshape(t, new_shape)

Expand Down
22 changes: 19 additions & 3 deletions src/arraymancer/tensor/shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ proc reshape*(t: Tensor, new_shape: varargs[int]): Tensor {.noinit.} =
##
## Input:
## - a tensor
## - a new shape. Number of elements must be the same
## - a new shape. Number of elements must be the same. Unlike numpy,
## dimensions cannot be -1 to infer their value. If that is what you need
## you must use the alternative `reshape_infer` proc.
## Returns:
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result)
reshapeImpl(t, new_shape, result, infer = false)

proc reshape*(t: Tensor, new_shape: Metadata): Tensor {.noinit.} =
## Reshape a tensor. If possible no data copy is done and the returned tensor
Expand All @@ -78,7 +80,21 @@ proc reshape*(t: Tensor, new_shape: Metadata): Tensor {.noinit.} =
## - a new shape. Number of elements must be the same
## Returns:
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result)
reshapeImpl(t, new_shape, result, infer = false)

proc reshape_infer*(t: Tensor, new_shape: varargs[int]):
Tensor {.noinit.} =
## Reshape a tensor. If possible no data copy is done and the returned tensor
## shares data with the input. If input is not contiguous, this is not possible
## and a copy will be made.
##
## Input:
## - a tensor
## - a new shape. Number of elements must be the same. The new shape can
## contain -1 to infer the size of one (and only one) dimension
## Returns:
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result, infer = true)

proc flatten*(t: Tensor): Tensor {.noinit,inline.} =
## Flatten a tensor, returning a rank-1 tensor with the same data as the input.
Expand Down
6 changes: 5 additions & 1 deletion tests/tensor/test_shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ proc main() =
check: a == b

test "Reshape":
let a = toSeq(1..4).toTensor().reshape(2,2)
let a = toSeq(1..4).toTensor().reshape(2, 2)
let b = toSeq(1..4).toTensor().reshape_infer(-1, 2)
let c = toSeq(1..4).toTensor().reshape_infer(2, -1)
check: a == [[1,2],
[3,4]].toTensor()
check: a == b
check: a == c

test "Unsafe reshape":
block:
Expand Down

0 comments on commit d21362a

Please sign in to comment.