Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 77 additions & 46 deletions src/indexes/text/term.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ TermIterator::TermIterator(std::vector<Postings::KeyIterator>&& key_iterators,
const InternedStringSet* untracked_keys)
: query_field_mask_(query_field_mask),
key_iterators_(std::move(key_iterators)),
pos_iterators_(),
current_key_(),
current_position_(std::nullopt),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: No need to initialize with std::nullopt, it's the default value.

current_field_mask_(0ULL),
untracked_keys_(untracked_keys) {
untracked_keys_(untracked_keys),
key_heap_(),
pos_heap_(),
current_key_indices_(),
current_pos_indices_() {
Comment on lines +21 to +25
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need to explicitly call the default constructor

// Prime the first key and position if they exist.
if (!key_iterators_.empty()) {
TermIterator::NextKey();
Expand All @@ -39,40 +45,57 @@ const InternedStringPtr& TermIterator::CurrentKey() const {
return current_key_;
}

// Helper function to insert a key iterator into the heap if it is valid
void TermIterator::InsertValidKeyIterator(size_t idx) {
auto& key_iter = key_iterators_[idx];
while (key_iter.IsValid() && !key_iter.ContainsFields(query_field_mask_)) {
key_iter.NextKey();
}
if (key_iter.IsValid()) {
key_heap_.emplace(key_iter.GetKey(), idx);
}
}

bool TermIterator::FindMinimumValidKey() {
current_key_ = nullptr;
current_position_ = std::nullopt;
current_field_mask_ = 0ULL;
for (auto& key_iter : key_iterators_) {
while (key_iter.IsValid() && !key_iter.ContainsFields(query_field_mask_)) {
key_iter.NextKey();
}
if (key_iter.IsValid()) {
const auto& key = key_iter.GetKey();
if (!current_key_ || key < current_key_) {
pos_iterators_.clear();
pos_iterators_.emplace_back(key_iter.GetPositionIterator());
current_key_ = key;
} else if (key == current_key_) {
pos_iterators_.emplace_back(key_iter.GetPositionIterator());
}
// Build heap only if empty
if (key_heap_.empty()) {
for (size_t i = 0; i < key_iterators_.size(); ++i) {
InsertValidKeyIterator(i);
}
}
if (!current_key_) {
if (key_heap_.empty()) {
current_key_ = nullptr;
current_position_ = std::nullopt;
current_field_mask_ = 0ULL;
return false;
}
// No need to check since we know that at least one position exists based on
// ContainsFields.
// Get minimum key
current_key_ = key_heap_.top().first;
pos_iterators_.clear();
current_key_indices_.clear();
// Collect all iterators with minimum key
while (!key_heap_.empty() && key_heap_.top().first == current_key_) {
size_t idx = key_heap_.top().second;
key_heap_.pop();
current_key_indices_.push_back(idx);
pos_iterators_.emplace_back(key_iterators_[idx].GetPositionIterator());
}
// Clear position state for new key
pos_heap_ = {};
current_position_ = std::nullopt;
TermIterator::NextPosition();
return true;
}

bool TermIterator::NextKey() {
if (current_key_) {
for (auto& key_iter : key_iterators_) {
if (key_iter.IsValid() && key_iter.GetKey() == current_key_) {
key_iter.NextKey();
}
// First advance all iterators at current key
for (size_t idx : current_key_indices_) {
key_iterators_[idx].NextKey();
}
// Then insert them back if still valid
for (size_t idx : current_key_indices_) {
InsertValidKeyIterator(idx);
}
}
return FindMinimumValidKey();
Expand Down Expand Up @@ -102,39 +125,47 @@ const PositionRange& TermIterator::CurrentPosition() const {
return current_position_.value();
}

// Helper function to insert a position iterator into the heap if it is valid
void TermIterator::InsertValidPositionIterator(size_t idx) {
auto& pos_iter = pos_iterators_[idx];
while (pos_iter.IsValid() && !(pos_iter.GetFieldMask() & query_field_mask_)) {
pos_iter.NextPosition();
}
if (pos_iter.IsValid()) {
pos_heap_.emplace(pos_iter.GetPosition(), idx);
}
}

bool TermIterator::NextPosition() {
if (current_position_.has_value()) {
for (auto& pos_iter : pos_iterators_) {
if (pos_iter.IsValid() &&
pos_iter.GetPosition() == current_position_.value().start) {
pos_iter.NextPosition();
}
// Advance all iterators at current position
for (size_t idx : current_pos_indices_) {
pos_iterators_[idx].NextPosition();
}
}
uint32_t min_position = UINT32_MAX;
bool found = false;
FieldMaskPredicate field;
for (auto& pos_iter : pos_iterators_) {
while (pos_iter.IsValid() &&
!(pos_iter.GetFieldMask() & query_field_mask_)) {
pos_iter.NextPosition();
// Then insert them back if still valid
for (size_t idx : current_pos_indices_) {
InsertValidPositionIterator(idx);
}
if (pos_iter.IsValid()) {
uint32_t position = pos_iter.GetPosition();
if (position < min_position) {
min_position = position;
field = pos_iter.GetFieldMask();
found = true;
}
} else {
// Initialize heap (new key)
for (size_t i = 0; i < pos_iterators_.size(); ++i) {
InsertValidPositionIterator(i);
}
}
if (!found) {
if (pos_heap_.empty()) {
current_position_ = std::nullopt;
current_field_mask_ = 0ULL;
return false;
}
uint32_t min_position = pos_heap_.top().first;
current_pos_indices_.clear();
// Collect all iterators at minimum position
while (!pos_heap_.empty() && pos_heap_.top().first == min_position) {
current_pos_indices_.push_back(pos_heap_.top().second);
pos_heap_.pop();
}
current_position_ = PositionRange{min_position, min_position};
current_field_mask_ = field;
current_field_mask_ = pos_iterators_[current_pos_indices_[0]].GetFieldMask();
return true;
}

Expand Down
16 changes: 16 additions & 0 deletions src/indexes/text/term.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifndef _VALKEY_SEARCH_INDEXES_TEXT_TERM_H_
#define _VALKEY_SEARCH_INDEXES_TEXT_TERM_H_

#include <queue>
#include <vector>

#include "src/indexes/text/text_iterator.h"
Expand Down Expand Up @@ -66,7 +67,22 @@ class TermIterator : public TextIterator {
std::optional<PositionRange> current_position_;
FieldMaskPredicate current_field_mask_;
const InternedStringSet* untracked_keys_;

// Heap for efficient min-finding
std::priority_queue<std::pair<Key, size_t>,
std::vector<std::pair<Key, size_t>>, std::greater<>>
key_heap_;
std::priority_queue<std::pair<uint32_t, size_t>,
std::vector<std::pair<uint32_t, size_t>>, std::greater<>>
pos_heap_;

// Track which iterators have current key/position
std::vector<size_t> current_key_indices_;
std::vector<size_t> current_pos_indices_;

bool FindMinimumValidKey();
void InsertValidKeyIterator(size_t idx);
void InsertValidPositionIterator(size_t idx);
};

} // namespace valkey_search::indexes::text
Expand Down
Loading