Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
maybe_create_remote_uploader_downloader_from_uri,
parse_uri)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
Expand All @@ -32,17 +33,41 @@
_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)


def _maybe_get_license_filename(local_dir: str) -> Optional[str]:
def _maybe_get_license_filename(
local_dir: str,
pretrained_model_name: Optional[str] = None) -> Optional[str]:
"""Returns the name of the license file if it exists in the local_dir.

Note: This is intended to be consistent with the code in MLflow.
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152

Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
in which case this function will use that to fetch the correct license information.

If the license file does not exist, returns None.
"""
try:
return next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

# If a pretrained model name is provided, replace the license file with the correct info from HF Hub.
if pretrained_model_name is not None:
log.info(
f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub'
)
os.remove(os.path.join(local_dir, license_filename))
model_card = _fetch_model_card(pretrained_model_name)

local_dir_path = Path(local_dir).absolute()
_write_license_information(pretrained_model_name, model_card,
local_dir_path)
license_filename = next(file for file in os.listdir(local_dir)
if _LICENSE_FILE_PATTERN.search(file))

return license_filename

except StopIteration:
return None

Expand Down Expand Up @@ -330,8 +355,11 @@ def _save_checkpoint(self, state: State, logger: Logger):

mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
license_filename = _maybe_get_license_filename(
local_save_path)
local_save_path,
self.mlflow_logging_config['metadata'].get(
'pretrained_model_name', None))
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(
mlflow_logger._run_id,
Expand Down