Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: cache reference metrics & clean audio tests #2335

Merged
merged 29 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
tests/_data/
data.zip
tests/_reference-cache/
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
Expand Down
23 changes: 23 additions & 0 deletions tests/unittests/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
import os
from typing import Callable, Optional

from torch import Tensor

from unittests import _PATH_ALL_TESTS

_SAMPLE_AUDIO_SPEECH = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech.wav")
_SAMPLE_AUDIO_SPEECH_BAB_DB = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech_bab_0dB.wav")
_SAMPLE_NUMPY_ISSUE_895 = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "issue_895.npz")


def _average_metric_wrapper(
preds: Tensor, target: Tensor, metric_func: Callable, res_index: Optional[int] = None
) -> Tensor:
"""Average the metric values.
Args:
preds: predictions, shape[batch, spk, time]
target: targets, shape[batch, spk, time]
metric_func: a function which return best_metric and best_perm
res_index: if not None, return best_metric[res_index]
Returns:
the average of best_metric
"""
if res_index is not None:
return metric_func(preds, target)[res_index].mean()
return metric_func(preds, target).mean()
23 changes: 6 additions & 17 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchmetrics.functional.audio import perceptual_evaluation_speech_quality

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

Expand All @@ -41,7 +41,7 @@
)


def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
def _reference_pesq_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
"""Comparison function."""
# shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time]
Expand All @@ -54,23 +54,12 @@ def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
return torch.tensor(mss)


def _average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


pesq_original_batch_8k_nb = partial(_pesq_original_batch, fs=8000, mode="nb")
pesq_original_batch_16k_nb = partial(_pesq_original_batch, fs=16000, mode="nb")
pesq_original_batch_16k_wb = partial(_pesq_original_batch, fs=16000, mode="wb")


@pytest.mark.parametrize(
"preds, target, ref_metric, fs, mode",
[
(inputs_8k.preds, inputs_8k.target, pesq_original_batch_8k_nb, 8000, "nb"),
(inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_nb, 16000, "nb"),
(inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_wb, 16000, "wb"),
(inputs_8k.preds, inputs_8k.target, partial(_reference_pesq_batch, fs=8000, mode="nb"), 8000, "nb"),
(inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="nb"), 16000, "nb"),
(inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="wb"), 16000, "wb"),
],
)
class TestPESQ(MetricTester):
Expand All @@ -89,7 +78,7 @@ def test_pesq(self, preds, target, ref_metric, fs, mode, num_processes, ddp):
preds,
target,
PerceptualEvaluationSpeechQuality,
reference_metric=partial(_average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric),
metric_args={"fs": fs, "mode": mode, "n_processes": num_processes},
)

Expand Down
68 changes: 32 additions & 36 deletions tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,28 @@
)

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.audio import _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

TIME = 10
TIME_FRAME = 10


# three speaker examples to test _find_best_perm_by_linear_sum_assignment
inputs1 = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME),
)
# two speaker examples to test _find_best_perm_by_exhuastive_method
inputs2 = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME),
)


def naive_implementation_pit_scipy(
def _reference_scipy_pit(
preds: Tensor,
target: Tensor,
metric_func: Callable,
Expand All @@ -66,10 +67,8 @@ def naive_implementation_pit_scipy(
eval_func: min or max
Returns:
best_metric:
shape [batch]
best_perm:
shape [batch, spk]
best_metric: shape [batch]
best_perm: shape [batch, spk]
"""
batch_size, spk_num = target.shape[0:2]
Expand All @@ -88,62 +87,59 @@ def naive_implementation_pit_scipy(
return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms))


def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
"""Average the metric values.
def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
return _reference_scipy_pit(
preds=preds,
target=target,
metric_func=signal_noise_ratio,
eval_func="max",
)

Args:
preds: predictions, shape[batch, spk, time]
target: targets, shape[batch, spk, time]
metric_func: a function which return best_metric and best_perm
Returns:
the average of best_metric

"""
return metric_func(preds, target)[0].mean()


snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=signal_noise_ratio, eval_func="max")
si_sdr_pit_scipy = partial(
naive_implementation_pit_scipy, metric_func=scale_invariant_signal_distortion_ratio, eval_func="max"
)
def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
return _reference_scipy_pit(
preds=preds,
target=target,
metric_func=scale_invariant_signal_distortion_ratio,
eval_func="max",
)


@pytest.mark.parametrize(
"preds, target, ref_metric, metric_func, mode, eval_func",
[
(inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"),
(inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"),
(
inputs1.preds,
inputs1.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"speaker-wise",
"max",
),
(inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"),
(inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"),
(
inputs2.preds,
inputs2.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"speaker-wise",
"max",
),
(inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"),
(inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"),
(
inputs1.preds,
inputs1.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"permutation-wise",
"max",
),
(inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"),
(inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"),
(
inputs2.preds,
inputs2.target,
si_sdr_pit_scipy,
_reference_scipy_pit_si_sdr,
scale_invariant_signal_distortion_ratio,
"permutation-wise",
"max",
Expand All @@ -163,7 +159,7 @@ def test_pit(self, preds, target, ref_metric, metric_func, mode, eval_func, ddp)
preds,
target,
PermutationInvariantTraining,
reference_metric=partial(_average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric, res_index=0),
metric_args={"metric_func": metric_func, "mode": mode, "eval_func": eval_func},
)

Expand Down
26 changes: 15 additions & 11 deletions tests/unittests/audio/test_sa_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
)


def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool):
def _reference_local_sa_sdr(
preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool, reduce_mean: bool = False
):
# According to the original paper, the sa-sdr equals to si-sdr with inputs concatenated over the speaker
# dimension if scale_invariant==True. Accordingly, for scale_invariant==False, the sa-sdr equals to snr.
# shape: preds [BATCH_SIZE, Spk, Time] , target [BATCH_SIZE, Spk, Time]
Expand All @@ -51,14 +53,14 @@ def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean:
preds = preds.reshape(preds.shape[0], preds.shape[1] * preds.shape[2])
target = target.reshape(target.shape[0], target.shape[1] * target.shape[2])
if scale_invariant:
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False)
return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)


def _average_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return _ref_metric(preds, target, scale_invariant, zero_mean).mean()
sa_sdr = scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False)
else:
sa_sdr = signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)
if reduce_mean:
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return sa_sdr.mean()
return sa_sdr


@pytest.mark.parametrize(
Expand All @@ -83,7 +85,9 @@ def test_si_sdr(self, preds, target, scale_invariant, zero_mean, ddp):
preds,
target,
SourceAggregatedSignalDistortionRatio,
reference_metric=partial(_average_metric, scale_invariant=scale_invariant, zero_mean=zero_mean),
reference_metric=partial(
_reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean, reduce_mean=True
),
metric_args={
"scale_invariant": scale_invariant,
"zero_mean": zero_mean,
Expand All @@ -96,7 +100,7 @@ def test_sa_sdr_functional(self, preds, target, scale_invariant, zero_mean):
preds,
target,
source_aggregated_signal_distortion_ratio,
reference_metric=partial(_ref_metric, scale_invariant=scale_invariant, zero_mean=zero_mean),
reference_metric=partial(_reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean),
metric_args={
"scale_invariant": scale_invariant,
"zero_mean": zero_mean,
Expand Down
44 changes: 19 additions & 25 deletions tests/unittests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable

import numpy as np
import pytest
Expand Down Expand Up @@ -43,7 +42,9 @@
)


def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool = False) -> Tensor:
def _reference_sdr_batch(
preds: Tensor, target: Tensor, compute_permutation: bool = False, reduce_mean: bool = False
) -> Tensor:
# shape: preds [BATCH_SIZE, spk, Time] , target [BATCH_SIZE, spk, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, spk, Time] , target [NUM_BATCHES*BATCH_SIZE, spk, Time]
target = target.detach().cpu().numpy()
Expand All @@ -52,56 +53,49 @@ def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool
for b in range(preds.shape[0]):
sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation)
mss.append(sdr_val_np)
return torch.tensor(np.array(mss))
sdr = torch.tensor(np.array(mss))
if reduce_mean:
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return sdr.mean()
return sdr


def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


original_impl_compute_permutation = partial(_sdr_original_batch)


@pytest.mark.skipif( # TODO: figure out why tests leads to cuda errors on latest torch
@pytest.mark.skipif( # FIXME: figure out why tests leads to cuda errors on latest torch
_TORCH_GREATER_EQUAL_1_11 and torch.cuda.is_available(), reason="tests leads to cuda errors on latest torch"
)
@pytest.mark.parametrize(
"preds, target, ref_metric",
[
(inputs_1spk.preds, inputs_1spk.target, original_impl_compute_permutation),
(inputs_2spk.preds, inputs_2spk.target, original_impl_compute_permutation),
],
"preds, target",
[(inputs_1spk.preds, inputs_1spk.target), (inputs_2spk.preds, inputs_2spk.target)],
)
class TestSDR(MetricTester):
"""Test class for `SignalDistortionRatio` metric."""

atol = 1e-2

@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
def test_sdr(self, preds, target, ref_metric, ddp):
def test_sdr(self, preds, target, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp,
preds,
target,
SignalDistortionRatio,
reference_metric=partial(_average_metric, metric_func=ref_metric),
reference_metric=partial(_reference_sdr_batch, reduce_mean=True),
metric_args={},
)

def test_sdr_functional(self, preds, target, ref_metric):
def test_sdr_functional(self, preds, target):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds,
target,
signal_distortion_ratio,
ref_metric,
_reference_sdr_batch,
metric_args={},
)

def test_sdr_differentiability(self, preds, target, ref_metric):
def test_sdr_differentiability(self, preds, target):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
preds=preds,
Expand All @@ -110,7 +104,7 @@ def test_sdr_differentiability(self, preds, target, ref_metric):
metric_args={},
)

def test_sdr_half_cpu(self, preds, target, ref_metric):
def test_sdr_half_cpu(self, preds, target):
"""Test dtype support of the metric on CPU."""
self.run_precision_test_cpu(
preds=preds,
Expand All @@ -121,7 +115,7 @@ def test_sdr_half_cpu(self, preds, target, ref_metric):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_sdr_half_gpu(self, preds, target, ref_metric):
def test_sdr_half_gpu(self, preds, target):
"""Test dtype support of the metric on GPU."""
self.run_precision_test_gpu(
preds=preds,
Expand Down
Loading
Loading