Skip to content

Commit

Permalink
Format Benchmark Scripts with Ruff (linkedin#516)
Browse files Browse the repository at this point in the history
## Summary

Format benchmark scripts with (newly migrated) Ruff after
linkedin#483.


## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 authored Jan 10, 2025
1 parent 9586a87 commit ba72b8e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 49 deletions.
23 changes: 8 additions & 15 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,12 +36,8 @@ def bench_memory_fused_linear_cpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -72,7 +68,8 @@ def full():
def bench_speed_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,12 +79,8 @@ def bench_speed_fused_linear_cpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
8 changes: 4 additions & 4 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import torch
import triton

from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -21,7 +19,8 @@


def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand Down Expand Up @@ -70,7 +69,8 @@ def full():


def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand Down
23 changes: 8 additions & 15 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,12 +36,8 @@ def bench_memory_fused_linear_orpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -73,7 +69,8 @@ def full():
def bench_speed_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -83,12 +80,8 @@ def bench_speed_fused_linear_orpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
23 changes: 8 additions & 15 deletions benchmark/scripts/benchmark_simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,12 +36,8 @@ def bench_memory_fused_linear_simpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -72,7 +68,8 @@ def full():
def bench_speed_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,12 +79,8 @@ def bench_speed_fused_linear_simpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down

0 comments on commit ba72b8e

Please sign in to comment.