Skip to content

Official implementation of "Conditional Diffusion Models are Medical Image Classifiers that Provide Explainability and Uncertainty for Free", MIDL 2025

License

Notifications You must be signed in to change notification settings

faverogian/med-diffusion-classifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Medical Diffusion Classifier: Official PyTorch Implementation

Venue: MIDL 2025
Paper: Conditional Diffusion Models are Medical Image Classifiers that Provide Explainability and Uncertainty for Free
Authors: Gian Favero*, Parham Saremi*, Emily Kaczmarek, Brennan Nichyporuk, Tal Arbel
Institution(s): Mila - Quebec AI Institute, McGill University

Requirements

  • Use Linux (recommended) for best performance, compatibility, and reproducibility.
  • All testing, training, inference completed with A100 NVIDIA GPUs (single or multiple).
  • 64-bit Python 3.10 and PyTorch 2.6. See https://pytorch.org for PyTorch install instructions.
  • Python virtual environment (recommended) to manage libraries, packages for this repository.

Getting Started

First, clone this repository.

Installing Packages

Required packages are provided in the requirements.txt file and can be installed using the following command:

pip install -r requirements.txt

FlashAttention can be installed for faster inference time (especially for DiT models). The wheel file can be downloaded from FlashAttention GitHub. We used flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl for our CUDA and Torch version. After downloading the wheel file, install it with:

wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

pip install flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

Data Preparation

We use the CheXpert and ISIC datasets in our paper. Our train/validation/test CSV files for both datasets are in the splits folder.

Downloading Model Weights

All trained models can be downloaded from this link. Alternatively, use the following command with gdown:

gdown --folder 1x7CKrbS8pxS45EXzpUKhpusBYCLgH82Y

Configuration

Before using the models, modify the scripts/run.sh file:

  • PROJECT_ROOT: Absolute path to the root directory of the diffusion-classifier repository.
  • DATA_ROOT: Absolute path to the data directory containing isic_mel_balanced/, chexpert/, and sd_isic_chexpert/ folders.
  • INFERENCE_CHECKPOINT_FOLDER: Absolute path to the directory where the downloaded model weights are stored.

If you want to use CometML for experiment tracking, set the COMET variables in run.sh. Additionally, you must set USE_COMET=1 to enable tracking.

All training/inference hyperparameters are defined in their corresponding bash scripts. For example, the CheXpert-UNet hyperparameters are specified in scripts/unet/chexpert-unet.sh.

Below, we describe various use cases that are easily achievable with simple customizations to our code. In any case, launching the desired experiment is done via bash scripts/run.sh from the parent folder of the repository.

Using Pre-Trained Models

Scripts to run inference with all models are provided in the scripts folder. However to launch each script you only have to modify the run.sh file to select which model and data you want to run. For instance, to run the UNet model's inference on the CheXpert dataset, you'll need to set MODEL=unet, DATA=chexpert, FUNCTION=inference.

Baselines

Baseline classifiers for both datasets can be evaluated using scripts in scripts/run.sh. You can change the VARIANT and BACKBONE environment variables to run different models. Available models:

VARIANT BACKBONE
resnet18 resnet
resnet50 resnet
efficientnet_b0 efficientnet
efficientnet_b4 efficientnet
swin_base_patch4_window7_224 swin
vit_base_patch16_224 vit
vit_small_patch16_224 vit

To evaluate on CheXpert, modify the run.sh file:

export MODEL="baseline" 
export FUNCTION="inference"
export DATA="chexpert" 

Or on ISIC:

export MODEL="baseline"
export FUNCTION="inference"
export DATA="chexpert" 

Inference: Diffusion Models

Use the following instructions to run inference with diffusion models.

For faster inference, set the FLASH_ATTENTION variable to true for DiT and UNet models.

CheXpert-UNet:

export MODEL="unet"
export FUNCTION="inference"
export DATA="chexpert" 

ISIC-UNet:

export MODEL="unet"
export FUNCTION="inference"
export DATA="isic" 

CheXpert-DiT:

export MODEL="dit"
export FUNCTION="inference"
export DATA="chexpert" 

ISIC-DiT:

export MODEL="dit"
export FUNCTION="inference"
export DATA="isic" 

CheXpert-StableDiffusion:

export MODEL="sd"
export FUNCTION="inference"
export DATA="chexpert" 

ISIC-StableDiffusion:

export MODEL="sd"
export FUNCTION="inference"
export DATA="dit" 

Counterfactual Generation

Counterfactual generation is currently supported only for UNet models. To generate counterfactuals for the UNet models, modify the script to run explain.py instead of inference.py. This can be easily done by changing the FUNCTION value to explain:

export MODEL="unet"
export FUNCTION="explain"
export DATA="chexpert" 

To improve visual quality, increase SAMPLING_STEPS to at least 256 in the unet scripts (scripts/unet/chexpert-unet.sh and scripts/unet/isic-unet.sh). CFG_W refers to the classifier-free guidance scale.

The images will be saved in the inference_images directory located within the experiment folder.

Training Models

Similar to counterfactual generation, training different models is as simple as changing the FUNCTION to train. For example, to train the UNet model on CheXpert data you should use the following environment variables:

export MODEL="unet"
export FUNCTION="train"
export DATA="chexpert" 

Note: The Stable Diffusion model is jointly trained on both ISIC and CheXpert datasets. To train on a single dataset, modify the metadata.jsonl placed in data/sd_isic_chexpert file and adjust the training data folder accordingly.

In order to train the Stable Diffusion model, the diffusers package should be installed from source. To do this you can run the following command:

pip install git+https://github.com/huggingface/diffusers

The output directory for the Stable Diffusion model can be set in its train.sh script, while other models will save their checkpoints within their respective experiment folders.

License

This project is licensed under the MIT License. See the LICENSE file for details.

Citation

@misc{favero2025conditionaldiffusionmodelsmedical,
      title={Conditional Diffusion Models are Medical Image Classifiers that Provide Explainability and Uncertainty for Free}, 
      author={Gian Mario Favero and Parham Saremi and Emily Kaczmarek and Brennan Nichyporuk and Tal Arbel},
      year={2025},
      eprint={2502.03687},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2502.03687}, 
}

About

Official implementation of "Conditional Diffusion Models are Medical Image Classifiers that Provide Explainability and Uncertainty for Free", MIDL 2025

Topics

Resources

License

Stars

Watchers

Forks