Skip to content

Commit 11853f0

Browse files
author
WeatherBenchX authors
committed
Adds a regular grid Latitude and longitude binning method.
PiperOrigin-RevId: 809092370
1 parent cc53740 commit 11853f0

File tree

2 files changed

+200
-8
lines changed

2 files changed

+200
-8
lines changed

weatherbenchX/binning.py

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,43 @@ def create_bin_mask(
4949
"""
5050

5151

52-
def _region_to_mask(
53-
lat: xr.DataArray,
54-
lon: xr.DataArray,
55-
lat_lims: Tuple[int, int],
56-
lon_lims: Tuple[int, int],
52+
def _create_lat_mask(
53+
lat: xr.DataArray, lat_lims: Tuple[int, int]
5754
) -> xr.DataArray:
58-
"""Computes a boolean mask for a lat/lon limits region."""
55+
"""Computes a boolean mask for a latitude limits region."""
5956
if lat_lims[0] >= lat_lims[1]:
6057
raise ValueError(
6158
f'`lat_lims[0]` must be smaller than `lat_lims[1]`, got {lat_lims}`'
6259
)
63-
lat_mask = np.logical_and(lat >= lat_lims[0], lat <= lat_lims[1])
60+
return np.logical_and(lat >= lat_lims[0], lat <= lat_lims[1])
6461

62+
63+
def _create_lon_mask(
64+
lon: xr.DataArray, lon_lims: Tuple[int, int]
65+
) -> xr.DataArray:
66+
"""Computes a boolean mask for a longitude limits region."""
6567
# Make sure we are in the [0, 360] interval.
6668
lon = np.mod(lon, 360)
6769
lon_lims = np.mod(lon_lims[0], 360), np.mod(lon_lims[1], 360)
68-
6970
if lon_lims[1] > lon_lims[0]:
7071
# Same as the latitude.
7172
lon_mask = np.logical_and(lon >= lon_lims[0], lon <= lon_lims[1])
7273
else:
7374
# In this case it means we need to wrap longitude around the other side of
7475
# the globe.
7576
lon_mask = np.logical_or(lon <= lon_lims[1], lon >= lon_lims[0])
77+
return lon_mask
78+
79+
80+
def _region_to_mask(
81+
lat: xr.DataArray,
82+
lon: xr.DataArray,
83+
lat_lims: Tuple[int, int],
84+
lon_lims: Tuple[int, int],
85+
) -> xr.DataArray:
86+
"""Computes a boolean mask for a lat/lon limits region."""
87+
lat_mask = _create_lat_mask(lat, lat_lims)
88+
lon_mask = _create_lon_mask(lon, lon_lims)
7689
return np.logical_and(lat_mask, lon_mask)
7790

7891

@@ -185,6 +198,93 @@ def create_bin_mask(
185198
return masks
186199

187200

201+
class LatitudeBins(Binning):
202+
"""Class for binning by latitude bands."""
203+
204+
def __init__(
205+
self,
206+
degrees: float,
207+
lat_range: Tuple[int, int] = (-90, 90),
208+
bin_dim_name: str = 'latitude_bins',
209+
):
210+
"""Init.
211+
212+
Args:
213+
degrees: Grid spacing in degrees.
214+
lat_range: Tuple of (min_lat, max_lat).
215+
bin_dim_name: Name of binning dimension.
216+
"""
217+
super().__init__(bin_dim_name)
218+
self._degrees = degrees
219+
self._lat_bins = np.arange(
220+
lat_range[0], lat_range[1] + self._degrees, self._degrees
221+
)
222+
223+
def create_bin_mask(
224+
self,
225+
statistic: xr.DataArray,
226+
) -> xr.DataArray:
227+
"""Creates a bin mask for a statistic."""
228+
masks = []
229+
for lat_start in self._lat_bins[:-1]:
230+
lat_end = lat_start + self._degrees
231+
mask = _create_lat_mask(
232+
statistic.latitude,
233+
(lat_start, lat_end),
234+
)
235+
# Broadcast the mask to the shape of statistic
236+
mask = mask.broadcast_like(statistic)
237+
mask = mask.expand_dims(dim=self.bin_dim_name, axis=0)
238+
mask.coords[self.bin_dim_name] = np.array([lat_start])
239+
masks.append(mask)
240+
return xr.concat(masks, dim=self.bin_dim_name)
241+
242+
243+
class LongitudeBins(Binning):
244+
"""Class for binning by longitude bands."""
245+
246+
def __init__(
247+
self,
248+
degrees: float,
249+
lon_range: Tuple[int, int] = (0, 360),
250+
bin_dim_name: str = 'longitude_bins',
251+
):
252+
"""Init.
253+
254+
Args:
255+
degrees: Grid spacing in degrees.
256+
lon_range: Tuple of (min_lon, max_lon).
257+
bin_dim_name: Name of binning dimension.
258+
"""
259+
super().__init__(bin_dim_name)
260+
self._degrees = degrees
261+
lon_end = lon_range[1]
262+
if lon_range[0] >= lon_range[1]:
263+
lon_end += 360
264+
self._lon_bins = np.arange(
265+
lon_range[0], lon_end + self._degrees, self._degrees
266+
)
267+
268+
def create_bin_mask(
269+
self,
270+
statistic: xr.DataArray,
271+
) -> xr.DataArray:
272+
"""Creates a bin mask for a statistic."""
273+
masks = []
274+
for lon_start in self._lon_bins[:-1]:
275+
lon_end = lon_start + self._degrees
276+
mask = _create_lon_mask(
277+
statistic.longitude,
278+
(lon_start, lon_end),
279+
)
280+
# Broadcast the mask to the shape of statistic
281+
mask = mask.broadcast_like(statistic)
282+
mask = mask.expand_dims(dim=self.bin_dim_name, axis=0)
283+
mask.coords[self.bin_dim_name] = np.array([np.mod(lon_start, 360)])
284+
masks.append(mask)
285+
return xr.concat(masks, dim=self.bin_dim_name)
286+
287+
188288
def vectorized_coord_mask(
189289
coord: xr.DataArray,
190290
coord_name: str,

weatherbenchX/binning_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,37 @@ def test_by_time_unit_from_seconds_binning(self):
151151
mask['prediction_timedelta_sec_hour'], np.arange(0, 6)
152152
)
153153

154+
@parameterized.parameters(
155+
('second', None, np.arange(0, 60)),
156+
('second', [0, 15, 30, 45], [0, 15, 30, 45]),
157+
('minute', None, np.arange(0, 60)),
158+
('minute', [0, 30], [0, 30]),
159+
('hour', None, np.arange(0, 24)),
160+
('hour', [0, 6, 12, 18], [0, 6, 12, 18]),
161+
)
162+
def test_by_time_unit_from_seconds_binning_with_units(
163+
self, unit, bins, expected_bins
164+
):
165+
statistic_values = test_utils.mock_prediction_data(
166+
time_start='2020-01-01T00',
167+
time_stop='2020-01-01T01',
168+
time_resolution='1 hr',
169+
lead_resolution='1 second',
170+
lead_stop='24 hour',
171+
)['2m_temperature']
172+
statistic_values = statistic_values.assign_coords({
173+
'prediction_timedelta_sec': (
174+
statistic_values.prediction_timedelta.dt.total_seconds()
175+
)
176+
})
177+
binning_obj = binning.ByTimeUnitFromSeconds(
178+
unit, 'prediction_timedelta_sec', bins=bins
179+
)
180+
mask = binning_obj.create_bin_mask(statistic_values)
181+
np.testing.assert_array_equal(
182+
mask[f'prediction_timedelta_sec_{unit}'].values, expected_bins
183+
)
184+
154185
def test_by_coord_bins(self):
155186
target_path = resources.files('weatherbenchX').joinpath(
156187
'test_data/metar-timeNominal-by-month'
@@ -222,6 +253,67 @@ def test_by_sets(self):
222253
self.assertEqual(mask.sum('index').sel(station_subset='wrong_set'), 0)
223254
self.assertLen(statistic, mask.sum('index').sel(station_subset='global'))
224255

256+
@parameterized.parameters(
257+
(10, (-90, 90), 18),
258+
(30, (-90, 90), 6),
259+
(20, (0, 60), 3),
260+
)
261+
def test_latitude_bins(self, degrees, lat_range, expected_bins):
262+
statistic_values = test_utils.mock_prediction_data(
263+
time_start='2020-01-01T00', time_stop='2020-01-01T01'
264+
)['2m_temperature']
265+
binning_obj = binning.LatitudeBins(degrees, lat_range)
266+
mask = binning_obj.create_bin_mask(statistic_values)
267+
self.assertEqual(mask.latitude_bins.shape[0], expected_bins)
268+
self.assertTrue(np.all(mask.latitude_bins.values >= lat_range[0]))
269+
self.assertTrue(np.all(mask.latitude_bins.values < lat_range[1]))
270+
self.assertEqual(mask.shape, (expected_bins,) + statistic_values.shape)
271+
# Check that a point is in the correct bin
272+
# Find the latitude closest to 25
273+
lat_val = 25
274+
if not (lat_range[0] <= lat_val < lat_range[1]):
275+
# If 25 is not in range, pick a value that is.
276+
lat_val = (lat_range[0] + lat_range[1]) / 2
277+
lat_idx = np.argmin(np.abs(statistic_values.latitude.values - lat_val))
278+
lon_idx = np.argmin(np.abs(statistic_values.longitude.values - 0))
279+
280+
# Find the bin that contains the selected latitude
281+
expected_bin_idx = (statistic_values.latitude.values[lat_idx] - lat_range[0]) // degrees
282+
self.assertTrue(
283+
mask.isel(latitude_bins=int(expected_bin_idx), latitude=lat_idx, longitude=lon_idx).values.all()
284+
)
285+
286+
@parameterized.parameters(
287+
(10, (0, 360), 36, 10),
288+
(30, (0, 360), 12, 150),
289+
(60, (-180, 180), 6, 0),
290+
(90, (270, 360), 1, 300),
291+
)
292+
def test_longitude_bins(self, degrees, lon_range, expected_bins, test_lon):
293+
statistic_values = test_utils.mock_prediction_data(
294+
time_start='2020-01-01T00', time_stop='2020-01-01T01'
295+
)['2m_temperature']
296+
binning_obj = binning.LongitudeBins(degrees, lon_range)
297+
mask = binning_obj.create_bin_mask(statistic_values)
298+
self.assertEqual(mask.longitude_bins.shape[0], expected_bins)
299+
self.assertEqual(mask.shape, (expected_bins,) + statistic_values.shape)
300+
# Check wrapping
301+
if lon_range == (-180, 180):
302+
self.assertTrue(0 in mask.longitude_bins.values)
303+
304+
# Find the longitude closest to test_lon
305+
lon_idx = np.argmin(np.abs(statistic_values.longitude.values - test_lon))
306+
lat_idx = np.argmin(np.abs(statistic_values.latitude.values - 0))
307+
lon_val = statistic_values.longitude.values[lon_idx]
308+
309+
# Calculate expected bin index: (lon_val - lon_range[0]) // degrees
310+
# This works even with wrapping ranges because lon_bins is constructed correctly.
311+
expected_bin_idx = (lon_val - lon_range[0]) // degrees
312+
313+
self.assertTrue(
314+
mask.isel(longitude_bins=int(expected_bin_idx), latitude=lat_idx, longitude=lon_idx).values.all()
315+
)
316+
225317

226318
if __name__ == '__main__':
227319
absltest.main()

0 commit comments

Comments
 (0)