-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
36defab
commit b32380f
Showing
38 changed files
with
477 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/usr/bin/env python | ||
|
||
"""An abstract class that specifies the Agent API for RL-Glue-py. | ||
""" | ||
|
||
from __future__ import print_function | ||
from abc import ABCMeta, abstractmethod | ||
|
||
|
||
class BaseAgent: | ||
"""Implements the agent for an RL-Glue environment. | ||
Note: | ||
agent_init, agent_start, agent_step, agent_end, agent_cleanup, and | ||
agent_message are required methods. | ||
""" | ||
|
||
__metaclass__ = ABCMeta | ||
|
||
def __init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def agent_init(self, agent_info= {}): | ||
"""Setup for the agent called when the experiment first starts.""" | ||
|
||
@abstractmethod | ||
def agent_start(self, observation): | ||
"""The first method called when the experiment starts, called after | ||
the environment starts. | ||
Args: | ||
observation (Numpy array): the state observation from the environment's evn_start function. | ||
Returns: | ||
The first action the agent takes. | ||
""" | ||
|
||
@abstractmethod | ||
def agent_step(self, reward, observation): | ||
"""A step taken by the agent. | ||
Args: | ||
reward (float): the reward received for taking the last action taken | ||
observation (Numpy array): the state observation from the | ||
environment's step based, where the agent ended up after the | ||
last step | ||
Returns: | ||
The action the agent is taking. | ||
""" | ||
|
||
@abstractmethod | ||
def agent_end(self, reward): | ||
"""Run when the agent terminates. | ||
Args: | ||
reward (float): the reward the agent received for entering the terminal state. | ||
""" | ||
|
||
@abstractmethod | ||
def agent_cleanup(self): | ||
"""Cleanup done after the agent ends.""" | ||
|
||
@abstractmethod | ||
def agent_message(self, message): | ||
"""A function used to pass information from the agent to the experiment. | ||
Args: | ||
message: The message passed to the agent. | ||
Returns: | ||
The response (or answer) to the message. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!/usr/bin/env python | ||
|
||
"""Abstract environment base class for RL-Glue-py. | ||
""" | ||
|
||
from __future__ import print_function | ||
|
||
from abc import ABCMeta, abstractmethod | ||
|
||
|
||
class BaseEnvironment: | ||
"""Implements the environment for an RLGlue environment | ||
Note: | ||
env_init, env_start, env_step, env_cleanup, and env_message are required | ||
methods. | ||
""" | ||
|
||
__metaclass__ = ABCMeta | ||
|
||
def __init__(self): | ||
reward = None | ||
state = None | ||
termination = None | ||
self.reward_state_term = (reward, state, termination) | ||
|
||
@abstractmethod | ||
def env_init(self, env_info={}): | ||
"""Setup for the environment called when the experiment first starts. | ||
Note: | ||
Initialize a tuple with the reward, first state, boolean | ||
indicating if it's terminal. | ||
""" | ||
|
||
@abstractmethod | ||
def env_start(self): | ||
"""The first method called when the experiment starts, called before the | ||
agent starts. | ||
Returns: | ||
The first state from the environment. | ||
""" | ||
|
||
@abstractmethod | ||
def env_step(self, action): | ||
"""A step taken by the environment. | ||
Args: | ||
action: The action taken by the agent | ||
Returns: | ||
(float, state, Boolean): a tuple of the reward, state, | ||
and boolean indicating if it's terminal. | ||
""" | ||
|
||
@abstractmethod | ||
def env_cleanup(self): | ||
"""Cleanup done after the environment ends""" | ||
|
||
@abstractmethod | ||
def env_message(self, message): | ||
"""A message asking the environment for information | ||
Args: | ||
message: the message passed to the environment | ||
Returns: | ||
the response (or answer) to the message | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
# Function to plot result | ||
def plot_result(agent_parameters, directory): | ||
|
||
true_V = np.load('/content/drive/MyDrive/Colab Notebooks/RL_assignments/semi_TD/data/true_V.npy') | ||
|
||
for num_g in agent_parameters["num_groups"]: | ||
plt1_agent_sweeps = [] | ||
plt2_agent_sweeps = [] | ||
|
||
# two plots: learned state-value and learning curve (RMSVE) | ||
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,5)) | ||
|
||
for step_size in agent_parameters["step_size"]: | ||
|
||
# plot1 | ||
filename = 'V_TD_agent_agg_states_{}_step_size_{}'.format(num_g, step_size).replace('.','') | ||
current_agent_V = np.load('/content/drive/MyDrive/Colab Notebooks/RL_assignments/semi_TD/{}/{}.npy'.format(directory, filename)) | ||
|
||
plt1_x_legend = range(1,len(current_agent_V[:]) + 1) | ||
graph_current_agent_V, = ax[0].plot(plt1_x_legend, current_agent_V[:], label="approximate values: state aggregation: {}, step-size: {}".format(num_g, step_size)) | ||
plt1_agent_sweeps.append(graph_current_agent_V) | ||
|
||
# plot2 | ||
filename = 'RMSVE_TD_agent_agg_states_{}_step_size_{}'.format(num_g, step_size).replace('.','') | ||
current_agent_RMSVE = np.load('/content/drive/MyDrive/Colab Notebooks/RL_assignments/semi_TD/{}/{}.npy'.format(directory, filename)) | ||
|
||
plt2_x_legend = range(1,len(current_agent_RMSVE[:]) + 1) | ||
graph_current_agent_RMSVE, = ax[1].plot(plt2_x_legend, current_agent_RMSVE[:], label="approximate values: state aggregation: {}, step-size: {}".format(num_g, step_size)) | ||
plt2_agent_sweeps.append(graph_current_agent_RMSVE) | ||
|
||
|
||
# plot1: | ||
# add True V | ||
plt1_x_legend = range(1,len(true_V[:]) + 1) | ||
graph_true_V, = ax[0].plot(plt1_x_legend, true_V[:], label="$v_\pi$") | ||
|
||
ax[0].legend(handles=[*plt1_agent_sweeps, graph_true_V]) | ||
|
||
ax[0].set_title("Learned State Value after 2000 episodes") | ||
ax[0].set_xlabel('State') | ||
ax[0].set_ylabel('Value\n scale', rotation=0, labelpad=15) | ||
|
||
plt1_xticks = [1, 100, 200, 300, 400, 500]#, 600, 700, 800, 900, 1000] | ||
plt1_yticks = [-1.0, 0.0, 1.0] | ||
ax[0].set_xticks(plt1_xticks) | ||
ax[0].set_xticklabels(plt1_xticks) | ||
ax[0].set_yticks(plt1_yticks) | ||
ax[0].set_yticklabels(plt1_yticks) | ||
|
||
|
||
# plot2: | ||
ax[1].legend(handles=plt2_agent_sweeps) | ||
|
||
ax[1].set_title("Learning Curve") | ||
ax[1].set_xlabel('Episodes') | ||
ax[1].set_ylabel('RMSVE\n averaged over 50 runs', rotation=0, labelpad=40) | ||
|
||
plt2_xticks = range(0, 210, 20) # [0, 10, 20, 30, 40, 50, 60, 70, 80] | ||
plt2_xticklabels = range(0, 2100, 200) # [0, 100, 200, 300, 400, 500, 600, 700, 800] | ||
plt2_yticks = [0, 0.1, 0.2, 0.3, 0.4, 0.5] | ||
ax[1].set_xticks(plt2_xticks) | ||
ax[1].set_xticklabels(plt2_xticklabels) | ||
ax[1].set_yticks(plt2_yticks) | ||
ax[1].set_yticklabels(plt2_yticks) | ||
|
||
plt.tight_layout() | ||
plt.suptitle("{}-State Aggregation".format(num_g),fontsize=16, fontweight='bold', y=1.03) | ||
plt.show() |
Oops, something went wrong.