Skip to content

Commit d4dbf1a

Browse files
yaniyuvalWeatherbench2 authors
authored andcommitted
Add the option to exclude years from probabilistic climatology if these years are within a certain range on the initial time.
PiperOrigin-RevId: 755443287
1 parent d29e269 commit d4dbf1a

File tree

2 files changed

+220
-24
lines changed

2 files changed

+220
-24
lines changed

scripts/compute_probabilistic_climatological_forecasts.py

Lines changed: 116 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,26 @@
253253
' with_replacement=False.'
254254
),
255255
)
256+
LEAVE_OUT_IF_IN_CLIMATOLOGY = flags.DEFINE_boolean(
257+
'leave_out_if_in_climatology',
258+
False,
259+
help=(
260+
'If True, and an initial_time is within the climatology year range, '
261+
'sampling will exclude the year of the initial_time and the '
262+
'subsequent year from the pool of available climatology years '
263+
'for that specific initial_time. Only applies if with_replacement=True.'
264+
),
265+
)
266+
NUM_YEARS_TO_EXCLUDE = flags.DEFINE_integer(
267+
'num_years_to_exclude',
268+
0,
269+
help=(
270+
'Relevant only if LEAVE_OUT_IF_IN_CLIMATOLOGY=True. Number of years '
271+
'to exclude from sampling for each initial_time (not including the '
272+
'initial time-s year). '
273+
),
274+
)
275+
256276
SEED = flags.DEFINE_integer(
257277
'seed', 802701, help='Seed for the random number generator.'
258278
)
@@ -341,9 +361,16 @@ def _get_ensemble_size(
341361
climatology_start_year: int,
342362
climatology_end_year: int,
343363
day_window_size: int,
364+
leave_out_if_in_climatology: bool,
344365
) -> int:
345366
"""Computes the ensemble size from FLAGS."""
346367
if ensemble_size_flag_value == -1:
368+
if leave_out_if_in_climatology:
369+
raise flags.ValidationError(
370+
'ensemble_size=-1 (all combinations) is not supported when '
371+
'leave_out_if_in_climatology=True, as the number of available '
372+
'years can change per initial_time.'
373+
)
347374
return len(
348375
_get_possible_year_values(climatology_start_year, climatology_end_year)
349376
) * len(_get_possible_day_perturbation_values(day_window_size))
@@ -375,6 +402,8 @@ def _get_sampled_init_times(
375402
with_replacement: bool,
376403
sample_hold_days: int,
377404
initial_time_edge_behavior: str,
405+
leave_out_if_in_climatology: bool,
406+
num_years_to_exclude: int,
378407
seed: int,
379408
) -> np.ndarray:
380409
"""For each output time, get the times to sample from observations.
@@ -405,6 +434,10 @@ def _get_sampled_init_times(
405434
perturbation. 0 means switch perturbations every consecutive init time.
406435
initial_time_edge_behavior: How to deal with perturbations that move the
407436
sampled day outside of sampled year.
437+
leave_out_if_in_climatology: If True, and initial_time's year is within
438+
climatology, exclude its year and subsequent num_years_to_exclude years.
439+
num_years_to_exclude: Number of years after initial_time's year to
440+
exclude.
408441
seed: Integer seed for the RNG.
409442
410443
Returns:
@@ -428,41 +461,89 @@ def _get_sampled_init_times(
428461
day_perturbation_values = _get_possible_day_perturbation_values(
429462
day_window_size
430463
)
431-
year_values = _get_possible_year_values(
464+
base_climatology_year_pool = _get_possible_year_values(
432465
climatology_start_year, climatology_end_year
433466
)
434467
n_days = len(day_perturbation_values)
435-
n_years = len(year_values)
468+
n_base_years = len(base_climatology_year_pool)
436469
n_times = len(output_times)
437470
if ensemble_size > 0:
438471
pass
439472
elif ensemble_size == -1:
440-
ensemble_size = n_days * n_years
473+
if leave_out_if_in_climatology:
474+
raise flags.ValidationError(
475+
'ensemble_size=-1 (all combinations) is not supported when '
476+
'leave_out_if_in_climatology=True.'
477+
)
478+
ensemble_size = n_days * n_base_years
441479
else:
442480
raise ValueError(f'{ensemble_size=} was not > 0 or -1.')
443-
444-
sample_shape = (ensemble_size, len(output_times))
481+
sample_shape = (ensemble_size, n_times)
482+
years = np.zeros(sample_shape, dtype=int)
445483

446484
# Get sampled years and day_perturbations.
447485
if with_replacement:
448-
# In this case, years and days are iid samples. Easy!
449-
years = rng.integers(
450-
year_values.min(),
451-
year_values.max() + 1, # +1 because the interval is open on the right.
452-
size=sample_shape,
453-
)
454486
day_perturbations = rng.integers(
455487
day_perturbation_values.min(),
456488
day_perturbation_values.max() + 1,
457489
size=sample_shape,
458490
)
459-
else:
491+
if leave_out_if_in_climatology:
492+
if not base_climatology_year_pool.size and ensemble_size > 0:
493+
raise ValueError(
494+
'Climatology year range is empty. Cannot sample years for '
495+
'leave_out_if_in_climatology=True. Check flags: '
496+
f'climatology_start_year={climatology_start_year}, '
497+
f'climatology_end_year={climatology_end_year}.'
498+
)
499+
for j, current_output_time in enumerate(output_times):
500+
current_output_year = current_output_time.year
501+
available_years_for_this_time = [
502+
y
503+
for y in base_climatology_year_pool
504+
if y < current_output_year
505+
or y > current_output_year + num_years_to_exclude
506+
]
507+
508+
if not available_years_for_this_time:
509+
if ensemble_size > 0:
510+
raise ValueError(
511+
'No available climatology years to sample for output_time'
512+
)
513+
elif ensemble_size > 0:
514+
years[:, j] = rng.choice(
515+
available_years_for_this_time, size=ensemble_size, replace=True
516+
)
517+
else:
518+
if not n_base_years and ensemble_size > 0:
519+
raise ValueError(
520+
'Climatology year range is empty. Cannot sample years. '
521+
f'Check climatology_start_year={climatology_start_year}, '
522+
f'climatology_end_year={climatology_end_year}.'
523+
)
524+
if n_base_years > 0:
525+
years = rng.integers(
526+
base_climatology_year_pool.min(),
527+
base_climatology_year_pool.max() + 1,
528+
size=sample_shape,
529+
)
530+
531+
else: # with_replacement == False
532+
if leave_out_if_in_climatology:
533+
raise NotImplementedError(
534+
'leave_out_if_in_climatology=True is not currently supported with'
535+
' with_replacement=False due to the complexity of ensuring unique'
536+
' samples per output time with dynamically changing year'
537+
' availability.'
538+
)
460539
if not isinstance(seed, int):
461540
raise AssertionError(
462541
f'{seed=} was not an integer. Seeding with None causes a nasty bug'
463542
' whereby different choices will be used for day_perturbations and'
464543
' years!'
465544
)
545+
n_years = len(base_climatology_year_pool)
546+
year_values = base_climatology_year_pool
466547
tiled_day_window_values = _repeat_along_new_axis(
467548
# tiled_day_window_values.shape = [n_years, n_days, n_times].
468549
# tiled_day_window_values[i, :, j] = day_window_values for every i, j.
@@ -497,27 +578,35 @@ def _get_sampled_init_times(
497578
dayofyears = output_times.dayofyear.values + day_perturbations
498579

499580
if initial_time_edge_behavior == WRAP_YEAR:
500-
for year in range(climatology_start_year, climatology_end_year + 1):
501-
mask = years == year
502-
days_in_this_year = 365 + calendar.isleap(year)
581+
for year_in_sample in np.unique(years):
582+
if ensemble_size == 0:
583+
continue
584+
mask = years == year_in_sample
585+
days_in_this_year = 365 + calendar.isleap(year_in_sample)
503586
dayofyears[mask] = (dayofyears[mask] - 1) % days_in_this_year + 1
504587

505588
elif initial_time_edge_behavior == REFLECT_RANGE:
506-
for year in {climatology_start_year, climatology_end_year}:
507-
mask = years == year
508-
days_in_this_year = 365 + calendar.isleap(year)
509-
if year == climatology_start_year:
510-
# Transform e.g. 1 --> 1, 0 --> 2, -1 --> 3
589+
for year_at_climatology_edge in {
590+
climatology_start_year,
591+
climatology_end_year,
592+
}:
593+
if ensemble_size == 0:
594+
continue
595+
mask = years == year_at_climatology_edge
596+
if not np.any(mask):
597+
continue
598+
599+
days_in_this_year = 365 + calendar.isleap(year_at_climatology_edge)
600+
if year_at_climatology_edge == climatology_start_year:
511601
dayofyears[mask] = np.where(
512602
dayofyears[mask] >= 1,
513603
dayofyears[mask],
514604
np.abs(dayofyears[mask]) + 2,
515605
)
516-
elif year == climatology_end_year:
606+
elif year_at_climatology_edge == climatology_end_year:
517607
dayofyears[mask] = np.where(
518608
dayofyears[mask] <= days_in_this_year,
519609
dayofyears[mask],
520-
# If d > 365, set to 2*365 - d = 365 - (d - 365)
521610
2 * days_in_this_year - dayofyears[mask],
522611
)
523612
elif initial_time_edge_behavior == NO_EDGE:
@@ -549,10 +638,10 @@ def _get_sampled_init_times(
549638
)
550639
hold_idx = np.repeat(
551640
# E.g. hold_idx = [0, 0, ..., 0, 1, 1, ..., 1, 2, ...]
552-
np.arange(len(output_times) // hold_stride + 1)[:, np.newaxis],
641+
np.arange(n_times // hold_stride + 1)[:, np.newaxis],
553642
hold_stride,
554643
axis=1,
555-
).ravel()[: len(output_times)]
644+
).ravel()[:n_times]
556645

557646
# Convert np datetimes into δ days, sample-hold, then add back to datetimes.
558647
delta_days = np.array(
@@ -696,6 +785,7 @@ def main(argv: abc.Sequence[str]) -> None:
696785
CLIMATOLOGY_START_YEAR.value,
697786
CLIMATOLOGY_END_YEAR.value,
698787
DAY_WINDOW_SIZE.value,
788+
LEAVE_OUT_IF_IN_CLIMATOLOGY.value,
699789
)
700790

701791
# Define output times and the template.
@@ -723,6 +813,8 @@ def main(argv: abc.Sequence[str]) -> None:
723813
with_replacement=WITH_REPLACEMENT.value,
724814
initial_time_edge_behavior=INITIAL_TIME_EDGE_BEHAVIOR.value,
725815
sample_hold_days=SAMPLE_HOLD_DAYS.value,
816+
leave_out_if_in_climatology=LEAVE_OUT_IF_IN_CLIMATOLOGY.value,
817+
num_years_to_exclude=NUM_YEARS_TO_EXCLUDE.value,
726818
seed=SEED.value,
727819
).ravel()
728820

0 commit comments

Comments
 (0)