-
Notifications
You must be signed in to change notification settings - Fork 44
Description
Hi, I'm attempting to run the BART model example as given in the readme:
import torch
from transformers import BartTokenizer
from bart import MyBart
base_model = "facebook/bart-large"
unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint
tokenizer = BartTokenizer.from_pretrained(base_model)
model = MyBart.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()
x = model.generate_from_string("Which is best conductor? \n (A) iron (B) feather", tokenizer=tokenizer)
The .from_pretrained line executes fine but the .generate_from_string(..) line errors out with the error:
TypeError: forward() got an unexpected keyword argument 'past_key_values'
I tried using the run_model(..) method from the main git page and it gives exactly the same error.
Any idea what might be causing this and how to fix it?
I am using python 3.85 with transformers 4.4.2 and pytorch 1.7.1