Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sdk/python/kubeflow/trainer/hf_llm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ def load_and_preprocess_data(dataset_dir, transformer_type, tokenizer):
def setup_peft_model(model, lora_config):
# Set up the PEFT model
lora_config = LoraConfig(**json.loads(lora_config))
reference_lora_config = LoraConfig()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand how would reference_lora_config will be populated with correct types with Katib tune API? I believe Katib arguments are coming with --lora_config which only populates lora_config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nsingl00 The logic is quite similar to the process of TrainingArguments. When using the Katib tune API, lora_config will be populated with parameters passed through "--lora_config". However, these parameters might not have the correct types (e.g., r might be a string instead of an integer). This code ensures that after Katib populates lora_config, its attributes are converted to the correct types by comparing them with the default values in reference_lora_config. This ensures that lora_config is correctly configured with the appropriate types for further use. Does this clarify your question?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

for key, val in lora_config.__dict__.items():
old_attr = getattr(reference_lora_config, key, None)
if old_attr is not None:
val = type(old_attr)(val)
setattr(lora_config, key, val)

model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
return model
Expand Down Expand Up @@ -159,6 +166,15 @@ def parse_arguments():
logger.info("Starting HuggingFace LLM Trainer")
args = parse_arguments()
train_args = TrainingArguments(**json.loads(args.training_parameters))
reference_train_args = transformers.TrainingArguments(
output_dir=train_args.output_dir
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we are only passing train_args.output_dir. What about other parameters for TrainingArguments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since output_dir is the only parameter in TrainingArguments that has no default value and must be set, other parameters are optional and have default values. These optional parameters will be automatically set to their default values when traversed. For detailed information, refer to this link: https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments

)
for key, val in train_args.to_dict().items():
old_attr = getattr(reference_train_args, key, None)
if old_attr is not None:
val = type(old_attr)(val)
setattr(train_args, key, val)

transformer_type = getattr(transformers, args.transformer_type)

logger.info("Setup model and tokenizer")
Expand Down