Skip to content

[MODEL] add Exaone model support #7819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 30, 2024
Merged

Conversation

nayohan
Copy link
Contributor

@nayohan nayohan commented Aug 23, 2024

Recently, The new model exaone released. I would love to contribute the new model to vLLM as well.

In this PR, I have provided the implementation of EXAONE-3.0 model and add model configs.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 23, 2024
@shing100
Copy link
Contributor

#7236

@nayohan nayohan force-pushed the add_exaone branch 2 times, most recently from c857d59 to 449aa30 Compare August 28, 2024 01:54
@nayohan
Copy link
Contributor Author

nayohan commented Aug 28, 2024

Checked the ruff format and fixed the code.

@nayohan
Copy link
Contributor Author

nayohan commented Aug 28, 2024

Solve #7236

Summary of changes

  • Add ExaoneModel

    • vllm/model_executor/models/init.py
    • vllm/model_executor/models/exaone.py
  • Add ExaoneConfig

    • vllm/transformers_utils/config.py
    • vllm/transformers_utils/configs/init.py. (New)
    • vllm/transformers_utils/configs/exaone.py (New)

Test Result

Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from vllm import LLM, SamplingParams
WARNING 08-28 11:10:49 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead, and make sure to uninstall `pynvml`. When both of them are installed, `pynvml` will take precedence and cause errors. See https://pypi.org/project/pynvml for more information.
>>> model = LLM("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", download_dir="/data/project/yohan/98_model")
INFO 08-28 11:11:56 config.py:1610] Downcasting torch.float32 to torch.float16.
INFO 08-28 11:11:56 llm_engine.py:210] Initializing an LLM engine (v0.5.5) with config: model='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', speculative_config=None, tokenizer='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir='/data/project/yohan/98_model', load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=False, use_async_output_proc=True)
INFO 08-28 11:11:57 model_runner.py:906] Starting to load model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct...
INFO 08-28 11:11:58 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  14% Completed | 1/7 [00:02<00:13,  2.33s/it]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:05<00:14,  2.85s/it]
Loading safetensors checkpoint shards:  43% Completed | 3/7 [00:08<00:12,  3.01s/it]
Loading safetensors checkpoint shards:  57% Completed | 4/7 [00:11<00:08,  2.98s/it]
Loading safetensors checkpoint shards:  71% Completed | 5/7 [00:14<00:06,  3.05s/it]
Loading safetensors checkpoint shards:  86% Completed | 6/7 [00:17<00:02,  2.97s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:18<00:00,  2.12s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:18<00:00,  2.58s/it]

INFO 08-28 11:12:16 model_runner.py:917] Loading model weights took 14.5640 GB
INFO 08-28 11:12:17 gpu_executor.py:121] # GPU blocks: 10030, # CPU blocks: 2048
INFO 08-28 11:12:20 model_runner.py:1212] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 08-28 11:12:20 model_runner.py:1216] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 08-28 11:12:29 model_runner.py:1331] Graph capturing finished in 9 secs.
>>> model.generate("Hello!")
Processed prompts: 100%|█████████████████████████████| 1/1 [00:00<00:00,  3.76it/s, est. speed input: 7.52 toks/s, output: 60.14 toks/s]
[RequestOutput(request_id=0, prompt='Hello!', prompt_token_ids=[33381, 362], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=" It looks like you're interested in understanding the MMa MparamItem and", token_ids=array('l', [1533, 7589, 1664, 904, 368, 628, 9124, 666, 6835, 629, 13995, 426, 852, 23219, 9314, 686]), cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1724843556.9958072, last_token_time=1724843556.9958072, first_scheduled_time=1724843557.0000691, first_token_time=1724843557.0201893, time_in_queue=0.004261970520019531, finished_time=1724843557.2531652, scheduler_time=0.001150771975517273, model_forward_time=None, model_execute_time=None), lora_request=None)]

@nayohan
Copy link
Contributor Author

nayohan commented Aug 28, 2024

Here is benchmark result with A100 40GB * 2. (tensor-parallel-size 2)

git clone https://github.com/nayohan/vllm
cd vllm
pip install -e . 

# vllm 0.5.5+cu124   /data/project/yohan/01_project/vllm

python3 benchmark_throughput.py --model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct --max_model_len 4096 --tensor-parallel-size 2 --gpu-memory-utilization 0.95 --dataset "ShareGPT_V3_unfiltered_cleaned_split.json" --output_json o
Throuhput benchmark result (Click to Expand)
python3 benchmark_throughput.py --model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct --max_model_len 4096 --tensor-parallel-size 2 --gpu-memory-utilization 0.95 --dataset "ShareGPT_V3_unfiltered_cleaned_split.json" --output_json o
WARNING 08-28 10:31:54 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead, and make sure to uninstall `pynvml`. When both of them are installed, `pynvml` will take precedence and cause errors. See https://pypi.org/project/pynvml for more information.
Namespace(backend='vllm', dataset='ShareGPT_V3_unfiltered_cleaned_split.json', input_len=None, output_len=None, model='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', tokenizer=None, quantization=None, tensor_parallel_size=2, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False, max_model_len=4096, dtype='auto', gpu_memory_utilization=0.95, enforce_eager=False, kv_cache_dtype='auto', quantization_param_path=None, device='auto', num_scheduler_steps=1, use_v2_block_manager=False, enable_prefix_caching=False, enable_chunked_prefill=False, max_num_batched_tokens=None, download_dir='/data/project/yohan/98_model', output_json='o', distributed_executor_backend=None, load_format='auto')
Namespace(backend='vllm', dataset='ShareGPT_V3_unfiltered_cleaned_split.json', input_len=None, output_len=None, model='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', tokenizer='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', quantization=None, tensor_parallel_size=2, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False, max_model_len=4096, dtype='auto', gpu_memory_utilization=0.95, enforce_eager=False, kv_cache_dtype='auto', quantization_param_path=None, device='auto', num_scheduler_steps=1, use_v2_block_manager=False, enable_prefix_caching=False, enable_chunked_prefill=False, max_num_batched_tokens=None, download_dir='/data/project/yohan/98_model', output_json='o', distributed_executor_backend=None, load_format='auto')
INFO 08-28 10:32:03 config.py:1610] Downcasting torch.float32 to torch.float16.
INFO 08-28 10:32:03 config.py:864] Defaulting to use mp for distributed inference
INFO 08-28 10:32:03 llm_engine.py:210] Initializing an LLM engine (v0.5.5) with config: model='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', speculative_config=None, tokenizer='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir='/data/project/yohan/98_model', load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=False, use_async_output_proc=True)
WARNING 08-28 10:32:03 multiproc_gpu_executor.py:55] Reducing Torch parallelism from 255 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 08-28 10:32:03 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:04 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 08-28 10:32:04 utils.py:976] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:04 utils.py:976] Found nccl from library libnccl.so.2
INFO 08-28 10:32:04 pynccl.py:63] vLLM is using nccl==2.20.5
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:04 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 08-28 10:32:05 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:05 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 08-28 10:32:05 shm_broadcast.py:235] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7f33b56855d0>, local_subscribe_port=46443, remote_subscribe_port=None)
INFO 08-28 10:32:05 model_runner.py:906] Starting to load model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct...
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:05 model_runner.py:906] Starting to load model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct...
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:05 weight_utils.py:236] Using model weights format ['*.safetensors']
INFO 08-28 10:32:05 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  14% Completed | 1/7 [00:05<00:34,  5.73s/it]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:13<00:33,  6.66s/it]
Loading safetensors checkpoint shards:  43% Completed | 3/7 [00:20<00:27,  6.99s/it]
Loading safetensors checkpoint shards:  57% Completed | 4/7 [00:27<00:20,  6.89s/it]
Loading safetensors checkpoint shards:  71% Completed | 5/7 [00:34<00:13,  6.99s/it]
Loading safetensors checkpoint shards:  86% Completed | 6/7 [00:41<00:06,  7.00s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:42<00:00,  5.24s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:42<00:00,  6.14s/it]

INFO 08-28 10:32:49 model_runner.py:917] Loading model weights took 7.2827 GB
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:49 model_runner.py:917] Loading model weights took 7.2827 GB
INFO 08-28 10:32:50 distributed_gpu_executor.py:56] # GPU blocks: 29314, # CPU blocks: 4096
INFO 08-28 10:32:52 model_runner.py:1212] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 08-28 10:32:52 model_runner.py:1216] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:52 model_runner.py:1212] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(VllmWorkerProcess pid=2175918) INFO 08-28 10:32:52 model_runner.py:1216] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=2175918) INFO 08-28 10:33:10 custom_all_reduce.py:223] Registering 2275 cuda graph addresses
INFO 08-28 10:33:10 custom_all_reduce.py:223] Registering 2275 cuda graph addresses
(VllmWorkerProcess pid=2175918) INFO 08-28 10:33:10 model_runner.py:1331] Graph capturing finished in 18 secs.
INFO 08-28 10:33:10 model_runner.py:1331] Graph capturing finished in 18 secs.
Processed prompts: 100%|████████████████| 1000/1000 [01:10<00:00, 14.12it/s, est. speed input: 3349.74 toks/s, output: 3108.96 toks/s]
Throughput: 13.87 requests/s, 6343.01 tokens/s

Throughput: 13.87 requests/s, 6343.01 tokens/s

@nayohan
Copy link
Contributor Author

nayohan commented Aug 28, 2024

Here is benchmark result with A100 40GB * 1. (--quantization fp8)

CUDA_VISIBLE_DEVICES=0 python3 benchmark_throughput.py --model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct --max_model_len 4096 --tensor-parallel-size 1 --gpu-memory-utilization 0.95 --dataset "ShareGPT_V3_unfiltered_cleaned_split.json" --output_json o --quantization fp8
Throuhput benchmark result
CUDA_VISIBLE_DEVICES=0 python3 benchmark_throughput.py --model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct --max_model_len 4096 --tensor-parallel-size 1 --gpu-memory-utilization 0.95 --dataset "ShareGPT_V3_unfiltered_cleaned_split.json" --output_json o --quantization fp8
 WARNING 08-28 10:47:22 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead, and make sure to uninstall `pynvml`. When both of them are installed, `pynvml` will take precedence and cause errors. See https://pypi.org/project/pynvml for more information.
Namespace(backend='vllm', dataset='ShareGPT_V3_unfiltered_cleaned_split.json', input_len=None, output_len=None, model='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', tokenizer=None, quantization='fp8', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False, max_model_len=4096, dtype='auto', gpu_memory_utilization=0.95, enforce_eager=False, kv_cache_dtype='auto', quantization_param_path=None, device='auto', num_scheduler_steps=1, use_v2_block_manager=False, enable_prefix_caching=False, enable_chunked_prefill=False, max_num_batched_tokens=None, download_dir='/data/project/yohan/98_model', output_json='o', distributed_executor_backend=None, load_format='auto')
Namespace(backend='vllm', dataset='ShareGPT_V3_unfiltered_cleaned_split.json', input_len=None, output_len=None, model='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', tokenizer='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', quantization='fp8', tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False, max_model_len=4096, dtype='auto', gpu_memory_utilization=0.95, enforce_eager=False, kv_cache_dtype='auto', quantization_param_path=None, device='auto', num_scheduler_steps=1, use_v2_block_manager=False, enable_prefix_caching=False, enable_chunked_prefill=False, max_num_batched_tokens=None, download_dir='/data/project/yohan/98_model', output_json='o', distributed_executor_backend=None, load_format='auto')
INFO 08-28 10:47:30 config.py:1610] Downcasting torch.float32 to torch.float16.
INFO 08-28 10:47:30 llm_engine.py:210] Initializing an LLM engine (v0.5.5) with config: model='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', speculative_config=None, tokenizer='LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir='/data/project/yohan/98_model', load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=False, use_async_output_proc=True)
INFO 08-28 10:47:31 model_runner.py:906] Starting to load model LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct...
INFO 08-28 10:47:32 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  14% Completed | 1/7 [00:02<00:14,  2.36s/it]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:05<00:14,  2.92s/it]
Loading safetensors checkpoint shards:  43% Completed | 3/7 [00:09<00:12,  3.15s/it]
Loading safetensors checkpoint shards:  57% Completed | 4/7 [00:12<00:09,  3.16s/it]
Loading safetensors checkpoint shards:  71% Completed | 5/7 [00:15<00:06,  3.21s/it]
Loading safetensors checkpoint shards:  86% Completed | 6/7 [00:18<00:03,  3.18s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:19<00:00,  2.29s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:19<00:00,  2.74s/it]

WARNING 08-28 10:47:51 utils.py:722] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
INFO 08-28 10:47:52 model_runner.py:917] Loading model weights took 8.0678 GB
INFO 08-28 10:47:53 gpu_executor.py:121] # GPU blocks: 14181, # CPU blocks: 2048
INFO 08-28 10:47:55 model_runner.py:1212] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 08-28 10:47:55 model_runner.py:1216] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 08-28 10:48:06 model_runner.py:1331] Graph capturing finished in 11 secs.
Processed prompts: 100%|██████████████████| 1000/1000 [01:18<00:00, 12.71it/s, est. speed input: 3014.60 toks/s, output: 2797.91 toks/s]
Throughput: 12.60 requests/s, 5765.20 tokens/s

Throughput: 12.60 requests/s, 5765.20 tokens/s

@nayohan
Copy link
Contributor Author

nayohan commented Aug 28, 2024

I checked the other PR (#6611 , #7615 ) to add and added the code.
After completing all the work, I tested it in multi-gpu environment and quantization.

please let me know if there is anything missing that should be added. I'll update it. @mgoin @simon-mo

@nayohan
Copy link
Contributor Author

nayohan commented Aug 28, 2024

/ready

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what to "choose" between this and the other PR #7942, but this one does have the README update and also gets a good accuracy score, so I am accepting this one.

lm_eval --model vllm --model_args pretrained=LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct,max_model_len=4096,enable_chunked_prefill=True --tasks gsm8k --batch_size auto
vllm (pretrained=LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct,max_model_len=4096,enable_chunked_prefill=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8044|±  |0.0109|
|     |       |strict-match    |     5|exact_match|↑  |0.8021|±  |0.0110|

@simon-mo
Copy link
Collaborator

Sorry before we merge, a common question we ask is how is this different from llama implementation, and why can't the existing llama implementation run it. For example, we have Mistral, InternLMForCausalLM, and AquilaForCausalLM all mapped directly to llama.py

@nayohan
Copy link
Contributor Author

nayohan commented Aug 29, 2024

Thank you for accepting this PR!

I'll add to @Deepfocused's partial answer to your question and explain the whole change. (#7942 (comment))

The Exaone3 model is Llama based code, but when it pre-trained the model from scratch, it changed the tokenizer and changed some model configs such as keys and values.

The changes are as follows:

  1. model State_dict key changed. Compared to llama3, there are some key changes for each layer.
model.embed_tokens.weight -> transformer.wte.weight
model.layers.0.input_layernorm.weight -> transformer.h.0.ln_1.weight
model.layers.0.self_attn.o_proj.weight -> transformer.h.0.attn.attention.out_proj.weight 
model.layers.0.mlp.gate_proj.weight -> transformer.h.0.mlp.c_fc_0.weight
model.layers.0.mlp.up_proj.weight -> transformer.h.0.mlp.c_fc_1.weight
model.layers.0.mlp.down_proj.weight -> transformer.h.0.mlp.c_proj.weight
model.layers.0.post_attention_layernorm.weight -> transformer.h.0.ln_2.weight
model.norm.weight -> transformer.ln_f.weight
  1. model config key changed. There are some key changes in other parts.
hidden_act -> activation_function
num_hidden_layers -> num_layers	
rms_norm_eps -> layer_norm_epsilon

These two differences make it unlikely that a directly mapping to llama.py would be applicable. If there is another way to map it, please leave a reference PR. I'll update the code.

(While off-topic, It's nice to have a convenient way to evaluate performance using lm_eval. If I do a new model PR in the future, I will include performance evaluation results. Thanks for letting me know!)

@simon-mo simon-mo merged commit dc13e99 into vllm-project:main Aug 30, 2024
36 checks passed
@DarkLight1337
Copy link
Member

DarkLight1337 commented Aug 30, 2024

There appears to be some incompatibilities between the HF model file and the current version of vLLM. It is causing the CI to fail.

Update: It's just the modeling file inside vLLM that's broken. I'll open a PR to fix it.

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants