-
-
Notifications
You must be signed in to change notification settings - Fork 104
Add support for Mistral 7B #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import Any, Callable, Dict, Iterable, Optional, Tuple | ||
|
||
from confection import SimpleFrozenDict | ||
|
||
from ...compat import Literal, transformers | ||
from ...registry.util import registry | ||
from .base import HuggingFace | ||
|
||
|
||
class Mistral(HuggingFace): | ||
MODEL_NAMES = Literal["Mistral-7B-v0.1", "Mistral-7B-Instruct-v0.1"] # noqa: F722 | ||
|
||
def __init__( | ||
self, | ||
name: MODEL_NAMES, | ||
config_init: Optional[Dict[str, Any]], | ||
config_run: Optional[Dict[str, Any]], | ||
): | ||
self._tokenizer: Optional["transformers.AutoTokenizer"] = None | ||
self._device: Optional[str] = None | ||
self._is_instruct = "instruct" in name | ||
super().__init__(name=name, config_init=config_init, config_run=config_run) | ||
|
||
assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) | ||
# self._config_run["pad_token_id"] = self._tokenizer.pad_token_id | ||
|
||
# Instantiate GenerationConfig object from config dict. | ||
self._hf_config_run = transformers.GenerationConfig.from_pretrained( | ||
self._name, **self._config_run | ||
) | ||
# To avoid deprecation warning regarding usage of `max_length`. | ||
self._hf_config_run.max_new_tokens = self._hf_config_run.max_length | ||
|
||
def init_model(self) -> Any: | ||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name) | ||
init_cfg = self._config_init | ||
if "device" in init_cfg: | ||
self._device = init_cfg.pop("device") | ||
|
||
model = transformers.AutoModelForCausalLM.from_pretrained( | ||
self._name, **init_cfg, resume_download=True | ||
) | ||
if self._device: | ||
model.to(self._device) | ||
|
||
return model | ||
|
||
@property | ||
def hf_account(self) -> str: | ||
return "mistralai" | ||
|
||
def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] | ||
assert callable(self._tokenizer) | ||
assert hasattr(self._model, "generate") | ||
assert hasattr(self._tokenizer, "batch_decode") | ||
prompts = list(prompts) | ||
|
||
tokenized_input_ids = [ | ||
self._tokenizer( | ||
prompt if not self._is_instruct else f"<s>[INST] {prompt} [/INST]", | ||
return_tensors="pt", | ||
).input_ids | ||
for prompt in prompts | ||
] | ||
if self._device: | ||
tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids] | ||
|
||
return [ | ||
self._tokenizer.decode( | ||
self._model.generate( | ||
input_ids=tok_ii, generation_config=self._hf_config_run | ||
)[:, tok_ii.shape[1] :][0], | ||
skip_special_tokens=True, | ||
) | ||
for tok_ii in tokenized_input_ids | ||
] | ||
|
||
@staticmethod | ||
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() | ||
return ( | ||
default_cfg_init, | ||
default_cfg_run, | ||
) | ||
|
||
|
||
@registry.llm_models("spacy.Mistral.v1") | ||
def mistral_hf( | ||
name: Mistral.MODEL_NAMES, | ||
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
) -> Callable[[Iterable[str]], Iterable[str]]: | ||
"""Generates Mistral instance that can execute a set of prompts and return the raw responses. | ||
name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names(). | ||
config_init (Optional[Dict[str, Any]]): HF config for initializing the model. | ||
config_run (Optional[Dict[str, Any]]): HF config for running the model. | ||
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return | ||
the raw responses. | ||
""" | ||
return Mistral(name=name, config_init=config_init, config_run=config_run) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import copy | ||
|
||
import pytest | ||
import spacy | ||
from confection import Config # type: ignore[import] | ||
from thinc.compat import has_torch_cuda_gpu | ||
|
||
from ...compat import torch | ||
|
||
_PIPE_CFG = { | ||
"model": { | ||
"@llm_models": "spacy.Mistral.v1", | ||
"name": "Mistral-7B-v0.1", | ||
}, | ||
"task": {"@llm_tasks": "spacy.NoOp.v1"}, | ||
} | ||
|
||
_NLP_CONFIG = """ | ||
|
||
[nlp] | ||
lang = "en" | ||
pipeline = ["llm"] | ||
batch_size = 128 | ||
|
||
[components] | ||
|
||
[components.llm] | ||
factory = "llm" | ||
|
||
[components.llm.task] | ||
@llm_tasks = "spacy.NoOp.v1" | ||
|
||
[components.llm.model] | ||
@llm_models = "spacy.Mistral.v1" | ||
name = "Mistral-7B-v0.1" | ||
""" | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
def test_init(): | ||
"""Test initialization and simple run.""" | ||
nlp = spacy.blank("en") | ||
cfg = copy.deepcopy(_PIPE_CFG) | ||
nlp.add_pipe("llm", config=cfg) | ||
nlp("This is a test.") | ||
torch.cuda.empty_cache() | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
def test_init_from_config(): | ||
orig_config = Config().from_str(_NLP_CONFIG) | ||
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) | ||
assert nlp.pipe_names == ["llm"] | ||
torch.cuda.empty_cache() | ||
|
||
|
||
@pytest.mark.gpu | ||
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
def test_invalid_model(): | ||
orig_config = Config().from_str(_NLP_CONFIG) | ||
config = copy.deepcopy(orig_config) | ||
config["components"]["llm"]["model"]["name"] = "x" | ||
with pytest.raises(ValueError, match="unexpected value; permitted"): | ||
spacy.util.load_model_from_config(config, auto_fill=True) | ||
torch.cuda.empty_cache() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.