-
Notifications
You must be signed in to change notification settings - Fork 482
Update generation utilities #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@PhungVanDuy Are we adding best-of-n here? |
Yes, we can add it here. |
Are we planning to merge tonight |
No, I still have to update every other example (currently grappling with T5's) and make regression plots. I will also add clarifying comments shortly so until don't rush with reviewing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice update. I've left some comments for feedback. Could you also resolve the merge conflicts when you get a chance? Thanks 🙏
# Log and display evaluation metrics | ||
if self.accelerator.is_main_process: | ||
rows = sum(list(map(list, zip(*table))), []) | ||
rich_table = Table(*columns, title=f"Evaluation #{self.nth_evaluation}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self, prompts: List[torch.IntTensor], samples, prompt_sizes=None | ||
) -> List[str]: | ||
""" | ||
Decode samples into (samples: List[str], outputs: List[str], samples: List[str]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the return be documented instead as:
(samples: List[str], prompts: List[str], outputs: List[str])
trlx/trainer/__init__.py
Outdated
config: TRLConfig, | ||
reward_fn=None, | ||
metric_fn=None, | ||
stop_word=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I don't believe that stop sequences will only be words. For clarity, either document that this handles a stop sequence (string) or update the arg name (e.g. OpenAI's Completion
API calls it stop
).
stop_word_ix = str_output.find(self.stop_word) | ||
if stop_word_ix == -1: | ||
stop_word_ix = None | ||
str_output = str_output[:stop_word_ix] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we want to right-strip this to avoid extra white space? (Otherwise, assume users expect to include any leading white space in their stop_word
)
for k, v in self.config.method.gen_kwargs.items(): | ||
if isinstance(v, list): | ||
if self.generate_sweep_kwarg is not None: | ||
print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use utils.print_rank_0
or check for main process here to avoid the annoying prints from all processes.
This PR will
gen_kwargs
value (e.g.temperature
,top_k
,beta
) during periodic evaluationsreward_fn
's signature fromreward_fn(samples)
toreward_fn(samples, prompts, responses)
gen_kwargs
(e.g. Models indicate they have completed a response by generating a stop sequence, which is literally the string "Human:" – from Anthropic's 2021 "A General Language Assistant as a Laboratory for Alignment")Among HF's
generate
argumentsStoppingCriteria
can only stop all generations per batch,eos_token_ids
accepts onlyList[int]
. The other option is to customizegenerate
method for every architecture or to trim samples aftergenerate
as was done herehttps://wandb.ai/sorry/trlx/reports/-Update-generation-utilities-172---VmlldzozMzIxMjE1
https://wandb.ai/sorry/trlx/reports/Update-generation-utilities-172--VmlldzozMzIxMjE5