Skip to content

Commit

Permalink
[functorch] Add l1 loss forward decompositions
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee authored and bigfootjon committed Jul 21, 2022
1 parent da3eb08 commit 151b1f4
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions functorch/functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@ class Reduction(Enum):


@register_decomposition(aten.tanh_backward)
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
def tanh_backward(out_grad: Tensor, y: Tensor):
return out_grad * (1 - y * y)


@register_decomposition(aten.sigmoid_backward)
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
def sigmoid_backward(out_grad: Tensor, y: Tensor):
return out_grad * (y * (1 - y))


@register_decomposition(aten.softplus_backward)
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
z = (x * beta).exp()
return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))

Expand Down Expand Up @@ -103,7 +103,7 @@ def hardsigmoid(self: Tensor) -> Tensor:


@register_decomposition(aten.hardsigmoid_backward)
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
return torch.where((self > -3.0) & (self < 3.0), grad_output * (1.0 / 6.0), grad_output.new_zeros(()))


Expand All @@ -113,7 +113,7 @@ def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor:


@register_decomposition(aten.hardtanh_backward)
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
def hardtanh_backward(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
return torch.where((self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output)


Expand All @@ -133,7 +133,7 @@ def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:


@register_decomposition(aten.threshold_backward)
def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
return torch.where(self <= threshold, grad_output.new_zeros((1,)), grad_output)


Expand Down Expand Up @@ -179,7 +179,7 @@ def gelu_backward(grad: Tensor, self: Tensor, approximate: str = 'none'):


@register_decomposition(aten.mish_backward)
def mish_backward_decomposition(grad_output: Tensor, input: Tensor):
def mish_backward(grad_output: Tensor, input: Tensor):
input_tanh_softplus = torch.tanh(F.softplus(input))
input_sigmoid = torch.sigmoid(input)
out = (input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus))
Expand Down

0 comments on commit 151b1f4

Please sign in to comment.