@@ -432,11 +432,54 @@ def test_mlflow_register_model(tmp_path, monkeypatch):
432
432
name = 'my_model' ,
433
433
)
434
434
435
- assert mlflow .register_model .called_with (model_uri = local_mlflow_save_path ,
436
- name = 'my_catalog.my_schema.my_model' ,
437
- await_registration_for = 300 ,
438
- tags = None ,
439
- registry_uri = 'databricks-uc' )
435
+ mlflow .register_model .assert_called_with (
436
+ model_uri = local_mlflow_save_path ,
437
+ name = 'my_catalog.my_schema.my_model' ,
438
+ await_registration_for = 300 ,
439
+ tags = None ,
440
+ )
441
+ assert mlflow .get_registry_uri () == 'databricks-uc'
442
+
443
+ test_mlflow_logger .post_close ()
444
+
445
+
446
+ @pytest .mark .filterwarnings ('ignore:.*Setuptools is replacing distutils.*:UserWarning' )
447
+ @pytest .mark .filterwarnings ("ignore:.*The 'transformers' MLflow Models integration.*:FutureWarning" )
448
+ def test_mlflow_register_model_with_run_id (tmp_path , monkeypatch ):
449
+ mlflow = pytest .importorskip ('mlflow' )
450
+
451
+ mlflow_uri = tmp_path / Path ('my-test-mlflow-uri' )
452
+ mlflow_exp_name = 'test-log-model-exp-name'
453
+ test_mlflow_logger = MLFlowLogger (
454
+ tracking_uri = mlflow_uri ,
455
+ experiment_name = mlflow_exp_name ,
456
+ model_registry_prefix = 'my_catalog.my_schema' ,
457
+ model_registry_uri = 'databricks-uc' ,
458
+ )
459
+
460
+ monkeypatch .setattr (test_mlflow_logger ._mlflow_client , 'create_model_version' , MagicMock ())
461
+ monkeypatch .setattr (test_mlflow_logger ._mlflow_client , 'create_registered_model' ,
462
+ MagicMock (return_value = type ('MockResponse' , (), {'name' : 'my_catalog.my_schema.my_model' })))
463
+
464
+ mock_state = MagicMock ()
465
+ mock_state .run_name = 'dummy-run-name' # this run name should be unused.
466
+ mock_logger = MagicMock ()
467
+
468
+ local_mlflow_save_path = str (tmp_path / Path ('my_model_local' ))
469
+ test_mlflow_logger .init (state = mock_state , logger = mock_logger )
470
+
471
+ test_mlflow_logger .register_model_with_run_id (
472
+ model_uri = local_mlflow_save_path ,
473
+ name = 'my_model' ,
474
+ )
475
+
476
+ test_mlflow_logger ._mlflow_client .create_model_version .assert_called_with (
477
+ name = 'my_catalog.my_schema.my_model' ,
478
+ source = local_mlflow_save_path ,
479
+ run_id = test_mlflow_logger ._run_id ,
480
+ await_creation_for = 300 ,
481
+ tags = None ,
482
+ )
440
483
assert mlflow .get_registry_uri () == 'databricks-uc'
441
484
442
485
test_mlflow_logger .post_close ()
0 commit comments