-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add xCEBRA implementation (AISTATS 2025) #225
base: main
Are you sure you want to change the base?
Conversation
* Add multiobjective solver and regularized training * Add example for multiobjective training * Add jacobian regularizer and SAM * update license headers * add api draft for multiobjective training * add all necessary modules to run the complete xcebra pipeline * add notebooks to reproduce xcebra pipeline * add first working notebook * add notebook with hybrid learning * add notebook with creation of synthetic data * add notebook with hybrid training * add plot with R2 for different parts of the embedding * add new API * update api wrapper with more checks and messages * add tests and notebook with new api * merge xcebra into attribution * separate xcebra dataset from cebra * some minor refactoring of cebra dataset * separate xcebra loader from cebra * remove xcebra distributions from cebra * minor refactoring with distributions * separate xcebra criterions from cebra * minor refactoring on criterion * separate xcebra models/criterions/layers from cebra * refactoring multiobjective * more refactoring... * separate xcebra solvers from cebra * more refactoring * move xcebra to its own package * move more files into xcebra package * more files and remove changes with the registry * remove unncessary import * add folder structure * move back distributions * add missing init * remove wrong init * make loader and dataset run with new imports * making it run! * make attribution run * Run pre-commit * move xcebra repo one level up * update gitignore and add __init__ from data * add init to distributions * add correct init for attribution pacakge * add correct init for model package * fix remaining imports * fix tests * add examples back to xcebra repo * update imports from graphs_xcebra * add setup.py to create a package * update imports of graph_xcebra * update notebooks * Formatting code for submission Co-authored-by: Rodrigo Gonzalez <[email protected]> * move test into xcebra * Add README * move distributions back to main package * clean up examples * adapt tests * Add LICENSE * add train/eval notebook again * add notebook with clean results * rm synthetic data * change name from xcebra to regcl * change names of modules and adapt imports * change name from graphs_xcebra to synthetic_data * Integrate into CEBRA * Fix remaining imports and make notebook runnable * Add dependencies, add version flag * Remove synthetic data files * reset dockerfile, move vmf * apply pre-commit * Update notice * add some docstrings * Apply license headers * add new scd notebook * add notebook with scd --------- Co-authored-by: Steffen Schneider <[email protected]>
Let's move demo /examples to https://github.com/AdaptiveMotorControlLab/CEBRA-demos :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left comments throughout, thank you!!
@@ -35,3 +35,83 @@ | |||
- 'tests/**/*.py' | |||
- 'docs/**/*.py' | |||
- 'conda/**/*.yml' | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think best just in the related files, not here?
import torch | ||
import torch.nn as nn | ||
import tqdm | ||
from captum.attr import NeuronFeatureAblation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need all the baselines in this package? We should write up docs for this also ...
@@ -29,7 +30,7 @@ | |||
|
|||
|
|||
def _description(stats: Dict[str, float]): | |||
stats_str = [f"{key}: {value: .4f}" for key, value in stats.items()] | |||
stats_str = [f"{key}: {value:.3f}" for key, value in stats.items()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above, will revert back
@@ -73,7 +74,9 @@ class ProgressBar: | |||
"Log and display values during training." | |||
|
|||
loader: Iterable | |||
log_format: str | |||
logger: logging.Logger = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more details needed
examples/synthetic_data.pkl
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove and move to demos
examples/train_and_evaluate.ipynb
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove and move to demos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove and move to demos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a few edits:
- The old
MultiobjectiveSolver
is now again accessible (important for the hybrid model in Fig 2 in CEBRA paper), it is now calledLegacyMultiobjectiveSolver
. For an end user using the sklearn API withhybrid=True
, this can now still be used. - Reverted some changes from the research code base not important for release
- added additional tests, incl. integration tests from the notebooks
- (resolved some additional review comments)
@@ -130,6 +131,16 @@ def _inference(self, batch): | |||
class SingleSessionHybridSolver(abc_.MultiobjectiveSolver): | |||
"""Single session training, contrasting neural data against behavior.""" | |||
|
|||
log: Dict = dataclasses.field(default_factory=lambda: ({ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will revert some of these changes back, were part of the research code base and should not go to the package...
@@ -29,7 +30,7 @@ | |||
|
|||
|
|||
def _description(stats: Dict[str, float]): | |||
stats_str = [f"{key}: {value: .4f}" for key, value in stats.items()] | |||
stats_str = [f"{key}: {value:.3f}" for key, value in stats.items()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above, will revert back
# NOTE(stes): Temporarily disable, INCLUDE BEFORE MERGE! | ||
#- name: Check that no binary files have been added to repo | ||
# if: matrix.os == 'ubuntu-latest' | ||
# run: | | ||
# make check_for_binary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pragmatic workaround until i moved the demo files. will remove once fixed
negative=None, | ||
) | ||
|
||
def load_batch_contrastive(self, index: BatchIndex) -> Batch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That issue only appears once #168 is used, so I think not of concern here. This function simply constructs the batch from indices, this should be unrelated to the batched inference issue
cebra/datasets/__init__.py
Outdated
@@ -96,7 +96,6 @@ def get_datapath(path: str = None) -> str: | |||
from cebra.datasets.gaussian_mixture import * | |||
from cebra.datasets.hippocampus import * | |||
from cebra.datasets.monkey_reaching import * | |||
from cebra.datasets.synthetic_data import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix 5e30829
next:
|
xCEBRA
eXplainable CEBRA 🔎🦓
This PR adds the following features:
This code supports the following paper:
https://openreview.net/forum?id=aGrCXoTB4P
Abstract:
Gradient-based attribution methods aim to explain decisions of deep learning models, but so far lack identifiability guarantees. Here, we propose a method to generate attribution maps with identifiability guarantees by developing a regularized contrastive learning algorithm (RegCL) trained on time-series data. We show theoretically that RegCL has favorable properties for identifying the Jacobian matrix of the data generating process. Empirically, we demonstrate robust approximation of zero vs. non-zero entries in the ground-truth attribution map on synthetic datasets, and significant improvements across previous attribution methods based on feature ablation, Shapley values, and other gradient-based methods. Our work constitutes a first example of identifiable inference of time-series attribution maps, and opens avenues better understanding of time-series data, such as for neural dynamics and decision-processes within neural networks.
Outline of the Method:
Identifiable attribution maps for time-series data. Using time-series data (such as neural data
recorded during navigation, as depicted), our inference framework estimates the ground-truth Jacobian matrix
Jg (i.e., x is the observed neural data linked to latents z and c, where c is the explicit [auxiliary] behavioral
variable that would be linked to grid cells) by identifying the inverse data generation process up to a linear
indeterminacy L. Then, we estimate the Jacobian Jf of the encoder model (f) by minimizing a generalized
InfoNCE objective. Inverting this Jacobian J+f , which approximates Jg, allows us to construct the attributions.