Skip to content

Fix fmapi_chat for instruct models and custom tokenizers #914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions llmfoundry/models/inference_api_wrapper/fmapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,40 @@
log = logging.getLogger(__name__)


def block_until_ready(base_url: str):
"""Block until the endpoint is ready."""
sleep_s = 5
timout_s = 5 * 60 # At max, wait 5 minutes

ping_url = f'{base_url}/ping'

waited_s = 0
while True:
try:
requests.get(ping_url)
log.info(f'Endpoint {ping_url} is ready')
break
except requests.exceptions.ConnectionError:
log.debug(
f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds'
)
time.sleep(sleep_s)
waited_s += sleep_s

if waited_s >= timout_s:
raise TimeoutError(
f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting'
)


class FMAPIEvalInterface(OpenAIEvalInterface):

def block_until_ready(self, base_url: str):
"""Block until the endpoint is ready."""
sleep_s = 5
timout_s = 5 * 60 # At max, wait 5 minutes

ping_url = f'{base_url}/ping'

waited_s = 0
while True:
try:
requests.get(ping_url)
log.info(f'Endpoint {ping_url} is ready')
break
except requests.exceptions.ConnectionError:
log.debug(
f'Endpoint {ping_url} not ready yet. Sleeping {sleep_s} seconds'
)
time.sleep(sleep_s)
waited_s += sleep_s

if waited_s >= timout_s:
raise TimeoutError(
f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting'
)

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
is_local = model_cfg.pop('local', False)
if is_local:
base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT',
'http://0.0.0.0:8080/v2')
model_cfg['base_url'] = base_url
block_until_ready(base_url)
self.block_until_ready(base_url)

if 'base_url' not in model_cfg:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]]),
num_classes=self.tokenizer.vocab_size)
num_classes=len(self.tokenizer))
for i in range(len(expected_cont_tokens)):
# decode one token at a time
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] +
Expand All @@ -81,7 +81,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],), padding_tok),
num_classes=self.tokenizer.vocab_size)
num_classes=len(self.tokenizer))
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

Expand Down
21 changes: 14 additions & 7 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]]),
num_classes=self.tokenizer.vocab_size)
num_classes=len(self.tokenizer))

prompt = self.tokenizer.decode(tokens[:cont_idxs[0]])
next_logit_tensor = self.get_next_token_logit_tensor(
Expand All @@ -214,7 +214,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
output_logits = torch.cat([output_logits, next_logit_tensor])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],), padding_tok),
num_classes=self.tokenizer.vocab_size)
num_classes=len(self.tokenizer))
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

Expand All @@ -228,9 +228,10 @@ def process_result(self, completion: Optional['ChatCompletion']):
tensors = []
for t in self.tokenizer(
completion.choices[0].message.content)['input_ids']:
tensors.append(
self.tokenizer.construct_logit_tensor(
{self.tokenizer.decode([t]): 0.0}))
# Not real logprobs
tensor = torch.tensor([0] * (len(self.tokenizer)))
tensor[t] = 1.0
tensors.append(tensor)

if len(tensors) == 0:
return None
Expand Down Expand Up @@ -263,8 +264,14 @@ def process_result(self, completion: Optional['Completion']):
assert isinstance(completion.choices[0].logprobs.top_logprobs, list)

if len(completion.choices[0].logprobs.top_logprobs[0]) > 0:
tensor = self.tokenizer.construct_logit_tensor(
dict(completion.choices[0].logprobs.top_logprobs[0]))
# Construct tensor of shape (vocab_size,) with logprobs for each token
tokenizer_logprobs = dict(
completion.choices[0].logprobs.top_logprobs[0])
tensor = torch.tensor([min(tokenizer_logprobs.values()) - 1] *
(len(self.tokenizer)))
for k in tokenizer_logprobs:
encoding = self.tokenizer(k)['input_ids']
tensor[encoding[0]] = tokenizer_logprobs[k]
return tensor
else:
# the model sometimes stops early even though we are still requesting tokens!
Expand Down
15 changes: 0 additions & 15 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers import PreTrainedTokenizer

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible."""
Expand Down Expand Up @@ -358,19 +357,5 @@ def sanitize_special_tokens(self) -> int:

return self.add_tokens(actual_new_tokens, special_tokens=True)

def construct_logit_tensor(self, logprobs: Dict[str,
float]) -> torch.Tensor:
"""Construct tensor of shape (vocab_size,) mapping words to logprobs.

Args:
logprobs (Dict[str, float]): Dictionary mapping tokens to log probabilities assigned to them by the model.
"""
tensor = torch.tensor([min(logprobs.values()) - 1] * (self.vocab_size))
for k in logprobs:
encoding = self(k)['input_ids']
idx = encoding[0]
tensor[idx] = logprobs[k]
return tensor


TiktokenTokenizerWrapper.register_for_auto_class()
159 changes: 159 additions & 0 deletions tests/models/inference_api_wrapper/test_fmapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Dict
from unittest.mock import patch

import pytest
import transformers
from omegaconf import DictConfig, ListConfig

from llmfoundry.models.inference_api_wrapper import (FMAPICasualLMEvalWrapper,
FMAPIChatAPIEvalWrapper)
from llmfoundry.models.inference_api_wrapper.fmapi import FMAPIEvalInterface
from llmfoundry.utils.builders import build_icl_evaluators


def load_icl_config():
return DictConfig({
'icl_tasks':
ListConfig([
DictConfig({
'label':
'jeopardy',
'dataset_uri':
'scripts/eval/local_data/world_knowledge/jeopardy_all.jsonl',
'num_fewshot': [0, 1],
'icl_task_type':
'language_modeling',
'continuation_delimiter':
'\nAnswer: ',
'has_categories':
True
})
])
})


class MockTopLogProb:

def __init__(self, expected_token: str) -> None:
self.top_logprobs = [{expected_token: 0}]


class MockLogprob:

def __init__(self, expected_token: str) -> None:
self.logprobs = MockTopLogProb(expected_token)


class MockCompletion:

def __init__(self, expected_token: str) -> None:
self.choices = [MockLogprob(expected_token)]


class MockContent:

def __init__(self, expected_token: str) -> None:
setattr(self, 'content', expected_token)


class MockMessage:

def __init__(self, expected_token: str) -> None:
setattr(self, 'message', MockContent(expected_token))


class MockChatCompletion:

def __init__(self, expected_token: str) -> None:
setattr(self, 'choices', [MockMessage(expected_token)])


def mock_create(**kwargs: Dict[str, str]):
prompt = kwargs['prompt']
if prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer:': # pyright: ignore[reportUnnecessaryComparison]
return MockCompletion(' Tre')

elif prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer: Tre': # pyright: ignore[reportUnnecessaryComparison]
return MockCompletion('ason')

elif prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer: Treason': # pyright: ignore[reportUnnecessaryComparison]
return MockCompletion('!')

else:
# dummy token to make sure the model is incorrect on any other prompt
return MockCompletion(' ')


def test_casual_fmapi_wrapper(tmp_path: str):
# patch block_until_ready
with patch.object(FMAPIEvalInterface, 'block_until_ready') as mock:

_ = pytest.importorskip('openai')

tokenizer = transformers.AutoTokenizer.from_pretrained(
'mosaicml/mpt-7b-8k-instruct')
model = FMAPICasualLMEvalWrapper(model_cfg={
'local': True,
'name': 'mosaicml/mpt-7b-8k-instruct'
},
tokenizer=tokenizer)
with patch.object(model, 'client') as mock:
mock.completions.create = mock_create

task_cfg = load_icl_config()
evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks,
tokenizer,
1024,
2,
destination_dir=str(tmp_path))

batch = next(evaluators[0].dataloader.dataloader.__iter__())
result = model.eval_forward(batch)
model.update_metric(
batch,
result,
metric=model.get_metrics()
['InContextLearningLMAccuracy']) # pyright: ignore
acc = model.get_metrics(
)['InContextLearningLMAccuracy'].compute( # pyright: ignore
) # pyright: ignore
assert acc == 0.5


def test_chat_fmapi_wrapper(tmp_path: str):
with patch.object(FMAPIEvalInterface, 'block_until_ready') as mock:
_ = pytest.importorskip('openai')

tokenizer = transformers.AutoTokenizer.from_pretrained(
'mosaicml/mpt-7b-8k-instruct')
chatmodel = FMAPIChatAPIEvalWrapper(model_cfg={
'local': True,
'name': 'mosaicml/mpt-7b-8k-instruct'
},
tokenizer=tokenizer)

with patch.object(chatmodel, 'client') as mock:
mock.chat.completions.create.return_value = MockChatCompletion(
'Treason!')

task_cfg = load_icl_config()
evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks,
tokenizer,
1024,
2,
destination_dir=str(tmp_path))

batch = next(evaluators[0].dataloader.dataloader.__iter__())
result = chatmodel.eval_forward(batch)
chatmodel.update_metric(
batch,
result,
metric=chatmodel.get_metrics()
['InContextLearningLMAccuracy']) # pyright: ignore
acc = chatmodel.get_metrics(
)['InContextLearningLMAccuracy'].compute( # pyright: ignore
)
assert acc == 0.5