Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import contextlib
import itertools
import logging
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -138,6 +138,38 @@ async def close(self):
def async_supported(self) -> SupportedAsyncMethods:
return SupportedAsyncMethods(read=True, write=True)

@staticmethod
def _table_tags(online_config, table_instance) -> list[dict[str, str]]:
table_instance_tags = table_instance.tags or {}
online_tags = online_config.tags or {}

common_tags = [
{"Key": key, "Value": table_instance_tags.get(key) or value}
for key, value in online_tags.items()
]
table_tags = [
{"Key": key, "Value": value}
for key, value in table_instance_tags.items()
if key not in online_tags
]
Comment on lines +146 to +154
Copy link
Contributor Author

@robhowley robhowley Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the table level tags override the global where applicable. eg

# in yaml, default to platform team owns the infra
team: platform
# in feature view py file we override for one particular instance
tags={"team": "product-team"}


return common_tags + table_tags

@staticmethod
def _update_tags(dynamodb_client, table_name: str, new_tags: list[dict[str, str]]):
table_arn = dynamodb_client.describe_table(TableName=table_name)["Table"][
"TableArn"
]
current_tags = dynamodb_client.list_tags_of_resource(ResourceArn=table_arn)[
"Tags"
]
if current_tags:
remove_keys = [tag["Key"] for tag in current_tags]
dynamodb_client.untag_resource(ResourceArn=table_arn, TagKeys=remove_keys)

if new_tags:
dynamodb_client.tag_resource(ResourceArn=table_arn, Tags=new_tags)

def update(
self,
config: RepoConfig,
Expand Down Expand Up @@ -167,40 +199,43 @@ def update(
online_config.endpoint_url,
online_config.session_based_auth,
)
# Add Tags attribute to creation request only if configured to prevent
# TagResource permission issues, even with an empty Tags array.
kwargs = (
{
"Tags": [
{"Key": key, "Value": value}
for key, value in online_config.tags.items()
]
}
if online_config.tags
else {}
)

do_tag_updates = defaultdict(bool)
for table_instance in tables_to_keep:
# Add Tags attribute to creation request only if configured to prevent
# TagResource permission issues, even with an empty Tags array.
table_tags = self._table_tags(online_config, table_instance)
kwargs = {"Tags": table_tags} if table_tags else {}

table_name = _get_table_name(online_config, config, table_instance)
try:
dynamodb_resource.create_table(
TableName=_get_table_name(online_config, config, table_instance),
TableName=table_name,
KeySchema=[{"AttributeName": "entity_id", "KeyType": "HASH"}],
AttributeDefinitions=[
{"AttributeName": "entity_id", "AttributeType": "S"}
],
BillingMode="PAY_PER_REQUEST",
**kwargs,
)

except ClientError as ce:
do_tag_updates[table_name] = True

# If the table creation fails with ResourceInUseException,
# it means the table already exists or is being created.
# Otherwise, re-raise the exception
if ce.response["Error"]["Code"] != "ResourceInUseException":
raise

for table_instance in tables_to_keep:
dynamodb_client.get_waiter("table_exists").wait(
TableName=_get_table_name(online_config, config, table_instance)
)
table_name = _get_table_name(online_config, config, table_instance)
dynamodb_client.get_waiter("table_exists").wait(TableName=table_name)
# once table is confirmed to exist, update the tags.
# tags won't be updated in the create_table call if the table already exists
if do_tag_updates[table_name]:
tags = self._table_tags(online_config, table_instance)
self._update_tags(dynamodb_client, table_name, tags)
Comment on lines +236 to +238
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the table already exists, then we should perform tag updates. otherwise we can skip that bc the tags would've been added in the create_table call


for table_to_delete in tables_to_delete:
_delete_table_idempotent(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from typing import Optional

import boto3
import pytest
Expand Down Expand Up @@ -32,6 +33,12 @@
@dataclass
class MockFeatureView:
name: str
tags: Optional[dict[str, str]] = None


@dataclass
class MockOnlineConfig:
tags: Optional[dict[str, str]] = None


@pytest.fixture
Expand Down Expand Up @@ -209,6 +216,13 @@ def test_dynamodb_online_store_online_write_batch(
assert [item[1] for item in stored_items] == list(features)


def _get_tags(dynamodb_client, table_name):
table_arn = dynamodb_client.describe_table(TableName=table_name)["Table"][
"TableArn"
]
return dynamodb_client.list_tags_of_resource(ResourceArn=table_arn).get("Tags")


@mock_dynamodb
def test_dynamodb_online_store_update(repo_config, dynamodb_online_store):
"""Test DynamoDBOnlineStore update method."""
Expand All @@ -222,7 +236,7 @@ def test_dynamodb_online_store_update(repo_config, dynamodb_online_store):
dynamodb_online_store.update(
config=repo_config,
tables_to_delete=[MockFeatureView(name=db_table_delete_name)],
tables_to_keep=[MockFeatureView(name=db_table_keep_name)],
tables_to_keep=[MockFeatureView(name=db_table_keep_name, tags={"some": "tag"})],
entities_to_delete=None,
entities_to_keep=None,
partial=None,
Expand All @@ -237,6 +251,98 @@ def test_dynamodb_online_store_update(repo_config, dynamodb_online_store):
assert len(existing_tables) == 1
assert existing_tables[0] == f"test_aws.{db_table_keep_name}"

assert _get_tags(dynamodb_client, existing_tables[0]) == [
{"Key": "some", "Value": "tag"}
]


@mock_dynamodb
def test_dynamodb_online_store_update_tags(repo_config, dynamodb_online_store):
"""Test DynamoDBOnlineStore update method."""
# create dummy table to update with new tags and tag values
table_name = f"{TABLE_NAME}_keep_update_tags"
create_test_table(PROJECT, table_name, REGION)

# add tags on update
dynamodb_online_store.update(
config=repo_config,
tables_to_delete=[],
tables_to_keep=[
MockFeatureView(
name=table_name, tags={"key1": "val1", "key2": "val2", "key3": "val3"}
)
],
entities_to_delete=[],
entities_to_keep=[],
partial=None,
)

# update tags
dynamodb_online_store.update(
config=repo_config,
tables_to_delete=[],
tables_to_keep=[
MockFeatureView(
name=table_name,
tags={"key1": "new-val1", "key2": "val2", "key4": "val4"},
)
],
entities_to_delete=[],
entities_to_keep=[],
partial=None,
)

# check only db_table_keep_name exists
dynamodb_client = dynamodb_online_store._get_dynamodb_client(REGION)
existing_tables = dynamodb_client.list_tables().get("TableNames", None)

expected_tags = [
{"Key": "key1", "Value": "new-val1"},
{"Key": "key2", "Value": "val2"},
{"Key": "key4", "Value": "val4"},
]
assert _get_tags(dynamodb_client, existing_tables[0]) == expected_tags

# and then remove all tags
dynamodb_online_store.update(
config=repo_config,
tables_to_delete=[],
tables_to_keep=[MockFeatureView(name=table_name, tags=None)],
entities_to_delete=[],
entities_to_keep=[],
partial=None,
)

assert _get_tags(dynamodb_client, existing_tables[0]) == []


@mock_dynamodb
@pytest.mark.parametrize(
"global_tags, table_tags, expected",
[
(None, {"key": "val"}, [{"Key": "key", "Value": "val"}]),
({"key": "val"}, None, [{"Key": "key", "Value": "val"}]),
(
{"key1": "val1"},
{"key2": "val2"},
[{"Key": "key1", "Value": "val1"}, {"Key": "key2", "Value": "val2"}],
),
(
{"key": "val", "key2": "val2"},
{"key": "new-val"},
[{"Key": "key", "Value": "new-val"}, {"Key": "key2", "Value": "val2"}],
),
],
)
def test_dynamodb_online_store_tag_priority(
global_tags, table_tags, expected, dynamodb_online_store
):
actual = dynamodb_online_store._table_tags(
MockOnlineConfig(tags=global_tags),
MockFeatureView(name="table", tags=table_tags),
)
assert actual == expected


@mock_dynamodb
def test_dynamodb_online_store_teardown(repo_config, dynamodb_online_store):
Expand Down
Loading