Skip to content

Commit fc86ee7

Browse files
authored
Merge pull request STEllAR-GROUP#6760 from STEllAR-GROUP/collectives2
Optimizing collectives for num_sites equal to one
2 parents 8afc459 + 6d22925 commit fc86ee7

24 files changed

+484
-313
lines changed

.circleci/tests.unit1.targets

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
# Distributed under the Boost Software License, Version 1.0. (See accompanying
55
# file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
66

7-
tests.unit.modules.batch_environments
8-
tests.unit.modules.cache
9-
tests.unit.modules.checkpoint
10-
tests.unit.modules.checkpoint_base
11-
tests.unit.modules.collectives
12-
tests.unit.modules.command_line_handling_local
13-
tests.unit.modules.distribution_policies
14-
tests.unit.modules.execution
15-
tests.unit.modules.execution_base
16-
tests.unit.modules.executors
17-
tests.unit.modules.executors_distributed
7+
tests.unit.build
8+
tests.unit.components
9+
tests.unit.modules.actions
10+
tests.unit.modules.actions_base
11+
tests.unit.modules.affinity
12+
tests.unit.modules.agas
13+
tests.unit.modules.agas_base
14+
tests.unit.modules.assertion
15+
tests.unit.modules.async_base
16+
tests.unit.modules.async_colocated
17+
tests.unit.modules.async_combinators
18+
tests.unit.modules.async_distributed
19+
tests.unit.modules.async_local

libs/full/collectives/include/hpx/collectives/all_gather.hpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019-2024 Hartmut Kaiser
1+
// Copyright (c) 2019-2025 Hartmut Kaiser
22
//
33
// SPDX-License-Identifier: BSL-1.0
44
// Distributed under the Boost Software License, Version 1.0. (See accompanying
@@ -266,11 +266,11 @@ namespace hpx::collectives {
266266
template <typename T>
267267
hpx::future<std::vector<std::decay_t<T>>> all_gather(communicator fid,
268268
T&& local_result, this_site_arg this_site = this_site_arg(),
269-
generation_arg generation = generation_arg())
269+
generation_arg const generation = generation_arg())
270270
{
271271
using arg_type = std::decay_t<T>;
272272

273-
if (this_site == static_cast<std::size_t>(-1))
273+
if (this_site.is_default())
274274
{
275275
this_site = agas::get_locality_id();
276276
}
@@ -282,6 +282,22 @@ namespace hpx::collectives {
282282
"the generation number shouldn't be zero"));
283283
}
284284

285+
// Handle operation right away if there is only one value.
286+
if (auto [num_sites, comm_site] = fid.get_info(); num_sites == 1)
287+
{
288+
if (this_site != comm_site)
289+
{
290+
return hpx::make_exceptional_future<std::vector<arg_type>>(
291+
HPX_GET_EXCEPTION(hpx::error::bad_parameter,
292+
"hpx::collectives::all_gather",
293+
"the local site should be zero if only one site is "
294+
"involved"));
295+
}
296+
297+
std::vector<arg_type> result(1, HPX_FORWARD(T, local_result));
298+
return hpx::make_ready_future(HPX_MOVE(result));
299+
}
300+
285301
auto all_gather_data = [local_result = HPX_FORWARD(T, local_result),
286302
this_site,
287303
generation](communicator&& c) mutable
@@ -311,19 +327,19 @@ namespace hpx::collectives {
311327

312328
template <typename T>
313329
hpx::future<std::vector<std::decay_t<T>>> all_gather(communicator fid,
314-
T&& local_result, generation_arg generation,
315-
this_site_arg this_site = this_site_arg())
330+
T&& local_result, generation_arg const generation,
331+
this_site_arg const this_site = this_site_arg())
316332
{
317333
return all_gather(
318334
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation);
319335
}
320336

321337
template <typename T>
322338
hpx::future<std::vector<std::decay_t<T>>> all_gather(char const* basename,
323-
T&& local_result, num_sites_arg num_sites = num_sites_arg(),
324-
this_site_arg this_site = this_site_arg(),
325-
generation_arg generation = generation_arg(),
326-
root_site_arg root_site = root_site_arg())
339+
T&& local_result, num_sites_arg const num_sites = num_sites_arg(),
340+
this_site_arg const this_site = this_site_arg(),
341+
generation_arg const generation = generation_arg(),
342+
root_site_arg const root_site = root_site_arg())
327343
{
328344
return all_gather(create_communicator(basename, num_sites, this_site,
329345
generation, root_site),
@@ -334,8 +350,8 @@ namespace hpx::collectives {
334350
template <typename T>
335351
std::vector<std::decay_t<T>> all_gather(hpx::launch::sync_policy,
336352
communicator fid, T&& local_result,
337-
this_site_arg this_site = this_site_arg(),
338-
generation_arg generation = generation_arg())
353+
this_site_arg const this_site = this_site_arg(),
354+
generation_arg const generation = generation_arg())
339355
{
340356
return all_gather(
341357
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
@@ -344,8 +360,8 @@ namespace hpx::collectives {
344360

345361
template <typename T>
346362
std::vector<std::decay_t<T>> all_gather(hpx::launch::sync_policy,
347-
communicator fid, T&& local_result, generation_arg generation,
348-
this_site_arg this_site = this_site_arg())
363+
communicator fid, T&& local_result, generation_arg const generation,
364+
this_site_arg const this_site = this_site_arg())
349365
{
350366
return all_gather(
351367
HPX_MOVE(fid), HPX_FORWARD(T, local_result), this_site, generation)
@@ -355,10 +371,10 @@ namespace hpx::collectives {
355371
template <typename T>
356372
std::vector<std::decay_t<T>> all_gather(hpx::launch::sync_policy,
357373
char const* basename, T&& local_result,
358-
num_sites_arg num_sites = num_sites_arg(),
359-
this_site_arg this_site = this_site_arg(),
360-
generation_arg generation = generation_arg(),
361-
root_site_arg root_site = root_site_arg())
374+
num_sites_arg const num_sites = num_sites_arg(),
375+
this_site_arg const this_site = this_site_arg(),
376+
generation_arg const generation = generation_arg(),
377+
root_site_arg const root_site = root_site_arg())
362378
{
363379
return all_gather(create_communicator(basename, num_sites, this_site,
364380
generation, root_site),

libs/full/collectives/include/hpx/collectives/all_reduce.hpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ namespace hpx::collectives {
310310
template <typename T, typename F>
311311
hpx::future<std::decay_t<T>> all_reduce(communicator fid, T&& local_result,
312312
F&& op, this_site_arg this_site = this_site_arg(),
313-
generation_arg generation = generation_arg())
313+
generation_arg const generation = generation_arg())
314314
{
315315
using arg_type = std::decay_t<T>;
316316

317-
if (this_site == static_cast<std::size_t>(-1))
317+
if (this_site.is_default())
318318
{
319319
this_site = agas::get_locality_id();
320320
}
@@ -325,6 +325,20 @@ namespace hpx::collectives {
325325
"the generation number shouldn't be zero"));
326326
}
327327

328+
// Handle operation right away if there is only one value.
329+
if (auto [num_sites, comm_site] = fid.get_info(); num_sites == 1)
330+
{
331+
if (this_site != comm_site)
332+
{
333+
return hpx::make_exceptional_future<arg_type>(HPX_GET_EXCEPTION(
334+
hpx::error::bad_parameter, "hpx::collectives::all_reduce",
335+
"the local site should be zero if only one site is "
336+
"involved"));
337+
}
338+
339+
return hpx::make_ready_future(HPX_FORWARD(T, local_result));
340+
}
341+
328342
auto all_reduce_data =
329343
[local_result = HPX_FORWARD(T, local_result),
330344
op = HPX_FORWARD(F, op), generation,
@@ -354,19 +368,20 @@ namespace hpx::collectives {
354368

355369
template <typename T, typename F>
356370
hpx::future<std::decay_t<T>> all_reduce(communicator fid, T&& local_result,
357-
F&& op, generation_arg generation,
358-
this_site_arg this_site = this_site_arg())
371+
F&& op, generation_arg const generation,
372+
this_site_arg const this_site = this_site_arg())
359373
{
360374
return all_reduce(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
361375
HPX_FORWARD(F, op), this_site, generation);
362376
}
363377

364378
template <typename T, typename F>
365379
hpx::future<std::decay_t<T>> all_reduce(char const* basename,
366-
T&& local_result, F&& op, num_sites_arg num_sites = num_sites_arg(),
367-
this_site_arg this_site = this_site_arg(),
368-
generation_arg generation = generation_arg(),
369-
root_site_arg root_site = root_site_arg())
380+
T&& local_result, F&& op,
381+
num_sites_arg const num_sites = num_sites_arg(),
382+
this_site_arg const this_site = this_site_arg(),
383+
generation_arg const generation = generation_arg(),
384+
root_site_arg const root_site = root_site_arg())
370385
{
371386
return all_reduce(create_communicator(basename, num_sites, this_site,
372387
generation, root_site),
@@ -376,8 +391,9 @@ namespace hpx::collectives {
376391
////////////////////////////////////////////////////////////////////////////
377392
template <typename T, typename F>
378393
decltype(auto) all_reduce(hpx::launch::sync_policy, communicator fid,
379-
T&& local_result, F&& op, this_site_arg this_site = this_site_arg(),
380-
generation_arg generation = generation_arg())
394+
T&& local_result, F&& op,
395+
this_site_arg const this_site = this_site_arg(),
396+
generation_arg const generation = generation_arg())
381397
{
382398
return all_reduce(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
383399
HPX_FORWARD(F, op), this_site, generation)
@@ -386,8 +402,8 @@ namespace hpx::collectives {
386402

387403
template <typename T, typename F>
388404
decltype(auto) all_reduce(hpx::launch::sync_policy, communicator fid,
389-
T&& local_result, F&& op, generation_arg generation,
390-
this_site_arg this_site = this_site_arg())
405+
T&& local_result, F&& op, generation_arg const generation,
406+
this_site_arg const this_site = this_site_arg())
391407
{
392408
return all_reduce(HPX_MOVE(fid), HPX_FORWARD(T, local_result),
393409
HPX_FORWARD(F, op), this_site, generation)
@@ -396,10 +412,11 @@ namespace hpx::collectives {
396412

397413
template <typename T, typename F>
398414
decltype(auto) all_reduce(hpx::launch::sync_policy, char const* basename,
399-
T&& local_result, F&& op, num_sites_arg num_sites = num_sites_arg(),
415+
T&& local_result, F&& op,
416+
num_sites_arg const num_sites = num_sites_arg(),
400417
this_site_arg this_site = this_site_arg(),
401-
generation_arg generation = generation_arg(),
402-
root_site_arg root_site = root_site_arg())
418+
generation_arg const generation = generation_arg(),
419+
root_site_arg const root_site = root_site_arg())
403420
{
404421
return all_reduce(create_communicator(basename, num_sites, this_site,
405422
generation, root_site),

libs/full/collectives/include/hpx/collectives/all_to_all.hpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ namespace hpx::collectives {
279279
hpx::future<std::vector<T>> all_to_all(communicator fid,
280280
std::vector<T>&& local_result,
281281
this_site_arg this_site = this_site_arg(),
282-
generation_arg generation = generation_arg())
282+
generation_arg const generation = generation_arg())
283283
{
284-
if (this_site == static_cast<std::size_t>(-1))
284+
if (this_site.is_default())
285285
{
286286
this_site = agas::get_locality_id();
287287
}
@@ -293,6 +293,20 @@ namespace hpx::collectives {
293293
"the generation number shouldn't be zero"));
294294
}
295295

296+
// Handle operation right away if there is only one value.
297+
if (auto [num_sites, comm_site] = fid.get_info(); num_sites == 1)
298+
{
299+
if (this_site != comm_site)
300+
{
301+
return hpx::make_exceptional_future<std::vector<T>>(
302+
HPX_GET_EXCEPTION(hpx::error::bad_parameter,
303+
"hpx::collectives::all_to_all",
304+
"the local site should be zero if only one site is "
305+
"involved"));
306+
}
307+
return hpx::make_ready_future(HPX_MOVE(local_result));
308+
}
309+
296310
auto all_to_all_data =
297311
[local_result = HPX_MOVE(local_result), this_site, generation](
298312
communicator&& c) mutable -> hpx::future<std::vector<T>> {
@@ -320,8 +334,8 @@ namespace hpx::collectives {
320334

321335
template <typename T>
322336
hpx::future<std::vector<T>> all_to_all(communicator fid,
323-
std::vector<T>&& local_result, generation_arg generation,
324-
this_site_arg this_site = this_site_arg())
337+
std::vector<T>&& local_result, generation_arg const generation,
338+
this_site_arg const this_site = this_site_arg())
325339
{
326340
return all_to_all(
327341
HPX_MOVE(fid), HPX_MOVE(local_result), this_site, generation);
@@ -330,10 +344,10 @@ namespace hpx::collectives {
330344
template <typename T>
331345
hpx::future<std::vector<T>> all_to_all(char const* basename,
332346
std::vector<T>&& local_result,
333-
num_sites_arg num_sites = num_sites_arg(),
334-
this_site_arg this_site = this_site_arg(),
335-
generation_arg generation = generation_arg(),
336-
root_site_arg root_site = root_site_arg())
347+
num_sites_arg const num_sites = num_sites_arg(),
348+
this_site_arg const this_site = this_site_arg(),
349+
generation_arg const generation = generation_arg(),
350+
root_site_arg const root_site = root_site_arg())
337351
{
338352
return all_to_all(create_communicator(basename, num_sites, this_site,
339353
generation, root_site),
@@ -344,8 +358,8 @@ namespace hpx::collectives {
344358
template <typename T>
345359
std::vector<T> all_to_all(hpx::launch::sync_policy, communicator fid,
346360
std::vector<T>&& local_result,
347-
this_site_arg this_site = this_site_arg(),
348-
generation_arg generation = generation_arg())
361+
this_site_arg const this_site = this_site_arg(),
362+
generation_arg const generation = generation_arg())
349363
{
350364
return all_to_all(
351365
HPX_MOVE(fid), HPX_MOVE(local_result), this_site, generation)
@@ -354,8 +368,8 @@ namespace hpx::collectives {
354368

355369
template <typename T>
356370
std::vector<T> all_to_all(hpx::launch::sync_policy, communicator fid,
357-
std::vector<T>&& local_result, generation_arg generation,
358-
this_site_arg this_site = this_site_arg())
371+
std::vector<T>&& local_result, generation_arg const generation,
372+
this_site_arg const this_site = this_site_arg())
359373
{
360374
return all_to_all(
361375
HPX_MOVE(fid), HPX_MOVE(local_result), this_site, generation)
@@ -365,10 +379,10 @@ namespace hpx::collectives {
365379
template <typename T>
366380
std::vector<T> all_to_all(hpx::launch::sync_policy, char const* basename,
367381
std::vector<T>&& local_result,
368-
num_sites_arg num_sites = num_sites_arg(),
369-
this_site_arg this_site = this_site_arg(),
370-
generation_arg generation = generation_arg(),
371-
root_site_arg root_site = root_site_arg())
382+
num_sites_arg const num_sites = num_sites_arg(),
383+
this_site_arg const this_site = this_site_arg(),
384+
generation_arg const generation = generation_arg(),
385+
root_site_arg const root_site = root_site_arg())
372386
{
373387
return all_to_all(create_communicator(basename, num_sites, this_site,
374388
generation, root_site),

libs/full/collectives/include/hpx/collectives/argument_types.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ namespace hpx::collectives {
2222
struct argument_type
2323
{
2424
explicit constexpr argument_type(
25-
std::size_t argument = Default) noexcept
25+
std::size_t const argument = Default) noexcept
2626
: argument_(argument)
2727
{
2828
}
2929

30-
constexpr argument_type& operator=(std::size_t argument) noexcept
30+
constexpr argument_type& operator=(
31+
std::size_t const argument) noexcept
3132
{
3233
argument_ = argument;
3334
return *this;
@@ -38,6 +39,11 @@ namespace hpx::collectives {
3839
return argument_;
3940
}
4041

42+
[[nodiscard]] constexpr bool is_default() const noexcept
43+
{
44+
return argument_ == Default;
45+
}
46+
4147
std::size_t argument_;
4248
};
4349

0 commit comments

Comments
 (0)