Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overlooked overload information on torchlib functions #919

Merged
merged 3 commits into from
Jul 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op("aten::add")
@torch_op(("aten::add", "aten::add.Tensor"))
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
Expand Down Expand Up @@ -1235,7 +1235,7 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
return op.SplitToSequence(self, list_split, axis=dim)


@torch_op("aten::clamp", trace_only=True)
@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True)
def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal:
"""clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
clamped = self
Expand Down Expand Up @@ -2184,7 +2184,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
raise NotImplementedError()


@torch_op("aten::div")
@torch_op(("aten::div", "aten::div.Tensor"))
def aten_div(self: TFloat, other: TFloat) -> TFloat:
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2299,7 +2299,7 @@ def aten_embedding_sparse_backward(
raise NotImplementedError()


@torch_op("aten::empty")
@torch_op(("aten::empty", "aten::empty.memory_format"))
def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

Expand Down Expand Up @@ -2353,7 +2353,7 @@ def aten_empty_strided(
return op.Expand(zero, size)


@torch_op("aten::eq")
@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar"))
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
"""eq.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2563,7 +2563,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType
raise NotImplementedError()


@torch_op("aten::fill")
@torch_op(("aten::fill", "aten::fill.Tensor"))
def aten_fill(self: TTensor, value: TTensor) -> TTensor:
"""fill.Tensor(Tensor self, Tensor value) -> Tensor"""

Expand Down Expand Up @@ -2748,7 +2748,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::ge")
@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar"))
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2905,7 +2905,7 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op("aten::gt")
@torch_op(("aten::gt", "aten::gt.Scalar"))
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3595,7 +3595,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::le")
@torch_op(("aten::le", "aten::le.Tensor"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3884,7 +3884,7 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op("aten::lt")
@torch_op(("aten::lt", "aten::lt.Scalar"))
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3957,7 +3957,7 @@ def aten_margin_ranking_loss(
raise NotImplementedError()


@torch_op("aten::masked_fill")
@torch_op(("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"))
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
"""masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"""
# NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types.
Expand Down Expand Up @@ -4462,15 +4462,15 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::mul")
@torch_op(("aten::mul", "aten::mul.Tensor"))
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
other = op.CastLike(other, self)
return op.Mul(self, other)


@torch_op("aten::mul")
@torch_op(("aten::mul", "aten::mul.Tensor"))
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

Expand Down Expand Up @@ -4883,7 +4883,7 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()


@torch_op("aten::ne")
@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor"))
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5223,7 +5223,7 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::pow")
@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar"))
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

Expand Down Expand Up @@ -5756,7 +5756,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Reciprocal(op.Sqrt(self))


@torch_op("aten::rsub")
@torch_op(("aten::rsub", "aten::rsub.Scalar"))
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
Expand Down Expand Up @@ -5785,7 +5785,7 @@ def aten_scatter_add(
return op.ScatterElements(self, index, src, axis=dim, reduction="add")


@torch_op("aten::scatter_reduce", trace_only=True)
@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True)
def aten_scatter_reduce(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
Expand Down Expand Up @@ -5855,7 +5855,7 @@ def aten_segment_reduce(
raise NotImplementedError()


@torch_op("aten::select")
@torch_op(("aten::select", "aten::select.int"))
def aten_select(self: TTensor, dim: int, index: int) -> TTensor:
"""select(Tensor self, int dim, int index) -> Tensor"""

Expand Down Expand Up @@ -5935,7 +5935,7 @@ def aten_sinh(self: TFloat) -> TFloat:
return op.Sinh(self)


@torch_op("aten::slice", trace_only=True)
@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True)
def aten_slice(
self: TTensor,
dim: int = 0,
Expand Down Expand Up @@ -6081,7 +6081,7 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::split")
@torch_op(("aten::split", "aten::split.Tensor"))
def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor:
"""split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]"""

Expand Down Expand Up @@ -6309,7 +6309,7 @@ def aten_stft(
return result


@torch_op("aten::sub")
@torch_op(("aten::sub", "aten::sub.Tensor"))
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand All @@ -6324,7 +6324,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> Te
raise NotImplementedError()


@torch_op("aten::sum", trace_only=True)
@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
def aten_sum_dim_IntList(
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
) -> TReal:
Expand Down Expand Up @@ -6634,7 +6634,7 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()


@torch_op("aten::transpose", trace_only=True)
@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True)
def aten_transpose(self, dim0: int, dim1: int):
"""transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"""

Expand Down Expand Up @@ -6729,7 +6729,7 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::unbind")
@torch_op(("aten::unbind", "aten::unbind.int"))
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""

Expand Down Expand Up @@ -7082,7 +7082,7 @@ def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
return op.ConcatFromSequence(tensors, axis=0)


@torch_op("aten::where")
@torch_op(("aten::where", "aten::where.self"))
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""

Expand Down