Skip to content

Commit 9cf3c8b

Browse files
authored
add dtype and seed (#2430)
1 parent 318d070 commit 9cf3c8b

File tree

7 files changed

+84
-17
lines changed

7 files changed

+84
-17
lines changed

fastchat/llm_judge/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as j
1010
- [Review Pre-Generated Model Answers and Judgments](#review-pre-generated-model-answers-and-judgments)
1111
- [MT-Bench](#mt-bench)
1212
- [Agreement Computation](#agreement-computation)
13-
- [Dataset](#dataset)
13+
- [Datasets](#datasets)
1414
- [Citation](#citation)
1515

1616
## Install
@@ -133,7 +133,7 @@ We released 3.3K human annotations for model responses generated by 6 models in
133133

134134
This Colab [notebook](https://colab.research.google.com/drive/1ctgygDRJhVGUJTQy8-bRZCl1WNcT8De6?usp=sharing) shows how to compute the agreement between humans and GPT-4 judge with the dataset. Our results show that humans and GPT-4 judge achieve over 80\% agreement, the same level of agreement between humans.
135135

136-
## Dataset
136+
## Datasets
137137
- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations)
138138
- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)
139139

fastchat/llm_judge/gen_model_answer.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from fastchat.llm_judge.common import load_questions, temperature_config
1717
from fastchat.model import load_model, get_conversation_template
18+
from fastchat.utils import str_to_torch_dtype
1819

1920

2021
def run_eval(
@@ -29,6 +30,7 @@ def run_eval(
2930
num_gpus_per_model,
3031
num_gpus_total,
3132
max_gpu_memory,
33+
dtype,
3234
):
3335
questions = load_questions(question_file, question_begin, question_end)
3436
# random shuffle the questions to balance the loading
@@ -45,7 +47,7 @@ def run_eval(
4547
else:
4648
get_answers_func = get_model_answers
4749

48-
chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) // 2
50+
chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
4951
ans_handles = []
5052
for i in range(0, len(questions), chunk_size):
5153
ans_handles.append(
@@ -58,6 +60,7 @@ def run_eval(
5860
num_choices,
5961
num_gpus_per_model,
6062
max_gpu_memory,
63+
dtype=dtype,
6164
)
6265
)
6366

@@ -75,12 +78,14 @@ def get_model_answers(
7578
num_choices,
7679
num_gpus_per_model,
7780
max_gpu_memory,
81+
dtype,
7882
):
7983
model, tokenizer = load_model(
8084
model_path,
8185
device="cuda",
8286
num_gpus=num_gpus_per_model,
8387
max_gpu_memory=max_gpu_memory,
88+
dtype=dtype,
8489
load_8bit=False,
8590
cpu_offloading=False,
8691
debug=False,
@@ -192,7 +197,9 @@ def reorg_answer_file(answer_file):
192197
required=True,
193198
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
194199
)
195-
parser.add_argument("--model-id", type=str, required=True)
200+
parser.add_argument(
201+
"--model-id", type=str, required=True, help="A custom name for the model."
202+
)
196203
parser.add_argument(
197204
"--bench-name",
198205
type=str,
@@ -234,6 +241,14 @@ def reorg_answer_file(answer_file):
234241
type=str,
235242
help="Maxmum GPU memory used for model weights per GPU.",
236243
)
244+
parser.add_argument(
245+
"--dtype",
246+
type=str,
247+
choices=["float32", "float16", "bfloat16"],
248+
help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
249+
default=None,
250+
)
251+
237252
args = parser.parse_args()
238253

239254
if args.num_gpus_total // args.num_gpus_per_model > 1:
@@ -250,17 +265,18 @@ def reorg_answer_file(answer_file):
250265
print(f"Output to {answer_file}")
251266

252267
run_eval(
253-
args.model_path,
254-
args.model_id,
255-
question_file,
256-
args.question_begin,
257-
args.question_end,
258-
answer_file,
259-
args.max_new_token,
260-
args.num_choices,
261-
args.num_gpus_per_model,
262-
args.num_gpus_total,
263-
args.max_gpu_memory,
268+
model_path=args.model_path,
269+
model_id=args.model_id,
270+
question_file=question_file,
271+
question_begin=args.question_begin,
272+
question_end=args.question_end,
273+
answer_file=answer_file,
274+
max_new_token=args.max_new_token,
275+
num_choices=args.num_choices,
276+
num_gpus_per_model=args.num_gpus_per_model,
277+
num_gpus_total=args.num_gpus_total,
278+
max_gpu_memory=args.max_gpu_memory,
279+
dtype=str_to_torch_dtype(args.dtype),
264280
)
265281

266282
reorg_answer_file(answer_file)

fastchat/model/model_adapter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def load_model(
152152
device: str = "cuda",
153153
num_gpus: int = 1,
154154
max_gpu_memory: Optional[str] = None,
155+
dtype: Optional[torch.dtype] = None,
155156
load_8bit: bool = False,
156157
cpu_offloading: bool = False,
157158
gptq_config: Optional[GptqConfig] = None,
@@ -282,6 +283,9 @@ def load_model(
282283
return model, tokenizer
283284
kwargs["revision"] = revision
284285

286+
if dtype is not None: # Overwrite dtype if it is provided in the arguments.
287+
kwargs["torch_dtype"] = dtype
288+
285289
# Load model
286290
model, tokenizer = adapter.load_model(model_path, kwargs)
287291

@@ -393,6 +397,13 @@ def add_model_args(parser):
393397
type=str,
394398
help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'",
395399
)
400+
parser.add_argument(
401+
"--dtype",
402+
type=str,
403+
choices=["float32", "float16", "bfloat16"],
404+
help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
405+
default=None,
406+
)
396407
parser.add_argument(
397408
"--load-8bit", action="store_true", help="Use 8-bit quantization"
398409
)

fastchat/serve/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
from rich.console import Console
2727
from rich.live import Live
2828
from rich.markdown import Markdown
29+
import torch
2930

3031
from fastchat.model.model_adapter import add_model_args
3132
from fastchat.modules.gptq import GptqConfig
3233
from fastchat.modules.awq import AWQConfig
3334
from fastchat.serve.inference import ChatIO, chat_loop
35+
from fastchat.utils import str_to_torch_dtype
3436

3537

3638
class SimpleChatIO(ChatIO):
@@ -208,6 +210,7 @@ def main(args):
208210
args.device,
209211
args.num_gpus,
210212
args.max_gpu_memory,
213+
str_to_torch_dtype(args.dtype),
211214
args.load_8bit,
212215
args.cpu_offloading,
213216
args.conv_template,

fastchat/serve/inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def chat_loop(
291291
device: str,
292292
num_gpus: int,
293293
max_gpu_memory: str,
294+
dtype: Optional[torch.dtype],
294295
load_8bit: bool,
295296
cpu_offloading: bool,
296297
conv_template: Optional[str],
@@ -312,6 +313,7 @@ def chat_loop(
312313
device=device,
313314
num_gpus=num_gpus,
314315
max_gpu_memory=max_gpu_memory,
316+
dtype=dtype,
315317
load_8bit=load_8bit,
316318
cpu_offloading=cpu_offloading,
317319
gptq_config=gptq_config,

fastchat/serve/model_worker.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
import torch
3636
import torch.nn.functional as F
37+
from transformers import set_seed
3738
import uvicorn
3839

3940
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
@@ -46,7 +47,12 @@
4647
)
4748
from fastchat.modules.gptq import GptqConfig
4849
from fastchat.modules.awq import AWQConfig
49-
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
50+
from fastchat.utils import (
51+
build_logger,
52+
pretty_print_semaphore,
53+
get_context_length,
54+
str_to_torch_dtype,
55+
)
5056

5157

5258
worker_id = str(uuid.uuid4())[:8]
@@ -190,13 +196,15 @@ def __init__(
190196
device: str,
191197
num_gpus: int,
192198
max_gpu_memory: str,
199+
dtype: Optional[torch.dtype] = None,
193200
load_8bit: bool = False,
194201
cpu_offloading: bool = False,
195202
gptq_config: Optional[GptqConfig] = None,
196203
awq_config: Optional[AWQConfig] = None,
197204
stream_interval: int = 2,
198-
conv_template: str = None,
205+
conv_template: Optional[str] = None,
199206
embed_in_truncate: bool = False,
207+
seed: Optional[int] = None,
200208
**kwargs,
201209
):
202210
super().__init__(
@@ -215,6 +223,7 @@ def __init__(
215223
device=device,
216224
num_gpus=num_gpus,
217225
max_gpu_memory=max_gpu_memory,
226+
dtype=dtype,
218227
load_8bit=load_8bit,
219228
cpu_offloading=cpu_offloading,
220229
gptq_config=gptq_config,
@@ -227,6 +236,7 @@ def __init__(
227236
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
228237
self.stream_interval = stream_interval
229238
self.embed_in_truncate = embed_in_truncate
239+
self.seed = seed
230240

231241
if not no_register:
232242
self.init_heart_beat()
@@ -235,6 +245,8 @@ def generate_stream_gate(self, params):
235245
self.call_ct += 1
236246

237247
try:
248+
if self.seed is not None:
249+
set_seed(self.seed)
238250
for output in self.generate_stream_func(
239251
self.model,
240252
self.tokenizer,
@@ -475,6 +487,12 @@ def create_model_worker():
475487
)
476488
parser.add_argument("--stream-interval", type=int, default=2)
477489
parser.add_argument("--no-register", action="store_true")
490+
parser.add_argument(
491+
"--seed",
492+
type=int,
493+
default=None,
494+
help="Overwrite the random seed for each generation.",
495+
)
478496
args = parser.parse_args()
479497
logger.info(f"args: {args}")
480498

@@ -508,13 +526,15 @@ def create_model_worker():
508526
device=args.device,
509527
num_gpus=args.num_gpus,
510528
max_gpu_memory=args.max_gpu_memory,
529+
dtype=str_to_torch_dtype(args.dtype),
511530
load_8bit=args.load_8bit,
512531
cpu_offloading=args.cpu_offloading,
513532
gptq_config=gptq_config,
514533
awq_config=awq_config,
515534
stream_interval=args.stream_interval,
516535
conv_template=args.conv_template,
517536
embed_in_truncate=args.embed_in_truncate,
537+
seed=args.seed,
518538
)
519539
return args, worker
520540

fastchat/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,18 @@ def get_context_length(config):
302302
if val is not None:
303303
return int(rope_scaling_factor * val)
304304
return 2048
305+
306+
307+
def str_to_torch_dtype(dtype: str):
308+
import torch
309+
310+
if dtype is None:
311+
return None
312+
elif dtype == "float32":
313+
return torch.float32
314+
elif dtype == "float16":
315+
return torch.float16
316+
elif dtype == "bfloat16":
317+
return torch.bfloat16
318+
else:
319+
raise ValueError(f"Unrecognized dtype: {dtype}")

0 commit comments

Comments
 (0)