Skip to content

Latest commit

 

History

History
293 lines (217 loc) · 10.9 KB

README.md

File metadata and controls

293 lines (217 loc) · 10.9 KB
R2R Logo

Reveal to Revise: An Explainable AI Life Cycle for Iterative Bias Correction of Deep Models

PyTorch

arXiv springer

Description

State-of-the-art machine learning models often learn spurious correlations embedded in the training data. This poses risks when deploying these models for high-stake decision-making, such as in medical applications like skin cancer detection.

To tackle this problem, we propose Reveal to Revise (R2R), a framework entailing the entire eXplainable Artificial Intelligence (XAI) life cycle, enabling practitioners to iteratively identify, mitigate, and (re-)evaluate spurious model behavior with a minimal amount of human interaction.

R2R Overvie

Overview of the Reveal to Revise method.

In the first step (1), R2R reveals model weaknesses by finding outliers in attributions or through inspection of latent concepts learned by the model. Secondly (2), the responsible artifacts are detected and spatially localized in the input data, which is then leveraged to (3) revise the model behavior. Concretely, we apply the methods of RRR, CDEP and ClArC for model correction, and (4) (re-)evaluate the model’s performance and remaining sensitivity towards the artifact.

Using two medical benchmark datasets for Melanoma detection and bone age estimation, we apply our R2R framework to VGG, ResNet and EfficientNet architectures and thereby reveal and correct real dataset-intrinsic artifacts, as well as synthetic variants in a controlled setting.

Please feel free to cite our work, if used in your research:

@inproceedings{pahde2023reveal,
  author       = {Frederik Pahde and
                  Maximilian Dreyer and
                  Wojciech Samek and
                  Sebastian Lapuschkin},
  editor       = {Hayit Greenspan and
                  Anant Madabhushi and
                  Parvin Mousavi and
                  Septimiu Salcudean and
                  James Duncan and
                  Tanveer F. Syeda{-}Mahmood and
                  Russell H. Taylor},
  title        = {Reveal to Revise: An Explainable {AI} Life Cycle for Iterative Bias
                  Correction of Deep Models},
  booktitle    = {Medical Image Computing and Computer Assisted Intervention - {MICCAI}
                  2023 - 26th International Conference, Proceedings, Part {II}},
  series       = {Lecture Notes in Computer Science},
  volume       = {14221},
  pages        = {596--606},
  publisher    = {Springer},
  year         = {2023},
  url          = {https://doi.org/10.1007/978-3-031-43895-0\_56},
  doi          = {10.1007/978-3-031-43895-0\_56},
}
}

Table of Contents

Installation

We use Python 3.8.10. To install the required packages, run:

pip install -r requirements.txt

Secondly, the datasets need to be downloaded. To do so, download and extract the ISIC 2019 dataset using the following:

mkdir datasets
cd datasets
wget https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Training_Input.zip
wget https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Training_GroundTruth.csv
unzip ISIC_2019_Training_Input.zip

Then, we need to download the Bone Age Estimation dataset. To do so, run:

wget https://s3.amazonaws.com/east1.public.rsna.org/AI/2017/Bone+Age+Training+Set.zip
wget https://s3.amazonaws.com/east1.public.rsna.org/AI/2017/Bone+Age+Training+Set+Annotations.zip
unzip Bone+Age+Training+Set.zip
unzip Bone+Age+Training+Set+Annotations.zip
cd ..

Model Training (optional)

Having installed the required packages and datasets, we can begin with training the models. To do so convieniently, we use config files to specify the model, dataset and training parameters.

Config Files

We can generate config files using the config file generator. To do so, specify the config file generator parameters in config_files/config_generator_training.py and run:

cd config_files
python config_generator_training.py
cd ..

Note: We suggest to use wandb to track the training progress and results.

Training

Using the previously generated config files, we can train the models. To do so, run:

python -m model_training.start_training --config_file "config_files/training/your_config_file.yaml"

NOTE: We provide model checkpoints! All checkpoints can be found in the download scripts in the scripts folder. An example for the ISIC models can be found here:

echo "Download VGG"
wget https://datacloud.hhi.fraunhofer.de/s/Tq2fGgtEZ2mMFQ4/download/checkpoint_vgg16_isic_last.pth

echo "Download ResNet18"
wget https://datacloud.hhi.fraunhofer.de/s/pijw3KFNdcZNi2e/download/checkpoint_resnet18_isic_last.pth

echo "Download EfficientNet B0"
wget https://datacloud.hhi.fraunhofer.de/s/3WoL9Ge38jJ93xN/download/checkpoint_efficientnet_b0_isic_last.pth

Revealing Spurious Behavior

To reveal spurious behavior, we propose the following steps:

Create new config files for the model correction experiments, which are also used for revealing model bias. To do so, specify the config file generator parameters in config_files/config_generator_correcting_isic.py and run:

cd config_files
python config_generator_correcting_isic.py
cd ..

In the following, you can choose between using CRP concept visualizations to find bias concepts, or SpRAy to find input samples with outlier explanations.

CRP

CRP Dash Example

Interactive Tool based on the Dash package

(1) Run CRP preprocessing:

python -m experiments.reveal_artifacts.crp_run --config_file "config_files/correction/your_config_file.yaml"

QUICKSTART: To get started with CRP right away, we provide the CRP analysis file here for the VGG-16 model trained on the ISIC dataset, which needs to be extracted in the crp_files directory.

(2) Plot the CRP concepts in layer concept_layer by running:

Note, that the concept_layer represents the name of the layer in which the CRP concepts are investigated, e.g. features.28 for VGG-16.

python -m experiments.reveal_artifacts.crp_plot_concept_embedding_plotly  --config_file "config_files/correction/your_config_file.yaml" \
                                                                          --layer_name "concept_layer"

This will start an interactive plot based on Dash, allowing you to investigate the learned concepts and find spurious ones. Having found spurious concepts, you are able to export the corresponding reference samples in the tool.

QUICKSTART: Simply run python -m experiments.reveal_artifacts.crp_plot_concept_embedding_plotly to investigate the concepts in layer features.28 of the VGG-16 model trained on the ISIC dataset.

SpRAy

CRP Dash Example

Interactive Tool based on the Dash package

(1) Run SpRAy preprocessing:

consisting of two steps: (1) collect relevances and activations, and (2) compute SpRAy embedding.

$config_path="config_files/correction/your_config_file.yaml"
$analyze_corrected_model=False

for class_id in {0,1,...}; do
  python3 -m experiments.preprocessing.global_collect_relevances_and_activations  --config_file $config_path \
                                                                                  --class_id $class_id \
                                                                                  --use_corrected_ckpt $analyze_corrected_model \
done

python -m experiments.reveal_artifacts.run_spray  --config_file $config_path \
                                                  --layer_name "concept_layer" \
                                                  --analysis-file "results/spray/analysis.hdf5" \
                                                  --corrected_model $analyze_corrected_model \

QUICKSTART: To get started with SpRAy right away, we provide the analysis file here for the VGG-16 model trained on the ISIC dataset, which needs to be extracted in the results/spray directory.

(2) Show SpRAy embedding:

python -m experiments.reveal_artifacts.plot_spray_embedding_plotly --config_file "config_files/correction/your_config_file.yaml"

QUICKSTART: Simply run python -m experiments.reveal_artifacts.plot_spray_embedding_plotly to investigate the VGG-16 model trained on the ISIC dataset.

Please note, that we already provide the artifact sample ids for the ISIC and Bone Age datasets.

Artifact Localization

Having found the artifact samples, we can localize them in the input data by running:

python -m experiments.reveal_artifacts.localize_artifacts --config_file "config_files/correction/your_config_file.yaml"

This effectively saves the artifact heatmaps in the data/localized_artifacts folder. Please note, that we already provide the artifact heatmaps for the ISIC and Bone Age datasets.

Model Correction

Now, we can correct the model bias by running:

python -m model_training.start_model_correction--config_file "config_files/correcting_isic/your_config_file.yaml"

Model Evaluation

Finally, we can evaluate the corrected model by running:

python -m experiments.evaluation.artifact_relevance --config_file "config_files/correction/your_config_file.yaml"

for artifact in {band_aid,ruler,skin_marker}; do
  python -m experiments.evaluation.accuracy_attacked_auto --config_file "config_files/correction/your_config_file.yaml" --artifact $artifact
done

python -m experiments.evaluation.evaluate_by_subset --config_file "config_files/correction/your_config_file.yaml"

where this is an example for the ISIC data-intrinsic artifacts (band_aid,ruler,skin_marker).