Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Iterable, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple

from feast import type_map
from feast.data_source import DataSource
Expand Down Expand Up @@ -119,18 +119,20 @@ def get_table_column_names_and_types(

client = bigquery.Client()
if self.table_ref is not None:
table_schema = client.get_table(self.table_ref).schema
if not isinstance(table_schema[0], bigquery.schema.SchemaField):
schema = client.get_table(self.table_ref).schema
if not isinstance(schema[0], bigquery.schema.SchemaField):
raise TypeError("Could not parse BigQuery table schema.")

name_type_pairs = [(field.name, field.field_type) for field in table_schema]
else:
bq_columns_query = f"SELECT * FROM ({self.query}) LIMIT 1"
queryRes = client.query(bq_columns_query).result()
name_type_pairs = [
(schema_field.name, schema_field.field_type)
for schema_field in queryRes.schema
]
schema = queryRes.schema

name_type_pairs: List[Tuple[str, str]] = []
for field in schema:
bq_type_as_str = field.field_type
if field.mode == "REPEATED":
bq_type_as_str = "ARRAY<" + bq_type_as_str + ">"
name_type_pairs.append((field.name, bq_type_as_str))

return name_type_pairs

Expand Down
86 changes: 56 additions & 30 deletions sdk/python/feast/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from datetime import datetime
from datetime import date, datetime
from typing import Any, Dict, List, Optional, Set, Tuple, Type

import numpy as np
Expand Down Expand Up @@ -92,7 +91,7 @@ def feast_value_type_to_pandas_type(value_type: ValueType) -> Any:
ValueType.DOUBLE: "float",
ValueType.BYTES: "bytes",
ValueType.BOOL: "bool",
ValueType.UNIX_TIMESTAMP: "datetime",
ValueType.UNIX_TIMESTAMP: "datetime64[ns]",
}
if value_type.name.endswith("_LIST"):
return "object"
Expand Down Expand Up @@ -121,7 +120,6 @@ def python_type_to_feast_value_type(
Feast Value Type
"""
type_name = (type_name or type(value).__name__).lower()

type_map = {
"int": ValueType.INT64,
"str": ValueType.STRING,
Expand All @@ -141,6 +139,7 @@ def python_type_to_feast_value_type(
"datetime": ValueType.UNIX_TIMESTAMP,
"datetime64[ns]": ValueType.UNIX_TIMESTAMP,
"datetime64[ns, tz]": ValueType.UNIX_TIMESTAMP,
"date": ValueType.UNIX_TIMESTAMP,
"category": ValueType.STRING,
}

Expand Down Expand Up @@ -284,6 +283,21 @@ def _python_value_to_proto_value(feast_value_type: ValueType, value: Any) -> Pro
if value is None:
return ProtoValue()

if feast_value_type == ValueType.UNIX_TIMESTAMP_LIST:
converted_value = []
for sub_value in value:
if isinstance(sub_value, datetime):
converted_value.append(int(sub_value.timestamp()))
elif isinstance(sub_value, date):
converted_value.append(
int(datetime(*sub_value.timetuple()[:6]).timestamp())
)
elif isinstance(sub_value, Timestamp):
converted_value.append(int(sub_value.ToSeconds()))
else:
converted_value.append(sub_value)
value = converted_value

if feast_value_type in PYTHON_LIST_VALUE_TYPE_TO_PROTO_VALUE:
proto_type, field_name, valid_types = PYTHON_LIST_VALUE_TYPE_TO_PROTO_VALUE[
feast_value_type
Expand All @@ -307,6 +321,10 @@ def _python_value_to_proto_value(feast_value_type: ValueType, value: Any) -> Pro
if feast_value_type == ValueType.UNIX_TIMESTAMP:
if isinstance(value, datetime):
return ProtoValue(int64_val=int(value.timestamp()))
elif isinstance(value, date):
return ProtoValue(
int64_val=int(datetime(*value.timetuple()[:6]).timestamp())
)
elif isinstance(value, Timestamp):
return ProtoValue(int64_val=int(value.ToSeconds()))
return ProtoValue(int64_val=int(value))
Expand Down Expand Up @@ -374,31 +392,40 @@ def _proto_value_to_value_type(proto_value: ProtoValue) -> ValueType:


def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType:
if re.match(r"^timestamp", pa_type_as_str):
return ValueType.INT64
is_list = False
if pa_type_as_str.startswith("list<item: "):
is_list = True
pa_type_as_str = pa_type_as_str[11:-1]

type_map = {
"int32": ValueType.INT32,
"int64": ValueType.INT64,
"double": ValueType.DOUBLE,
"float": ValueType.FLOAT,
"string": ValueType.STRING,
"binary": ValueType.BYTES,
"bool": ValueType.BOOL,
"list<item: int32>": ValueType.INT32_LIST,
"list<item: int64>": ValueType.INT64_LIST,
"list<item: double>": ValueType.DOUBLE_LIST,
"list<item: float>": ValueType.FLOAT_LIST,
"list<item: string>": ValueType.STRING_LIST,
"list<item: binary>": ValueType.BYTES_LIST,
"list<item: bool>": ValueType.BOOL_LIST,
"null": ValueType.NULL,
}
return type_map[pa_type_as_str]
if pa_type_as_str.startswith("timestamp"):
value_type = ValueType.UNIX_TIMESTAMP
else:
type_map = {
"int32": ValueType.INT32,
"int64": ValueType.INT64,
"double": ValueType.DOUBLE,
"float": ValueType.FLOAT,
"string": ValueType.STRING,
"binary": ValueType.BYTES,
"bool": ValueType.BOOL,
"null": ValueType.NULL,
}
value_type = type_map[pa_type_as_str]

if is_list:
value_type = ValueType[value_type.name + "_LIST"]

return value_type


def bq_to_feast_value_type(bq_type_as_str: str) -> ValueType:
is_list = False
if bq_type_as_str.startswith("ARRAY<"):
is_list = True
bq_type_as_str = bq_type_as_str[6:-1]

type_map: Dict[str, ValueType] = {
"DATE": ValueType.UNIX_TIMESTAMP,
"DATETIME": ValueType.UNIX_TIMESTAMP,
"TIMESTAMP": ValueType.UNIX_TIMESTAMP,
"INTEGER": ValueType.INT64,
Expand All @@ -409,15 +436,14 @@ def bq_to_feast_value_type(bq_type_as_str: str) -> ValueType:
"BYTES": ValueType.BYTES,
"BOOL": ValueType.BOOL,
"BOOLEAN": ValueType.BOOL, # legacy sql data type
"ARRAY<INT64>": ValueType.INT64_LIST,
"ARRAY<FLOAT64>": ValueType.DOUBLE_LIST,
"ARRAY<STRING>": ValueType.STRING_LIST,
"ARRAY<BYTES>": ValueType.BYTES_LIST,
"ARRAY<BOOL>": ValueType.BOOL_LIST,
"NULL": ValueType.NULL,
}

return type_map[bq_type_as_str]
value_type = type_map[bq_type_as_str]
if is_list:
value_type = ValueType[value_type.name + "_LIST"]

return value_type


def redshift_to_feast_value_type(redshift_type_as_str: str) -> ValueType:
Expand Down
16 changes: 15 additions & 1 deletion sdk/python/tests/data/data_creator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import date, datetime, timedelta
from typing import List

import pandas as pd
Expand Down Expand Up @@ -60,6 +60,20 @@ def get_feature_values_for_dtype(
"float": [1.0, None, 3.0, 4.0, 5.0],
"string": ["1", None, "3", "4", "5"],
"bool": [True, None, False, True, False],
"datetime": [
datetime(2020, 1, 2),
None,
datetime(2020, 1, 3),
datetime(2020, 1, 4),
datetime(2020, 1, 5),
],
"date": [
date(2020, 1, 2),
None,
date(2020, 1, 3),
date(2020, 1, 4),
date(2020, 1, 5),
],
}
non_list_val = dtype_map[dtype]
if is_list:
Expand Down
59 changes: 43 additions & 16 deletions sdk/python/tests/integration/registration/test_universal_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from datetime import date, datetime, timedelta
from typing import List

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.types
import pytest

import feast
from feast.infra.offline_stores.offline_store import RetrievalJob
from feast.value_type import ValueType
from tests.data.data_creator import create_dataset
Expand All @@ -25,6 +28,8 @@ def populate_test_configs(offline: bool):
(ValueType.INT64, "int64"),
(ValueType.STRING, "float"),
(ValueType.STRING, "bool"),
(ValueType.INT32, "datetime"),
(ValueType.INT32, "date"),
]
configs: List[TypeTestConfig] = []
for test_repo_config in FULL_REPO_CONFIGS:
Expand Down Expand Up @@ -149,6 +154,20 @@ def test_entity_inference_types_match(offline_types_test_fixtures):
@pytest.mark.universal
def test_feature_get_historical_features_types_match(offline_types_test_fixtures):
environment, config, data_source, fv = offline_types_test_fixtures

# TODO: improve how FileSource handles Arrow schema inference.
if (
config.feature_dtype == "date"
and config.feature_is_list
and config.has_empty_list
and isinstance(data_source, feast.FileSource)
):
pytest.xfail(
"`feast.FileSource` cannot deal with returning all empty "
"`List[date]` features to Arrow as it infers the schema "
"from the Pandas Dataframe which does not have a dtype to represent `date`"
)

fs = environment.feature_store
fv = create_feature_view(
config.feature_dtype, config.feature_is_list, config.has_empty_list, data_source
Expand Down Expand Up @@ -216,6 +235,8 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures):
"float": float,
"string": str,
"bool": bool,
"date": int,
"datetime": int,
}
expected_dtype = feature_list_dtype_to_expected_online_response_value_type[
config.feature_dtype
Expand All @@ -240,6 +261,8 @@ def create_feature_view(feature_dtype, feature_is_list, has_empty_list, data_sou
value_type = ValueType.FLOAT_LIST
elif feature_dtype == "bool":
value_type = ValueType.BOOL_LIST
elif feature_dtype in ("date", "datetime"):
value_type = ValueType.UNIX_TIMESTAMP_LIST
else:
if feature_dtype == "int32":
value_type = ValueType.INT32
Expand All @@ -249,6 +272,8 @@ def create_feature_view(feature_dtype, feature_is_list, has_empty_list, data_sou
value_type = ValueType.FLOAT
elif feature_dtype == "bool":
value_type = ValueType.BOOL
elif feature_dtype in ("date", "datetime"):
value_type = ValueType.UNIX_TIMESTAMP
return driver_feature_view(data_source, value_type=value_type,)


Expand All @@ -262,6 +287,8 @@ def assert_expected_historical_feature_types(
"float": (pd.api.types.is_float_dtype,),
"string": (pd.api.types.is_string_dtype,),
"bool": (pd.api.types.is_bool_dtype, pd.api.types.is_object_dtype),
"date": (pd.api.types.is_object_dtype,),
"datetime": (pd.api.types.is_datetime64_any_dtype,),
}
dtype_checkers = feature_dtype_to_expected_historical_feature_dtype[feature_dtype]
assert any(
Expand All @@ -288,6 +315,8 @@ def assert_feature_list_types(
bool,
np.bool_,
), # Can be `np.bool_` if from `np.array` rather that `list`
"datetime": (datetime, np.datetime64,),
"date": (date, np.datetime64),
}
expected_dtype = feature_list_dtype_to_expected_historical_feature_list_dtype[
feature_dtype
Expand All @@ -307,24 +336,22 @@ def assert_expected_arrow_types(
):
print("Asserting historical feature arrow types")
historical_features_arrow = historical_features.to_arrow()
print(historical_features_arrow)
feature_list_dtype_to_expected_historical_feature_arrow_type = {
"int32": "int64",
"int64": "int64",
"float": "double",
"string": "string",
"bool": "bool",
"int32": pa.types.is_int64,
"int64": pa.types.is_int64,
"float": pa.types.is_float64,
"string": pa.types.is_string,
"bool": pa.types.is_boolean,
"date": pa.types.is_date,
"datetime": pa.types.is_timestamp,
}
arrow_type = feature_list_dtype_to_expected_historical_feature_arrow_type[
arrow_type_checker = feature_list_dtype_to_expected_historical_feature_arrow_type[
feature_dtype
]
pa_type = historical_features_arrow.schema.field("value").type

if feature_is_list:
assert (
str(historical_features_arrow.schema.field_by_name("value").type)
== f"list<item: {arrow_type}>"
)
assert pa.types.is_list(pa_type)
assert arrow_type_checker(pa_type.value_type)
else:
assert (
str(historical_features_arrow.schema.field_by_name("value").type)
== arrow_type
)
assert arrow_type_checker(pa_type)