diff --git a/maze_transformer/generation/generators.py b/maze_transformer/generation/generators.py index c609268e..56b88b2c 100644 --- a/maze_transformer/generation/generators.py +++ b/maze_transformer/generation/generators.py @@ -100,9 +100,9 @@ def gen_dfs( ) @classmethod - def gen_dfs_with_solution(cls, grid_shape: Coord): + def gen_dfs_with_solution(cls, grid_shape: Coord) -> SolvedMaze: maze = cls.gen_dfs(grid_shape) - solution = np.array(maze.generate_random_path()) + solution = maze.get_shortest_path_between_random_points() return SolvedMaze(maze, solution) @@ -199,6 +199,13 @@ def neighbor(current: Coord, direction: int) -> Coord: ), ) + @classmethod + def gen_wilson_with_solution(cls, grid_shape: Coord) -> SolvedMaze: + maze = cls.gen_wilson(grid_shape) + solution = maze.get_shortest_path_between_random_points() + + return SolvedMaze(maze, solution) + GENERATORS_MAP: dict[str, Callable[[Coord, Any], "LatticeMaze"]] = { "gen_dfs": LatticeMazeGenerators.gen_dfs, diff --git a/maze_transformer/generation/lattice_maze.py b/maze_transformer/generation/lattice_maze.py index 2c6f9cef..c8bdab51 100644 --- a/maze_transformer/generation/lattice_maze.py +++ b/maze_transformer/generation/lattice_maze.py @@ -143,7 +143,7 @@ def find_shortest_path( self, c_start: CoordTup, c_end: CoordTup, - ) -> list[Coord]: + ) -> list[CoordTup]: """find the shortest path between two coordinates, using A*""" g_score: dict[ @@ -215,7 +215,7 @@ def get_nodes(self) -> list[Coord]: for col in range(self.grid_shape[1]) ] - def generate_random_path(self) -> list[Coord]: + def get_shortest_path_between_random_points(self) -> list[CoordTup]: """ "return a path between randomly chosen start and end nodes""" # we can't create a "path" in a single-node maze diff --git a/maze_transformer/training/tokenizer.py b/maze_transformer/training/tokenizer.py index 9f08f61e..6387dac9 100644 --- a/maze_transformer/training/tokenizer.py +++ b/maze_transformer/training/tokenizer.py @@ -19,7 +19,7 @@ def maze_to_tokens( maze: LatticeMaze, - solution: list[Coord], + solution: list[CoordTup], node_token_map: dict[CoordTup, str], ) -> list[str]: """serialize maze and solution to tokens""" @@ -29,9 +29,9 @@ def maze_to_tokens( *chain.from_iterable( [ [ - node_token_map[tuple(c_s.tolist())], + node_token_map[tuple(list(c_s))], SPECIAL_TOKENS["connector"], - node_token_map[tuple(c_e.tolist())], + node_token_map[tuple(list(c_e))], SPECIAL_TOKENS["adjacency_endline"], ] for c_s, c_e in maze.as_adj_list() @@ -47,7 +47,7 @@ def maze_to_tokens( node_token_map[tuple(solution[-1])], SPECIAL_TOKENS["target_end"], SPECIAL_TOKENS["path_start"], - *[node_token_map[tuple(c.tolist())] for c in solution], + *[node_token_map[c] for c in solution], SPECIAL_TOKENS["path_end"], ] diff --git a/scripts/test_generation.py b/scripts/test_generation.py index 32ed3074..5ba7d7d8 100644 --- a/scripts/test_generation.py +++ b/scripts/test_generation.py @@ -25,7 +25,7 @@ def generate_solve_plot( if start and end: path = np.array(maze.find_shortest_path(start, end)) else: - path = np.array(maze.generate_random_path()) + path = np.array(maze.get_shortest_path_between_random_points()) print(f"solving time: {time.time() - solution_start}") diff --git a/tests/unit/maze_transformer/generation/test_generators.py b/tests/unit/maze_transformer/generation/test_generators.py index 090474e3..fae8bad3 100644 --- a/tests/unit/maze_transformer/generation/test_generators.py +++ b/tests/unit/maze_transformer/generation/test_generators.py @@ -29,3 +29,9 @@ def test_gen_dfs_with_solution(): def test_wilson_generation(): maze = LatticeMazeGenerators.gen_wilson(np.array([2, 2])) assert maze.connection_list.shape == (2, 2, 2) + + +def test_wilson_generation_with_solution(): + maze, solution = LatticeMazeGenerators.gen_wilson_with_solution(np.array([2, 2])) + assert maze.connection_list.shape == (2, 2, 2) + assert len(solution[0]) == 2 diff --git a/tests/unit/maze_transformer/generation/test_latticemaze.py b/tests/unit/maze_transformer/generation/test_latticemaze.py index bf323a1c..91da7d00 100644 --- a/tests/unit/maze_transformer/generation/test_latticemaze.py +++ b/tests/unit/maze_transformer/generation/test_latticemaze.py @@ -36,7 +36,7 @@ def test_get_nodes(): def test_generate_random_path(): maze = LatticeMazeGenerators.gen_dfs((2, 2)) - path = maze.generate_random_path() + path = maze.get_shortest_path_between_random_points() # len > 1 ensures that we have unique start and end nodes assert len(path) > 1 @@ -45,4 +45,4 @@ def test_generate_random_path(): def test_generate_random_path_size_1(): maze = LatticeMazeGenerators.gen_dfs((1, 1)) with pytest.raises(AssertionError): - maze.generate_random_path() + maze.get_shortest_path_between_random_points()