diff --git a/tests/influence/torch/test_influence_model.py b/tests/influence/torch/test_influence_model.py index ecdc2de6b..b8266df58 100644 --- a/tests/influence/torch/test_influence_model.py +++ b/tests/influence/torch/test_influence_model.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, NamedTuple, Tuple +from typing import Callable, Dict, NamedTuple, OrderedDict, Tuple import numpy as np import pytest @@ -25,6 +25,7 @@ PreConditioner, ) from pydvl.influence.torch.util import BlockMode, SecondOrderMode +from pydvl.influence.types import UnsupportedInfluenceModeException from tests.influence.torch.conftest import minimal_training torch = pytest.importorskip("torch") @@ -246,12 +247,30 @@ def model_and_data( @fixture -def direct_influence_function_model(model_and_data, test_case: TestCase): +def block_structure(request): + return getattr(request, "param", BlockMode.FULL) + + +@fixture +def second_order_mode(request): + return getattr(request, "param", SecondOrderMode.HESSIAN) + + +@fixture +def direct_influence_function_model( + model_and_data, test_case: TestCase, block_structure: BlockMode, second_order_mode +): model, loss, x_train, y_train, x_test, y_test = model_and_data train_dataloader = DataLoader( TensorDataset(x_train, y_train), batch_size=test_case.batch_size ) - return DirectInfluence(model, loss, test_case.hessian_reg).fit(train_dataloader) + return DirectInfluence( + model, + loss, + test_case.hessian_reg, + block_structure=block_structure, + second_order_mode=second_order_mode, + ).fit(train_dataloader) @fixture @@ -259,11 +278,23 @@ def direct_influences( direct_influence_function_model: DirectInfluence, model_and_data, test_case: TestCase, -) -> NDArray: +) -> torch.Tensor: model, loss, x_train, y_train, x_test, y_test = model_and_data return direct_influence_function_model.influences( x_test, y_test, x_train, y_train, mode=test_case.mode - ).numpy() + ) + + +@fixture +def direct_influences_by_block( + direct_influence_function_model: DirectInfluence, + model_and_data, + test_case: TestCase, +) -> Dict[str, torch.Tensor]: + model, loss, x_train, y_train, x_test, y_test = model_and_data + return direct_influence_function_model.influences_by_block( + x_test, y_test, x_train, y_train, mode=test_case.mode + ) @fixture @@ -271,11 +302,11 @@ def direct_sym_influences( direct_influence_function_model: DirectInfluence, model_and_data, test_case: TestCase, -) -> NDArray: +) -> torch.Tensor: model, loss, x_train, y_train, x_test, y_test = model_and_data return direct_influence_function_model.influences( x_train, y_train, mode=test_case.mode - ).numpy() + ) @fixture @@ -283,9 +314,45 @@ def direct_factors( direct_influence_function_model: DirectInfluence, model_and_data, test_case: TestCase, -) -> NDArray: +) -> torch.Tensor: + model, loss, x_train, y_train, x_test, y_test = model_and_data + return direct_influence_function_model.influence_factors(x_train, y_train) + + +@fixture +def direct_factors_by_block( + direct_influence_function_model: DirectInfluence, + model_and_data, + test_case: TestCase, +) -> OrderedDict[str, torch.Tensor]: + model, loss, x_train, y_train, x_test, y_test = model_and_data + return direct_influence_function_model.influence_factors_by_block(x_train, y_train) + + +@fixture +def direct_influences_from_factors_by_block( + direct_influence_function_model: DirectInfluence, + direct_factors_by_block, + model_and_data, + test_case: TestCase, +) -> Dict[str, torch.Tensor]: model, loss, x_train, y_train, x_test, y_test = model_and_data - return direct_influence_function_model.influence_factors(x_train, y_train).numpy() + return direct_influence_function_model.influences_from_factors_by_block( + direct_factors_by_block, x_train, y_train, mode=test_case.mode + ) + + +@fixture +def direct_influences_from_factors( + direct_influence_function_model: DirectInfluence, + direct_factors, + model_and_data, + test_case: TestCase, +) -> torch.Tensor: + model, loss, x_train, y_train, x_test, y_test = model_and_data + return direct_influence_function_model.influences_from_factors( + direct_factors, x_train, y_train, mode=test_case.mode + ) @pytest.mark.parametrize( @@ -668,10 +735,10 @@ def test_influences_ekfac( assert np.allclose(ekfac_influence_values, influence_from_factors) assert np.allclose(ekfac_influence_values, accumulated_inf_by_layer) check_influence_correlations( - direct_influences, ekfac_influence_values, threshold=0.94 + direct_influences.numpy(), ekfac_influence_values, threshold=0.94 ) check_influence_correlations( - direct_sym_influences, ekfac_self_influence, threshold=0.94 + direct_sym_influences.numpy(), ekfac_self_influence, threshold=0.94 ) @@ -766,8 +833,13 @@ def test_influences_cg( ] +def check_correlation(arr_1, arr_2, corr_val): + assert np.all(pearsonr(arr_1, arr_2).statistic > corr_val) + assert np.all(spearmanr(arr_1, arr_2).statistic > corr_val) + + @pytest.mark.parametrize("composable_influence_factory", composable_influence_factories) -@pytest.mark.parametrize("block_mode", [mode for mode in BlockMode]) +@pytest.mark.parametrize("block_structure", [mode for mode in BlockMode], indirect=True) @pytest.mark.torch def test_composable_influence( test_case: TestCase, @@ -781,8 +853,12 @@ def test_composable_influence( ], direct_influences, direct_sym_influences, + direct_factors, + direct_influences_by_block, + direct_factors_by_block, + direct_influences_from_factors_by_block, device: torch.device, - block_mode, + block_structure, composable_influence_factory, ): model, loss, x_train, y_train, x_test, y_test = model_and_data @@ -791,25 +867,58 @@ def test_composable_influence( TensorDataset(x_train, y_train), batch_size=test_case.batch_size ) - harmonic_mean_influence = composable_influence_factory( - model, loss, test_case.hessian_reg, block_structure=block_mode + infl_model = composable_influence_factory( + model, loss, test_case.hessian_reg, block_structure=block_structure ).to(device) - harmonic_mean_influence = harmonic_mean_influence.fit(train_dataloader) - harmonic_mean_influence_values = ( - harmonic_mean_influence.influences( - x_test, y_test, x_train, y_train, mode=test_case.mode - ) - .cpu() - .numpy() + + with pytest.raises(NotFittedException): + infl_model.influences(x_test, y_test, x_train, y_train, mode=test_case.mode) + + infl_model = infl_model.fit(train_dataloader) + + with pytest.raises(UnsupportedInfluenceModeException): + infl_model.influences(x_test, y_test, mode="strange_mode") + + infl_values = infl_model.influences( + x_test, y_test, x_train, y_train, mode=test_case.mode ) - threshold = 0.999 - flat_direct_influences = direct_influences.reshape(-1) - flat_harmonic_influences = harmonic_mean_influence_values.reshape(-1) - assert np.all( - pearsonr(flat_direct_influences, flat_harmonic_influences).statistic > threshold + threshold = 1 - 1e-3 + check_correlation( + direct_influences.reshape(-1), infl_values.reshape(-1), corr_val=threshold ) - assert np.all( - spearmanr(flat_direct_influences, flat_harmonic_influences).statistic - > threshold + + sym_infl_values = infl_model.influences(x_train, y_train, mode=test_case.mode) + check_correlation( + direct_sym_influences.reshape(-1), sym_infl_values.reshape(-1), threshold ) + + infl_factors = infl_model.influence_factors(x_train, y_train) + flat_factors = infl_factors.reshape(-1) + flat_direct_factors = direct_factors.reshape(-1) + check_correlation(flat_factors, flat_direct_factors, threshold) + + infl_factors_by_block = infl_model.influence_factors_by_block(x_train, y_train) + for block, infl in infl_factors_by_block.items(): + check_correlation( + infl.reshape(-1), direct_factors_by_block[block].reshape(-1), threshold + ) + + infl_by_block = infl_model.influences_by_block( + x_test, y_test, x_train, y_train, mode=test_case.mode + ) + for block, infl in infl_by_block.items(): + check_correlation( + infl.reshape(-1), direct_influences_by_block[block].reshape(-1), threshold + ) + + infl_from_factors_by_block = infl_model.influences_from_factors_by_block( + infl_factors_by_block, x_train, y_train, mode=test_case.mode + ) + + for block, infl in infl_from_factors_by_block.items(): + check_correlation( + infl.reshape(-1), + direct_influences_from_factors_by_block[block].reshape(-1), + threshold, + )