Skip to content

Commit 7f8a5eb

Browse files
mjwillsonWeatherBenchX authors
authored andcommitted
Add a statistical inference method based on the stationary bootstrap, with automatic block length selection.
PiperOrigin-RevId: 811459311
1 parent 11853f0 commit 7f8a5eb

File tree

8 files changed

+818
-93
lines changed

8 files changed

+818
-93
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
base_requires = [
1818
"apache_beam[gcp]>=2.31.0",
19+
"arch>=5.0",
1920
"cftime>=1.6.2",
2021
"numpy>=2.1.3",
2122
"pandas>=2.2.3",

weatherbenchX/aggregation.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import collections
1717
import dataclasses
18-
from typing import Any, Collection, Hashable, Mapping, Sequence
18+
from typing import Any, Callable, Collection, Hashable, Mapping, Sequence
1919

2020
from weatherbenchX import binning
2121
from weatherbenchX import weighting
@@ -135,14 +135,20 @@ def sum_along_dims(self, dims: Collection[str]) -> 'AggregationState':
135135
if self.sum_weighted_statistics is None:
136136
# Further reduction of a generic zero state is also a zero state.
137137
return self
138+
else:
139+
return self.map(lambda x: x.sum(dims, skipna=False))
140+
141+
def map(
142+
self,
143+
func: Callable[[xr.DataArray], xr.DataArray],
144+
) -> 'AggregationState':
145+
"""Apply a function to all DataArrays in the AggregationState."""
146+
if self.sum_weighted_statistics is None:
147+
raise ValueError('Cannot map a zero AggregationState.')
138148
sum_weighted_statistics = xarray_tree.map_structure(
139-
lambda x: x.sum(dims, skipna=False),
140-
self.sum_weighted_statistics,
141-
)
149+
func, self.sum_weighted_statistics)
142150
sum_weights = xarray_tree.map_structure(
143-
lambda x: x.sum(dims, skipna=False),
144-
self.sum_weights,
145-
)
151+
func, self.sum_weights)
146152
return AggregationState(sum_weighted_statistics, sum_weights)
147153

148154
def to_data_tree(self) -> xr.DataTree:
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Bootstrap methods statistical inference methods for evaluation metrics."""
15+
16+
from collections.abc import Mapping, Hashable
17+
import functools
18+
19+
import arch.bootstrap
20+
import numpy as np
21+
from weatherbenchX import aggregation
22+
from weatherbenchX import xarray_tree
23+
from weatherbenchX.metrics import base as metrics_base
24+
from weatherbenchX.statistical_inference import base
25+
from weatherbenchX.statistical_inference import utils
26+
27+
import xarray as xr
28+
29+
30+
def stationary_bootstrap_indices(
31+
n_data: int,
32+
mean_block_length: float,
33+
n_replicates: int,
34+
dtype: np.typing.DTypeLike = np.int64,
35+
) -> np.ndarray:
36+
"""Samples indices with shape (n_data, n_replicates)."""
37+
end_block_prob = 1/mean_block_length
38+
current_indices = np.random.randint(n_data, size=(n_replicates,), dtype=dtype)
39+
all_indices = [current_indices]
40+
for _ in range(1, n_data):
41+
end_block_flags = np.random.rand(n_replicates) < end_block_prob
42+
new_random_indices = np.random.randint(
43+
n_data, size=(n_replicates,), dtype=dtype)
44+
next_indices = (current_indices+1) % n_data
45+
current_indices = np.where(
46+
end_block_flags, new_random_indices, next_indices)
47+
all_indices.append(current_indices)
48+
return np.stack(all_indices, axis=0)
49+
50+
51+
_REPLICATE_DIM = 'bootstrap_replicate'
52+
53+
54+
class StationaryBootstrap(base.StatisticalInferenceMethod):
55+
r"""Stationary bootstrap method of Politis and Romano [1].
56+
57+
With optimal block length selection from [2], [3].
58+
59+
[1] Politis, D. N. & Romano, J. P. The stationary bootstrap. J. Am. Stat.
60+
Assoc. 89, 1303–1313 (1994).
61+
[2] Politis, D. N. & White, H. Automatic Block-Length Selection for the
62+
Dependent Bootstrap, Econometric Reviews, 23:1, 53-70 (2004).
63+
[3] Patton, A., Politis, D. N. & White, H. Correction to "Automatic
64+
Block-Length Selection for the Dependent Bootstrap" by D. Politis and
65+
H. White, Econometric Reviews, 28:4, 372-375 (2009).
66+
"""
67+
68+
def __init__(
69+
self,
70+
metrics: Mapping[str, metrics_base.Metric],
71+
aggregated_statistics: aggregation.AggregationState,
72+
experimental_unit_dim: str,
73+
n_replicates: int,
74+
mean_block_length: float | None = None,
75+
block_length_rounding_resolution: float | None = 30.0,
76+
stationary_bootstrap_indices_cache_size: int = 10,
77+
):
78+
self._experimental_unit_dim = experimental_unit_dim
79+
self._mean_block_length = mean_block_length
80+
self._n_replicates = n_replicates
81+
self._aggregated_statistics = aggregated_statistics
82+
self._block_length_rounding_resolution = block_length_rounding_resolution
83+
self._stationary_bootstrap_indices = functools.lru_cache(
84+
maxsize=stationary_bootstrap_indices_cache_size)(
85+
stationary_bootstrap_indices)
86+
self._original_values = {}
87+
self._resampled_values = {}
88+
for metric_name, metric in metrics.items():
89+
original_values, resampled_values = self._bootstrap_results_for_metric(
90+
metric)
91+
self._original_values[metric_name] = original_values
92+
self._resampled_values[metric_name] = resampled_values
93+
94+
def _optimal_block_length(self, data_array: xr.DataArray) -> float:
95+
if self._mean_block_length is not None:
96+
return self._mean_block_length
97+
98+
assert self._experimental_unit_dim in data_array.dims
99+
if data_array.sizes[self._experimental_unit_dim] < 8:
100+
# At least, arch.bootstrap.optimal_block_length craps out with a very
101+
# unfriendly error if given a smaller array.
102+
raise ValueError(
103+
'Need at least 8 data points along experimental_unit_dim '
104+
f'{self._experimental_unit_dim} to set mean_block_length '
105+
'automatically -- and many more than 8 recommended.')
106+
data_array = data_array.squeeze()
107+
assert data_array.ndim == 1
108+
109+
# .stationary gives the mean block length for use with the stationary
110+
# bootstrap:
111+
result = arch.bootstrap.optimal_block_length(
112+
data_array.data).stationary.item()
113+
# Values <1 can sometimes show up, but 1 is the minimum.
114+
result = max(1.0, result)
115+
if self._block_length_rounding_resolution is not None:
116+
# Rounding this off makes it a useful key for LRU caching of the
117+
# bootstrap indices. These need to be sampled separately for each mean
118+
# block length used, and this forms a significant fraction of total
119+
# running time. The inference of an optimal block length is noisy enough
120+
# that rounding off to 1 or 2 significant figures (or the similar but
121+
# smoother logarithmic rounding below) should be perfectly acceptable.
122+
result = utils.logarithmic_round(
123+
result, self._block_length_rounding_resolution)
124+
return result
125+
126+
def _bootstrap_results_for_metric(
127+
self, metric: metrics_base.Metric) -> tuple[
128+
Mapping[Hashable, xr.DataArray], Mapping[Hashable, xr.DataArray]]:
129+
130+
overall_values = metrics_base.compute_metric_from_statistics(
131+
metric, self._aggregated_statistics.sum_along_dims(
132+
[self._experimental_unit_dim]).mean_statistics())
133+
per_unit_values = metrics_base.compute_metric_from_statistics(
134+
metric, self._aggregated_statistics.mean_statistics())
135+
sum_weighted_stats = {
136+
stat_name: self._aggregated_statistics.sum_weighted_statistics[
137+
stat.unique_name]
138+
for stat_name, stat in metric.statistics.items()
139+
}
140+
sum_weights = {
141+
stat_name: self._aggregated_statistics.sum_weights[
142+
stat.unique_name]
143+
for stat_name, stat in metric.statistics.items()
144+
}
145+
resampled_values = {}
146+
for var_name in overall_values.keys():
147+
# Results for different variables will need to be computed separately,
148+
# as the optimal block length will depend on the variable.
149+
#
150+
# We try to avoid computing results for *all* variables every time we
151+
# do a bootstrap resample based on the optimal block length for a single
152+
# *one* of these variables though, using this logic:
153+
if (len(overall_values) > 1 and
154+
all(var_name in vars for vars in sum_weighted_stats.values())):
155+
# A corresponding variable is present in each Statistic and we make the
156+
# assumption that this variable in the result only depends on these
157+
# corresponding variables in the stats and that we can recompute the
158+
# Metric with the statistics restricted just to this single variable.
159+
# This saves us resampling statistics for all the other variables.
160+
sum_weighted_stats_for_this_var = {
161+
stat_name: {var_name: vars[var_name]}
162+
for stat_name, vars in sum_weighted_stats.items()
163+
}
164+
sum_weights_for_this_var = {
165+
stat_name: {var_name: vars[var_name]}
166+
for stat_name, vars in sum_weights.items()
167+
}
168+
else:
169+
# If there was only a single variable, it's fine to resample all the
170+
# statistics since this will only be done once.
171+
# If there are multiple variables and they don't correspond 1:1 to
172+
# variables in the statistics, then we can't do any better than
173+
# resampling all the statistics even though this may result in some
174+
# redundant work. This should be a rare edge case though.
175+
sum_weighted_stats_for_this_var = sum_weighted_stats
176+
sum_weights_for_this_var = sum_weights
177+
178+
# The optimal block length will also depend on the specific component
179+
# within the DataArray for the metric result, for example different
180+
# degrees of autocorrelation may be observed for forecast metrics at
181+
# different lead times.
182+
# And so bootstrap indices will need to be sampled separately for each
183+
# component along any dimensions present in the metric result.
184+
#
185+
# We assume that where a dimension of the metric result also occurs in the
186+
# statistics, that the metrics at index i along that dimension only depend
187+
# on the statistics at index i along the same dimension, and that we can
188+
# therefore slice the statistics down to a single index along any such
189+
# dimensions when computing a single index of the metric result.
190+
#
191+
# This assumption isn't strictly guaranteed, but it is true in the vast
192+
# majority of cases, including:
193+
# * The common case of a per-component metric like RMSE, which is a scalar
194+
# quantity computed independently for each component.
195+
# * Metrics which introduce some additional internal dimensions on their
196+
# statistics, but reduce them down to a scalar value in their output.
197+
# * Metrics which introduce some additional dimensions in their output
198+
# which aren't present in the statistics, but use a different dimension
199+
# name for them to any dimensions used in the statistics.
200+
per_var_resampled_values = utils.apply_to_slices(
201+
functools.partial(self._bootstrap_results_for_metric_scalar,
202+
metric, var_name),
203+
per_unit_values[var_name],
204+
sum_weighted_stats_for_this_var,
205+
sum_weights_for_this_var,
206+
dim=overall_values[var_name].dims,
207+
)
208+
resampled_values[var_name] = per_var_resampled_values
209+
return overall_values, resampled_values
210+
211+
def _bootstrap_results_for_metric_scalar(
212+
self,
213+
metric: metrics_base.Metric,
214+
var_name: str,
215+
per_unit_values: xr.DataArray,
216+
sum_weighted_stats: Mapping[str, Mapping[Hashable, xr.DataArray]],
217+
sum_weights: Mapping[str, Mapping[Hashable, xr.DataArray]],
218+
) -> xr.DataArray:
219+
n_data = per_unit_values.sizes[self._experimental_unit_dim]
220+
mean_block_length = self._optimal_block_length(per_unit_values)
221+
bootstrap_indices = self._stationary_bootstrap_indices(
222+
n_data=n_data,
223+
mean_block_length=mean_block_length,
224+
n_replicates=self._n_replicates,
225+
)
226+
bootstrap_indices = xr.DataArray(
227+
bootstrap_indices, dims=[self._experimental_unit_dim, _REPLICATE_DIM])
228+
229+
def sum_of_resampled(data):
230+
return data.isel({self._experimental_unit_dim: bootstrap_indices}).sum(
231+
self._experimental_unit_dim)
232+
sum_weighted_stats, sum_weights = xarray_tree.map_structure(
233+
sum_of_resampled, (sum_weighted_stats, sum_weights))
234+
# del bootstrap_indices
235+
mean_stats = xarray_tree.map_structure(
236+
lambda x, y: x / y, sum_weighted_stats, sum_weights)
237+
del sum_weighted_stats, sum_weights
238+
239+
return metric.values_from_mean_statistics(mean_stats)[var_name]
240+
241+
def point_estimates(self) -> base.MetricValues:
242+
return self._original_values
243+
244+
def standard_error_estimates(self) -> base.MetricValues:
245+
return xarray_tree.map_structure(
246+
lambda x: x.std(_REPLICATE_DIM, ddof=1), self._resampled_values)
247+
248+
def confidence_intervals(
249+
self, alpha: float = 0.05
250+
) -> tuple[base.MetricValues, base.MetricValues]:
251+
# TODO(matthjw): implement BCa intervals.
252+
return (
253+
xarray_tree.map_structure(
254+
lambda x: x.quantile(alpha/2, _REPLICATE_DIM),
255+
self._resampled_values),
256+
xarray_tree.map_structure(
257+
lambda x: x.quantile(1-alpha/2, _REPLICATE_DIM),
258+
self._resampled_values),
259+
)
260+
261+
def p_values(self, null_value: float = 0.) -> base.MetricValues:
262+
"""p-value for a two-sided test with the given null hypothesis value."""
263+
264+
# Obtained by inverting the percentile confidence interval above.
265+
# TODO(matthjw): replace with inverting the BCa interval when implemented.
266+
267+
def p_value_numpy_1d(resampled: np.ndarray) -> float:
268+
data = np.sort(resampled)
269+
q = np.linspace(0, 1, data.shape[0])
270+
empirical_cdf_at_null = np.interp(null_value, data, q)
271+
return 2 * min(empirical_cdf_at_null, 1 - empirical_cdf_at_null)
272+
273+
def p_value(resampled: xr.DataArray) -> xr.DataArray:
274+
return xr.apply_ufunc(
275+
p_value_numpy_1d,
276+
resampled,
277+
input_core_dims=[[_REPLICATE_DIM]],
278+
vectorize=True)
279+
280+
return xarray_tree.map_structure(p_value, self._resampled_values)

0 commit comments

Comments
 (0)