-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Zanj integration: datasets & training (#177)
Porting datasets and training loop to use ZANJ, as well as many, many other changes to the dataset. # configs Modifying configs from the command line is now easier! - `ConfigHolder.get_config_multisource()` - takes one of: config object, a file to read the config from, or a list of names to get presets for each of the sub-configs - a dotlist-dict to modify any parameters of the config - `GPTDataset().to_fname()` used to generate filename for saving config (and also to find matching config to load/download). `MazeDatasetConfig` also implements this in a custom way - `MazeDatasetConfig` now has a `maze_ctor_kwargs` field, for passing keyword arguments to maze generation (see #183) # maze dataset You can now get a `MazeDataset` from just a config -- it will load, download, or generate a dataset on the fly. The mess of ways of storing a dataset we had before is now gone -- a `MazeDataset` contains a list of `SolvedMaze`, and it will return one of those when you call `__getitem__`. We also added filters and fixed some parallelization issues! - `GPTDataset().from_config()` as a new, simplified version of getting a dataset: simply pass a config, and it will attempt to load from local directory, download, or generate. any of these can be disabled, and kwargs (for things like # of cores to use) are passed down. - canonical representation of the dataset as list of `SolvedMaze` - `mazes_objs, mazes_tokens, mazes_array` are now cached properties. they will work, but might be slow due to no parallelization - `MazeDataset.__getitem__()` now returns a `SolvedMaze` - `create_dataset()` deprecated but should still work. remove this? - filtering! you can specify filters in the config under the `applied_filters` field, or you can call `dataset.filter_by.your_filter_func(your_arg=your_val)`. Both of these work the same under the hood. - can specify in `from_config()` whether to run in parallel or not (default is no). this is useful since for small datasets, parallelization has huge overhead. tests are now much faster. - there may have been some issues to parallelization and using the same fixed seed across all processes. This was fixed in #183 , but in a hacky way # training Models now saved as ZANJ objects, and the command line interface is improved. - `train()` now: - saves models as ZANJ - returns the trained `ZanjHookedTransformer` - `train_model()`: - now returns `TrainingResult` which contains output path, model, and eventually logging info perhaps? - for config, interface inherited from `ConfigHolder.get_config_multisource()` and kwargs are passed as modification dict --------- Co-authored-by: mivanit <[email protected]> Co-authored-by: Dan Valentine <[email protected]> Co-authored-by: Can Rager <[email protected]> Co-authored-by: canrager <[email protected]>
- Loading branch information
1 parent
bf8605f
commit 06c8181
Showing
44 changed files
with
2,358 additions
and
878 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import typing | ||
|
||
from maze_transformer.generation.lattice_maze import SolvedMaze | ||
from maze_transformer.utils.utils import register_method | ||
|
||
MAZE_COMPLEXITY_EVALS: dict[str, typing.Callable[[SolvedMaze], float]] = dict() | ||
|
||
|
||
class MazeComplexityEvals: | ||
@register_method(MAZE_COMPLEXITY_EVALS) | ||
def solution_length(maze: SolvedMaze) -> float: | ||
return len(maze.solution) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.