Python implementation of "Synthetic Data for Robust Stroke Segmentation" published in Machine Learning for Biomedical Imaging (MELBA) 2025.
This repository contains the implementation of our MELBA 2025 paper on synthetic data generation for stroke lesion segmentation. The method uses synthetic data to improve model generalization across different imaging protocols and patient populations.
Features:
- Synthetic data generation pipeline using healthy brain MRI
- Multi-tissue segmentation (lesions and healthy brain tissue)
- Mixed precision training with configurable loss functions
- Test-time augmentation (TTA) for improved inference accuracy
- Hugging Face Hub integration with
PyTorchModelHubMixin
- 6 pre-trained models available for immediate use
- Easy model loading and inference via
synthstroke_model.py
Paper: Chalcroft, L., Pappas, I., Price, C. J., & Ashburner, J. (2025). Synthetic Data for Robust Stroke Segmentation. Machine Learning for Biomedical Imaging, 3, 317β346.
- Python 3.10+
- CUDA-capable GPU (recommended)
- 8GB+ RAM
-
Clone the repository
git clone https://github.com/liamchalcroft/synthstroke.git cd synthstroke
-
Set up environment
conda create -n synthstroke python=3.10 conda activate synthstroke
-
Install dependencies
pip install -r requirements.txt
Baseline Model (Real Data Only)
Train a baseline model using only real stroke imaging data:
python train.py \
--name baseline_model \
--logdir ./ \
--baseline \
--l2 50 \
--patch 128 \
--amp \
--epochs 500 \
--epoch_length 200 \
--lr 0.001 \
--val_interval 2
Parameters:
--baseline
: Use real stroke images (no synthetic data)--l2 50
: L2 loss for first 50 epochs, then switches to Dice loss--patch 128
: Random crop size for training patches--amp
: Enable automatic mixed precision training--val_interval 2
: Validate and save weights every 2 epochs
SynthStroke Model (With Synthetic Data)
Train the model with synthetic data augmentation:
python train.py \
--name synthstroke_model \
--logdir ./ \
--mbhealthy \
--fade \
--lesion_weight 2 \
--l2 50 \
--patch 128 \
--amp \
--epochs 500 \
--epoch_length 200 \
--lr 0.001 \
--val_interval 2
Key Features:
--mbhealthy
: Enable MultiBrain healthy tissue segmentation--fade
: Apply intensity non-uniformity fields to simulate penumbra--lesion_weight 2
: Increase lesion class weight for better sensitivity
Prediction on New Data
Run inference on new stroke MRI scans:
python test.py \
--weights ./synthstroke_model/checkpoint.pt \
--tta \
--mb \
--patch 128 \
--savedir /path/to/output/ \
--files "/path/to/input/*.nii.gz"
Options:
--tta
: Enable test-time augmentation--mb
: Output multi-brain tissue labels alongside lesions--files
: Path pattern or text file with input paths
For MATLAB/SPM users, check out our SPM Toolbox for seamless integration with SPM preprocessing pipelines.
The synthstroke_model.py
file provides a complete Python interface for using our models:
- PyTorchModelHubMixin Integration: Seamless loading from Hugging Face Hub
- Multiple Model Variants: Support for all 6 model types (baseline, synth, synth_pseudo, synth_plus, qatlas, qsynth)
- Test-Time Augmentation: Built-in flip-based TTA for improved accuracy
- Automatic Device Handling: Smart GPU/CPU device management
- Input Validation: Robust error checking and validation
- Model Information: Detailed metadata and parameter counts
The model library is included in this repository and requires the same dependencies as the training code.
Model | Description | Hugging Face |
---|---|---|
SynthStroke Baseline | Model trained on real ATLAS T1w data | π€ synthstroke-baseline |
SynthStroke Synth | Multi-tissue segmentation with synthetic data | π€ synthstroke-synth |
SynthStroke SynthPseudo | Synthetic data + pseudo-label augmentation | π€ synthstroke-synth-pseudo |
SynthStroke SynthPlus | Synthetic data + real multi-dataset training | π€ synthstroke-synth-plus |
SynthStroke qATLAS | qMRI-based model trained on synthetic parameters | π€ synthstroke-qatlas |
SynthStroke qSynth | qMRI-constrained synthetic data training | π€ synthstroke-qsynth |
The synthstroke_model.py
provides easy access to all pre-trained models using Hugging Face's PyTorchModelHubMixin
:
import torch
from synthstroke_model import SynthStrokeModel
# Load any model from Hugging Face Hub
model = SynthStrokeModel.from_pretrained("liamchalcroft/synthstroke-baseline")
# Prepare your MRI data (T1-weighted, shape: [batch, 1, H, W, D])
mri_volume = torch.randn(1, 1, 192, 192, 192)
# Run inference with optional Test-Time Augmentation
predictions = model.predict_segmentation(mri_volume, use_tta=True)
# For baseline model: get lesion probability map (channel 1)
lesion_probs = predictions[:, 1]
# For multi-tissue models (synth, synth_pseudo, synth_plus, qsynth):
# Get all tissue probability maps
background = predictions[:, 0] # Background
gray_matter = predictions[:, 1] # Gray Matter
white_matter = predictions[:, 2] # White Matter
partial_volume = predictions[:, 3] # Gray/White Partial Volume
csf = predictions[:, 4] # Cerebro-Spinal Fluid
stroke = predictions[:, 5] # Stroke Lesion
# Get detailed model information
info = model.get_model_info()
print(f"Model type: {info['model_type']}")
print(f"Input channels: {info['input_channels']}")
print(f"Output channels: {info['output_channels']}")
print(f"TTA support: {info['tta_support']}")
print(f"Parameters: {info['parameters']:,}")
from synthstroke_model import (
create_baseline_model, # 2-class: Background + Stroke
create_synth_model, # 6-class: Multi-tissue + Stroke
create_synth_pseudo_model, # 6-class: With pseudo-labels
create_synth_plus_model, # 6-class: Multi-dataset training
create_qatlas_model, # 2-class: qMRI-based
create_qsynth_model # 6-class: qMRI-constrained
)
# Create models locally (without downloading from Hub)
baseline_model = create_baseline_model()
synth_model = create_synth_model()
All models support flip-based TTA for improved inference accuracy:
# Enable TTA for more robust predictions
predictions_with_tta = model.predict_segmentation(mri_volume, use_tta=True)
# This uses 8 augmentations (original + 7 flipped versions) and averages results
- Framework: MONAI UNet with PyTorch
- Input: 3D MRI volumes (T1-weighted for most models, qMRI parameters for qATLAS/qSynth)
- Architecture: 3D UNet with configurable channels and strides
- Training: Mixed precision (AMP) with custom loss functions
- Inference: Optional Test-Time Augmentation support
For detailed model specifications, see the individual model cards on Hugging Face Hub.
For issues or questions, please open an issue on GitHub.
If you use SynthStroke in your research, please cite:
@article{Chalcroft2025,
title = {Synthetic Data for Robust Stroke Segmentation},
volume = {3},
ISSN = {2766-905X},
url = {http://dx.doi.org/10.59275/j.melba.2025-f3g6},
DOI = {10.59275/j.melba.2025-f3g6},
number = {August 2025},
journal = {Machine Learning for Biomedical Imaging},
publisher = {Machine Learning for Biomedical Imaging},
author = {Chalcroft, Liam and Pappas, Ioannis and Price, Cathy J. and Ashburner, John},
year = {2025},
month = aug,
pages = {317β346}
}
This project is licensed under the MIT License. See the LICENSE file for details.