Skip to content

Commit 689b24c

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Skip some more tests with Tensor.name resolution issues.
PiperOrigin-RevId: 527958188
1 parent 25c64d5 commit 689b24c

File tree

2 files changed

+9
-18
lines changed

2 files changed

+9
-18
lines changed

tensorflow_probability/python/sts/fitting_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class HMCTestsStatic32(test_util.TestCase, _HMCTests):
340340
@test_util.disable_test_for_backend(
341341
disable_jax=True, reason='No variables in JAX backend.')
342342
def test_chain_batch_shape(self, shape_in, expected_batch_shape_out):
343+
self.skipTest('b/275876892')
343344
batch_shape = [2, 3]
344345
num_results = 1
345346
num_timesteps = 5

tensorflow_probability/python/sts/forecast_test.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def _run():
190190

191191
@test_util.jax_disable_test_missing_functionality('fit_with_hmc')
192192
def test_forecast_from_hmc(self):
193+
self.skipTest('b/275876892')
193194
if not tf1.control_flow_v2_enabled():
194195
self.skipTest('test_forecast_from_hmc does not currently work with TF1')
195196

@@ -202,24 +203,13 @@ def test_forecast_from_hmc(self):
202203
observed_time_series = self._build_tensor(np.random.randn(
203204
*(batch_shape + [num_timesteps])))
204205
model = self._build_model(observed_time_series)
205-
try:
206-
samples, _ = fitting.fit_with_hmc(
207-
model,
208-
observed_time_series,
209-
num_results=num_results,
210-
num_warmup_steps=2,
211-
num_variational_steps=2,
212-
)
213-
except NotImplementedError as e:
214-
err_str = str(e)
215-
if "'Tensor' object has no attribute '_name'" in err_str:
216-
# TODO(b/279596122): Enable the test after the upstream issue is fixed.
217-
self.skipTest(
218-
'test_forecast_from_hmc is failing due to Tensor object'
219-
'has no attribute `shape`.'
220-
)
221-
else:
222-
raise e
206+
samples, _ = fitting.fit_with_hmc(
207+
model,
208+
observed_time_series,
209+
num_results=num_results,
210+
num_warmup_steps=2,
211+
num_variational_steps=2,
212+
)
223213

224214
forecast_dist = forecast.forecast(
225215
model,

0 commit comments

Comments
 (0)