Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding regression plots to test script, add opset option to export #8

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ dependencies:
- jupyterlab
- seaborn
- rich

- matplotlib
- scipy
6 changes: 2 additions & 4 deletions spanet/dataset/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ def __init__(self, event_info: EventInfo):

cluster_group = self.target_groups[names[0]]
for name in names:
assert (
self.target_groups[name] == cluster_group,
"Invalid Symmetry Group. Invariant targets have different structures."
)
assert self.target_groups[name] == cluster_group, "Invalid Symmetry Group. Invariant targets have different structures."


cluster_groups.append((cluster_name, names, cluster_group))

Expand Down
2 changes: 1 addition & 1 deletion spanet/dataset/jet_reconstruction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
self.num_events = limit_index.shape[0]
self.num_vectors = sum(source.num_vectors() for source in self.sources.values())

print(f"Index Range: {limit_index[0]}...{limit_index[-1]}")
#print(f"Index Range: {limit_index[0]}...{limit_index[-1]}")

# Optionally remove any events where any of the targets are missing.
if not partial_events:
Expand Down
9 changes: 6 additions & 3 deletions spanet/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ def main(
input_log_transform: bool,
output_log_transform: bool,
output_embeddings: bool,
gpu: bool
gpu: bool,
opset_version: int
):
model = load_model(log_directory, cuda=gpu)

# Create wrapped model with flat inputs and outputs
wrapped_model = WrappedModel(model, input_log_transform, output_log_transform, output_embeddings)
wrapped_model.to(model.device)
Expand All @@ -153,7 +154,7 @@ def main(
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=13
opset_version=opset_version
)


Expand All @@ -177,5 +178,7 @@ def main(
parser.add_argument("--output-embeddings", action="store_true",
help="Exported model will also output the embeddings for every part of the event.")

parser.add_argument("--opset-version", type=int, default=13, help="Opset version to use in export.")

arguments = parser.parse_args()
main(**arguments.__dict__)
99 changes: 96 additions & 3 deletions spanet/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from spanet.evaluation import evaluate_on_test_dataset, load_model
from spanet.dataset.types import Evaluation

import matplotlib.pyplot as plt
import math
import os
from scipy.stats import wasserstein_distance

def formatter(value: Any) -> str:
""" A monolithic formatter function to convert possible values to output strings.
Expand Down Expand Up @@ -220,19 +224,99 @@ def evaluate_predictions(predictions: ArrayLike, num_vectors: ArrayLike, targets

return results, jet_limits, evaluator.clusters


def plot_regression_performance(key, predictions, targets, outdir):

if not os.path.isdir(outdir): os.mkdir(outdir)

key_filename = key.replace("/", "_")

nbins = 50

# Comparison of target and prediction
plt.subplots(2,1, height_ratios=[3,1])

plt.subplot(2,1,1)
pred_hist, bins, _ = plt.hist(predictions, nbins, density=1, alpha=0.5, label="{} Predicted".format(key))
targ_hist, _, _ = plt.hist(targets, bins=bins, density=1, alpha=0.5, label="{} Target".format(key))
plt.legend()
plt.ylabel("A.U.")

EMD=wasserstein_distance(pred_hist, targ_hist)
plt.text(0.02, 1.02, "EMD = {}".format(EMD), transform=plt.gca().transAxes)

# Ratio pad
plt.subplot(2,1,2)
ratio = [p/t if not t==0 else 0 for p,t in zip(pred_hist, targ_hist)]
#ratio = [i if not (math.isinf(i) or math.isnan(i)) else 0.0 for i in ratio]
bincenter = 0.5 * (bins[1:] + bins[:-1])

plt.plot(bincenter, ratio)
plt.ylim([0.5, 2.0])
line = np.full(nbins, 1.0)
plt.plot(bincenter, line, color="black", linestyle="dashed" )
plt.ylabel("Pred/Target")
plt.xlabel(key)


plt.savefig("{}/{}.png".format(outdir, key_filename))
print("Regression plot for {} saved in {}, EMD = {}".format(key, outdir, EMD))
plt.close()

# Delta plot
delta = predictions - targets
plt.hist(delta, nbins, density=1, alpha=0.5, label="{} Delta (Pred - Target)".format(key))
plt.legend()
plt.ylabel("A.U.")
plt.xlabel("$\Delta$ ({})".format(key))

plt.savefig("{}/{}_delta.png".format(outdir, key_filename))
plt.close()

# Percent err plot
percent_err = [d/t if not t==0 else 0 for d,t in zip(delta,targets)]
#percent_err = [i for i in percent_err if not (math.isinf(i) or math.isnan(i))]

plt.hist(percent_err, nbins, density=1, alpha=0.5, label="{} % Error".format(key))
plt.legend()
plt.ylabel("A.U.")
plt.xlabel("% Error ({})".format(key))

plt.savefig("{}/{}_percent_err.png".format(outdir, key_filename))
plt.close()

return()

def main(
log_directory: str,
test_file: Optional[str],
event_file: Optional[str],
batch_size: Optional[int],
lines: int,
gpu: bool,
latex: bool
latex: bool,
outdir: str,
checkpoint: Optional[str],
):
model = load_model(log_directory, test_file, event_file, batch_size, gpu)

model = load_model(log_directory, test_file, event_file, batch_size, gpu, checkpoint)

evaluation = evaluate_on_test_dataset(model)


# make some plots for regressions
regressions = list(evaluation.regressions.values())
keys = list(evaluation.regressions.keys())

regression_targets = [reg.cpu().numpy() for reg in model.testing_dataset.regressions.values()]

if not outdir:
outdir = log_directory

for key, pred, truth in zip(keys, regressions, regression_targets):
plot_regression_performance(key, pred, truth, outdir)



# Flatten predictions
predictions = list(evaluation.assignments.values())

Expand All @@ -247,6 +331,9 @@ def main(
display_table(results, jet_limits, clusters)





if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("log_directory", type=str,
Expand All @@ -272,6 +359,12 @@ def main(
parser.add_argument("-tex", "--latex", action="store_true",
help="Output a latex table.")

parser.add_argument("--outdir", type=str, default=None,
help="Output directory for regression performance plots (default:log_directory)")

parser.add_argument("--checkpoint", "--chk", type=str, default=None,
help="Select which checkpoint to load")

arguments = parser.parse_args()
main(**arguments.__dict__)

Expand Down