-
Notifications
You must be signed in to change notification settings - Fork 834
Update trainer to ensure type consistency for train_args
and lora_config
#2181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,6 +110,13 @@ def load_and_preprocess_data(dataset_dir, transformer_type, tokenizer): | |
def setup_peft_model(model, lora_config): | ||
# Set up the PEFT model | ||
lora_config = LoraConfig(**json.loads(lora_config)) | ||
reference_lora_config = LoraConfig() | ||
for key, val in lora_config.__dict__.items(): | ||
old_attr = getattr(reference_lora_config, key, None) | ||
if old_attr is not None: | ||
val = type(old_attr)(val) | ||
setattr(lora_config, key, val) | ||
|
||
model.enable_input_require_grads() | ||
model = get_peft_model(model, lora_config) | ||
return model | ||
|
@@ -159,6 +166,15 @@ def parse_arguments(): | |
logger.info("Starting HuggingFace LLM Trainer") | ||
args = parse_arguments() | ||
train_args = TrainingArguments(**json.loads(args.training_parameters)) | ||
reference_train_args = transformers.TrainingArguments( | ||
output_dir=train_args.output_dir | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we are only passing train_args.output_dir. What about other parameters for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
) | ||
for key, val in train_args.to_dict().items(): | ||
old_attr = getattr(reference_train_args, key, None) | ||
if old_attr is not None: | ||
val = type(old_attr)(val) | ||
setattr(train_args, key, val) | ||
|
||
transformer_type = getattr(transformers, args.transformer_type) | ||
|
||
logger.info("Setup model and tokenizer") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you help me understand how would
reference_lora_config
will be populated with correct types with Katib tune API? I believe Katib arguments are coming with--lora_config
which only populateslora_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nsingl00 The logic is quite similar to the process of
TrainingArguments
. When using the Katib tune API,lora_config
will be populated with parameters passed through "--lora_config". However, these parameters might not have the correct types (e.g.,r
might be a string instead of an integer). This code ensures that after Katib populateslora_config
, its attributes are converted to the correct types by comparing them with the default values inreference_lora_config
. This ensures thatlora_config
is correctly configured with the appropriate types for further use. Does this clarify your question?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense