-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zanj integration: datasets & training #177
Merged
Merged
Changes from 65 commits
Commits
Show all changes
89 commits
Select commit
Hold shift + click to select a range
25d745d
wip
mivanit da2e05f
Merge branch 'zanj-integration' into zanj-integration-2
mivanit 7fdbdb0
wip
mivanit f7abcb0
bump muutils to 0.3.3, some zanj tests working with that
mivanit a31d4ba
misc
mivanit 705e1f6
something with layernorm is causing the tensor elements not to match up
mivanit 34a62fc
???
mivanit a6a5b32
exact loading of model works!
mivanit 0181b02
ugh not quite, only working if layernorm folding disabled
mivanit 9e2fe97
wip
mivanit 07aa160
zanj save/load tests passing?
mivanit e1b28b4
fixed some unit tests, test_eval_model still fails >:(
mivanit 84d3ae8
so confused, test only fails when model generated via training?
mivanit 2019ed4
merge with main (and bump muutils to 0.3.6)
mivanit 570c2b1
fixed folding issue
mivanit 1db5c61
Merge branch 'add-notebook-testing' into zanj-integration-2
mivanit 075ff2b
bump muutils to 0.3.7
mivanit 808e333
updated poetry.lock
mivanit 04b9d09
prelim to/from ascii and pixels methods, might need to be moved
mivanit 9ab36f7
run notebook
mivanit 4548296
merge with add-notebook-testing
mivanit 377724a
wip
mivanit 2406dea
wip
mivanit 70e99f5
this was some of the most paintful debugging ive ever done
mivanit a8a52af
format
mivanit 8ab6e79
bump muutils
mivanit 6bf592b
merge with main
mivanit 820f0b3
fixes?
mivanit ecb1872
format
mivanit b650af9
update poetry lock
mivanit 525c719
fixes
mivanit 93a31aa
format
mivanit 94c675d
reworked mazeplot init
mivanit e612f09
wip
mivanit 3cf9041
add unit length parameter to MazePlot
canrager 40f4efd
misspelled folder??
mivanit ea7a66a
wip, but unit tests passing!
mivanit b09e707
wip
mivanit e1b774f
incomprehensible upstream issue in muutils
mivanit e2d3799
reworking training script
mivanit 16b5665
wip
mivanit c3a9d69
test_train_model working!
mivanit a8f8934
wip
mivanit 5d8bd00
test_eval_model passing
mivanit 5238158
format
mivanit 56ce56d
wip refactor
mivanit bb04c45
SolvedMaze now inherits from TargetedLatticeMaze
mivanit 09876b1
Really dumb bug tracked down, path would overwrite endpoints in as_pi…
mivanit cdb9ea7
format
mivanit ea20e9a
Merge branch 'add-maze-from-ascii' of https://github.com/AISC-underst…
mivanit 20436ab
remove MazePlot.show()
mivanit f65abbe
aaaaA
mivanit 134e0ea
wip
mivanit fe4eae6
merge
mivanit f248e5a
wip
mivanit 22518df
wip filtering
mivanit 2c0728e
more filtering wip
mivanit 360c940
wip filters
mivanit 1ae7d6e
filters working!
mivanit 1742ee4
filteringgit add maze_transformer/ notebooks/!
mivanit 41223af
removed debug printing
mivanit 52c2042
format
mivanit 92eae14
simplified decorator, minor change to notebook
mivanit 2180d19
filtering improvements
mivanit f1e304c
format
mivanit cee6204
bump muutils to v0.3.9
mivanit f491f32
Add tests for MazeDataset
valedan e64119e
Test custom filters
valedan a8fd1e5
test dataset filters
valedan a7148e9
fixed minor bugs in tests from zanj-integration-datasets, needs to be…
mivanit da56b52
initial version of maze complexity evals
mivanit 2c13e51
fixed bug in cut_percentile_shortest and ran formatting
mivanit a510d41
merging in from main
mivanit 990dbb0
format, resolved a forgotten merge conflict
mivanit 2d91858
MazePath dissapeared again???
mivanit 45e75dd
format (removed jaxtyping import)
mivanit 135435a
added a TODO of something to implement for constrained dfs kwargs
mivanit 88002f6
dumb bug that probably doesnt matter since we will remove TargetedLat…
mivanit e0cd326
Revert "dumb bug that probably doesnt matter since we will remove Tar…
mivanit 15070b6
Zanj datasets getitem (#182)
valedan 88402cd
format
mivanit e8b7196
format
mivanit 04486da
Constrained dfs, dataset modifications (#184)
canrager 6d942ef
Merge branch 'zanj-integration-datasets' of https://github.com/AISC-u…
mivanit 54ff5a0
fixed maze dataset config hash usage, removed print from parallel wor…
mivanit c99f652
format
mivanit e30f3f0
fixed notebook test
mivanit e58c348
bumpy pytest to 7.3.1 to resolve missing 'mocker' fixture
mivanit e2f9039
fix biased baseline
valedan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from muutils.zanj.torchutil import ConfigMismatchException, assert_model_cfg_equality | ||
|
||
from maze_transformer.training.config import ZanjHookedTransformer | ||
|
||
|
||
def assert_model_output_equality( | ||
model_a: ZanjHookedTransformer, model_b: ZanjHookedTransformer | ||
): | ||
try: | ||
assert_model_cfg_equality(model_a, model_b) | ||
except ConfigMismatchException as e: | ||
if e.diff == { | ||
"model_cfg": {"are_weights_processed": {"self": False, "other": True}} | ||
} or e.diff == { | ||
"model_cfg": { | ||
"are_layernorms_folded": {"self": False, "other": True}, | ||
"are_weights_processed": {"self": False, "other": True}, | ||
} | ||
}: | ||
pass | ||
else: | ||
raise e | ||
|
||
# Random input tokens | ||
dataset_cfg = model_a.zanj_model_config.dataset_cfg | ||
input_sequence = torch.randint( | ||
low=0, | ||
high=len(dataset_cfg.token_arr), | ||
size=(1, min(dataset_cfg.seq_len_max, 10)), | ||
) | ||
|
||
# (copied from `test_eval_model.py`) | ||
# Check for equality in argsort (absolute values won't be equal due to centering the unembedding weight matrix) | ||
assert torch.all( | ||
model_a(input_sequence.clone()).argsort() | ||
== model_b(input_sequence.clone()).argsort() | ||
) | ||
# apply normalization (e.g. softmax) and check with atol v-small | ||
# (roughly 1E-7 for float error on logexp I think) | ||
output_a = torch.nn.functional.softmax(model_a(input_sequence.clone()), dim=-1) | ||
output_b = torch.nn.functional.softmax(model_b(input_sequence.clone()), dim=-1) | ||
|
||
assert torch.allclose(output_a, output_b, atol=1e-7) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should live in tests/helpers