From 262197faa5560841357f0d05c3d0a4bf2b9470a3 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 12 Jan 2025 18:20:57 +0100 Subject: [PATCH] Backport fix from #633 for KNNShap --- src/pydvl/value/shapley/knn.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/pydvl/value/shapley/knn.py b/src/pydvl/value/shapley/knn.py index 06cb9903d..c7eee1974 100644 --- a/src/pydvl/value/shapley/knn.py +++ b/src/pydvl/value/shapley/knn.py @@ -73,28 +73,26 @@ def knn_shapley(u: Utility, *, progress: bool = True) -> ValuationResult: # closest to farthest _, indices = nns.kneighbors(u.data.x_test) - values: NDArray[np.float64] = np.zeros_like(u.data.indices, dtype=np.float64) + res = np.zeros_like(u.data.indices, dtype=np.float64) n = len(u.data) yt = u.data.y_train iterator = enumerate(zip(u.data.y_test, indices), start=1) for j, (y, ii) in tqdm(iterator, disable=not progress): - value_at_x = int(yt[ii[-1]] == y) / n - values[ii[-1]] += (value_at_x - values[ii[-1]]) / j - for i in range(n - 2, n_neighbors, -1): # farthest to closest - value_at_x = ( - values[ii[i + 1]] + (int(yt[ii[i]] == y) - int(yt[ii[i + 1]] == y)) / i - ) - values[ii[i]] += (value_at_x - values[ii[i]]) / j - for i in range(n_neighbors, -1, -1): # farthest to closest - value_at_x = ( - values[ii[i + 1]] - + (int(yt[ii[i]] == y) - int(yt[ii[i + 1]] == y)) / n_neighbors - ) - values[ii[i]] += (value_at_x - values[ii[i]]) / j + values = np.zeros_like(u.data.indices, dtype=np.float64) + idx = ii[-1] + values[idx] = int(yt[idx] == y) / n + + for i in range(n - 1, 0, -1): + prev_idx = idx + idx = ii[i - 1] + values[idx] = values[prev_idx] + ( + int(yt[idx] == y) - int(yt[prev_idx] == y) + ) / max(n_neighbors, i) + res += values return ValuationResult( algorithm="knn_shapley", status=Status.Converged, - values=values, + values=res, data_names=u.data.data_names, )