Skip to content

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented Jan 9, 2023

This PR will

  • enable sweeping over a single gen_kwargs value (e.g. temperature, top_k, beta) during periodic evaluations
  • change reward_fn's signature from reward_fn(samples) to reward_fn(samples, prompts, responses)
  • add an optional stop keyword to gen_kwargs (e.g. Models indicate they have completed a response by generating a stop sequence, which is literally the string "Human:" – from Anthropic's 2021 "A General Language Assistant as a Laboratory for Alignment")
  • clean up evaluation result's outputs

Among HF's generate arguments StoppingCriteria can only stop all generations per batch, eos_token_ids accepts only List[int]. The other option is to customize generate method for every architecture or to trim samples after generate as was done here

https://wandb.ai/sorry/trlx/reports/-Update-generation-utilities-172---VmlldzozMzIxMjE1
https://wandb.ai/sorry/trlx/reports/Update-generation-utilities-172--VmlldzozMzIxMjE5

@LouisCastricato
Copy link
Contributor

@PhungVanDuy Are we adding best-of-n here?

@PhungVanDuy
Copy link
Collaborator

@PhungVanDuy Are we adding best-of-n here?

Yes, we can add it here.

@jon-tow jon-tow added this to the v0.4.0 milestone Jan 9, 2023
@LouisCastricato
Copy link
Contributor

Are we planning to merge tonight

@maxreciprocate
Copy link
Collaborator Author

No, I still have to update every other example (currently grappling with T5's) and make regression plots. I will also add clarifying comments shortly so until don't rush with reviewing

@maxreciprocate maxreciprocate marked this pull request as ready for review January 12, 2023 17:24
Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice update. I've left some comments for feedback. Could you also resolve the merge conflicts when you get a chance? Thanks 🙏

# Log and display evaluation metrics
if self.accelerator.is_main_process:
rows = sum(list(map(list, zip(*table))), [])
rich_table = Table(*columns, title=f"Evaluation #{self.nth_evaluation}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's so beautiful 😭 Do you think we should add show_lines=True in the Table constructor? I noticed some of the outputs kind of bleed together but this might just be my terminal settings. Fine either way - much better than before!

Screenshot 2023-01-12 at 23 52 26

Screenshot 2023-01-12 at 23 50 54

self, prompts: List[torch.IntTensor], samples, prompt_sizes=None
) -> List[str]:
"""
Decode samples into (samples: List[str], outputs: List[str], samples: List[str])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the return be documented instead as:
(samples: List[str], prompts: List[str], outputs: List[str])

config: TRLConfig,
reward_fn=None,
metric_fn=None,
stop_word=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I don't believe that stop sequences will only be words. For clarity, either document that this handles a stop sequence (string) or update the arg name (e.g. OpenAI's Completion API calls it stop).

stop_word_ix = str_output.find(self.stop_word)
if stop_word_ix == -1:
stop_word_ix = None
str_output = str_output[:stop_word_ix]
Copy link
Collaborator

@jon-tow jon-tow Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want to right-strip this to avoid extra white space? (Otherwise, assume users expect to include any leading white space in their stop_word)

for k, v in self.config.method.gen_kwargs.items():
if isinstance(v, list):
if self.generate_sweep_kwarg is not None:
print(
Copy link
Collaborator

@jon-tow jon-tow Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use utils.print_rank_0 or check for main process here to avoid the annoying prints from all processes.

@LouisCastricato LouisCastricato merged commit 84dd156 into main Jan 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants