Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Expand Down Expand Up @@ -303,6 +304,60 @@ def write_logged_features(
job_config=job_config,
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, BigQueryOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when bigquery type required"
)
if not isinstance(feature_view.batch_source, BigQuerySource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not bigquery source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

client = _get_bigquery_client(
project=config.offline_store.project_id,
location=config.offline_store.location,
)

job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET,
schema=arrow_schema_to_bq_schema(pa_schema),
write_disposition="WRITE_APPEND", # Default but included for clarity
)

with tempfile.TemporaryFile() as parquet_temp_file:
pyarrow.parquet.write_table(table=table, where=parquet_temp_file)

parquet_temp_file.seek(0)

client.load_table_from_file(
file_obj=parquet_temp_file,
destination=feature_view.batch_source.table,
job_config=job_config,
)


class BigQueryRetrievalJob(RetrievalJob):
def __init__(
Expand Down
28 changes: 18 additions & 10 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
get_pyarrow_schema_from_batch_source,
)
from feast.infra.provider import (
_get_requested_feature_views_to_features_dict,
Expand Down Expand Up @@ -408,7 +409,7 @@ def write_logged_features(
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
data: pyarrow.Table,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
Expand All @@ -423,20 +424,27 @@ def offline_write_batch(
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not file source"
)

pa_schema, column_names = get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

file_options = feature_view.batch_source.file_options
filesystem, path = FileSource.create_filesystem_and_path(
file_options.uri, file_options.s3_endpoint_override
)

prev_table = pyarrow.parquet.read_table(path, memory_map=True)
if prev_table.column_names != data.column_names:
raise ValueError(
f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}"
)
if data.schema != prev_table.schema:
data = data.cast(prev_table.schema)
new_table = pyarrow.concat_tables([data, prev_table])
writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem)
if table.schema != prev_table.schema:
table = table.cast(prev_table.schema)
new_table = pyarrow.concat_tables([table, prev_table])
writer = pyarrow.parquet.ParquetWriter(
path, table.schema, filesystem=filesystem
)
writer.write_table(new_table)
writer.close()

Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def write_logged_features(
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
data: pyarrow.Table,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
"""
Expand All @@ -286,8 +286,8 @@ def offline_write_batch(

Args:
config: Repo configuration object
table: FeatureView to write the data to.
data: pyarrow table containing feature data and timestamp column for historical feature retrieval
feature_view: FeatureView to write the data to.
table: pyarrow table containing feature data and timestamp column for historical feature retrieval
progress: Optional function to be called once every mini-batch of rows is written to
the online store. Can be used to display progress.
"""
Expand Down
26 changes: 26 additions & 0 deletions sdk/python/feast/infra/offline_stores/offline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import numpy as np
import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from pandas import Timestamp

from feast.data_source import DataSource
from feast.errors import (
EntityTimestampInferenceException,
FeastEntityDFMissingColumnsError,
Expand All @@ -17,6 +19,8 @@
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.provider import _get_requested_feature_views_to_features_dict
from feast.registry import BaseRegistry
from feast.repo_config import RepoConfig
from feast.type_map import feast_value_type_to_pa
from feast.utils import to_naive_utc

DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp"
Expand Down Expand Up @@ -217,3 +221,25 @@ def get_offline_store_from_config(offline_store_config: Any) -> OfflineStore:
class_name = qualified_name.replace("Config", "")
offline_store_class = import_class(module_name, class_name, "OfflineStore")
return offline_store_class()


def get_pyarrow_schema_from_batch_source(
config: RepoConfig, batch_source: DataSource
) -> Tuple[pa.Schema, List[str]]:
"""Returns the pyarrow schema and column names for the given batch source."""
column_names_and_types = batch_source.get_table_column_names_and_types(config)

pa_schema = []
column_names = []
for column_name, column_type in column_names_and_types:
pa_schema.append(
(
column_name,
feast_value_type_to_pa(
batch_source.source_datatype_to_feast_value_type()(column_type)
),
)
)
column_names.append(column_name)

return pa.schema(pa_schema), column_names
27 changes: 8 additions & 19 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from feast.registry import BaseRegistry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type
from feast.usage import log_exceptions_and_usage


Expand Down Expand Up @@ -318,33 +317,23 @@ def offline_write_batch(
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
)
redshift_options = feature_view.batch_source.redshift_options
redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
)

column_name_to_type = feature_view.batch_source.get_table_column_names_and_types(
config
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
pa_schema_list = []
column_names = []
for column_name, redshift_type in column_name_to_type:
pa_schema_list.append(
(
column_name,
feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)),
)
)
column_names.append(column_name)
pa_schema = pa.schema(pa_schema_list)
if column_names != table.column_names:
raise ValueError(
f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}"
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

redshift_options = feature_view.batch_source.redshift_options
redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

aws_utils.upload_arrow_table_to_redshift(
Expand Down
42 changes: 42 additions & 0 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Expand Down Expand Up @@ -306,6 +307,47 @@ def write_logged_features(
auto_create_table=True,
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, SnowflakeOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when snowflake type required"
)
if not isinstance(feature_view.batch_source, SnowflakeSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not snowflake source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

snowflake_conn = get_snowflake_conn(config.offline_store)

write_pandas(
snowflake_conn,
table.to_pandas(),
table_name=feature_view.batch_source.table,
auto_create_table=True,
)


class SnowflakeRetrievalJob(RetrievalJob):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = {
"file": ("local", FileDataSourceCreator),
"gcp": ("gcp", BigQueryDataSourceCreator),
"bigquery": ("gcp", BigQueryDataSourceCreator),
"redshift": ("aws", RedshiftDataSourceCreator),
"snowflake": ("aws", RedshiftDataSourceCreator),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources):


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_offline_stores
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_writing_consecutively_to_offline_store(environment, universal_data_sources):
store = environment.feature_store
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_offline_stores
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_push_features_and_read_from_offline_store(environment, universal_data_sources):
store = environment.feature_store
Expand Down