A highly optimized PyTorch implementation of U-Net architecture featuring Flash Attention and modern deep learning optimizations. This implementation is specifically tuned for the NVIDIA RTX 4070 GPU and can be used for various image segmentation tasks.
- Flash Attention integration for faster attention computation
- Mixed precision training (FP16)
- Memory-efficient gradient checkpointing
- Optimized for RTX 4070's tensor cores
- Automatic dataset handling via Hugging Face
- Real-time visualization of training progress
- Depth-wise separable convolutions for efficiency
- GPU: Optimized for NVIDIA RTX 4070 (12GB VRAM)
- RAM: Minimum 16GB recommended
- Storage: Space for dataset and checkpoints
Note: If you're using a different GPU, you might need to adjust the batch size and other parameters accordingly. The current settings are optimized for RTX 4070's 12GB VRAM.
- Python 3.8 or higher
- CUDA 11.7 or higher
- Create and activate a virtual environment:
python -m venv venv
source venv/bin/activate # On Windows, use: venv\Scripts\activate
- Install required packages:
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install flash-attn einops datasets tqdm matplotlib numpy
Run the main script:
python main.py
The script will:
- Load the dataset from Hugging Face
- Initialize the optimized U-Net model
- Start training with real-time visualization
- Display both training and validation metrics
- Selective Attention: Only used in deeper layers where it matters most
- Gradient Checkpointing: Applied selectively to manage memory usage
- Batch Size: Optimized to 8 for RTX 4070's 12GB VRAM
- Mixed Precision: Uses FP16 for efficient memory usage
- Flash Attention: Faster and more memory-efficient attention mechanism
- Depth-wise Separable Convolutions: More efficient than standard convolutions
- Tensor Core Utilization: Channel sizes aligned for optimal tensor core usage
- Efficient Data Loading: Optimized worker count and memory pinning
- AdamW Optimizer: Modern optimizer with proper weight decay
- OneCycleLR: Advanced learning rate scheduling
- Efficient Gradient Clearing: Uses memory-efficient gradient clearing
- Compiled Model: Uses torch.compile with "max-autotune" mode
Key parameters you might want to adjust based on your setup:
BATCH_SIZE = 8 # Adjust based on your GPU memory
NUM_WORKERS = 4 # Adjust based on your CPU cores
MIXED_PRECISION = True # Set to False if you encounter issues
The project uses Hugging Face datasets. To use your own dataset:
- Upload it to Hugging Face in the correct format
- Modify the dataset path in the code
- Ensure your data follows the expected format (image and label pairs)
-
Out of Memory (OOM)
- Reduce batch size
- Disable some attention layers
- Reduce model size
-
Slow Training
- Check CUDA version compatibility
- Adjust number of workers
- Ensure proper tensor core utilization
-
Poor Convergence
- Adjust learning rate
- Modify OneCycleLR parameters
- Check data normalization
If you encounter CUDA out of memory errors:
# Try reducing batch size
BATCH_SIZE = 4 # Instead of 8
# Or disable attention in more layers
self.conv_down_2 = OptimizedResBlock(8, 16, use_attention=False)
Feel free to:
- Report issues
- Suggest optimizations
- Submit pull requests
- Share your results
This project is MIT licensed. Feel free to use it in your own projects.
- Flash Attention implementation based on tri-dao's work
- U-Net architecture inspired by the original paper
- Optimizations tuned for modern GPU architectures