Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xCEBRA implementation (AISTATS 2025) #225

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

gonlairo
Copy link
Contributor

@gonlairo gonlairo commented Feb 17, 2025

xCEBRA

eXplainable CEBRA 🔎🦓

This PR adds the following features:

  • multiobjective solver -> fit multiple subspaces with a new API
  • attribution methods (via captum), including our new method, inverted neuron gradient
  • regularized contrastive learning using jacobian regularization (required for identifiable attribution maps, but also useful for regularizing training more generally); aka xCEBRA

This code supports the following paper:

https://openreview.net/forum?id=aGrCXoTB4P

@inproceedings{
schneider2025timeseries,
title={Time-series attribution maps with regularized contrastive learning},
author={Steffen Schneider and Rodrigo Gonz{\'a}lez Laiz and Anastasiia Filippova and Markus Frey and Mackenzie W Mathis},
booktitle={The 28th International Conference on Artificial Intelligence and Statistics},
year={2025},
url={https://openreview.net/forum?id=aGrCXoTB4P}
}

Abstract:

Gradient-based attribution methods aim to explain decisions of deep learning models, but so far lack identifiability guarantees. Here, we propose a method to generate attribution maps with identifiability guarantees by developing a regularized contrastive learning algorithm (RegCL) trained on time-series data. We show theoretically that RegCL has favorable properties for identifying the Jacobian matrix of the data generating process. Empirically, we demonstrate robust approximation of zero vs. non-zero entries in the ground-truth attribution map on synthetic datasets, and significant improvements across previous attribution methods based on feature ablation, Shapley values, and other gradient-based methods. Our work constitutes a first example of identifiable inference of time-series attribution maps, and opens avenues better understanding of time-series data, such as for neural dynamics and decision-processes within neural networks.

Outline of the Method:

FIG1

Identifiable attribution maps for time-series data. Using time-series data (such as neural data
recorded during navigation, as depicted), our inference framework estimates the ground-truth Jacobian matrix
Jg (i.e., x is the observed neural data linked to latents z and c, where c is the explicit [auxiliary] behavioral
variable that would be linked to grid cells) by identifying the inverse data generation process up to a linear
indeterminacy L. Then, we estimate the Jacobian Jf of the encoder model (f) by minimizing a generalized
InfoNCE objective. Inverting this Jacobian J+f , which approximates Jg, allows us to construct the attributions.

gonlairo and others added 6 commits February 17, 2025 21:16
* Add multiobjective solver and regularized training

* Add example for multiobjective training

* Add jacobian regularizer and SAM

* update license headers

* add api draft for multiobjective training

* add all necessary modules to run the complete xcebra pipeline

* add notebooks to reproduce xcebra pipeline

* add first working notebook

* add notebook with hybrid learning

* add notebook with creation of synthetic data

* add notebook with hybrid training

* add plot with R2 for different parts of the embedding

* add new API

* update api wrapper with more checks and messages

* add tests and notebook with new api

* merge xcebra into attribution

* separate xcebra dataset from cebra

* some minor refactoring of cebra dataset

* separate xcebra loader from cebra

* remove xcebra distributions from cebra

* minor refactoring with distributions

* separate xcebra criterions from cebra

* minor refactoring on criterion

* separate xcebra models/criterions/layers from cebra

* refactoring multiobjective

* more refactoring...

* separate xcebra solvers from cebra

* more refactoring

* move xcebra to its own package

* move more files into xcebra package

* more files and remove changes with the registry

* remove unncessary import

* add folder structure

* move back distributions

* add missing init

* remove wrong init

* make loader and dataset run with new imports

* making it run!

* make attribution run

* Run pre-commit

* move xcebra repo one level up

* update gitignore and add __init__ from data

* add init to distributions

* add correct init for attribution pacakge

* add correct init for model package

* fix remaining imports

* fix tests

* add examples back to xcebra repo

* update imports from graphs_xcebra

* add setup.py to create a package

* update imports of graph_xcebra

* update notebooks

* Formatting code for submission

Co-authored-by: Rodrigo Gonzalez <[email protected]>

* move test into xcebra

* Add README

* move distributions back to main package

* clean up examples

* adapt tests

* Add LICENSE

* add train/eval notebook again

* add notebook with clean results

* rm synthetic data

* change name from xcebra to regcl

* change names of modules and adapt imports

* change name from graphs_xcebra to synthetic_data

* Integrate into CEBRA

* Fix remaining imports and make notebook runnable

* Add dependencies, add version flag

* Remove synthetic data files

* reset dockerfile, move vmf

* apply pre-commit

* Update notice

* add some docstrings

* Apply license headers

* add new scd notebook

* add notebook with scd

---------

Co-authored-by: Steffen Schneider <[email protected]>
* bump version

* update dockerfile

* fix progress bar

* remove outdated test

* rename models
@stes stes changed the title Aistats2025 Add xCEBRA implementation (AISTATS 2025) Feb 17, 2025
@MMathisLab
Copy link
Member

Let's move demo /examples to https://github.com/AdaptiveMotorControlLab/CEBRA-demos :D

Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left comments throughout, thank you!!

@@ -35,3 +35,83 @@
- 'tests/**/*.py'
- 'docs/**/*.py'
- 'conda/**/*.yml'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think best just in the related files, not here?

import torch
import torch.nn as nn
import tqdm
from captum.attr import NeuronFeatureAblation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need all the baselines in this package? We should write up docs for this also ...

@@ -29,7 +30,7 @@


def _description(stats: Dict[str, float]):
stats_str = [f"{key}: {value: .4f}" for key, value in stats.items()]
stats_str = [f"{key}: {value:.3f}" for key, value in stats.items()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above, will revert back

@@ -73,7 +74,9 @@ class ProgressBar:
"Log and display values during training."

loader: Iterable
log_format: str
logger: logging.Logger = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more details needed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove and move to demos

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove and move to demos

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove and move to demos

@cla-bot cla-bot bot added the CLA signed label Feb 18, 2025
Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a few edits:

  • The old MultiobjectiveSolver is now again accessible (important for the hybrid model in Fig 2 in CEBRA paper), it is now called LegacyMultiobjectiveSolver. For an end user using the sklearn API with hybrid=True, this can now still be used.
  • Reverted some changes from the research code base not important for release
  • added additional tests, incl. integration tests from the notebooks
  • (resolved some additional review comments)

@@ -130,6 +131,16 @@ def _inference(self, batch):
class SingleSessionHybridSolver(abc_.MultiobjectiveSolver):
"""Single session training, contrasting neural data against behavior."""

log: Dict = dataclasses.field(default_factory=lambda: ({
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revert some of these changes back, were part of the research code base and should not go to the package...

@@ -29,7 +30,7 @@


def _description(stats: Dict[str, float]):
stats_str = [f"{key}: {value: .4f}" for key, value in stats.items()]
stats_str = [f"{key}: {value:.3f}" for key, value in stats.items()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above, will revert back

Comment on lines +74 to +78
# NOTE(stes): Temporarily disable, INCLUDE BEFORE MERGE!
#- name: Check that no binary files have been added to repo
# if: matrix.os == 'ubuntu-latest'
# run: |
# make check_for_binary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pragmatic workaround until i moved the demo files. will remove once fixed

negative=None,
)

def load_batch_contrastive(self, index: BatchIndex) -> Batch:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That issue only appears once #168 is used, so I think not of concern here. This function simply constructs the batch from indices, this should be unrelated to the batched inference issue

@@ -96,7 +96,6 @@ def get_datapath(path: str = None) -> str:
from cebra.datasets.gaussian_mixture import *
from cebra.datasets.hippocampus import *
from cebra.datasets.monkey_reaching import *
from cebra.datasets.synthetic_data import *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix 5e30829

@stes
Copy link
Member

stes commented Feb 19, 2025

next:

  • check docstring coverage and write docstrings
  • remove ratinabox==1.8 and ephysiopy==1.9.62 deps for xcebra? I think not needed for core functionality. when dropping, could be added to the docs somewhere
  • fix and consolidate naming of some of the newly added classes
  • cebra.distributions.DeltaVMFDistribution seems missing, check!

@stes
Copy link
Member

stes commented Feb 19, 2025

tests build!

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants