Skip to content

Commit

Permalink
Fix plot script
Browse files Browse the repository at this point in the history
  • Loading branch information
MadsSR committed May 23, 2024
1 parent 4fd778c commit dde4b55
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,14 @@ def get_colors(log_dirs: list[str]):
policies = []
for log_dir in log_dirs:
log_dir = log_dir if log_dir.endswith("/") else log_dir + "/"
prefix = log_dir.split("/")[-2].split("-")[0]
policy = log_dir.split("/")[-2].split("_")[-1]
name = log_dir.split("/")[-2]

prefix = name.split("_", maxsplit=1)[0]
if "__" in name:
policy = int(name.split("__")[-1])
else:
policy = name.split("_", maxsplit=1)[-1].replace("-", " ")

if prefix not in prefixes:
prefixes.append(prefix)
if policy not in policies:
Expand All @@ -57,8 +63,7 @@ def get_colors(log_dirs: list[str]):

def smooth(scalars: list[float] | np.ndarray, weight: float) -> list[float]:
"""
EMA implementation according to
https://github.com/tensorflow/tensorboard/blob/34877f15153e1a2087316b9952c931807a122aa7/tensorboard/components/vz_line_chart2/line-chart.ts#L699
EMA implementation according to: https://github.com/tensorflow/tensorboard/blob/34877f15153e1a2087316b9952c931807a122aa7/tensorboard/components/vz_line_chart2/line-chart.ts#L699
"""
last = 0
smoothed = []
Expand Down Expand Up @@ -94,7 +99,7 @@ def plot_tensorboard(log_dirs: list[str], tags: list[str], ema_weight=0.5, show=

line_width = 0.5

for j, log_dir in enumerate(log_dirs):
for log_dir in log_dirs:
# Load TensorBoard logs
event_acc = EventAccumulator(log_dir)
event_acc.Reload()
Expand All @@ -113,11 +118,18 @@ def plot_tensorboard(log_dirs: list[str], tags: list[str], ema_weight=0.5, show=
# Plot specified tags
for i, tag in enumerate(tags):
# get color by n-balls and policy
prefix = log_dir.split("/")[-1].split("-")[0]
policy = log_dir.split("/")[-1].split("_")[-1]

log_dir = log_dir if log_dir.endswith("/") else log_dir + "/"
name = log_dir.split("/")[-2]

prefix = name.split("_", maxsplit=1)[0]
if "__" in name:
policy = int(name.split("__")[-1])
else:
policy = name.split("_", maxsplit=1)[-1].replace("-", " ")
# if tag == "rollout/ep_rew_mean" and "reg" in policy:
# color = colors[prefix]["random-policy"]
color = colors[prefix][policy]
if tag == "rollout/ep_rew_mean" and "reg" in policy:
color = colors[prefix]["random-policy"]

if tag not in scalar_tags:
print(f"Tag '{tag}' not found in '{log_dir}'.")
Expand All @@ -130,13 +142,12 @@ def plot_tensorboard(log_dirs: list[str], tags: list[str], ema_weight=0.5, show=
def action_space(n: int):
return r"$\mathcal{A}_" + f"{n}" + r"$"

label = graph_name.replace("_", ", ").replace("-", " ")
if "cart" in graph_name.lower():
label = label.split(", ")[0] + f", {action_space(2)}"
elif "reg" in graph_name.lower():
label = label.split(", ")[0] + f", {action_space(1)}"
elif "random" in graph_name.lower():
label = label.split(", ")[0] + ", random policy"
if "__" in graph_name:
suffix = action_space(int(graph_name.split("__")[-1]))
else:
suffix = graph_name.split("_", maxsplit=1)[-1].replace("-", " ")

label = graph_name.split("_", maxsplit=1)[0] + ", " + suffix

if ema_weight > 0:
# Plot original data with lower opacity
Expand Down Expand Up @@ -179,7 +190,11 @@ def action_space(n: int):
plot_name = "plot-" + "-".join(tags).replace("/", "_") + ".pdf"
plot_path = os.path.join(plot_dir, plot_name)
os.makedirs(plot_dir, exist_ok=True)
plt.savefig(plot_path)
# plt.savefig(plot_path)

# save tight layout
plt.savefig(plot_path, bbox_inches="tight", pad_inches=0)

print(f"Plot saved as '{plot_path}'")
if show:
plt.show()
Expand Down

0 comments on commit dde4b55

Please sign in to comment.