diff --git a/scripts/plot.py b/scripts/plot.py index 3312ebb..4144e5e 100644 --- a/scripts/plot.py +++ b/scripts/plot.py @@ -33,8 +33,14 @@ def get_colors(log_dirs: list[str]): policies = [] for log_dir in log_dirs: log_dir = log_dir if log_dir.endswith("/") else log_dir + "/" - prefix = log_dir.split("/")[-2].split("-")[0] - policy = log_dir.split("/")[-2].split("_")[-1] + name = log_dir.split("/")[-2] + + prefix = name.split("_", maxsplit=1)[0] + if "__" in name: + policy = int(name.split("__")[-1]) + else: + policy = name.split("_", maxsplit=1)[-1].replace("-", " ") + if prefix not in prefixes: prefixes.append(prefix) if policy not in policies: @@ -57,8 +63,7 @@ def get_colors(log_dirs: list[str]): def smooth(scalars: list[float] | np.ndarray, weight: float) -> list[float]: """ - EMA implementation according to - https://github.com/tensorflow/tensorboard/blob/34877f15153e1a2087316b9952c931807a122aa7/tensorboard/components/vz_line_chart2/line-chart.ts#L699 + EMA implementation according to: https://github.com/tensorflow/tensorboard/blob/34877f15153e1a2087316b9952c931807a122aa7/tensorboard/components/vz_line_chart2/line-chart.ts#L699 """ last = 0 smoothed = [] @@ -94,7 +99,7 @@ def plot_tensorboard(log_dirs: list[str], tags: list[str], ema_weight=0.5, show= line_width = 0.5 - for j, log_dir in enumerate(log_dirs): + for log_dir in log_dirs: # Load TensorBoard logs event_acc = EventAccumulator(log_dir) event_acc.Reload() @@ -113,11 +118,18 @@ def plot_tensorboard(log_dirs: list[str], tags: list[str], ema_weight=0.5, show= # Plot specified tags for i, tag in enumerate(tags): # get color by n-balls and policy - prefix = log_dir.split("/")[-1].split("-")[0] - policy = log_dir.split("/")[-1].split("_")[-1] + + log_dir = log_dir if log_dir.endswith("/") else log_dir + "/" + name = log_dir.split("/")[-2] + + prefix = name.split("_", maxsplit=1)[0] + if "__" in name: + policy = int(name.split("__")[-1]) + else: + policy = name.split("_", maxsplit=1)[-1].replace("-", " ") + # if tag == "rollout/ep_rew_mean" and "reg" in policy: + # color = colors[prefix]["random-policy"] color = colors[prefix][policy] - if tag == "rollout/ep_rew_mean" and "reg" in policy: - color = colors[prefix]["random-policy"] if tag not in scalar_tags: print(f"Tag '{tag}' not found in '{log_dir}'.") @@ -130,13 +142,12 @@ def plot_tensorboard(log_dirs: list[str], tags: list[str], ema_weight=0.5, show= def action_space(n: int): return r"$\mathcal{A}_" + f"{n}" + r"$" - label = graph_name.replace("_", ", ").replace("-", " ") - if "cart" in graph_name.lower(): - label = label.split(", ")[0] + f", {action_space(2)}" - elif "reg" in graph_name.lower(): - label = label.split(", ")[0] + f", {action_space(1)}" - elif "random" in graph_name.lower(): - label = label.split(", ")[0] + ", random policy" + if "__" in graph_name: + suffix = action_space(int(graph_name.split("__")[-1])) + else: + suffix = graph_name.split("_", maxsplit=1)[-1].replace("-", " ") + + label = graph_name.split("_", maxsplit=1)[0] + ", " + suffix if ema_weight > 0: # Plot original data with lower opacity @@ -179,7 +190,11 @@ def action_space(n: int): plot_name = "plot-" + "-".join(tags).replace("/", "_") + ".pdf" plot_path = os.path.join(plot_dir, plot_name) os.makedirs(plot_dir, exist_ok=True) - plt.savefig(plot_path) + # plt.savefig(plot_path) + + # save tight layout + plt.savefig(plot_path, bbox_inches="tight", pad_inches=0) + print(f"Plot saved as '{plot_path}'") if show: plt.show()