From 5aff5807c035aff3c21282246855953675cc37ab Mon Sep 17 00:00:00 2001 From: Dan Valentine Date: Thu, 15 Jun 2023 04:59:03 -0400 Subject: [PATCH] training loop evals (#181) added evals to training loop, and a bunch of options in the `TrainingConfig` related to this mega PR, merging without review because I need to run experiments lol. --------- Co-authored-by: mivanit --- examples/coverage/coverage.svg | 4 +- examples/coverage/coverage.txt | 79 +++--- makefile | 4 +- maze_transformer/evaluation/eval_model.py | 109 +++++++- maze_transformer/evaluation/path_evals.py | 38 ++- maze_transformer/test_helpers/stub_logger.py | 7 + maze_transformer/training/config.py | 169 ++++++++++- maze_transformer/training/train_model.py | 39 ++- maze_transformer/training/train_save_files.py | 31 ++ maze_transformer/training/training.py | 140 ++++++---- maze_transformer/training/wandb_logger.py | 28 +- notebooks/train_model.ipynb | 132 +++++---- notebooks/train_model_hallway.ipynb | 264 ++++++++++++++++-- notes.md | 23 ++ poetry.lock | 2 +- pyrightconfig.json | 19 ++ tests/integration/test_eval_model.py | 14 +- tests/integration/test_training.py | 129 +++++++++ .../config/test_train_cfg_intervals.py | 259 +++++++++++++++++ .../training/config/test_train_config.py | 22 +- ...est_training.py => test_get_dataloader.py} | 0 21 files changed, 1284 insertions(+), 228 deletions(-) create mode 100644 maze_transformer/training/train_save_files.py create mode 100644 notes.md create mode 100644 pyrightconfig.json create mode 100644 tests/integration/test_training.py create mode 100644 tests/unit/maze_transformer/training/config/test_train_cfg_intervals.py rename tests/unit/maze_transformer/training/{test_training.py => test_get_dataloader.py} (100%) diff --git a/examples/coverage/coverage.svg b/examples/coverage/coverage.svg index 318685c0..6963b3e1 100644 --- a/examples/coverage/coverage.svg +++ b/examples/coverage/coverage.svg @@ -15,7 +15,7 @@ coverage coverage - 85% - 85% + 87% + 87% diff --git a/examples/coverage/coverage.txt b/examples/coverage/coverage.txt index 143aef73..ad77b545 100644 --- a/examples/coverage/coverage.txt +++ b/examples/coverage/coverage.txt @@ -1,38 +1,41 @@ -Name Stmts Miss Cover Missing ------------------------------------------------------------------------------------------------------- -maze_transformer\__init__.py 0 0 100% -maze_transformer\evaluation\__init__.py 0 0 100% -maze_transformer\evaluation\baseline_models.py 69 7 90% 62-63, 158, 160-163, 169 -maze_transformer\evaluation\eval_model.py 65 23 65% 36-45, 60-95 -maze_transformer\evaluation\maze_complexity_evals.py 8 0 100% -maze_transformer\evaluation\path_evals.py 98 14 86% 19, 52-56, 83, 155, 165-169, 178-181 -maze_transformer\evaluation\plot_attention.py 89 89 0% 2-261 -maze_transformer\test_helpers\assertions.py 22 1 95% 44 -maze_transformer\test_helpers\stub_logger.py 21 7 67% 15-17, 20, 23, 26, 29 -maze_transformer\tokenizer.py 67 11 84% 13, 50-51, 57, 89-90, 96-98, 101-104 -maze_transformer\training\__init__.py 0 0 100% -maze_transformer\training\config.py 121 12 90% 95, 259, 281-285, 327-328, 341, 350, 356-359, 405, 410 -maze_transformer\training\train_model.py 45 3 93% 33, 60-61 -maze_transformer\training\training.py 69 0 100% -maze_transformer\training\wandb_logger.py 40 3 92% 55-57 -setup.py 3 3 0% 3-6 -tests\conftest.py 8 0 100% -tests\integration\test_create_dataset.py 27 0 100% -tests\integration\test_eval_model.py 49 0 100% -tests\integration\test_train_model.py 8 0 100% -tests\unit\maze_transformer\evaluation\test_baseline_models.py 23 0 100% -tests\unit\maze_transformer\evaluation\test_maze_complexity_evals.py 15 0 100% -tests\unit\maze_transformer\evaluation\test_path_evals.py 48 2 96% 30-32 -tests\unit\maze_transformer\test_tokenizers.py 58 0 100% -tests\unit\maze_transformer\training\config\test_base_gpt_config.py 29 0 100% -tests\unit\maze_transformer\training\config\test_cfg_save_load.py 67 0 100% -tests\unit\maze_transformer\training\config\test_config_holder.py 38 5 87% 44-50 -tests\unit\maze_transformer\training\config\test_train_config.py 32 0 100% -tests\unit\maze_transformer\training\test_dataset.py 53 7 87% 44-59 -tests\unit\maze_transformer\training\test_maze_dataset_construction.py 8 0 100% -tests\unit\maze_transformer\training\test_model_loading_old.py 19 0 100% -tests\unit\maze_transformer\training\test_tokenizer.py 13 0 100% -tests\unit\maze_transformer\training\test_training.py 18 0 100% -tests\unit\maze_transformer\training\zanj\test_zanj_ht_save_load.py 37 0 100% ------------------------------------------------------------------------------------------------------- -TOTAL 1267 187 85% +Name Stmts Miss Cover Missing +------------------------------------------------------------------------------------------------------- +maze_transformer\__init__.py 0 0 100% +maze_transformer\evaluation\__init__.py 0 0 100% +maze_transformer\evaluation\baseline_models.py 69 7 90% 62-63, 158, 160-163, 169 +maze_transformer\evaluation\eval_model.py 78 28 64% 46-55, 70-105, 173-186, 251 +maze_transformer\evaluation\maze_complexity_evals.py 8 0 100% +maze_transformer\evaluation\path_evals.py 106 9 92% 19, 64-68, 88, 91-95, 167, 177, 190 +maze_transformer\evaluation\plot_attention.py 89 89 0% 2-261 +maze_transformer\test_helpers\assertions.py 22 1 95% 44 +maze_transformer\test_helpers\stub_logger.py 26 2 92% 23, 26 +maze_transformer\tokenizer.py 67 11 84% 13, 50-51, 57, 89-90, 96-98, 101-104 +maze_transformer\training\__init__.py 0 0 100% +maze_transformer\training\config.py 156 17 89% 80-83, 143, 190, 194, 232, 402, 424-428, 470-471, 484, 493, 499-502, 548, 553 +maze_transformer\training\train_model.py 59 6 90% 34, 61-62, 131-138 +maze_transformer\training\train_save_files.py 16 0 100% +maze_transformer\training\training.py 80 3 96% 61, 86-89 +maze_transformer\training\wandb_logger.py 51 4 92% 56-58, 61 +setup.py 3 3 0% 3-6 +tests\conftest.py 8 0 100% +tests\integration\test_create_dataset.py 27 0 100% +tests\integration\test_eval_model.py 49 0 100% +tests\integration\test_train_model.py 8 0 100% +tests\integration\test_training.py 62 0 100% +tests\unit\maze_transformer\evaluation\test_baseline_models.py 23 0 100% +tests\unit\maze_transformer\evaluation\test_maze_complexity_evals.py 15 0 100% +tests\unit\maze_transformer\evaluation\test_path_evals.py 48 2 96% 30-32 +tests\unit\maze_transformer\test_tokenizers.py 58 0 100% +tests\unit\maze_transformer\training\config\test_base_gpt_config.py 29 0 100% +tests\unit\maze_transformer\training\config\test_cfg_save_load.py 67 0 100% +tests\unit\maze_transformer\training\config\test_config_holder.py 38 5 87% 44-50 +tests\unit\maze_transformer\training\config\test_train_cfg_intervals.py 89 0 100% +tests\unit\maze_transformer\training\config\test_train_config.py 33 0 100% +tests\unit\maze_transformer\training\test_dataset.py 53 7 87% 44-59 +tests\unit\maze_transformer\training\test_get_dataloader.py 18 0 100% +tests\unit\maze_transformer\training\test_maze_dataset_construction.py 8 0 100% +tests\unit\maze_transformer\training\test_model_loading_old.py 19 0 100% +tests\unit\maze_transformer\training\test_tokenizer.py 13 0 100% +tests\unit\maze_transformer\training\zanj\test_zanj_ht_save_load.py 37 0 100% +------------------------------------------------------------------------------------------------------- +TOTAL 1532 194 87% diff --git a/makefile b/makefile index 1ea9f892..6a496571 100644 --- a/makefile +++ b/makefile @@ -39,7 +39,7 @@ unit: .PHONY: integration integration: @echo "run integration tests" - $(POETRY_RUN_PYTHON) -m pytest -s tests/integration + $(POETRY_RUN_PYTHON) -m pytest tests/integration .PHONY: convert_notebooks @@ -62,7 +62,7 @@ test: clean unit integration test_notebooks .PHONY: cov cov: @echo "run tests and generate coverage reports" - $(POETRY_RUN_PYTHON) -m pytest --cov=. -s tests/ + $(POETRY_RUN_PYTHON) -m pytest --cov=. tests/ $(POETRY_RUN_PYTHON) -m coverage report -m > $(COVERAGE_REPORTS_DIR)/coverage.txt $(POETRY_RUN_PYTHON) -m coverage_badge -f -o $(COVERAGE_REPORTS_DIR)/coverage.svg $(POETRY_RUN_PYTHON) -m coverage html diff --git a/maze_transformer/evaluation/eval_model.py b/maze_transformer/evaluation/eval_model.py index ded6fa44..e5bde8be 100644 --- a/maze_transformer/evaluation/eval_model.py +++ b/maze_transformer/evaluation/eval_model.py @@ -4,20 +4,30 @@ import numpy as np import torch -from maze_dataset import SPECIAL_TOKENS, CoordTup, MazeDataset, MazeDatasetConfig +from jaxtyping import Float +from maze_dataset import ( + SPECIAL_TOKENS, + CoordTup, + MazeDataset, + MazeDatasetConfig, + SolvedMaze, +) from maze_dataset.tokenization.token_utils import ( WhenMissing, get_context_tokens, get_path_tokens, + remove_padding_from_token_str, tokens_to_coords, ) from muutils.mlutils import chunks from muutils.statcounter import StatCounter from transformer_lens import HookedTransformer +from transformer_lens import utils as tl_utils from maze_transformer.evaluation.path_evals import PathEvalFunction, PathEvals +from maze_transformer.tokenizer import HuggingMazeTokenizer from maze_transformer.training.config import ConfigHolder -from maze_transformer.training.training import TRAIN_SAVE_FILES +from maze_transformer.training.train_save_files import TRAIN_SAVE_FILES # pylint: disable=protected-access @@ -109,10 +119,10 @@ def predict_maze_paths( # check types assert isinstance( - tokens_batch, list + tokens_batch, (list, tuple) ), f"tokens_batch must be a list, got {type(tokens_batch)}" assert all( - isinstance(tokens, list) for tokens in tokens_batch + isinstance(tokens, (list, tuple)) for tokens in tokens_batch ), f"tokens_batch must be a list of lists, got {[type(tokens) for tokens in tokens_batch] = }" assert all( isinstance(x, str) for tokens in tokens_batch for x in tokens @@ -131,6 +141,7 @@ def predict_maze_paths( max_new_tokens=max_new_tokens, verbose=verbose, temperature=temperature, + # use_past_kv_cache=False, ) assert isinstance( prediction, str @@ -154,27 +165,59 @@ def predict_maze_paths( return paths +def evaluate_path_predictions( + solved_mazes: list[SolvedMaze], + predictions: list[list[tuple[int, int]]], + path_evals: dict[str, PathEvalFunction], +) -> dict[str, StatCounter]: + path_scores: dict[str, StatCounter] = { + name: StatCounter() for name in path_evals.keys() + } + for name, func in path_evals.items(): + path_scores[name].update( + func( + maze=solved_maze.maze, + solution=np.array(solved_maze.solution), + prediction=np.array(prediction), + ) + for solved_maze, prediction in zip(solved_mazes, predictions) + ) + + return path_scores + + def evaluate_model( model: HookedTransformer, dataset: MazeDataset, + dataset_tokens: list[list[str]] | None = None, eval_functions: dict[str, PathEvalFunction] | None = None, max_new_tokens: int = 8, batch_size: int = 64, verbose: bool = False, ) -> dict[str, StatCounter]: - """Run a set of eval functions on a model for a given dataset. Returns a seperate StatCounter for each eval function.""" + """Run a set of eval functions on a model for a given dataset. Returns a seperate StatCounter for each eval function. + + if dataset_tokens is provided, we assume that the dataset has already been tokenized and we skip tokenization. MAKE SURE THERE IS NOT A MISMATCH BETWEEN THE DATASET AND DATASET_TOKENS + """ + if not eval_functions: - eval_functions = PathEvals.evals + # TODO: potentially model evals which aren't path evals? + eval_functions = PathEvals.EVALS score_counters: dict[str, StatCounter] = { - name: StatCounter() for name in eval_functions.keys() + name: StatCounter() for name in eval_functions } - for maze_batch in chunks(dataset, batch_size): - tokens_batch = [ - maze.as_tokens(dataset.cfg.node_token_map) for maze in maze_batch - ] - predictions = predict_maze_paths( + if dataset_tokens is None: + dataset_tokens = dataset.as_tokens(join_tokens_individual_maze=False) + else: + assert len(dataset) == len( + dataset_tokens + ), f"dataset and dataset_tokens must be the same length and must be from corresponding mazes, got {len(dataset) = } and {len(dataset_tokens) = }" + + for batch in chunks(zip(dataset, dataset_tokens), batch_size): + maze_batch, tokens_batch = zip(*batch) + predictions: list[str | list[tuple[int, int]]] = predict_maze_paths( tokens_batch=tokens_batch, data_cfg=dataset.cfg, model=model, @@ -194,3 +237,45 @@ def evaluate_model( ) return score_counters + + +def evaluate_logits( + logits: Float[torch.Tensor, "batch pos d_vocab"], + batch: list[int], + config: ConfigHolder, + tokenizer: HuggingMazeTokenizer, + path_evals: dict[str, PathEvalFunction] | None = None, +) -> dict[str, StatCounter]: + """Runs a set of eval functions on the provided logits. For path evals, an attempt will be made to extract a predicted path from the logits (it is assumed that the logits are an entire sequence output from training, so they contain the adj_list plus path)""" + + raise NotImplementedError( + "evaluate_logits does not function correctly, and at the moment there are only path evals anyway" + ) + + scores: dict[str, StatCounter] = {} + + if path_evals: + # TODO: this is pretty much wrong -- sampling from the logits over the sequence should not produce a valid path + sampled_logits = tl_utils.sample_logits(logits) + prediction_tokens = tokenizer.batch_decode(sampled_logits) + predicted_paths = [] + for tokens in prediction_tokens: + # this returns first path_start to end of list. Early in training there may be multiple path_start tokens, so results should be treated with caution + path_tokens = get_path_tokens(tokens.split(" ")) + path_coords = tokens_to_coords( + path_tokens, maze_data_cfg=config.dataset_cfg, when_noncoord="skip" + ) + predicted_paths.append(cast(list[tuple[int, int]], path_coords)) + + maze_tokens = [ + remove_padding_from_token_str(token_str) + for token_str in tokenizer.batch_decode(batch) + ] + + solved_mazes = [ + SolvedMaze.from_tokens(tokens.split(" "), config.dataset_cfg) + for tokens in maze_tokens + ] + scores |= evaluate_path_predictions(solved_mazes, predicted_paths, path_evals) + + return scores diff --git a/maze_transformer/evaluation/path_evals.py b/maze_transformer/evaluation/path_evals.py index 59cf2fff..2b732ff5 100644 --- a/maze_transformer/evaluation/path_evals.py +++ b/maze_transformer/evaluation/path_evals.py @@ -38,9 +38,21 @@ def is_adjacent(node1: Coord, node2: Coord) -> bool: class PathEvals: """array path based eval functions""" - evals: dict[str, PathEvalFunction] = {} + # We split evals into fast and slow. Fast ones can be used more frequently during training + fast: dict[str, PathEvalFunction] = {} + slow: dict[str, PathEvalFunction] = {} - @register_method(evals) + PATH_EVALS_MAP: dict[str, dict[str, PathEvalFunction]] = { + "eval_fast": fast, + "eval_slow": slow, + } + + @classmethod + @property + def EVALS(cls): + return {**cls.fast, **cls.slow} + + @register_method(fast) @staticmethod def node_overlap(solution: CoordArray, prediction: CoordArray, **_) -> float: """number of shared nodes (any order) / total number of (unique) nodes in solution""" @@ -57,7 +69,7 @@ def node_overlap(solution: CoordArray, prediction: CoordArray, **_) -> float: return len(prediction_set & solution_set) / len(solution_set) - @register_method(evals) + @register_method(fast) @staticmethod def num_connections_adjacent_lattice(prediction: CoordArray, **_) -> float: """number of the connections in prediction which actually connect nodes that are adjacent on the lattice, ignoring if they are adjacent on the maze""" @@ -68,10 +80,12 @@ def num_connections_adjacent_lattice(prediction: CoordArray, **_) -> float: return n_adj - @register_method(evals) + @register_method(fast) @staticmethod def fraction_connections_adjacent_lattice(prediction: CoordArray, **_) -> float: """fraction of the connections in prediction which actually connect nodes that are adjacent on the lattice, ignoring if they are adjacent on the maze""" + if len(prediction) == 0: + return 0 if len(prediction) <= 1: warnings.warn( @@ -82,7 +96,7 @@ def fraction_connections_adjacent_lattice(prediction: CoordArray, **_) -> float: return PathEvals.num_connections_adjacent_lattice(prediction) / len(prediction) - @register_method(evals) + @register_method(fast) @staticmethod def num_connections_adjacent(maze: LatticeMaze, prediction: MazePath, **_) -> float: """number of connections in prediction which are valid paths on the maze""" @@ -94,7 +108,7 @@ def num_connections_adjacent(maze: LatticeMaze, prediction: MazePath, **_) -> fl return n_connected - @register_method(evals) + @register_method(fast) @staticmethod def fraction_connections_adjacent( maze: LatticeMaze, prediction: CoordArray, **_ @@ -106,7 +120,7 @@ def fraction_connections_adjacent( num_connections, 1.0 ) - @register_method(evals) + @register_method(fast) @staticmethod def exact_path_predicted( solution: CoordArray, prediction: CoordArray, **_ @@ -114,12 +128,12 @@ def exact_path_predicted( """Was the maze successfully solved?""" return float(np.array_equal(solution, prediction)) - @register_method(evals) + @register_method(fast) @staticmethod def solution_length(solution: CoordArray, **_) -> float: return float(len(solution)) - @register_method(evals) + @register_method(fast) @staticmethod def streak_length_until_incorrect( solution: CoordArray, @@ -143,7 +157,7 @@ def streak_length_until_incorrect( return streak_length - @register_method(evals) + @register_method(fast) @staticmethod def distance_between_end_nodes( solution: MazePath, prediction: MazePath, **_ @@ -154,7 +168,7 @@ def distance_between_end_nodes( return np.linalg.norm(solution[-1] - prediction[-1]) - @register_method(evals) + @register_method(fast) @staticmethod def corner_jumps(prediction: MazePath, **_) -> float: """Looks for corner jumps in the predicted path. A corner jump is if the transformer predicts predicts @@ -168,7 +182,7 @@ def corner_jumps(prediction: MazePath, **_) -> float: normed_distances = np.linalg.norm(distance_between_nodes, axis=1) return np.count_nonzero(normed_distances == np.sqrt(2)) - @register_method(evals) + @register_method(fast) @staticmethod def average_predicted_step_size(prediction: MazePath, **_) -> float: """Returns average step size in the predicted path.""" diff --git a/maze_transformer/test_helpers/stub_logger.py b/maze_transformer/test_helpers/stub_logger.py index 1da8f4e4..9bfcb1df 100644 --- a/maze_transformer/test_helpers/stub_logger.py +++ b/maze_transformer/test_helpers/stub_logger.py @@ -25,8 +25,15 @@ def upload_dataset(self, *args, **kwargs) -> None: def log_metric(self, *args, **kwargs) -> None: self._log("Metric logged.", args, kwargs) + def log_metric_hist(self, *args, **kwargs) -> None: + self._log("Metric (Statcounter) logged.", args, kwargs) + def summary(self, *args, **kwargs) -> None: self._log("Summary logged.", args, kwargs) def progress(self, message: str) -> None: self._log(f"[INFO] - {message}") + + @property + def url(self) -> str: + return "stub logger, not a url" diff --git a/maze_transformer/training/config.py b/maze_transformer/training/config.py index 9ed3dae6..f5c0b962 100644 --- a/maze_transformer/training/config.py +++ b/maze_transformer/training/config.py @@ -65,6 +65,23 @@ def summary(self) -> dict: # ================================================== +_DEFAULT_INTERVAL_COUNTS: typing.Callable[[], dict[str, int]] = lambda: dict( + print_loss=100, + checkpoint=10, + eval_fast=20, + eval_slow=10, +) + + +def _intervals_loading_fn(data: dict) -> dict[str, int]: + if "intervals" in data: + return data["intervals"] + else: + warnings.warn( + "`intervals` not found in config (probably trying to load a legacy config), using None!" + ) + return None + def _optimizer_save_fn(optim: Type[torch.optim.Optimizer]) -> str: """convert torch optimizer to string, while checking that the conversion is reversible""" @@ -76,9 +93,40 @@ def _optimizer_save_fn(optim: Type[torch.optim.Optimizer]) -> str: @serializable_dataclass(kw_only=True) class TrainConfig(SerializableDataclass): - """full training configuration""" + """full training configuration + + # Usage: + - get the optimizer via calling `train_cfg.get_optimizer(model.parameters())` + - get the intervals via `train_cfg.get_intervals()` + + # Parameters + + - `name: str`: name of the training configuration + - `optimizer: Type[torch.optim.Optimizer]`: optimizer class to use + - `optimizer_kwargs: dict[str, Any]`: kwargs to pass to the optimizer + - `batch_size: int`: batch size + - `dataloader_cfg: dict`: kwargs to pass to the dataloader + - `intervals: dict[str, int]`: intervals at which to perform certain actions: + "print_loss", "checkpoint", "eval_fast", "eval_slow" + - `intervals_count: dict[str, int]`: how many of each action to do over the course of the training run + - `evals_max_new_tokens: int`: how many new tokens to generate during evaluation + - `validation_dataset_cfg: None|int|GPTDatasetConfig`: validation dataset + - if `None`, evals are disabled + - if `int`, a dataset of that size is created by sampling from the training dataset using `torch.utils.data.random_split` + - if `GPTDatasetConfig`, a dataset is created from the specified config TODO: this is not implemented yet + + """ name: str + # TODO: loaders specified here only because of legacy models, remove this after some time and models are updated + evals_max_new_tokens: int = serializable_field( + default=8, + loading_fn=lambda data: data.get("evals_max_new_tokens", 8), + ) + validation_dataset_cfg: None | int | GPTDatasetConfig = serializable_field( + default=None, + loading_fn=lambda data: data.get("validation_dataset_cfg", None), + ) optimizer: Type[torch.optim.Optimizer] = serializable_field( # type: ignore default_factory=lambda: torch.optim.RMSprop, @@ -107,8 +155,92 @@ def get_optimizer(self, params) -> Type[torch.optim.Optimizer]: ) ) - print_loss_interval: int = serializable_field(default=1000) - checkpoint_interval: int = serializable_field(default=50000) + intervals: dict[str, int] | None = serializable_field( + default=None, + loading_fn=_intervals_loading_fn, + ) + + intervals_count: dict[str, int] | None = serializable_field( + default=None, + loading_fn=lambda data: data.get("intervals_count", None), + ) + + def get_intervals( + self, + dataset_n_samples: int | None = None, + use_defaults_if_missing: bool = True, + mod_batch_size: bool = True, + ) -> dict[str, int | float]: + """get the intervals""" + + # handle the case where both are missing + if (self.intervals is None) and (self.intervals_count is None): + if use_defaults_if_missing: + self.intervals_count = _DEFAULT_INTERVAL_COUNTS() + else: + raise ValueError( + "both `intervals` and `intervals_count` are missing, and `use_defaults_if_missing` is False. Don't know what intervals to use!" + ) + + # checks + intervals_new: dict[str, int | float] + try: + match (self.intervals is not None, self.intervals_count is not None): + case (False, False): + raise ValueError( + "both `intervals` and `intervals_count` are None! this state should be inaccessible" + ) + case (True, True): + raise ValueError( + "both `intervals` and `intervals_count` are specified, this is not allowed!" + ) + case (True, False): + intervals_new = self.intervals + case (False, True): + if isinstance(dataset_n_samples, int): + intervals_new = { + k: ( + int(dataset_n_samples / v) + if v > 0 + else float("inf") + # setting a count to < 0 means "dont do it" + ) + for k, v in self.intervals_count.items() + } + else: + raise ValueError( + f"{dataset_n_samples = }, but we need an integer to compute the intervals" + ) + + except ValueError as e: + _debug_vals: str = f"{dataset_n_samples=}, {use_defaults_if_missing=}, {mod_batch_size=},\n{self.intervals=},\n{self.intervals_count=}" + raise ValueError(f"{_debug_vals}\ntriggered error:\n{e}") from e + + # disable if set to 0 or negative + intervals_new = { + k: ( + v + if v > 0 + else float("inf") # mod by infinity is always the number itself + ) + for k, v in intervals_new.items() + } + + # check all expected keys are present + for k in _DEFAULT_INTERVAL_COUNTS().keys(): + if k not in intervals_new: + raise ValueError(f"missing key {k} in {intervals_new = }") + + # actually return the intervals + if mod_batch_size: + return { + k: max(1, v // self.batch_size) + if isinstance(v, int) + else v # if float, leave it as is since its float("inf") + for k, v in intervals_new.items() + } + else: + return intervals_new def summary(self) -> dict: """return a human-readable summary of the config""" @@ -118,8 +250,17 @@ def summary(self) -> dict: optimizer_kwargs=self.optimizer_kwargs, batch_size=self.batch_size, dataloader_cfg=self.dataloader_cfg, - print_loss_interval=self.print_loss_interval, - checkpoint_interval=self.checkpoint_interval, + intervals=self.intervals, + intervals_count=self.intervals_count, + evals_max_new_tokens=self.evals_max_new_tokens, + validation_dataset_cfg=( + self.validation_dataset_cfg + if ( + isinstance(self.validation_dataset_cfg, int) + or self.validation_dataset_cfg is None + ) + else self.validation_dataset_cfg.summary() + ), ) @@ -171,8 +312,13 @@ def summary(self) -> dict: num_workers=0, drop_last=False, ), - print_loss_interval=10, - checkpoint_interval=100, + intervals_count=dict( + print_loss=100, + checkpoint=2, + eval_fast=4, + eval_slow=2, + ), + validation_dataset_cfg=10, ), TrainConfig( name="tiny-v1", @@ -185,8 +331,7 @@ def summary(self) -> dict: persistent_workers=True, drop_last=True, ), - print_loss_interval=1000, - checkpoint_interval=5000, + validation_dataset_cfg=10, ), TrainConfig( name="gpt2-small", @@ -199,8 +344,7 @@ def summary(self) -> dict: persistent_workers=True, drop_last=True, ), - print_loss_interval=50, - checkpoint_interval=10000, + validation_dataset_cfg=10, ), TrainConfig( name="sweep-v1", @@ -213,8 +357,7 @@ def summary(self) -> dict: persistent_workers=True, drop_last=True, ), - print_loss_interval=1000, - checkpoint_interval=5000, + validation_dataset_cfg=50, ), ] diff --git a/maze_transformer/training/train_model.py b/maze_transformer/training/train_model.py index 35516063..fe920e45 100644 --- a/maze_transformer/training/train_model.py +++ b/maze_transformer/training/train_model.py @@ -4,11 +4,11 @@ from typing import Union import torch -from maze_dataset import MazeDataset +from maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS from muutils.json_serialize import SerializableDataclass, serializable_dataclass from muutils.mlutils import get_device -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split from maze_transformer.training.config import ( GPT_CONFIGS, @@ -16,7 +16,8 @@ ConfigHolder, ZanjHookedTransformer, ) -from maze_transformer.training.training import TRAIN_SAVE_FILES, get_dataloader, train +from maze_transformer.training.train_save_files import TRAIN_SAVE_FILES +from maze_transformer.training.training import get_dataloader, train from maze_transformer.training.wandb_logger import ( WandbJobType, WandbLogger, @@ -40,7 +41,7 @@ def train_model( cfg_file: str | Path | None = None, cfg_names: typing.Sequence[str] | None = None, do_generate_dataset: bool = False, - dataset_verbose: bool = True, + dataset_verbose: bool = False, device: torch.device | None = None, help: bool = False, **kwargs, @@ -109,7 +110,34 @@ def train_model( local_base_path=base_path, verbose=dataset_verbose, ) - logger.progress("finished getting dataset") + logger.progress(f"finished getting training dataset with {len(dataset)} samples") + # validation dataset, if applicable + val_dataset: MazeDataset | None = None + if cfg.train_cfg.validation_dataset_cfg is not None: + if isinstance(cfg.train_cfg.validation_dataset_cfg, int): + # split the training dataset + split_dataset_sizes: tuple[int, int] = [ + len(dataset) - cfg.train_cfg.validation_dataset_cfg, + cfg.train_cfg.validation_dataset_cfg, + ] + sub_dataset, sub_val_dataset = random_split(dataset, split_dataset_sizes) + dataset = sub_dataset.dataset + val_dataset = sub_val_dataset.dataset + dataset.update_self_config() + val_dataset.update_self_config() + logger.progress( + f"got validation dataset by splitting training dataset into {len(dataset)} train and {len(val_dataset)} validation samples" + ) + elif isinstance(cfg.train_cfg.validation_dataset_cfg, MazeDatasetConfig): + val_dataset = MazeDataset.from_config( + cfg=cfg.train_cfg.validation_dataset_cfg, + do_generate=do_generate_dataset, + local_base_path=base_path, + verbose=dataset_verbose, + ) + logger.progress( + f"got custom validation dataset with {len(val_dataset)} samples" + ) # get dataloader and then train dataloader: DataLoader = get_dataloader(dataset, cfg, logger) @@ -121,6 +149,7 @@ def train_model( logger=logger, output_dir=output_path, device=device, + val_dataset=val_dataset, ) return TrainingResult( diff --git a/maze_transformer/training/train_save_files.py b/maze_transformer/training/train_save_files.py new file mode 100644 index 00000000..8cc833a7 --- /dev/null +++ b/maze_transformer/training/train_save_files.py @@ -0,0 +1,31 @@ +from datetime import datetime +from typing import Callable + +from muutils.misc import freeze, sanitize_fname # type: ignore[import] + +from maze_transformer.training.config import ConfigHolder + + +@freeze +class TRAIN_SAVE_FILES: + """namespace for filenames/formats for saving training data""" + + # old + data_cfg: str = "data_config.json" + train_cfg: str = "train_config.json" + model_checkpt: Callable[[int], str] = lambda iteration: f"model.iter_{iteration}.pt" + model_final: str = "model.final.pt" + + # keep these + config_holder: str = "config.json" + checkpoints: str = "checkpoints" + log: str = "log.jsonl" + model_checkpt_zanj: Callable[ + [int], str + ] = lambda iteration: f"model.iter_{iteration}.zanj" + model_final_zanj: str = "model.final.zanj" + model_run_dir: Callable[ + [ConfigHolder], str + ] = ( + lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" + ) diff --git a/maze_transformer/training/training.py b/maze_transformer/training/training.py index 272e1008..c2eb0bbb 100644 --- a/maze_transformer/training/training.py +++ b/maze_transformer/training/training.py @@ -1,45 +1,23 @@ -from datetime import datetime +import warnings from functools import partial from pathlib import Path -from typing import Callable import torch from jaxtyping import Float from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze -from muutils.misc import freeze, sanitize_fname # type: ignore[import] +from muutils.statcounter import StatCounter from torch.utils.data import DataLoader from transformer_lens.HookedTransformer import SingleLoss from zanj import ZANJ +from maze_transformer.evaluation.eval_model import evaluate_model +from maze_transformer.evaluation.path_evals import PathEvals +from maze_transformer.tokenizer import HuggingMazeTokenizer from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer +from maze_transformer.training.train_save_files import TRAIN_SAVE_FILES from maze_transformer.training.wandb_logger import WandbLogger -@freeze -class TRAIN_SAVE_FILES: - """namespace for filenames/formats for saving training data""" - - # old - data_cfg: str = "data_config.json" - train_cfg: str = "train_config.json" - model_checkpt: Callable[[int], str] = lambda iteration: f"model.iter_{iteration}.pt" - model_final: str = "model.final.pt" - - # keep these - config_holder: str = "config.json" - checkpoints: str = "checkpoints" - log: str = "log.jsonl" - model_checkpt_zanj: Callable[ - [int], str - ] = lambda iteration: f"model.iter_{iteration}.zanj" - model_final_zanj: str = "model.final.zanj" - model_run_dir: Callable[ - [ConfigHolder], str - ] = ( - lambda cfg: f"{sanitize_fname(cfg.name)}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" - ) - - def collate_batch(batch: list[SolvedMaze], config: MazeDatasetConfig) -> list[str]: return [" ".join(maze.as_tokens(config.node_token_map)) for maze in batch] @@ -65,12 +43,23 @@ def train( logger: WandbLogger, output_dir: Path, device: torch.device, + val_dataset: MazeDataset | None = None, zanj: ZANJ | None = None, + model: ZanjHookedTransformer | None = None, ) -> ZanjHookedTransformer: + # initialize + # ============================== if zanj is None: zanj = ZANJ() - logger.progress("Initializing model") - model: ZanjHookedTransformer = cfg.create_model_zanj() + + # init model & optimizer + if model is None: + logger.progress(f"Initializing model") + model: ZanjHookedTransformer = cfg.create_model_zanj() + model.to(device) + else: + logger.progress("Using existing model") + logger.summary({"device": str(device), "model.device": model.cfg.device}) logger.progress("Initializing optimizer") @@ -80,61 +69,116 @@ def train( ) logger.summary(dict(model_n_params=model.cfg.n_params)) - model.train() - logger.progress("Starting training") - n_batches: int = len(dataloader) - logger.summary({"n_batches": n_batches}) + # add wandb run url to model + model.training_records = { + "wandb_url": logger.url, + } + + # figure out whether to run evals, and validation dataset + evals_enabled: bool = cfg.train_cfg.validation_dataset_cfg is not None + if evals_enabled: + assert ( + val_dataset is not None + ), "val_dataset must be provided if evals are enabled" + + # Only the HuggingMazeTokenizer has token decoding implemented, which is required for evals + if not type(model.tokenizer) == HuggingMazeTokenizer: + warnings.warn( + "Using a tokenizer that cannot decode. Disabling evals for this run even though TrainConfig says to enable them" + ) + evals_enabled = False - checkpoint_interval_iters: int = max( - 1, - int(cfg.train_cfg.checkpoint_interval // cfg.train_cfg.batch_size), + val_dataset_tokens: list[list[str]] = val_dataset.as_tokens( + join_tokens_individual_maze=False + ) + + # compute intervals + n_samples: int = len(dataloader.dataset) + n_batches: int = len(dataloader) + intervals: dict[str, int] = cfg.train_cfg.get_intervals( + dataset_n_samples=n_samples, + mod_batch_size=True, ) - loss_interval_iters: int = max( - 1, int(cfg.train_cfg.print_loss_interval // cfg.train_cfg.batch_size) + if not evals_enabled: + intervals = { + key: value if not key.startswith("eval") else float("inf") + for key, value in intervals.items() + } + logger.summary( + {"n_batches": n_batches, "n_samples": n_samples, "intervals": intervals} ) logger.progress( - f"will train for {n_batches} batches, {checkpoint_interval_iters = }, {loss_interval_iters = }" + f"will train for {n_batches} batches, {evals_enabled=}, with intervals: {intervals}" ) + + # start up training + # ============================== + model.train() + logger.progress("Starting training") + for iteration, batch in enumerate(dataloader): + # forward pass + # ------------------------------ loss: SingleLoss logits: Float[torch.Tensor, "batch pos d_vocab"] logits, loss = model(batch, return_type="both") + + # backward pass + # ------------------------------ # Remove the last logit because it's the prediction for what comes after PATH_END (and so is meaningless) # Do this after computing loss because the loss_fn already ignores the last logit logits = logits[:, :-1, :] loss.backward() - optimizer.step() optimizer.zero_grad() - logger.log_metric({"loss": loss}) - - if iteration % loss_interval_iters == 0: + # log metrics + # ------------------------------ + metrics: dict[str, int | float | StatCounter] = {"loss": float(loss)} + + if evals_enabled: + for interval_key, evals_dict in PathEvals.PATH_EVALS_MAP.items(): + if iteration % intervals[interval_key] == 0: + logger.progress(f"Running evals: {interval_key}") + scores: dict[str, StatCounter] = evaluate_model( + model=model, + dataset=val_dataset, + dataset_tokens=val_dataset_tokens, + eval_functions=evals_dict, + batch_size=cfg.train_cfg.batch_size, + max_new_tokens=cfg.train_cfg.evals_max_new_tokens, + ) + metrics.update(scores) + logger.log_metric_hist(metrics) + + if iteration % intervals["print_loss"] == 0: logger.progress( f"iteration {iteration}/{n_batches}: loss={loss.item():.3f}" ) del loss - if iteration % checkpoint_interval_iters == 0: + # checkpoints + # ------------------------------ + if iteration % intervals["checkpoint"] == 0: model_save_path: Path = ( output_dir / TRAIN_SAVE_FILES.checkpoints / TRAIN_SAVE_FILES.model_checkpt_zanj(iteration) ) - logger.progress(f"Saving model to {model_save_path.as_posix()}") + logger.progress(f"Saving model checkpoint to {model_save_path.as_posix()}") zanj.save(model, model_save_path) logger.upload_model( model_save_path, aliases=["latest", f"iter-{iteration}"] ) # save the final model - # ================================================== + # ============================== final_model_path: Path = output_dir / TRAIN_SAVE_FILES.model_final_zanj logger.progress(f"Saving final model to {final_model_path.as_posix()}") zanj.save(model, final_model_path) logger.upload_model(final_model_path, aliases=["latest", "final"]) - logger.progress("Done!") + logger.progress("Done training!") return model diff --git a/maze_transformer/training/wandb_logger.py b/maze_transformer/training/wandb_logger.py index 6e8d6f3d..7442d6a7 100644 --- a/maze_transformer/training/wandb_logger.py +++ b/maze_transformer/training/wandb_logger.py @@ -7,7 +7,8 @@ from typing import Any, Dict, Union import wandb -from wandb.sdk.wandb_run import Run +from muutils.statcounter import StatCounter +from wandb.sdk.wandb_run import Artifact, Run class WandbProject(Enum): @@ -36,32 +37,49 @@ def create( datefmt="%Y-%m-%d %H:%M:%S", ) - run = wandb.init( + run: Run = wandb.init( config=config, project=(project.value if isinstance(project, WandbProject) else project), job_type=job_type.value, ) - logger = WandbLogger(run) + logger: WandbLogger = WandbLogger(run) logger.progress(f"{config =}") return logger def upload_model(self, model_path: Path, aliases=None) -> None: - artifact = wandb.Artifact(name=wandb.run.id, type="model") + artifact: Artifact = wandb.Artifact(name=wandb.run.id, type="model") artifact.add_file(str(model_path)) self._run.log_artifact(artifact, aliases=aliases) def upload_dataset(self, name: str, path: Path) -> None: - artifact = wandb.Artifact(name=name, type="dataset") + artifact: Artifact = wandb.Artifact(name=name, type="dataset") artifact.add_dir(local_path=str(path)) self._run.log_artifact(artifact) def log_metric(self, data: Dict[str, Any]) -> None: self._run.log(data) + def log_metric_hist(self, data: dict[str, float | int | StatCounter]) -> None: + # TODO: store the statcounters themselves somehow + data_processed: dict[str, int | float] = dict() + for key, value in data.items(): + if isinstance(value, StatCounter): + # we use the mean, since then smoothing a whole bunch of evals gives us an idea of the distribution + # data_processed[key + "-median"] = value.median() + data_processed[key + "-mean"] = value.mean() + # data_processed[key + "-std"] = value.std() + else: + data_processed[key] = value + self._run.log(data_processed) + def summary(self, data: Dict[str, Any]) -> None: self._run.summary.update(data) + @property + def url(self) -> str: + return self._run.get_url() + @staticmethod def progress(message: str) -> None: logging.info(message) diff --git a/notebooks/train_model.ipynb b/notebooks/train_model.ipynb index f67447b3..e2d28a67 100644 --- a/notebooks/train_model.ipynb +++ b/notebooks/train_model.ipynb @@ -102,8 +102,12 @@ " num_workers=0,\n", " drop_last=False,\n", " ),\n", - " print_loss_interval=100,\n", - " checkpoint_interval=1000,\n", + " intervals_count=dict(\n", + " print_loss=100,\n", + " checkpoint=5,\n", + " eval_fast=10,\n", + " eval_slow=5,\n", + " )\n", " ),\n", ")" ] @@ -189,8 +193,15 @@ " \"num_workers\": 0,\n", " \"drop_last\": false\n", " },\n", - " \"print_loss_interval\": 10,\n", - " \"checkpoint_interval\": 100\n", + " \"intervals\": null,\n", + " \"intervals_count\": {\n", + " \"print_loss\": 100,\n", + " \"checkpoint\": 2,\n", + " \"eval_fast\": 4,\n", + " \"eval_slow\": 2\n", + " },\n", + " \"evals_max_new_tokens\": 8,\n", + " \"validation_dataset_cfg\": 10\n", " },\n", " \"pretrainedtokenizer_kwargs\": null\n", "}\n" @@ -210,36 +221,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "seeing if we can download the dataset...\n", - "no download found, or download failed\n", - "generating dataset...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "generating & solving mazes: 100%|██████████| 100/100 [00:00<00:00, 965.73maze/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "saving dataset to ..\\data\\demo_small-g3-n100-a_dfs-h58410.zanj\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "loading dataset from ../data/demo_small-g3-n100-a_dfs-h58410.zanj\n", "Got dataset demo_small with 100 items. output.cfg.to_fname() = 'demo_small-g3-n100-a_dfs-h58410'\n" ] } @@ -262,7 +244,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-06-14 16:11:03 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" + "2023-06-15 02:27:36 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" ] }, { @@ -300,7 +282,7 @@ { "data": { "text/html": [ - "Run data is saved locally in f:\\KNC\\maze-transformer\\notebooks\\wandb\\run-20230614_161106-168sq19w" + "Run data is saved locally in f:\\KNC\\maze-transformer\\notebooks\\wandb\\run-20230615_022738-nnxf6kk4" ], "text/plain": [ "" @@ -312,7 +294,7 @@ { "data": { "text/html": [ - "Syncing run neat-sky-150 to Weights & Biases (docs)
" + "Syncing run pleasant-wildflower-196 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -336,7 +318,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/miv/integration-tests/runs/168sq19w" + " View run at https://wandb.ai/miv/integration-tests/runs/nnxf6kk4" ], "text/plain": [ "" @@ -349,30 +331,60 @@ "name": "stdout", "output_type": "stream", "text": [ - "2023-06-14 16:11:07 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'demo_small', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [], 'grid_n': 3, 'n_mazes': 100, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', ' (default: `None`)', ' - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' '], 'source_code': [' @staticmethod', ' def gen_dfs(', ' grid_shape: Coord,', ' lattice_dim: int = 2,', ' n_accessible_cells: int | None = None,', ' max_tree_depth: int | None = None,', ' do_forks: bool = True,', ' start_coord: Coord | None = None,', ' ) -> LatticeMaze:', ' \"\"\"generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', ' (default: `None`)', ' - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' \"\"\"', '', ' # Default values if no constraints have been passed', ' grid_shape: Coord = np.array(grid_shape)', ' n_total_cells: int = int(np.prod(grid_shape))', ' if n_accessible_cells is None:', ' n_accessible_cells = n_total_cells', ' if max_tree_depth is None:', ' max_tree_depth = (', ' 2 * n_total_cells', ' ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.', '', ' start_coord = _random_start_coord(grid_shape, start_coord)', '', ' # initialize the maze with no connections', ' connection_list: ConnectionList = np.zeros(', ' (lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_', ' )', '', ' # initialize the stack with the target coord', ' visited_cells: set[tuple[int, int]] = set()', ' visited_cells.add(tuple(start_coord))', ' stack: list[Coord] = [start_coord]', '', ' # initialize tree_depth_counter', ' current_tree_depth: int = 1', '', ' # loop until the stack is empty or n_connected_cells is reached', ' while stack and (len(visited_cells) < n_accessible_cells):', ' # get the current coord from the stack', ' current_coord: Coord = stack.pop()', '', ' # filter neighbors by being within grid bounds and being unvisited', ' unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [', ' (neighbor, delta)', ' for neighbor, delta in zip(', ' current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK', ' )', ' if (', ' (tuple(neighbor) not in visited_cells)', ' and (0 <= neighbor[0] < grid_shape[0])', ' and (0 <= neighbor[1] < grid_shape[1])', ' )', ' ]', '', \" # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)\", ' if unvisited_neighbors_deltas and (', ' current_tree_depth <= max_tree_depth / 2', ' ):', \" # if we want a maze without forks, simply don't add the current coord back to the stack\", ' if do_forks:', ' stack.append(current_coord)', '', ' # choose one of the unvisited neighbors', ' chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)', '', ' # add connection', ' dim: int = np.argmax(np.abs(delta))', ' # if positive, down/right from current coord', ' # if negative, up/left from current coord (down/right from neighbor)', ' clist_node: Coord = (', ' current_coord if (delta.sum() > 0) else chosen_neighbor', ' )', ' connection_list[dim, clist_node[0], clist_node[1]] = True', '', ' # add to visited cells and stack', ' visited_cells.add(tuple(chosen_neighbor))', ' stack.append(chosen_neighbor)', '', ' # Update current tree depth', ' current_tree_depth += 1', ' else:', ' current_tree_depth -= 1', '', ' return LatticeMaze(', ' connection_list=connection_list,', ' generation_meta=dict(', ' func_name=\"gen_dfs\",', ' grid_shape=grid_shape,', ' start_coord=start_coord,', ' n_accessible_cells=int(n_accessible_cells),', ' max_tree_depth=int(max_tree_depth),', ' fully_connected=bool(len(visited_cells) == n_accessible_cells),', ' visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},', ' ),', ' )']}, 'maze_ctor_kwargs': {}, 'padding_token_index': 10, 'token_arr': ['', '', '', '', '', '', '', '', '<-->', ';', '', '(0,0)', '(0,1)', '(1,0)', '(1,1)', '(0,2)', '(2,0)', '(1,2)', '(2,1)', '(2,2)'], 'tokenizer_map': {'': 0, '': 1, '': 2, '': 3, '': 4, '': 5, '': 6, '': 7, '<-->': 8, ';': 9, '': 10, '(0,0)': 11, '(0,1)': 12, '(1,0)': 13, '(1,1)': 14, '(0,2)': 15, '(2,0)': 16, '(1,2)': 17, '(2,1)': 18, '(2,2)': 19}, 'grid_shape': (3, 3), 'token_node_map': {'(0,0)': (0, 0), '(0,1)': (0, 1), '(1,0)': (1, 0), '(1,1)': (1, 1), '(0,2)': (0, 2), '(2,0)': (2, 0), '(1,2)': (1, 2), '(2,1)': (2, 1), '(2,2)': (2, 2)}, 'n_tokens': 20}, 'model_cfg': {'__format__': 'BaseGPTConfig(SerializableDataclass)', 'name': 'nano-v1', 'act_fn': 'gelu', 'd_model': 8, 'd_head': 4, 'n_layers': 2, 'weight_processing': {'are_layernorms_folded': False, 'are_weights_processed': False}, 'n_heads': 2}, 'train_cfg': {'__format__': 'TrainConfig(SerializableDataclass)', 'name': 'test-v1', 'optimizer': 'RMSprop', 'optimizer_kwargs': {'lr': 0.0001}, 'batch_size': 16, 'dataloader_cfg': {'shuffle': True, 'num_workers': 0, 'drop_last': False}, 'print_loss_interval': 10, 'checkpoint_interval': 100}, 'name': 'multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1', 'pretrainedtokenizer_kwargs': None}\n", - "2023-06-14 16:11:07 INFO Initialized logger\n", - "2023-06-14 16:11:07 INFO Summary logged, getting dataset\n", + "2023-06-15 02:27:39 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'demo_small', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [], 'grid_n': 3, 'n_mazes': 100, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', ' (default: `None`)', ' - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' '], 'source_code': [' @staticmethod', ' def gen_dfs(', ' grid_shape: Coord,', ' lattice_dim: int = 2,', ' n_accessible_cells: int | None = None,', ' max_tree_depth: int | None = None,', ' do_forks: bool = True,', ' start_coord: Coord | None = None,', ' ) -> LatticeMaze:', ' \"\"\"generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', ' (default: `None`)', ' - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' \"\"\"', '', ' # Default values if no constraints have been passed', ' grid_shape: Coord = np.array(grid_shape)', ' n_total_cells: int = int(np.prod(grid_shape))', ' if n_accessible_cells is None:', ' n_accessible_cells = n_total_cells', ' if max_tree_depth is None:', ' max_tree_depth = (', ' 2 * n_total_cells', ' ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.', '', ' start_coord = _random_start_coord(grid_shape, start_coord)', '', ' # initialize the maze with no connections', ' connection_list: ConnectionList = np.zeros(', ' (lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_', ' )', '', ' # initialize the stack with the target coord', ' visited_cells: set[tuple[int, int]] = set()', ' visited_cells.add(tuple(start_coord))', ' stack: list[Coord] = [start_coord]', '', ' # initialize tree_depth_counter', ' current_tree_depth: int = 1', '', ' # loop until the stack is empty or n_connected_cells is reached', ' while stack and (len(visited_cells) < n_accessible_cells):', ' # get the current coord from the stack', ' current_coord: Coord = stack.pop()', '', ' # filter neighbors by being within grid bounds and being unvisited', ' unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [', ' (neighbor, delta)', ' for neighbor, delta in zip(', ' current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK', ' )', ' if (', ' (tuple(neighbor) not in visited_cells)', ' and (0 <= neighbor[0] < grid_shape[0])', ' and (0 <= neighbor[1] < grid_shape[1])', ' )', ' ]', '', \" # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)\", ' if unvisited_neighbors_deltas and (', ' current_tree_depth <= max_tree_depth / 2', ' ):', \" # if we want a maze without forks, simply don't add the current coord back to the stack\", ' if do_forks:', ' stack.append(current_coord)', '', ' # choose one of the unvisited neighbors', ' chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)', '', ' # add connection', ' dim: int = np.argmax(np.abs(delta))', ' # if positive, down/right from current coord', ' # if negative, up/left from current coord (down/right from neighbor)', ' clist_node: Coord = (', ' current_coord if (delta.sum() > 0) else chosen_neighbor', ' )', ' connection_list[dim, clist_node[0], clist_node[1]] = True', '', ' # add to visited cells and stack', ' visited_cells.add(tuple(chosen_neighbor))', ' stack.append(chosen_neighbor)', '', ' # Update current tree depth', ' current_tree_depth += 1', ' else:', ' current_tree_depth -= 1', '', ' return LatticeMaze(', ' connection_list=connection_list,', ' generation_meta=dict(', ' func_name=\"gen_dfs\",', ' grid_shape=grid_shape,', ' start_coord=start_coord,', ' n_accessible_cells=int(n_accessible_cells),', ' max_tree_depth=int(max_tree_depth),', ' fully_connected=bool(len(visited_cells) == n_accessible_cells),', ' visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},', ' ),', ' )']}, 'maze_ctor_kwargs': {}, 'padding_token_index': 10, 'token_arr': ['', '', '', '', '', '', '', '', '<-->', ';', '', '(0,0)', '(0,1)', '(1,0)', '(1,1)', '(0,2)', '(2,0)', '(1,2)', '(2,1)', '(2,2)'], 'tokenizer_map': {'': 0, '': 1, '': 2, '': 3, '': 4, '': 5, '': 6, '': 7, '<-->': 8, ';': 9, '': 10, '(0,0)': 11, '(0,1)': 12, '(1,0)': 13, '(1,1)': 14, '(0,2)': 15, '(2,0)': 16, '(1,2)': 17, '(2,1)': 18, '(2,2)': 19}, 'grid_shape': (3, 3), 'token_node_map': {'(0,0)': (0, 0), '(0,1)': (0, 1), '(1,0)': (1, 0), '(1,1)': (1, 1), '(0,2)': (0, 2), '(2,0)': (2, 0), '(1,2)': (1, 2), '(2,1)': (2, 1), '(2,2)': (2, 2)}, 'n_tokens': 20}, 'model_cfg': {'__format__': 'BaseGPTConfig(SerializableDataclass)', 'name': 'nano-v1', 'act_fn': 'gelu', 'd_model': 8, 'd_head': 4, 'n_layers': 2, 'weight_processing': {'are_layernorms_folded': False, 'are_weights_processed': False}, 'n_heads': 2}, 'train_cfg': {'__format__': 'TrainConfig(SerializableDataclass)', 'name': 'test-v1', 'evals_max_new_tokens': 8, 'validation_dataset_cfg': 10, 'optimizer': 'RMSprop', 'optimizer_kwargs': {'lr': 0.0001}, 'batch_size': 16, 'dataloader_cfg': {'shuffle': True, 'num_workers': 0, 'drop_last': False}, 'intervals': None, 'intervals_count': {'print_loss': 100, 'checkpoint': 2, 'eval_fast': 4, 'eval_slow': 2}}, 'name': 'multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1', 'pretrainedtokenizer_kwargs': None}\n", + "2023-06-15 02:27:39 INFO Initialized logger\n", + "2023-06-15 02:27:39 INFO Summary logged, getting dataset\n", "loading dataset from ../data/demo_small-g3-n100-a_dfs-h58410.zanj\n", "Got dataset demo_small with 100 items. output.cfg.to_fname() = 'demo_small-g3-n100-a_dfs-h58410'\n", - "2023-06-14 16:11:07 INFO finished getting dataset\n", - "2023-06-14 16:11:07 INFO Loaded 100 sequences\n", - "2023-06-14 16:11:07 INFO Creating dataloader\n", - "2023-06-14 16:11:07 INFO finished dataloader, passing to train()\n", - "2023-06-14 16:11:07 INFO Initializing model\n", - "2023-06-14 16:11:07 INFO Initializing optimizer\n", - "2023-06-14 16:11:07 INFO Starting training\n", - "2023-06-14 16:11:07 INFO will train for 7 batches, checkpoint_interval_iters = 6, loss_interval_iters = 1\n", - "2023-06-14 16:11:07 INFO iteration 0/7: loss=3.271\n", - "2023-06-14 16:11:07 INFO Saving model to ../data/multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1_2023-06-14-16-11-00/checkpoints/model.iter_0.zanj\n", - "2023-06-14 16:11:07 INFO iteration 1/7: loss=3.275\n", - "2023-06-14 16:11:07 INFO iteration 2/7: loss=3.241\n", - "2023-06-14 16:11:07 INFO iteration 3/7: loss=3.231\n", - "2023-06-14 16:11:07 INFO iteration 4/7: loss=3.220\n", - "2023-06-14 16:11:08 INFO iteration 5/7: loss=3.191\n", - "2023-06-14 16:11:08 INFO iteration 6/7: loss=3.159\n", - "2023-06-14 16:11:08 INFO Saving model to ../data/multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1_2023-06-14-16-11-00/checkpoints/model.iter_6.zanj\n", - "2023-06-14 16:11:08 INFO Saving final model to ../data/multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1_2023-06-14-16-11-00/model.final.zanj\n", - "2023-06-14 16:11:08 INFO Done!\n" + "2023-06-15 02:27:39 INFO finished getting training dataset with 100 samples\n", + "2023-06-15 02:27:39 INFO got validation dataset by splitting training dataset into 100 train and 100 validation samples\n", + "2023-06-15 02:27:39 INFO Loaded 100 sequences\n", + "2023-06-15 02:27:39 INFO Creating dataloader\n", + "2023-06-15 02:27:39 INFO finished dataloader, passing to train()\n", + "2023-06-15 02:27:39 INFO Initializing model\n", + "Moving model to device: cpu\n", + "2023-06-15 02:27:39 INFO Initializing optimizer\n", + "2023-06-15 02:27:39 INFO will train for 7 batches, evals_enabled=True, with intervals: {'print_loss': 1, 'checkpoint': 3, 'eval_fast': 1, 'eval_slow': 3}\n", + "2023-06-15 02:27:39 INFO Starting training\n", + "2023-06-15 02:27:39 INFO Running evals: eval_fast\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "F:\\KNC\\maze-transformer\\maze_transformer\\evaluation\\path_evals.py:91: RuntimeWarning:\n", + "\n", + "fraction_connections_adjacent_lattice called on path of length less than 2, retuning NaN\n", + "prediction = array([[2, 0]])\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-06-15 02:27:44 INFO Running evals: eval_slow\n", + "2023-06-15 02:27:51 INFO iteration 0/7: loss=3.310\n", + "2023-06-15 02:27:51 INFO Saving model checkpoint to ../data/multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1_2023-06-15-02-27-35/checkpoints/model.iter_0.zanj\n", + "2023-06-15 02:27:51 INFO Running evals: eval_fast\n", + "2023-06-15 02:27:57 INFO iteration 1/7: loss=3.302\n", + "2023-06-15 02:27:57 INFO Running evals: eval_fast\n", + "2023-06-15 02:28:01 INFO iteration 2/7: loss=3.290\n", + "2023-06-15 02:28:01 INFO Running evals: eval_fast\n", + "2023-06-15 02:28:05 INFO Running evals: eval_slow\n", + "2023-06-15 02:28:10 INFO iteration 3/7: loss=3.272\n", + "2023-06-15 02:28:10 INFO Saving model checkpoint to ../data/multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1_2023-06-15-02-27-35/checkpoints/model.iter_3.zanj\n", + "2023-06-15 02:28:10 INFO Running evals: eval_fast\n", + "2023-06-15 02:28:15 INFO iteration 4/7: loss=3.262\n", + "2023-06-15 02:28:15 INFO Running evals: eval_fast\n", + "2023-06-15 02:28:19 INFO iteration 5/7: loss=3.273\n", + "2023-06-15 02:28:19 INFO Running evals: eval_fast\n", + "2023-06-15 02:28:23 INFO Running evals: eval_slow\n", + "2023-06-15 02:28:27 INFO iteration 6/7: loss=3.258\n", + "2023-06-15 02:28:27 INFO Saving model checkpoint to ../data/multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1_2023-06-15-02-27-35/checkpoints/model.iter_6.zanj\n", + "2023-06-15 02:28:28 INFO Saving final model to ../data/multsrc_demo_small-g3-n100-a_dfs-h58410_nano-v1_test-v1_2023-06-15-02-27-35/model.final.zanj\n", + "2023-06-15 02:28:28 INFO Done training!\n" ] } ], diff --git a/notebooks/train_model_hallway.ipynb b/notebooks/train_model_hallway.ipynb index 244047d1..ba353ac8 100644 --- a/notebooks/train_model_hallway.ipynb +++ b/notebooks/train_model_hallway.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -35,9 +35,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DEVICE = device(type='cpu')\n" + ] + } + ], "source": [ "# set global defaults for ZANJ\n", "ZANJ_GLOBAL_DEFAULTS.external_array_threshold = 1024\n", @@ -54,9 +62,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "list(MAZE_DATASET_CONFIGS.keys()) = ['test-g3-n5-a_dfs-h89001', 'demo_small-g3-n100-a_dfs-h58410', 'demo-g6-n10K-a_dfs-h86254']\n" + ] + } + ], "source": [ "print(f\"{list(MAZE_DATASET_CONFIGS.keys()) = }\")\n", "\n", @@ -89,8 +105,6 @@ " num_workers=4,\n", " drop_last=False,\n", " ),\n", - " print_loss_interval=100,\n", - " checkpoint_interval=1000,\n", " ),\n", ")\n", "\n", @@ -121,15 +135,13 @@ " shuffle=True,\n", " drop_last=False,\n", " ),\n", - " print_loss_interval=4,\n", - " checkpoint_interval=1000,\n", " ),\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -139,18 +151,110 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"name\": \"hallway-nano\",\n", + " \"dataset_cfg\": {\n", + " \"name\": \"custom-hallway\",\n", + " \"fname\": \"custom-hallway-g3-n8-a_dfs-h77610\",\n", + " \"sdc_hash\": 101966949714109448691488851066661809366629862507217849163669177948570579077610,\n", + " \"seed\": 42,\n", + " \"seq_len_min\": 1,\n", + " \"seq_len_max\": 512,\n", + " \"padding_token_index\": 10,\n", + " \"token_arr_joined\": \" <--> ; (0,0) (0,1) (1,0) (1,1) (0,2) (2,0) (1,2) (2,1) (2,2)\",\n", + " \"applied_filters\": [],\n", + " \"grid_n\": 3,\n", + " \"grid_shape\": [\n", + " 3,\n", + " 3\n", + " ],\n", + " \"n_mazes\": 8,\n", + " \"maze_ctor_name\": \"gen_dfs\",\n", + " \"maze_ctor_kwargs\": {\n", + " \"do_forks\": false\n", + " }\n", + " },\n", + " \"model_cfg\": {\n", + " \"name\": \"custom-model\",\n", + " \"act_fn\": \"gelu\",\n", + " \"d_model\": 8,\n", + " \"d_head\": 2,\n", + " \"n_layers\": 2,\n", + " \"weight_processing\": {\n", + " \"are_layernorms_folded\": false,\n", + " \"are_weights_processed\": false\n", + " },\n", + " \"n_heads\": 4\n", + " },\n", + " \"train_cfg\": {\n", + " \"name\": \"custom-train\",\n", + " \"optimizer\": \"AdamW\",\n", + " \"optimizer_kwargs\": {\n", + " \"lr\": 0.0001\n", + " },\n", + " \"batch_size\": 4,\n", + " \"dataloader_cfg\": {\n", + " \"shuffle\": true,\n", + " \"drop_last\": false\n", + " },\n", + " \"intervals\": null,\n", + " \"intervals_count\": null,\n", + " \"evals_max_new_tokens\": 8,\n", + " \"validation_dataset_cfg\": null\n", + " },\n", + " \"pretrainedtokenizer_kwargs\": null\n", + "}\n" + ] + } + ], "source": [ "print(json.dumps(CFG.summary(), indent=2))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seeing if we can download the dataset...\n", + "no download found, or download failed\n", + "generating dataset...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "generating & solving mazes: 100%|██████████| 8/8 [00:00<00:00, 253.48maze/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving dataset to ..\\data\\custom-hallway-g3-n8-a_dfs-h77610.zanj\n", + "Got dataset custom-hallway with 8 items. output.cfg.to_fname() = 'custom-hallway-g3-n8-a_dfs-h77610'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# get just the dataset, generating it if needed. \n", "# This step can be skipped if you set `do_generate_dataset=True` when calling `train_model`\n", @@ -167,7 +271,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -177,9 +281,135 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-06-15 02:29:30 ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmiv\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "22698976caa44b888ff3d62cbfaf61d9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016666666666666666, max=1.0…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "wandb version 0.15.4 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.13.11" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in f:\\KNC\\maze-transformer\\notebooks\\wandb\\run-20230615_022933-3jy69jog" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run leafy-jazz-197 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/miv/integration-tests" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/miv/integration-tests/runs/3jy69jog" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-06-15 02:29:42 INFO config ={'__format__': 'ConfigHolder(SerializableDataclass)', 'dataset_cfg': {'__format__': 'MazeDatasetConfig(SerializableDataclass)', 'name': 'custom-hallway', 'seq_len_min': 1, 'seq_len_max': 512, 'seed': 42, 'applied_filters': [{'name': 'collect_generation_meta', 'args': (), 'kwargs': {}}], 'grid_n': 3, 'n_mazes': 8, 'maze_ctor': {'__name__': 'gen_dfs', '__module__': 'maze_dataset.generation.generators', '__doc__': ['generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', ' (default: `None`)', ' - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' '], 'source_code': [' @staticmethod', ' def gen_dfs(', ' grid_shape: Coord,', ' lattice_dim: int = 2,', ' n_accessible_cells: int | None = None,', ' max_tree_depth: int | None = None,', ' do_forks: bool = True,', ' start_coord: Coord | None = None,', ' ) -> LatticeMaze:', ' \"\"\"generate a lattice maze using depth first search, iterative', '', ' # Arguments', ' - `grid_shape: Coord`: the shape of the grid', ' - `lattice_dim: int`: the dimension of the lattice', ' (default: `2`)', ' - `n_accessible_cells: int | None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid.', ' (default: `None`)', ' - `max_tree_depth: int | None`: the maximum depth of the tree. If `None`, defaults to `2 * n_accessible_cells`.', ' (default: `None`)', ' - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.', ' - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.', '', '', ' # algorithm', ' 1. Choose the initial cell, mark it as visited and push it to the stack', ' 2. While the stack is not empty', ' 1. Pop a cell from the stack and make it a current cell', ' 2. If the current cell has any neighbours which have not been visited', ' 1. Push the current cell to the stack', ' 2. Choose one of the unvisited neighbours', ' 3. Remove the wall between the current cell and the chosen cell', ' 4. Mark the chosen cell as visited and push it to the stack', ' \"\"\"', '', ' # Default values if no constraints have been passed', ' grid_shape: Coord = np.array(grid_shape)', ' n_total_cells: int = int(np.prod(grid_shape))', ' if n_accessible_cells is None:', ' n_accessible_cells = n_total_cells', ' if max_tree_depth is None:', ' max_tree_depth = (', ' 2 * n_total_cells', ' ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.', '', ' start_coord = _random_start_coord(grid_shape, start_coord)', '', ' # initialize the maze with no connections', ' connection_list: ConnectionList = np.zeros(', ' (lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_', ' )', '', ' # initialize the stack with the target coord', ' visited_cells: set[tuple[int, int]] = set()', ' visited_cells.add(tuple(start_coord))', ' stack: list[Coord] = [start_coord]', '', ' # initialize tree_depth_counter', ' current_tree_depth: int = 1', '', ' # loop until the stack is empty or n_connected_cells is reached', ' while stack and (len(visited_cells) < n_accessible_cells):', ' # get the current coord from the stack', ' current_coord: Coord = stack.pop()', '', ' # filter neighbors by being within grid bounds and being unvisited', ' unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [', ' (neighbor, delta)', ' for neighbor, delta in zip(', ' current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK', ' )', ' if (', ' (tuple(neighbor) not in visited_cells)', ' and (0 <= neighbor[0] < grid_shape[0])', ' and (0 <= neighbor[1] < grid_shape[1])', ' )', ' ]', '', \" # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)\", ' if unvisited_neighbors_deltas and (', ' current_tree_depth <= max_tree_depth / 2', ' ):', \" # if we want a maze without forks, simply don't add the current coord back to the stack\", ' if do_forks:', ' stack.append(current_coord)', '', ' # choose one of the unvisited neighbors', ' chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)', '', ' # add connection', ' dim: int = np.argmax(np.abs(delta))', ' # if positive, down/right from current coord', ' # if negative, up/left from current coord (down/right from neighbor)', ' clist_node: Coord = (', ' current_coord if (delta.sum() > 0) else chosen_neighbor', ' )', ' connection_list[dim, clist_node[0], clist_node[1]] = True', '', ' # add to visited cells and stack', ' visited_cells.add(tuple(chosen_neighbor))', ' stack.append(chosen_neighbor)', '', ' # Update current tree depth', ' current_tree_depth += 1', ' else:', ' current_tree_depth -= 1', '', ' return LatticeMaze(', ' connection_list=connection_list,', ' generation_meta=dict(', ' func_name=\"gen_dfs\",', ' grid_shape=grid_shape,', ' start_coord=start_coord,', ' n_accessible_cells=int(n_accessible_cells),', ' max_tree_depth=int(max_tree_depth),', ' fully_connected=bool(len(visited_cells) == n_accessible_cells),', ' visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},', ' ),', ' )']}, 'maze_ctor_kwargs': {'do_forks': False}, 'padding_token_index': 10, 'token_arr': ['', '', '', '', '', '', '', '', '<-->', ';', '', '(0,0)', '(0,1)', '(1,0)', '(1,1)', '(0,2)', '(2,0)', '(1,2)', '(2,1)', '(2,2)'], 'tokenizer_map': {'': 0, '': 1, '': 2, '': 3, '': 4, '': 5, '': 6, '': 7, '<-->': 8, ';': 9, '': 10, '(0,0)': 11, '(0,1)': 12, '(1,0)': 13, '(1,1)': 14, '(0,2)': 15, '(2,0)': 16, '(1,2)': 17, '(2,1)': 18, '(2,2)': 19}, 'grid_shape': (3, 3), 'token_node_map': {'(0,0)': (0, 0), '(0,1)': (0, 1), '(1,0)': (1, 0), '(1,1)': (1, 1), '(0,2)': (0, 2), '(2,0)': (2, 0), '(1,2)': (1, 2), '(2,1)': (2, 1), '(2,2)': (2, 2)}, 'n_tokens': 20}, 'model_cfg': {'__format__': 'BaseGPTConfig(SerializableDataclass)', 'name': 'custom-model', 'act_fn': 'gelu', 'd_model': 8, 'd_head': 2, 'n_layers': 2, 'weight_processing': {'are_layernorms_folded': False, 'are_weights_processed': False}, 'n_heads': 4}, 'train_cfg': {'__format__': 'TrainConfig(SerializableDataclass)', 'name': 'custom-train', 'evals_max_new_tokens': 8, 'validation_dataset_cfg': None, 'optimizer': 'AdamW', 'optimizer_kwargs': {'lr': 0.0001}, 'batch_size': 4, 'dataloader_cfg': {'shuffle': True, 'drop_last': False}, 'intervals': None, 'intervals_count': None}, 'name': 'hallway-nano', 'pretrainedtokenizer_kwargs': None}\n", + "2023-06-15 02:29:42 INFO Initialized logger\n", + "2023-06-15 02:29:42 INFO Summary logged, getting dataset\n", + "loading dataset from ../data/custom-hallway-g3-n8-a_dfs-h76723.zanj\n", + "Got dataset custom-hallway with 8 items. output.cfg.to_fname() = 'custom-hallway-g3-n8-a_dfs-h76723'\n", + "2023-06-15 02:29:42 INFO finished getting training dataset with 8 samples\n", + "2023-06-15 02:29:42 INFO Loaded 8 sequences\n", + "2023-06-15 02:29:42 INFO Creating dataloader\n", + "2023-06-15 02:29:42 INFO finished dataloader, passing to train()\n", + "2023-06-15 02:29:42 INFO Initializing model\n", + "Moving model to device: cpu\n", + "2023-06-15 02:29:42 INFO Initializing optimizer\n", + "2023-06-15 02:29:42 INFO will train for 2 batches, evals_enabled=False, with intervals: {'print_loss': inf, 'checkpoint': inf, 'eval_fast': inf, 'eval_slow': inf}\n", + "2023-06-15 02:29:42 INFO Starting training\n", + "2023-06-15 02:29:42 INFO iteration 0/2: loss=3.215\n", + "2023-06-15 02:29:42 INFO Saving model checkpoint to ../data/hallway-nano_2023-06-15-02-29-29/checkpoints/model.iter_0.zanj\n", + "2023-06-15 02:29:44 INFO Saving final model to ../data/hallway-nano_2023-06-15-02-29-29/model.final.zanj\n", + "2023-06-15 02:29:45 INFO Done training!\n" + ] + } + ], "source": [ "result: TrainingResult = train_model(\n", "\tbase_path=PATH_DATA,\n", diff --git a/notes.md b/notes.md new file mode 100644 index 00000000..fb6d4a63 --- /dev/null +++ b/notes.md @@ -0,0 +1,23 @@ +# Objectives + + - Dataloader yields shuffled, tokenized mazes to training loop + - Shuffling should be: + - Order of adj list + - Order of coord pairs within adj list + - Order of sections (adjlist, target, origin, NOT path) + - Shuffling should be controllable with config (not essential for MVP) + - Training loop performance should be as high as possible + - Either shuffling a batch needs to be a negligible perf impact, or we precompute it somehow (but then we may have memory concerns) + - Right now shuffling is done as part of __getitem__, it's not a batch operation, so we're unlikely to make perf worse with these changes) + - Need to run some benchmarks on different approaches to get clarity here + - __getitem__ should NOT shuffle (we want dataset[0] == dataset[0]) + - Outside of the training loop, it should be easily possible to get a shuffled or unshuffled tokenized maze (either shuffled adjlist, or shuffled entire tokenized maze) + - __getitem__ should perhaps return SolvedMazes (probably the most ergonomic option) + - If possible, remove some duplicate accessor stuff from dataset + - + +# Approach + + - Dataloader collate_fn looks promising. Operated on yielded batch of samples from Dataset. This seems like a good place for tokenization and shuffling (and would likely be faster as it happens on the batch) + + diff --git a/poetry.lock b/poetry.lock index 21a0a610..49a1bbcd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2059,7 +2059,7 @@ zanj = "^0.1.4" type = "git" url = "https://github.com/AISC-understanding-search/maze-dataset.git" reference = "HEAD" -resolved_reference = "991347a3176c7aa989335d74df23822b6b0ea506" +resolved_reference = "ce914b05c2182ccfcf93842783a585bff375d715" [[package]] name = "mdurl" diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..1bf81633 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,19 @@ +{ + "verbose": true, + "include": [ + "maze_transformer", + "tests", + "scripts" + ], + "exclude": [ + "**/.venv**", + "**/.git**", + "**/.github**", + "**/.idea**", + "**/.pytest_cache**", + "**/data**", + "**/wandb**" + ], + "venvPath": ".", + "venv": ".venv" +} diff --git a/tests/integration/test_eval_model.py b/tests/integration/test_eval_model.py index b1ba2c29..e359ebe3 100644 --- a/tests/integration/test_eval_model.py +++ b/tests/integration/test_eval_model.py @@ -18,7 +18,7 @@ from maze_transformer.test_helpers.assertions import assert_model_output_equality from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer from maze_transformer.training.train_model import TrainingResult, train_model -from maze_transformer.training.training import TRAIN_SAVE_FILES +from maze_transformer.training.train_save_files import TRAIN_SAVE_FILES from maze_transformer.training.wandb_logger import WandbProject temp_dir: Path = Path("tests/_temp/test_eval_model") @@ -70,11 +70,9 @@ def test_predict_maze_paths(): ) model: ZanjHookedTransformer = result.model - dataset: MazeDataset = MazeDataset.from_config( - cfg=cfg.dataset_cfg, - ) + dataset: MazeDataset = MazeDataset.from_config(cfg=cfg.dataset_cfg) - max_new_tokens = 2 + max_new_tokens = 3 paths = predict_maze_paths( tokens_batch=dataset.as_tokens(), data_cfg=cfg.dataset_cfg, @@ -117,11 +115,9 @@ def test_evaluate_model(temp_dir): ) model: ZanjHookedTransformer = result.model - dataset: MazeDataset = MazeDataset.from_config( - cfg=cfg.dataset_cfg, - ) + dataset: MazeDataset = MazeDataset.from_config(cfg=cfg.dataset_cfg) - path_evals = PathEvals.evals + path_evals = PathEvals.fast eval_names = [name for name in path_evals.keys()] scores = evaluate_model(dataset=dataset, model=model) diff --git a/tests/integration/test_training.py b/tests/integration/test_training.py new file mode 100644 index 00000000..0e01022a --- /dev/null +++ b/tests/integration/test_training.py @@ -0,0 +1,129 @@ +import re +from copy import deepcopy +from pathlib import Path + +import pytest +from maze_dataset import MazeDataset, MazeDatasetConfig +from muutils.mlutils import get_device + +from maze_transformer.evaluation.path_evals import PathEvals +from maze_transformer.test_helpers.stub_logger import StubLogger +from maze_transformer.training.config import GPT_CONFIGS, TRAINING_CONFIGS, ConfigHolder +from maze_transformer.training.train_save_files import TRAIN_SAVE_FILES +from maze_transformer.training.training import get_dataloader, train +from maze_transformer.training.wandb_logger import WandbJobType, WandbProject + + +@pytest.mark.usefixtures("temp_dir") +def test_train_model_without_evals(temp_dir: Path): + dataset = _create_dataset() + cfg = _create_tokenizer_config(dataset.cfg, batch_size=5) + + output_path = _create_output_path(cfg, temp_dir) + logger = _create_logger(cfg) + dataloader = get_dataloader(dataset, cfg, logger) + device = get_device() + cfg.train_cfg.validation_dataset_cfg = None + + train( + dataloader=dataloader, + cfg=cfg, + logger=logger, + output_dir=output_path, + device=device, + ) + + metrics = _get_metrics(logger.logs) + assert len(metrics) == 2 + assert list(metrics[0].keys()) == ["loss"] + + +@pytest.mark.usefixtures("temp_dir") +def test_train_model_with_evals(temp_dir: Path): + dataset = _create_dataset() + cfg = _create_tokenizer_config(dataset.cfg, batch_size=5) + + output_path = _create_output_path(cfg, temp_dir) + logger = _create_logger(cfg) + dataloader = get_dataloader(dataset, cfg, logger) + device = get_device() + + # fast should run every 5 mazes (1 batch), slow every 10 mazes (2 batches) + cfg.train_cfg.intervals = dict( + print_loss=1, + checkpoint=10, + eval_fast=5, + eval_slow=10, + ) + cfg.train_cfg.intervals_count = None + cfg.train_cfg.validation_dataset_cfg = deepcopy(cfg.dataset_cfg) + val_dataset: MazeDataset = MazeDataset.from_config( + cfg.train_cfg.validation_dataset_cfg, + ) + + train( + dataloader=dataloader, + cfg=cfg, + logger=logger, + output_dir=output_path, + device=device, + val_dataset=val_dataset, + ) + + metrics = _get_metrics(logger.logs) + + # we should have 1 loop with fast evals and 1 loop with fast and slow + assert len(metrics) == 2 + assert set(metrics[0].keys()) == {"loss", *PathEvals.fast.keys()} + assert set(metrics[0].keys()) == { + "loss", + *PathEvals.fast.keys(), + *PathEvals.slow.keys(), + } + + +def _create_dataset(n_mazes: int = 10, grid_n: int = 3) -> MazeDataset: + dataset_cfg: MazeDatasetConfig = MazeDatasetConfig( + name="test", n_mazes=n_mazes, grid_n=grid_n + ) + dataset = MazeDataset.from_config(dataset_cfg) + # dataset.cfg.seq_len_max = 32 + # TODO(@mivanit): the above line caused me much pain. setting the sequence length in the tokenizer to below the length of the actual sequence passed causes horrible things to happen in `predict_maze_paths()` + return dataset + + +def _create_logger(cfg: ConfigHolder) -> StubLogger: + logger = StubLogger.create( + config=cfg.serialize(), + project=WandbProject.INTEGRATION_TESTS, + job_type=WandbJobType.TRAIN_MODEL, + ) + return logger + + +def _create_output_path(cfg: ConfigHolder, temp_dir: Path) -> Path: + output_dir_name = TRAIN_SAVE_FILES.model_run_dir(cfg) + output_path: Path = temp_dir / output_dir_name + (output_path / TRAIN_SAVE_FILES.checkpoints).mkdir(parents=True) + return output_path + + +def _create_tokenizer_config( + dataset_cfg: MazeDatasetConfig, batch_size: int = 5 +) -> ConfigHolder: + cfg: ConfigHolder = ConfigHolder( + dataset_cfg=dataset_cfg, + model_cfg=GPT_CONFIGS["tiny-v1"], + train_cfg=TRAINING_CONFIGS["tiny-v1"], + ) + cfg.train_cfg.dataloader_cfg["shuffle"] = False + cfg.train_cfg.batch_size = batch_size + return cfg + + +def _get_metrics(logs: list): + # for x in logs: + # print(x) + metrics = [log[1][0] for log in logs if re.match("metric", log[0], re.IGNORECASE)] + + return metrics diff --git a/tests/unit/maze_transformer/training/config/test_train_cfg_intervals.py b/tests/unit/maze_transformer/training/config/test_train_cfg_intervals.py new file mode 100644 index 00000000..e757c873 --- /dev/null +++ b/tests/unit/maze_transformer/training/config/test_train_cfg_intervals.py @@ -0,0 +1,259 @@ +import itertools + +import pytest +from torch.optim import RMSprop + +from maze_transformer.training.config import _DEFAULT_INTERVAL_COUNTS, TrainConfig + + +def test_get_intervals_with_default_values(): + n_samples: int = 100 + config = TrainConfig( + name="test", optimizer=RMSprop, optimizer_kwargs={"lr": 0.001}, batch_size=32 + ) + intervals = config.get_intervals( + n_samples, use_defaults_if_missing=True, mod_batch_size=False + ) + assert isinstance(intervals, dict) + default_counts: dict[str, int] = _DEFAULT_INTERVAL_COUNTS() + for k, v in intervals.items(): + assert isinstance(k, str) + assert isinstance(v, int) + assert v > 0 + assert abs(v - n_samples // default_counts[k]) <= 1 + + +def test_get_intervals_with_custom_intervals(): + # inputs + batch_size: int = 5 + intervals = {"print_loss": 5, "checkpoint": 20, "eval_fast": 10, "eval_slow": 40} + # expected result + intervals_mod_batch_size = { + "print_loss": 1, + "checkpoint": 4, + "eval_fast": 2, + "eval_slow": 8, + } + + config = TrainConfig( + name="test", + optimizer=RMSprop, + optimizer_kwargs={"lr": 0.001}, + batch_size=batch_size, + intervals=intervals, + ) + + for dataset_n_samples, use_defaults in itertools.product( + [100, None], [True, False] + ): + calculated_intervals = config.get_intervals( + dataset_n_samples, + mod_batch_size=False, + use_defaults_if_missing=use_defaults, + ) + assert isinstance(calculated_intervals, dict) + assert calculated_intervals == intervals + + calculated_intervals_batched = config.get_intervals( + dataset_n_samples, mod_batch_size=True, use_defaults_if_missing=use_defaults + ) + assert isinstance(calculated_intervals_batched, dict) + assert calculated_intervals_batched == intervals_mod_batch_size + + +def test_get_intervals_with_custom_counts(): + # inputs + dataset_n_samples: int = 100 + batch_size: int = 5 + intervals_count = { + "print_loss": 2, + "checkpoint": 5, + "eval_fast": 4, + "eval_slow": 10, + } + # expected result + intervals_expected = { + "print_loss": 50, + "checkpoint": 20, + "eval_fast": 25, + "eval_slow": 10, + } + intervals_expected_batched = { + "print_loss": 10, + "checkpoint": 4, + "eval_fast": 5, + "eval_slow": 2, + } + + config = TrainConfig( + name="test", + optimizer=RMSprop, + optimizer_kwargs={"lr": 0.001}, + batch_size=batch_size, + intervals_count=intervals_count, + ) + + for use_defaults in [True, False]: + calculated_intervals = config.get_intervals( + dataset_n_samples, + mod_batch_size=False, + use_defaults_if_missing=use_defaults, + ) + assert isinstance(calculated_intervals, dict) + assert calculated_intervals == intervals_expected + + calculated_intervals_batched = config.get_intervals( + dataset_n_samples, mod_batch_size=True, use_defaults_if_missing=use_defaults + ) + assert isinstance(calculated_intervals_batched, dict) + assert calculated_intervals_batched == intervals_expected_batched + + +def _plus_minus_proportion( + value: float, proportion: float = 0.1 +) -> tuple[float, float]: + return ( + value * (1 - proportion), + value * (1 + proportion), + ) + + +def _in_interval(value: float, interval: tuple[float, float]) -> bool: + return interval[0] <= value <= interval[1] + + +def test_get_intervals_with_custom_counts_approx(): + # inputs + dataset_n_samples: int = 100_000 + batch_size: int = 5 + intervals_count = { + "print_loss": 1000, + "checkpoint": 10, + "eval_fast": 100, + "eval_slow": 20, + } + # expected result + intervals_expected = { + "print_loss": _plus_minus_proportion(100), + "checkpoint": _plus_minus_proportion(10_000), + "eval_fast": _plus_minus_proportion(1000), + "eval_slow": _plus_minus_proportion(5000), + } + intervals_expected_batched = { + "print_loss": _plus_minus_proportion(20), + "checkpoint": _plus_minus_proportion(2_000), + "eval_fast": _plus_minus_proportion(200), + "eval_slow": _plus_minus_proportion(1000), + } + + config = TrainConfig( + name="test", + optimizer=RMSprop, + optimizer_kwargs={"lr": 0.001}, + batch_size=batch_size, + intervals_count=intervals_count, + ) + + for use_defaults in [True, False]: + calculated_intervals = config.get_intervals( + dataset_n_samples, + mod_batch_size=False, + use_defaults_if_missing=use_defaults, + ) + assert isinstance(calculated_intervals, dict) + for k, v in calculated_intervals.items(): + assert _in_interval(v, intervals_expected[k]) + + calculated_intervals_batched = config.get_intervals( + dataset_n_samples, mod_batch_size=True, use_defaults_if_missing=use_defaults + ) + assert isinstance(calculated_intervals_batched, dict) + for k, v in calculated_intervals_batched.items(): + assert _in_interval(v, intervals_expected_batched[k]) + + +def test_get_intervals_raises_with_missing_values(): + config = TrainConfig( + name="test", optimizer=RMSprop, optimizer_kwargs={"lr": 0.001}, batch_size=32 + ) + with pytest.raises(ValueError): + config.get_intervals(None, use_defaults_if_missing=False) + + +def test_get_intervals_raises_with_missing_counts_and_dataset_size(): + intervals_count = { + "print_loss": 2, + "checkpoint": 5, + "eval_fast": 4, + "eval_slow": 10, + } + config = TrainConfig( + name="test", + optimizer=RMSprop, + optimizer_kwargs={"lr": 0.001}, + batch_size=32, + intervals_count=intervals_count, + ) + with pytest.raises(ValueError): + config.get_intervals(None) + + +def test_get_intervals_with_no_mod_batch_size(): + intervals = {"print_loss": 5, "checkpoint": 20, "eval_fast": 10, "eval_slow": 40} + config = TrainConfig( + name="test", + optimizer=RMSprop, + optimizer_kwargs={"lr": 0.001}, + batch_size=32, + intervals=intervals, + ) + calculated_intervals = config.get_intervals(100, mod_batch_size=False) + assert calculated_intervals == intervals + + +def test_get_intervals_disabled_evals(): + # inputs + dataset_n_samples: int = 100 + batch_size: int = 5 + intervals_count = { + "print_loss": 2, + "checkpoint": 5, + "eval_fast": 0, + "eval_slow": 0, + } + # expected result + intervals_expected = { + "print_loss": 50, + "checkpoint": 20, + "eval_fast": float("inf"), + "eval_slow": float("inf"), + } + intervals_expected_batched = { + "print_loss": 10, + "checkpoint": 4, + "eval_fast": float("inf"), + "eval_slow": float("inf"), + } + + config = TrainConfig( + name="test", + optimizer=RMSprop, + optimizer_kwargs={"lr": 0.001}, + batch_size=batch_size, + intervals_count=intervals_count, + ) + + for use_defaults in [True, False]: + calculated_intervals = config.get_intervals( + dataset_n_samples, + mod_batch_size=False, + use_defaults_if_missing=use_defaults, + ) + assert isinstance(calculated_intervals, dict) + assert calculated_intervals == intervals_expected + + calculated_intervals_batched = config.get_intervals( + dataset_n_samples, mod_batch_size=True, use_defaults_if_missing=use_defaults + ) + assert isinstance(calculated_intervals_batched, dict) + assert calculated_intervals_batched == intervals_expected_batched diff --git a/tests/unit/maze_transformer/training/config/test_train_config.py b/tests/unit/maze_transformer/training/config/test_train_config.py index 103c2f23..6e2e9e64 100644 --- a/tests/unit/maze_transformer/training/config/test_train_config.py +++ b/tests/unit/maze_transformer/training/config/test_train_config.py @@ -22,6 +22,7 @@ def test_serialize_custom_values(): def test_load_custom_values(): loaded = TrainConfig.load(_custom_serialized_config()) assert loaded.optimizer == torch.optim.SGD + assert loaded.diff(_custom_train_config()) == {} assert loaded == _custom_train_config() @@ -32,8 +33,14 @@ def _custom_train_config() -> TrainConfig: optimizer_kwargs=dict(lr=0.01, momentum=0.9), batch_size=64, dataloader_cfg=dict(num_workers=8, drop_last=False), - print_loss_interval=500, - checkpoint_interval=1000, + intervals=dict( + print_loss=100, + checkpoint=10, + eval_fast=20, + eval_slow=10, + ), + evals_max_new_tokens=16, + validation_dataset_cfg=100, ) @@ -44,8 +51,15 @@ def _custom_serialized_config() -> Dict[Any, Any]: "optimizer_kwargs": {"lr": 0.01, "momentum": 0.9}, "batch_size": 64, "dataloader_cfg": {"num_workers": 8, "drop_last": False}, - "print_loss_interval": 500, - "checkpoint_interval": 1000, + "intervals": { + "print_loss": 100, + "checkpoint": 10, + "eval_fast": 20, + "eval_slow": 10, + }, + "intervals_count": None, + "evals_max_new_tokens": 16, + "validation_dataset_cfg": 100, "__format__": "TrainConfig(SerializableDataclass)", } diff --git a/tests/unit/maze_transformer/training/test_training.py b/tests/unit/maze_transformer/training/test_get_dataloader.py similarity index 100% rename from tests/unit/maze_transformer/training/test_training.py rename to tests/unit/maze_transformer/training/test_get_dataloader.py