Skip to content

Support for T5 for ILQL #204

@loganlebanoff

Description

@loganlebanoff

🐛 Describe the bug

It seems that there is not support for seq2seq models when using the ILQL algorithm.

Traceback (most recent call last):
  File "/home/local/AA/logan.lebanoff/.pycharm_helpers/pydev/pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/local/AA/logan.lebanoff/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/local/AA/logan.lebanoff/slamdunk/src/slamdunk/trl/test_trlx.py", line 311, in <module>
    main(args)
  File "/home/local/AA/logan.lebanoff/slamdunk/src/slamdunk/trl/test_trlx.py", line 289, in main
    model = trlx.train(
  File "/home/local/AA/logan.lebanoff/trlx/trlx/trlx.py", line 104, in train
    trainer = get_trainer(config.train.trainer)(
  File "/home/local/AA/logan.lebanoff/trlx/trlx/trainer/accelerate_ilql_trainer.py", line 16, in __init__
    super().__init__(config, **kwargs)
  File "/home/local/AA/logan.lebanoff/trlx/trlx/trainer/accelerate_base_trainer.py", line 51, in __init__
    self.model = self.setup_model()
  File "/home/local/AA/logan.lebanoff/trlx/trlx/trainer/accelerate_base_trainer.py", line 100, in setup_model
    model = self.get_arch(self.config)
  File "/home/local/AA/logan.lebanoff/trlx/trlx/trainer/accelerate_ilql_trainer.py", line 32, in get_arch
    return CausalLMWithValueHeads(
  File "/home/local/AA/logan.lebanoff/trlx/trlx/trainer/nn/ilql_models.py", line 221, in __init__
    self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config)
  File "/home/local/AA/logan.lebanoff/miniconda3/envs/slamdunk/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 466, in from_pretrained
    raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers.models.t5.configuration_t5.T5Config'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, CodeGenConfig, CTRLConfig, Data2VecTextConfig, ElectraConfig, ErnieConfig, GPT2Config, GPTNeoConfig, GPTNeoXConfig, GPTJConfig, MarianConfig, MBartConfig, MegatronBertConfig, MvpConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, ReformerConfig, RemBertConfig, RobertaConfig, RoFormerConfig, Speech2Text2Config, TransfoXLConfig, TrOCRConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig.

Which trlX version are you using?

0.3.0

Additional system and package information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions