Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fastchat/llm_judge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as j
- [Review Pre-Generated Model Answers and Judgments](#review-pre-generated-model-answers-and-judgments)
- [MT-Bench](#mt-bench)
- [Agreement Computation](#agreement-computation)
- [Dataset](#dataset)
- [Datasets](#datasets)
- [Citation](#citation)

## Install
Expand Down Expand Up @@ -133,7 +133,7 @@ We released 3.3K human annotations for model responses generated by 6 models in

This Colab [notebook](https://colab.research.google.com/drive/1ctgygDRJhVGUJTQy8-bRZCl1WNcT8De6?usp=sharing) shows how to compute the agreement between humans and GPT-4 judge with the dataset. Our results show that humans and GPT-4 judge achieve over 80\% agreement, the same level of agreement between humans.

## Dataset
## Datasets
- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations)
- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)

Expand Down
42 changes: 29 additions & 13 deletions fastchat/llm_judge/gen_model_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from fastchat.llm_judge.common import load_questions, temperature_config
from fastchat.model import load_model, get_conversation_template
from fastchat.utils import str_to_torch_dtype


def run_eval(
Expand All @@ -29,6 +30,7 @@ def run_eval(
num_gpus_per_model,
num_gpus_total,
max_gpu_memory,
dtype,
):
questions = load_questions(question_file, question_begin, question_end)
# random shuffle the questions to balance the loading
Expand All @@ -45,7 +47,7 @@ def run_eval(
else:
get_answers_func = get_model_answers

chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) // 2
chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
ans_handles = []
for i in range(0, len(questions), chunk_size):
ans_handles.append(
Expand All @@ -58,6 +60,7 @@ def run_eval(
num_choices,
num_gpus_per_model,
max_gpu_memory,
dtype=dtype,
)
)

Expand All @@ -75,12 +78,14 @@ def get_model_answers(
num_choices,
num_gpus_per_model,
max_gpu_memory,
dtype,
):
model, tokenizer = load_model(
model_path,
device="cuda",
num_gpus=num_gpus_per_model,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=False,
cpu_offloading=False,
debug=False,
Expand Down Expand Up @@ -192,7 +197,9 @@ def reorg_answer_file(answer_file):
required=True,
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument("--model-id", type=str, required=True)
parser.add_argument(
"--model-id", type=str, required=True, help="A custom name for the model."
)
parser.add_argument(
"--bench-name",
type=str,
Expand Down Expand Up @@ -234,6 +241,14 @@ def reorg_answer_file(answer_file):
type=str,
help="Maxmum GPU memory used for model weights per GPU.",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
default=None,
)

args = parser.parse_args()

if args.num_gpus_total // args.num_gpus_per_model > 1:
Expand All @@ -250,17 +265,18 @@ def reorg_answer_file(answer_file):
print(f"Output to {answer_file}")

run_eval(
args.model_path,
args.model_id,
question_file,
args.question_begin,
args.question_end,
answer_file,
args.max_new_token,
args.num_choices,
args.num_gpus_per_model,
args.num_gpus_total,
args.max_gpu_memory,
model_path=args.model_path,
model_id=args.model_id,
question_file=question_file,
question_begin=args.question_begin,
question_end=args.question_end,
answer_file=answer_file,
max_new_token=args.max_new_token,
num_choices=args.num_choices,
num_gpus_per_model=args.num_gpus_per_model,
num_gpus_total=args.num_gpus_total,
max_gpu_memory=args.max_gpu_memory,
dtype=str_to_torch_dtype(args.dtype),
)

reorg_answer_file(answer_file)
11 changes: 11 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def load_model(
device: str = "cuda",
num_gpus: int = 1,
max_gpu_memory: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
Expand Down Expand Up @@ -275,6 +276,9 @@ def load_model(
return model, tokenizer
kwargs["revision"] = revision

if dtype is not None: # Overwrite dtype if it is provided in the arguments.
kwargs["torch_dtype"] = dtype

# Load model
model, tokenizer = adapter.load_model(model_path, kwargs)

Expand Down Expand Up @@ -385,6 +389,13 @@ def add_model_args(parser):
type=str,
help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "float16", "bfloat16"],
help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
default=None,
)
parser.add_argument(
"--load-8bit", action="store_true", help="Use 8-bit quantization"
)
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
import torch

from fastchat.model.model_adapter import add_model_args
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.serve.inference import ChatIO, chat_loop
from fastchat.utils import str_to_torch_dtype


class SimpleChatIO(ChatIO):
Expand Down Expand Up @@ -208,6 +210,7 @@ def main(args):
args.device,
args.num_gpus,
args.max_gpu_memory,
str_to_torch_dtype(args.dtype),
args.load_8bit,
args.cpu_offloading,
args.conv_template,
Expand Down
2 changes: 2 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def chat_loop(
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype],
load_8bit: bool,
cpu_offloading: bool,
conv_template: Optional[str],
Expand All @@ -309,6 +310,7 @@ def chat_loop(
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
Expand Down
24 changes: 22 additions & 2 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
import torch
import torch.nn.functional as F
from transformers import set_seed
import uvicorn

from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG
Expand All @@ -46,7 +47,12 @@
)
from fastchat.modules.gptq import GptqConfig
from fastchat.modules.awq import AWQConfig
from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length
from fastchat.utils import (
build_logger,
pretty_print_semaphore,
get_context_length,
str_to_torch_dtype,
)


worker_id = str(uuid.uuid4())[:8]
Expand Down Expand Up @@ -190,13 +196,15 @@ def __init__(
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
stream_interval: int = 2,
conv_template: str = None,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
seed: Optional[int] = None,
**kwargs,
):
super().__init__(
Expand All @@ -215,6 +223,7 @@ def __init__(
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
Expand All @@ -227,6 +236,7 @@ def __init__(
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
self.stream_interval = stream_interval
self.embed_in_truncate = embed_in_truncate
self.seed = seed

if not no_register:
self.init_heart_beat()
Expand All @@ -235,6 +245,8 @@ def generate_stream_gate(self, params):
self.call_ct += 1

try:
if self.seed is not None:
set_seed(self.seed)
for output in self.generate_stream_func(
self.model,
self.tokenizer,
Expand Down Expand Up @@ -473,6 +485,12 @@ def create_model_worker():
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
parser.add_argument(
"--seed",
type=int,
default=None,
help="Overwrite the random seed for each generation.",
)
args = parser.parse_args()
logger.info(f"args: {args}")

Expand Down Expand Up @@ -506,13 +524,15 @@ def create_model_worker():
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
dtype=str_to_torch_dtype(args.dtype),
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
seed=args.seed,
)
return args, worker

Expand Down
15 changes: 15 additions & 0 deletions fastchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,18 @@ def get_context_length(config):
if val is not None:
return int(rope_scaling_factor * val)
return 2048


def str_to_torch_dtype(dtype: str):
import torch

if dtype is None:
return None
elif dtype == "float32":
return torch.float32
elif dtype == "float16":
return torch.float16
elif dtype == "bfloat16":
return torch.bfloat16
else:
raise ValueError(f"Unrecognized dtype: {dtype}")