Skip to content

Commit 0b643bb

Browse files
dakingggmvpatel2000
authored andcommitted
Add a register_model_with_run_id api to MLflowLogger (#2967)
1 parent de89606 commit 0b643bb

File tree

2 files changed

+95
-5
lines changed

2 files changed

+95
-5
lines changed

composer/loggers/mlflow_logger.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,53 @@ def log_model(self, flavor: Literal['transformers'], **kwargs):
346346
else:
347347
raise NotImplementedError(f'flavor {flavor} not supported.')
348348

349+
def register_model_with_run_id(
350+
self,
351+
model_uri: str,
352+
name: str,
353+
await_creation_for: int = 300,
354+
tags: Optional[Dict[str, Any]] = None,
355+
):
356+
"""Similar to ``register_model``, but uses a different MLflow API to allow passing in the run id.
357+
358+
Args:
359+
model_uri (str): The URI of the model to register.
360+
name (str): The name of the model to register. Will be appended to ``model_registry_prefix``.
361+
await_creation_for (int, optional): The number of seconds to wait for the model to be registered. Defaults to 300.
362+
tags (Optional[Dict[str, Any]], optional): A dictionary of tags to add to the model. Defaults to None.
363+
"""
364+
if self._enabled:
365+
from mlflow.exceptions import MlflowException
366+
from mlflow.protos.databricks_pb2 import ALREADY_EXISTS, RESOURCE_ALREADY_EXISTS, ErrorCode
367+
368+
full_name = f'{self.model_registry_prefix}.{name}' if len(self.model_registry_prefix) > 0 else name
369+
370+
# This try/catch code is copied from
371+
# https://github.com/mlflow/mlflow/blob/3ba1e50e90a38be19920cb9118593a43d7cfa90e/mlflow/tracking/_model_registry/fluent.py#L90-L103
372+
try:
373+
create_model_response = self._mlflow_client.create_registered_model(full_name)
374+
log.info(f'Successfully registered model {name} with {create_model_response.name}')
375+
except MlflowException as e:
376+
if e.error_code in (
377+
ErrorCode.Name(RESOURCE_ALREADY_EXISTS),
378+
ErrorCode.Name(ALREADY_EXISTS),
379+
):
380+
log.info(f'Registered model {name} already exists. Creating a new version of this model...')
381+
else:
382+
raise e
383+
384+
create_version_response = self._mlflow_client.create_model_version(
385+
name=full_name,
386+
source=model_uri,
387+
run_id=self._run_id,
388+
await_creation_for=await_creation_for,
389+
tags=tags,
390+
)
391+
392+
log.info(
393+
f'Successfully created model version {create_version_response.version} for model {create_version_response.name}'
394+
)
395+
349396
def log_images(
350397
self,
351398
images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]],

tests/loggers/test_mlflow_logger.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,54 @@ def test_mlflow_register_model(tmp_path, monkeypatch):
432432
name='my_model',
433433
)
434434

435-
assert mlflow.register_model.called_with(model_uri=local_mlflow_save_path,
436-
name='my_catalog.my_schema.my_model',
437-
await_registration_for=300,
438-
tags=None,
439-
registry_uri='databricks-uc')
435+
mlflow.register_model.assert_called_with(
436+
model_uri=local_mlflow_save_path,
437+
name='my_catalog.my_schema.my_model',
438+
await_registration_for=300,
439+
tags=None,
440+
)
441+
assert mlflow.get_registry_uri() == 'databricks-uc'
442+
443+
test_mlflow_logger.post_close()
444+
445+
446+
@pytest.mark.filterwarnings('ignore:.*Setuptools is replacing distutils.*:UserWarning')
447+
@pytest.mark.filterwarnings("ignore:.*The 'transformers' MLflow Models integration.*:FutureWarning")
448+
def test_mlflow_register_model_with_run_id(tmp_path, monkeypatch):
449+
mlflow = pytest.importorskip('mlflow')
450+
451+
mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
452+
mlflow_exp_name = 'test-log-model-exp-name'
453+
test_mlflow_logger = MLFlowLogger(
454+
tracking_uri=mlflow_uri,
455+
experiment_name=mlflow_exp_name,
456+
model_registry_prefix='my_catalog.my_schema',
457+
model_registry_uri='databricks-uc',
458+
)
459+
460+
monkeypatch.setattr(test_mlflow_logger._mlflow_client, 'create_model_version', MagicMock())
461+
monkeypatch.setattr(test_mlflow_logger._mlflow_client, 'create_registered_model',
462+
MagicMock(return_value=type('MockResponse', (), {'name': 'my_catalog.my_schema.my_model'})))
463+
464+
mock_state = MagicMock()
465+
mock_state.run_name = 'dummy-run-name' # this run name should be unused.
466+
mock_logger = MagicMock()
467+
468+
local_mlflow_save_path = str(tmp_path / Path('my_model_local'))
469+
test_mlflow_logger.init(state=mock_state, logger=mock_logger)
470+
471+
test_mlflow_logger.register_model_with_run_id(
472+
model_uri=local_mlflow_save_path,
473+
name='my_model',
474+
)
475+
476+
test_mlflow_logger._mlflow_client.create_model_version.assert_called_with(
477+
name='my_catalog.my_schema.my_model',
478+
source=local_mlflow_save_path,
479+
run_id=test_mlflow_logger._run_id,
480+
await_creation_for=300,
481+
tags=None,
482+
)
440483
assert mlflow.get_registry_uri() == 'databricks-uc'
441484

442485
test_mlflow_logger.post_close()

0 commit comments

Comments
 (0)