Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 34 additions & 51 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@
#include "arrow/util/future.h"
#include "arrow/util/logging.h"
#include "arrow/util/optional.h"
#include "arrow/util/task_group.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/windows_fixup.h"

namespace arrow {

using internal::TaskGroup;
using internal::Uri;

namespace fs {
Expand Down Expand Up @@ -1142,82 +1144,65 @@ struct TreeWalker : public std::enable_shared_from_this<TreeWalker> {
recursion_handler_(std::move(recursion_handler)) {}

private:
std::shared_ptr<TaskGroup> task_group_;
std::mutex mutex_;
Future<> future_;
std::atomic<int32_t> num_in_flight_;

Status DoWalk() {
future_ = decltype(future_)::Make();
num_in_flight_ = 0;
task_group_ =
TaskGroup::MakeThreaded(io_context_.executor(), io_context_.stop_token());
WalkChild(base_dir_, /*nesting_depth=*/0);
// When this returns, ListObjectsV2 tasks either have finished or will exit early
return future_.status();
return task_group_->Finish();
}

bool is_finished() const { return future_.is_finished(); }

void ListObjectsFinished(Status st) {
const auto in_flight = --num_in_flight_;
if (!st.ok() || !in_flight) {
future_.MarkFinished(std::move(st));
}
}
bool ok() const { return task_group_->ok(); }

struct ListObjectsV2Handler {
std::shared_ptr<TreeWalker> walker;
std::string prefix;
int32_t nesting_depth;
S3Model::ListObjectsV2Request req;

void operator()(const Result<S3Model::ListObjectsV2Outcome>& result) {
Status operator()(const Result<S3Model::ListObjectsV2Outcome>& result) {
// Serialize calls to operation-specific handlers
std::unique_lock<std::mutex> guard(walker->mutex_);
if (walker->is_finished()) {
if (!walker->ok()) {
// Early exit: avoid executing handlers if DoWalk() returned
return;
return Status::OK();
}
if (!result.ok()) {
HandleError(result.status());
return;
return result.status();
}
const auto& outcome = *result;
if (!outcome.IsSuccess()) {
Status st = walker->error_handler_(outcome.GetError());
HandleError(std::move(st));
return;
{
std::lock_guard<std::mutex> guard(walker->mutex_);
return walker->error_handler_(outcome.GetError());
}
}
HandleResult(outcome.GetResult());
return HandleResult(outcome.GetResult());
}

void SpawnListObjectsV2() {
auto walker = this->walker;
auto req = this->req;
auto maybe_fut = walker->io_context_.executor()->Submit(
walker->io_context_.stop_token(),
[walker, req]() { return walker->client_->ListObjectsV2(req); });
if (!maybe_fut.ok()) {
HandleError(maybe_fut.status());
return;
}
maybe_fut->AddCallback(*this);
auto cb = *this;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're capturing *this by value then you don't need to capture req and walker as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the extraneous captures.

walker->task_group_->Append([cb]() mutable {
Result<S3Model::ListObjectsV2Outcome> result =
cb.walker->client_->ListObjectsV2(cb.req);
return cb(result);
});
}

void HandleError(Status status) { walker->ListObjectsFinished(std::move(status)); }

void HandleResult(const S3Model::ListObjectsV2Result& result) {
bool recurse = result.GetCommonPrefixes().size() > 0;
if (recurse) {
auto maybe_recurse = walker->recursion_handler_(nesting_depth + 1);
if (!maybe_recurse.ok()) {
walker->ListObjectsFinished(maybe_recurse.status());
return;
Status HandleResult(const S3Model::ListObjectsV2Result& result) {
bool recurse;
{
// Only one thread should be running result_handler_/recursion_handler_ at a time
std::lock_guard<std::mutex> guard(walker->mutex_);
recurse = result.GetCommonPrefixes().size() > 0;
if (recurse) {
ARROW_ASSIGN_OR_RAISE(auto maybe_recurse,
walker->recursion_handler_(nesting_depth + 1));
recurse &= maybe_recurse;
}
recurse &= *maybe_recurse;
}
Status st = walker->result_handler_(prefix, result);
if (!st.ok()) {
walker->ListObjectsFinished(std::move(st));
return;
RETURN_NOT_OK(walker->result_handler_(prefix, result));
}
if (recurse) {
walker->WalkChildren(result, nesting_depth + 1);
Expand All @@ -1228,9 +1213,8 @@ struct TreeWalker : public std::enable_shared_from_this<TreeWalker> {
DCHECK(!result.GetNextContinuationToken().empty());
req.SetContinuationToken(result.GetNextContinuationToken());
SpawnListObjectsV2();
} else {
walker->ListObjectsFinished(Status::OK());
}
return Status::OK();
}

void Start() {
Expand All @@ -1246,7 +1230,6 @@ struct TreeWalker : public std::enable_shared_from_this<TreeWalker> {

void WalkChild(std::string key, int32_t nesting_depth) {
ListObjectsV2Handler handler{shared_from_this(), std::move(key), nesting_depth, {}};
++num_in_flight_;
handler.Start();
}

Expand Down