Skip to content

Commit f4646f7

Browse files
hills-codewuchengyuekennymckormick
authored
[Model] Add Janus-1.3B (#541)
* add janus eval * update * [Fix] Fix Lint --------- Co-authored-by: wuchengyue <[email protected]> Co-authored-by: kennymckormick <[email protected]>
1 parent db31bbb commit f4646f7

File tree

4 files changed

+144
-2
lines changed

4 files changed

+144
-2
lines changed

vlmeval/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@
204204
'deepseek_vl_1.3b': partial(DeepSeekVL, model_path='deepseek-ai/deepseek-vl-1.3b-chat'),
205205
}
206206

207+
208+
janus_series = {
209+
'janus_1.3b': partial(Janus, model_path='deepseek-ai/Janus-1.3B')
210+
}
211+
207212
cogvlm_series = {
208213
'cogvlm-grounding-generalist': partial(CogVlm, model_path='THUDM/cogvlm-grounding-generalist-hf', tokenizer_name='lmsys/vicuna-7b-v1.5'),
209214
'cogvlm-chat': partial(CogVlm, model_path='THUDM/cogvlm-chat-hf', tokenizer_name='lmsys/vicuna-7b-v1.5'),
@@ -322,7 +327,7 @@
322327
ungrouped, api_models,
323328
xtuner_series, qwen_series, llava_series, internvl_series, yivl_series,
324329
xcomposer_series, minigpt4_series, idefics_series, instructblip_series,
325-
deepseekvl_series, minicpm_series, cogvlm_series, wemm_series,
330+
deepseekvl_series, janus_series, minicpm_series, cogvlm_series, wemm_series,
326331
cambrian_series, chameleon_series, video_models, ovis_series, vila_series,
327332
mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series,
328333
slime_series, eagle_series, moondream_series, llama_series, molmo_series,

vlmeval/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
'TransCore_M', 'emu2_chat', 'MiniCPM-V', 'MiniCPM-V-2', 'OmniLMM_12B',
7272
'cogvlm-grounding-generalist', 'cogvlm-chat', 'cogvlm2-llama3-chat-19B',
7373
'mPLUG-Owl3'
74-
] + list(xtuner_series) + list(yivl_series) + list(deepseekvl_series) + list(cambrian_series),
74+
] + list(xtuner_series) + list(yivl_series) + list(deepseekvl_series) + list(janus_series) + list(cambrian_series),
7575
'4.36.2': ['Moondream1'],
7676
'4.40.0': [
7777
'idefics2_8b', 'Bunny-llama3-8B', 'MiniCPM-Llama3-V-2_5', '360VL-70B', 'Phi-3-Vision',

vlmeval/vlm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .yi_vl import Yi_VL
2929
from .internvl_chat import InternVLChat
3030
from .deepseek_vl import DeepSeekVL
31+
from .janus import Janus
3132
from .mgm import Mini_Gemini
3233
from .bunnyllama3 import BunnyLLama3
3334
from .vxverse import VXVERSE

vlmeval/vlm/janus.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import sys
2+
import torch
3+
from transformers import AutoModelForCausalLM, AutoConfig
4+
import warnings
5+
from .base import BaseModel
6+
from ..smp import *
7+
from ..dataset import DATASET_TYPE
8+
9+
10+
class Janus(BaseModel):
11+
12+
INSTALL_REQ = True
13+
INTERLEAVE = True
14+
15+
def check_install(self):
16+
try:
17+
import janus
18+
except Exception as e:
19+
logging.critical(
20+
'Please first install janus from source codes in: https://github.com/deepseek-ai/Janus')
21+
raise e
22+
23+
def __init__(self, model_path='deepseek-ai/Janus-1.3B', **kwargs):
24+
self.check_install()
25+
assert model_path is not None
26+
self.model_path = model_path
27+
from janus.models import VLChatProcessor
28+
29+
self.vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
30+
self.tokenizer = self.vl_chat_processor.tokenizer
31+
32+
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
33+
self.model = model.to(torch.bfloat16).cuda().eval()
34+
35+
torch.cuda.empty_cache()
36+
default_kwargs = dict(
37+
max_new_tokens=512,
38+
do_sample=False,
39+
use_cache=True,
40+
output_logits=False,
41+
output_scores=False,
42+
return_dict_in_generate=False)
43+
44+
default_kwargs.update(kwargs)
45+
self.kwargs = default_kwargs
46+
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
47+
48+
def prepare_inputs(self, message):
49+
def prepare_itlist(msgs):
50+
content, images = '', []
51+
for s in msgs:
52+
if s['type'] == 'image':
53+
images.append(s['value'])
54+
content += '<image_placeholder>'
55+
elif s['type'] == 'text':
56+
content += s['value']
57+
return content, images
58+
conversation = []
59+
if 'role' not in message[0]:
60+
content, images = prepare_itlist(message)
61+
conversation.append(dict(role='User', content=content, images=images))
62+
else:
63+
role_map = {'user': 'User', 'assistant': 'Assistant'}
64+
for msgs in message:
65+
role = role_map[msgs['role']]
66+
content, images = prepare_itlist(msgs['content'])
67+
conversation.append(dict(role=role, content=content, images=images))
68+
conversation.append(dict(role='Assistant', content=''))
69+
return conversation
70+
71+
def generate_inner(self, message, dataset=None):
72+
if not ('MMVet' in dataset):
73+
self.vl_chat_processor.system_prompt = ""
74+
else:
75+
self.vl_chat_processor.system_prompt = "You are a helpful assistant. Please answer truthfully and write out your thinking step by step to be sure you get the right answer." # noqa: E501
76+
77+
conversation = self.prepare_inputs(message)
78+
from janus.utils.io import load_pil_images
79+
pil_images = load_pil_images(conversation)
80+
prepare_inputs = self.vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True)
81+
prepare_inputs = prepare_inputs.to(self.model.device, dtype=torch.bfloat16)
82+
inputs_embeds = self.model.prepare_inputs_embeds(**prepare_inputs)
83+
84+
outputs = self.model.language_model.generate(
85+
inputs_embeds=inputs_embeds,
86+
attention_mask=prepare_inputs.attention_mask,
87+
pad_token_id=self.tokenizer.eos_token_id,
88+
bos_token_id=self.tokenizer.bos_token_id,
89+
eos_token_id=self.tokenizer.eos_token_id,
90+
**self.kwargs)
91+
answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
92+
return answer
93+
94+
def chat_inner(self, message, dataset=None):
95+
return self.generate_inner(message, dataset=dataset)
96+
97+
def use_custom_prompt(self, dataset):
98+
assert dataset is not None
99+
if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ' or dataset == 'MMVet':
100+
return True
101+
return False
102+
103+
def build_prompt(self, line, dataset=None):
104+
assert dataset is None or isinstance(dataset, str)
105+
assert self.use_custom_prompt(dataset)
106+
tgt_path = self.dump_image(line, dataset)
107+
question = line['question']
108+
if DATASET_TYPE(dataset) == 'Y/N':
109+
if dataset == 'POPE':
110+
question = question.replace(" Please answer yes or no.", "")
111+
prompt = '\n' + question + "\nAnswer the question using a single word or phrase."
112+
elif DATASET_TYPE(dataset) == 'MCQ':
113+
options = {
114+
cand: line[cand]
115+
for cand in string.ascii_uppercase
116+
if cand in line and not pd.isna(line[cand])
117+
}
118+
options_prompt = ''
119+
for key, item in options.items():
120+
options_prompt += f'{key}. {item}\n'
121+
122+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
123+
prompt = f'\nHint: {hint}\n' if hint is not None else '\n'
124+
prompt += f'{question}\n'
125+
prompt += (
126+
f"{options_prompt}\nAnswer with the option's letter from the given choices directly."
127+
if len(options) else 'Answer the question directly. '
128+
)
129+
elif dataset == 'MMVet':
130+
prompt = '\n' + question
131+
else:
132+
raise NotImplementedError
133+
134+
message = [dict(type='image', value=s) for s in tgt_path]
135+
message.extend([dict(type='text', value=prompt)])
136+
return message

0 commit comments

Comments
 (0)