@@ -532,24 +532,27 @@ def to_netcdf(
532
532
return filename
533
533
534
534
def to_datatree (self ):
535
- """Convert InferenceData object to a :class:`~datatree .DataTree`."""
535
+ """Convert InferenceData object to a :class:`~xarray .DataTree`."""
536
536
try :
537
- from datatree import DataTree
538
- except ModuleNotFoundError as err :
539
- raise ModuleNotFoundError (
540
- "datatree must be installed in order to use InferenceData.to_datatree"
537
+ from xarray import DataTree
538
+ except ImportError as err :
539
+ raise ImportError (
540
+ "xarray must be have DataTree in order to use InferenceData.to_datatree. "
541
+ "Update to xarray>=2024.11.0"
541
542
) from err
542
543
return DataTree .from_dict ({group : ds for group , ds in self .items ()})
543
544
544
545
@staticmethod
545
546
def from_datatree (datatree ):
546
- """Create an InferenceData object from a :class:`~datatree .DataTree`.
547
+ """Create an InferenceData object from a :class:`~xarray .DataTree`.
547
548
548
549
Parameters
549
550
----------
550
551
datatree : DataTree
551
552
"""
552
- return InferenceData (** {group : sub_dt .to_dataset () for group , sub_dt in datatree .items ()})
553
+ return InferenceData (
554
+ ** {group : child .to_dataset () for group , child in datatree .children .items ()}
555
+ )
553
556
554
557
def to_dict (self , groups = None , filter_groups = None ):
555
558
"""Convert InferenceData to a dictionary following xarray naming conventions.
@@ -1531,9 +1534,8 @@ def add_groups(
1531
1534
import xarray as xr
1532
1535
from xarray_einstats.stats import XrDiscreteRV
1533
1536
from scipy.stats import poisson
1534
- dist = XrDiscreteRV(poisson)
1535
- log_lik = xr.Dataset()
1536
- log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
1537
+ dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
1538
+ log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
1537
1539
idata2.add_groups({"log_likelihood": log_lik})
1538
1540
idata2
1539
1541
0 commit comments