@@ -190,6 +190,7 @@ def _run():
190
190
191
191
@test_util .jax_disable_test_missing_functionality ('fit_with_hmc' )
192
192
def test_forecast_from_hmc (self ):
193
+ self .skipTest ('b/275876892' )
193
194
if not tf1 .control_flow_v2_enabled ():
194
195
self .skipTest ('test_forecast_from_hmc does not currently work with TF1' )
195
196
@@ -202,24 +203,13 @@ def test_forecast_from_hmc(self):
202
203
observed_time_series = self ._build_tensor (np .random .randn (
203
204
* (batch_shape + [num_timesteps ])))
204
205
model = self ._build_model (observed_time_series )
205
- try :
206
- samples , _ = fitting .fit_with_hmc (
207
- model ,
208
- observed_time_series ,
209
- num_results = num_results ,
210
- num_warmup_steps = 2 ,
211
- num_variational_steps = 2 ,
212
- )
213
- except NotImplementedError as e :
214
- err_str = str (e )
215
- if "'Tensor' object has no attribute '_name'" in err_str :
216
- # TODO(b/279596122): Enable the test after the upstream issue is fixed.
217
- self .skipTest (
218
- 'test_forecast_from_hmc is failing due to Tensor object'
219
- 'has no attribute `shape`.'
220
- )
221
- else :
222
- raise e
206
+ samples , _ = fitting .fit_with_hmc (
207
+ model ,
208
+ observed_time_series ,
209
+ num_results = num_results ,
210
+ num_warmup_steps = 2 ,
211
+ num_variational_steps = 2 ,
212
+ )
223
213
224
214
forecast_dist = forecast .forecast (
225
215
model ,
0 commit comments