-
Notifications
You must be signed in to change notification settings - Fork 65
feat: Checkpoint pipeline infrastructure (Phase 1) #501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add CheckpointContext dataclass for carrying state through pipeline - Add PipelineStage abstract base class for all pipeline stages - Add comprehensive exception hierarchy for checkpoint operations - Establish foundation for three-layer checkpoint architecture Part of Phase 1 checkpoint pipeline infrastructure (#493)
- Add CheckpointPipeline class for orchestrating stage execution - Support async and sync execution modes - Implement context passing between stages - Add stage management (add/remove/clear) - Include error handling with optional continuation Part of Phase 1 checkpoint pipeline infrastructure (#493)
- Implement dynamic component discovery using module inspection - Add hybrid abstract class detection (ABC + name-based) - Support automatic discovery of sources, loaders, and modifiers - Convert class names to simple identifiers automatically - Use Hydra instantiate for component creation Replaces static registry pattern with true discovery mechanism. Part of Phase 1 checkpoint pipeline infrastructure (#493)
- Add download_with_retry with exponential backoff - Add checkpoint validation utilities - Add metadata extraction without full loading - Add state dict comparison utilities - Include comprehensive unit tests for all components - Test dynamic discovery with mocked modules Part of Phase 1 checkpoint pipeline infrastructure (#493)
for more information, see https://pre-commit.ci
- Add checkpoint format detection (Lightning, PyTorch, safetensors, state_dict) - Enhance CheckpointContext with format-aware fields - Add format conversion utilities - Make safetensors an optional dependency - Refactor ComponentCatalog for reduced complexity - Add comprehensive error handling and validation Implements core infrastructure for flexible checkpoint handling across different formats while maintaining backward compatibility.
for more information, see https://pre-commit.ci
Add foundational checkpoint pipeline infrastructure: - CheckpointContext: Type-safe state container for pipeline operations - PipelineStage: Abstract base class for all checkpoint pipeline stages This establishes the base architecture for extensible checkpoint operations with clean interfaces and proper type checking.
…elineStage base class - Add CheckpointContext dataclass for pipeline state management - Implement smart validation with helpful warnings for common issues - Add validate_for_stage() method for pipeline stage requirements - Include metadata management methods (update, get) - Implement PipelineStage abstract base class for pipeline components - Support multiple checkpoint formats (lightning, pytorch, safetensors) - Add comprehensive docstrings with usage examples
…dling - Extract _handle_unknown_loader() for loader error handling - Extract _handle_unknown_modifier() for modifier error handling - Add _build_loader_error_message() for detailed error messages - Add _build_modifier_error_message() for modifier errors - Add _get_loader_type_descriptions() for loader documentation - Add _get_modifier_type_descriptions() for modifier documentation - Add _find_similar_names() for smart suggestions - Use list comprehensions for better performance - Preserve all helpful error messages and suggestions - Reduce McCabe complexity from 11 to acceptable levels
…t validation - Add CheckpointPipeline class for stage orchestration - Implement async and sync execution modes - Add Hydra-based configuration support with instantiation - Implement smart pipeline composition validation - Check source-loader-modifier ordering automatically - Detect duplicate stages and provide warnings - Suggest missing stages based on pipeline composition - Add pre-execution validation with health checks - Support dynamic stage management (add/remove/clear) - Include comprehensive error handling with context - Track stage execution in metadata for debugging - Support continue_on_error for resilient pipelines
The conftest.py used hardcoded relative path "../src/anemoi/training/config" which only worked from training/tests/ directory, causing Hydra to hang when tests were run from training/ root. Changes: - Add _get_config_path() helper to dynamically locate config directory - Supports running tests from any directory (training/ or training/tests/) - Add lazy import of AnemoiDatasetsDataModule for performance - Tests now run in 40s instead of hanging for 2+ minutes Fixes issue where `cd training && pytest tests/checkpoint/` would timeout.
Updated error handling tests to expect the proper CheckpointConfigError exception instead of generic ValueError, matching the actual implementation in catalog.py. Changes: - test_get_source_target_when_empty: expect CheckpointConfigError - test_get_loader_target_when_empty: expect CheckpointConfigError - test_get_modifier_target_when_empty: expect CheckpointConfigError - Update assertions to match actual error message format
Remove undefined placeholder functions from __all__ that don't exist yet: - create_error_context - ensure_checkpoint_error - log_checkpoint_error - map_pytorch_error_to_checkpoint_error - validate_checkpoint_keys Apply ruff auto-fixes: - Sort __all__ alphabetically within sections - Add trailing commas where missing Fixes 5 F822 ruff errors (undefined name in __all__).
Fix three categories of pre-commit failures without using noqa suppressions: 1. RUF039: Add raw string prefix to regex pattern in mlflow/logger.py - Changed re.compile(b"...") to re.compile(rb"...") 2. docsig: Add Returns section to freeze_submodule_by_name docstring - Added Returns section documenting bool return value - Improved documentation completeness for deprecated function 3. ARG002: Rename unused parameters in Lightning callback methods - Prefixed unused parameters with underscore (_trainer, _pl_module) - Maintains interface compliance while signaling intentionally unused params - Follows Python convention for unused interface method arguments All pre-commit hooks now pass successfully.
Phase 1: Checkpoint Pipeline Infrastructure - Progress UpdateSummaryThis PR implements the foundational infrastructure for the new checkpoint architecture, establishing the core abstractions, pipeline pattern, and utilities that all checkpoint operations will build upon. Related Issues: #493 (Phase 1), Part of 5-phase checkpoint architecture refactoring 🎯 MotivationThe current checkpoint handling in anemoi-training is monolithic and difficult to extend. This PR establishes a clean, three-layer pipeline architecture: 📦 Changes IncludedNew Files (VERIFIED)Modified Files (VERIFIED)Pending Additions (Not Yet in PR)🔑 Key Features1. Pipeline PatternClean, composable stages for checkpoint processing: # Example usage (programmatic)
pipeline = CheckpointPipeline(
stages=[
CheckpointSource(...), # Phase 2
LoadingStrategy(...), # Phase 2
ModelModifier(...), # Phase 2
],
async_execution=True
)
context = await pipeline.execute(initial_context)2. Component CatalogAutomatic discovery of pipeline components using reflection: # No manual registration needed!
ComponentCatalog.list_sources() # Auto-discovers CheckpointSource subclasses
ComponentCatalog.list_loaders() # Auto-discovers LoadingStrategy subclasses
ComponentCatalog.list_modifiers() # Auto-discovers ModelModifier subclasses3. Multi-Format SupportSeamless handling of different checkpoint formats:
Auto-detection and conversion included. 4. Comprehensive Error HandlingRich exception hierarchy for debugging: try:
context = await pipeline.execute(context)
except CheckpointNotFoundError as e:
logger.error(f"Checkpoint not found: {e}")
except CheckpointIncompatibleError as e:
logger.error(f"Incompatible checkpoint: {e}")5. Async-First DesignEfficient async operations with sync compatibility: # Async mode (default)
context = await pipeline.execute(context)
# Sync mode (when needed)
pipeline = CheckpointPipeline(stages, async_execution=False)
context = pipeline.execute(context)🧪 Testing StrategyCoverage Metrics (VERIFIED)
Test Execution# Run all checkpoint tests
pytest training/tests/checkpoint/ -v
# Run with coverage
pytest training/tests/checkpoint/ --cov=anemoi.training.checkpoint --cov-report=html
# Quick verification
pytest training/tests/checkpoint/ -q
# Output: 211 passed, 2 skipped in 46.46sTest Coverage by Module
🔄 Integration PointsWith PR #464 (Checkpoint Acquisition)# PR #464 will implement CheckpointSource subclasses
class S3Source(PipelineStage): # Uses base.py abstractions
async def process(self, context: CheckpointContext) -> CheckpointContext:
# Load from S3, populate context.checkpoint_data
...With PR #494 (Loading Orchestration)# PR #494 will implement LoadingStrategy subclasses
class TransferLearningLoader(PipelineStage): # Uses base.py abstractions
async def process(self, context: CheckpointContext) -> CheckpointContext:
# Apply checkpoint to model with flexibility
...With PR #442 (Model Modifiers)# PR #442 will implement ModelModifier as PipelineStage
class FreezingModifier(PipelineStage): # Uses base.py abstractions
async def process(self, context: CheckpointContext) -> CheckpointContext:
# Freeze model layers post-loading
...
|
| def _use_modern_checkpoint_pipeline(self, model: GraphForecaster) -> GraphForecaster: | ||
| """Use the modern checkpoint pipeline system.""" | ||
| try: | ||
| import asyncio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the asyncio functionality optional? (as in could people use the pipeline without this?)
| import torch.nn as nn | ||
|
|
||
| # Import pipeline exceptions for consistent error handling | ||
| from anemoi.training.checkpoint.exceptions import CheckpointConfigError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdvllrs also implemented some exception https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/migrations/migrator.py for the migration component. I some might attack different aspects but I wonder if we should just have a common script to store all of this and avoid duplications.
|
|
||
| # Load the checkpoint with proper error handling | ||
| # Get device from model parameters, defaulting to CPU | ||
| device = next(model.parameters()).device if len(list(model.parameters())) > 0 else "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we add a log to show what device is being used?
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ComponentCatalog: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like useful functionality but I wonder if this should live in anemoi-utils or potentially https://github.com/ecmwf/anemoi-registry/tree/main?
(as in how much of this is for classes related to checkpoints or extensive to any other anemoi class?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| def load_checkpoint( | ||
| checkpoint_path: Path | str, | ||
| checkpoint_format: Literal["lightning", "pytorch", "safetensors", "state_dict"] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use these formats at the moment? (as in safetensors is not yet supported, I would be in favour of keeping this simple for now and extending this when safetensors functionality is introduced)!
| LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class CheckpointPipelineCallback(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I get correctly that the way the pipeline is then orchestrated and executed is via a Pytorch Lightning callback?
| msg = ( | ||
| f"Cannot load safetensors checkpoint '{path}': safetensors library not available.\n" | ||
| "Install with: pip install safetensors\n" | ||
| "Or use a different checkpoint format (.ckpt, .pt, .pth)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be the benefit of introducing these many formats here? and how easy it would then to share checkpoints between users with many formats? (guess related to my comment that I think we could consider safetensors in a second phase)
| >>> # Get component target path | ||
| >>> target = ComponentCatalog.get_source_target('s3') | ||
| >>> print(target) | ||
| >>> 'anemoi.training.checkpoint.sources.S3Source' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the case of s3 bucket - how would authentication be handled? I'd focus for now in supporting local checkpoints.
|
|
||
| LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this file possibly be unified with this file? callbacks/checkpoint.py
Remove safetensors format support to simplify checkpoint handling: Source changes: - Remove safetensors optional dependency from pyproject.toml - Remove safetensors import and HAS_SAFETENSORS flag from formats.py - Update checkpoint_format type hints to exclude "safetensors" - Remove safetensors loading/saving logic and helper functions - Remove .safetensors from supported file extensions in exceptions - Update documentation to reflect supported formats only Test changes: - Remove safetensors fixture from conftest.py - Remove 8 safetensors test methods from test_formats.py - Remove safetensors skip logic from test_utils.py - Remove unused unittest.mock.patch import Net changes: -55 lines from source, -157 lines from tests Supported formats after this change: lightning, pytorch, state_dict
Summary
This PR implements Phase 1 of the checkpoint pipeline infrastructure, establishing the core abstractions and pipeline pattern for flexible checkpoint handling in Anemoi.
Changes
Technical Details
The component catalog uses a hybrid approach for identifying abstract classes:
This ensures true dynamic discovery without hardcoded component lists, making the system easily extensible.
Testing
Related Issues
Closes #493
Next Steps
This is Phase 1 of a multi-phase implementation:
📚 Documentation preview 📚: https://anemoi-training--501.org.readthedocs.build/en/501/
📚 Documentation preview 📚: https://anemoi-graphs--501.org.readthedocs.build/en/501/
📚 Documentation preview 📚: https://anemoi-models--501.org.readthedocs.build/en/501/