Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Jan 24, 2025
1 parent 4c877dc commit 0355bfb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
8 changes: 4 additions & 4 deletions monai/inferers/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ def __init__(
self.store = store
if version_geq(get_package_version("zarr"), "3.0.0"):
if value_store is None:
with TemporaryDirectory() as tmpdir:
self.value_store = zarr.storage.LocalStore(tmpdir)
tmpdir = TemporaryDirectory()
self.value_store = zarr.storage.LocalStore(tmpdir.name)
else:
self.value_store = value_store
if count_store is None:
with TemporaryDirectory() as tmpdir:
self.count_store = zarr.storage.LocalStore(tmpdir)
tmpdir = TemporaryDirectory()
self.count_store = zarr.storage.LocalStore(tmpdir.name)
else:
self.count_store = count_store
else:
Expand Down
11 changes: 4 additions & 7 deletions tests/test_zarr_avg_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,16 @@ class ZarrAvgMergerTests(unittest.TestCase):
]
)
def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected):
codec_reg = numcodecs.registry.codec_registry
if "compressor" in arguments:
if arguments["compressor"] != "default":
arguments["compressor"] = numcodecs.registry.codec_registry[arguments["compressor"].lower()]()
arguments["compressor"] = codec_reg[arguments["compressor"].lower()]()
if "value_compressor" in arguments:
if arguments["value_compressor"] != "default":
arguments["value_compressor"] = numcodecs.registry.codec_registry[
arguments["value_compressor"].lower()
]()
arguments["value_compressor"] = codec_reg[arguments["value_compressor"].lower()]()
if "count_compressor" in arguments:
if arguments["count_compressor"] != "default":
arguments["count_compressor"] = numcodecs.registry.codec_registry[
arguments["count_compressor"].lower()
]()
arguments["count_compressor"] = codec_reg[arguments["count_compressor"].lower()]()
merger = ZarrAvgMerger(**arguments)
for pl in patch_locations:
merger.aggregate(pl[0], pl[1])
Expand Down

0 comments on commit 0355bfb

Please sign in to comment.