Skip to content

Commit

Permalink
Remove stale skips in testing; fix pre-commit (#5)
Browse files Browse the repository at this point in the history
* Remove stale skips in testing; fix Tripy repository url

* Ensure that add_copyright applies only to tripy/ files
  • Loading branch information
parthchadha authored Aug 1, 2024
1 parent f77c4f9 commit cb22081
Show file tree
Hide file tree
Showing 10 changed files with 15 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tripy/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
hooks:
- id: add-license
name: Add License
entry: python tools/add_copyright.py
entry: python tripy/tools/add_copyright.py
language: python
stages: [pre-commit]
files: \.py$
3 changes: 1 addition & 2 deletions tripy/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@
myst_url_schemes = {
"http": None,
"https": None,
# TODO (#release): Update link
"source": "https://github.com/NVIDIA/TensorRT-Incubator/tripy/-/blob/main/{{path}}",
"source": "https://github.com/NVIDIA/TensorRT-Incubator/tree/main/tripy/{{path}}",
}
myst_number_code_blocks = ["py", "rst"]

Expand Down
3 changes: 0 additions & 3 deletions tripy/tests/backend/mlir/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def test_invalid_slice(self):
):
sliced.eval()

@pytest.mark.skip(
"TODO (#207): MLIR-TRT currently triggers a C-style abort, which we cannot handle. Needs to be fixed in MLIR-TRT."
)
def test_reshape_invalid_volume(self):
tensor = tp.ones((2, 2))
reshaped = tp.reshape(tensor, (3, 3))
Expand Down
1 change: 0 additions & 1 deletion tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def test_gather(self, values):
assert isinstance(s2.trace_tensor.producer, Gather)
assert cp.from_dlpack(s2).get().tolist() == [values[0], values[-1]]

@pytest.mark.skip("#186: Fix test_matrix_multiplication.py hang for 1D tensor.")
def test_matmul(self, values):
s1 = tp.Shape(values)
s2 = tp.Shape(values)
Expand Down
5 changes: 1 addition & 4 deletions tripy/tests/frontend/trace/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,12 @@ def test_mismatched_dtypes_fails(self):
with helper.raises(tp.TripyException, match="Incompatible input data types.", has_stack_info_for=[a, b]):
c = a @ b

@pytest.mark.skip(
"mlir-tensorrt #860 fixes dynamic broadcast issue."
)
def test_incompatible_1d_shapes_fails(self):
a = tp.ones((2,), dtype=tp.float32)
b = tp.ones((3,), dtype=tp.float32)
c = a @ b

with helper.raises(tp.TripyException, match="Incompatible input shapes.", has_stack_info_for=[a, b, c]):
with helper.raises(tp.TripyException, match="contracting dimension sizes must match for lhs/rhs", has_stack_info_for=[a, b, c]):
c.eval()

def test_incompatible_2d_shapes_fails(self):
Expand Down
5 changes: 1 addition & 4 deletions tripy/tests/frontend/trace/ops/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,13 @@ def test_op_func(self):
assert isinstance(a, tp.Tensor)
assert isinstance(a.trace_tensor.producer, Squeeze)

@pytest.mark.skip(
"Program segfaulting instead of error being reported: mlir-tensorrt #855"
)
def test_incorrect_dims(self):
a = tp.Tensor(np.ones((1, 1, 4), dtype=np.int32))
b = tp.squeeze(a, 2)

with helper.raises(
tp.TripyException,
match="Cannot select an axis to squeeze out which has size not equal to one",
match="output_shape is incompatible with input type of operation: input has 4 elements, but output_shape has 1",
has_stack_info_for=[a, b],
):
b.eval()
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/frontend/trace/ops/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_bool_condition(self):
assert isinstance(w, tp.Tensor)
assert isinstance(w.trace_tensor.producer, Where)

@pytest.mark.skip("Test segfaults due to mlir-tensorrt #885")
def test_mismatched_input_shapes(self):
cond = tp.ones((2,), dtype=tp.float32) > tp.ones((2,), dtype=tp.float32)
a = tp.ones((2,), dtype=tp.float32)
Expand All @@ -48,7 +47,7 @@ def test_mismatched_input_shapes(self):

with helper.raises(
tp.TripyException,
match=re.escape("size of operand dimension 0 (2) is not compatible with size of result dimension 0 (3)"),
match=re.escape("size of operand dimension 0 (3) is not compatible with size of result dimension 0 (2)"),
has_stack_info_for=[a, b, c, cond],
):
c.eval()
Expand Down
4 changes: 1 addition & 3 deletions tripy/tests/integration/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def test_cast(self, input_dtype, target_dtype):

assert np.array_equal(cp.from_dlpack(output).get(), np_input.astype(target_dtype))

@pytest.mark.skip("#219: Quantize/dequantize fail with dynamic shapes")
# these dtypes don't have analogues in numpy
@pytest.mark.parametrize("source_dtype", [pytest.param(tp.float8, marks=skip_if_older_than_sm89), tp.int4])
def test_cast_quantized_dtypes_into_bool(self, source_dtype):
Expand All @@ -71,8 +70,7 @@ def test_cast_quantized_dtypes_into_bool(self, source_dtype):
output = tp.cast(q, tp.bool)
assert cp.from_dlpack(output).get().tolist() == [True, False, False, True]

@pytest.mark.skip("#219: Dequantize fails with dynamic shapes")
@pytest.mark.parametrize("target_dtype", [np.float32, np.float64, np.int32, np.int64, np.int8])
@pytest.mark.parametrize("target_dtype", [np.float32, np.int32, np.int64, np.int8])
def test_cast_from_bool(self, target_dtype):
from tripy.common.utils import convert_frontend_dtype_to_tripy_dtype

Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/integration/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ class TestSplitOp:
((4,), (2, 0), lambda t: (t[:2], t[2:])),
((4,), (1, 0), lambda t: t[:]),
((4,), (4, 0), lambda t: (t[0:1], t[1:2], t[2:3], t[3:4])),
# Blocked on mlir-tensorrt #860
# ((4,), ([1, 2], 0), lambda t: (t[:1], t[1:2], t[2:])),
((4,), ([1, 2], 0), lambda t: (t[:1], t[1:2], t[2:])),
((12, 12), (3, 1), lambda t: (t[:, :4], t[:, 4:8], t[:, 8:])),
((12, 12), ([3], 1), lambda t: (t[:, :3], t[:, 3:])),
((12, 12), (4, 0), lambda t: (t[:3, :], t[3:6, :], t[6:9, :], t[9:12, :])),
Expand Down
15 changes: 8 additions & 7 deletions tripy/tools/add_copyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@
# limitations under the License.
#


import os
import re
import sys
import subprocess
import re
import sys
from datetime import datetime

current_year = str(datetime.now().year)

# Read the license text from the LICENSE file
def get_license_header():
with open('LICENSE', 'r', encoding='utf-8') as license_file:
return license_file.read().strip()
for license_path in ['LICENSE', 'tripy/LICENSE']:
if os.path.exists(license_path):
with open(license_path, 'r', encoding='utf-8') as license_file:
return license_file.read().strip()
raise FileNotFoundError('LICENSE file not found in the current directory or tripy folder')

license_text = get_license_header()

Expand Down Expand Up @@ -57,7 +58,7 @@ def get_files(mode):
raise ValueError("Invalid mode. Use 'new' or 'all'.")

result = subprocess.run(command, capture_output=True, text=True)
return [f for f in result.stdout.splitlines() if f.endswith('.py')]
return [f for f in result.stdout.splitlines() if f.endswith('.py') and f.startswith('tripy/')]

def main(mode):
files = get_files(mode)
Expand Down

0 comments on commit cb22081

Please sign in to comment.