Skip to content

Commit

Permalink
Add Jax torch test (#17)
Browse files Browse the repository at this point in the history
add jax torch test
  • Loading branch information
FanhaiLu1 authored Apr 12, 2024
1 parent cd26fe1 commit e03d13d
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tests/test_jax_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch_xla2
import jax
import jax.numpy as jnp

import unittest


class JaxTorchTest(unittest.TestCase):


def test_matmul_bfloat16_xla2(self):
jax.config.update('jax_platform_name', 'cpu')
torch.set_default_dtype(torch.bfloat16)
r = c = 1000
q = torch.randn((r, c))
k = torch.randn((r, c))
print(f"torch matlmul: {q.shape} * {k.shape}")
result = torch.matmul(q, k)

jax_q = torch_xla2.tensor.t2j(q)
jax_k = torch_xla2.tensor.t2j(k)
print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}")
jax_result = jnp.matmul(jax_q, jax_k)
target_result = torch_xla2.tensor.j2t(jax_result)
print(f"----------------------- matmul: Diff norm {(target_result - result).norm()}")
self.assertTrue(torch.allclose(target_result, result, atol=1))


def test_matmul_bfloat32(self):
jax.config.update('jax_platform_name', 'cpu')
torch.set_default_dtype(torch.float32)
r = c = 1000
q = torch.randn((r, c))
k = torch.randn((r, c))
print(f"torch matlmul: {q.shape} * {k.shape}")
result = torch.matmul(q, k)


jax_q = torch_xla2.tensor.t2j(q)
jax_k = torch_xla2.tensor.t2j(k)
print(f"torch matlmul: {jax_q.shape} * {jax_k.shape}")
jax_result = jnp.matmul(jax_q, jax_k)
target_result = torch_xla2.tensor.j2t(jax_result)
print(f"----------------------- matmul: Diff norm {(target_result - result).norm()}")
self.assertTrue(torch.allclose(target_result, result, atol=1e-4))




if __name__ == '__main__':
unittest.main()

0 comments on commit e03d13d

Please sign in to comment.