Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PWhiddy patches applied pokemon_red.py #49

Open
wants to merge 2 commits into
base: 0.5-cleanup
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 132 additions & 29 deletions pufferlib/environments/pokemon_red.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def __init__(
self.early_stopping = early_stopping
self.save_video = save_video
self.fast_video = fast_video
self.explore_weight = 1 if 'explore_weight' not in config else config['explore_weight']
self.use_screen_explore = True if 'use_screen_explore' not in config else config['use_screen_explore']
self.reward_scale = 1 if 'reward_scale' not in config else config['reward_scale']
self.extra_buttons = False if 'extra_buttons' not in config else config['extra_buttons']

self.save_screenshots = save_screenshots
self.video_interval = video_interval_mul * self.act_freq
self.downsample_factor = downsample_factor
Expand All @@ -100,6 +105,7 @@ def __init__(
self.instance_id = str(uuid.uuid4())[:8]

self.s_path.mkdir(exist_ok=True)
self.reset_count = 0
self.all_runs = []

# Set this in SOME subclasses
Expand All @@ -112,11 +118,15 @@ def __init__(
WindowEvent.PRESS_ARROW_RIGHT,
WindowEvent.PRESS_ARROW_UP,
WindowEvent.PRESS_BUTTON_A,
WindowEvent.PRESS_BUTTON_B,
WindowEvent.PRESS_BUTTON_START,
WindowEvent.PASS
WindowEvent.PRESS_BUTTON_B
]

if self.extra_buttons:
self.valid_actions.extend([
WindowEvent.PRESS_BUTTON_START,
WindowEvent.PASS
])

self.release_arrow = [
WindowEvent.RELEASE_ARROW_DOWN,
WindowEvent.RELEASE_ARROW_LEFT,
Expand Down Expand Up @@ -154,17 +164,20 @@ def __init__(
)

self.screen = self.pyboy.botsupport_manager().screen()
self.pyboy.set_emulation_speed(0 if headless else 6)
if not config['headless']:
self.pyboy.set_emulation_speed(6)
self.reset()

def reset(self, seed=None):
self.seed = seed

# restart game, skipping credits
with open(self.init_state, "rb") as f:
self.pyboy.load_state(f)

self.init_knn()

if self.use_screen_explore:
self.init_knn()
else:
self.init_map_mem

self.recent_memory = np.zeros(
(self.output_shape[1]*self.memory_height, 3),
Expand All @@ -176,6 +189,8 @@ def reset(self, seed=None):
), dtype=np.uint8,
)

self.agent_stats = []

if self.save_video:
base_dir = self.s_path / Path('rollouts')
base_dir.mkdir(exist_ok=True)
Expand All @@ -199,10 +214,20 @@ def reset(self, seed=None):
self.death_count = 0
self.step_count = 0
self.reset_count += 1

self.compute_rewards()
self.total_reward = sum(self.rewards.values())

self.levels_satisfied = False
self.base_explore = 0
self.max_opponent_level = 0
self.max_event_rew = 0
self.max_level_rew = 0
self.last_health = 1
self.total_healing_rew = 0
self.died_count = 0
self.step_count = 0
self.progress_reward = self.get_game_state_reward()
self.total_reward = sum([val for _, val in self.progress_reward.items()])
self.reset_count += 1

return self.render(), {}

Expand All @@ -213,6 +238,9 @@ def init_knn(self):
self.knn_index.init_index(
max_elements=self.num_elements, ef_construction=100, M=16)

def init_map_mem(self):
self.seen_coords = {}

def render(self, reduce_res=True, add_memory=True, update_mem=True):
game_pixels_render = self.screen.screen_ndarray() # (144, 160, 3)

Expand Down Expand Up @@ -240,9 +268,10 @@ def render(self, reduce_res=True, add_memory=True, update_mem=True):
), axis=0)

return game_pixels_render

def step(self, action):
self.run_action_on_emulator(action)
self.append_agent_stats(action)

self.recent_frames = np.roll(self.recent_frames, 1, axis=0)
obs_memory = self.render()
Expand All @@ -252,9 +281,15 @@ def step(self, action):
obs_flat = obs_memory[
frame_start:frame_start+self.output_shape[0], ...].flatten().astype(np.float32)

self.update_frame_knn_index(obs_flat)
if self.use_screen_explore:
self.update_frame_knn_index(obs_flat)
else:
self.update_seen_coords()

self.update_heal_reward()
new_reward, new_prog = self.compute_rewards()
self.cfg["state_params"]["health"] = self.read_hp_fraction()
self.last_health = self.read_hp_fraction()
#self.cfg["state_params"]["health"] = self.read_hp_fraction()

# shift over short term reward memory
self.recent_memory = np.roll(self.recent_memory, 3)
Expand All @@ -272,6 +307,8 @@ def step(self, action):
def run_action_on_emulator(self, action):
# press button then release after some steps
self.pyboy.send_input(self.valid_actions[action])
if not self.save_video and self.headless:
self.pyboy._rendering(False)
for i in range(self.act_freq):
# release action, so they are stateless
if i == 8:
Expand All @@ -281,14 +318,16 @@ def run_action_on_emulator(self, action):
if action > 3 and action < 6:
# release button
self.pyboy.send_input(self.release_button[action - 4])
if action == WindowEvent.PRESS_BUTTON_START:
if self.valid_actions[action] == WindowEvent.PRESS_BUTTON_START:
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START)
if self.save_video and not self.fast_video:
self.add_video_frame()
if i == self.act_freq-1:
self.pyboy._rendering(True)
self.pyboy.tick()
if self.save_video and self.fast_video:
self.add_video_frame()

def add_video_frame(self):
self.full_frame_writer.add_image(self.render(reduce_res=False, update_mem=False))
self.model_frame_writer.add_image(self.render(reduce_res=True, update_mem=False))
Expand All @@ -298,14 +337,29 @@ def get_agent_stats(self, action):
y_pos = self.read_m(0xD361)
map_n = self.read_m(0xD35E)
levels = [self.read_m(a) for a in [0xD18C, 0xD1B8, 0xD1E4, 0xD210, 0xD23C, 0xD268]]
if self.use_screen_explore:
expl = ('frames', self.knn_index.get_current_count())
else:
expl = ('coord_count', len(self.seen_coords))
return {
'step': self.step_count, 'x': x_pos, 'y': y_pos, 'map': map_n,
'last_action': action,
'pcount': self.read_m(0xD163), 'levels': levels, 'ptypes': self.read_party(),
'hp': self.read_hp_fraction(),
'frames': self.knn_index.get_current_count(),
'deaths': self.death_count, 'badge': self.get_badges(),
'event': self.rewards["event"], 'healr': self.cfg["reward_params"]["total_healing_rew"]
#'event': self.rewards["event"], 'healr': self.cfg["reward_params"]["total_healing_rew"]
'event': self.reward_scale*self.update_max_event_rew(),
#'party_xp': self.reward_scale*0.1*sum(poke_xps),
'level': self.reward_scale*self.get_levels_reward(),
'heal': self.reward_scale*self.total_healing_rew,
'op_lvl': self.reward_scale*self.update_max_op_level(),
'dead': self.reward_scale*-0.1*self.died_count,
'badge': self.reward_scale*self.get_badges() * 5,
#'op_poke': self.reward_scale*self.max_opponent_poke * 800,
#'money': self.reward_scale* money * 3,
#'seen_poke': self.reward_scale * seen_poke_count * 400,
'explore': self.reward_scale * self.get_knn_reward()
}

def update_frame_knn_index(self, frame_vec):
Expand All @@ -327,17 +381,53 @@ def update_frame_knn_index(self, frame_vec):
self.knn_index.add_items(
frame_vec, np.array([self.knn_index.get_current_count()]))

def update_seen_coords(self):
x_pos = self.read_m(0xD362)
y_pos = self.read_m(0xD361)
map_n = self.read_m(0xD35E)
coord_string = f"x:{x_pos} y:{y_pos} m:{map_n}"
if self.get_levels_sum() >= 22 and not self.levels_satisfied:
self.levels_satisfied = True
self.base_explore = len(self.seen_coords)
self.seen_coords = {}

self.seen_coords[coord_string] = self.step_count

def compute_rewards(self):
# addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map
# https://github.com/pret/pokered/blob/91dc3c9f9c8fd529bb6e8307b58b96efa0bec67e/constants/event_constants.asm

self.rewards_old = self.rewards.copy()

# adds up all event flags, exclude museum ticket
event_flags_start = 0xD747
event_flags_end = 0xD886
museum_ticket = (0xD754, 0)
base_event_flags = 13
return max(
sum(
[
self.bit_count(self.read_m(i))
for i in range(event_flags_start, event_flags_end)
]
)
- base_event_flags
- int(self.read_bit(museum_ticket[0], museum_ticket[1])),
0,)

# healing reward
curr_health = self.read_hp_fraction()
self.rewards["healing"] = self.cfg["rewards"]["healing_scale"] * max(0, curr_health - self.cfg["state_params"]["health"])
if self.cfg["state_params"]["health"] <= 0: self.death_count += 1
self.cfg["state_params"]["health"] = curr_health
'''
# Not sure where to integrate
prog = self.progress_reward
# these values are only used by memory
return (prog['level'] * 100 / self.reward_scale,
self.read_hp_fraction()*2000,
prog['explore'] * 150 / (self.explore_weight * self.reward_scale))
'''

# event reward
curr_event_rew = max(sum([self.bit_count(self.read_m(i)) for i in range(0xD747, 0xD886)]) - 13, 0)
Expand All @@ -362,6 +452,9 @@ def compute_rewards(self):
self.rewards["badges"] = self.cfg["rewards"]["badge_scale"] * self.get_badges()

# exploration reward
pre_rew = self.explore_weight * 0.005
post_rew = self.explore_weight * 0.01
cur_size = self.knn_index.get_current_count() if self.use_screen_explore else len(self.seen_coords)
curr_size = self.knn_index.get_current_count()
base = (self.cfg["state_params"]["base_explore"] if self.cfg["state_params"]["levels_satisfied"] else curr_size) * self.cfg["rewards"]["knn_pre_scale"]
post = (curr_size if self.cfg["state_params"]["levels_satisfied"] else 0) * self.cfg["rewards"]["knn_post_scale"]
Expand Down Expand Up @@ -430,24 +523,31 @@ def save_and_print_info(self, done, obs_memory):
print(f'\r{prog_string}', end='', flush=True)

if self.step_count % 50 == 0:
plt.imsave(
self.s_path / Path(f'curframe_{self.instance_id}.jpeg'),
self.render(reduce_res=False)
try:
plt.imsave(
self.s_path / Path(f'curframe_{self.instance_id}.jpeg'),
self.render(reduce_res=False))
except Exception as e:
print(f"Error saving iamge: {e}")
)

if self.print_rewards and done:
print('', flush=True)
if self.save_final_state:
fs_path = self.s_path / Path('final_states')
fs_path.mkdir(exist_ok=True)
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'),
obs_memory
)
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'),
self.render(reduce_res=False)
)
try:
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_small.jpeg'),
obs_memory)
except Exception as e:
print(f"Error saving image: {e}")
try:
plt.imsave(
fs_path / Path(f'frame_r{self.total_reward:.4f}_{self.reset_count}_full.jpeg'),
self.render(reduce_res=False))
except Exception as e:
print(f"Error saving iamge: {e}")

if self.save_video and done:
self.full_frame_writer.close()
Expand Down Expand Up @@ -476,9 +576,12 @@ def read_party(self):
def save_screenshot(self, name):
ss_dir = self.s_path / Path('screenshots')
ss_dir.mkdir(exist_ok=True)
plt.imsave(
ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'),
self.render(reduce_res=False))
try:
plt.imsave(
ss_dir / Path(f'frame{self.instance_id}_r{self.total_reward:.4f}_{self.reset_count}_{name}.jpeg'),
self.render(reduce_res=False))
except Exception as e:
print(f"Error saving iamge: {e}")

def read_hp_fraction(self):
hp_sum = sum([self.read_hp(add) for add in [0xD16C, 0xD198, 0xD1C4, 0xD1F0, 0xD21C, 0xD248]])
Expand Down