Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

base_requires = [
"apache_beam[gcp]>=2.31.0",
"arch>=5.0",
"cftime>=1.6.2",
"numpy>=2.1.3",
"pandas>=2.2.3",
Expand Down
356 changes: 354 additions & 2 deletions weatherbenchX/statistical_inference/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
"""Bootstrap-based statistical inference methods for evaluation metrics."""

from collections.abc import Mapping
from collections.abc import Mapping, Hashable
import functools
from typing import final

import arch.bootstrap
import numpy as np
from weatherbenchX import aggregation
from weatherbenchX import xarray_tree
Expand All @@ -30,7 +32,31 @@


class Bootstrap(base.StatisticalInferenceMethod):
r"""Superclass for bootstrap-based statistical inference methods."""
r"""Superclass for bootstrap-based statistical inference methods.

# General caveats about bootstrap confidence intervals

While a somewhat general tool, the bootstrap is no magic bullet, it can
exhibit biases and its confidence intervals can have poor coverage, especially
when the sample size is not huge. Unfortunately coverage is often less than
advertised (i.e. the CIs are too narrow). Highly non-linear functions and
highly skewed distributions can also cause it to perform poorly, although the
BCa method is sometimes able to mitigate this.

Ideally we would like a confidence interval for the true underlying value
f(E[X]) where X are our statistics, E is their expectation (or limiting value
of the mean under infinite data) and f is our values_from_mean_statistics
function.

What bootstrap methods actually aim to give us, is a confidence interval for
the expectation of the finite-sample estimator: E[f(1/N \sum_n X_n)].

In cases where f is linear, this estimator is unbiased for f(E[X]) and so
these two are the same thing, but when f is nonlinear we will have to live
with the mismatch. In particular this means that bootstrap intervals may not
be strictly comparable across different sample sizes (they are intervals for
different quantities), although if doing a paired test this is not a problem.
"""

# Subclass constructors must set these.
_resampled_values: base.MetricValues
Expand Down Expand Up @@ -204,3 +230,329 @@ def __init__(
[experimental_unit_dim]).mean_statistics())
self._resampled_values = metrics_base.compute_metrics_from_statistics(
metrics, resampled_stats.mean_statistics())


def stationary_bootstrap_indices(
n_data: int,
mean_block_length: float,
n_replicates: int,
dtype: np.typing.DTypeLike = np.int64,
) -> np.ndarray:
"""Samples indices for stationary bootstrap, shape (n_data, n_replicates)."""
end_block_prob = 1/mean_block_length
current_indices = np.random.randint(n_data, size=(n_replicates,), dtype=dtype)
all_indices = [current_indices]
for _ in range(1, n_data):
end_block_flags = np.random.rand(n_replicates) < end_block_prob
new_random_indices = np.random.randint(
n_data, size=(n_replicates,), dtype=dtype)
# Blocks wrap around in a periodic fashion. This feature of the stationary
# bootstrap method exists to avoid endpoint bias / ensure that each data
# point is equally likely to be sampled.
next_indices = (current_indices+1) % n_data
current_indices = np.where(
end_block_flags, new_random_indices, next_indices)
all_indices.append(current_indices)
return np.stack(all_indices, axis=0)


class StationaryBootstrap(Bootstrap):
r"""Stationary bootstrap method of Politis and Romano [1].

This is a block bootstrap resampling method designed to work with stationary
time series data where there may be some temporal dependence. By default we
use the optimal block length selection procedure from [2], [3], and this is
done separately for every metric, variable, and index along any extra
dimensions present in the metric result.

The core method isn't limited to Metrics which are simple means or linear
functions of means and so has broader applicability than the t-test or
autocorrelation-corrected versions of it. There are still some caveats to
note however:

# Optimal block length selection for functions of multiple time-series

The optimal block length selection algorithm we use was only designed to apply
to means of univariate time series, but our metrics in general can be computed
from arbitrary functions of the means of multiple time-series of statistics.

The compromise we make is to apply the block length selection procedure to
scalar values of the metric computed on a per-timestep basis. For metrics that
are a simple mean or a linear function of means this is using the method
exactly as intended. For non-linear functions of means, essentially we're
approximating f(mean(X)) as mean(f(X)) for the purposes of the block length
selection. This is justified when the function f is close to linear over the
range of variation of *per-timestep* values of X, but if f is very nonlinear
over this range then block length selection can fail badly and you are
advised to select an appropriate block length manually instead.
TODO(matthjw): Provide more options for this, e.g. a way to specify a
particular statistic to use for the block length selection instead of the
per-timestep values of the metric itself.

Other possible heuristic approaches seen in the literature are to base it on
the average or maximum of the optimal block lengths computed for each separate
univariate statistics time series, or on a VAR (vector auto-regressive) model
of the statistics. These may sometimes be too conservative, because they don't
take into account the potential of the function f to reduce the effect of
autocorrelation in some cases, for example where f computes something like
a difference of two positively-correlated time-series. A better solution may
be to linearize f around the mean, and then apply block length selection to
the per-timestep values of this linearized function. This would require the
gradient of f however.

From what I understand, automatic block-length selection for bootstrap methods
applied to multivariate time series data is a difficult open problem in
statistics. If the default approach doesn't work for you, you are free to
manually specify the block length to use, and you may sometimes need to.

# Stationarity assumption

While this method makes few distributional assumptions, one assumption is does
make is that the time series of statistics is stationary, meaning the
distribution (including marginal mean and variance, autocorrelation at
different lags, etc) doesn't change over time. If you have clearly non-random
trends over time in the distribution of your data, including seasonality --
then ideally you would detrend or de-seasonalize the data in some way
beforehand, or use a more tailored method.
Note that it's not uncommon to apply tests like this to data with mild
seasonality though, for better or worse. The hope is that when you are
comparing (e.g.) errors of two models, the errors may be less seasonal or
trended than the ground-truth data itself, and the difference of errors even
less so.

# Weightings

This method can handle non-constant weights being used (via the
AggregationState) to compute the means of the statistics. In this case the
weights are treated as randomly sampled alongside the statistics themselves,
with the joint distribution of weights and statistics assumed stationary in
time. So when we think about long-run properties of our confidence intervals
etc, this in the context of repeated sampling of new weights as well as new
statistics.

In particular, if the per-experimental-unit weights are fixed, non-random
values that are different at different timesteps, this would violate the
stationarity assumption. It would be more of an issue the more uneven the
weights are.

# Finite-sample bias from the bootstrap

While quite general tools, the bootstrap in general (and the block bootstrap
in particular) are no magic bullet and can exhibit biases especially when the
sample size (or here, the effective sample size after taking into account
temporal dependence) is small.

[1] Politis, D. N. & Romano, J. P. The stationary bootstrap. J. Am. Stat.
Assoc. 89, 1303–1313 (1994).
[2] Politis, D. N. & White, H. Automatic Block-Length Selection for the
Dependent Bootstrap, Econometric Reviews, 23:1, 53-70 (2004).
[3] Patton, A., Politis, D. N. & White, H. Correction to "Automatic
Block-Length Selection for the Dependent Bootstrap" by D. Politis and
H. White, Econometric Reviews, 28:4, 372-375 (2009).
"""

def __init__(
self,
metrics: Mapping[str, metrics_base.Metric],
aggregated_statistics: aggregation.AggregationState,
experimental_unit_dim: str,
n_replicates: int,
mean_block_length: float | None = None,
block_length_rounding_resolution: float | None = 30.0,
stationary_bootstrap_indices_cache_size: int = 50,
):
"""Initializer.

Args:
metrics: The metrics to compute.
aggregated_statistics: The statistics to use to compute the metrics.
experimental_unit_dim: The dimension over which to bootstrap, along which
any serial dependence occurs. Typically this will be a dimension
corresponding to time.
n_replicates: The number of bootstrap replicates to use.
mean_block_length: The mean block length to use. If None, an optimal
block length will be computed automatically for every time series
present in the metrics: for each metric, for each variable within that
metric, and for each index into any dimensions present besides the
`experimental_unit_dim`.
block_length_rounding_resolution: As a performance optimization, we round
off the block length and reuse bootstrap indices when the rounded block
length is the same. This setting controls how aggressitvely we round the
block length when doing this. The rounding is done in the log domain and
the resolution corresponds to the number of distinct rounded values
between consective powers of 10 (e.g. 1 and 10, 10 and 100 etc).
You can set it to None to disable rounding altogether.
stationary_bootstrap_indices_cache_size: The size of the LRU cache used
to cache bootstrap indices as a function of the rounded block length.
This is a memory / speed trade-off.
"""
self._experimental_unit_dim = experimental_unit_dim
self._mean_block_length = mean_block_length
self._n_replicates = n_replicates
self._aggregated_statistics = aggregated_statistics
self._block_length_rounding_resolution = block_length_rounding_resolution
self._stationary_bootstrap_indices = functools.lru_cache(
maxsize=stationary_bootstrap_indices_cache_size)(
stationary_bootstrap_indices)
self._point_estimates = {}
self._resampled_values = {}
for metric_name, metric in metrics.items():
point_estimates, resampled_values = self._bootstrap_results_for_metric(
metric)
self._point_estimates[metric_name] = point_estimates
self._resampled_values[metric_name] = resampled_values

def _optimal_block_length(self, data_array: xr.DataArray) -> float:
if self._mean_block_length is not None:
return self._mean_block_length

assert self._experimental_unit_dim in data_array.dims
if data_array.sizes[self._experimental_unit_dim] < 8:
# At least, arch.bootstrap.optimal_block_length craps out with a very
# unfriendly error if given a smaller array.
raise ValueError(
'Need at least 8 data points along experimental_unit_dim '
f'{self._experimental_unit_dim} to set mean_block_length '
'automatically -- and many more than 8 recommended.')
data_array = data_array.squeeze()
assert data_array.ndim == 1

# We use the arch library to compute optimal block length, since it's a
# somewhat fiddly procedure. (Ideally we would re-use their entire
# implementation of the stationary bootstrap, but it is quite slow and we
# would have to patch it awkwardly to fix some issues and extend it to
# produce p-values.)
#
# .stationary gives the mean block length for use with the stationary
# bootstrap:
result = arch.bootstrap.optimal_block_length(
data_array.data).stationary.item()
# Values <1 can sometimes show up, but 1 is the minimum.
result = max(1.0, result)
if self._block_length_rounding_resolution is not None:
# Rounding this off makes it a useful key for LRU caching of the
# bootstrap indices. These need to be sampled separately for each mean
# block length used, and this forms a significant fraction of total
# running time. The inference of an optimal block length is noisy enough
# that rounding off to 1 or 2 significant figures (or the similar but
# smoother logarithmic rounding below) should be perfectly acceptable.
result = utils.logarithmic_round(
result, self._block_length_rounding_resolution)
return result

def _bootstrap_results_for_metric(
self, metric: metrics_base.Metric) -> tuple[
Mapping[Hashable, xr.DataArray], Mapping[Hashable, xr.DataArray]]:

point_estimates = metrics_base.compute_metric_from_statistics(
metric, self._aggregated_statistics.sum_along_dims(
[self._experimental_unit_dim]).mean_statistics())
per_unit_values = metrics_base.compute_metric_from_statistics(
metric, self._aggregated_statistics.mean_statistics())
sum_weighted_stats = {
stat_name: self._aggregated_statistics.sum_weighted_statistics[
stat.unique_name]
for stat_name, stat in metric.statistics.items()
}
sum_weights = {
stat_name: self._aggregated_statistics.sum_weights[
stat.unique_name]
for stat_name, stat in metric.statistics.items()
}
resampled_values = {}
for var_name in point_estimates.keys():
# Results for different variables will need to be computed separately,
# as the optimal block length will depend on the variable.
#
# We try to avoid computing results for *all* variables every time we
# do a bootstrap resample based on the optimal block length for a single
# *one* of these variables though, using this logic:
if (len(point_estimates) > 1 and
all(var_name in vars for vars in sum_weighted_stats.values())):
# A corresponding variable is present in each Statistic and we make the
# assumption that this variable in the result only depends on these
# corresponding variables in the stats and that we can recompute the
# Metric with the statistics restricted just to this single variable.
# This saves us resampling statistics for all the other variables.
sum_weighted_stats_for_this_var = {
stat_name: {var_name: vars[var_name]}
for stat_name, vars in sum_weighted_stats.items()
}
sum_weights_for_this_var = {
stat_name: {var_name: vars[var_name]}
for stat_name, vars in sum_weights.items()
}
else:
# If there was only a single variable, it's fine to resample all the
# statistics since this will only be done once.
# If there are multiple variables and they don't correspond 1:1 to
# variables in the statistics, then we can't do any better than
# resampling all the statistics even though this may result in some
# redundant work. This should be a rare edge case though.
sum_weighted_stats_for_this_var = sum_weighted_stats
sum_weights_for_this_var = sum_weights

# The optimal block length will also depend on the specific index along
# any extra dimensions present in the metric result, for example suppose a
# lead_time dimension is present, different degrees of autocorrelation may
# be observed for forecast metrics at different lead times.
# And so bootstrap indices will need to be sampled separately for each
# index along any extra dimensions present in the metric result.
#
# We assume that where a dimension of the metric result also occurs in the
# statistics, that the metrics at index i along that dimension only depend
# on the statistics at index i along the same dimension, and that we can
# therefore slice the statistics down to a single index along any such
# dimensions when computing a single index of the metric result.
#
# This assumption isn't strictly guaranteed, but it is true in the vast
# majority of cases, including:
# * The common case of a component-wise metric like RMSE, which is a
# scalar quantity computed independently for each component.
# * Metrics which introduce some additional internal dimensions on their
# statistics, but reduce them down to a scalar value in their output.
# * Metrics which introduce some additional dimensions in their output
# which aren't present in the statistics, but use a different dimension
# name for them to any dimensions used in the statistics.
per_var_resampled_values = utils.apply_to_slices(
functools.partial(self._bootstrap_results_for_metric_scalar,
metric, var_name),
per_unit_values[var_name],
sum_weighted_stats_for_this_var,
sum_weights_for_this_var,
dim=point_estimates[var_name].dims,
)
resampled_values[var_name] = per_var_resampled_values
return point_estimates, resampled_values

def _bootstrap_results_for_metric_scalar(
self,
metric: metrics_base.Metric,
var_name: str,
per_unit_values: xr.DataArray,
sum_weighted_stats: Mapping[str, Mapping[Hashable, xr.DataArray]],
sum_weights: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> xr.DataArray:
n_data = per_unit_values.sizes[self._experimental_unit_dim]
mean_block_length = self._optimal_block_length(per_unit_values)
bootstrap_indices = self._stationary_bootstrap_indices(
n_data=n_data,
mean_block_length=mean_block_length,
n_replicates=self._n_replicates,
)
bootstrap_indices = xr.DataArray(
bootstrap_indices, dims=[self._experimental_unit_dim, _REPLICATE_DIM])

def sum_of_resampled(data):
# Note the dimensions of bootstrap_indices (experimental_unit_dim,
# _REPLICATE_DIM) that we're selecting, will be present in the result of
# the isel call.
return data.isel({self._experimental_unit_dim: bootstrap_indices}).sum(
self._experimental_unit_dim)
sum_weighted_stats, sum_weights = xarray_tree.map_structure(
sum_of_resampled, (sum_weighted_stats, sum_weights))
mean_stats = xarray_tree.map_structure(
lambda x, y: x / y, sum_weighted_stats, sum_weights)
del sum_weighted_stats, sum_weights

return metric.values_from_mean_statistics(mean_stats)[var_name]
Loading
Loading