@@ -1400,20 +1400,26 @@ void BindImperative(py::module *m_ptr) {
14001400
14011401 )DOC" )
14021402 .def (" cuda" ,
1403- [](const std::shared_ptr<imperative::VarBase> &self, int device_id,
1404- bool blocking) {
1403+ [](const std::shared_ptr<imperative::VarBase> &self,
1404+ py::handle &handle, bool blocking) {
14051405#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
14061406 PADDLE_THROW (platform::errors::PermissionDenied (
14071407 " Cannot copy this Tensor to GPU in CPU version Paddle, "
14081408 " Please recompile or reinstall Paddle with CUDA support." ));
14091409#else
14101410 int device_count = platform::GetCUDADeviceCount ();
1411- if (device_id == -1 ) {
1411+ int device_id = 0 ;
1412+ if (handle == py::none ()) {
14121413 if (platform::is_gpu_place (self->Place ())) {
14131414 return self;
1414- } else {
1415- device_id = 0 ;
14161415 }
1416+ } else {
1417+ PyObject *py_obj = handle.ptr ();
1418+ PADDLE_ENFORCE_EQ (
1419+ PyCheckInteger (py_obj), true ,
1420+ platform::errors::InvalidArgument (
1421+ " 'device_id' must be a positive integer" ));
1422+ device_id = py::cast<int >(handle);
14171423 }
14181424 PADDLE_ENFORCE_GE (
14191425 device_id, 0 ,
@@ -1437,26 +1443,30 @@ void BindImperative(py::module *m_ptr) {
14371443 }
14381444#endif
14391445 },
1440- py::arg (" device_id" ) = - 1 , py::arg (" blocking" ) = true , R"DOC(
1446+ py::arg (" device_id" ) = py::none () , py::arg (" blocking" ) = true , R"DOC(
14411447 Returns a copy of this Tensor in GPU memory.
14421448
14431449 If this Tensor is already in GPU memory and device_id is default,
14441450 then no copy is performed and the original Tensor is returned.
14451451
14461452 Args:
1447- device_id(int, optional): The destination GPU device id. Defaults to the current device.
1453+ device_id(int, optional): The destination GPU device id. Default: None, means current device.
14481454 blocking(bool, optional): If False and the source is in pinned memory, the copy will be
14491455 asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.
14501456
14511457 Examples:
14521458 .. code-block:: python
14531459
1460+ # required: gpu
14541461 import paddle
14551462 x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
14561463 print(x.place) # CPUPlace
14571464
14581465 y = x.cuda()
14591466 print(y.place) # CUDAPlace(0)
1467+
1468+ y = x.cuda(None)
1469+ print(y.place) # CUDAPlace(0)
14601470
14611471 y = x.cuda(1)
14621472 print(y.place) # CUDAPlace(1)
0 commit comments