Skip to content
7 changes: 4 additions & 3 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
Generator,
Generic,
Iterable,
Optional,
NoReturn,
Sequence,
Tuple,
Union,
)
Expand Down Expand Up @@ -555,7 +556,7 @@ def avg(self, field_ref: str | FieldPath, alias=None):
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
Expand All @@ -568,7 +569,7 @@ def find_nearest(
Args:
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def _to_protobuf(self) -> StructuredQuery:
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
Expand Down
9 changes: 6 additions & 3 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import abc
from abc import ABC
from enum import Enum
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union

from google.api_core import gapic_v1
from google.api_core import retry as retries
Expand Down Expand Up @@ -137,16 +137,19 @@ def get(
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[float] = None,
):
"""Finds the closest vector embeddings to the given query vector."""
if not isinstance(query_vector, Vector):
self._query_vector = Vector(query_vector)
else:
self._query_vector = query_vector
self._vector_field = vector_field
self._query_vector = query_vector
self._limit = limit
self._distance_measure = distance_measure
self._distance_result_field = distance_result_field
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Sequence, Type

from google.api_core import exceptions, gapic_v1
from google.api_core import retry as retries
Expand Down Expand Up @@ -269,7 +269,7 @@ def _retry_query_after_exception(self, exc, retry, transaction):
def find_nearest(
self,
vector_field: str,
query_vector: Vector,
query_vector: Union[Vector, Sequence[float]],
limit: int,
distance_measure: DistanceMeasure,
*,
Expand All @@ -282,7 +282,7 @@ def find_nearest(
Args:
vector_field (str): An indexed vector field to search upon. Only documents which contain
vectors whose dimensionality match the query_vector can be returned.
query_vector (Vector): The query vector that we are searching on. Must be a vector of no more
query_vector(Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more
than 2048 dimensions.
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/v1/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.cloud.firestore_v1.vector import Vector


def _make_commit_repsonse():
def _make_commit_response():
response = mock.create_autospec(firestore.CommitResponse)
response.write_results = [mock.sentinel.write_result]
response.commit_time = mock.sentinel.commit_time
Expand All @@ -35,7 +35,7 @@ def _make_commit_repsonse():
def _make_firestore_api():
firestore_api = mock.Mock()
firestore_api.commit.mock_add_spec(spec=["commit"])
firestore_api.commit.return_value = _make_commit_repsonse()
firestore_api.commit.return_value = _make_commit_response()
return firestore_api


Expand Down
62 changes: 62 additions & 0 deletions tests/unit/v1/test_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,68 @@ def test_vector_query_collection_group(distance_measure, expected_distance):
)


def test_vector_query_list_as_query_vector():
# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["run_query"])
client = make_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")
query = make_query(parent)
parent_path, expected_prefix = parent._parent_info()

data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
response_pb1 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)
response_pb2 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)

kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)

# Execute the vector query and check the response.
firestore_api.run_query.return_value = iter([response_pb1, response_pb2])

vector_query = query.where("snooze", "==", 10).find_nearest(
vector_field="embedding",
query_vector=[1.0, 2.0, 3.0],
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=5,
)

returned = vector_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == data

expected_pb = _expected_pb(
parent=parent,
vector_field="embedding",
vector=Vector([1.0, 2.0, 3.0]),
distance_type=StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
limit=5,
)
expected_pb.where = StructuredQuery.Filter(
field_filter=StructuredQuery.FieldFilter(
field=StructuredQuery.FieldReference(field_path="snooze"),
op=StructuredQuery.FieldFilter.Operator.EQUAL,
value=encode_value(10),
)
)

firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
},
metadata=client._rpc_metadata,
**kwargs,
)


def test_query_stream_multiple_empty_response_in_stream():
from google.cloud.firestore_v1 import stream_generator

Expand Down
Loading