Skip to content

BCOTrainer version upgrade fixes #2867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"scikit": ["scikit-learn"],
"bco": ["scikit-learn", "joblib"],
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
# vllm is not available on Windows
# vllm 0.7.3 causes hanging while gathering. temporary pinning the version until the issue is resolved
Expand Down
4 changes: 2 additions & 2 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available

from trl import BaseBinaryJudge, BasePairwiseJudge, is_diffusers_available, is_llm_blender_available
from trl.import_utils import is_mergekit_available
from trl.import_utils import is_joblib_available, is_mergekit_available


# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use
Expand Down Expand Up @@ -62,7 +62,7 @@ def require_sklearn(test_case):
"""
Decorator marking a test that requires sklearn. Skips the test if sklearn is not available.
"""
return unittest.skipUnless(is_sklearn_available(), "test requires sklearn")(test_case)
return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case)


def require_comet(test_case):
Expand Down
5 changes: 5 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_rich_available = _is_package_available("rich")
_unsloth_available = _is_package_available("unsloth")
_vllm_available = _is_package_available("vllm")
_joblib_available = _is_package_available("joblib")


def is_deepspeed_available() -> bool:
Expand Down Expand Up @@ -59,6 +60,10 @@ def is_vllm_available() -> bool:
return _vllm_available


def is_joblib_available() -> bool:
return _joblib_available


class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
Expand Down
92 changes: 46 additions & 46 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.logging import get_logger
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from packaging import version
Expand All @@ -54,6 +55,7 @@
from transformers.utils import is_peft_available

from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_joblib_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .bco_config import BCOConfig
from .utils import (
Expand All @@ -78,14 +80,19 @@
if is_sklearn_available():
from sklearn.linear_model import LogisticRegression

if is_joblib_available():
import joblib

if is_deepspeed_available():
import deepspeed

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

logger = get_logger(__name__)

RUNNING_NAME = "running.json"
CLF_NAME = "clf.pt"
CLF_NAME = "clf.pkl"


def _tokenize(
Expand Down Expand Up @@ -346,15 +353,15 @@ def __init__(
embedding_func: Optional[Callable] = None,
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
):
if not is_sklearn_available():
if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()):
raise ImportError(
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
"BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`."
)

if type(args) is TrainingArguments:
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")

if not isinstance(model, str) and ref_model is model:
if not isinstance(model, str) and model is not None and ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
Expand Down Expand Up @@ -600,10 +607,7 @@ def make_inputs_require_grad(module, input, output):
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
)
# Shuffle the datasets
train_dataset = train_dataset.shuffle(seed=args.data_seed)
if eval_dataset is not None:
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)

# Tokenize and prepare the training datasets
train_dataset = train_dataset.map(
_tokenize,
Expand Down Expand Up @@ -666,9 +670,6 @@ def make_inputs_require_grad(module, input, output):
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
)

desirable = desirable.shuffle(seed=args.data_seed)
undesirable = undesirable.shuffle(seed=args.data_seed)

super().__init__(
model=model,
args=args,
Expand Down Expand Up @@ -717,7 +718,7 @@ def make_inputs_require_grad(module, input, output):

self.running = RunningMoments(accelerator=self.accelerator)

if self.embedding_func is None:
if self.embedding_func is None or args.resume_from_checkpoint:
return

chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
Expand All @@ -731,6 +732,13 @@ def make_inputs_require_grad(module, input, output):
self.clf = LogisticRegression(class_weight="balanced").fit(
embeddings.cpu().float().numpy(), labels.cpu().numpy()
)
chosen_mean = self.clf.score(
chosen_embeddings.cpu().float().numpy(), torch.ones_like(chosen_embeddings[:, 0]).cpu().numpy()
)
rejected_mean = self.clf.score(
rejected_embeddings.cpu().float().numpy(), torch.zeros_like(rejected_embeddings[:, 0]).cpu().numpy()
)
logger.info(f"UDM classifier training scores: chosen: {chosen_mean}, rejected: {rejected_mean}")

@property
def match_underlying_distribution(self):
Expand Down Expand Up @@ -870,30 +878,32 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
return model

def _save_optimizer_and_scheduler(self, output_dir):
output_dir = output_dir if output_dir is not None else self.args.output_dir
super()._save_optimizer_and_scheduler(output_dir)

# When saving optimizer and scheduler to checkpoint, save also the running delta object.
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.accelerator.is_main_process:
# When saving optimizer and scheduler to checkpoint, save also the running delta object.
self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))

self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))

if self.match_underlying_distribution:
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
if self.match_underlying_distribution:
joblib.dump(self.clf, os.path.join(output_dir, CLF_NAME), compress=True)

def _load_optimizer_and_scheduler(self, checkpoint):
super()._load_optimizer_and_scheduler(checkpoint)

if checkpoint is None:
logger.warning_once(f"Missing Checkpoint {checkpoint}")
return

super()._load_optimizer_and_scheduler(checkpoint)

# when loading optimizer and scheduler from checkpoint, also load the running delta object.
running_file = os.path.join(checkpoint, RUNNING_NAME)
if os.path.isfile(running_file):
self.running = RunningMoments.load_from_json(self.accelerator, running_file)

if self.match_underlying_distribution:
clf_file = os.path.join(checkpoint, CLF_NAME)
if os.path.isfile(running_file):
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
if os.path.isfile(clf_file):
self.clf = joblib.load(clf_file)

@contextmanager
def null_ref_context(self):
Expand Down Expand Up @@ -1138,6 +1148,7 @@ def bco_loss(
reference_rejected_logps: torch.FloatTensor,
chosen_embeddings: Optional[torch.FloatTensor],
rejected_embeddings: Optional[torch.FloatTensor],
do_train: bool = True,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the BCO loss for a batch of policy and reference model log probabilities.

Expand All @@ -1156,31 +1167,18 @@ def bco_loss(
The delta value contains the moving average of all implicit rewards.
"""

if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
chosen_rewards = self.beta * chosen_logratios
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
chosen_logratios = policy_chosen_logps - reference_chosen_logps
chosen_rewards = self.beta * chosen_logratios

if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_rewards = self.beta * rejected_logratios
else:
# lists can't be empty -- if they are, then accelerate.gather will hang
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_rewards = self.beta * rejected_logratios

rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
self.running.update(rewards)
delta = self.running.mean
if do_train:
self.running.update(torch.cat((chosen_rewards, rejected_rewards), 0).detach())
delta = torch.as_tensor(self.running.mean, device=chosen_rewards.device)

if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
chosen_losses = -F.logsigmoid(chosen_rewards - delta)

if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
chosen_losses = -F.logsigmoid(chosen_rewards - delta)
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))

if self.match_underlying_distribution:
chosen_weight = torch.ones_like(chosen_losses)
Expand All @@ -1190,12 +1188,13 @@ def bco_loss(
else:
losses = torch.cat((chosen_losses, rejected_losses), dim=0)

return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
return losses, chosen_rewards, rejected_rewards, delta

def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
do_train: bool = True,
):
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
Expand Down Expand Up @@ -1245,6 +1244,7 @@ def get_batch_loss_metrics(
reference_rejected_logps,
chosen_embeddings,
rejected_embeddings,
do_train=do_train,
)
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()

Expand Down Expand Up @@ -1375,7 +1375,7 @@ def prediction_step(

prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs)
loss, metrics = self.get_batch_loss_metrics(model, inputs, do_train=False)

# force log the metrics
if self.accelerator.is_main_process:
Expand Down