Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions assert_padding_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import typer
from transformers import AutoTokenizer
from functionary.prompt_template import get_prompt_template_by_version
from functionary.train_vision.qwen2_vl_dataset import LazyVisionDataset
import torch
from transformers import Qwen2VLForConditionalGeneration
import math


def get_raw_data():
raw_data = [
{
"messages": [
{"role": "user", "content": "what do you do for a living?"},
{"role": "assistant", "content": "I am a doctor"},
],
"tools": [],
},
{
"messages": [
{"role": "user", "content": "what do you do for a living?"},
{"role": "assistant", "content": "I am a doctor"},
{
"role": "user",
"content": f"can you count number of letters s in this string: "
+ " ".join(["this" for _ in range(100)]),
},
{"role": "assistant", "content": "The number is 100"},
],
"tools": [],
},
{ # the last image as partially truncated
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": "file://assets/s.png"},
},
{"type": "text", "text": "can you describe this image"},
],
},
{"role": "assistant", "content": "the first image is letter s"},
{
"role": "user",
"content": [
{"type": "text", "text": "how about this image ?"},
{
"type": "image_url",
"image_url": {"url": "file://assets/as.png"},
},
],
},
{
"role": "assistant",
"content": "this image is the word: 'as' I can predict this",
},
],
"tools": [],
},
]
return raw_data


def get_loss_from_ds(ds, model):
loss_list = []
model.eval()
with torch.no_grad():
for i in range(len(ds)):
print("------------------")
data = ds[i]
print("data: ", data)
for key in ["input_ids", "labels", "attention_mask"]:
data[key] = data[key][None, :]

for key in data:
data[key] = data[key].to(model.device)
print(f"{key}: {data[key].shape}")

labels = data["labels"]
label_count = (labels != -100).sum().item()
if label_count == 0:
loss_list.append((label_count, -1))
else:
output = model.forward(**data)
loss = output.loss.item()
loss_list.append((label_count, loss))
return loss_list


def main(max_length: int):
pretrained_path = "Qwen/Qwen2-VL-7B-Instruct"
prompt_template = get_prompt_template_by_version("qwen2-vl")
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
tokenizer.chat_template = prompt_template.get_chat_template_jinja()

raw_data = get_raw_data()
# raw_data = raw_data[-1: ]
ds = LazyVisionDataset(
raw_data,
tokenizer,
pretrained_path=pretrained_path,
pad_img_path="functionary/train_vision/pad_img2.png",
max_length=max_length,
use_img_pad_token=False,
)

pad_ds = LazyVisionDataset(
raw_data,
tokenizer,
pretrained_path=pretrained_path,
pad_img_path="functionary/train_vision/pad_img2.png",
max_length=max_length,
use_img_pad_token=True,
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
pretrained_path,
torch_dtype=torch.bfloat16,
device_map="auto",
use_flash_attention_2=True,
)
loss_list = get_loss_from_ds(ds, model)
print("loss_list: ", loss_list)
pad_loss_list = get_loss_from_ds(pad_ds, model)
print("pad_loss_list: ", pad_loss_list)
print("----------------------------")
for loss, pad_loss in zip(loss_list, pad_loss_list):
count1, loss1 = loss
count2, loss2 = pad_loss
percentage = math.fabs(loss2 - loss1) * 100 / loss1
print(
f"count1: {count1}; count2: {count2}, loss1: {loss1}, loss2: {loss2}; percentage={percentage} %"
)


if __name__ == "__main__":
typer.run(main)
39 changes: 39 additions & 0 deletions create_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json
import random

def read_data(path):
with open(path, "r") as f:
return [json.loads(line) for line in f]


def main():
text_train_data = read_data("2024-03-27/2024-03-27_train.jsonl")
text_dev_data = read_data("2024-03-27/2024-03-27_val.jsonl")

img_train_data = read_data("2024-07-30_train.jsonl")
img_dev_data = read_data("2024-07-30_val.jsonl")

print(f"text_train_data:{len(text_train_data)}; text_dev_data: {len(text_dev_data)}; img_train_data: {len(img_train_data)}; img_dev_data:{len(img_dev_data)}")

for item in [text_train_data, text_dev_data, img_train_data, img_dev_data]:
random.shuffle(item)

total_train = text_train_data + img_train_data
total_dev = text_dev_data + img_dev_data

random.shuffle(total_train)
random.shuffle(total_dev)

print(f"number of total_train: {len(total_train)}; total_dev: {len(total_dev)}")

with open("total_train.jsonl", "w") as f:
for item in total_train:
f.write(json.dumps(item, ensure_ascii=False) + "\n")

with open("total_dev.jsonl", "w") as f:
for item in total_dev:
f.write(json.dumps(item, ensure_ascii=False) + "\n")


if __name__ == "__main__":
main()
12 changes: 6 additions & 6 deletions functionary/inference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Dict, List, Optional, Union

import torch
from lmformatenforcer import CharacterLevelParser, JsonSchemaParser
from lmformatenforcer.integrations.vllm import build_vllm_logits_processor
#from lmformatenforcer import CharacterLevelParser, JsonSchemaParser
#from lmformatenforcer.integrations.vllm import build_vllm_logits_processor
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
_cached_build_vllm_token_enforcer_tokenizer_data,
_normalize_json_schema_object,
)
# from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
# _cached_build_vllm_token_enforcer_tokenizer_data,
# _normalize_json_schema_object,
# )
from vllm.sampling_params import LogitsProcessor

from functionary.openai_types import ChatMessage, Function, FunctionCall, Tool
Expand Down
2 changes: 1 addition & 1 deletion functionary/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Function(BaseModel):


class Tool(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
type: Literal["function", "code_interpreter", "reasoning"] = "function"
function: Optional[Function] = None


Expand Down
6 changes: 5 additions & 1 deletion functionary/prompt_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from functionary.prompt_template.llava_prompt_template import LlavaLlama
from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1
from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2
from functionary.prompt_template.llama31_reasoning_prompt_template import Llama31ReasoningTemplate
from functionary.prompt_template.llama31_reasoning_prompt_template import (
Llama31ReasoningTemplate,
)
from functionary.prompt_template.qwen25_template import Qwen25PromptTemplate


def get_available_prompt_template_versions() -> List[PromptTemplate]:
Expand All @@ -29,6 +32,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]:
# directly add LLavaLlama as it is not a direct subclass of PromptTemplate but the subclass of: Llama3TemplateV3
# we don't use get_prompt_template or this will return the parent class
all_templates_obj.append(LlavaLlama.get_prompt_template())
# all_templates_obj.append(Qwen2VLTemplate.get_prompt_template())
all_templates_obj.append(Llama31ReasoningTemplate.get_prompt_template())
return all_templates_obj

Expand Down
27 changes: 20 additions & 7 deletions functionary/prompt_template/base_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from functionary.prompt_template.prompt_utils import resolve_json_refs
from functionary.openai_types import Function, Tool
from functionary.prompt_template import prompt_utils

from PIL import Image
import sys

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)
Expand Down Expand Up @@ -129,12 +130,21 @@ def get_prompt_from_messages(

tools = resolve_json_refs(tools_or_functions=tools_or_functions)

prompt = self._jinja_template.render(
messages=messages,
tools=tools,
bos_token=bos_token,
add_generation_prompt=add_generation_prompt,
)
try:
prompt = self._jinja_template.render(
messages=messages,
tools=tools,
bos_token=bos_token,
add_generation_prompt=add_generation_prompt,
)
except Exception as e:
print(f"Error in get_prompt_from_messages: {e}")
print(f"messages: {messages}")
print(f"tools: {tools}")
print(f"bos_token: {bos_token}")
print(f"add_generation_prompt: {add_generation_prompt}")
raise e
sys.exit(1)

return prompt

Expand Down Expand Up @@ -386,6 +396,9 @@ def get_generation_prefix_for_tool_choice(self, tool_choice: Any):
"tool-choice must be one of: None, none, auto, required, or a specific tool"
)

def preprocess_image_input(self, image: Image) -> Image:
return image

@classmethod
def get_prompt_template(cls):
if cls._instances.get(cls, None) is None:
Expand Down
77 changes: 77 additions & 0 deletions functionary/prompt_template/jinja_templates/qwen2-vl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{# version=qwen2-vl #}{%- if not tools -%}
{%- set tools = [] -%}
{%- endif -%}
{%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%}
{%- if has_reasoning -%}
{%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%}
{%- endif -%}

{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%}
{%- if has_code_interpreter -%}
{%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%}
{%- endif -%}

{{- bos_token + '<|im_start|>system\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nTo send text to user, use this format:\n>>>all\n{content}\n' -}}
{%- if tools %}
{{- "\nYou have access to the following functions:\n\n" }}
{%- for t in tools %}
{%- if "type" in t -%}
{{ "Use the function '" + t["function"]["name"] + "' to '" + t["function"]["description"] + "'\n" + t["function"] | tojson() }}
{%- else -%}
{{ "Use the function '" + t["name"] + "' to '" + t["description"] + "'\n" + t | tojson }}
{%- endif -%}
{{- "\n\n" }}
{%- endfor %}
{{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n>>>{function_name}\n{arguments}' -}}
{%- endif %}

{%- if has_code_interpreter -%}
{{- '\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.' -}}
{%- endif -%}

{%- if has_reasoning %}
{{- "\n\nReasoning Mode: On" }}
{%- else -%}
{{ "\n\nReasoning Mode: Off" }}
{%- endif %}
{{- "<|im_end|>\n" -}}

{%- for message in messages -%}
{%- if message['role'] == 'user' -%}
{%- if message['content'] -%}
{%- if message['content'] is string -%}
{{ '<|im_start|>user\n' + message['content'] }}
{%- else -%}
{{ '<|im_start|>user\n' }}
{%- for content in message['content'] -%}
{%- if content['type'] == 'text' -%}
{{ content['text'] }}
{%- else -%}
{{ '<|vision_start|><|image_pad|><|vision_end|>' }}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{{ '<|im_end|>\n' }}
{%- endif -%}
{%- elif message['role'] == 'system' -%}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' -}}
{%- elif message['role'] == 'tool' -%}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' -}}
{%- else -%}
{%- if (message['content'] and message['content']|length > 0) or ('tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0) -%}
{{- '<|im_start|>' + message['role'] + '\n'-}}
{%- endif -%}
{%- if message['content'] and message['content']|length > 0 -%}
{{- '>>>all\n' + message['content'] -}}
{%- endif -%}
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0 -%}
{%- for tool_call in message['tool_calls'] -%}
{{- '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] -}}
{%- endfor -%}
{%- endif -%}
{%- if (message['content'] and message['content']|length > 0) or ('tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0) -%}
{{- '<|im_end|>\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{% if add_generation_prompt %}{{- '<|im_start|>assistant\n>>>' -}}{% endif %}
Loading
Loading