Skip to content

Commit 93acaf1

Browse files
stephanie-wangkevin85421edoakes
authored
[core][gpu objects] Integrate single-controller collective APIs with GPU objects (#53720)
Adds integration between the single-controller collective APIs introduced in #53319 and the GPU objects feature prototyped in #52938. Actor collectives created through `ray.experimental.collective.create_collective_group` will now be automatically used if a task declares a tensor transport other than the default OBJECT_STORE. This also adds support for allocating the torch tensors on the correct device (GPU for NCCL and CPU for GLOO). See updates in test_gpu_objects.py for examples. --------- Signed-off-by: Stephanie wang <[email protected]> Signed-off-by: Stephanie Wang <[email protected]> Co-authored-by: Kai-Hsun Chen <[email protected]> Co-authored-by: Edward Oakes <[email protected]>
1 parent 42c5837 commit 93acaf1

File tree

13 files changed

+361
-185
lines changed

13 files changed

+361
-185
lines changed

ci/lint/pydoclint-baseline.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,9 @@ python/ray/_private/worker.py
328328
DOC201: Function `remote` does not have a return section in docstring
329329
--------------------
330330
python/ray/actor.py
331-
DOC101: Function `method`: Docstring contains fewer arguments than in function signature.
331+
DOC102: Function `method`: Docstring contains more arguments than in function signature.
332332
DOC106: Function `method`: The option `--arg-type-hints-in-signature` is `True` but there are no argument type hints in the signature
333-
DOC103: Function `method`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. Arguments in the docstring but not in the function signature: [num_returns: ].
333+
DOC103: Function `method`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: ]. Arguments in the docstring but not in the function signature: [concurrency_group: , max_task_retries: , num_returns: , retry_exceptions: , tensor_transport: ].
334334
DOC201: Function `method` does not have a return section in docstring
335335
DOC107: Method `ActorMethod.__init__`: The option `--arg-type-hints-in-signature` is `True` but not all args in the signature have type hints
336336
DOC101: Method `ActorMethod.options`: Docstring contains fewer arguments than in function signature.

python/ray/_private/custom_types.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1+
from enum import Enum
12
from typing import Literal
23

34
from ray.core.generated.common_pb2 import (
4-
GLOO,
5-
NCCL,
6-
OBJECT_STORE,
75
ErrorType,
86
Language,
97
TaskStatus,
@@ -122,13 +120,19 @@
122120
LANGUAGE = ["PYTHON", "JAVA", "CPP"]
123121

124122
# See `common.proto` for more details.
125-
TENSOR_TRANSPORT = [
126-
"OBJECT_STORE",
127-
"NCCL",
128-
"GLOO",
129-
]
130-
TypeTensorTransport = Literal[tuple(TENSOR_TRANSPORT)]
131-
TypeTensorTransportEnum = Literal[OBJECT_STORE, NCCL, GLOO]
123+
class TensorTransportEnum(Enum):
124+
OBJECT_STORE = TensorTransport.Value("OBJECT_STORE")
125+
NCCL = TensorTransport.Value("NCCL")
126+
GLOO = TensorTransport.Value("GLOO")
127+
128+
@classmethod
129+
def from_str(cls, name: str) -> "TensorTransportEnum":
130+
name = name.upper()
131+
if name not in cls.__members__:
132+
raise ValueError(
133+
f"Invalid tensor transport {name}, must be one of {list(cls.__members__.keys())}."
134+
)
135+
return cls[name]
132136

133137

134138
def validate_protobuf_enum(grpc_enum, custom_enum):
@@ -157,4 +161,4 @@ def validate_protobuf_enum(grpc_enum, custom_enum):
157161
validate_protobuf_enum(TaskType, TASK_TYPE)
158162
validate_protobuf_enum(ErrorType, ERROR_TYPE)
159163
validate_protobuf_enum(Language, LANGUAGE)
160-
validate_protobuf_enum(TensorTransport, TENSOR_TRANSPORT)
164+
validate_protobuf_enum(TensorTransport, list(TensorTransportEnum.__members__.keys()))
Lines changed: 97 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,40 @@
1-
from collections import namedtuple
2-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
1+
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple
32

3+
from ray._private.custom_types import TensorTransportEnum
44
from ray._raylet import ObjectRef
55
from ray.actor import ActorHandle
66

7+
# Avoid importing util until needed because it requires several external
8+
# dependencies like torch and cupy. These dependencies can significantly slow
9+
# down normal worker startup time.
10+
util = None
11+
712
if TYPE_CHECKING:
813
import torch
914

10-
# GPUObjectMeta is a named tuple containing the source actor and tensor metadata.
11-
# The tensor metadata is a list of tuples, each containing the shape and dtype
12-
# of a tensor in the GPU object store.
13-
GPUObjectMeta = namedtuple("GPUObjectMeta", ["src_actor", "tensor_meta"])
15+
from ray._private import gpu_object_manager_util as util
16+
17+
18+
def _get_or_import_util():
19+
"""Lazily import the gpu_object_manager_util module."""
20+
global util
21+
if util is None:
22+
from ray._private import gpu_object_manager_util as util
23+
return util
24+
25+
26+
# GPUObjectMeta is a named tuple containing the source actor, tensor transport
27+
# backend, and tensor metadata.
28+
# - The tensor transport backend is the backend used to transport the tensors.
29+
# Currently, the supported backends are "nccl" and "torch_gloo".
30+
# - The tensor metadata is a list of tuples, each containing the shape and dtype
31+
# of a tensor in the GPU object store.
32+
class GPUObjectMeta(NamedTuple):
33+
src_actor: ActorHandle
34+
# Must be a valid backend name as defined in
35+
# `ray.util.collective.types.Backend`.
36+
tensor_transport_backend: str
37+
tensor_meta: List[Tuple["torch.Size", "torch.dtype"]]
1438

1539

1640
class GPUObjectManager:
@@ -55,14 +79,30 @@ def __ray_get_tensor_meta__(self, obj_id: str):
5579

5680
return src_actor.__ray_call__.remote(__ray_get_tensor_meta__, obj_id)
5781

58-
def add_gpu_object_ref(self, obj_ref: ObjectRef, src_actor: ActorHandle):
59-
# `obj_ref` is an ObjectRef generated by the `src_actor`'s actor task
60-
# that is annotated with `@ray.method(tensor_transport=...)`. This function
61-
# adds the `obj_ref` to the `gpu_object_refs` dictionary so that the coordinator
62-
# process can determine whether the `obj_ref` is a GPU object reference or not.
82+
def add_gpu_object_ref(
83+
self,
84+
obj_ref: ObjectRef,
85+
src_actor: ActorHandle,
86+
tensor_transport: TensorTransportEnum,
87+
):
88+
"""Add a GPU object reference to the GPU object manager. This should be
89+
called whenever the current process calls a task that is annotated with
90+
`@ray.method(tensor_transport=...)`.
91+
92+
Args:
93+
obj_ref: The ObjectRef of the task output.
94+
src_actor: The actor that executes the task and that creates the GPU object.
95+
tensor_transport: The tensor transport protocol to use for the GPU object.
96+
"""
97+
util = _get_or_import_util()
98+
tensor_transport_backend = util.tensor_transport_to_collective_backend(
99+
tensor_transport
100+
)
63101
tensor_meta = self._get_tensor_meta(src_actor, obj_ref.hex())
64102
self.gpu_object_refs[obj_ref] = GPUObjectMeta(
65-
src_actor=src_actor, tensor_meta=tensor_meta
103+
src_actor=src_actor,
104+
tensor_transport_backend=tensor_transport_backend,
105+
tensor_meta=tensor_meta,
66106
)
67107

68108
# TODO(kevin85421): Call this function to remove the `obj_ref` from the `gpu_object_refs` dictionary
@@ -76,58 +116,30 @@ def _get_gpu_object_ref(self, obj_ref: ObjectRef) -> Optional[GPUObjectMeta]:
76116
def _is_gpu_object_ref(self, obj_ref: ObjectRef) -> bool:
77117
return obj_ref in self.gpu_object_refs
78118

79-
def _send_gpu_object(self, src_actor: ActorHandle, obj_id: str, dst_rank: int):
119+
def _send_gpu_object(
120+
self, communicator_name: str, src_actor: ActorHandle, obj_id: str, dst_rank: int
121+
):
80122
# Send tensors stored in the `src_actor`'s GPU object store to the
81123
# destination rank `dst_rank`.
82-
def __ray_send__(self, obj_id: str, dst_rank: int):
83-
import torch.distributed as dist
84-
85-
from ray._private.worker import global_worker
86-
87-
gpu_object_manager = global_worker.gpu_object_manager
88-
assert gpu_object_manager.has_gpu_object(
89-
obj_id
90-
), f"obj_id={obj_id} not found in GPU object store"
91-
tensors = gpu_object_manager.get_gpu_object(obj_id)
92-
for tensor in tensors:
93-
dist.send(tensor, dst_rank)
94-
# TODO(kevin85421): The current garbage collection implementation for the
95-
# in-actor object store is naive. We garbage collect each object after it
96-
# is consumed once.
97-
gpu_object_manager.remove_gpu_object(obj_id)
98-
99-
src_actor.__ray_call__.remote(__ray_send__, obj_id, dst_rank)
124+
util = _get_or_import_util()
125+
src_actor.__ray_call__.remote(
126+
util.__ray_send__, communicator_name, obj_id, dst_rank
127+
)
100128

101129
def _recv_gpu_object(
102130
self,
131+
communicator_name: str,
103132
dst_actor: ActorHandle,
104133
obj_id: str,
105134
src_rank: int,
106135
tensor_meta: List[Tuple["torch.Size", "torch.dtype"]],
107136
):
108137
# Receive tensors from the source rank and store them in the
109138
# `dst_actor`'s GPU object store.
110-
def __ray_recv__(
111-
self,
112-
obj_id: str,
113-
src_rank: int,
114-
tensor_meta: List[Tuple["torch.Size", "torch.dtype"]],
115-
):
116-
import torch
117-
import torch.distributed as dist
118-
119-
from ray._private.worker import global_worker
120-
121-
gpu_object_manager = global_worker.gpu_object_manager
122-
tensors = []
123-
for meta in tensor_meta:
124-
shape, dtype = meta
125-
tensor = torch.zeros(shape, dtype=dtype)
126-
dist.recv(tensor, src_rank)
127-
tensors.append(tensor)
128-
gpu_object_manager.add_gpu_object(obj_id, tensors)
129-
130-
dst_actor.__ray_call__.remote(__ray_recv__, obj_id, src_rank, tensor_meta)
139+
util = _get_or_import_util()
140+
dst_actor.__ray_call__.remote(
141+
util.__ray_recv__, communicator_name, obj_id, src_rank, tensor_meta
142+
)
131143

132144
def trigger_out_of_band_tensor_transfer(
133145
self, dst_actor: ActorHandle, task_args: Tuple[Any, ...]
@@ -150,11 +162,6 @@ def trigger_out_of_band_tensor_transfer(
150162
dst_actor: The target actor to receive tensors
151163
task_args: List of arguments for the target actor task that may contain ObjectRefs.
152164
"""
153-
from ray.experimental.channel import ChannelContext
154-
155-
ctx = ChannelContext.get_current()
156-
157-
actor_id_to_rank = {}
158165
for arg in task_args:
159166
# If an ObjectRef exists in `gpu_object_refs`, it means the ObjectRef
160167
# is in-actor tensors. Therefore, this function will trigger a tensor
@@ -164,37 +171,51 @@ def trigger_out_of_band_tensor_transfer(
164171

165172
if not self._is_gpu_object_ref(arg):
166173
continue
174+
175+
# Import get_collective_groups here to avoid dependency on
176+
# collective libraries for default Ray installation.
177+
from ray.experimental.collective import get_collective_groups
178+
167179
gpu_object_meta = self._get_gpu_object_ref(arg)
168180

169181
src_actor = gpu_object_meta.src_actor
170182
tensor_meta = gpu_object_meta.tensor_meta
171-
if not actor_id_to_rank:
172-
# TODO(kevin85421): Support multiple communicators.
173-
if len(ctx.communicators) != 1:
174-
raise ValueError(
175-
f"There are {len(ctx.communicators)} communicators in the current context. "
176-
"Currently, GPU objects only support 1 communicator. Please make sure only "
177-
"one communicator exists."
178-
)
179-
actor_id_to_rank = {
180-
a._ray_actor_id: i for i, a in enumerate(ctx.communicators[0])
181-
}
182-
if src_actor._ray_actor_id not in actor_id_to_rank:
183+
communicators = get_collective_groups(
184+
[src_actor, dst_actor], backend=gpu_object_meta.tensor_transport_backend
185+
)
186+
# TODO(kevin85421): Support multiple communicators.
187+
if len(communicators) == 0:
188+
raise ValueError(
189+
f"No communicators found for actors {src_actor} and {dst_actor}. "
190+
"Create a communicator with "
191+
"`ray.experimental.collective.create_collective_group` "
192+
"before calling actor tasks."
193+
)
194+
elif len(communicators) > 1:
195+
raise ValueError(
196+
f"There are {len(communicators)} possible communicators that contain actors {src_actor} and {dst_actor}. "
197+
"Currently, GPU objects only support one communicator. Please make sure only "
198+
"one communicator exists."
199+
)
200+
communicator = communicators[0]
201+
src_rank = communicator.get_rank(src_actor)
202+
if src_rank == -1:
183203
raise ValueError(
184-
f"Sender actor {src_actor._ray_actor_id} not found in communicator. "
204+
f"Sender actor {src_actor} not found in communicator. "
185205
"Please make sure the sender and receiver are in the same communicator."
186206
)
187-
if dst_actor._ray_actor_id not in actor_id_to_rank:
207+
dst_rank = communicator.get_rank(dst_actor)
208+
if dst_rank == -1:
188209
raise ValueError(
189-
f"Receiver actor {dst_actor._ray_actor_id} not found in communicator. "
210+
f"Receiver actor {dst_actor} not found in communicator. "
190211
"Please make sure the sender and receiver are in the same communicator."
191212
)
192-
src_rank = actor_id_to_rank[src_actor._ray_actor_id]
193-
dst_rank = actor_id_to_rank[dst_actor._ray_actor_id]
194213
if src_rank == dst_rank:
195214
# If the source and destination ranks are the same, the tensors can
196215
# be transferred intra-process, so we skip the out-of-band tensor
197216
# transfer.
198217
continue
199-
self._send_gpu_object(src_actor, arg.hex(), dst_rank)
200-
self._recv_gpu_object(dst_actor, arg.hex(), src_rank, tensor_meta)
218+
self._send_gpu_object(communicator.name, src_actor, arg.hex(), dst_rank)
219+
self._recv_gpu_object(
220+
communicator.name, dst_actor, arg.hex(), src_rank, tensor_meta
221+
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import List, Tuple
2+
3+
try:
4+
import torch
5+
except ImportError:
6+
raise ImportError(
7+
"`tensor_transport` requires PyTorch. "
8+
"Please install torch with 'pip install torch' to use this feature."
9+
)
10+
11+
import ray.util.collective as collective
12+
from ray._private.custom_types import TensorTransportEnum
13+
from ray._private.worker import global_worker
14+
from ray.util.collective.types import Backend
15+
16+
TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND = {
17+
TensorTransportEnum.NCCL: Backend.NCCL,
18+
TensorTransportEnum.GLOO: Backend.TORCH_GLOO,
19+
}
20+
21+
COLLECTIVE_BACKEND_TO_TORCH_DEVICE = {
22+
Backend.NCCL: torch.device("cuda"),
23+
Backend.TORCH_GLOO: torch.device("cpu"),
24+
}
25+
26+
27+
def tensor_transport_to_collective_backend(
28+
tensor_transport: TensorTransportEnum,
29+
) -> Backend:
30+
try:
31+
return TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND[tensor_transport]
32+
except KeyError:
33+
raise ValueError(
34+
f"Invalid tensor transport {tensor_transport.name}, must be one of {list(TENSOR_TRANSPORT_TO_COLLECTIVE_BACKEND.keys())}."
35+
)
36+
37+
38+
def __ray_send__(self, communicator_name: str, obj_id: str, dst_rank: int):
39+
"""Helper function that runs on the src actor to send tensors to the dst actor."""
40+
gpu_object_manager = global_worker.gpu_object_manager
41+
assert gpu_object_manager.has_gpu_object(
42+
obj_id
43+
), f"obj_id={obj_id} not found in GPU object store"
44+
tensors = gpu_object_manager.get_gpu_object(obj_id)
45+
46+
backend = collective.get_group_handle(communicator_name).backend()
47+
device = COLLECTIVE_BACKEND_TO_TORCH_DEVICE[backend]
48+
49+
for tensor in tensors:
50+
if tensor.device.type != device.type:
51+
# TODO(swang): Right now there is no way to catch this error
52+
# and the receiving Ray task will hang.
53+
raise ValueError(
54+
f"tensor device {tensor.device} does not match device {device}"
55+
)
56+
collective.send(tensor, dst_rank, group_name=communicator_name)
57+
# TODO(kevin85421): The current garbage collection implementation for the
58+
# in-actor object store is naive. We garbage collect each object after it
59+
# is consumed once.
60+
gpu_object_manager.remove_gpu_object(obj_id)
61+
62+
63+
def __ray_recv__(
64+
self,
65+
communicator_name: str,
66+
obj_id: str,
67+
src_rank: int,
68+
tensor_meta: List[Tuple["torch.Size", "torch.dtype"]],
69+
):
70+
"""Helper function that runs on the dst actor to receive tensors from the src actor."""
71+
from ray._private.worker import global_worker
72+
73+
backend = collective.get_group_handle(communicator_name).backend()
74+
device = COLLECTIVE_BACKEND_TO_TORCH_DEVICE[backend]
75+
76+
gpu_object_manager = global_worker.gpu_object_manager
77+
tensors = []
78+
for meta in tensor_meta:
79+
shape, dtype = meta
80+
tensor = torch.zeros(shape, dtype=dtype, device=device)
81+
collective.recv(tensor, src_rank, group_name=communicator_name)
82+
tensors.append(tensor)
83+
gpu_object_manager.add_gpu_object(obj_id, tensors)

0 commit comments

Comments
 (0)