-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Add retrieve online documents v2 method into pgvector #5253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
a03b8c4
75233cf
776c327
e1f0cae
246d6a6
6e3413c
0169cb6
d35e0ba
86d40b0
55dec54
15dcabf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -119,13 +119,20 @@ def online_write_batch( | |
|
|
||
| for feature_name, val in values.items(): | ||
| vector_val = None | ||
| value_text = None | ||
|
|
||
| # Check if the feature type is STRING | ||
| if val.WhichOneof("val") == "string_val": | ||
| value_text = val.string_val | ||
|
|
||
| if config.online_store.vector_enabled: | ||
| vector_val = get_list_val_str(val) | ||
| insert_values.append( | ||
| ( | ||
| entity_key_bin, | ||
| feature_name, | ||
| val.SerializeToString(), | ||
| value_text, | ||
| vector_val, | ||
| timestamp, | ||
| created_ts, | ||
|
|
@@ -136,11 +143,12 @@ def online_write_batch( | |
| sql_query = sql.SQL( | ||
| """ | ||
| INSERT INTO {} | ||
| (entity_key, feature_name, value, vector_value, event_ts, created_ts) | ||
| VALUES (%s, %s, %s, %s, %s, %s) | ||
| (entity_key, feature_name, value, value_text, vector_value, event_ts, created_ts) | ||
| VALUES (%s, %s, %s, %s, %s, %s, %s) | ||
| ON CONFLICT (entity_key, feature_name) DO | ||
| UPDATE SET | ||
| value = EXCLUDED.value, | ||
| value_text = EXCLUDED.value_text, | ||
| vector_value = EXCLUDED.vector_value, | ||
| event_ts = EXCLUDED.event_ts, | ||
| created_ts = EXCLUDED.created_ts; | ||
|
|
@@ -308,6 +316,11 @@ def update( | |
| else: | ||
| # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility | ||
| vector_value_type = "BYTEA" | ||
|
|
||
| has_string_features = any( | ||
| f.dtype.to_value_type().value == 2 for f in table.features | ||
| ) # 2 is STRING in ValueType | ||
|
|
||
| cur.execute( | ||
| sql.SQL( | ||
| """ | ||
|
|
@@ -316,6 +329,7 @@ def update( | |
| entity_key BYTEA, | ||
| feature_name TEXT, | ||
| value BYTEA, | ||
| value_text TEXT NULL, -- Added for FTS | ||
| vector_value {} NULL, | ||
| event_ts TIMESTAMPTZ, | ||
| created_ts TIMESTAMPTZ, | ||
|
|
@@ -331,6 +345,16 @@ def update( | |
| ) | ||
| ) | ||
|
|
||
| if has_string_features: | ||
| cur.execute( | ||
| sql.SQL( | ||
| """CREATE INDEX IF NOT EXISTS {} ON {} USING GIN (to_tsvector('english', value_text));""" | ||
| ).format( | ||
| sql.Identifier(f"{table_name}_fts_idx"), | ||
| sql.Identifier(table_name), | ||
| ) | ||
| ) | ||
|
|
||
| conn.commit() | ||
|
|
||
| def teardown( | ||
|
|
@@ -456,6 +480,267 @@ def retrieve_online_documents( | |
|
|
||
| return result | ||
|
|
||
| def retrieve_online_documents_v2( | ||
| self, | ||
| config: RepoConfig, | ||
| table: FeatureView, | ||
| requested_features: List[str], | ||
| embedding: Optional[List[float]], | ||
| top_k: int, | ||
| distance_metric: Optional[str] = None, | ||
| query_string: Optional[str] = None, | ||
| ) -> List[ | ||
| Tuple[ | ||
| Optional[datetime], | ||
| Optional[EntityKeyProto], | ||
| Optional[Dict[str, ValueProto]], | ||
| ] | ||
| ]: | ||
| """ | ||
| Retrieve documents using vector similarity search or keyword search in PostgreSQL. | ||
|
|
||
| Args: | ||
| config: Feast configuration object | ||
| table: FeatureView object as the table to search | ||
| requested_features: List of requested features to retrieve | ||
| embedding: Query embedding to search for (optional) | ||
| top_k: Number of items to return | ||
| distance_metric: Distance metric to use (optional) | ||
| query_string: The query string to search for using keyword search (optional) | ||
|
|
||
| Returns: | ||
| List of tuples containing the event timestamp, entity key, and feature values | ||
| """ | ||
| if not config.online_store.vector_enabled: | ||
| raise ValueError("Vector search is not enabled in the online store config") | ||
|
|
||
| if embedding is None and query_string is None: | ||
| raise ValueError("Either embedding or query_string must be provided") | ||
|
|
||
| distance_metric = distance_metric or "L2" | ||
|
|
||
| if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT: | ||
| raise ValueError( | ||
| f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}" | ||
| ) | ||
|
|
||
| distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] | ||
|
|
||
| string_fields = [] | ||
YassinNouh21 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for feature in table.features: | ||
| if ( | ||
| feature.dtype.to_value_type().value == 2 | ||
| and feature.name in requested_features | ||
| ): # 2 is STRING | ||
| string_fields.append(feature.name) | ||
|
|
||
| table_name = _table_id(config.project, table) | ||
|
|
||
| with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur: | ||
| # Case 1: Hybrid Search (vector + text) | ||
| if embedding is not None and query_string is not None and string_fields: | ||
| tsquery_str = " & ".join(query_string.split()) | ||
|
|
||
| query = sql.SQL( | ||
| """ | ||
| SELECT | ||
| entity_key, | ||
| feature_name, | ||
| value, | ||
| vector_value, | ||
| vector_value {distance_metric_sql} %s::vector as distance, | ||
| ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank, | ||
| event_ts, | ||
| created_ts | ||
| FROM {table_name} | ||
| WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) | ||
| ORDER BY distance | ||
| LIMIT {top_k} | ||
| """ | ||
| ).format( | ||
| distance_metric_sql=sql.SQL(distance_metric_sql), | ||
| table_name=sql.Identifier(table_name), | ||
| top_k=sql.Literal(top_k), | ||
| ) | ||
|
|
||
| cur.execute(query, (embedding, tsquery_str, string_fields, tsquery_str)) | ||
| rows = cur.fetchall() | ||
|
|
||
| # Case 2: Vector Search Only | ||
| elif embedding is not None: | ||
| query = sql.SQL( | ||
| """ | ||
| SELECT | ||
| entity_key, | ||
| feature_name, | ||
| value, | ||
| vector_value, | ||
| vector_value {distance_metric_sql} %s::vector as distance, | ||
| NULL as text_rank, -- Keep consistent columns | ||
| event_ts, | ||
| created_ts | ||
| FROM {table_name} | ||
| ORDER BY distance | ||
| LIMIT {top_k} | ||
| """ | ||
| ).format( | ||
| distance_metric_sql=sql.SQL(distance_metric_sql), | ||
| table_name=sql.Identifier(table_name), | ||
| top_k=sql.Literal(top_k), | ||
| ) | ||
|
|
||
| cur.execute(query, (embedding,)) | ||
| rows = cur.fetchall() | ||
|
|
||
| # Case 3: Text Search Only | ||
| elif query_string is not None and string_fields: | ||
| tsquery_str = " & ".join(query_string.split()) | ||
| query = sql.SQL( | ||
| """ | ||
| WITH text_matches AS ( | ||
| SELECT DISTINCT entity_key, ts_rank(to_tsvector('english', value_text), to_tsquery('english', %s)) as text_rank | ||
| FROM {table_name} | ||
| WHERE feature_name = ANY(%s) AND to_tsvector('english', value_text) @@ to_tsquery('english', %s) | ||
| ORDER BY text_rank DESC | ||
| LIMIT {top_k} | ||
| ) | ||
| SELECT | ||
| t1.entity_key, | ||
| t1.feature_name, | ||
| t1.value, | ||
| t1.vector_value, | ||
| NULL as distance, | ||
| t2.text_rank, | ||
| t1.event_ts, | ||
| t1.created_ts | ||
| FROM {table_name} t1 | ||
| INNER JOIN text_matches t2 ON t1.entity_key = t2.entity_key | ||
| WHERE t1.feature_name = ANY(%s) | ||
| ORDER BY t2.text_rank DESC | ||
| """ | ||
| ).format( | ||
| table_name=sql.Identifier(table_name), | ||
| top_k=sql.Literal(top_k), | ||
| ) | ||
|
|
||
| cur.execute( | ||
|
||
| query, (tsquery_str, string_fields, tsquery_str, requested_features) | ||
| ) | ||
| rows = cur.fetchall() | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| "Either vector_enabled must be True for embedding search or string fields must be available for query_string search" | ||
| ) | ||
|
|
||
| # Group by entity_key to build feature records | ||
| entities_dict: Dict[str, Dict[str, Any]] = defaultdict( | ||
| lambda: { | ||
| "features": {}, | ||
| "timestamp": None, | ||
| "entity_key_proto": None, | ||
| "vector_distance": float("inf"), | ||
| "text_rank": 0.0, | ||
| } | ||
| ) | ||
|
|
||
| for ( | ||
| entity_key_bytes, | ||
| feature_name, | ||
| feature_val_bytes, | ||
| vector_val, | ||
| distance, | ||
| text_rank, | ||
| event_ts, | ||
| created_ts, | ||
| ) in rows: | ||
| entity_key_proto = None | ||
| if entity_key_bytes: | ||
| from feast.infra.key_encoding_utils import deserialize_entity_key | ||
|
|
||
| entity_key_proto = deserialize_entity_key(entity_key_bytes) | ||
|
|
||
| key = entity_key_bytes.hex() if entity_key_bytes else None | ||
|
|
||
| if key is None: | ||
| continue | ||
|
|
||
| entities_dict[key]["entity_key_proto"] = entity_key_proto | ||
|
|
||
| if ( | ||
| entities_dict[key]["timestamp"] is None | ||
| or event_ts > entities_dict[key]["timestamp"] | ||
| ): | ||
| entities_dict[key]["timestamp"] = event_ts | ||
|
|
||
| val = ValueProto() | ||
| if feature_val_bytes: | ||
| val.ParseFromString(feature_val_bytes) | ||
|
|
||
| entities_dict[key]["features"][feature_name] = val | ||
|
|
||
| if distance is not None: | ||
| entities_dict[key]["vector_distance"] = min( | ||
| entities_dict[key]["vector_distance"], float(distance) | ||
| ) | ||
| if text_rank is not None: | ||
| entities_dict[key]["text_rank"] = max( | ||
| entities_dict[key]["text_rank"], float(text_rank) | ||
| ) | ||
|
|
||
| if embedding is not None and query_string is not None: | ||
|
||
|
|
||
| def sort_key(x): | ||
| return x["vector_distance"] | ||
| elif embedding is not None: | ||
|
|
||
| def sort_key(x): | ||
| return x["vector_distance"] | ||
| else: # Text only | ||
|
|
||
| def sort_key(x): | ||
| return x["text_rank"] | ||
|
|
||
| sorted_entities = sorted( | ||
| entities_dict.values(), key=sort_key, reverse=(embedding is None) | ||
| )[:top_k] | ||
|
|
||
| result: List[ | ||
| Tuple[ | ||
| Optional[datetime], | ||
| Optional[EntityKeyProto], | ||
| Optional[Dict[str, ValueProto]], | ||
| ] | ||
| ] = [] | ||
| for entity_data in sorted_entities: | ||
| features = ( | ||
| entity_data["features"].copy() | ||
| if isinstance(entity_data["features"], dict) | ||
| else None | ||
| ) | ||
|
|
||
| if features is not None: | ||
| if "vector_distance" in entity_data and entity_data[ | ||
| "vector_distance" | ||
| ] != float("inf"): | ||
| dist_val = ValueProto() | ||
| dist_val.double_val = entity_data["vector_distance"] | ||
| features["distance"] = dist_val | ||
|
|
||
| if embedding is None or query_string is not None: | ||
| rank_val = ValueProto() | ||
| rank_val.double_val = entity_data["text_rank"] | ||
| features["text_rank"] = rank_val | ||
|
|
||
| result.append( | ||
| ( | ||
| entity_data["timestamp"], | ||
| entity_data["entity_key_proto"], | ||
| features, | ||
| ) | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| def _table_id(project: str, table: FeatureView) -> str: | ||
| return f"{project}_{table.name}" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe there's a more explicit way to handle this? Feels like this could be cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do u think this will be a better version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes!