Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
db10f80
tmp
stephanie-wang May 23, 2025
81eeebc
Working basic test
stephanie-wang May 24, 2025
2f4ea06
tests and group
stephanie-wang May 26, 2025
e31ec81
doc
stephanie-wang May 26, 2025
4393996
lint
stephanie-wang May 26, 2025
e866300
lint
stephanie-wang May 27, 2025
541b54a
test
stephanie-wang May 27, 2025
91da8fc
more tests
stephanie-wang May 27, 2025
7fcaeac
lint
stephanie-wang May 29, 2025
0ce186e
Update python/ray/experimental/collective/util.py
stephanie-wang Jun 3, 2025
f909028
Update python/ray/experimental/collective/communicator.py
stephanie-wang Jun 3, 2025
0fbf8fb
Update python/ray/experimental/collective/collective.py
stephanie-wang Jun 3, 2025
7cc982c
Update python/ray/experimental/collective/collective.py
stephanie-wang Jun 3, 2025
ee69c3d
cleanup
stephanie-wang Jun 3, 2025
8a84643
update
stephanie-wang Jun 3, 2025
32b767a
Merge commit '2ff7298b1a69ea68b0c51a8036acacf147dc8cdf' into gpu-obje…
stephanie-wang Jun 4, 2025
aa7dac9
Unit tests work now
stephanie-wang Jun 4, 2025
2594dc1
Specify backend
stephanie-wang Jun 6, 2025
681704c
Allocate on correct device
stephanie-wang Jun 6, 2025
a72f78c
GPU test
stephanie-wang Jun 6, 2025
cff9244
Merge remote-tracking branch 'upstream/master' into gpu-object-collec…
stephanie-wang Jun 10, 2025
083a3fa
doc
stephanie-wang Jun 10, 2025
a30ea59
more docs
stephanie-wang Jun 10, 2025
909d298
lint
stephanie-wang Jun 10, 2025
ce69966
doc
stephanie-wang Jun 10, 2025
edb45ae
test
stephanie-wang Jun 10, 2025
253cfce
comment
stephanie-wang Jun 10, 2025
c4091bf
fix, lint
stephanie-wang Jun 11, 2025
10e6624
fix and lint
stephanie-wang Jun 11, 2025
309558c
lint
stephanie-wang Jun 12, 2025
5e2ac05
lint
stephanie-wang Jun 12, 2025
d8f457c
avoid torch import
stephanie-wang Jun 12, 2025
816baa5
lint
stephanie-wang Jun 12, 2025
0e751d2
lint
stephanie-wang Jun 12, 2025
88d80be
Merge remote-tracking branch 'upstream/master' into gpu-object-collec…
stephanie-wang Jun 12, 2025
4ca5195
fix imports
stephanie-wang Jun 13, 2025
42ba38a
ignore
stephanie-wang Jun 13, 2025
384fb46
Merge remote-tracking branch 'upstream/master' into gpu-object-collec…
stephanie-wang Jun 16, 2025
a7a79cb
fix
stephanie-wang Jun 16, 2025
2bfccdf
fix test
stephanie-wang Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions python/ray/_private/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from enum import Enum
from typing import Literal

from ray.core.generated.common_pb2 import (
GLOO,
NCCL,
OBJECT_STORE,
ErrorType,
Language,
TaskStatus,
Expand Down Expand Up @@ -122,13 +120,19 @@
LANGUAGE = ["PYTHON", "JAVA", "CPP"]

# See `common.proto` for more details.
TENSOR_TRANSPORT = [
"OBJECT_STORE",
"NCCL",
"GLOO",
]
TypeTensorTransport = Literal[tuple(TENSOR_TRANSPORT)]
TypeTensorTransportEnum = Literal[OBJECT_STORE, NCCL, GLOO]
class TensorTransportEnum(Enum):
OBJECT_STORE = TensorTransport.Value("OBJECT_STORE")
NCCL = TensorTransport.Value("NCCL")
GLOO = TensorTransport.Value("GLOO")

@classmethod
def from_str(cls, name: str) -> "TensorTransportEnum":
name = name.upper()
if name not in cls.__members__:
raise ValueError(
f"Invalid tensor transport {name}, must be one of {list(cls.__members__.keys())}."
)
return cls[name]


def validate_protobuf_enum(grpc_enum, custom_enum):
Expand Down Expand Up @@ -157,4 +161,4 @@ def validate_protobuf_enum(grpc_enum, custom_enum):
validate_protobuf_enum(TaskType, TASK_TYPE)
validate_protobuf_enum(ErrorType, ERROR_TYPE)
validate_protobuf_enum(Language, LANGUAGE)
validate_protobuf_enum(TensorTransport, TENSOR_TRANSPORT)
validate_protobuf_enum(TensorTransport, list(TensorTransportEnum.__members__.keys()))
138 changes: 93 additions & 45 deletions python/ray/_private/gpu_object_manager.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
from collections import namedtuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple

import torch

from ray._private.custom_types import TensorTransportEnum
from ray._raylet import ObjectRef
from ray.actor import ActorHandle
from ray.util.collective.types import Backend

if TYPE_CHECKING:
import torch

# GPUObjectMeta is a named tuple containing the source actor and tensor metadata.
# The tensor metadata is a list of tuples, each containing the shape and dtype
# of a tensor in the GPU object store.
GPUObjectMeta = namedtuple("GPUObjectMeta", ["src_actor", "tensor_meta"])
TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND = {
TensorTransportEnum.NCCL: Backend.NCCL,
TensorTransportEnum.GLOO: Backend.TORCH_GLOO,
}

COLLECTIVE_BACKEND_TO_TORCH_DEVICE = {
Backend.NCCL: torch.device("cuda"),
Backend.TORCH_GLOO: torch.device("cpu"),
}

# GPUObjectMeta is a named tuple containing the source actor, tensor transport
# backend, and tensor metadata.
# - The tensor transport backend is the backend used to transport the tensors.
# Currently, the supported backends are "nccl" and "torch_gloo".
# - The tensor metadata is a list of tuples, each containing the shape and dtype
# of a tensor in the GPU object store.
class GPUObjectMeta(NamedTuple):
src_actor: ActorHandle
# Must be a valid backend name as defined in
# `ray.util.collective.types.Backend`.
tensor_transport_backend: str
tensor_meta: List[Tuple["torch.Size", "torch.dtype"]]


class GPUObjectManager:
Expand Down Expand Up @@ -55,14 +76,27 @@ def __ray_get_tensor_meta__(self, obj_id: str):

return src_actor.__ray_call__.remote(__ray_get_tensor_meta__, obj_id)

def add_gpu_object_ref(self, obj_ref: ObjectRef, src_actor: ActorHandle):
# `obj_ref` is an ObjectRef generated by the `src_actor`'s actor task
# that is annotated with `@ray.method(tensor_transport=...)`. This function
# adds the `obj_ref` to the `gpu_object_refs` dictionary so that the coordinator
# process can determine whether the `obj_ref` is a GPU object reference or not.
def add_gpu_object_ref(self, obj_ref: ObjectRef, src_actor: ActorHandle, tensor_transport: TensorTransportEnum):
"""Add a GPU object reference to the GPU object manager. This should be
called whenever the current process calls a task that is annotated with
`@ray.method(tensor_transport=...)`.

Args:
obj_ref: The ObjectRef of the task output.
src_actor: The actor that executes the task and that creates the GPU object.
tensor_transport: The tensor transport protocol to use for the GPU object.
"""
try:
tensor_transport_backend = TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND[tensor_transport]
except KeyError:
raise ValueError(
f"Invalid tensor transport {tensor_transport.name}, must be one of {list(TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND.keys())}."
)
tensor_meta = self._get_tensor_meta(src_actor, obj_ref.hex())
self.gpu_object_refs[obj_ref] = GPUObjectMeta(
src_actor=src_actor, tensor_meta=tensor_meta
src_actor=src_actor,
tensor_transport_backend=tensor_transport_backend,
tensor_meta=tensor_meta
)

# TODO(kevin85421): Call this function to remove the `obj_ref` from the `gpu_object_refs` dictionary
Expand All @@ -76,30 +110,38 @@ def _get_gpu_object_ref(self, obj_ref: ObjectRef) -> Optional[GPUObjectMeta]:
def _is_gpu_object_ref(self, obj_ref: ObjectRef) -> bool:
return obj_ref in self.gpu_object_refs

def _send_gpu_object(self, src_actor: ActorHandle, obj_id: str, dst_rank: int):
def _send_gpu_object(self, communicator_name: str, src_actor: ActorHandle, obj_id: str, dst_rank: int):
# Send tensors stored in the `src_actor`'s GPU object store to the
# destination rank `dst_rank`.
def __ray_send__(self, obj_id: str, dst_rank: int):
import torch.distributed as dist

def __ray_send__(self, communicator_name: str, obj_id: str, dst_rank: int):
import ray.util.collective as collective
from ray._private.worker import global_worker

gpu_object_manager = global_worker.gpu_object_manager
assert gpu_object_manager.has_gpu_object(
obj_id
), f"obj_id={obj_id} not found in GPU object store"
tensors = gpu_object_manager.get_gpu_object(obj_id)

backend = collective.get_group_handle(communicator_name).backend()
device = COLLECTIVE_BACKEND_TO_TORCH_DEVICE[backend]

for tensor in tensors:
dist.send(tensor, dst_rank)
if tensor.device.type != device.type:
# TODO(swang): Right now there is no way to catch this error
# and the receiving Ray task will hang.
raise ValueError(f"tensor device {tensor.device} does not match device {device}")
collective.send(tensor, dst_rank, group_name=communicator_name)
# TODO(kevin85421): The current garbage collection implementation for the
# in-actor object store is naive. We garbage collect each object after it
# is consumed once.
gpu_object_manager.remove_gpu_object(obj_id)

src_actor.__ray_call__.remote(__ray_send__, obj_id, dst_rank)
src_actor.__ray_call__.remote(__ray_send__, communicator_name, obj_id, dst_rank)

def _recv_gpu_object(
self,
communicator_name: str,
dst_actor: ActorHandle,
obj_id: str,
src_rank: int,
Expand All @@ -109,25 +151,29 @@ def _recv_gpu_object(
# `dst_actor`'s GPU object store.
def __ray_recv__(
self,
communicator_name: str,
obj_id: str,
src_rank: int,
tensor_meta: List[Tuple["torch.Size", "torch.dtype"]],
):
import torch
import torch.distributed as dist

import ray.util.collective as collective
from ray._private.worker import global_worker

backend = collective.get_group_handle(communicator_name).backend()
device = COLLECTIVE_BACKEND_TO_TORCH_DEVICE[backend]

gpu_object_manager = global_worker.gpu_object_manager
tensors = []
for meta in tensor_meta:
shape, dtype = meta
tensor = torch.zeros(shape, dtype=dtype)
dist.recv(tensor, src_rank)
tensor = torch.zeros(shape, dtype=dtype, device=device)
collective.recv(tensor, src_rank, group_name=communicator_name)
tensors.append(tensor)
gpu_object_manager.add_gpu_object(obj_id, tensors)

dst_actor.__ray_call__.remote(__ray_recv__, obj_id, src_rank, tensor_meta)
dst_actor.__ray_call__.remote(__ray_recv__, communicator_name, obj_id, src_rank, tensor_meta)

def trigger_out_of_band_tensor_transfer(
self, dst_actor: ActorHandle, task_args: Tuple[Any, ...]
Expand All @@ -150,11 +196,8 @@ def trigger_out_of_band_tensor_transfer(
dst_actor: The target actor to receive tensors
task_args: List of arguments for the target actor task that may contain ObjectRefs.
"""
from ray.experimental.channel import ChannelContext

ctx = ChannelContext.get_current()
from ray.experimental.collective import get_collective_groups

actor_id_to_rank = {}
for arg in task_args:
# If an ObjectRef exists in `gpu_object_refs`, it means the ObjectRef
# is in-actor tensors. Therefore, this function will trigger a tensor
Expand All @@ -168,32 +211,37 @@ def trigger_out_of_band_tensor_transfer(

src_actor = gpu_object_meta.src_actor
tensor_meta = gpu_object_meta.tensor_meta
if not actor_id_to_rank:
# TODO(kevin85421): Support multiple communicators.
if len(ctx.communicators) != 1:
raise ValueError(
f"There are {len(ctx.communicators)} communicators in the current context. "
"Currently, GPU objects only support 1 communicator. Please make sure only "
"one communicator exists."
)
actor_id_to_rank = {
a._ray_actor_id: i for i, a in enumerate(ctx.communicators[0])
}
if src_actor._ray_actor_id not in actor_id_to_rank:
communicators = get_collective_groups([src_actor, dst_actor], backend=gpu_object_meta.tensor_transport_backend)
# TODO(kevin85421): Support multiple communicators.
if len(communicators) == 0:
raise ValueError(
f"No communicators found for actors {src_actor} and {dst_actor}. "
"Create a communicator with "
"`ray.experimental.collective.create_collective_group` "
"before calling actor tasks."
)
elif len(communicators) > 1:
raise ValueError(
f"There are {len(communicators)} possible communicators that contain actors {src_actor} and {dst_actor}. "
"Currently, GPU objects only support one communicator. Please make sure only "
"one communicator exists."
)
communicator = communicators[0]
src_rank = communicator.get_rank(src_actor)
if src_rank == -1:
raise ValueError(
f"Sender actor {src_actor._ray_actor_id} not found in communicator. "
f"Sender actor {src_actor} not found in communicator. "
"Please make sure the sender and receiver are in the same communicator."
)
if dst_actor._ray_actor_id not in actor_id_to_rank:
dst_rank = communicator.get_rank(dst_actor)
if dst_rank == -1:
raise ValueError(
f"Receiver actor {dst_actor._ray_actor_id} not found in communicator. "
f"Receiver actor {dst_actor} not found in communicator. "
"Please make sure the sender and receiver are in the same communicator."
)
src_rank = actor_id_to_rank[src_actor._ray_actor_id]
dst_rank = actor_id_to_rank[dst_actor._ray_actor_id]
if src_rank == dst_rank:
raise ValueError(
f"src_rank: {src_rank} and dst_rank: {dst_rank} are the same. This may cause deadlock for transports like NCCL."
)
self._send_gpu_object(src_actor, arg.hex(), dst_rank)
self._recv_gpu_object(dst_actor, arg.hex(), src_rank, tensor_meta)
self._send_gpu_object(communicator.name, src_actor, arg.hex(), dst_rank)
self._recv_gpu_object(communicator.name, dst_actor, arg.hex(), src_rank, tensor_meta)
8 changes: 5 additions & 3 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from ray._private import ray_option_utils
from ray._private.client_mode_hook import client_mode_hook
from ray._private.function_manager import FunctionActorManager
from ray._private.gpu_object_manager import GPUObjectManager
from ray._private.inspect_util import is_cython
from ray._private.ray_logging import (
global_worker_stdstream_dispatcher,
Expand Down Expand Up @@ -448,7 +447,7 @@ def __init__(self):
self.actors = {}
# GPU object manager to manage GPU object lifecycles, including coordinating out-of-band
# tensor transfers between actors, storing and retrieving GPU objects, and garbage collection.
self._gpu_object_manager = GPUObjectManager()
self._gpu_object_manager = None
# When the worker is constructed. Record the original value of the
# (CUDA_VISIBLE_DEVICES, ONEAPI_DEVICE_SELECTOR, HIP_VISIBLE_DEVICES,
# NEURON_RT_VISIBLE_CORES, TPU_VISIBLE_CHIPS, ..) environment variables.
Expand Down Expand Up @@ -499,7 +498,10 @@ def __init__(self):
self._is_connected: bool = False

@property
def gpu_object_manager(self) -> GPUObjectManager:
def gpu_object_manager(self) -> "ray._private.gpu_object_manager.GPUObjectManager":
if self._gpu_object_manager is None:
from ray._private.gpu_object_manager import GPUObjectManager
self._gpu_object_manager = GPUObjectManager()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why's this made to be lazy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah this is to avoid pulling in any dependencies needed by GPUObjectManager that aren't required by ray usually (currently torch).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Would be nice if we came up with a more structured way to quarantine soft dependencies so we don't need lazy imports for first party code. I'll play around with it at some point.

return self._gpu_object_manager

@property
Expand Down
Loading