Skip to content

Commit 056cfa1

Browse files
authored
feat: Support retrieval from multiple feature views with different join keys (#2835)
* feat: Support retrieving from multiple feature views Signed-off-by: Yongheng Lin <[email protected]> * group by join keys instead of feature view Signed-off-by: Yongheng Lin <[email protected]> * tolerate insufficient entities Signed-off-by: Yongheng Lin <[email protected]> * mock registry.getEntityJoinKey Signed-off-by: Yongheng Lin <[email protected]> * add integration test Signed-off-by: Yongheng Lin <[email protected]>
1 parent 86e9efd commit 056cfa1

File tree

5 files changed

+140
-16
lines changed

5 files changed

+140
-16
lines changed

java/serving/src/main/java/feast/serving/registry/Registry.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class Registry {
3333
private Map<String, OnDemandFeatureViewProto.OnDemandFeatureViewSpec>
3434
onDemandFeatureViewNameToSpec;
3535
private final Map<String, FeatureServiceProto.FeatureServiceSpec> featureServiceNameToSpec;
36+
private final Map<String, String> entityNameToJoinKey;
3637

3738
Registry(RegistryProto.Registry registry) {
3839
this.registry = registry;
@@ -60,6 +61,12 @@ public class Registry {
6061
.collect(
6162
Collectors.toMap(
6263
FeatureServiceProto.FeatureServiceSpec::getName, Function.identity()));
64+
this.entityNameToJoinKey =
65+
registry.getEntitiesList().stream()
66+
.map(EntityProto.Entity::getSpec)
67+
.collect(
68+
Collectors.toMap(
69+
EntityProto.EntitySpecV2::getName, EntityProto.EntitySpecV2::getJoinKey));
6370
}
6471

6572
public RegistryProto.Registry getRegistry() {
@@ -115,4 +122,12 @@ public FeatureServiceProto.FeatureServiceSpec getFeatureServiceSpec(String name)
115122
}
116123
return spec;
117124
}
125+
126+
public String getEntityJoinKey(String name) {
127+
String joinKey = entityNameToJoinKey.get(name);
128+
if (joinKey == null) {
129+
throw new SpecRetrievalException(String.format("Unable to find entity with name: %s", name));
130+
}
131+
return joinKey;
132+
}
118133
}

java/serving/src/main/java/feast/serving/registry/RegistryRepository.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,8 @@ public Duration getMaxAge(ServingAPIProto.FeatureReferenceV2 featureReference) {
102102
public List<String> getEntitiesList(ServingAPIProto.FeatureReferenceV2 featureReference) {
103103
return getFeatureViewSpec(featureReference).getEntitiesList();
104104
}
105+
106+
public String getEntityJoinKey(String name) {
107+
return this.registry.getEntityJoinKey(name);
108+
}
105109
}

java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import feast.serving.registry.RegistryRepository;
3535
import feast.serving.util.Metrics;
3636
import feast.storage.api.retriever.OnlineRetrieverV2;
37-
import io.grpc.Status;
3837
import io.opentracing.Span;
3938
import io.opentracing.Tracer;
4039
import java.util.*;
@@ -51,6 +50,11 @@ public class OnlineServingServiceV2 implements ServingServiceV2 {
5150
private final OnlineTransformationService onlineTransformationService;
5251
private final String project;
5352

53+
public static final String DUMMY_ENTITY_ID = "__dummy_id";
54+
public static final String DUMMY_ENTITY_VAL = "";
55+
public static final ValueProto.Value DUMMY_ENTITY_VALUE =
56+
ValueProto.Value.newBuilder().setStringVal(DUMMY_ENTITY_VAL).build();
57+
5458
public OnlineServingServiceV2(
5559
OnlineRetrieverV2 retriever,
5660
Tracer tracer,
@@ -103,31 +107,18 @@ public ServingAPIProto.GetOnlineFeaturesResponse getOnlineFeatures(
103107

104108
List<Map<String, ValueProto.Value>> entityRows = getEntityRows(request);
105109

106-
List<String> entityNames;
107-
if (retrievedFeatureReferences.size() > 0) {
108-
entityNames = this.registryRepository.getEntitiesList(retrievedFeatureReferences.get(0));
109-
} else {
110-
throw new RuntimeException("Requested features list must not be empty");
111-
}
112-
113110
Span storageRetrievalSpan = tracer.buildSpan("storageRetrieval").start();
114111
if (storageRetrievalSpan != null) {
115112
storageRetrievalSpan.setTag("entities", entityRows.size());
116113
storageRetrievalSpan.setTag("features", retrievedFeatureReferences.size());
117114
}
115+
118116
List<List<feast.storage.api.retriever.Feature>> features =
119-
retriever.getOnlineFeatures(entityRows, retrievedFeatureReferences, entityNames);
117+
retrieveFeatures(retrievedFeatureReferences, entityRows);
120118

121119
if (storageRetrievalSpan != null) {
122120
storageRetrievalSpan.finish();
123121
}
124-
if (features.size() != entityRows.size()) {
125-
throw Status.INTERNAL
126-
.withDescription(
127-
"The no. of FeatureRow obtained from OnlineRetriever"
128-
+ "does not match no. of entityRow passed.")
129-
.asRuntimeException();
130-
}
131122

132123
Span postProcessingSpan = tracer.buildSpan("postProcessing").start();
133124

@@ -255,6 +246,84 @@ private List<Map<String, ValueProto.Value>> getEntityRows(
255246
return entityRows;
256247
}
257248

249+
private List<List<feast.storage.api.retriever.Feature>> retrieveFeatures(
250+
List<FeatureReferenceV2> featureReferences, List<Map<String, ValueProto.Value>> entityRows) {
251+
// Prepare feature reference to index mapping. This mapping will be used to arrange the
252+
// retrieved features to the same order as in the input.
253+
if (featureReferences.isEmpty()) {
254+
throw new RuntimeException("Requested features list must not be empty.");
255+
}
256+
Map<FeatureReferenceV2, Integer> featureReferenceToIndexMap =
257+
new HashMap<>(featureReferences.size());
258+
for (int i = 0; i < featureReferences.size(); i++) {
259+
FeatureReferenceV2 featureReference = featureReferences.get(i);
260+
if (featureReferenceToIndexMap.containsKey(featureReference)) {
261+
throw new RuntimeException(
262+
String.format(
263+
"Found duplicate features %s:%s.",
264+
featureReference.getFeatureViewName(), featureReference.getFeatureName()));
265+
}
266+
featureReferenceToIndexMap.put(featureReference, i);
267+
}
268+
269+
// Create placeholders for retrieved features.
270+
List<List<feast.storage.api.retriever.Feature>> features = new ArrayList<>(entityRows.size());
271+
for (int i = 0; i < entityRows.size(); i++) {
272+
List<feast.storage.api.retriever.Feature> featuresPerEntity =
273+
new ArrayList<>(featureReferences.size());
274+
for (int j = 0; j < featureReferences.size(); j++) {
275+
featuresPerEntity.add(null);
276+
}
277+
features.add(featuresPerEntity);
278+
}
279+
280+
// Group feature references by join keys.
281+
Map<String, List<FeatureReferenceV2>> groupNameToFeatureReferencesMap =
282+
featureReferences.stream()
283+
.collect(
284+
Collectors.groupingBy(
285+
featureReference ->
286+
this.registryRepository.getEntitiesList(featureReference).stream()
287+
.map(this.registryRepository::getEntityJoinKey)
288+
.sorted()
289+
.collect(Collectors.joining(","))));
290+
291+
// Retrieve features one group at a time.
292+
for (List<FeatureReferenceV2> featureReferencesPerGroup :
293+
groupNameToFeatureReferencesMap.values()) {
294+
List<String> entityNames =
295+
this.registryRepository.getEntitiesList(featureReferencesPerGroup.get(0));
296+
List<Map<String, ValueProto.Value>> entityRowsPerGroup = new ArrayList<>(entityRows.size());
297+
for (Map<String, ValueProto.Value> entityRow : entityRows) {
298+
Map<String, ValueProto.Value> entityRowPerGroup = new HashMap<>();
299+
entityNames.stream()
300+
.map(this.registryRepository::getEntityJoinKey)
301+
.forEach(
302+
joinKey -> {
303+
if (joinKey.equals(DUMMY_ENTITY_ID)) {
304+
entityRowPerGroup.put(joinKey, DUMMY_ENTITY_VALUE);
305+
} else {
306+
ValueProto.Value value = entityRow.get(joinKey);
307+
if (value != null) {
308+
entityRowPerGroup.put(joinKey, value);
309+
}
310+
}
311+
});
312+
entityRowsPerGroup.add(entityRowPerGroup);
313+
}
314+
List<List<feast.storage.api.retriever.Feature>> featuresPerGroup =
315+
retriever.getOnlineFeatures(entityRowsPerGroup, featureReferencesPerGroup, entityNames);
316+
for (int i = 0; i < featuresPerGroup.size(); i++) {
317+
for (int j = 0; j < featureReferencesPerGroup.size(); j++) {
318+
int k = featureReferenceToIndexMap.get(featureReferencesPerGroup.get(j));
319+
features.get(i).set(k, featuresPerGroup.get(i).get(j));
320+
}
321+
}
322+
}
323+
324+
return features;
325+
}
326+
258327
private void populateOnDemandFeatures(
259328
List<FeatureReferenceV2> onDemandFeatureReferences,
260329
List<FeatureReferenceV2> onDemandFeatureSources,

java/serving/src/test/java/feast/serving/it/ServingBaseTests.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,35 @@ public void shouldGetOnlineFeaturesWithStringEntity() {
172172
}
173173
}
174174

175+
@Test
176+
public void shouldGetOnlineFeaturesFromAllFeatureViews() {
177+
Map<String, ValueProto.RepeatedValue> entityRows =
178+
ImmutableMap.of(
179+
"entity",
180+
ValueProto.RepeatedValue.newBuilder()
181+
.addVal(DataGenerator.createStrValue("key-1"))
182+
.build(),
183+
"driver_id",
184+
ValueProto.RepeatedValue.newBuilder()
185+
.addVal(DataGenerator.createInt64Value(1005))
186+
.build());
187+
188+
ImmutableList<String> featureReferences =
189+
ImmutableList.of(
190+
"feature_view_0:feature_0",
191+
"feature_view_0:feature_1",
192+
"driver_hourly_stats:conv_rate",
193+
"driver_hourly_stats:avg_daily_trips");
194+
195+
ServingAPIProto.GetOnlineFeaturesRequest req =
196+
TestUtils.createOnlineFeatureRequest(featureReferences, entityRows);
197+
198+
ServingAPIProto.GetOnlineFeaturesResponse resp = servingStub.getOnlineFeatures(req);
199+
200+
for (final int featureIdx : List.of(0, 1, 2, 3)) {
201+
assertEquals(FieldStatus.PRESENT, resp.getResults(featureIdx).getStatuses(0));
202+
}
203+
}
204+
175205
abstract void updateRegistryFile(RegistryProto.Registry registry);
176206
}

java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfKeysPresent() {
170170
.thenReturn(featureSpecs.get(0));
171171
when(registry.getFeatureSpec(mockedFeatureRows.get(3).getFeatureReference()))
172172
.thenReturn(featureSpecs.get(1));
173+
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
174+
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");
173175

174176
when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));
175177

@@ -237,6 +239,8 @@ public void shouldReturnResponseWithUnsetValuesAndMetadataIfKeysNotPresent() {
237239
.thenReturn(featureSpecs.get(0));
238240
when(registry.getFeatureSpec(mockedFeatureRows.get(1).getFeatureReference()))
239241
.thenReturn(featureSpecs.get(1));
242+
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
243+
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");
240244

241245
when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));
242246

@@ -314,6 +318,8 @@ public void shouldReturnResponseWithValuesAndMetadataIfMaxAgeIsExceeded() {
314318
.thenReturn(featureSpecs.get(1));
315319
when(registry.getFeatureSpec(mockedFeatureRows.get(5).getFeatureReference()))
316320
.thenReturn(featureSpecs.get(0));
321+
when(registry.getEntityJoinKey("entity1")).thenReturn("entity1");
322+
when(registry.getEntityJoinKey("entity2")).thenReturn("entity2");
317323

318324
when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class));
319325

0 commit comments

Comments
 (0)