Skip to content

Commit 196db7e

Browse files
authored
dist dataloader: add cuda compilation check (#8099)
1 parent 60b9755 commit 196db7e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,17 @@ def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_r
238238
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
239239
data.dtype, datatype
240240
)
241-
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
241+
if paddle.is_compiled_with_cuda():
242+
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
243+
else:
244+
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)
245+
242246
assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
243247
else:
244-
data_b = paddle.empty([numel], dtype=datatype).cuda()
248+
if paddle.is_compiled_with_cuda():
249+
data_b = paddle.empty([numel], dtype=datatype).cuda()
250+
else:
251+
data_b = paddle.empty([numel], dtype=datatype)
245252

246253
# Broadcast
247254
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()

0 commit comments

Comments
 (0)