Skip to content

Commit 5397cda

Browse files
lgeigeramitm02
authored andcommitted
[Core] Improve Tensor serialisation (vllm-project#18774)
Signed-off-by: Lukas Geiger <[email protected]> Signed-off-by: amit <[email protected]>
1 parent b19f788 commit 5397cda

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

vllm/v1/serial_utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,18 +158,16 @@ def _encode_tensor(
158158
self, obj: torch.Tensor
159159
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
160160
assert self.aux_buffers is not None
161-
# this creates a copy of the tensor if it's not already contiguous
162-
obj = obj.contiguous()
163161
# view the tensor as a 1D array of bytes
164-
arr = obj.view((obj.numel(), )).view(torch.uint8).numpy()
162+
arr = obj.flatten().view(torch.uint8).numpy()
165163
if obj.nbytes < self.size_threshold:
166164
# Smaller tensors are encoded inline, just like ndarrays.
167165
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
168166
else:
169167
# Otherwise encode index of backing buffer to avoid copy.
170168
data = len(self.aux_buffers)
171169
self.aux_buffers.append(arr.data)
172-
dtype = str(obj.dtype)[6:] # remove 'torch.' prefix
170+
dtype = str(obj.dtype).removeprefix("torch.")
173171
return dtype, obj.shape, data
174172

175173
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
@@ -245,7 +243,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
245243
# zero-copy decode. We assume the ndarray will not be kept around,
246244
# as it now locks the whole received message buffer in memory.
247245
buffer = self.aux_buffers[data] if isinstance(data, int) else data
248-
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
246+
return np.frombuffer(buffer, dtype=dtype).reshape(shape)
249247

250248
def _decode_tensor(self, arr: Any) -> torch.Tensor:
251249
dtype, shape, data = arr
@@ -254,12 +252,15 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor:
254252
# not complain about a readonly memoryview.
255253
buffer = self.aux_buffers[data] if isinstance(data, int) \
256254
else bytearray(data)
257-
# Create numpy wrapper around the bytes
258-
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
259255
torch_dtype = getattr(torch, dtype)
260256
assert isinstance(torch_dtype, torch.dtype)
257+
if not buffer: # torch.frombuffer doesn't like empty buffers
258+
assert 0 in shape
259+
return torch.empty(shape, dtype=torch_dtype)
260+
# Create uint8 array
261+
arr = torch.frombuffer(buffer, dtype=torch.uint8)
261262
# Convert back to proper shape & type
262-
return torch.from_numpy(arr).view(torch_dtype).view(shape)
263+
return arr.view(torch_dtype).view(shape)
263264

264265
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
265266
decoded_items = []

0 commit comments

Comments
 (0)