Skip to content

Slow inference performance for large Llama models compared to naive MP #66

@sgsdxzy

Description

@sgsdxzy

The inference speed of naive model parallel is much better than tensor parallel:

Setup: Llama-30b on 2080Ti 22G x4
Naive: 31.64s
4-way TP, main branch: 177.78s
4-way TP, llama branch: 102.22s

The code for naive inference

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.half, device_map="balanced")

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))

The code for TP:

import torch
import time
import os
import json
import tensor_parallel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import accelerate
from transformers.utils.bitsandbytes import replace_8bit_linear
from accelerate.hooks import remove_hook_from_module

model_name = 'models/llama-30b'

tokenizer = AutoTokenizer.from_pretrained(model_name)
with accelerate.init_empty_weights():
    model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)).half()
    model = tensor_parallel.TensorParallelPreTrainedModel(model)

device_map = tensor_parallel.infer_sharded_device_map(model) # <- The model is on meta device but we can sill deduce
                                                #    the target devices for each weight using this helper function
# Get nums parts
with open(f"{model_name}/pytorch_model.bin.index.json", "r") as index_file:
    shard_filenames = set(json.load(index_file)["weight_map"].values())

for shard_filename in sorted(shard_filenames):
    # Download a shard
    shard_path = f"{model_name}/{shard_filename}"
    print(shard_path)
    
    # Convert model shard
    converted_state_dict = tensor_parallel.convert_state_dict( # <- tensor_parallel helper function. 
        torch.load(shard_path),                   #    Creates a tensor_parallel checkpoint form a normal one
        model.tensor_parallel_config,
        world_size=4,
        for_pretrained=True,
    )    
    torch.save(converted_state_dict, "/tmp/shard.bin")
    del converted_state_dict
        
    # Dispatch the shard
    accelerate.load_checkpoint_in_model(
        model,
        checkpoint="/tmp/shard.bin",
        device_map=device_map,
    )

torch.cuda.empty_cache()
model = model.eval()
with torch.no_grad():
    batch = tokenizer(
        "DeepSpeed first included offloading capabilities with ZeRO-Offload, a system ",
        return_tensors="pt"
    )
    batch = {k: v.cuda(0) for k, v in batch.items()}
    print("Start")
    t0 = time.time()
    generated = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], max_length=200)
    t1 = time.time()
    print(f"Output generated in {(t1-t0):.2f} seconds")
    print(tokenizer.decode(generated[0]))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions