Skip to content

Commit

Permalink
Fixed incorrect PTX parsing of ret instruction after branch label (#…
Browse files Browse the repository at this point in the history
…17859)

The PTX parser replaces PTX code with inline PTX code (using inline ASM blocks).
It considers a branch label and the immediate instruction as a single unit to process.  
During the ASM->CUDA transform step,  it searches for the `ret` instruction in the string and replaces the whole statement and not the substring that contains the `ret;` instruction. which means an expression like:

```asm

BB0_1:
ret;
```

is parsed as: 

```asm 

BB0_1: ret;

```

and then transformed to: 

```asm

bra RETTGT

``` 

instead of:

```asm 

BB0_1: bra RETTGT

```

This merge request fixes this bug.

Authors:
  - Basit Ayantunde (https://github.com/lamarrr)

Approvers:
  - David Wendt (https://github.com/davidwendt)
  - Shruti Shivakumar (https://github.com/shrshi)

URL: #17859
  • Loading branch information
lamarrr authored Feb 4, 2025
1 parent 7baf1e9 commit 0e91baf
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
13 changes: 9 additions & 4 deletions cpp/src/jit/parser.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -375,9 +375,14 @@ std::string ptx_parser::parse()

// Don't use std::accumulate until C++20 when rvalue references are supported
auto final_output = fn_header_output + "\n asm volatile (\"{\");";
for (auto const& line : fn_body_output)
final_output += line.find("ret;") != std::string::npos ? " asm volatile (\"bra RETTGT;\");\n"
: " " + line + "\n";
for (auto const& line : fn_body_output) {
std::string output{line};
std::string_view const ret_instruction = "ret;";
if (auto start = output.find(ret_instruction); start != std::string::npos) {
output.replace(start, ret_instruction.size(), "bra RETTGT;");
}
final_output += " " + output + "\n";
}
return final_output + " asm volatile (\"RETTGT:}\");}";
}

Expand Down
25 changes: 24 additions & 1 deletion cpp/tests/jit/parse_ptx_function.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -69,6 +69,7 @@ __device__ __inline__ void GENERIC_OP(
){
asm volatile ("{");
asm volatile ("bra RETTGT;");
/** ret*/
asm volatile ("RETTGT:}");}
)";

Expand All @@ -78,6 +79,28 @@ __device__ __inline__ void GENERIC_OP(
EXPECT_TRUE(ptx_equal(cuda_source, expected));
}

TEST_F(JitParseTest, PTXWithBranchLabel)
{
std::string raw_ptx = R"(
.visible .func _Z1flPaS_(){
BB0:
ret;
}
)";

std::string expected = R"(
__device__ __inline__ void GENERIC_OP(){
asm volatile ("{");
asm volatile ("BB0: bra RETTGT;");
/** BB0: ret*/
asm volatile ("RETTGT:}");}
)";

std::string cuda_source = cudf::jit::parse_single_function_ptx(raw_ptx, "GENERIC_OP", {});

EXPECT_TRUE(ptx_equal(cuda_source, expected));
}

TEST_F(JitParseTest, PTXWithPragma)
{
std::string raw_ptx = R"(
Expand Down

0 comments on commit 0e91baf

Please sign in to comment.