|
| 1 | +import sys |
| 2 | +import torch |
| 3 | +from transformers import AutoModelForCausalLM, AutoConfig |
| 4 | +import warnings |
| 5 | +from .base import BaseModel |
| 6 | +from ..smp import * |
| 7 | +from ..dataset import DATASET_TYPE |
| 8 | + |
| 9 | + |
| 10 | +class Janus(BaseModel): |
| 11 | + |
| 12 | + INSTALL_REQ = True |
| 13 | + INTERLEAVE = True |
| 14 | + |
| 15 | + def check_install(self): |
| 16 | + try: |
| 17 | + import janus |
| 18 | + except Exception as e: |
| 19 | + logging.critical( |
| 20 | + 'Please first install janus from source codes in: https://github.com/deepseek-ai/Janus') |
| 21 | + raise e |
| 22 | + |
| 23 | + def __init__(self, model_path='deepseek-ai/Janus-1.3B', **kwargs): |
| 24 | + self.check_install() |
| 25 | + assert model_path is not None |
| 26 | + self.model_path = model_path |
| 27 | + from janus.models import VLChatProcessor |
| 28 | + |
| 29 | + self.vl_chat_processor = VLChatProcessor.from_pretrained(model_path) |
| 30 | + self.tokenizer = self.vl_chat_processor.tokenizer |
| 31 | + |
| 32 | + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) |
| 33 | + self.model = model.to(torch.bfloat16).cuda().eval() |
| 34 | + |
| 35 | + torch.cuda.empty_cache() |
| 36 | + default_kwargs = dict( |
| 37 | + max_new_tokens=512, |
| 38 | + do_sample=False, |
| 39 | + use_cache=True, |
| 40 | + output_logits=False, |
| 41 | + output_scores=False, |
| 42 | + return_dict_in_generate=False) |
| 43 | + |
| 44 | + default_kwargs.update(kwargs) |
| 45 | + self.kwargs = default_kwargs |
| 46 | + warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') |
| 47 | + |
| 48 | + def prepare_inputs(self, message): |
| 49 | + def prepare_itlist(msgs): |
| 50 | + content, images = '', [] |
| 51 | + for s in msgs: |
| 52 | + if s['type'] == 'image': |
| 53 | + images.append(s['value']) |
| 54 | + content += '<image_placeholder>' |
| 55 | + elif s['type'] == 'text': |
| 56 | + content += s['value'] |
| 57 | + return content, images |
| 58 | + conversation = [] |
| 59 | + if 'role' not in message[0]: |
| 60 | + content, images = prepare_itlist(message) |
| 61 | + conversation.append(dict(role='User', content=content, images=images)) |
| 62 | + else: |
| 63 | + role_map = {'user': 'User', 'assistant': 'Assistant'} |
| 64 | + for msgs in message: |
| 65 | + role = role_map[msgs['role']] |
| 66 | + content, images = prepare_itlist(msgs['content']) |
| 67 | + conversation.append(dict(role=role, content=content, images=images)) |
| 68 | + conversation.append(dict(role='Assistant', content='')) |
| 69 | + return conversation |
| 70 | + |
| 71 | + def generate_inner(self, message, dataset=None): |
| 72 | + if not ('MMVet' in dataset): |
| 73 | + self.vl_chat_processor.system_prompt = "" |
| 74 | + else: |
| 75 | + self.vl_chat_processor.system_prompt = "You are a helpful assistant. Please answer truthfully and write out your thinking step by step to be sure you get the right answer." # noqa: E501 |
| 76 | + |
| 77 | + conversation = self.prepare_inputs(message) |
| 78 | + from janus.utils.io import load_pil_images |
| 79 | + pil_images = load_pil_images(conversation) |
| 80 | + prepare_inputs = self.vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True) |
| 81 | + prepare_inputs = prepare_inputs.to(self.model.device, dtype=torch.bfloat16) |
| 82 | + inputs_embeds = self.model.prepare_inputs_embeds(**prepare_inputs) |
| 83 | + |
| 84 | + outputs = self.model.language_model.generate( |
| 85 | + inputs_embeds=inputs_embeds, |
| 86 | + attention_mask=prepare_inputs.attention_mask, |
| 87 | + pad_token_id=self.tokenizer.eos_token_id, |
| 88 | + bos_token_id=self.tokenizer.bos_token_id, |
| 89 | + eos_token_id=self.tokenizer.eos_token_id, |
| 90 | + **self.kwargs) |
| 91 | + answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) |
| 92 | + return answer |
| 93 | + |
| 94 | + def chat_inner(self, message, dataset=None): |
| 95 | + return self.generate_inner(message, dataset=dataset) |
| 96 | + |
| 97 | + def use_custom_prompt(self, dataset): |
| 98 | + assert dataset is not None |
| 99 | + if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ' or dataset == 'MMVet': |
| 100 | + return True |
| 101 | + return False |
| 102 | + |
| 103 | + def build_prompt(self, line, dataset=None): |
| 104 | + assert dataset is None or isinstance(dataset, str) |
| 105 | + assert self.use_custom_prompt(dataset) |
| 106 | + tgt_path = self.dump_image(line, dataset) |
| 107 | + question = line['question'] |
| 108 | + if DATASET_TYPE(dataset) == 'Y/N': |
| 109 | + if dataset == 'POPE': |
| 110 | + question = question.replace(" Please answer yes or no.", "") |
| 111 | + prompt = '\n' + question + "\nAnswer the question using a single word or phrase." |
| 112 | + elif DATASET_TYPE(dataset) == 'MCQ': |
| 113 | + options = { |
| 114 | + cand: line[cand] |
| 115 | + for cand in string.ascii_uppercase |
| 116 | + if cand in line and not pd.isna(line[cand]) |
| 117 | + } |
| 118 | + options_prompt = '' |
| 119 | + for key, item in options.items(): |
| 120 | + options_prompt += f'{key}. {item}\n' |
| 121 | + |
| 122 | + hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None |
| 123 | + prompt = f'\nHint: {hint}\n' if hint is not None else '\n' |
| 124 | + prompt += f'{question}\n' |
| 125 | + prompt += ( |
| 126 | + f"{options_prompt}\nAnswer with the option's letter from the given choices directly." |
| 127 | + if len(options) else 'Answer the question directly. ' |
| 128 | + ) |
| 129 | + elif dataset == 'MMVet': |
| 130 | + prompt = '\n' + question |
| 131 | + else: |
| 132 | + raise NotImplementedError |
| 133 | + |
| 134 | + message = [dict(type='image', value=s) for s in tgt_path] |
| 135 | + message.extend([dict(type='text', value=prompt)]) |
| 136 | + return message |
0 commit comments