Skip to content

Commit 70d4a13

Browse files
authored
fix: Dynamodb deduplicate batch write request by partition keys (#2515)
Signed-off-by: Miguel Trejo <[email protected]>
1 parent 6bf8df0 commit 70d4a13

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

sdk/python/feast/infra/online_stores/dynamodb.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,7 @@ def online_write_batch(
191191
table_instance = dynamodb_resource.Table(
192192
_get_table_name(online_config, config, table)
193193
)
194-
with table_instance.batch_writer() as batch:
195-
for entity_key, features, timestamp, created_ts in data:
196-
entity_id = compute_entity_id(entity_key)
197-
batch.put_item(
198-
Item={
199-
"entity_id": entity_id, # PartitionKey
200-
"event_ts": str(utils.make_tzaware(timestamp)),
201-
"values": {
202-
k: v.SerializeToString()
203-
for k, v in features.items() # Serialized Features
204-
},
205-
}
206-
)
207-
if progress:
208-
progress(1)
194+
self._write_batch_non_duplicates(table_instance, data, progress)
209195

210196
@log_exceptions_and_usage(online_store="dynamodb")
211197
def online_read(
@@ -299,6 +285,32 @@ def _sort_dynamodb_response(self, responses: list, order: list):
299285
_, table_responses_ordered = zip(*table_responses_ordered)
300286
return table_responses_ordered
301287

288+
@log_exceptions_and_usage(online_store="dynamodb")
289+
def _write_batch_non_duplicates(
290+
self,
291+
table_instance,
292+
data: List[
293+
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
294+
],
295+
progress: Optional[Callable[[int], Any]],
296+
):
297+
"""Deduplicate write batch request items on ``entity_id`` primary key."""
298+
with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
299+
for entity_key, features, timestamp, created_ts in data:
300+
entity_id = compute_entity_id(entity_key)
301+
batch.put_item(
302+
Item={
303+
"entity_id": entity_id, # PartitionKey
304+
"event_ts": str(utils.make_tzaware(timestamp)),
305+
"values": {
306+
k: v.SerializeToString()
307+
for k, v in features.items() # Serialized Features
308+
},
309+
}
310+
)
311+
if progress:
312+
progress(1)
313+
302314

303315
def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
304316
return boto3.client("dynamodb", region_name=region, endpoint_url=endpoint_url)

sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from copy import deepcopy
12
from dataclasses import dataclass
23

4+
import boto3
35
import pytest
46
from moto import mock_dynamodb2
57

@@ -162,7 +164,7 @@ def test_online_read(repo_config, n_samples):
162164
data = _create_n_customer_test_samples(n=n_samples)
163165
_insert_data_test_table(data, PROJECT, f"{TABLE_NAME}_{n_samples}", REGION)
164166

165-
entity_keys, features = zip(*data)
167+
entity_keys, features, *rest = zip(*data)
166168
dynamodb_store = DynamoDBOnlineStore()
167169
returned_items = dynamodb_store.online_read(
168170
config=repo_config,
@@ -171,3 +173,24 @@ def test_online_read(repo_config, n_samples):
171173
)
172174
assert len(returned_items) == len(data)
173175
assert [item[1] for item in returned_items] == list(features)
176+
177+
178+
@mock_dynamodb2
179+
def test_write_batch_non_duplicates(repo_config):
180+
"""Test DynamoDBOnline Store deduplicate write batch request items."""
181+
dynamodb_tbl = f"{TABLE_NAME}_batch_non_duplicates"
182+
_create_test_table(PROJECT, dynamodb_tbl, REGION)
183+
data = _create_n_customer_test_samples()
184+
data_duplicate = deepcopy(data)
185+
dynamodb_resource = boto3.resource("dynamodb", region_name=REGION)
186+
table_instance = dynamodb_resource.Table(f"{PROJECT}.{dynamodb_tbl}")
187+
dynamodb_store = DynamoDBOnlineStore()
188+
# Insert duplicate data
189+
dynamodb_store._write_batch_non_duplicates(
190+
table_instance, data + data_duplicate, progress=None
191+
)
192+
# Request more items than inserted
193+
response = table_instance.scan(Limit=20)
194+
returned_items = response.get("Items", None)
195+
assert returned_items is not None
196+
assert len(returned_items) == len(data)

sdk/python/tests/utils/online_store_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def _create_n_customer_test_samples(n=10):
1919
"name": ValueProto(string_val="John"),
2020
"age": ValueProto(int64_val=3),
2121
},
22+
datetime.utcnow(),
23+
None,
2224
)
2325
for i in range(n)
2426
]
@@ -42,13 +44,13 @@ def _delete_test_table(project, tbl_name, region):
4244
def _insert_data_test_table(data, project, tbl_name, region):
4345
dynamodb_resource = boto3.resource("dynamodb", region_name=region)
4446
table_instance = dynamodb_resource.Table(f"{project}.{tbl_name}")
45-
for entity_key, features in data:
47+
for entity_key, features, timestamp, created_ts in data:
4648
entity_id = compute_entity_id(entity_key)
4749
with table_instance.batch_writer() as batch:
4850
batch.put_item(
4951
Item={
5052
"entity_id": entity_id,
51-
"event_ts": str(utils.make_tzaware(datetime.utcnow())),
53+
"event_ts": str(utils.make_tzaware(timestamp)),
5254
"values": {k: v.SerializeToString() for k, v in features.items()},
5355
}
5456
)

0 commit comments

Comments
 (0)