13
13
from kubeflow .katib import V1beta1TrialParameterSpec
14
14
from kubeflow .katib import V1beta1TrialTemplate
15
15
from kubeflow .katib .constants import constants
16
+ import kubeflow .katib .katib_api_pb2 as katib_api_pb2
16
17
from kubernetes .client import V1ObjectMeta
17
18
import pytest
18
19
@@ -238,7 +239,7 @@ def create_experiment(
238
239
239
240
240
241
@pytest .fixture
241
- def katib_client ():
242
+ def katib_client_create_experiment ():
242
243
with patch (
243
244
"kubernetes.client.CustomObjectsApi" ,
244
245
return_value = Mock (
@@ -255,14 +256,103 @@ def katib_client():
255
256
256
257
257
258
@pytest .mark .parametrize ("test_name,kwargs,expected_output" , test_create_experiment_data )
258
- def test_create_experiment (katib_client , test_name , kwargs , expected_output ):
259
+ def test_create_experiment (katib_client_create_experiment , test_name , kwargs , expected_output ):
259
260
"""
260
261
test create_experiment function of katib client
261
262
"""
262
263
print ("\n \n Executing test:" , test_name )
263
264
try :
264
- katib_client .create_experiment (** kwargs )
265
+ katib_client_create_experiment .create_experiment (** kwargs )
265
266
assert expected_output == TEST_RESULT_SUCCESS
266
267
except Exception as e :
267
268
assert type (e ) is expected_output
268
269
print ("test execution complete" )
270
+
271
+
272
+ def get_observation_log_response (* args , ** kwargs ):
273
+ if kwargs .get ("timeout" ) == 0 :
274
+ raise TimeoutError
275
+ elif args [0 ].trial_name == "invalid" :
276
+ raise RuntimeError
277
+ else :
278
+ return katib_api_pb2 .GetObservationLogReply (
279
+ observation_log = katib_api_pb2 .ObservationLog (
280
+ metric_logs = [
281
+ katib_api_pb2 .MetricLog (
282
+ time_stamp = "2024-07-29T15:09:08Z" ,
283
+ metric = katib_api_pb2 .Metric (name = "result" ,value = "0.99" )
284
+ )
285
+ ]
286
+ )
287
+ )
288
+
289
+ test_get_trial_metrics_data = [
290
+ (
291
+ "valid trial name" ,
292
+ {
293
+ "name" : "example" ,
294
+ "namespace" : "valid" ,
295
+ "timeout" : constants .DEFAULT_TIMEOUT
296
+ },
297
+ [
298
+ katib_api_pb2 .MetricLog (
299
+ time_stamp = "2024-07-29T15:09:08Z" ,
300
+ metric = katib_api_pb2 .Metric (name = "result" ,value = "0.99" )
301
+ )
302
+ ]
303
+ ),
304
+ (
305
+ "invalid trial name" ,
306
+ {
307
+ "name" : "invalid" ,
308
+ "namespace" : "invalid" ,
309
+ "timeout" : constants .DEFAULT_TIMEOUT
310
+ },
311
+ RuntimeError
312
+ ),
313
+ (
314
+ "GetObservationLog timeout error" ,
315
+ {
316
+ "name" : "example" ,
317
+ "namespace" : "valid" ,
318
+ "timeout" : 0
319
+ },
320
+ RuntimeError
321
+ )
322
+ ]
323
+
324
+
325
+ @pytest .fixture
326
+ def katib_client_get_trial_metrics ():
327
+ with patch (
328
+ "kubernetes.client.CustomObjectsApi" ,
329
+ return_value = Mock (),
330
+ ), patch (
331
+ "kubernetes.config.load_kube_config" ,
332
+ return_value = Mock ()
333
+ ):
334
+ client = KatibClient ()
335
+ yield client
336
+
337
+
338
+ @pytest .fixture
339
+ def mock_get_observation_log ():
340
+ with patch ("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub" ) as mock :
341
+ mock_instance = mock .return_value
342
+ mock_instance .GetObservationLog .side_effect = get_observation_log_response
343
+ yield mock_instance
344
+
345
+
346
+ @pytest .mark .parametrize ("test_name,kwargs,expected_output" , test_get_trial_metrics_data )
347
+ def test_get_trial_metrics (test_name , kwargs , expected_output , katib_client_get_trial_metrics , mock_get_observation_log ):
348
+ """
349
+ test get_trial_metrics function of katib client
350
+ """
351
+ print ("\n \n Executing test:" , test_name )
352
+ try :
353
+ metrics = katib_client_get_trial_metrics .get_trial_metrics (** kwargs )
354
+ for i in range (len (metrics )):
355
+ assert metrics [i ] == expected_output [i ]
356
+ except Exception as e :
357
+ assert type (e ) is expected_output
358
+ print ("test execution complete" )
0 commit comments