Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions butterfree/_cli/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import typer

from butterfree._cli import cli_logger
from butterfree.migrations.database_migration import ALLOWED_DATABASE
from butterfree.pipelines import FeatureSetPipeline

app = typer.Typer()
Expand Down Expand Up @@ -88,6 +89,28 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]:
)


class Migrate:
"""Execute migration operations in a Database based on pipeline Writer.

Attributes:
pipelines: list of Feature Set Pipelines to use to migration.
"""

def __init__(self, pipelines: Set[FeatureSetPipeline]) -> None:
self.pipelines = pipelines

def _send_logs_to_s3(self) -> None:
"""Send all migration logs to S3."""
pass

def run(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will not parse the feature sets within Butterfree anymore, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not necessary anymore 😺

"""Construct and apply the migrations."""
for pipeline in self.pipelines:
for writer in pipeline.sink.writers:
migration = ALLOWED_DATABASE[writer.db_config.database]
migration.apply_migration(pipeline.feature_set, writer)


@app.callback()
def migrate(path: str = PATH) -> Set[FeatureSetPipeline]:
"""Scan and run database migrations for feature set pipelines defined under PATH.
Expand All @@ -100,5 +123,6 @@ def migrate(path: str = PATH) -> Set[FeatureSetPipeline]:
All pipelines must be under python modules inside path, so we can dynamically
import and instantiate them.
"""
# TODO call the Migration actor with all feature set pipeline objects
return __fs_objects(path)
pipe_set = __fs_objects(path)
Migrate(pipe_set).run()
return pipe_set
5 changes: 5 additions & 0 deletions butterfree/configs/db/abstract_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
class AbstractWriteConfig(ABC):
"""Abstract class for database write configurations with spark."""

@property
@abstractmethod
def database(self) -> str:
"""Database name."""

@property
@abstractmethod
def mode(self) -> Any:
Expand Down
5 changes: 5 additions & 0 deletions butterfree/configs/db/cassandra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def __init__(
self.stream_output_mode = stream_output_mode
self.stream_checkpoint_path = stream_checkpoint_path

@property
def database(self) -> str:
"""Database name."""
return "cassandra"

@property
def username(self) -> Optional[str]:
"""Username used in connection to Cassandra DB."""
Expand Down
5 changes: 5 additions & 0 deletions butterfree/configs/db/kafka_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def __init__(
self.stream_output_mode = stream_output_mode
self.stream_checkpoint_path = stream_checkpoint_path

@property
def database(self) -> str:
"""Database name."""
return "kafka"

@property
def kafka_topic(self) -> Optional[str]:
"""Kafka topic name."""
Expand Down
5 changes: 5 additions & 0 deletions butterfree/configs/db/metastore_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def __init__(
self.format_ = format_
self.file_system = file_system

@property
def database(self) -> str:
"""Database name."""
return "metastore"

@property
def path(self) -> Optional[str]:
"""Bucket name."""
Expand Down
9 changes: 5 additions & 4 deletions butterfree/load/writers/historical_feature_store_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Holds the Historical Feature Store writer class."""

import os
from typing import Any, Union
from typing import Any

from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import dayofmonth, month, year
Expand Down Expand Up @@ -106,16 +106,17 @@ class HistoricalFeatureStoreWriter(Writer):

def __init__(
self,
db_config: Union[AbstractWriteConfig, MetastoreConfig] = None,
db_config: AbstractWriteConfig = None,
database: str = None,
num_partitions: int = None,
validation_threshold: float = DEFAULT_VALIDATION_THRESHOLD,
debug_mode: bool = False,
interval_mode: bool = False,
check_schema_hook: Hook = None,
):
super(HistoricalFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
self.db_config = db_config or MetastoreConfig()
super(HistoricalFeatureStoreWriter, self).__init__(
db_config or MetastoreConfig(), debug_mode, interval_mode
)
self.database = database or environment.get_variable(
"FEATURE_STORE_HISTORICAL_DATABASE"
)
Expand Down
8 changes: 4 additions & 4 deletions butterfree/load/writers/online_feature_store_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ class OnlineFeatureStoreWriter(Writer):

def __init__(
self,
db_config: Union[AbstractWriteConfig, CassandraConfig] = None,
db_config: AbstractWriteConfig = None,
debug_mode: bool = False,
write_to_entity: bool = False,
interval_mode: bool = False,
check_schema_hook: Hook = None,
):
super(OnlineFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
self.db_config = db_config or CassandraConfig()
self.write_to_entity = write_to_entity
super(OnlineFeatureStoreWriter, self).__init__(
db_config or CassandraConfig(), debug_mode, interval_mode, write_to_entity
)
self.check_schema_hook = check_schema_hook

@staticmethod
Expand Down
11 changes: 10 additions & 1 deletion butterfree/load/writers/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyspark.sql.dataframe import DataFrame

from butterfree.clients import SparkClient
from butterfree.configs.db import AbstractWriteConfig
from butterfree.hooks import HookableComponent
from butterfree.transform import FeatureSet

Expand All @@ -19,11 +20,19 @@ class Writer(ABC, HookableComponent):

"""

def __init__(self, debug_mode: bool = False, interval_mode: bool = False) -> None:
def __init__(
self,
db_config: AbstractWriteConfig,
debug_mode: bool = False,
interval_mode: bool = False,
write_to_entity: bool = False,
) -> None:
super().__init__()
self.db_config = db_config
self.transformations: List[Dict[str, Any]] = []
self.debug_mode = debug_mode
self.interval_mode = interval_mode
self.write_to_entity = write_to_entity

def with_(
self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any
Expand Down
3 changes: 0 additions & 3 deletions butterfree/migrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
"""Holds available migrations."""
from butterfree.migrations.migrate import Migrate

__all__ = ["Migrate"]
6 changes: 6 additions & 0 deletions butterfree/migrations/database_migration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@
)

__all__ = ["CassandraMigration", "MetastoreMigration", "Diff"]


ALLOWED_DATABASE = {
"cassandra": CassandraMigration(),
"metastore": MetastoreMigration(),
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ class CassandraMigration(DatabaseMigration):

def __init__(self) -> None:
self._db_config = CassandraConfig()
self._client = CassandraClient(
host=[self._db_config.host],
keyspace=self._db_config.keyspace, # type: ignore
user=self._db_config.username,
password=self._db_config.password,
super(CassandraMigration, self).__init__(
CassandraClient(
host=[self._db_config.host],
keyspace=self._db_config.keyspace, # type: ignore
user=self._db_config.username,
password=self._db_config.password,
)
)

@staticmethod
Expand Down
49 changes: 41 additions & 8 deletions butterfree/migrations/database_migration/database_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from enum import Enum, auto
from typing import Any, Dict, List, Set

from butterfree.clients import AbstractClient
from butterfree.load.writers.writer import Writer
from butterfree.transform import FeatureSet


Expand Down Expand Up @@ -40,6 +42,9 @@ def __eq__(self, other: object) -> bool:
class DatabaseMigration(ABC):
"""Abstract base class for Migrations."""

def __init__(self, client: AbstractClient) -> None:
self._client = client

@abstractmethod
def _get_create_table_query(
self, columns: List[Dict[str, Any]], table_name: str
Expand Down Expand Up @@ -173,10 +178,6 @@ def create_query(

return self._get_queries(schema_diff, table_name, write_on_entity)

def _apply_migration(self, feature_set: FeatureSet) -> None:
"""Apply the migration in the respective database."""
pass

@staticmethod
def _get_diff(
fs_schema: List[Dict[str, Any]], db_schema: List[Dict[str, Any]],
Expand Down Expand Up @@ -238,11 +239,43 @@ def _get_diff(
)
return schema_diff

def run(self, feature_set: FeatureSet) -> None:
"""Runs the migrations.
def _get_schema(self, table_name: str) -> List[Dict[str, Any]]:
"""Get a table schema in the respective database.

Args:
feature_set: the feature set.
table_name: Table name to get schema.

Returns:
Schema object.
"""
pass
try:
db_schema = self._client.get_schema(table_name)
except Exception: # noqa
db_schema = []
return db_schema

def apply_migration(self, feature_set: FeatureSet, writer: Writer,) -> None:
"""Apply the migration in the respective database.

Args:
feature_set: the feature set.
writer: the writer being used to load the feature set.
"""
logging.info(f"Migrating feature set: {feature_set.name}")

table_name = (
feature_set.name if not writer.write_to_entity else feature_set.entity
)

fs_schema = writer.db_config.translate(feature_set.get_schema())
db_schema = self._get_schema(table_name)

queries = self.create_query(
fs_schema, table_name, db_schema, writer.write_to_entity
)

for q in queries:
logging.info(f"Applying {q}...")
self._client.sql(q)

logging.info(f"Feature Set migration finished successfully.")
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import Any, Dict, List

from butterfree.clients import SparkClient
from butterfree.configs import environment
from butterfree.configs.db import MetastoreConfig
from butterfree.constants.migrations import PARTITION_BY
from butterfree.migrations.database_migration.database_migration import (
DatabaseMigration,
Expand All @@ -28,12 +30,12 @@ class MetastoreMigration(DatabaseMigration):
data is being loaded into an entity table, then users can drop columns manually.
"""

def __init__(
self, database: str = None,
):
def __init__(self, database: str = None,) -> None:
self._db_config = MetastoreConfig()
self.database = database or environment.get_variable(
"FEATURE_STORE_HISTORICAL_DATABASE"
)
super(MetastoreMigration, self).__init__(SparkClient())

@staticmethod
def _get_parsed_columns(columns: List[Diff]) -> List[str]:
Expand Down
41 changes: 0 additions & 41 deletions butterfree/migrations/migrate.py

This file was deleted.

23 changes: 22 additions & 1 deletion tests/unit/butterfree/_cli/test_migrate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
from unittest.mock import call

from butterfree._cli import migrate
from butterfree.migrations.database_migration import (
CassandraMigration,
MetastoreMigration,
)
from butterfree.pipelines import FeatureSetPipeline


def test_migrate_success():
def test_migrate_success(mocker):
mocker.patch.object(migrate.Migrate, "run")
all_fs = migrate.migrate("tests/mocks/entities/")
assert all(isinstance(fs, FeatureSetPipeline) for fs in all_fs)
assert sorted([fs.feature_set.name for fs in all_fs]) == ["first", "second"]


def test_migrate_all_pairs(mocker):
mocker.patch.object(MetastoreMigration, "apply_migration")
mocker.patch.object(CassandraMigration, "apply_migration")
all_fs = migrate.migrate("tests/mocks/entities/")

assert MetastoreMigration.apply_migration.call_count == 2
assert CassandraMigration.apply_migration.call_count == 2

metastore_pairs = [call(pipe.feature_set, pipe.sink.writers[0]) for pipe in all_fs]
cassandra_pairs = [call(pipe.feature_set, pipe.sink.writers[1]) for pipe in all_fs]
MetastoreMigration.apply_migration.assert_has_calls(metastore_pairs, any_order=True)
CassandraMigration.apply_migration.assert_has_calls(cassandra_pairs, any_order=True)
Loading