@@ -158,18 +158,16 @@ def _encode_tensor(
158
158
self , obj : torch .Tensor
159
159
) -> tuple [str , tuple [int , ...], Union [int , memoryview ]]:
160
160
assert self .aux_buffers is not None
161
- # this creates a copy of the tensor if it's not already contiguous
162
- obj = obj .contiguous ()
163
161
# view the tensor as a 1D array of bytes
164
- arr = obj .view (( obj . numel (), ) ).view (torch .uint8 ).numpy ()
162
+ arr = obj .flatten ( ).view (torch .uint8 ).numpy ()
165
163
if obj .nbytes < self .size_threshold :
166
164
# Smaller tensors are encoded inline, just like ndarrays.
167
165
data = msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr .data )
168
166
else :
169
167
# Otherwise encode index of backing buffer to avoid copy.
170
168
data = len (self .aux_buffers )
171
169
self .aux_buffers .append (arr .data )
172
- dtype = str (obj .dtype )[ 6 :] # remove ' torch.' prefix
170
+ dtype = str (obj .dtype ). removeprefix ( " torch." )
173
171
return dtype , obj .shape , data
174
172
175
173
def _encode_nested_tensors (self , nt : NestedTensors ) -> Any :
@@ -245,7 +243,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
245
243
# zero-copy decode. We assume the ndarray will not be kept around,
246
244
# as it now locks the whole received message buffer in memory.
247
245
buffer = self .aux_buffers [data ] if isinstance (data , int ) else data
248
- return np .ndarray (buffer = buffer , dtype = np . dtype ( dtype ), shape = shape )
246
+ return np .frombuffer (buffer , dtype = dtype ). reshape ( shape )
249
247
250
248
def _decode_tensor (self , arr : Any ) -> torch .Tensor :
251
249
dtype , shape , data = arr
@@ -254,12 +252,15 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor:
254
252
# not complain about a readonly memoryview.
255
253
buffer = self .aux_buffers [data ] if isinstance (data , int ) \
256
254
else bytearray (data )
257
- # Create numpy wrapper around the bytes
258
- arr = np .ndarray (buffer = buffer , dtype = np .uint8 , shape = (len (buffer ), ))
259
255
torch_dtype = getattr (torch , dtype )
260
256
assert isinstance (torch_dtype , torch .dtype )
257
+ if not buffer : # torch.frombuffer doesn't like empty buffers
258
+ assert 0 in shape
259
+ return torch .empty (shape , dtype = torch_dtype )
260
+ # Create uint8 array
261
+ arr = torch .frombuffer (buffer , dtype = torch .uint8 )
261
262
# Convert back to proper shape & type
262
- return torch . from_numpy ( arr ) .view (torch_dtype ).view (shape )
263
+ return arr .view (torch_dtype ).view (shape )
263
264
264
265
def _decode_mm_items (self , obj : list ) -> list [MultiModalKwargsItem ]:
265
266
decoded_items = []
0 commit comments