Skip to content
27 changes: 23 additions & 4 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,15 +2097,34 @@ def _retrieve_from_online_store_v2(
entity_key_dict[key] = []
entity_key_dict[key].append(python_value)

table_entity_values, idxs, output_len = utils._get_unique_entities_from_values(
entity_key_dict,
)

features_to_request: List[str] = []
if requested_features:
features_to_request = requested_features + ["distance"]
# Add text_rank for text search queries
if query_string is not None:
features_to_request.append("text_rank")
else:
features_to_request = ["distance"]
# Add text_rank for text search queries
if query_string is not None:
features_to_request.append("text_rank")

if not datevals:
online_features_response = GetOnlineFeaturesResponse(results=[])
for feature in features_to_request:
field = online_features_response.results.add()
field.values.extend([])
field.statuses.extend([])
field.event_timestamps.extend([])
online_features_response.metadata.feature_names.val.extend(
features_to_request
)
return OnlineResponse(online_features_response)

table_entity_values, idxs, output_len = utils._get_unique_entities_from_values(
entity_key_dict,
)

feature_data = utils._convert_rows_to_protobuf(
requested_features=features_to_request,
read_rows=list(zip(datevals, list_of_feature_dicts)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,20 @@ def online_write_batch(

for feature_name, val in values.items():
vector_val = None
value_text = None

# Check if the feature type is STRING
if val.WhichOneof("val") == "string_val":
value_text = val.string_val

if config.online_store.vector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
value_text,
vector_val,
timestamp,
created_ts,
Expand All @@ -136,11 +143,12 @@ def online_write_batch(
sql_query = sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES (%s, %s, %s, %s, %s, %s)
(entity_key, feature_name, value, value_text, vector_value, event_ts, created_ts)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
value_text = EXCLUDED.value_text,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
Expand Down Expand Up @@ -308,6 +316,11 @@ def update(
else:
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
vector_value_type = "BYTEA"

has_string_features = any(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there's a more explicit way to handle this? Feels like this could be cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do u think this will be a better version

has_string_features = any(
                    f.dtype.to_value_type() == ValueType.STRING 
                    for f in table.features
                )

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

f.dtype.to_value_type().value == 2 for f in table.features
) # 2 is STRING in ValueType

cur.execute(
sql.SQL(
"""
Expand All @@ -316,6 +329,7 @@ def update(
entity_key BYTEA,
feature_name TEXT,
value BYTEA,
value_text TEXT NULL, -- Added for FTS
vector_value {} NULL,
event_ts TIMESTAMPTZ,
created_ts TIMESTAMPTZ,
Expand All @@ -331,6 +345,16 @@ def update(
)
)

if has_string_features:
cur.execute(
sql.SQL(
"""CREATE INDEX IF NOT EXISTS {} ON {} USING GIN (to_tsvector('english', value_text));"""
).format(
sql.Identifier(f"{table_name}_fts_idx"),
sql.Identifier(table_name),
)
)

conn.commit()

def teardown(
Expand Down Expand Up @@ -456,6 +480,267 @@ def retrieve_online_documents(

return result

def retrieve_online_documents_v2(
self,
config: RepoConfig,
table: FeatureView,
requested_features: List[str],
embedding: Optional[List[float]],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
]:
"""
Retrieve documents using vector similarity search or keyword search in PostgreSQL.

Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_features: List of requested features to retrieve
embedding: Query embedding to search for (optional)
top_k: Number of items to return
distance_metric: Distance metric to use (optional)
query_string: The query string to search for using keyword search (optional)

Returns:
List of tuples containing the event timestamp, entity key, and feature values
"""
if not config.online_store.vector_enabled:
raise ValueError("Vector search is not enabled in the online store config")

if embedding is None and query_string is None:
raise ValueError("Either embedding or query_string must be provided")

distance_metric = distance_metric or "L2"

if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT:
raise ValueError(
f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}"
)

distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]

string_fields = []
for feature in table.features:
if (
feature.dtype.to_value_type().value == 2
and feature.name in requested_features
): # 2 is STRING
string_fields.append(feature.name)

table_name = _table_id(config.project, table)

with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
# Case 1: Hybrid Search (vector + text)
if embedding is not None and query_string is not None and string_fields:
tsquery_str = " & ".join(query_string.split())

query = sql.SQL(
"""
SELECT
entity_key,
feature_name,
value,
vector_value,
vector_value {distance_metric_sql} %s::vector as distance,
ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank,
event_ts,
created_ts
FROM {table_name}
WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s)
ORDER BY distance
LIMIT {top_k}
"""
).format(
distance_metric_sql=sql.SQL(distance_metric_sql),
table_name=sql.Identifier(table_name),
top_k=sql.Literal(top_k),
)

cur.execute(query, (embedding, tsquery_str, string_fields, tsquery_str))
rows = cur.fetchall()

# Case 2: Vector Search Only
elif embedding is not None:
query = sql.SQL(
"""
SELECT
entity_key,
feature_name,
value,
vector_value,
vector_value {distance_metric_sql} %s::vector as distance,
NULL as text_rank, -- Keep consistent columns
event_ts,
created_ts
FROM {table_name}
ORDER BY distance
LIMIT {top_k}
"""
).format(
distance_metric_sql=sql.SQL(distance_metric_sql),
table_name=sql.Identifier(table_name),
top_k=sql.Literal(top_k),
)

cur.execute(query, (embedding,))
rows = cur.fetchall()

# Case 3: Text Search Only
elif query_string is not None and string_fields:
tsquery_str = " & ".join(query_string.split())
query = sql.SQL(
"""
WITH text_matches AS (
SELECT DISTINCT entity_key, ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank
FROM {table_name}
WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s)
ORDER BY text_rank DESC
LIMIT {top_k}
)
SELECT
t1.entity_key,
t1.feature_name,
t1.value,
t1.vector_value,
NULL as distance,
t2.text_rank,
t1.event_ts,
t1.created_ts
FROM {table_name} t1
INNER JOIN text_matches t2 ON t1.entity_key = t2.entity_key
WHERE t1.feature_name = ANY(%s)
ORDER BY t2.text_rank DESC
"""
).format(
table_name=sql.Identifier(table_name),
top_k=sql.Literal(top_k),
)

cur.execute(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cur.execute and cur.fetchall() is repeated in all conditions.

if hybrid search, params = [embedding, tsquery_str, string_fields, tsquery_str]
if vector search, params = [embedding],
.....

cur.execute(query, params)
rows = cur.fetchall()

query, (tsquery_str, string_fields, tsquery_str, requested_features)
)
rows = cur.fetchall()

else:
raise ValueError(
"Either vector_enabled must be True for embedding search or string fields must be available for query_string search"
)

# Group by entity_key to build feature records
entities_dict: Dict[str, Dict[str, Any]] = defaultdict(
lambda: {
"features": {},
"timestamp": None,
"entity_key_proto": None,
"vector_distance": float("inf"),
"text_rank": 0.0,
}
)

for (
entity_key_bytes,
feature_name,
feature_val_bytes,
vector_val,
distance,
text_rank,
event_ts,
created_ts,
) in rows:
entity_key_proto = None
if entity_key_bytes:
from feast.infra.key_encoding_utils import deserialize_entity_key

entity_key_proto = deserialize_entity_key(entity_key_bytes)

key = entity_key_bytes.hex() if entity_key_bytes else None

if key is None:
continue

entities_dict[key]["entity_key_proto"] = entity_key_proto

if (
entities_dict[key]["timestamp"] is None
or event_ts > entities_dict[key]["timestamp"]
):
entities_dict[key]["timestamp"] = event_ts

val = ValueProto()
if feature_val_bytes:
val.ParseFromString(feature_val_bytes)

entities_dict[key]["features"][feature_name] = val

if distance is not None:
entities_dict[key]["vector_distance"] = min(
entities_dict[key]["vector_distance"], float(distance)
)
if text_rank is not None:
entities_dict[key]["text_rank"] = max(
entities_dict[key]["text_rank"], float(text_rank)
)

if embedding is not None and query_string is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be simplified ?

def sort_key(item: Dict[str, Any]) -> float:
            return item["vector_distance"] if embedding else item["text_rank"]


def sort_key(x):
return x["vector_distance"]
elif embedding is not None:

def sort_key(x):
return x["vector_distance"]
else: # Text only

def sort_key(x):
return x["text_rank"]

sorted_entities = sorted(
entities_dict.values(), key=sort_key, reverse=(embedding is None)
)[:top_k]

result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
] = []
for entity_data in sorted_entities:
features = (
entity_data["features"].copy()
if isinstance(entity_data["features"], dict)
else None
)

if features is not None:
if "vector_distance" in entity_data and entity_data[
"vector_distance"
] != float("inf"):
dist_val = ValueProto()
dist_val.double_val = entity_data["vector_distance"]
features["distance"] = dist_val

if embedding is None or query_string is not None:
rank_val = ValueProto()
rank_val.double_val = entity_data["text_rank"]
features["text_rank"] = rank_val

result.append(
(
entity_data["timestamp"],
entity_data["entity_key_proto"],
features,
)
)
return result


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"
Expand Down
Loading
Loading