Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions tests/test_trainers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
import tempfile
import unittest
from typing import List, Mapping

import trlx.utils.logging as logging
from trlx.data.configs import (
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)
from trlx.models.modeling_ppo import PPOConfig
from trlx.utils.loading import get_pipeline, get_trainer

logging.disable_progress_bar()
logging.set_verbosity(logging.ERROR)


def get_default_train_and_eval_prompts() -> Mapping[str, List[str]]:
return dict(
train=[
"The quick brown fox jumps over the lazy",
"The cat sat on the mat next to the",
"What sort of food does a",
"The nextdoor neighbor's fence couldn't keep the",
"When Tom got home from work he had to walk his",
],
eval=[
"I purchased a collar for my new",
"I couldn't help but laugh when the mailman was chased by the",
],
)


def get_default_reward_fn():
def reward_fn(samples: List[str], **kwargs):
return [sample.count("dog") for sample in samples]

return reward_fn


class TestAccelerateBaseTrainer(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.prompt_dataset = get_default_train_and_eval_prompts()

@classmethod
def get_default_config(cls):
return TRLConfig(
train=TrainConfig(
seq_length=16,
epochs=1,
total_steps=8,
batch_size=2,
checkpoint_interval=4,
checkpoint_dir="checkpoints",
eval_interval=8,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
tracker=None,
),
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)),
method=PPOConfig(
name="PPOConfig",
num_rollouts=128,
chunk_size=128,
ppo_epochs=4,
init_kl_coef=0.05,
target=6,
horizon=10000,
gamma=1,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=1,
scale_reward="ignored",
ref_mean=None,
ref_std=None,
cliprange_reward=10,
gen_kwargs=dict(
max_new_tokens=6,
top_k=0,
top_p=1.0,
do_sample=True,
),
),
)

def get_trainer(self, config: TRLConfig):
trainer = get_trainer(config.train.trainer)(
config=config,
reward_fn=get_default_reward_fn(),
metric_fn=None,
stop_sequences=None,
**config.train.trainer_kwargs,
)

max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
train_pipeline = get_pipeline(config.train.pipeline)(
self.prompt_dataset["train"], max_prompt_length, trainer.tokenizer
)
trainer.add_prompt_pipeline(train_pipeline)
trainer.make_experience(config.method.num_rollouts)

eval_pipeline = get_pipeline(config.train.pipeline)(
self.prompt_dataset["eval"], max_prompt_length, trainer.tokenizer
)
trainer.add_eval_pipeline(eval_pipeline)
return trainer

def test_save_checkpoint(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = self.get_default_config()
config.train.checkpoint_dir = tmpdir

trainer = self.get_trainer(config)
trainer.learn()

total_steps = config.train.total_steps
interval = config.train.checkpoint_interval
for i in range(interval, total_steps + 1, interval):
checkpoint_dir = os.path.join(tmpdir, f"checkpoint_{i}")
self.assertTrue(os.path.isdir(checkpoint_dir))
if total_steps % interval != 0:
self.assertTrue(os.path.isdir(os.path.join(tmpdir, f"checkpoint_{total_steps}")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "best_checkpoint")))
4 changes: 3 additions & 1 deletion trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ class TrainConfig:
:param tracker: Tracker to use for logging. Default: "wandb"
:type tracker: str

:param checkpoint_interval: Save model every checkpoint_interval steps
:param checkpoint_interval: Save model every checkpoint_interval steps.
Each checkpoint is stored in a sub-directory of the `TrainConfig.checkpoint_dir`
directory in the format `checkpoint_dir/checkpoint_{step}`.
:type checkpoint_interval: int

:param eval_interval: Evaluate model every eval_interval steps
Expand Down
4 changes: 2 additions & 2 deletions trlx/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from abc import abstractmethod
from typing import Any, Callable, Dict, Iterable
from typing import Any, Callable, Dict, Iterable, Optional

from trlx.data.configs import TRLConfig
from trlx.pipeline import BaseRolloutStore
Expand Down Expand Up @@ -93,7 +93,7 @@ def learn(
pass

@abstractmethod
def save(self, directory=None):
def save(self, directory: Optional[str] = None):
"""Creates a checkpoint of training states"""
pass

Expand Down
22 changes: 13 additions & 9 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,6 @@ def generate_eval(self, input_ids, attention_mask=None, **kwargs):
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)

def save(self, directory: Optional[str] = None):
"""Creates a checkpoint of the optimizer, scheduler and model"""
self.accelerator.save_state(directory or self.config.train.checkpoint_dir)

def save_pretrained(self, directory: Optional[str] = None, **kwargs):
"""Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for
later use.
Expand All @@ -260,15 +256,19 @@ def save_pretrained(self, directory: Optional[str] = None, **kwargs):
`save_pretrained` method.
"""
if directory is None:
directory = f"{self.config.train.checkpoint_dir}/hf_model"
directory = os.path.join(self.config.train.checkpoint_dir, "hf_model")
self.accelerator.wait_for_everyone()
self.accelerator.unwrap_model(self.model).save_pretrained(directory, **kwargs)
if self.accelerator.is_main_process:
self.tokenizer.save_pretrained(directory)

def load(self, directory=None):
def save(self, directory: Optional[str] = None, **kwargs):
"""Creates a checkpoint of the optimizer, scheduler and model"""
self.accelerator.save_state(directory or self.config.train.checkpoint_dir, **kwargs)

def load(self, directory: Optional[str] = None, **kwargs):
"""Load checkpoint of optimizer, scheduler and a model"""
self.accelerator.load_state(directory or self.config.train.checkpoint_dir)
self.accelerator.load_state(directory or self.config.train.checkpoint_dir, **kwargs)

def add_eval_pipeline(self, eval_pipeline):
"""Adds pipeline from with validation prompts"""
Expand Down Expand Up @@ -487,7 +487,9 @@ def learn(self): # noqa: C901
self.iter_count += 1

if self.iter_count % self.config.train.checkpoint_interval == 0:
self.save()
subfolder = f"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}"
directory = os.path.join(self.config.train.checkpoint_dir, subfolder)
self.save(directory)

stats["time/forward"] = forward_time
stats["time/backward"] = backward_time
Expand Down Expand Up @@ -533,7 +535,9 @@ def learn(self): # noqa: C901
tbar.update()

if self.iter_count >= self.total_steps:
self.save()
subfolder = f"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}"
directory = os.path.join(self.config.train.checkpoint_dir, subfolder)
self.save(directory)
return self.evaluate()

self.post_backward_callback()
Expand Down