-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add jax torch test
- Loading branch information
Showing
1 changed file
with
52 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
import torch_xla2 | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
import unittest | ||
|
||
|
||
class JaxTorchTest(unittest.TestCase): | ||
|
||
|
||
def test_matmul_bfloat16_xla2(self): | ||
jax.config.update('jax_platform_name', 'cpu') | ||
torch.set_default_dtype(torch.bfloat16) | ||
r = c = 1000 | ||
q = torch.randn((r, c)) | ||
k = torch.randn((r, c)) | ||
print(f"torch matlmul: {q.shape} * {k.shape}") | ||
result = torch.matmul(q, k) | ||
|
||
jax_q = torch_xla2.tensor.t2j(q) | ||
jax_k = torch_xla2.tensor.t2j(k) | ||
print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}") | ||
jax_result = jnp.matmul(jax_q, jax_k) | ||
target_result = torch_xla2.tensor.j2t(jax_result) | ||
print(f"----------------------- matmul: Diff norm {(target_result - result).norm()}") | ||
self.assertTrue(torch.allclose(target_result, result, atol=1)) | ||
|
||
|
||
def test_matmul_bfloat32(self): | ||
jax.config.update('jax_platform_name', 'cpu') | ||
torch.set_default_dtype(torch.float32) | ||
r = c = 1000 | ||
q = torch.randn((r, c)) | ||
k = torch.randn((r, c)) | ||
print(f"torch matlmul: {q.shape} * {k.shape}") | ||
result = torch.matmul(q, k) | ||
|
||
|
||
jax_q = torch_xla2.tensor.t2j(q) | ||
jax_k = torch_xla2.tensor.t2j(k) | ||
print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}") | ||
jax_result = jnp.matmul(jax_q, jax_k) | ||
target_result = torch_xla2.tensor.j2t(jax_result) | ||
print(f"----------------------- matmul: Diff norm {(target_result - result).norm()}") | ||
self.assertTrue(torch.allclose(target_result, result, atol=1e-4)) | ||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |