Skip to content

Commit 95c6f08

Browse files
xingyousongcopybara-github
authored andcommitted
Handle edge case when all labels are NaN.
PiperOrigin-RevId: 673447471
1 parent fb55be9 commit 95c6f08

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

vizier/_src/algorithms/designers/gp/output_warpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,18 @@ class InfeasibleWarperComponent(OutputWarper):
455455
nearly equal to v_infeasible.
456456
"""
457457

458-
_shift: float = attr.field(default=np.nan)
458+
_shift: Optional[float] = attr.field(default=None)
459459

460460
def warp(self, labels_arr: types.Array) -> types.Array:
461461
labels_arr = _validate_labels(labels_arr)
462462
labels_arr = labels_arr.flatten()
463+
464+
if np.isnan(labels_arr).all():
465+
# Edge case when all values are NaN.
466+
self._shift = np.nan
467+
labels_arr[:] = 0
468+
return labels_arr[:, np.newaxis]
469+
463470
labels_range = np.nanmax(labels_arr) - np.nanmin(labels_arr)
464471
warped_bad_value = np.nanmin(labels_arr) - (0.5 * labels_range + 1)
465472
num_feasible = labels_arr.size - np.isnan(labels_arr).sum()
@@ -481,7 +488,7 @@ def warp(self, labels_arr: types.Array) -> types.Array:
481488
return labels_arr[:, np.newaxis]
482489

483490
def unwarp(self, labels_arr: types.Array) -> types.Array:
484-
if np.isnan(self._shift):
491+
if self._shift is None:
485492
raise ValueError('warp() needs to be called before unwarp() is called.')
486493
return labels_arr - self._shift
487494

vizier/_src/algorithms/designers/gp/output_warpers_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,16 @@ def test_warper_removes_nans(self):
490490
labels_warped_infeasible = warper_infeasible.warp(labels)
491491
self.assertEqual(np.isnan(labels_warped_infeasible).sum(), 0)
492492

493+
def test_all_nans(self):
494+
warper_infeasible = output_warpers.InfeasibleWarperComponent()
495+
labels = np.array([[np.nan], [np.nan], [np.nan], [np.nan]])
496+
labels_warped_infeasible = warper_infeasible.warp(labels)
497+
expected = np.array([[0.0], [0.0], [0.0], [0.0]])
498+
np.testing.assert_equal(labels_warped_infeasible, expected)
499+
500+
unwarped = warper_infeasible.unwarp(labels_warped_infeasible)
501+
np.testing.assert_equal(unwarped, labels)
502+
493503
def test_known_arrays(self):
494504
# TODO: Add a couple of parameterized test cases.
495505
self.skipTest('No test cases provided')

0 commit comments

Comments
 (0)