Skip to content

[Neuron] Adding support for adding/ overriding neuron configuration a… #8062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 4, 2024

Conversation

hbikki
Copy link
Contributor

@hbikki hbikki commented Sep 1, 2024

…nd adding support for neuron model quantization configuration.

  1. Quantization : Added support for int8 quantization for the neuron models, this is achieved by passing the quant config as a param tot he neuron config, the NEURON_QUANT_DTYPE helps us customize the dtype for the quantization.

  2. Neuron Config: Update the default neuron config to support the latest neuron release's optimizations out of the box.

  3. Dynamic Config update/Override : Provided support to override or pass neuron config from the engine args -> ModelConfig -> NeuronConfig , this will minimize the updates to vllm and provides more customizable configuration for clients.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

github-actions bot commented Sep 1, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@hbikki hbikki force-pushed the neuron-config branch 4 times, most recently from e301a9f to f1688eb Compare September 2, 2024 00:24
@hbikki
Copy link
Contributor Author

hbikki commented Sep 2, 2024

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 2, 2024
@WoosukKwon WoosukKwon added the aws-neuron Related to AWS Inferentia & Trainium label Sep 2, 2024
Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hbikki for contribuing. The proposed change looks good overall. I left some comments.

Comment on lines 160 to 158
"quant":
neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will need weight_tiling flag, when quantization being enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep added weight_tiling, thanks!

dequant_dtype: str = "f16",
quantize_method: str = "vector_dynamic",
) -> None:
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this configurable from existing vLLM config? can we pass existing vLLM config ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the datatype we use for quantization in neuron doesn't match the torch d types or existing quant , so created a new quant config to support it.

Copy link
Contributor

@liangfu liangfu Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we could map torch dtype to neuron dtype, we should be able to avoid inventing the environment variable specifically for neuron backend.
for instance, s8 can be translated to torch.int8, and f8e4m3fn and be translated to torch.float8_e4m3fn.
ideally, we should be able to configure quantization data type with existing vllm config.

help map the relation between neuron dtype and torch dtype?

def get_scaled_act_names(self) -> List[str]:
return []

def get_tnx_quantization_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid tnx terminology? it will be confusing what is tnx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it , as the quant config is specific to tnx, I can do a tnx package check before using the config , any thoughts ?

neuron_quantization_config_builder = lambda quant: get_quantization_config(
quant).from_config(quant_config).get_quant_method(None, "")
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1/ Can we use layout enum instead of strings (e.g. BSH)?
2/ Can we build dict with assignment statements?

This would be something like the following:

default_neuron_config_dict = dict(
fused_qkv=True, 
attention_layout=constants.Layout.BSH,
...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated for better readability, thanks!

overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig
overridden_neuron_config = overridden_neuron_config or {}
combined_config = {**default_neuron_config, **overridden_neuron_config}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if there is an overlap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overridden config replaces the default neuron config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example : when use use this as override config:

override_neuron_config={"cast_logits_dtype":"bfloat16", "collectives_layout":"Dummy"}

we get this result :

(Pdb) combined_config
{'collectives_layout': 'Dummy', 'attention_layout': 'BSH', 'fuse_qkv': True, 'quant': <transformers_neuronx.config.QuantizationConfig object at 0x7f259835a2f0>, 'continuous_batching': <transformers_neuronx.config.ContinuousBatchingConfig object at 0x7f259831d750>, 'weight_tiling': True, 'cast_logits_dtype': 'bfloat16'}

Copy link
Contributor

@liangfu liangfu Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this line

combined_config = {**default_neuron_config, **overridden_neuron_config}

there would be an error when we override like this:

combined_config={"collectives_layout":"BSH", "collectives_layout":"HSB"}

right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no the dict supports only one key with the name , the second collectives_layout will overwrite the first one ,
Based on our offline sync update to .update() for readability.

@hbikki hbikki force-pushed the neuron-config branch 2 times, most recently from 4b6a00b to f4ac863 Compare September 4, 2024 17:46
quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might not want weight_tiling enabled, when weight quantization is not enabled.
pair weight_tiling with weight quantization config ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I couldn't find an env variable to enable it dynamically , so paired it with qunat

@@ -119,19 +136,51 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]:
return buckets_list


def _get_default_tnx_neuron_config(model_config: ModelConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also avoid tnx terminology here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

…nd adding support for neuron model quantization configuration.
Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hbikki for the update. It looks good.

I think we can leave the dtype mapping between neuron-specific and torch native data types as a follow-up.

@simon-mo simon-mo merged commit 008cf88 into vllm-project:main Sep 4, 2024
53 checks passed
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
aws-neuron Related to AWS Inferentia & Trainium ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants