diff --git a/pufferlib/environments/pokemon_red.py b/pufferlib/environments/pokemon_red.py index 40a77cc5..3eb7f808 100644 --- a/pufferlib/environments/pokemon_red.py +++ b/pufferlib/environments/pokemon_red.py @@ -87,6 +87,11 @@ def __init__( self.early_stopping = early_stopping self.save_video = save_video self.fast_video = fast_video + self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight'] + self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore'] + self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale'] + self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons'] + self.save_screenshots = save_screenshots self.video_interval = video_interval_mul * self.act_freq self.downsample_factor = downsample_factor @@ -100,6 +105,7 @@ def __init__( self.instance_id = str(uuid.uuid4())[:8] self.s_path.mkdir(exist_ok=True) + self.reset_count = 0 self.all_runs = [] # Set this in SOME subclasses @@ -112,11 +118,15 @@ def __init__( WindowEvent.PRESS_ARROW_RIGHT, WindowEvent.PRESS_ARROW_UP, WindowEvent.PRESS_BUTTON_A, - WindowEvent.PRESS_BUTTON_B, - WindowEvent.PRESS_BUTTON_START, - WindowEvent.PASS + WindowEvent.PRESS_BUTTON_B ] + if self.extra_buttons: + self.valid_actions.extend([ + WindowEvent.PRESS_BUTTON_START, + WindowEvent.PASS + ]) + self.release_arrow = [ WindowEvent.RELEASE_ARROW_DOWN, WindowEvent.RELEASE_ARROW_LEFT, @@ -154,17 +164,20 @@ def __init__( ) self.screen = self.pyboy.botsupport_manager().screen() - self.pyboy.set_emulation_speed(0 if headless else 6) + if not config['headless']: + self.pyboy.set_emulation_speed(6) self.reset() def reset(self, seed=None): self.seed = seed - # restart game, skipping credits with open(self.init_state, "rb") as f: self.pyboy.load_state(f) - - self.init_knn() + + if self.use_screen_explore: + self.init_knn() + else: + self.init_map_mem self.recent_memory = np.zeros( (self.output_shape[1]*self.memory_height, 3), @@ -176,6 +189,8 @@ def reset(self, seed=None): ), dtype=np.uint8, ) + self.agent_stats = [] + if self.save_video: base_dir = self.s_path / Path('rollouts') base_dir.mkdir(exist_ok=True) @@ -199,10 +214,20 @@ def reset(self, seed=None): self.death_count = 0 self.step_count = 0 self.reset_count += 1 - self.compute_rewards() self.total_reward = sum(self.rewards.values()) - + self.levels_satisfied = False + self.base_explore = 0 + self.max_opponent_level = 0 + self.max_event_rew = 0 + self.max_level_rew = 0 + self.last_health = 1 + self.total_healing_rew = 0 + self.died_count = 0 + self.step_count = 0 + self.progress_reward = self.get_game_state_reward() + self.total_reward = sum([val for _, val in self.progress_reward.items()]) + self.reset_count += 1 return self.render(), {} @@ -213,6 +238,9 @@ def init_knn(self): self.knn_index.init_index( max_elements=self.num_elements, ef_construction=100, M=16) + def init_map_mem(self): + self.seen_coords = {} + def render(self, reduce_res=True, add_memory=True, update_mem=True): game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3) @@ -240,9 +268,10 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True): ), axis=0) return game_pixels_render - + def step(self, action): self.run_action_on_emulator(action) + self.append_agent_stats(action) self.recent_frames = np.roll(self.recent_frames, 1, axis=0) obs_memory = self.render() @@ -252,9 +281,15 @@ def step(self, action): obs_flat = obs_memory[ frame_start:frame_start+self.output_shape[0], ...].flatten().astype(np.float32) - self.update_frame_knn_index(obs_flat) + if self.use_screen_explore: + self.update_frame_knn_index(obs_flat) + else: + self.update_seen_coords() + + self.update_heal_reward() new_reward, new_prog = self.compute_rewards() - self.cfg["state_params"]["health"] = self.read_hp_fraction() + self.last_health = self.read_hp_fraction() + #self.cfg["state_params"]["health"] = self.read_hp_fraction() # shift over short term reward memory self.recent_memory = np.roll(self.recent_memory, 3) @@ -272,6 +307,8 @@ def step(self, action): def run_action_on_emulator(self, action): # press button then release after some steps self.pyboy.send_input(self.valid_actions[action]) + if not self.save_video and self.headless: + self.pyboy._rendering(False) for i in range(self.act_freq): # release action, so they are stateless if i == 8: @@ -281,14 +318,16 @@ def run_action_on_emulator(self, action): if action > 3 and action < 6: # release button self.pyboy.send_input(self.release_button[action - 4]) - if action == WindowEvent.PRESS_BUTTON_START: + if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START: self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START) if self.save_video and not self.fast_video: self.add_video_frame() + if i == self.act_freq-1: + self.pyboy._rendering(True) self.pyboy.tick() if self.save_video and self.fast_video: self.add_video_frame() - + def add_video_frame(self): self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False)) self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False)) @@ -298,6 +337,10 @@ def get_agent_stats(self, action): y_pos = self.read_m(0xD361) map_n = self.read_m(0xD35E) levels = [self.read_m(a) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]] + if self.use_screen_explore: + expl = ('frames', self.knn_index.get_current_count()) + else: + expl = ('coord_count', len(self.seen_coords)) return { 'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n, 'last_action': action, @@ -305,7 +348,18 @@ def get_agent_stats(self, action): 'hp': self.read_hp_fraction(), 'frames': self.knn_index.get_current_count(), 'deaths': self.death_count, 'badge': self.get_badges(), - 'event': self.rewards["event"], 'healr': self.cfg["reward_params"]["total_healing_rew"] + #'event': self.rewards["event"], 'healr': self.cfg["reward_params"]["total_healing_rew"] + 'event': self.reward_scale*self.update_max_event_rew(), + #'party_xp': self.reward_scale*0.1*sum(poke_xps), + 'level': self.reward_scale*self.get_levels_reward(), + 'heal': self.reward_scale*self.total_healing_rew, + 'op_lvl': self.reward_scale*self.update_max_op_level(), + 'dead': self.reward_scale*-0.1*self.died_count, + 'badge': self.reward_scale*self.get_badges() * 5, + #'op_poke': self.reward_scale*self.max_opponent_poke * 800, + #'money': self.reward_scale* money * 3, + #'seen_poke': self.reward_scale * seen_poke_count * 400, + 'explore': self.reward_scale * self.get_knn_reward() } def update_frame_knn_index(self, frame_vec): @@ -327,17 +381,53 @@ def update_frame_knn_index(self, frame_vec): self.knn_index.add_items( frame_vec, np.array([self.knn_index.get_current_count()])) + def update_seen_coords(self): + x_pos = self.read_m(0xD362) + y_pos = self.read_m(0xD361) + map_n = self.read_m(0xD35E) + coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}" + if self.get_levels_sum() >= 22 and not self.levels_satisfied: + self.levels_satisfied = True + self.base_explore = len(self.seen_coords) + self.seen_coords = {} + + self.seen_coords[coord_string] = self.step_count + def compute_rewards(self): # addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map # https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm self.rewards_old = self.rewards.copy() + # adds up all event flags, exclude museum ticket + event_flags_start = 0xD747 + event_flags_end = 0xD886 + museum_ticket = (0xD754, 0) + base_event_flags = 13 + return max( + sum( + [ + self.bit_count(self.read_m(i)) + for i in range(event_flags_start, event_flags_end) + ] + ) + - base_event_flags + - int(self.read_bit(museum_ticket[0], museum_ticket[1])), + 0,) + # healing reward curr_health = self.read_hp_fraction() self.rewards["healing"] = self.cfg["rewards"]["healing_scale"] * max(0, curr_health - self.cfg["state_params"]["health"]) if self.cfg["state_params"]["health"] <= 0: self.death_count += 1 self.cfg["state_params"]["health"] = curr_health + ''' + # Not sure where to integrate + prog = self.progress_reward + # these values are only used by memory + return (prog['level'] * 100 / self.reward_scale, + self.read_hp_fraction()*2000, + prog['explore'] * 150 / (self.explore_weight * self.reward_scale)) + ''' # event reward curr_event_rew = max(sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - 13, 0) @@ -362,6 +452,9 @@ def compute_rewards(self): self.rewards["badges"] = self.cfg["rewards"]["badge_scale"] * self.get_badges() # exploration reward + pre_rew = self.explore_weight * 0.005 + post_rew = self.explore_weight * 0.01 + cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords) curr_size = self.knn_index.get_current_count() base = (self.cfg["state_params"]["base_explore"] if self.cfg["state_params"]["levels_satisfied"] else curr_size) * self.cfg["rewards"]["knn_pre_scale"] post = (curr_size if self.cfg["state_params"]["levels_satisfied"] else 0) * self.cfg["rewards"]["knn_post_scale"] @@ -430,9 +523,12 @@ def save_and_print_info(self, done, obs_memory): print(f'\r{prog_string}', end='', flush=True) if self.step_count % 50 == 0: - plt.imsave( - self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), - self.render(reduce_res=False) + try: + plt.imsave( + self.s_path / Path(f'curframe_{self.instance_id}.jpeg'), + self.render(reduce_res=False)) + except Exception as e: + print(f"Error saving iamge: {e}") ) if self.print_rewards and done: @@ -440,14 +536,18 @@ def save_and_print_info(self, done, obs_memory): if self.save_final_state: fs_path = self.s_path / Path('final_states') fs_path.mkdir(exist_ok=True) - plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), - obs_memory - ) - plt.imsave( - fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), - self.render(reduce_res=False) - ) + try: + plt.imsave( + fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'), + obs_memory) + except Exception as e: + print(f"Error saving image: {e}") + try: + plt.imsave( + fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'), + self.render(reduce_res=False)) + except Exception as e: + print(f"Error saving iamge: {e}") if self.save_video and done: self.full_frame_writer.close() @@ -476,9 +576,12 @@ def read_party(self): def save_screenshot(self, name): ss_dir = self.s_path / Path('screenshots') ss_dir.mkdir(exist_ok=True) - plt.imsave( - ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), - self.render(reduce_res=False)) + try: + plt.imsave( + ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'), + self.render(reduce_res=False)) + except Exception as e: + print(f"Error saving iamge: {e}") def read_hp_fraction(self): hp_sum = sum([self.read_hp(add) for add in [0xD16C, 0xD198, 0xD1C4, 0xD1F0, 0xD21C, 0xD248]])