Skip to content

Commit f2574f3

Browse files
authored
Merge branch 'develop' into load_recover_tensor_name
2 parents 982bfb6 + adc2b74 commit f2574f3

File tree

354 files changed

+6945
-4044
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

354 files changed

+6945
-4044
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ endif()
2424
# https://cmake.org/cmake/help/v3.0/policy/CMP0026.html?highlight=cmp0026
2525
cmake_policy(SET CMP0026 OLD)
2626
cmake_policy(SET CMP0079 NEW)
27+
if(${CMAKE_VERSION} VERSION_GREATER_EQUAL 3.18)
28+
cmake_policy(SET CMP0104 OLD)
29+
endif()
2730
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
2831
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
2932
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})

cmake/external/pybind11.cmake

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ set(PYBIND_TAG v2.10.3)
2121
set(PYBIND_INCLUDE_DIR ${THIRD_PARTY_PATH}/pybind/src/extern_pybind/include)
2222
include_directories(${PYBIND_INCLUDE_DIR})
2323

24+
set(PYBIND_PATCH_COMMAND "")
25+
if(NOT WIN32)
26+
file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/pybind/cast.h.patch
27+
native_dst)
28+
set(PYBIND_PATCH_COMMAND patch -d ${PYBIND_INCLUDE_DIR}/pybind11 <
29+
${native_dst})
30+
endif()
31+
2432
ExternalProject_Add(
2533
extern_pybind
2634
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
@@ -33,6 +41,7 @@ ExternalProject_Add(
3341
# third-party library version changes cannot be incorporated.
3442
# reference: https://cmake.org/cmake/help/latest/module/ExternalProject.html
3543
UPDATE_COMMAND ""
44+
PATCH_COMMAND ${PYBIND_PATCH_COMMAND}
3645
CONFIGURE_COMMAND ""
3746
BUILD_COMMAND ""
3847
INSTALL_COMMAND ""

cmake/external/warpctc.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ else()
8282
set(WARPCTC_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
8383
set(WARPCTC_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
8484
endif()
85+
8586
ExternalProject_Add(
8687
extern_warpctc
8788
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}

paddle/fluid/distributed/collective/process_group_custom.cc

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,150 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
576576
false);
577577
}
578578

579+
void CheckTensorsInDifferentCustomDevices(
580+
const std::vector<phi::DenseTensor>& tensors, const size_t num_devices) {
581+
PADDLE_ENFORCE_EQ(
582+
tensors.size() == 0,
583+
false,
584+
phi::errors::InvalidArgument("Tensor list must be nonempty."));
585+
PADDLE_ENFORCE_LE(
586+
tensors.size(),
587+
num_devices,
588+
phi::errors::InvalidArgument("Tensor list mustn't be larger than the "
589+
"number of available CustomDevice."));
590+
591+
std::set<Place> used_devices;
592+
593+
for (const auto& t : tensors) {
594+
PADDLE_ENFORCE_EQ(platform::is_custom_place(t.place()),
595+
true,
596+
phi::errors::InvalidArgument(
597+
"Tensors must be CustomDevice and dense tensor."));
598+
599+
const auto inserted = used_devices.insert(t.place()).second;
600+
PADDLE_ENFORCE_EQ(inserted,
601+
true,
602+
phi::errors::InvalidArgument(
603+
"Tensors must be on distinct custom devices."));
604+
}
605+
}
606+
607+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
608+
phi::DenseTensor* tensor,
609+
int src_rank,
610+
int64_t offset,
611+
int64_t numel,
612+
bool sync_op,
613+
bool use_calc_stream) {
614+
// numel > 0 indicates the tensor need to be sliced
615+
phi::DenseTensor partial_tensor;
616+
if (numel > 0) {
617+
partial_tensor = GetPartialTensor(*tensor, offset, numel);
618+
tensor = &partial_tensor;
619+
}
620+
phi::distributed::CommStaticCheck::CheckShape(
621+
*tensor, rank_, size_, phi::AllocationType::CUSTOM);
622+
std::vector<phi::DenseTensor> in_wrapper{*tensor};
623+
std::vector<phi::DenseTensor> out_wrapper{*tensor};
624+
return Collective(
625+
in_wrapper,
626+
out_wrapper,
627+
[&](phi::DenseTensor& input,
628+
phi::DenseTensor& output,
629+
phi::ccl::CCLComm comm,
630+
const phi::stream::Stream& stream) {
631+
phi::DeviceManager::CCLRecv(device_type_,
632+
output.data(),
633+
output.numel(),
634+
phi::ccl::ToCCLDataType(output.dtype()),
635+
src_rank,
636+
comm,
637+
stream);
638+
},
639+
CommType::RECV,
640+
sync_op,
641+
use_calc_stream);
642+
}
643+
644+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
645+
std::vector<phi::DenseTensor>& tensors, int src_rank) {
646+
CheckTensorsInDifferentCustomDevices(tensors, static_cast<size_t>(GetSize()));
647+
return Collective(
648+
tensors,
649+
tensors,
650+
[&](phi::DenseTensor& input,
651+
phi::DenseTensor& output,
652+
phi::ccl::CCLComm comm,
653+
const phi::stream::Stream& stream) {
654+
phi::DeviceManager::CCLRecv(device_type_,
655+
output.data(),
656+
output.numel(),
657+
phi::ccl::ToCCLDataType(output.dtype()),
658+
src_rank,
659+
comm,
660+
stream);
661+
},
662+
CommType::RECV,
663+
false,
664+
false);
665+
}
666+
667+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
668+
const phi::DenseTensor& tensor,
669+
int dst_rank,
670+
int64_t offset,
671+
int64_t numel,
672+
bool sync_op,
673+
bool use_calc_stream) {
674+
// numel > 0 indicates the tensor need to be sliced
675+
const phi::DenseTensor& tensor_maybe_partial =
676+
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
677+
phi::distributed::CommStaticCheck::CheckShape(
678+
tensor_maybe_partial, rank_, size_, phi::AllocationType::CUSTOM);
679+
std::vector<phi::DenseTensor> in_wrapper{tensor_maybe_partial};
680+
std::vector<phi::DenseTensor> out_wrapper{tensor_maybe_partial};
681+
return Collective(
682+
in_wrapper,
683+
out_wrapper,
684+
[&](phi::DenseTensor& input,
685+
phi::DenseTensor& output,
686+
phi::ccl::CCLComm comm,
687+
const phi::stream::Stream& stream) {
688+
phi::DeviceManager::CCLSend(device_type_,
689+
input.data(),
690+
input.numel(),
691+
phi::ccl::ToCCLDataType(input.dtype()),
692+
dst_rank,
693+
comm,
694+
stream);
695+
},
696+
CommType::SEND,
697+
sync_op,
698+
use_calc_stream);
699+
}
700+
701+
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
702+
std::vector<phi::DenseTensor>& tensors, int dst_rank) {
703+
CheckTensorsInDifferentCustomDevices(tensors, static_cast<size_t>(GetSize()));
704+
return Collective(
705+
tensors,
706+
tensors,
707+
[&](phi::DenseTensor& input,
708+
phi::DenseTensor& output,
709+
phi::ccl::CCLComm comm,
710+
const phi::stream::Stream& stream) {
711+
phi::DeviceManager::CCLSend(device_type_,
712+
input.data(),
713+
input.numel(),
714+
phi::ccl::ToCCLDataType(input.dtype()),
715+
dst_rank,
716+
comm,
717+
stream);
718+
},
719+
CommType::SEND,
720+
false,
721+
false);
722+
}
579723
std::shared_ptr<ProcessGroupCustom>
580724
ProcessGroupCustom::CreateProcessGroupCustom(
581725
const std::shared_ptr<phi::distributed::Store>& store,

paddle/fluid/distributed/collective/process_group_custom.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,26 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
143143
const BroadcastOptions& opts,
144144
bool sync_op) override;
145145

146+
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
147+
int dst_rank,
148+
int64_t offset,
149+
int64_t numel,
150+
bool sync_op,
151+
bool use_calc_stream) override;
152+
153+
std::shared_ptr<ProcessGroup::Task> Send(
154+
std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
155+
156+
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
157+
int src_rank,
158+
int64_t offset,
159+
int64_t numel,
160+
bool sync_op,
161+
bool use_calc_stream) override;
162+
163+
std::shared_ptr<ProcessGroup::Task> Recv(
164+
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
165+
146166
protected:
147167
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
148168
std::vector<Place> places,

paddle/fluid/distributed/common/afs_warpper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
#include "paddle/fluid/distributed/the_one_ps.pb.h"
2424
#include "paddle/fluid/string/string_helper.h"
25-
25+
#include "paddle/phi/core/macros.h"
2626
namespace paddle {
2727
namespace distributed {
2828
struct FsDataConverter {
@@ -43,7 +43,7 @@ class FsReadChannel {
4343
virtual ~FsReadChannel() {}
4444
FsReadChannel(FsReadChannel&&) = delete;
4545
FsReadChannel(const FsReadChannel&) = delete;
46-
int open(std::shared_ptr<FILE> fp, const FsChannelConfig& config) {
46+
int open(std::shared_ptr<FILE> fp, const FsChannelConfig& config UNUSED) {
4747
_file = fp;
4848
return 0;
4949
}
@@ -83,7 +83,7 @@ class FsWriteChannel {
8383
FsWriteChannel(FsWriteChannel&&) = delete;
8484
FsWriteChannel(const FsWriteChannel&) = delete;
8585

86-
int open(std::shared_ptr<FILE> fp, const FsChannelConfig& config) {
86+
int open(std::shared_ptr<FILE> fp, const FsChannelConfig& config UNUSED) {
8787
_file = fp;
8888

8989
// the buffer has set in fs.cc

paddle/fluid/distributed/ps/service/env.h

100755100644
Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <vector>
2727

2828
#include "gflags/gflags.h"
29+
#include "paddle/phi/core/macros.h"
2930

3031
namespace paddle {
3132
namespace distributed {
@@ -115,19 +116,23 @@ class PSEnvironment {
115116
explicit PSEnvironment() {} // NOLINT
116117
virtual ~PSEnvironment() {}
117118

118-
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
119+
virtual int32_t SetPsServers(uint64_t *host_sign_list UNUSED,
120+
int node_num UNUSED) {
119121
return 0;
120122
}
121123
virtual int32_t SetPsServers(
122-
const std::vector<std::string> *host_endpoint_list, int node_num) {
124+
const std::vector<std::string> *host_endpoint_list UNUSED,
125+
int node_num UNUSED) {
123126
return 0;
124127
}
125128

126-
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
129+
virtual int32_t SetPsClients(uint64_t *host_sign_list UNUSED,
130+
int node_num UNUSED) {
127131
return 0;
128132
}
129133

130-
virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
134+
virtual int32_t SetPsClients(std::string *host_endpoint_list UNUSED,
135+
int node_num UNUSED) {
131136
return 0;
132137
}
133138

0 commit comments

Comments
 (0)