Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@ trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('
trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])
```

#### Trained model is a wrapper over a given autoregressive model
#### Trainers provide a wrapper over their underlying model
```python
trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
```

#### Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!)
```python
trainer.save('/path/to/output/folder/')
trainer.save_pretrained('/path/to/output/folder/')
```

🩹 Warning: Only the `AcceleratePPOTrainer` can write HuggingFace transformers to disk with `save_pretrained` at the moment, as ILQL trainers require inference behavior currently unsupported by available `transformers` architectures.

#### Use 🤗 Accelerate to launch distributed training

```bash
Expand Down
3 changes: 3 additions & 0 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def save(self, directory: Optional[str] = None):
def save_pretrained(self, directory: Optional[str] = None):
"""Save the model and its configuration file to a directory, so that it can be re-loaded with the
`transformers.PreTrainedModel.from_pretrained` method.

NOTE: If a `directory` is not provided, the model will be saved to a sub-directory
of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`).
"""
pass

Expand Down
3 changes: 3 additions & 0 deletions trlx/trainer/accelerate_ilql_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def prepare_learning(self):
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def save_pretrained(self, directory: Optional[str] = None):
"""NOTE: If a `directory` is not provided, the model will be saved to a sub-directory
of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`).
"""
# TODO: Support saving with `transformers.PreTrainedModel.save_pretrained`.
# This is currently not supported becasue `nn.ilql_models.CausalLMWithValueHeads`
# requires a custom `generate` method using its (value/q) heads to steer
Expand Down
6 changes: 5 additions & 1 deletion trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def prepare_learning(self):
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def save_pretrained(self, directory: Optional[str] = None):
directory = f"{directory or self.config.train.checkpoint_dir}/hf_model"
"""NOTE: If a `directory` is not provided, the model will be saved to a sub-directory
of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`).
"""
if directory is None:
directory = f"{self.config.train.checkpoint_dir}/hf_model"
self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory)
self.tokenizer.save_pretrained(directory)