Skip to content

TRL Example fix #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 6, 2024
Merged

TRL Example fix #59

merged 2 commits into from
Aug 6, 2024

Conversation

rahul-tuli
Copy link
Collaborator

@rahul-tuli rahul-tuli commented Aug 6, 2024

Description:

This PR addresses a compatibility issue between the llmcompressor's TRL example and the recently updated TRL library. The TRL library recently introduced the SFTConfig class to define sparse fine-tuning arguments. This SFTConfig class inherits from HF Transformers' TrainingArguments, which led to errors when llmcompressor's TrainingArguments were used for initialization.

To resolve this, we have implemented our own SFTConfig class within the llmcompressor framework. Our SFTConfig class is designed to accept LLMCompressor.transformers.TrainingArguments, ensuring seamless integration and compatibility.

Key Changes:

  • Addition of a custom SFTConfig class that inherits from LLMCompressor.transformers.TrainingArguments.
  • Overriding the __init__ method of our SFTTrainer to instantiate and pass our custom SFTConfig to the superclass constructors (__init__ methods).

This fix ensures that the llmcompressor's TRL example runs smoothly without encountering initialization errors, providing a more robust and error-free user experience.

Testing:

The fixes were tested by running examples/trl_mixin/ex_trl_constant.py and the output (truncated) is as follows:

{'loss': 10.0674, 'grad_norm': 0.8202081322669983, 'learning_rate': 1.8805704099821748e-05, 'epoch': 0.37}                                                
{'step_loss': 10.032647132873535, 'perplexity': 22757.431640625, 'epoch': 0.37}                                                                           
{'loss': 10.0082, 'grad_norm': 0.8343890309333801, 'learning_rate': 1.4349376114081997e-05, 'epoch': 0.43}                                                
{'step_loss': 9.981490135192871, 'perplexity': 21622.509765625, 'epoch': 0.43}                                                                            
{'loss': 9.9634, 'grad_norm': 0.8964446783065796, 'learning_rate': 9.893048128342247e-06, 'epoch': 0.48}                                                  
{'step_loss': 9.951886177062988, 'perplexity': 20991.779296875, 'epoch': 0.48}                                                                            
{'loss': 9.9321, 'grad_norm': 0.9238964319229126, 'learning_rate': 5.436720142602496e-06, 'epoch': 0.53}                                                  
 89%|██████████████████████████████████████████████████████████████████████████████████████████████████████▍            | 500/561 [02:10<00:15,  4.05it/s]2024-08-06T17:32:52.579281+0000 | save_pretrained_wrapper | INFO - Inferring a sparsity configuration requires a global sparsity calculation. This can be costly for large models. To skip the calculation of compression statistics set skip_compression_stats=True
Calculating model sparsity: 100%|███████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:01<00:00, 90.72it/s]
2024-08-06T17:32:54.602756+0000 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k_sft_data/checkpoint-500/recipe.yaml
/root/llm-compressor/.venv/lib/python3.10/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
{'step_loss': 9.903632164001465, 'perplexity': 20002.892578125, 'epoch': 0.53}                                                                            
{'loss': 9.9112, 'grad_norm': 0.9303483963012695, 'learning_rate': 9.80392156862745e-07, 'epoch': 0.59}                                                   
{'step_loss': 9.910439491271973, 'perplexity': 20139.5234375, 'epoch': 0.59}                                                                              
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 561/561 [02:29<00:00,  3.74it/s]2024-08-06T17:33:11.123361+0000 | save_pretrained_wrapper | INFO - Inferring a sparsity configuration requires a global sparsity calculation. This can be costly for large models. To skip the calculation of compression statistics set skip_compression_stats=True
Calculating model sparsity: 100%|██████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 471.05it/s]
2024-08-06T17:33:12.121165+0000 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k_sft_data/checkpoint-561/recipe.yaml
{'train_runtime': 151.2166, 'train_samples_per_second': 29.652, 'train_steps_per_second': 3.71, 'train_loss': 10.133576921813203, 'epoch': 0.6}           
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 561/561 [02:31<00:00,  3.71it/s]
manager stage: Modifiers finalized
2024-08-06T17:33:12.926893+0000 | finalize | INFO - Compression lifecycle finalized for 2 modifiers
2024-08-06T17:33:12.927108+0000 | finalize_session | INFO - Finalized LLM Compressor session
2024-08-06T17:33:12.958588+0000 | log_model_sparsification | INFO - Sparsification info for LlamaForCausalLM: 134105856 total params. 
Calculating model sparsity: 100%|██████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 513.99it/s]
2024-08-06T17:33:13.179669+0000 | log_model_sparsification | INFO - There are 134105856 prunable params which have 0.01% avg sparsity.
2024-08-06T17:33:13.187199+0000 | log_model_sparsification | INFO - There are 134105856 quantizable params, with a quantization percentage of 0.00%.
2024-08-06T17:33:13.187508+0000 | save_pretrained_wrapper | INFO - Inferring a sparsity configuration requires a global sparsity calculation. This can be costly for large models. To skip the calculation of compression statistics set skip_compression_stats=True
Calculating model sparsity: 100%|██████████████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 650.71it/s]
2024-08-06T17:33:14.168212+0000 | save_model | INFO - Saved LLM Compressor recipe with model state to ./output_trl_sft_test_7b_gsm8k_sft_data/recipe.yaml

@rahul-tuli rahul-tuli force-pushed the session-mixin-fixes branch from 8d0eb12 to 4b06d64 Compare August 6, 2024 17:38
@rahul-tuli rahul-tuli force-pushed the session-mixin-fixes branch from 4b06d64 to f88906f Compare August 6, 2024 17:43
@rahul-tuli rahul-tuli self-assigned this Aug 6, 2024
@rahul-tuli rahul-tuli marked this pull request as ready for review August 6, 2024 17:55
@bfineran bfineran merged commit b692a07 into main Aug 6, 2024
8 of 12 checks passed
@bfineran bfineran deleted the session-mixin-fixes branch August 6, 2024 19:35
markmc pushed a commit to markmc/llm-compressor that referenced this pull request Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants