|  | 
|  | 1 | +# Copyright 2025 Google LLC | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     https://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +"""Bootstrap methods statistical inference methods for evaluation metrics.""" | 
|  | 15 | + | 
|  | 16 | +from collections.abc import Mapping, Hashable | 
|  | 17 | +import functools | 
|  | 18 | + | 
|  | 19 | +import arch.bootstrap | 
|  | 20 | +import numpy as np | 
|  | 21 | +from weatherbenchX import aggregation | 
|  | 22 | +from weatherbenchX import xarray_tree | 
|  | 23 | +from weatherbenchX.metrics import base as metrics_base | 
|  | 24 | +from weatherbenchX.statistical_inference import base | 
|  | 25 | +from weatherbenchX.statistical_inference import utils | 
|  | 26 | + | 
|  | 27 | +import xarray as xr | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +def stationary_bootstrap_indices( | 
|  | 31 | +    n_data: int, | 
|  | 32 | +    mean_block_length: float, | 
|  | 33 | +    n_replicates: int, | 
|  | 34 | +    dtype: np.typing.DTypeLike = np.int64, | 
|  | 35 | +) -> np.ndarray: | 
|  | 36 | +  """Samples indices with shape (n_data, n_replicates).""" | 
|  | 37 | +  end_block_prob = 1/mean_block_length | 
|  | 38 | +  current_indices = np.random.randint(n_data, size=(n_replicates,), dtype=dtype) | 
|  | 39 | +  all_indices = [current_indices] | 
|  | 40 | +  for _ in range(1, n_data): | 
|  | 41 | +    end_block_flags = np.random.rand(n_replicates) < end_block_prob | 
|  | 42 | +    new_random_indices = np.random.randint( | 
|  | 43 | +        n_data, size=(n_replicates,), dtype=dtype) | 
|  | 44 | +    next_indices = (current_indices+1) % n_data | 
|  | 45 | +    current_indices = np.where( | 
|  | 46 | +        end_block_flags, new_random_indices, next_indices) | 
|  | 47 | +    all_indices.append(current_indices) | 
|  | 48 | +  return np.stack(all_indices, axis=0) | 
|  | 49 | + | 
|  | 50 | + | 
|  | 51 | +_REPLICATE_DIM = 'bootstrap_replicate' | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +class StationaryBootstrap(base.StatisticalInferenceMethod): | 
|  | 55 | +  r"""Stationary bootstrap method of Politis and Romano [1]. | 
|  | 56 | +
 | 
|  | 57 | +  With optimal block length selection from [2], [3]. | 
|  | 58 | +
 | 
|  | 59 | +  [1] Politis, D. N. & Romano, J. P. The stationary bootstrap. J. Am. Stat. | 
|  | 60 | +  Assoc. 89, 1303–1313 (1994). | 
|  | 61 | +  [2] Politis, D. N. & White, H. Automatic Block-Length Selection for the | 
|  | 62 | +  Dependent Bootstrap, Econometric Reviews, 23:1, 53-70 (2004). | 
|  | 63 | +  [3] Patton, A., Politis, D. N. & White, H. Correction to "Automatic | 
|  | 64 | +  Block-Length Selection for the Dependent Bootstrap" by D. Politis and | 
|  | 65 | +  H. White, Econometric Reviews, 28:4, 372-375 (2009). | 
|  | 66 | +  """ | 
|  | 67 | + | 
|  | 68 | +  def __init__( | 
|  | 69 | +      self, | 
|  | 70 | +      metrics: Mapping[str, metrics_base.Metric], | 
|  | 71 | +      aggregated_statistics: aggregation.AggregationState, | 
|  | 72 | +      experimental_unit_dim: str, | 
|  | 73 | +      n_replicates: int, | 
|  | 74 | +      mean_block_length: float | None = None, | 
|  | 75 | +      block_length_rounding_resolution: float | None = 30.0, | 
|  | 76 | +      stationary_bootstrap_indices_cache_size: int = 10, | 
|  | 77 | +  ): | 
|  | 78 | +    self._experimental_unit_dim = experimental_unit_dim | 
|  | 79 | +    self._mean_block_length = mean_block_length | 
|  | 80 | +    self._n_replicates = n_replicates | 
|  | 81 | +    self._aggregated_statistics = aggregated_statistics | 
|  | 82 | +    self._block_length_rounding_resolution = block_length_rounding_resolution | 
|  | 83 | +    self._stationary_bootstrap_indices = functools.lru_cache( | 
|  | 84 | +        maxsize=stationary_bootstrap_indices_cache_size)( | 
|  | 85 | +            stationary_bootstrap_indices) | 
|  | 86 | +    self._original_values = {} | 
|  | 87 | +    self._resampled_values = {} | 
|  | 88 | +    for metric_name, metric in metrics.items(): | 
|  | 89 | +      original_values, resampled_values = self._bootstrap_results_for_metric( | 
|  | 90 | +          metric) | 
|  | 91 | +      self._original_values[metric_name] = original_values | 
|  | 92 | +      self._resampled_values[metric_name] = resampled_values | 
|  | 93 | + | 
|  | 94 | +  def _optimal_block_length(self, data_array: xr.DataArray) -> float: | 
|  | 95 | +    if self._mean_block_length is not None: | 
|  | 96 | +      return self._mean_block_length | 
|  | 97 | + | 
|  | 98 | +    assert self._experimental_unit_dim in data_array.dims | 
|  | 99 | +    if data_array.sizes[self._experimental_unit_dim] < 8: | 
|  | 100 | +      # At least, arch.bootstrap.optimal_block_length craps out with a very | 
|  | 101 | +      # unfriendly error if given a smaller array. | 
|  | 102 | +      raise ValueError( | 
|  | 103 | +          'Need at least 8 data points along experimental_unit_dim ' | 
|  | 104 | +          f'{self._experimental_unit_dim} to set mean_block_length ' | 
|  | 105 | +          'automatically -- and many more than 8 recommended.') | 
|  | 106 | +    data_array = data_array.squeeze() | 
|  | 107 | +    assert data_array.ndim == 1 | 
|  | 108 | + | 
|  | 109 | +    # .stationary gives the mean block length for use with the stationary | 
|  | 110 | +    # bootstrap: | 
|  | 111 | +    result = arch.bootstrap.optimal_block_length( | 
|  | 112 | +        data_array.data).stationary.item() | 
|  | 113 | +    # Values <1 can sometimes show up, but 1 is the minimum. | 
|  | 114 | +    result = max(1.0, result) | 
|  | 115 | +    if self._block_length_rounding_resolution is not None: | 
|  | 116 | +      # Rounding this off makes it a useful key for LRU caching of the | 
|  | 117 | +      # bootstrap indices. These need to be sampled separately for each mean | 
|  | 118 | +      # block length used, and this forms a significant fraction of total | 
|  | 119 | +      # running time. The inference of an optimal block length is noisy enough | 
|  | 120 | +      # that rounding off to 1 or 2 significant figures (or the similar but | 
|  | 121 | +      # smoother logarithmic rounding below) should be perfectly acceptable. | 
|  | 122 | +      result = utils.logarithmic_round( | 
|  | 123 | +          result, self._block_length_rounding_resolution) | 
|  | 124 | +    return result | 
|  | 125 | + | 
|  | 126 | +  def _bootstrap_results_for_metric( | 
|  | 127 | +      self, metric: metrics_base.Metric) -> tuple[ | 
|  | 128 | +          Mapping[Hashable, xr.DataArray], Mapping[Hashable, xr.DataArray]]: | 
|  | 129 | + | 
|  | 130 | +    overall_values = metrics_base.compute_metric_from_statistics( | 
|  | 131 | +        metric, self._aggregated_statistics.sum_along_dims( | 
|  | 132 | +            [self._experimental_unit_dim]).mean_statistics()) | 
|  | 133 | +    per_unit_values = metrics_base.compute_metric_from_statistics( | 
|  | 134 | +        metric, self._aggregated_statistics.mean_statistics()) | 
|  | 135 | +    sum_weighted_stats = { | 
|  | 136 | +        stat_name: self._aggregated_statistics.sum_weighted_statistics[ | 
|  | 137 | +            stat.unique_name] | 
|  | 138 | +        for stat_name, stat in metric.statistics.items() | 
|  | 139 | +    } | 
|  | 140 | +    sum_weights = { | 
|  | 141 | +        stat_name: self._aggregated_statistics.sum_weights[ | 
|  | 142 | +            stat.unique_name] | 
|  | 143 | +        for stat_name, stat in metric.statistics.items() | 
|  | 144 | +    } | 
|  | 145 | +    resampled_values = {} | 
|  | 146 | +    for var_name in overall_values.keys(): | 
|  | 147 | +      # Results for different variables will need to be computed separately, | 
|  | 148 | +      # as the optimal block length will depend on the variable. | 
|  | 149 | +      # | 
|  | 150 | +      # We try to avoid computing results for *all* variables every time we | 
|  | 151 | +      # do a bootstrap resample based on the optimal block length for a single | 
|  | 152 | +      # *one* of these variables though, using this logic: | 
|  | 153 | +      if (len(overall_values) > 1 and | 
|  | 154 | +          all(var_name in vars for vars in sum_weighted_stats.values())): | 
|  | 155 | +        # A corresponding variable is present in each Statistic and we make the | 
|  | 156 | +        # assumption that this variable in the result only depends on these | 
|  | 157 | +        # corresponding variables in the stats and that we can recompute the | 
|  | 158 | +        # Metric with the statistics restricted just to this single variable. | 
|  | 159 | +        # This saves us resampling statistics for all the other variables. | 
|  | 160 | +        sum_weighted_stats_for_this_var = { | 
|  | 161 | +            stat_name: {var_name: vars[var_name]} | 
|  | 162 | +            for stat_name, vars in sum_weighted_stats.items() | 
|  | 163 | +        } | 
|  | 164 | +        sum_weights_for_this_var = { | 
|  | 165 | +            stat_name: {var_name: vars[var_name]} | 
|  | 166 | +            for stat_name, vars in sum_weights.items() | 
|  | 167 | +        } | 
|  | 168 | +      else: | 
|  | 169 | +        # If there was only a single variable, it's fine to resample all the | 
|  | 170 | +        # statistics since this will only be done once. | 
|  | 171 | +        # If there are multiple variables and they don't correspond 1:1 to | 
|  | 172 | +        # variables in the statistics, then we can't do any better than | 
|  | 173 | +        # resampling all the statistics even though this may result in some | 
|  | 174 | +        # redundant work. This should be a rare edge case though. | 
|  | 175 | +        sum_weighted_stats_for_this_var = sum_weighted_stats | 
|  | 176 | +        sum_weights_for_this_var = sum_weights | 
|  | 177 | + | 
|  | 178 | +      # The optimal block length will also depend on the specific component | 
|  | 179 | +      # within the DataArray for the metric result, for example different | 
|  | 180 | +      # degrees of autocorrelation may be observed for forecast metrics at | 
|  | 181 | +      # different lead times. | 
|  | 182 | +      # And so bootstrap indices will need to be sampled separately for each | 
|  | 183 | +      # component along any dimensions present in the metric result. | 
|  | 184 | +      # | 
|  | 185 | +      # We assume that where a dimension of the metric result also occurs in the | 
|  | 186 | +      # statistics, that the metrics at index i along that dimension only depend | 
|  | 187 | +      # on the statistics at index i along the same dimension, and that we can | 
|  | 188 | +      # therefore slice the statistics down to a single index along any such | 
|  | 189 | +      # dimensions when computing a single index of the metric result. | 
|  | 190 | +      # | 
|  | 191 | +      # This assumption isn't strictly guaranteed, but it is true in the vast | 
|  | 192 | +      # majority of cases, including: | 
|  | 193 | +      # * The common case of a per-component metric like RMSE, which is a scalar | 
|  | 194 | +      #   quantity computed independently for each component. | 
|  | 195 | +      # * Metrics which introduce some additional internal dimensions on their | 
|  | 196 | +      #   statistics, but reduce them down to a scalar value in their output. | 
|  | 197 | +      # * Metrics which introduce some additional dimensions in their output | 
|  | 198 | +      #   which aren't present in the statistics, but use a different dimension | 
|  | 199 | +      #   name for them to any dimensions used in the statistics. | 
|  | 200 | +      per_var_resampled_values = utils.apply_to_slices( | 
|  | 201 | +          functools.partial(self._bootstrap_results_for_metric_scalar, | 
|  | 202 | +                            metric, var_name), | 
|  | 203 | +          per_unit_values[var_name], | 
|  | 204 | +          sum_weighted_stats_for_this_var, | 
|  | 205 | +          sum_weights_for_this_var, | 
|  | 206 | +          dim=overall_values[var_name].dims, | 
|  | 207 | +      ) | 
|  | 208 | +      resampled_values[var_name] = per_var_resampled_values | 
|  | 209 | +    return overall_values, resampled_values | 
|  | 210 | + | 
|  | 211 | +  def _bootstrap_results_for_metric_scalar( | 
|  | 212 | +      self, | 
|  | 213 | +      metric: metrics_base.Metric, | 
|  | 214 | +      var_name: str, | 
|  | 215 | +      per_unit_values: xr.DataArray, | 
|  | 216 | +      sum_weighted_stats: Mapping[str, Mapping[Hashable, xr.DataArray]], | 
|  | 217 | +      sum_weights: Mapping[str, Mapping[Hashable, xr.DataArray]], | 
|  | 218 | +  ) -> xr.DataArray: | 
|  | 219 | +    n_data = per_unit_values.sizes[self._experimental_unit_dim] | 
|  | 220 | +    mean_block_length = self._optimal_block_length(per_unit_values) | 
|  | 221 | +    bootstrap_indices = self._stationary_bootstrap_indices( | 
|  | 222 | +        n_data=n_data, | 
|  | 223 | +        mean_block_length=mean_block_length, | 
|  | 224 | +        n_replicates=self._n_replicates, | 
|  | 225 | +    ) | 
|  | 226 | +    bootstrap_indices = xr.DataArray( | 
|  | 227 | +        bootstrap_indices, dims=[self._experimental_unit_dim, _REPLICATE_DIM]) | 
|  | 228 | + | 
|  | 229 | +    def sum_of_resampled(data): | 
|  | 230 | +      return data.isel({self._experimental_unit_dim: bootstrap_indices}).sum( | 
|  | 231 | +          self._experimental_unit_dim) | 
|  | 232 | +    sum_weighted_stats, sum_weights = xarray_tree.map_structure( | 
|  | 233 | +        sum_of_resampled, (sum_weighted_stats, sum_weights)) | 
|  | 234 | +    # del bootstrap_indices | 
|  | 235 | +    mean_stats = xarray_tree.map_structure( | 
|  | 236 | +        lambda x, y: x / y, sum_weighted_stats, sum_weights) | 
|  | 237 | +    del sum_weighted_stats, sum_weights | 
|  | 238 | + | 
|  | 239 | +    return metric.values_from_mean_statistics(mean_stats)[var_name] | 
|  | 240 | + | 
|  | 241 | +  def point_estimates(self) -> base.MetricValues: | 
|  | 242 | +    return self._original_values | 
|  | 243 | + | 
|  | 244 | +  def standard_error_estimates(self) -> base.MetricValues: | 
|  | 245 | +    return xarray_tree.map_structure( | 
|  | 246 | +        lambda x: x.std(_REPLICATE_DIM, ddof=1), self._resampled_values) | 
|  | 247 | + | 
|  | 248 | +  def confidence_intervals( | 
|  | 249 | +      self, alpha: float = 0.05 | 
|  | 250 | +  ) -> tuple[base.MetricValues, base.MetricValues]: | 
|  | 251 | +    # TODO(matthjw): implement BCa intervals. | 
|  | 252 | +    return ( | 
|  | 253 | +        xarray_tree.map_structure( | 
|  | 254 | +            lambda x: x.quantile(alpha/2, _REPLICATE_DIM), | 
|  | 255 | +            self._resampled_values), | 
|  | 256 | +        xarray_tree.map_structure( | 
|  | 257 | +            lambda x: x.quantile(1-alpha/2, _REPLICATE_DIM), | 
|  | 258 | +            self._resampled_values), | 
|  | 259 | +    ) | 
|  | 260 | + | 
|  | 261 | +  def p_values(self, null_value: float = 0.) -> base.MetricValues: | 
|  | 262 | +    """p-value for a two-sided test with the given null hypothesis value.""" | 
|  | 263 | + | 
|  | 264 | +    # Obtained by inverting the percentile confidence interval above. | 
|  | 265 | +    # TODO(matthjw): replace with inverting the BCa interval when implemented. | 
|  | 266 | + | 
|  | 267 | +    def p_value_numpy_1d(resampled: np.ndarray) -> float: | 
|  | 268 | +      data = np.sort(resampled) | 
|  | 269 | +      q = np.linspace(0, 1, data.shape[0]) | 
|  | 270 | +      empirical_cdf_at_null = np.interp(null_value, data, q) | 
|  | 271 | +      return 2 * min(empirical_cdf_at_null, 1 - empirical_cdf_at_null) | 
|  | 272 | + | 
|  | 273 | +    def p_value(resampled: xr.DataArray) -> xr.DataArray: | 
|  | 274 | +      return xr.apply_ufunc( | 
|  | 275 | +          p_value_numpy_1d, | 
|  | 276 | +          resampled, | 
|  | 277 | +          input_core_dims=[[_REPLICATE_DIM]], | 
|  | 278 | +          vectorize=True) | 
|  | 279 | + | 
|  | 280 | +    return xarray_tree.map_structure(p_value, self._resampled_values) | 
0 commit comments