Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'tt.expand_dims' op inferred type(s) are incompatible with return type(s) of operation #3107

Open
Stonepia opened this issue Jan 7, 2025 · 1 comment
Labels
bug Something isn't working upstream: pytorch

Comments

@Stonepia
Copy link
Contributor

Stonepia commented Jan 7, 2025

Describe the bug

We witnessed such a bug during in total of 15 tests:

/tmp/torchinductor_pt-gpu/pj/cpjyiy3hzzp5fva5tt2hznb57kep2jnaxubw4yzgwc7hb7cdcjae.py:28:38: error: 'tt.expand_dims' op inferred type(s) 'tensor<1x8xi32, #ttg.linear<{register = [], lane = [[0, 1], [0, 2], [0, 4], [0, 0], [0, 0]], warp = [[0, 0]], block = []}>>' are incompatible with return type(s) of operation 'tensor<1x8xi32, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}>>'
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
                                     ^
/tmp/torchinductor_pt-gpu/pj/cpjyiy3hzzp5fva5tt2hznb57kep2jnaxubw4yzgwc7hb7cdcjae.py:28:38: error: 'tt.expand_dims' op failed to infer returned types
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
                                     ^
/tmp/torchinductor_pt-gpu/pj/cpjyiy3hzzp5fva5tt2hznb57kep2jnaxubw4yzgwc7hb7cdcjae.py:28:38: note: see current operation: %78 = "tt.expand_dims"(%49) <{axis = 0 : i32}> : (tensor<8xi32, #ttg.slice<{dim = 0, parent = #ttg.linear<{register = [], lane = [[0, 1], [0, 2], [0, 4], [0, 0], [0, 0]], warp = [[0, 0]], block = []}>}>>) -> tensor<1x8xi32, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}>>

The reproducer is :

import torch
from torch._inductor.async_compile import AsyncCompile
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
    grid,
)
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
from torch._C import _xpu_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p


# kernel path: /tmp/tmp0r9psgjh/pj/cpjyiy3hzzp5fva5tt2hznb57kep2jnaxubw4yzgwc7hb7cdcjae.py
# Topologically Sorted Source Nodes: [argsort], Original ATen: [aten.sort]
# Source node to ATen node mapping:
#   argsort => sort
# Graph fragment:
#   %sort : [num_users=1] = call_function[target=torch.ops.aten.sort.default](args = (%arg0_1, -3, True), kwargs = {})
triton_per_fused_sort_0 = async_compile.triton('triton_per_fused_sort_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 32, 'r0_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=64, cc={'architecture': 13136561920, 'driver_version': '1.3.30049+10', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68719476736, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': [(0,), (1,)], 'tt.equal_to': []}, 'cls': 'AttrsDescriptor'})]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_sort_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'FB40EBCFCB9A06A14744E64088225D5F51204F154F792FD0F599F4DD0B53BAD5', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused_sort_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 25
    r0_numel = 5
    R0_BLOCK: tl.constexpr = 8
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = r0_index < r0_numel
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 25*r0_1), r0_mask & xmask, other=0.0)
    tmp1 = r0_1
    tmp2 = tmp1.to(tl.int16)
    tl.static_assert(tmp2.dtype == tl.int16)
    tmp3 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
    tmp4 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
    tmp5, tmp6, = triton_helpers.sort_with_index(tmp3, tmp4, rnumel, 1, stable=False, descending=True)
    tl.store(out_ptr0 + (x0 + 25*r0_1), tmp6, r0_mask & xmask)
''', device_str='xpu')



async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (5, 5, 5), (25, 5, 1))
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        buf1 = empty_strided_xpu((5, 5, 5), (25, 5, 1), torch.int16)
        # Topologically Sorted Source Nodes: [argsort], Original ATen: [aten.sort]
        stream0 = get_raw_stream(0)
        triton_per_fused_sort_0.run(arg0_1, buf1, 25, 5, grid=grid(25), stream=stream0)
        del arg0_1
    return (buf1,)



def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((5, 5, 5), (25, 5, 1), device='xpu:0', dtype=torch.float32)
    fn = lambda: call([arg0_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

Environment details

Triton: latest main (b59bb9a)
PyTorch: latest main (d0f5df83a50d9bb630764c92ac63fcb2640b1f94) + triton patch
OS: Ubuntu 24.10 / Windows 11

@Stonepia Stonepia added bug Something isn't working upstream: pytorch labels Jan 7, 2025
@riverliuintel
Copy link

@vlad-penkin new bugs for PyTorch 2.7 upstream.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working upstream: pytorch
Projects
None yet
Development

No branches or pull requests

2 participants