@@ -724,9 +724,30 @@ def load_from_hdf5(self, file, index, task=None, func=None):
724
724
if task is None :
725
725
task = self .name
726
726
dset = file ['tasks' ][task ]
727
- if not np .all (dset .attrs ['grid_space' ]):
728
- raise ValueError ("Can only load data from grid space" )
729
- self .load_from_global_grid_data (dset , pre_slices = (index ,), func = func )
727
+ if np .all (dset .attrs ['grid_space' ]):
728
+ self .load_from_global_grid_data (dset , pre_slices = (index ,), func = func )
729
+ elif np .all (~ dset .attrs ['grid_space' ]):
730
+ self .load_from_global_coeff_data (dset , pre_slices = (index ,), func = func )
731
+ else :
732
+ raise ValueError ("Can only load global data from pure grid or coeff space" )
733
+
734
+ def load_from_global_coeff_data (self , global_data , pre_slices = tuple (), func = None ):
735
+ """Load local coeff data from array-like global coeff data."""
736
+ dim = self .dist .dim
737
+ layout = self .dist .coeff_layout
738
+ # Check shapes
739
+ data_shape = global_data .shape [- dim :]
740
+ self_shape = layout .global_shape (self .domain , scales = 1 )
741
+ if data_shape != self_shape :
742
+ raise ValueError ("Cannot change global shape when loading coeff data." )
743
+ # Extract local data from global data
744
+ component_slices = tuple (slice (None ) for cs in self .tensorsig )
745
+ spatial_slices = layout .slices (self .domain , scales = 1 )
746
+ local_slices = pre_slices + component_slices + spatial_slices
747
+ if func is None :
748
+ self [layout ] = global_data [local_slices ]
749
+ else :
750
+ self [layout ] = func (global_data [local_slices ])
730
751
731
752
def load_from_global_grid_data (self , global_data , pre_slices = tuple (), func = None ):
732
753
"""Load local grid data from array-like global grid data."""
0 commit comments