1+ from collections .abc import Mapping
12from os .path import relpath
23from pathlib import Path
3- from typing import Callable , Concatenate , TypeAlias
4+ from typing import Any , Callable , Concatenate , TypeAlias , overload
45
56import numpy as np
67import pytest
2324 dataset_from_kerchunk_refs ,
2425)
2526
26- RoundtripFunction : TypeAlias = Callable [Concatenate [xr .Dataset , Path , ...], xr .Dataset ]
27+ RoundtripFunction : TypeAlias = Callable [
28+ Concatenate [xr .Dataset | xr .DataTree , Path , ...], xr .Dataset | xr .DataTree
29+ ]
2730
2831
2932def test_kerchunk_roundtrip_in_memory_no_concat (array_v3_metadata ):
@@ -111,7 +114,22 @@ def roundtrip_as_kerchunk_parquet(vds: xr.Dataset, tmpdir, **kwargs):
111114 return xr .open_dataset (f"{ tmpdir } /refs.parquet" , engine = "kerchunk" , ** kwargs )
112115
113116
114- def roundtrip_as_in_memory_icechunk (vds : xr .Dataset , tmpdir , ** kwargs ):
117+ @overload
118+ def roundtrip_as_in_memory_icechunk (
119+ vdata : xr .Dataset , tmp_path : Path , ** kwargs
120+ ) -> xr .Dataset : ...
121+ @overload
122+ def roundtrip_as_in_memory_icechunk (
123+ vdata : xr .DataTree , tmp_path : Path , ** kwargs
124+ ) -> xr .DataTree : ...
125+
126+
127+ def roundtrip_as_in_memory_icechunk (
128+ vdata : xr .Dataset | xr .DataTree ,
129+ tmp_path : Path ,
130+ virtualize_kwargs : Mapping [str , Any ] | None = None ,
131+ ** kwargs ,
132+ ) -> xr .Dataset | xr .DataTree :
115133 from icechunk import Repository , Storage
116134
117135 # create an in-memory icechunk store
@@ -120,7 +138,17 @@ def roundtrip_as_in_memory_icechunk(vds: xr.Dataset, tmpdir, **kwargs):
120138 session = repo .writable_session ("main" )
121139
122140 # write those references to an icechunk store
123- vds .virtualize .to_icechunk (session .store )
141+ vdata .virtualize .to_icechunk (session .store , ** (virtualize_kwargs or {}))
142+
143+ if isinstance (vdata , xr .DataTree ):
144+ # read the dataset from icechunk
145+ return xr .open_datatree (
146+ session .store , # type: ignore
147+ engine = "zarr" ,
148+ zarr_format = 3 ,
149+ consolidated = False ,
150+ ** kwargs ,
151+ )
124152
125153 # read the dataset from icechunk
126154 return xr .open_zarr (session .store , zarr_format = 3 , consolidated = False , ** kwargs )
@@ -219,16 +247,14 @@ def test_kerchunk_roundtrip_concat(
219247
220248 roundtrip = roundtrip_func (vds , tmp_path , decode_times = decode_times )
221249
222- if decode_times is False :
223- # assert all_close to original dataset
224- xrt .assert_allclose (roundtrip , ds )
250+ # assert all_close to original dataset
251+ xrt .assert_allclose (roundtrip , ds )
225252
226- # assert coordinate attributes are maintained
227- for coord in ds .coords :
228- assert ds .coords [coord ].attrs == roundtrip .coords [coord ].attrs
229- else :
230- # they are very very close! But assert_allclose doesn't seem to work on datetimes
231- assert (roundtrip .time - ds .time ).sum () == 0
253+ # assert coordinate attributes are maintained
254+ for coord in ds .coords :
255+ assert ds .coords [coord ].attrs == roundtrip .coords [coord ].attrs
256+
257+ if decode_times :
232258 assert roundtrip .time .dtype == ds .time .dtype
233259 assert roundtrip .time .encoding ["units" ] == ds .time .encoding ["units" ]
234260 assert (
@@ -303,6 +329,102 @@ def test_datetime64_dtype_fill_value(
303329 assert roundtrip .a .attrs == vds .a .attrs
304330
305331
332+ @parametrize_over_hdf_backends
333+ @pytest .mark .parametrize (
334+ "roundtrip_func" , [roundtrip_as_in_memory_icechunk ] if has_icechunk else []
335+ )
336+ @pytest .mark .parametrize ("decode_times" , (False , True ))
337+ @pytest .mark .parametrize ("time_vars" , ([], ["time" ]))
338+ @pytest .mark .parametrize ("inherit" , (False , True ))
339+ def test_datatree_roundtrip (
340+ tmp_path : Path ,
341+ roundtrip_func : RoundtripFunction ,
342+ hdf_backend : type [VirtualBackend ],
343+ decode_times : bool ,
344+ time_vars : list [str ],
345+ inherit : bool ,
346+ ):
347+ # set up example xarray dataset
348+ with xr .tutorial .open_dataset ("air_temperature" , decode_times = decode_times ) as ds :
349+ # split into two datasets
350+ ds1 = ds .isel (time = slice (None , 1460 ))
351+ ds2 = ds .isel (time = slice (1460 , None ))
352+
353+ # save it to disk as netCDF (in temporary directory)
354+ air1_nc_path = tmp_path / "air1.nc"
355+ air2_nc_path = tmp_path / "air2.nc"
356+ ds1 .to_netcdf (air1_nc_path )
357+ ds2 .to_netcdf (air2_nc_path )
358+
359+ # use open_dataset_via_kerchunk to read it as references
360+ with (
361+ open_virtual_dataset (
362+ str (air1_nc_path ),
363+ loadable_variables = time_vars ,
364+ decode_times = decode_times ,
365+ backend = hdf_backend ,
366+ ) as vds1 ,
367+ open_virtual_dataset (
368+ str (air2_nc_path ),
369+ loadable_variables = time_vars ,
370+ decode_times = decode_times ,
371+ backend = hdf_backend ,
372+ ) as vds2 ,
373+ ):
374+ if not decode_times or not time_vars :
375+ assert vds1 .time .dtype == np .dtype ("float32" )
376+ assert vds2 .time .dtype == np .dtype ("float32" )
377+ else :
378+ assert vds1 .time .dtype == np .dtype ("<M8[ns]" )
379+ assert vds2 .time .dtype == np .dtype ("<M8[ns]" )
380+ assert "units" in vds1 .time .encoding
381+ assert "units" in vds2 .time .encoding
382+ assert "calendar" in vds1 .time .encoding
383+ assert "calendar" in vds2 .time .encoding
384+
385+ vdt = xr .DataTree .from_dict ({"/vds1" : vds1 , "/nested/vds2" : vds2 })
386+
387+ with roundtrip_func (
388+ vdt ,
389+ tmp_path ,
390+ virtualize_kwargs = dict (write_inherited_coords = inherit ),
391+ decode_times = decode_times ,
392+ ) as roundtrip :
393+ assert isinstance (roundtrip , xr .DataTree )
394+
395+ # assert all_close to original dataset
396+ roundtrip_vds1 = roundtrip ["/vds1" ].to_dataset ()
397+ roundtrip_vds2 = roundtrip ["/nested/vds2" ].to_dataset ()
398+ xrt .assert_allclose (roundtrip_vds1 , ds1 )
399+ xrt .assert_allclose (roundtrip_vds2 , ds2 )
400+
401+ # assert coordinate attributes are maintained
402+ for coord in ds1 .coords :
403+ assert ds1 .coords [coord ].attrs == roundtrip_vds1 .coords [coord ].attrs
404+ for coord in ds2 .coords :
405+ assert ds2 .coords [coord ].attrs == roundtrip_vds2 .coords [coord ].attrs
406+
407+ if decode_times :
408+ assert roundtrip_vds1 .time .dtype == ds1 .time .dtype
409+ assert roundtrip_vds2 .time .dtype == ds2 .time .dtype
410+ assert (
411+ roundtrip_vds1 .time .encoding ["units" ]
412+ == ds1 .time .encoding ["units" ]
413+ )
414+ assert (
415+ roundtrip_vds2 .time .encoding ["units" ]
416+ == ds2 .time .encoding ["units" ]
417+ )
418+ assert (
419+ roundtrip_vds1 .time .encoding ["calendar" ]
420+ == ds1 .time .encoding ["calendar" ]
421+ )
422+ assert (
423+ roundtrip_vds2 .time .encoding ["calendar" ]
424+ == ds2 .time .encoding ["calendar" ]
425+ )
426+
427+
306428@parametrize_over_hdf_backends
307429def test_open_scalar_variable (tmp_path : Path , hdf_backend : type [VirtualBackend ]):
308430 # regression test for GH issue #100
0 commit comments