Skip to content

Commit 28d1e2e

Browse files
authored
[XPU] feat: add xpu async memory copy to enable zero cost checkpoint (PaddlePaddle#71168)
* [XPU] feat: add xpu async memory copy to enable zero cost checkpoint * [XPU] feat: add xpu async memory copy to enable zero cost checkpoint * [XPU] feat: add xpu async memory copy to enable zero cost checkpoint * [XPU] feat: add xpu async memory copy to enable zero cost checkpoint * [XPU] feat: add xpu async memory copy to enable zero cost checkpoint * [XPU] feat: add xpu async memory copy to enable zero cost checkpoint * [XPU] feat: add xpu async memory copy to enable zero cost checkpoint
1 parent cccfde1 commit 28d1e2e

File tree

8 files changed

+763
-2
lines changed

8 files changed

+763
-2
lines changed

paddle/fluid/distributed/collective/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ if(WITH_XPU_BKCL)
4141
process_group_bkcl
4242
SRCS process_group_bkcl.cc bkcl_tools.cc common.cc
4343
DEPS process_group phi)
44+
45+
cc_library(
46+
xpu_async_load
47+
SRCS xpu_async_load.cc
48+
DEPS process_group phi ${DEVICE_EVENT_LIBS})
4449
endif()
4550

4651
if(WITH_MPI)
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/distributed/collective/xpu_async_load.h"
16+
17+
#include "paddle/fluid/platform/enforce.h"
18+
#include "paddle/phi/common/memory_utils.h" // phi::memory_utils::Copy
19+
// #include "paddle/phi/core/device_context_pool.h" // for DeviceContextPool
20+
// #include "paddle/phi/core/places.h" // phi::is_xpu_place(...)
21+
#include "paddle/phi/core/compat/convert_utils.h"
22+
23+
namespace paddle {
24+
namespace distributed {
25+
26+
using phi::is_xpu_place;
27+
28+
/**
29+
* Helper: Insert or retrieve a DeviceEvent in the map without
30+
* default-constructing it.
31+
* - If place is XPU, we skip event usage entirely (dummy).
32+
* - If place is NOT XPU, we create a DeviceEvent with the needed constructor.
33+
*/
34+
static platform::DeviceEvent& GetOrCreateEvent(
35+
std::unordered_map<std::string, platform::DeviceEvent>* event_map,
36+
const std::string& key,
37+
const phi::Place& place) {
38+
// If it's XPU, we do a "dummy" CPU-based event or skip
39+
// (but let's store a CPU event just so we can return a reference).
40+
// In a real design, you might do a separate approach.
41+
42+
phi::Place event_place = is_xpu_place(place) ? phi::CPUPlace() : place;
43+
unsigned int flags = platform::GenerateDeviceEventFlag();
44+
45+
auto it = event_map->find(key);
46+
if (it == event_map->end()) {
47+
// Insert using piecewise_construct to avoid default constructor
48+
auto emplace_result =
49+
event_map->emplace(std::piecewise_construct,
50+
std::forward_as_tuple(key),
51+
std::forward_as_tuple(event_place, flags));
52+
it = emplace_result.first; // newly inserted
53+
}
54+
return it->second;
55+
}
56+
57+
/* ------------------- Task Implementation ------------------- */
58+
59+
XpuAsyncLoad::Task::Task(const Place& place)
60+
: use_event_(!is_xpu_place(place)),
61+
// If place is XPU, we store a CPU event just so load_event_ is valid
62+
// (some dummy fallback, we won't really use it)
63+
load_event_(use_event_ ? place : phi::CPUPlace(),
64+
platform::GenerateDeviceEventFlag()),
65+
task_place_(place) {}
66+
67+
XpuAsyncLoad::Task::~Task() = default;
68+
69+
bool XpuAsyncLoad::Task::IsCompleted() {
70+
if (!use_event_) {
71+
// For XPU, skip real event usage and just say "complete"
72+
return true;
73+
}
74+
return load_event_.Query();
75+
}
76+
77+
// Example fix in Task::XpuSynchronize():
78+
void XpuAsyncLoad::Task::XpuSynchronize() {
79+
if (!use_event_) {
80+
return;
81+
}
82+
auto* calc_ctx = phi::DeviceContextPool::Instance().Get(task_place_);
83+
// OLD (won't compile in your version):
84+
// auto backend = task_place_.GetBackend();
85+
// load_event_.Wait(backend, calc_ctx);
86+
// NEW:
87+
load_event_.Wait(platform::Place2DeviceType(task_place_), calc_ctx);
88+
}
89+
90+
void XpuAsyncLoad::Task::CpuSynchronize() {
91+
if (!use_event_) {
92+
return;
93+
}
94+
load_event_.Finish();
95+
}
96+
97+
void XpuAsyncLoad::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {
98+
if (!use_event_) {
99+
// skip
100+
return;
101+
}
102+
load_event_.Record(&ctx);
103+
}
104+
105+
/* ------------------- XpuAsyncLoad Implementation ------------------- */
106+
107+
std::shared_ptr<XpuAsyncLoad::Task> XpuAsyncLoad::CreateTask(
108+
const Place& place) {
109+
return std::make_shared<XpuAsyncLoad::Task>(place);
110+
}
111+
112+
void XpuAsyncLoad::PrepareLoadEnv(const std::string& key, const Place& place) {
113+
if (!is_initialized_) {
114+
is_initialized_ = true;
115+
xpu_place_ = place;
116+
// If not XPU, create a real event; if XPU, we store a dummy CPU event
117+
(void)GetOrCreateEvent(&place_to_calc_event_, key, place);
118+
119+
// Create an XPUContext for the offload
120+
load_ctx_ = std::make_unique<phi::XPUContext>(place);
121+
}
122+
}
123+
124+
// Another fix in SyncCalcuStream():
125+
void XpuAsyncLoad::SyncCalcuStream(const Place& place,
126+
phi::XPUContext* offload_ctx,
127+
platform::DeviceEvent* calc_event) {
128+
if (is_xpu_place(place)) {
129+
// skip or do fallback
130+
return;
131+
}
132+
auto* calc_ctx = phi::DeviceContextPool::Instance().Get(place);
133+
calc_event->Record(calc_ctx);
134+
// OLD (won't compile):
135+
// auto backend = place.GetBackend();
136+
// calc_event.Wait(backend, offload_ctx);
137+
// NEW:
138+
calc_event->Wait(platform::Place2DeviceType(place), offload_ctx);
139+
}
140+
141+
/* ------------ Offload (XPU -> CPU pinned or CPU) ------------ */
142+
std::shared_ptr<XpuAsyncLoad::Task> XpuAsyncLoad::Offload(
143+
phi::DenseTensor* dst, const phi::DenseTensor& src) {
144+
PADDLE_ENFORCE_EQ(
145+
is_xpu_place(src.place()),
146+
true,
147+
phi::errors::InvalidArgument("Offload only supports XPU source."));
148+
149+
std::string key = "load_key";
150+
PrepareLoadEnv(key, src.place());
151+
// retrieve or create the event
152+
auto& calc_event = GetOrCreateEvent(&place_to_calc_event_, key, src.place());
153+
// sync
154+
SyncCalcuStream(xpu_place_, load_ctx_.get(), &calc_event);
155+
156+
// do synchronous copy to CPU
157+
dst->Resize(src.dims());
158+
size_t size = src.numel() * phi::SizeOf(src.dtype());
159+
auto cpu_place = phi::CPUPlace();
160+
auto* cpu_ctx = phi::DeviceContextPool::Instance().Get(cpu_place);
161+
void* dst_ptr = cpu_ctx->Alloc(dst, src.dtype(), size);
162+
const void* src_ptr = src.data();
163+
164+
phi::memory_utils::Copy(cpu_place,
165+
dst_ptr,
166+
src.place(),
167+
src_ptr,
168+
size,
169+
/*stream=*/nullptr);
170+
171+
auto task = CreateTask(src.place());
172+
task->UpdateWaitChain(*load_ctx_);
173+
return task;
174+
}
175+
176+
/* ------------ OffloadWithOffset (XPU -> CPU partial) ------------ */
177+
std::shared_ptr<XpuAsyncLoad::Task> XpuAsyncLoad::OffloadWithOffset(
178+
phi::DenseTensor* dst,
179+
const phi::DenseTensor& src,
180+
size_t dst_offset,
181+
size_t src_offset,
182+
size_t offload_size) {
183+
PADDLE_ENFORCE_EQ(
184+
is_xpu_place(src.place()),
185+
true,
186+
phi::errors::InvalidArgument("OffloadWithOffset requires XPU source."));
187+
188+
PADDLE_ENFORCE_EQ(dst->initialized(),
189+
true,
190+
phi::errors::PreconditionNotMet(
191+
"dst must be initialized for partial offload."));
192+
193+
PADDLE_ENFORCE_LE(
194+
src_offset + offload_size,
195+
src.numel(),
196+
phi::errors::InvalidArgument("src offset + size out of range."));
197+
PADDLE_ENFORCE_LE(
198+
dst_offset + offload_size,
199+
dst->numel(),
200+
phi::errors::InvalidArgument("dst offset + size out of range."));
201+
202+
std::string key = "load_key";
203+
PrepareLoadEnv(key, src.place());
204+
auto& calc_event = GetOrCreateEvent(&place_to_calc_event_, key, src.place());
205+
SyncCalcuStream(xpu_place_, load_ctx_.get(), &calc_event);
206+
207+
size_t elem_size = phi::SizeOf(src.dtype());
208+
size_t copy_bytes = offload_size * elem_size;
209+
const void* src_ptr =
210+
static_cast<const char*>(src.data()) + src_offset * elem_size;
211+
void* dst_ptr = static_cast<char*>(dst->data()) + dst_offset * elem_size;
212+
213+
phi::memory_utils::Copy(dst->place(),
214+
dst_ptr,
215+
src.place(),
216+
src_ptr,
217+
copy_bytes,
218+
/*stream=*/nullptr);
219+
220+
auto task = CreateTask(src.place());
221+
task->UpdateWaitChain(*load_ctx_);
222+
return task;
223+
}
224+
225+
/* ------------ Reload (CPU -> XPU) ------------ */
226+
std::shared_ptr<XpuAsyncLoad::Task> XpuAsyncLoad::Reload(
227+
phi::DenseTensor* dst, const phi::DenseTensor& src) {
228+
PADDLE_ENFORCE_EQ(
229+
is_initialized_,
230+
true,
231+
phi::errors::PreconditionNotMet("Call Offload before Reload."));
232+
233+
// Possibly we check if src is CPU or pinned place
234+
// We'll skip that check or treat it as CPU place
235+
std::string key = "load_key";
236+
auto& calc_event = GetOrCreateEvent(&place_to_calc_event_, key, xpu_place_);
237+
SyncCalcuStream(xpu_place_, load_ctx_.get(), &calc_event);
238+
239+
// Now do CPU->XPU
240+
dst->Resize(src.dims());
241+
size_t size = src.numel() * phi::SizeOf(src.dtype());
242+
243+
auto* xpu_ctx = phi::DeviceContextPool::Instance().Get(xpu_place_);
244+
void* dst_ptr = xpu_ctx->Alloc(dst, src.dtype(), size, /*pinned=*/false);
245+
const void* src_ptr = src.data();
246+
247+
phi::memory_utils::Copy(xpu_place_,
248+
dst_ptr,
249+
src.place(),
250+
src_ptr,
251+
size,
252+
/*stream=*/nullptr);
253+
254+
auto task = CreateTask(xpu_place_);
255+
task->UpdateWaitChain(*load_ctx_);
256+
return task;
257+
}
258+
259+
} // namespace distributed
260+
} // namespace paddle
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.#pragma once
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include <unordered_map>
20+
21+
// #include "paddle/fluid/platform/device_context.h"
22+
// #include "paddle/phi/backends/xpu/xpu_context.h"
23+
#include "paddle/phi/core/dense_tensor.h"
24+
#include "paddle/phi/core/platform/device_event_base.h"
25+
// #include "paddle/phi/core/places.h"
26+
27+
namespace paddle {
28+
namespace distributed {
29+
30+
using Place = phi::Place;
31+
32+
/**
33+
* AsyncLoad that does NOT use platform::DeviceEvent if place == XPU.
34+
*/
35+
class XpuAsyncLoad {
36+
public:
37+
class Task {
38+
public:
39+
explicit Task(const Place& place);
40+
virtual ~Task();
41+
42+
bool IsCompleted();
43+
44+
// Replaces CudaSynchronize with XpuSynchronize
45+
void XpuSynchronize();
46+
void CpuSynchronize();
47+
48+
// If not XPU, record the event. If XPU, do nothing
49+
void UpdateWaitChain(const phi::DeviceContext& ctx);
50+
51+
private:
52+
bool use_event_; // false if place is XPU
53+
platform::DeviceEvent load_event_;
54+
Place task_place_;
55+
};
56+
57+
// Offload
58+
std::shared_ptr<Task> Offload(phi::DenseTensor* dst,
59+
const phi::DenseTensor& src);
60+
61+
// OffloadWithOffset
62+
std::shared_ptr<Task> OffloadWithOffset(phi::DenseTensor* dst,
63+
const phi::DenseTensor& src,
64+
size_t dst_offset,
65+
size_t src_offset,
66+
size_t offload_size);
67+
68+
// Reload
69+
std::shared_ptr<Task> Reload(phi::DenseTensor* dst,
70+
const phi::DenseTensor& src);
71+
72+
private:
73+
bool is_initialized_{false};
74+
75+
// A fallback "offload context," though we won't do multi-stream sync for XPU
76+
std::unique_ptr<phi::XPUContext> load_ctx_;
77+
Place xpu_place_;
78+
79+
std::shared_ptr<Task> CreateTask(const Place& place);
80+
81+
// If not XPU, store calc-event. If XPU, skip
82+
std::unordered_map<std::string, platform::DeviceEvent> place_to_calc_event_;
83+
84+
// Prepare env
85+
void PrepareLoadEnv(const std::string& key, const Place& place);
86+
87+
// If not XPU, do event sync. If XPU, skip
88+
void SyncCalcuStream(const Place& place,
89+
phi::XPUContext* offload_ctx,
90+
platform::DeviceEvent* calc_event);
91+
};
92+
93+
} // namespace distributed
94+
} // namespace paddle

paddle/fluid/pybind/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ if(WITH_PYTHON)
181181
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_nccl async_load)
182182
endif()
183183
if(WITH_XPU_BKCL)
184-
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_bkcl)
184+
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_bkcl xpu_async_load)
185185
endif()
186186
if(WITH_GLOO)
187187
set(PYBIND_DEPS ${PYBIND_DEPS} process_group_gloo)

0 commit comments

Comments
 (0)