Skip to content

Commit

Permalink
Merge pull request #63 from KatherLab/exp_odelia_multisite_revision
Browse files Browse the repository at this point in the history
Exp odelia multisite revision
  • Loading branch information
Ultimate-Storm authored Mar 28, 2024
2 parents 6b87a8e + ceb66e6 commit df0b484
Show file tree
Hide file tree
Showing 23 changed files with 940 additions and 714 deletions.
2 changes: 2 additions & 0 deletions workspace/automate_scripts/launch_sl/run_sn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ sudo $script_dir/../../swarm_learning_scripts/run-sn \
--cert=cert/sn-"$host_index"-cert.pem \
--capath=cert/ca/capath \
--apls-ip="$sentinel" \
-e SWARM_LOG_LEVEL=DEBUG \
-e SL_DEVMODE_KEY=REVWTU9ERS0yMDI0LTAzLTE5 \

echo "SN node started, waiting for the network to be ready"
echo "Use 'cklog --sn' to follow the logs of the SN node"
Expand Down
4 changes: 3 additions & 1 deletion workspace/automate_scripts/launch_sl/run_swci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ sudo "$script_dir/../../swarm_learning_scripts/run-swci" \
--capath="cert/ca/capath" \
-e "http_proxy=" -e "https_proxy=" --apls-ip="$sentinel" \
-e "SWCI_RUN_TASK_MAX_WAIT_TIME=5000" \
-e "SWCI_GENERIC_TASK_MAX_WAIT_TIME=5000"
-e "SWCI_GENERIC_TASK_MAX_WAIT_TIME=5000" \
-e SWARM_LOG_LEVEL=DEBUG \
-e SL_DEVMODE_KEY=REVWTU9ERS0yMDI0LTAzLTE5 \

echo "SWCI container started successfully"
echo "Use 'cklog --swci' to follow the logs of the SWCI node"
Expand Down
4 changes: 3 additions & 1 deletion workspace/automate_scripts/launch_sl/run_swop.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ sudo $script_dir/../../swarm_learning_scripts/run-swop --rm -d\
--capath=cert/ca/capath \
-e http_proxy= -e https_proxy= \
--apls-ip="$sentinel" \
-e SWOP_KEEP_CONTAINERS=True
-e SWOP_KEEP_CONTAINERS=True \
-e SWARM_LOG_LEVEL=DEBUG \
-e SL_DEVMODE_KEY=REVWTU9ERS0yMDI0LTAzLTE5 \

echo "SWOP container started"
echo "Use 'cklog --swop' to follow the logs of the SWOP node"
Expand Down
Empty file.
11 changes: 2 additions & 9 deletions workspace/odelia-breast-mri/model/env_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,8 @@ def load_environment_variables():

def load_prediction_modules(prediction_flag):
"""Dynamically load prediction modules based on the prediction flag."""
if prediction_flag == 'ext':
from predict_ext import predict
from predict_last_ext import predict_last
elif prediction_flag == 'internal':
from predict import predict
from predict_last import predict_last
else:
raise Exception("Invalid prediction flag specified")
return predict, predict_last
from predict import predict
return predict, prediction_flag

def prepare_dataset(task_data_name, data_dir):
"""Prepare the dataset based on task data name."""
Expand Down
99 changes: 67 additions & 32 deletions workspace/odelia-breast-mri/model/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
from datetime import datetime
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import os
from pathlib import Path
from torchmetrics.functional import precision_recall_curve, average_precision, f1_score, precision, recall
import torch
import logging
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torch.utils.data.dataset import Subset
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from data.datamodules import DataModule
from utils.roc_curve import plot_roc_curve, cm2acc, cm2x
import monai.networks.nets as nets
import torch
from swarmlearning.pyt import SwarmCallback
from pytorch_lightning.callbacks import Callback
Expand All @@ -41,11 +32,63 @@ def on_train_epoch_end(self, trainer, pl_module):
# self.swarmCallback.on_train_end()


class ExternalDatasetEvaluationCallback(Callback):
def __init__(self, model_dir, test_data_dir, model_name, every_n_epochs=5):
super().__init__()
self.model_dir = model_dir
self.test_data_dir = test_data_dir
self.model_name = model_name
self.every_n_epochs = every_n_epochs

def on_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
if (epoch + 1) % self.every_n_epochs == 0:
# Assuming `predict` is modified to return metrics as a dict
metrics = predict(self.model_dir, self.test_data_dir, self.model_name, last_flag=False, prediction_flag='external')
self.log_metrics(trainer, metrics)

def log_metrics(self, trainer, metrics):
logger = trainer.logger
for key, value in metrics.items():
logger.experiment.add_scalar(f"External/{key}", value, trainer.current_epoch)
class ExtendedValidationCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
# Assuming you have a validation dataloader available via the datamodule
val_loader = trainer.datamodule.val_dataloader()
all_preds, all_targets = [], []
pl_module.eval() # Ensure model is in eval mode
with torch.no_grad():
for batch in val_loader:
inputs, targets = batch
inputs = inputs.to(pl_module.device)
preds = pl_module(inputs)
all_preds.append(preds)
all_targets.append(targets)

all_preds = torch.cat(all_preds)
all_targets = torch.cat(all_targets)

# Compute metrics
f1 = f1_score(all_preds, all_targets)
ppv = precision(all_preds, all_targets)
npv = recall(all_preds, all_targets) # Using recall function as a placeholder for NPV calculation
precisions, recalls, _ = precision_recall_curve(all_preds, all_targets)
ap = average_precision(all_preds, all_targets)

# Log to TensorBoard
trainer.logger.experiment.add_scalar("Validation/F1", f1, global_step=trainer.global_step)
trainer.logger.experiment.add_scalar("Validation/PPV", ppv, global_step=trainer.global_step)
trainer.logger.experiment.add_scalar("Validation/NPV", npv, global_step=trainer.global_step)
trainer.logger.experiment.add_pr_curve("Validation/PRC", all_targets, all_preds, global_step=trainer.global_step)
trainer.logger.experiment.add_scalar("Validation/AP", ap, global_step=trainer.global_step)

pl_module.train() # Set model back to train mode

if __name__ == "__main__":
env_vars = load_environment_variables()
print('model_name: ', env_vars['model_name'])

predict, predict_last = load_prediction_modules(env_vars['prediction_flag'])
predict, prediction_flag = load_prediction_modules(env_vars['prediction_flag'])
ds, task_data_name = prepare_dataset(env_vars['task_data_name'], env_vars['data_dir'])
path_run_dir = generate_run_directory(env_vars['scratch_dir'], env_vars['task_data_name'], env_vars['model_name'], env_vars['local_compare_flag'])

Expand Down Expand Up @@ -83,9 +126,9 @@ def on_train_epoch_end(self, trainer, pl_module):
# print the total number of different labels in the training set:
print(f"Total number of different labels in the training set: {len(label_counts)}")

#adsValData = DataLoader(ds_val, batch_size=2, shuffle=False)
adsValData = DataLoader(ds_val, batch_size=2, shuffle=False)
# print adsValData type
#print('adsValData type: ', type(adsValData))
print('adsValData type: ', type(adsValData))

train_size = len(ds_train)
val_size = len(ds_val)
Expand All @@ -110,12 +153,7 @@ def on_train_epoch_end(self, trainer, pl_module):
to_monitor = "val/AUC_ROC"
min_max = "max"
log_every_n_steps = 1
early_stopping = EarlyStopping(
monitor=to_monitor,
min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement
patience=10, # number of checks with no improvement
mode=min_max
)

checkpointing = ModelCheckpoint(
dirpath=str(path_run_dir), # dirpath
monitor=to_monitor,
Expand All @@ -128,20 +166,14 @@ def on_train_epoch_end(self, trainer, pl_module):
useCuda = torch.cuda.is_available()


lFArgsDict={}
lFArgsDict['reduction']='sum'
mFArgsDict={}
mFArgsDict['task']="multiclass"
mFArgsDict['num_classes']=2

#model = model.to(torch.device('cuda'))
if local_compare_flag:
torch.autograd.set_detect_anomaly(True)
trainer = Trainer(
accelerator='gpu', devices=1,
precision=16,
default_root_dir=str(path_run_dir),
callbacks=[checkpointing, early_stopping],
callbacks=[checkpointing],#early_stopping
enable_checkpointing=True,
check_val_every_n_epoch=1,
min_epochs=50,
Expand All @@ -163,16 +195,17 @@ def on_train_epoch_end(self, trainer, pl_module):
#adsValBatchSize=2,
nodeWeightage=cal_weightage(train_size),
model=model,
mergeMethod = "geomedian",
#lossFunction="BCEWithLogitsLoss",
#lossFunctionArgs=lFArgsDict,
#metricFunction="F1Score",
#metricFunctionArgs=mFArgsDict
#metricFunction="AUROC",
)

torch.autograd.set_detect_anomaly(True)
swarmCallback.logger.setLevel(logging.DEBUG)
swarmCallback.on_train_begin()
trainer = Trainer(
#resume_from_checkpoint=env_vars['scratch_dir']+'/checkpoint.ckpt',
accelerator='gpu', devices=1,
precision=16,
default_root_dir=str(path_run_dir),
Expand All @@ -190,7 +223,7 @@ def on_train_epoch_end(self, trainer, pl_module):
swarmCallback.on_train_end()
model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path)
model.save_last_checkpoint(trainer.logger.log_dir, checkpointing.last_model_path)

'''
import subprocess
# Get the container ID for the latest user-env container
Expand All @@ -207,6 +240,8 @@ def on_train_epoch_end(self, trainer, pl_module):
line = line.decode("utf-8").rstrip()
print(line)
log_file.write(line + "\n")
predict(path_run_dir, os.path.join(dataDir, task_data_name,'test'), model_name)
predict_last(path_run_dir, os.path.join(dataDir, task_data_name,'test'), model_name)
'''

predict(path_run_dir, os.path.join(dataDir, task_data_name,'test'), model_name, last_flag=False, prediction_flag = prediction_flag)
predict(path_run_dir, os.path.join(dataDir, task_data_name,'test'), model_name, last_flag=True, prediction_flag = prediction_flag)

168 changes: 0 additions & 168 deletions workspace/odelia-breast-mri/model/main_predict.py

This file was deleted.

Loading

0 comments on commit df0b484

Please sign in to comment.