Skip to content

Commit d7e49a9

Browse files
committed
mypy compliant
1 parent e7fb5d5 commit d7e49a9

File tree

8 files changed

+42
-23
lines changed

8 files changed

+42
-23
lines changed

butterfree/configs/db/abstract_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
class AbstractWriteConfig(ABC):
88
"""Abstract class for database write configurations with spark."""
99

10+
@property
11+
@abstractmethod
12+
def database(self) -> str:
13+
"""Database name."""
14+
1015
@property
1116
@abstractmethod
1217
def mode(self) -> Any:

butterfree/configs/db/cassandra_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class CassandraConfig(AbstractWriteConfig):
2828
2929
"""
3030

31-
database = "cassandra"
32-
3331
def __init__(
3432
self,
3533
username: str = None,
@@ -52,6 +50,11 @@ def __init__(
5250
self.stream_output_mode = stream_output_mode
5351
self.stream_checkpoint_path = stream_checkpoint_path
5452

53+
@property
54+
def database(self) -> str:
55+
"""Database name."""
56+
return "cassandra"
57+
5558
@property
5659
def username(self) -> Optional[str]:
5760
"""Username used in connection to Cassandra DB."""

butterfree/configs/db/kafka_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def __init__(
4141
self.stream_output_mode = stream_output_mode
4242
self.stream_checkpoint_path = stream_checkpoint_path
4343

44+
@property
45+
def database(self) -> str:
46+
"""Database name."""
47+
return "kafka"
48+
4449
@property
4550
def kafka_topic(self) -> Optional[str]:
4651
"""Kafka topic name."""

butterfree/configs/db/metastore_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ class MetastoreConfig(AbstractWriteConfig):
2323
2424
"""
2525

26-
database = "metastore"
27-
2826
def __init__(
2927
self,
3028
path: str = None,
@@ -37,6 +35,11 @@ def __init__(
3735
self.format_ = format_
3836
self.file_system = file_system
3937

38+
@property
39+
def database(self) -> str:
40+
"""Database name."""
41+
return "metastore"
42+
4043
@property
4144
def path(self) -> Optional[str]:
4245
"""Bucket name."""

butterfree/load/writers/historical_feature_store_writer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Holds the Historical Feature Store writer class."""
22

33
import os
4-
from typing import Any, Union
4+
from typing import Any
55

66
from pyspark.sql.dataframe import DataFrame
77
from pyspark.sql.functions import dayofmonth, month, year
@@ -106,16 +106,17 @@ class HistoricalFeatureStoreWriter(Writer):
106106

107107
def __init__(
108108
self,
109-
db_config: Union[AbstractWriteConfig, MetastoreConfig] = None,
109+
db_config: AbstractWriteConfig = None,
110110
database: str = None,
111111
num_partitions: int = None,
112112
validation_threshold: float = DEFAULT_VALIDATION_THRESHOLD,
113113
debug_mode: bool = False,
114114
interval_mode: bool = False,
115115
check_schema_hook: Hook = None,
116116
):
117-
super(HistoricalFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
118-
self.db_config = db_config or MetastoreConfig()
117+
super(HistoricalFeatureStoreWriter, self).__init__(
118+
db_config or MetastoreConfig(), debug_mode, interval_mode
119+
)
119120
self.database = database or environment.get_variable(
120121
"FEATURE_STORE_HISTORICAL_DATABASE"
121122
)

butterfree/load/writers/online_feature_store_writer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ class OnlineFeatureStoreWriter(Writer):
8080

8181
def __init__(
8282
self,
83-
db_config: Union[AbstractWriteConfig, CassandraConfig] = None,
83+
db_config: AbstractWriteConfig = None,
8484
debug_mode: bool = False,
8585
write_to_entity: bool = False,
8686
interval_mode: bool = False,
8787
check_schema_hook: Hook = None,
8888
):
89-
super(OnlineFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
90-
self.db_config = db_config or CassandraConfig()
91-
self.write_to_entity = write_to_entity
89+
super(OnlineFeatureStoreWriter, self).__init__(
90+
db_config or CassandraConfig(), debug_mode, interval_mode, write_to_entity
91+
)
9292
self.check_schema_hook = check_schema_hook
9393

9494
@staticmethod

butterfree/load/writers/writer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pyspark.sql.dataframe import DataFrame
88

99
from butterfree.clients import SparkClient
10+
from butterfree.configs.db import AbstractWriteConfig
1011
from butterfree.hooks import HookableComponent
1112
from butterfree.transform import FeatureSet
1213

@@ -19,11 +20,19 @@ class Writer(ABC, HookableComponent):
1920
2021
"""
2122

22-
def __init__(self, debug_mode: bool = False, interval_mode: bool = False) -> None:
23+
def __init__(
24+
self,
25+
db_config: AbstractWriteConfig,
26+
debug_mode: bool = False,
27+
interval_mode: bool = False,
28+
write_to_entity: bool = False,
29+
) -> None:
2330
super().__init__()
31+
self.db_config = db_config
2432
self.transformations: List[Dict[str, Any]] = []
2533
self.debug_mode = debug_mode
2634
self.interval_mode = interval_mode
35+
self.write_to_entity = write_to_entity
2736

2837
def with_(
2938
self, transformer: Callable[..., DataFrame], *args: Any, **kwargs: Any

butterfree/migrations/database_migration/database_migration.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
55
from enum import Enum, auto
6-
from typing import Any, Dict, List, Set, Union
6+
from typing import Any, Dict, List, Set
77

8-
from butterfree.load.writers import (
9-
HistoricalFeatureStoreWriter,
10-
OnlineFeatureStoreWriter,
11-
)
8+
from butterfree.load.writers.writer import Writer
129
from butterfree.transform import FeatureSet
1310

1411

@@ -129,11 +126,7 @@ def _get_schema(self, table_name: str) -> List[Dict[str, Any]]:
129126
db_schema = []
130127
return db_schema
131128

132-
def apply_migration(
133-
self,
134-
feature_set: FeatureSet,
135-
writer: Union[HistoricalFeatureStoreWriter, OnlineFeatureStoreWriter],
136-
) -> None:
129+
def apply_migration(self, feature_set: FeatureSet, writer: Writer,) -> None:
137130
"""Apply the migration in the respective database.
138131
139132
Args:

0 commit comments

Comments
 (0)