This is the official repo for the paper One-Step Diffusion Distillation via Deep Equilibrium Models, by Zhengyang Geng*, Ashwini Pokle*, and J. Zico Kolter.
First, download the datasets EDM-Uncond-CIFAR and EDM-Cond-CIFAR from this link.
Set up the --data_path in run.sh to the dir where you store the datasets, like --data_path DATA_DIR/EDM-Uncond-CIFAR-1M.
In addition, download the precomputed dataset statistics from this link.
Set up the --stat_path in run.sh and eval.sh using your download dir plus stat name.
To train a GET, run this command:
bash run.sh N_GPU DDP_PORT --model MODEL_NAME --name EXP_NAMEN_GPU is the number of GPU used for training.
DDP_PORT is the port number for syncing gradient during distributed training.
MODEL_NAME is the model's name.
See all available models using python train.py -h.
The training log, checkpoints, and sampled images will be saved to ./results using your EXP_NAME.
For example, this command train a GET-S/2 (of patch size 2) on 4 GPUs.
bash run.sh 4 12345 --model GET-S/2 --name test-GETTo train a ViT, run this command:
bash run.sh N_GPU DDP_PORT --model ViT-B/2 --name EXP_NAMEFor training conditional models, add the --cond command.
For the O(1)-memory training, add the --mem command.
Download pretrained models from this link.
To load a checkpoint for evaluation, run this command
bash run.sh N_GPU DDP_PORT --model MODEL_NAME --resume CKPT_PATH --name EXP_NAMEThe evaluation log and sampled images will be saved to ./eval-results plus your EXP_NAME.
For evaluating conditional models, add the --cond command. Here is an example.
bash run.sh 4 12345 --model GET-B/2 --cond --resume CKPT_DIR/GET-B-cond-2M-data-bs256.pthYou can see the generative performance here. The discussion there might be interesting.
First, clone the EDM repo. Then, copy the files under /data to the /edm directory.
Set up the DATA_PATH in dataset.sh for storing the synthetic dataset.
Run the following command to generate both conditional and unconditional training sets.
bash dataset.shIf you want to generate more data pairs, adjust the range of --seeds=0-MAX_SAMPLES.
If you find our work helpful to your research, please consider citing this paper. :)
@inproceedings{
geng2023onestep,
title={One-Step Diffusion Distillation via Deep Equilibrium Models},
author={Zhengyang Geng and Ashwini Pokle and J Zico Kolter},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}Feel free to contact us if you have additional questions! Please drop an email to [email protected] (or Twitter) or [email protected].
This project is built upon TorchDEQ, DiT, and timm. Thanks for the awesome projects!
