-
Notifications
You must be signed in to change notification settings - Fork 482
Add support for more CausalLM
s
#103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
|
||
@pytest.mark.parametrize( | ||
"model_name", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice this is great
|
||
**PPO** | ||
|
||
.. autoclass:: trlx.model.accelerate_ppo_model.AcceleratePPOModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will probably have to merge with @reciprocated latest commit (just fyi)
def _getattr(obj, attr): | ||
return getattr(obj, attr, *args) | ||
|
||
return functools.reduce(_getattr, [obj] + attr.split(".")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
Looks very nice! For some reason tests for the HydraHead are breaking (I'm trying to see why). Edit: I think it's negligible. We should should the assert equal to something like |
@Dahoas Thanks for the review! This is still in draft as I'll be adding more tests and running examples. I'll bother you again for another review later, if you don't mind 😄 |
OPTForCausalLM is working well ? cause i changed ppo trainer code, and trained but result was bad results was not like gpt2, OPT Model is going to add a special token, |
Hi @dongs0104! The default PPO trainer uses a GPT hydra-head modification which will require re-working of the internals to support other models such as OPT. I plan to add hydras for OPT in this PR soon. Thanks for bringing awareness to possible special-token issues - we'll keep an eye on it 👍 |
Good job! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Lets merge.
logit_mask=None, | ||
pad_token_id=50256, | ||
eos_token_id=50256, | ||
pad_token_id=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
50256
is the default eos_token_id
for gpt2
based tokenizers - this just ensures there is no implicit default for the other non-gpt2 models.
This is excellent! |
This PR adds support for more
CausalLM
s in the HuggingFace Hub. Previously, only models that followed thegpt2
architecture layer-naming convention were supported (except for ILQL which also supportsgpt-neox
).This will allow one to use models such as OPT and Pythia out of the box.
wandb
reports:Notes:
Related Issue: #121