Skip to content

Commit

Permalink
ci: cache reference metrics & clean audio tests (#2335)
Browse files Browse the repository at this point in the history
* cache reference metrics
* audio
* classif
* regress
* image
* others
* cleaning

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Feb 27, 2024
1 parent efb3a25 commit c53ea94
Show file tree
Hide file tree
Showing 105 changed files with 809 additions and 780 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
tests/_data/
data.zip
tests/_reference-cache/
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
Expand Down
23 changes: 23 additions & 0 deletions tests/unittests/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
import os
from typing import Callable, Optional

from torch import Tensor

from unittests import _PATH_ALL_TESTS

_SAMPLE_AUDIO_SPEECH = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech.wav")
_SAMPLE_AUDIO_SPEECH_BAB_DB = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech_bab_0dB.wav")
_SAMPLE_NUMPY_ISSUE_895 = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "issue_895.npz")


def _average_metric_wrapper(
preds: Tensor, target: Tensor, metric_func: Callable, res_index: Optional[int] = None
) -> Tensor:
"""Average the metric values.
Args:
preds: predictions, shape[batch, spk, time]
target: targets, shape[batch, spk, time]
metric_func: a function which return best_metric and best_perm
res_index: if not None, return best_metric[res_index]
Returns:
the average of best_metric
"""
if res_index is not None:
return metric_func(preds, target)[res_index].mean()
return metric_func(preds, target).mean()
23 changes: 6 additions & 17 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchmetrics.functional.audio import perceptual_evaluation_speech_quality

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

Expand All @@ -41,7 +41,7 @@
)


def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
def _reference_pesq_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
"""Comparison function."""
# shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time]
Expand All @@ -54,23 +54,12 @@ def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
return torch.tensor(mss)


def _average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


pesq_original_batch_8k_nb = partial(_pesq_original_batch, fs=8000, mode="nb")
pesq_original_batch_16k_nb = partial(_pesq_original_batch, fs=16000, mode="nb")
pesq_original_batch_16k_wb = partial(_pesq_original_batch, fs=16000, mode="wb")


@pytest.mark.parametrize(
"preds, target, ref_metric, fs, mode",
[
(inputs_8k.preds, inputs_8k.target, pesq_original_batch_8k_nb, 8000, "nb"),
(inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_nb, 16000, "nb"),
(inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_wb, 16000, "wb"),
(inputs_8k.preds, inputs_8k.target, partial(_reference_pesq_batch, fs=8000, mode="nb"), 8000, "nb"),
(inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="nb"), 16000, "nb"),
(inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="wb"), 16000, "wb"),
],
)
class TestPESQ(MetricTester):
Expand All @@ -89,7 +78,7 @@ def test_pesq(self, preds, target, ref_metric, fs, mode, num_processes, ddp):
preds,
target,
PerceptualEvaluationSpeechQuality,
reference_metric=partial(_average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric),
metric_args={"fs": fs, "mode": mode, "n_processes": num_processes},
)

Expand Down
68 changes: 32 additions & 36 deletions tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,28 @@
)

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.audio import _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

TIME = 10
TIME_FRAME = 10


# three speaker examples to test _find_best_perm_by_linear_sum_assignment
inputs1 = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME),
)
# two speaker examples to test _find_best_perm_by_exhuastive_method
inputs2 = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME),
)


def naive_implementation_pit_scipy(
def _reference_scipy_pit(
preds: Tensor,
target: Tensor,
metric_func: Callable,
Expand All @@ -66,10 +67,8 @@ def naive_implementation_pit_scipy(
eval_func: min or max
Returns:
best_metric:
shape [batch]
best_perm:
shape [batch, spk]
best_metric: shape [batch]
best_perm: shape [batch, spk]
"""
batch_size, spk_num = target.shape[0:2]
Expand All @@ -88,62 +87,59 @@ def naive_implementation_pit_scipy(
return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms))


def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
"""Average the metric values.
def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
return _reference_scipy_pit(
preds=preds,
target=target,
metric_func=signal_noise_ratio,
eval_func="max",
)

Args:
preds: predictions, shape[batch, spk, time]
target: targets, shape[batch, spk, time]
metric_func: a function which return best_metric and best_perm
Returns:
the average of best_metric

"""
return metric_func(preds, target)[0].mean()


snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=signal_noise_ratio, eval_func="max")
si_sdr_pit_scipy = partial(
naive_implementation_pit_scipy, metric_func=scale_invariant_signal_distortion_ratio, eval_func="max"
)
def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
return _reference_scipy_pit(
preds=preds,
target=target,
metric_func=scale_invariant_signal_distortion_ratio,
eval_func="max",
)


@pytest.mark.parametrize(
"preds, target, ref_metric, metric_func, mode, eval_func",
[
(inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"),
(inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"),
(
inputs1.preds,
inputs1.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"speaker-wise",
"max",
),
(inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"),
(inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"),
(
inputs2.preds,
inputs2.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"speaker-wise",
"max",
),
(inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"),
(inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"),
(
inputs1.preds,
inputs1.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"permutation-wise",
"max",
),
(inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"),
(inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"),
(
inputs2.preds,
inputs2.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"permutation-wise",
"max",
Expand All @@ -163,7 +159,7 @@ def test_pit(self, preds, target, ref_metric, metric_func, mode, eval_func, ddp)
preds,
target,
PermutationInvariantTraining,
reference_metric=partial(_average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric, res_index=0),
metric_args={"metric_func": metric_func, "mode": mode, "eval_func": eval_func},
)

Expand Down
26 changes: 15 additions & 11 deletions tests/unittests/audio/test_sa_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
)


def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool):
def _reference_local_sa_sdr(
preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool, reduce_mean: bool = False
):
# According to the original paper, the sa-sdr equals to si-sdr with inputs concatenated over the speaker
# dimension if scale_invariant==True. Accordingly, for scale_invariant==False, the sa-sdr equals to snr.
# shape: preds [BATCH_SIZE, Spk, Time] , target [BATCH_SIZE, Spk, Time]
Expand All @@ -51,14 +53,14 @@ def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean:
preds = preds.reshape(preds.shape[0], preds.shape[1] * preds.shape[2])
target = target.reshape(target.shape[0], target.shape[1] * target.shape[2])
if scale_invariant:
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False)
return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)


def _average_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return _ref_metric(preds, target, scale_invariant, zero_mean).mean()
sa_sdr = scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False)
else:
sa_sdr = signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)
if reduce_mean:
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return sa_sdr.mean()
return sa_sdr


@pytest.mark.parametrize(
Expand All @@ -83,7 +85,9 @@ def test_si_sdr(self, preds, target, scale_invariant, zero_mean, ddp):
preds,
target,
SourceAggregatedSignalDistortionRatio,
reference_metric=partial(_average_metric, scale_invariant=scale_invariant, zero_mean=zero_mean),
reference_metric=partial(
_reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean, reduce_mean=True
),
metric_args={
"scale_invariant": scale_invariant,
"zero_mean": zero_mean,
Expand All @@ -96,7 +100,7 @@ def test_sa_sdr_functional(self, preds, target, scale_invariant, zero_mean):
preds,
target,
source_aggregated_signal_distortion_ratio,
reference_metric=partial(_ref_metric, scale_invariant=scale_invariant, zero_mean=zero_mean),
reference_metric=partial(_reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean),
metric_args={
"scale_invariant": scale_invariant,
"zero_mean": zero_mean,
Expand Down
44 changes: 19 additions & 25 deletions tests/unittests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable

import numpy as np
import pytest
Expand Down Expand Up @@ -43,7 +42,9 @@
)


def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool = False) -> Tensor:
def _reference_sdr_batch(
preds: Tensor, target: Tensor, compute_permutation: bool = False, reduce_mean: bool = False
) -> Tensor:
# shape: preds [BATCH_SIZE, spk, Time] , target [BATCH_SIZE, spk, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, spk, Time] , target [NUM_BATCHES*BATCH_SIZE, spk, Time]
target = target.detach().cpu().numpy()
Expand All @@ -52,56 +53,49 @@ def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool
for b in range(preds.shape[0]):
sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation)
mss.append(sdr_val_np)
return torch.tensor(np.array(mss))
sdr = torch.tensor(np.array(mss))
if reduce_mean:
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return sdr.mean()
return sdr


def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


original_impl_compute_permutation = partial(_sdr_original_batch)


@pytest.mark.skipif( # TODO: figure out why tests leads to cuda errors on latest torch
@pytest.mark.skipif( # FIXME: figure out why tests leads to cuda errors on latest torch
_TORCH_GREATER_EQUAL_1_11 and torch.cuda.is_available(), reason="tests leads to cuda errors on latest torch"
)
@pytest.mark.parametrize(
"preds, target, ref_metric",
[
(inputs_1spk.preds, inputs_1spk.target, original_impl_compute_permutation),
(inputs_2spk.preds, inputs_2spk.target, original_impl_compute_permutation),
],
"preds, target",
[(inputs_1spk.preds, inputs_1spk.target), (inputs_2spk.preds, inputs_2spk.target)],
)
class TestSDR(MetricTester):
"""Test class for `SignalDistortionRatio` metric."""

atol = 1e-2

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_sdr(self, preds, target, ref_metric, ddp):
def test_sdr(self, preds, target, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp,
preds,
target,
SignalDistortionRatio,
reference_metric=partial(_average_metric, metric_func=ref_metric),
reference_metric=partial(_reference_sdr_batch, reduce_mean=True),
metric_args={},
)

def test_sdr_functional(self, preds, target, ref_metric):
def test_sdr_functional(self, preds, target):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds,
target,
signal_distortion_ratio,
ref_metric,
_reference_sdr_batch,
metric_args={},
)

def test_sdr_differentiability(self, preds, target, ref_metric):
def test_sdr_differentiability(self, preds, target):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
preds=preds,
Expand All @@ -110,7 +104,7 @@ def test_sdr_differentiability(self, preds, target, ref_metric):
metric_args={},
)

def test_sdr_half_cpu(self, preds, target, ref_metric):
def test_sdr_half_cpu(self, preds, target):
"""Test dtype support of the metric on CPU."""
self.run_precision_test_cpu(
preds=preds,
Expand All @@ -121,7 +115,7 @@ def test_sdr_half_cpu(self, preds, target, ref_metric):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_sdr_half_gpu(self, preds, target, ref_metric):
def test_sdr_half_gpu(self, preds, target):
"""Test dtype support of the metric on GPU."""
self.run_precision_test_gpu(
preds=preds,
Expand Down
Loading

0 comments on commit c53ea94

Please sign in to comment.