Skip to content

Official PyTorch implementation for "A Theory for Conditional Generative Modeling on Multiple Data Sources"

License

Notifications You must be signed in to change notification settings

ML-GSAI/Multi-Source-GM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A Theory for Conditional Generative Modeling on Multiple Data Sources

This is the official implementation for A Theory for Conditional Generative Modeling on Multiple Data Sources.

Dependencies

  • 64-bit Python 3.9 and PyTorch 2.1 (or later). See https://pytorch.org for PyTorch install instructions. We recommend using Anaconda3 to create your environment conda create -n [your_env_name] python=3.9
  • Install requirements pip install -r requirements.txt

Simulations on conditional Gaussian estimation

First change the working directory using cd simulations.

Run experiments

Run the following commands to reproduce the simulation results in Figure 1 of our paper:

python run_gaussian.py --tag=K  # Run experiment for K
python run_gaussian.py --tag=n  # Run experiment for n
python run_gaussian.py --tag=sim  # Run experiment for beta_sim

Results are saved under results/gaussian/.

Visualize results

Use the following commands to visualize the results:

python plot_gaussian_K.py
python plot_gaussian_n.py
python plot_gaussian_sim.py

Real-world experiments on conditional diffusion models

First change the working directory using cd real_world_experiments.

Preparing datasets

Datasets are stored as uncompressed ZIP archives containing uncompressed PNG or NPY files, along with a metadata file dataset.json for labels. When using latent diffusion, it is necessary to create two different versions of a given dataset: the original RGB version, used for evaluation, and a VAE-encoded latent version, used for training.

To set up ImageNet-256:

  1. Download the ILSVRC2012 data archive from Kaggle and extract it somewhere, e.g., downloads/imagenet.

  2. Crop and resize the images to create the original RGB dataset:

# Convert raw ImageNet data to a ZIP archive at 256x256 resolution, for example:
python dataset_tool.py convert --source=downloads/imagenet/train \
 --dest=datasets/img256.zip --resolution=256x256 --transform=center-crop-dhariwal
  1. Run the images through a pre-trained VAE encoder to create the corresponding latent dataset:
# Convert the pixel data to VAE latents, for example:
python dataset_tool.py encode --source=datasets/img256.zip \
 --dest=datasets/img256-sd.zip
  1. Calculate reference statistics for the original RGB dataset, to be used with calculate_metrics.py:
# Compute dataset reference statistics for calculating metrics, for example:
python calculate_metrics.py ref --data=datasets/img256.zip \
 --dest=dataset-refs/img256.pkl

For convenience, we provide bash files in directory bash_files/convert/, and you can directly run the bash files like bash bash_files/convert/convert_1c.sh sim3.

Training new models

New models can be trained using train_edm2.py. For example, to train an XS-sized conditional model for ImageNet-256 using the same hyperparameters as in our paper, run:

# Train XS-sized model for ImageNet-256 using 8 GPUs

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=25641 train_edm2.py \
 --outdir=training-runs/{similarity_name}-edm2-img256-xs-{classes_setup}-{sample_num} \
 --data=datasets/latent_datasets/{similarity_name}/inet-256-{classes_setup}-{sample_num}.zip \
 --preset=edm2-img256-xs-{classes_setup}-{sample_num} \
 --batch=$(( 1024 * 8 )) --status=64Ki --snapshot=16Mi --checkpoint=64Mi

As we have many setups in our paper, you can edit the parameters which in {}. We provide example scripts for training in directory bash_files/train/.

This example performs single-node training using 8 GPUs.

By default, the training script prints status every 128Ki (= 128 kibi = 128×210) training images (controlled by --status), saves network snapshots every 8Mi (= 8×220) training images (controlled by --snapshot), and dumps training checkpoints every 128Mi training images (controlled by --checkpoint). The status is saved in log.txt (one-line summary) and stats.json (comprehensive set of statistics). The network snapshots are saved in network-snapshot-*.pkl, and they can be used directly with, e.g., generate_images.py.

The training checkpoints, saved in training-state-*.pt, can be used to resume the training at a later time. When the training script starts, it will automatically look for the highest-numbered checkpoint and load it if available. To resume training, simply run the same train_edm2.py command line again — it is important to use the same set of options to avoid accidentally changing the hyperparameters mid-training.

Acknowledgments

The code and README of real-world experiments are based on the EDM2 repository. We appreciate their nice implementations.

License

This code is released under the MIT License. See LICENSE for details.

About

Official PyTorch implementation for "A Theory for Conditional Generative Modeling on Multiple Data Sources"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published