diff --git a/functorch/functorch/_src/decompositions.py b/functorch/functorch/_src/decompositions.py index f619a7f5416b4d..5c6bd538b7fef3 100644 --- a/functorch/functorch/_src/decompositions.py +++ b/functorch/functorch/_src/decompositions.py @@ -53,17 +53,17 @@ class Reduction(Enum): @register_decomposition(aten.tanh_backward) -def tanh_backward_decomposition(out_grad: Tensor, y: Tensor): +def tanh_backward(out_grad: Tensor, y: Tensor): return out_grad * (1 - y * y) @register_decomposition(aten.sigmoid_backward) -def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor): +def sigmoid_backward(out_grad: Tensor, y: Tensor): return out_grad * (y * (1 - y)) @register_decomposition(aten.softplus_backward) -def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float): +def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float): z = (x * beta).exp() return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) @@ -103,7 +103,7 @@ def hardsigmoid(self: Tensor) -> Tensor: @register_decomposition(aten.hardsigmoid_backward) -def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor): +def hardsigmoid_backward(grad_output: Tensor, self: Tensor): return torch.where((self > -3.0) & (self < 3.0), grad_output * (1.0 / 6.0), grad_output.new_zeros(())) @@ -113,7 +113,7 @@ def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor: @register_decomposition(aten.hardtanh_backward) -def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float): +def hardtanh_backward(grad_output: Tensor, self: Tensor, min_val: float, max_val: float): return torch.where((self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output) @@ -133,7 +133,7 @@ def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor: @register_decomposition(aten.threshold_backward) -def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float): +def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float): return torch.where(self <= threshold, grad_output.new_zeros((1,)), grad_output) @@ -179,7 +179,7 @@ def gelu_backward(grad: Tensor, self: Tensor, approximate: str = 'none'): @register_decomposition(aten.mish_backward) -def mish_backward_decomposition(grad_output: Tensor, input: Tensor): +def mish_backward(grad_output: Tensor, input: Tensor): input_tanh_softplus = torch.tanh(F.softplus(input)) input_sigmoid = torch.sigmoid(input) out = (input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus))