Skip to content

Conversation

mehdiataei
Copy link
Contributor

Category

  • New feature
  • Bugfix
  • Breaking change
  • Refactoring
  • Documentation
  • Other (please explain)

Description

  1. In warp/jax_experimental.py:
  • Modified the jax_kernel function to accept an optional launch_dims parameter.
  • Updated the abstract evaluation and lowering functions to use the provided launch_dims when available.
  • Changed jax.devices() to jax.local_devices() in the _get_jax_device function (this was a bug, as in multi-GPU settings Warp may select non-addressable devices).
  1. In warp/tests/test_jax.py:
    Added a new test case test_jax_kernel_launch_dims to verify the functionality of custom launch dimensions for both 1D and 2D kernels.
  2. In docs/modules/interoperability.rst:
  • Removed the limitation that output shapes must match launch dimensions given the new feature.
  • Added a new section on using shardmap for distributed multi-GPU computation with Warp and JAX.
  • Added a section on specifying launch dimensions for multi-GPU matrix operations.

Changelog

  • jax_kernel now accepts an optional launch_dims parameter. The launch dim is no longer limited the the shape of the first input.
  • Changed device selection from jax.devices() to jax.local_devices() to address multi-GPU launch issues.
  • Added tutorials on using JAX's shardmap for multi-GPU computations and specifying custom launch dimensions.

Before your PR is "Ready for review"

  • Do you agree to the terms under which contributions are accepted as described in Section 9 the Warp License?
  • Have you read the Contributor Guidelines?
  • Have you written any new necessary tests?
  • Have you added or updated any necessary documentation?
  • Have you added any files modified by compiling Warp and building the documentation to this PR (.e.g. stubs.py, functions.rst)?
  • Does your code pass ruff check and ruff format --check?

@shi-eric shi-eric requested a review from nvlukasz September 9, 2024 21:07
Copy link
Contributor

@nvlukasz nvlukasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great and I really appreciate the documentation updates!

@shi-eric shi-eric merged commit 2b3a7c8 into NVIDIA:main Sep 17, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants