Skip to content

Commit f5a4911

Browse files
authored
Improve Support for Mistral-Instruct (#2547)
1 parent cd7d048 commit f5a4911

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

fastchat/conversation.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class SeparatorStyle(IntEnum):
2828
PHOENIX = auto()
2929
ROBIN = auto()
3030
FALCON_CHAT = auto()
31+
MISTRAL_INSTRUCT = auto()
3132

3233

3334
@dataclasses.dataclass
@@ -212,6 +213,17 @@ def get_prompt(self) -> str:
212213
ret += role + ":"
213214

214215
return ret
216+
elif self.sep_style == SeparatorStyle.MISTRAL_INSTRUCT:
217+
ret = self.sep
218+
for i, (role, message) in enumerate(self.messages):
219+
if role == "user":
220+
if self.system_message and i == 0:
221+
ret += "[INST] " + system_prompt + " " + message + " [/INST]"
222+
else:
223+
ret += "[INST] " + message + " [/INST]"
224+
elif role == "assistant" and message:
225+
ret += message + self.sep2 + " "
226+
return ret
215227
else:
216228
raise ValueError(f"Invalid style: {self.sep_style}")
217229

@@ -840,16 +852,21 @@ def get_conv_template(name: str) -> Conversation:
840852
)
841853
)
842854

843-
# Mistral template
855+
# Mistral instruct template
844856
# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template
857+
# https://docs.mistral.ai/usage/guardrailing/
858+
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
845859
register_conv_template(
846860
Conversation(
847-
name="mistral",
848-
system_template="",
849-
roles=("[INST] ", " [/INST]"),
850-
sep_style=SeparatorStyle.LLAMA2,
851-
sep="",
852-
sep2=" </s>",
861+
name="mistral-instruct",
862+
system_message="Always assist with care, respect, and truth. "
863+
"Respond with utmost utility yet securely. "
864+
"Avoid harmful, unethical, prejudiced, or negative content. "
865+
"Ensure replies promote fairness and positivity.",
866+
roles=("user", "assistant"),
867+
sep_style=SeparatorStyle.MISTRAL_INSTRUCT,
868+
sep="<s>",
869+
sep2="</s>",
853870
)
854871
)
855872

fastchat/model/model_adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,11 +1283,11 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
12831283
return get_conv_template("starchat")
12841284

12851285

1286-
class MistralAdapter(BaseModelAdapter):
1287-
"""The model adapter for Mistral AI models"""
1286+
class MistralInstructAdapter(BaseModelAdapter):
1287+
"""The model adapter for Mistral Instruct AI models"""
12881288

12891289
def match(self, model_path: str):
1290-
return "mistral" in model_path.lower()
1290+
return "mistral" in model_path.lower() and "instruct" in model_path.lower()
12911291

12921292
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
12931293
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
@@ -1296,7 +1296,7 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
12961296
return model, tokenizer
12971297

12981298
def get_default_conv_template(self, model_path: str) -> Conversation:
1299-
return get_conv_template("mistral")
1299+
return get_conv_template("mistral-instruct")
13001300

13011301

13021302
class Llama2Adapter(BaseModelAdapter):
@@ -1716,7 +1716,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
17161716
register_model_adapter(InternLMChatAdapter)
17171717
register_model_adapter(StarChatAdapter)
17181718
register_model_adapter(Llama2Adapter)
1719-
register_model_adapter(MistralAdapter)
1719+
register_model_adapter(MistralInstructAdapter)
17201720
register_model_adapter(CuteGPTAdapter)
17211721
register_model_adapter(OpenOrcaAdapter)
17221722
register_model_adapter(WizardCoderAdapter)

fastchat/model/model_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def get_model_info(name: str) -> ModelInfo:
308308
)
309309
register_model_info(
310310
["mistral-7b-instruct"],
311-
"Mistral",
311+
"Mistral-Instruct",
312312
"https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1",
313313
"a large language model by Mistral AI team",
314314
)

0 commit comments

Comments
 (0)