Skip to content

Commit

Permalink
[intel] update driver.py and pytorch patches
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Jan 8, 2025
1 parent 96bda72 commit 0021f48
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 31 deletions.
1 change: 0 additions & 1 deletion benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
fn()
end_event.record()
synchronize()
# wait_on_sycl_queue?
estimate_ms = start_event.elapsed_time(end_event) / 5

# The cache is also maintained in `triton_do_bench` function,
Expand Down
9 changes: 1 addition & 8 deletions python/tutorials/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,7 @@ def allocated_slm_size(size_smem):
num_programs = n_rows

# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
)
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE)
return y


Expand Down
2 changes: 1 addition & 1 deletion scripts/patch-pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ echo "Applying PyTorch patches in $REPO_ROOT"
cd "$REPO_ROOT"

curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -

git apply "${SCRIPT_DIR}/pytorch.patch"

79 changes: 72 additions & 7 deletions scripts/pytorch.patch
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ index 00031a56b8d..59086d41b40 100644
self.triton_meta = triton_meta

for tree in self.range_trees:
diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py
index 8b8c29bbb15..c89a76e9868 100644
--- a/torch/_inductor/codegen/triton_utils.py
+++ b/torch/_inductor/codegen/triton_utils.py
@@ -165,12 +165,4 @@ def config_of(
else:
divisible_by_16 = ()

- equal_to_1 = tuple(
- i
- for i, arg in zip(indices, args)
- if isinstance(arg, SizeArg)
- and isinstance(arg.expr, (int, sympy.Integer))
- and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
- )
-
- return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
+ return AttrsDescriptorWrapper(divisible_by_16)
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 2ab2b326354..42d76b8bf94 100644
--- a/torch/_inductor/codegen/wrapper.py
Expand All @@ -123,22 +141,39 @@ index 2ab2b326354..42d76b8bf94 100644
"configs": [
config_of(
diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
index 276c01f3f42..5c633b7963b 100644
index fa2a1334380..4d730fd45de 100644
--- a/torch/_inductor/runtime/hints.py
+++ b/torch/_inductor/runtime/hints.py
@@ -48,8 +48,8 @@ if _is_triton_available():
@@ -44,25 +44,14 @@ def _is_triton_available() -> bool:
# Define `AttrsDescriptorWrapper` function with clear conditional handling
if _is_triton_available():
try:
- from triton.backends.compiler import AttrsDescriptor

def AttrsDescriptorWrapper(
divisible_by_16=None,
- equal_to_1=None,
):
# Prepare the arguments for AttrsDescriptor
- # Prepare the arguments for AttrsDescriptor
kwargs = {
- "tt.divisibility": divisible_by_16,
- "tt.equal_to": equal_to_1,
+ "tt.divisibility": tuple([(i,) for i in divisible_by_16]),
+ "tt.equal_to": tuple([(i,) for i in equal_to_1]),
+ tuple([(i,) for i in divisible_by_16]): [["tt.divisibility", 16]],
}
-
- # Instantiate AttrsDescriptor with the prepared arguments
- res = AttrsDescriptor.from_dict(
- {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
- )
- assert res.property_values["tt.divisibility"] == 16
- assert res.property_values["tt.equal_to"] == 1
- return res
+ return kwargs

# Instantiate AttrsDescriptor with the prepared arguments
except ImportError:
from triton.compiler.compiler import AttrsDescriptor
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
index 281d0e78ba4..901263df4aa 100644
index 281d0e78ba4..3b059a365c9 100644
--- a/torch/_inductor/runtime/triton_heuristics.py
+++ b/torch/_inductor/runtime/triton_heuristics.py
@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
Expand All @@ -164,3 +199,33 @@ index 281d0e78ba4..901263df4aa 100644
compile_meta["constants"],
compile_meta["configs"][0],
),
@@ -502,13 +513,11 @@ class CachingAutotuner(KernelInterface):
call_args = [
arg
for i, arg in enumerate(self.fn.arg_names)
- if i not in self.fn.constexprs and arg not in none_args
]

def_args = [
name
for name in self.fn.arg_names
- if name not in cfg.kwargs and name not in none_args
]
binary_shared = (
binary.shared if hasattr(binary, "shared") else binary.metadata.shared
@@ -952,6 +961,7 @@ class CachingAutotuner(KernelInterface):
):
return launcher(
*args,
+ **launcher.config.kwargs,
**kwargs,
grid=grid,
stream=stream,
@@ -959,6 +969,7 @@ class CachingAutotuner(KernelInterface):
else:
return launcher(
*args,
+ **launcher.config.kwargs,
**kwargs,
grid=grid,
stream=stream,
35 changes: 21 additions & 14 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def wait(self):


def ty_to_cpp(ty):
if ty[0] == '*' or ty == "none":
if ty[0] == '*':
return "void*"
return {
"i1": "int32_t",
Expand All @@ -215,10 +215,12 @@ def ty_to_cpp(ty):
}[ty]


def make_launcher(constants, signature, ids):
def make_launcher(constants, signature):

def _extracted_type(ty):
if ty[0] == '*' or ty == "none":
if ty == "constexpr":
return "PyObject*"
if ty[0] == '*':
return "PyObject*"
if ty[0] == '[':
if ty == "[]":
Expand Down Expand Up @@ -252,7 +254,6 @@ def format_of(ty):
"uint64_t": "K",
}[ty]

signature = {k: v for k, v in signature.items() if v != 'constexpr'}
args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
format = "iiiOOOOOO" + args_format
signature = ','.join(signature.values()).replace('[', '').replace(']', '')
Expand All @@ -262,16 +263,22 @@ def format_of(ty):

# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors.
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
internal_args_list = []
for i, ty in signature.items():
if ty[0] == "*" or ty == "none":
if ty[0] == "*":
internal_args_list.append(f"ptr_info{i}.dev_ptr")
else:
elif ty != "constexpr":
internal_args_list.append(f"_arg{i}")

# generate glue code
params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"]
newline = '\n '
ptr_decls = [
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;"
for i, ty in signature.items()
if ty[0] == "*"
]
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
src = f"""
#include <cstddef>
#include <string>
Expand Down Expand Up @@ -394,7 +401,7 @@ def format_of(ty):
assert(num_params == expected_num_params && "number of kernel param not matched");
// Submit the imported kernel.
auto cgf = [&](sycl::handler &cgh) {{
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))}
{" ".join(f'set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);' for idx, item in enumerate([signature[i] for i in signature if signature[i] != "constexpr"]))}
if (shared_memory) {{
using share_mem_t = sycl::local_accessor<int8_t, 1>;
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
Expand All @@ -418,7 +425,7 @@ def format_of(ty):
PyObject *py_obj_stream;
PyObject* py_kernel;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
{newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
&kernel_metadata, &launch_metadata,
&launch_enter_hook, &launch_exit_hook {args_list})) {{
Expand Down Expand Up @@ -467,7 +474,7 @@ def format_of(ty):
if(kernel_ptr == nullptr) return NULL;
sycl::kernel kernel = *kernel_ptr;
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
{newline.join(ptr_decls)}
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {',' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
if(launch_exit_hook != Py_None){{
Expand Down Expand Up @@ -568,11 +575,11 @@ def serialize_args(args, constants, signature):
class XPULauncher(object):

def __init__(self, src, metadata):
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
constants = src.constants if hasattr(src, "constants") else dict()
self.constants = {idx: value for idx, value in constants.items()}
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
self.constants = {arg_idx(idx): value for idx, value in constants.items()}
self.signature = {idx: value for idx, value in src.signature.items()}
src = make_launcher(self.constants, self.signature, ids)
src = make_launcher(self.constants, self.signature)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch

Expand Down

0 comments on commit 0021f48

Please sign in to comment.