diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..5c1416eb Binary files /dev/null and b/.DS_Store differ diff --git a/maze_transformer/evaluation/plot_maze.py b/maze_transformer/evaluation/plot_maze.py index 10ceff1e..cf76d1c3 100644 --- a/maze_transformer/evaluation/plot_maze.py +++ b/maze_transformer/evaluation/plot_maze.py @@ -6,10 +6,9 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np -from jaxtyping import Float +from jaxtyping import Array, Bool, Float from matplotlib.cm import ScalarMappable from matplotlib.colors import ListedColormap, Normalize -from muutils.tensor_utils import NDArray from maze_transformer.generation.lattice_maze import Coord, CoordArray, LatticeMaze @@ -225,8 +224,9 @@ def show(self, dpi: int = 100, title: str = "") -> None: self.plot(dpi=dpi, title=title) plt.show() - def _rowcol_to_coord(self, point: Coord) -> NDArray: - """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" + def _rowcol_to_coord(self, point: Coord) -> Float[Array, "2"]: + """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x + is the horizontal axis.""" point = np.array([point[1], point[0]]) return self.unit_length * (point + 0.5) @@ -286,7 +286,7 @@ def _plot_maze(self) -> None: self.ax.imshow(img, cmap=cmap, vmin=-1, vmax=1) - def _lattice_maze_to_img(self) -> NDArray["row col", bool]: + def _lattice_maze_to_img(self) -> Bool[Array, "row col"]: """ Build an image to visualise the maze. Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. @@ -317,7 +317,7 @@ def _lattice_maze_to_img(self) -> NDArray["row col", bool]: connection_values = scaled_node_values # Create background image (all pixels set to -1, walls everywhere) - img: NDArray["row col", float] = -np.ones( + img: Float[Array, "row col"] = -np.ones( ( self.maze.grid_shape[0] * self.unit_length + 1, self.maze.grid_shape[1] * self.unit_length + 1, @@ -351,12 +351,13 @@ def _lattice_maze_to_img(self) -> NDArray["row col", bool]: return img def _plot_path(self, path_format: PathFormat) -> None: - p_transformed = np.array( + p_transformed: Float[Array, "coord 2"] = np.array( [self._rowcol_to_coord(coord) for coord in path_format.path] ) + if path_format.quiver_kwargs is not None: - x: NDArray = p_transformed[:, 0] - y: NDArray = p_transformed[:, 1] + x: Float[Array, "x"] = p_transformed[:, 0] + y: Float[Array, "y"] = p_transformed[:, 1] self.ax.quiver( x[:-1], y[:-1], @@ -392,7 +393,7 @@ def _plot_path(self, path_format: PathFormat) -> None: ms=10, ) - def as_ascii(self, start=None, end=None): + def as_ascii(self, start=None, end=None) -> str: """ Returns an ASCII visualization of the maze. Courtesy of ChatGPT diff --git a/maze_transformer/generation/constants.py b/maze_transformer/generation/constants.py index cd5f197c..adf5aae6 100644 --- a/maze_transformer/generation/constants.py +++ b/maze_transformer/generation/constants.py @@ -1,9 +1,9 @@ import numpy as np -from muutils.tensor_utils import NDArray +from jaxtyping import Array, Int, Int8 -Coord = NDArray["x y", np.int8] +Coord = Int8[Array, "x y"] CoordTup = tuple[int, int] -CoordArray = NDArray["coords", np.int8] +CoordArray = Int8[Array, "coords"] SPECIAL_TOKENS: dict[str, str] = dict( adj_list_start="", @@ -19,7 +19,7 @@ padding="", ) -DIRECTIONS_MAP: NDArray["direction axes", int] = np.array( +DIRECTIONS_MAP: Int[Array, "direction axes"] = np.array( [ [0, 1], # down [0, -1], # up @@ -29,7 +29,7 @@ ) -NEIGHBORS_MASK: NDArray["coord point", int] = np.array( +NEIGHBORS_MASK: Int[Array, "coord point"] = np.array( [ [0, 1], # down [0, -1], # up diff --git a/maze_transformer/generation/utils.py b/maze_transformer/generation/utils.py index c87cb3e5..cc0c7490 100644 --- a/maze_transformer/generation/utils.py +++ b/maze_transformer/generation/utils.py @@ -1,12 +1,12 @@ import math import numpy as np -from muutils.tensor_utils import NDArray +from jaxtyping import Array, Bool def bool_array_from_string( string: str, shape: list[int], true_symbol: str = "T" -) -> NDArray: +) -> Bool[Array, "..."]: """Transform a string into an ndarray of bools. Parameters @@ -20,8 +20,8 @@ def bool_array_from_string( Returns ------- - NDArray - A ndarray with dtype bool. + Bool[Array, "..."] + An ndarray array with dtype bool and an unknown shape. Examples -------- diff --git a/maze_transformer/training/tokenizer.py b/maze_transformer/training/tokenizer.py index 9f08f61e..e7e79244 100644 --- a/maze_transformer/training/tokenizer.py +++ b/maze_transformer/training/tokenizer.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Union # need Union as "a" | "b" doesn't work import torch -from muutils.tensor_utils import ATensor, NDArray +from muutils.tensor_utils import ATensor from transformers import PreTrainedTokenizer from transformers.tokenization_utils import BatchEncoding @@ -140,7 +140,7 @@ def batch_decode( def to_ascii( self, sequence: list[int | str] | ATensor, start=None, end=None - ) -> NDArray: + ) -> str: # Sequence should be a single maze (not batch) if isinstance(sequence, list) and isinstance(sequence[0], str): str_sequence = sequence # already decoded