Skip to content

Commit 4b52639

Browse files
authored
xds: implement per-RPC hash generation (#7922)
Generates a hash value for each RPC based on the HashPolicies configured for the Route that the RPC is routed to.
1 parent b5c0a4a commit 4b52639

File tree

4 files changed

+268
-124
lines changed

4 files changed

+268
-124
lines changed

xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
interface ThreadSafeRandom {
2424
int nextInt(int bound);
2525

26+
long nextLong();
27+
2628
final class ThreadSafeRandomImpl implements ThreadSafeRandom {
2729

2830
static final ThreadSafeRandom instance = new ThreadSafeRandomImpl();
@@ -33,5 +35,10 @@ private ThreadSafeRandomImpl() {}
3335
public int nextInt(int bound) {
3436
return ThreadLocalRandom.current().nextInt(bound);
3537
}
38+
39+
@Override
40+
public long nextLong() {
41+
return ThreadLocalRandom.current().nextLong();
42+
}
3643
}
3744
}

xds/src/main/java/io/grpc/xds/XdsNameResolver.java

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import io.grpc.xds.VirtualHost.Route;
5858
import io.grpc.xds.VirtualHost.Route.RouteAction;
5959
import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight;
60+
import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy;
6061
import io.grpc.xds.VirtualHost.Route.RouteMatch;
6162
import io.grpc.xds.XdsClient.LdsResourceWatcher;
6263
import io.grpc.xds.XdsClient.LdsUpdate;
@@ -95,6 +96,8 @@ final class XdsNameResolver extends NameResolver {
9596

9697
static final CallOptions.Key<String> CLUSTER_SELECTION_KEY =
9798
CallOptions.Key.create("io.grpc.xds.CLUSTER_SELECTION_KEY");
99+
static final CallOptions.Key<Long> RPC_HASH_KEY =
100+
CallOptions.Key.create("io.grpc.xds.RPC_HASH_KEY");
98101
@VisibleForTesting
99102
static boolean enableTimeout =
100103
Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT"));
@@ -119,13 +122,15 @@ final class XdsNameResolver extends NameResolver {
119122
@VisibleForTesting
120123
static AtomicLong activeFaultInjectedStreamCounter = new AtomicLong();
121124

125+
private final InternalLogId logId;
122126
private final XdsLogger logger;
123127
private final String authority;
124128
private final ServiceConfigParser serviceConfigParser;
125129
private final SynchronizationContext syncContext;
126130
private final ScheduledExecutorService scheduler;
127131
private final XdsClientPoolFactory xdsClientPoolFactory;
128132
private final ThreadSafeRandom random;
133+
private final XxHash64 hashFunc = XxHash64.INSTANCE;
129134
private final ConcurrentMap<String, AtomicInteger> clusterRefs = new ConcurrentHashMap<>();
130135
private final ConfigSelector configSelector = new ConfigSelector();
131136

@@ -152,7 +157,8 @@ final class XdsNameResolver extends NameResolver {
152157
this.scheduler = checkNotNull(scheduler, "scheduler");
153158
this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory");
154159
this.random = checkNotNull(random, "random");
155-
logger = XdsLogger.withLogId(InternalLogId.allocate("xds-resolver", name));
160+
logId = InternalLogId.allocate("xds-resolver", name);
161+
logger = XdsLogger.withLogId(logId);
156162
logger.log(XdsLogLevel.INFO, "Created resolver for {0}", name);
157163
}
158164

@@ -347,26 +353,33 @@ static boolean matchHostName(String hostName, String pattern) {
347353
private final class ConfigSelector extends InternalConfigSelector {
348354
@Override
349355
public Result selectConfig(PickSubchannelArgs args) {
350-
// Index ASCII headers by keys.
351-
Map<String, Iterable<String>> asciiHeaders = new HashMap<>();
356+
// Index ASCII headers by key, multi-value headers are concatenated for matching purposes.
357+
Map<String, String> asciiHeaders = new HashMap<>();
352358
Metadata headers = args.getHeaders();
353359
for (String headerName : headers.keys()) {
354360
if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
355361
continue;
356362
}
357363
Metadata.Key<String> key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER);
358-
asciiHeaders.put(headerName, headers.getAll(key));
364+
Iterable<String> values = headers.getAll(key);
365+
if (values != null) {
366+
asciiHeaders.put(headerName, Joiner.on(",").join(values));
367+
}
359368
}
369+
// Special hack for exposing headers: "content-type".
370+
asciiHeaders.put("content-type", "application/grpc");
360371
String cluster = null;
361372
Route selectedRoute = null;
362373
HttpFault selectedFaultConfig;
374+
RoutingConfig routingCfg;
363375
do {
364-
selectedFaultConfig = routingConfig.faultConfig;
365-
for (Route route : routingConfig.routes) {
376+
routingCfg = routingConfig;
377+
selectedFaultConfig = routingCfg.faultConfig;
378+
for (Route route : routingCfg.routes) {
366379
if (matchRoute(route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(),
367380
asciiHeaders, random)) {
368381
selectedRoute = route;
369-
if (routingConfig.applyFaultInjection && route.httpFault() != null) {
382+
if (routingCfg.applyFaultInjection && route.httpFault() != null) {
370383
selectedFaultConfig = route.httpFault();
371384
}
372385
break;
@@ -390,7 +403,7 @@ public Result selectConfig(PickSubchannelArgs args) {
390403
accumulator += weightedCluster.weight();
391404
if (select < accumulator) {
392405
cluster = weightedCluster.name();
393-
if (routingConfig.applyFaultInjection && weightedCluster.httpFault() != null) {
406+
if (routingCfg.applyFaultInjection && weightedCluster.httpFault() != null) {
394407
selectedFaultConfig = weightedCluster.httpFault();
395408
}
396409
break;
@@ -403,7 +416,7 @@ public Result selectConfig(PickSubchannelArgs args) {
403416
if (enableTimeout) {
404417
Long timeoutNano = selectedRoute.routeAction().timeoutNano();
405418
if (timeoutNano == null) {
406-
timeoutNano = routingConfig.fallbackTimeoutNano;
419+
timeoutNano = routingCfg.fallbackTimeoutNano;
407420
}
408421
if (timeoutNano > 0) {
409422
rawServiceConfig = generateServiceConfigWithMethodTimeoutConfig(timeoutNano);
@@ -417,7 +430,6 @@ public Result selectConfig(PickSubchannelArgs args) {
417430
parsedServiceConfig.getError().augmentDescription(
418431
"Failed to parse service config (method config)"));
419432
}
420-
final String finalCluster = cluster;
421433
if (selectedFaultConfig != null && selectedFaultConfig.maxActiveFaults() != null
422434
&& activeFaultInjectedStreamCounter.get() >= selectedFaultConfig.maxActiveFaults()) {
423435
selectedFaultConfig = null;
@@ -447,15 +459,18 @@ public Result selectConfig(PickSubchannelArgs args) {
447459
abortStatus = determineFaultAbortStatus(selectedFaultConfig.faultAbort(), headers);
448460
}
449461
}
462+
final String finalCluster = cluster;
450463
final Long finalDelayNanos = delayNanos;
451464
final Status finalAbortStatus = abortStatus;
465+
final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), asciiHeaders);
452466
class ClusterSelectionInterceptor implements ClientInterceptor {
453467
@Override
454468
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
455469
final MethodDescriptor<ReqT, RespT> method, CallOptions callOptions,
456470
final Channel next) {
457471
final CallOptions callOptionsForCluster =
458-
callOptions.withOption(CLUSTER_SELECTION_KEY, finalCluster);
472+
callOptions.withOption(CLUSTER_SELECTION_KEY, finalCluster)
473+
.withOption(RPC_HASH_KEY, hash);
459474
Supplier<ClientCall<ReqT, RespT>> configApplyingCallSupplier =
460475
new Supplier<ClientCall<ReqT, RespT>>() {
461476
@Override
@@ -553,6 +568,36 @@ public void run() {
553568
}
554569
}
555570

571+
private long generateHash(List<HashPolicy> hashPolicies, Map<String, String> headers) {
572+
Long hash = null;
573+
for (HashPolicy policy : hashPolicies) {
574+
Long newHash = null;
575+
if (policy.type() == HashPolicy.Type.HEADER) {
576+
if (headers.containsKey(policy.headerName())) {
577+
String value = headers.get(policy.headerName());
578+
if (policy.regEx() != null && policy.regExSubstitution() != null) {
579+
value = policy.regEx().matcher(value).replaceAll(policy.regExSubstitution());
580+
}
581+
newHash = hashFunc.hashAsciiString(value);
582+
}
583+
} else if (policy.type() == HashPolicy.Type.CHANNEL_ID) {
584+
newHash = hashFunc.hashLong(logId.getId());
585+
}
586+
if (newHash != null ) {
587+
// Rotating the old value prevents duplicate hash rules from cancelling each other out
588+
// and preserves all of the entropy.
589+
long oldHash = hash != null ? ((hash << 1L) | (hash >> 63L)) : 0;
590+
hash = oldHash ^ newHash;
591+
}
592+
// If the policy is a terminal policy and a hash has been generated, ignore
593+
// the rest of the hash policies.
594+
if (policy.isTerminal() && hash != null) {
595+
break;
596+
}
597+
}
598+
return hash == null ? random.nextLong() : hash;
599+
}
600+
556601
@Nullable
557602
private Long determineFaultDelayNanos(FaultDelay faultDelay, Metadata headers) {
558603
Long delayNanos;
@@ -748,7 +793,7 @@ public void sendMessage(ReqT message) {}
748793

749794
@VisibleForTesting
750795
static boolean matchRoute(RouteMatch routeMatch, String fullMethodName,
751-
Map<String, Iterable<String>> headers, ThreadSafeRandom random) {
796+
Map<String, String> headers, ThreadSafeRandom random) {
752797
if (!matchPath(routeMatch.pathMatcher(), fullMethodName)) {
753798
return false;
754799
}
@@ -774,52 +819,41 @@ static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) {
774819
}
775820

776821
private static boolean matchHeaders(
777-
List<HeaderMatcher> headerMatchers, Map<String, Iterable<String>> headers) {
822+
List<HeaderMatcher> headerMatchers, Map<String, String> headers) {
778823
for (HeaderMatcher headerMatcher : headerMatchers) {
779-
Iterable<String> headerValues = headers.get(headerMatcher.name());
780-
// Special cases for hiding headers: "grpc-previous-rpc-attempts".
781-
if (headerMatcher.name().equals("grpc-previous-rpc-attempts")) {
782-
headerValues = null;
783-
}
784-
// Special case for exposing headers: "content-type".
785-
if (headerMatcher.name().equals("content-type")) {
786-
headerValues = Collections.singletonList("application/grpc");
787-
}
788-
if (!matchHeader(headerMatcher, headerValues)) {
824+
if (!matchHeader(headerMatcher, headers.get(headerMatcher.name()))) {
789825
return false;
790826
}
791827
}
792828
return true;
793829
}
794830

795831
@VisibleForTesting
796-
static boolean matchHeader(HeaderMatcher headerMatcher,
797-
@Nullable Iterable<String> headerValues) {
832+
static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) {
798833
if (headerMatcher.present() != null) {
799-
return (headerValues == null) == headerMatcher.present().equals(headerMatcher.inverted());
834+
return (value == null) == headerMatcher.present().equals(headerMatcher.inverted());
800835
}
801-
if (headerValues == null) {
836+
if (value == null) {
802837
return false;
803838
}
804-
String valueStr = Joiner.on(",").join(headerValues);
805839
boolean baseMatch;
806840
if (headerMatcher.exactValue() != null) {
807-
baseMatch = headerMatcher.exactValue().equals(valueStr);
841+
baseMatch = headerMatcher.exactValue().equals(value);
808842
} else if (headerMatcher.safeRegEx() != null) {
809-
baseMatch = headerMatcher.safeRegEx().matches(valueStr);
843+
baseMatch = headerMatcher.safeRegEx().matches(value);
810844
} else if (headerMatcher.range() != null) {
811845
long numValue;
812846
try {
813-
numValue = Long.parseLong(valueStr);
847+
numValue = Long.parseLong(value);
814848
baseMatch = numValue >= headerMatcher.range().start()
815849
&& numValue <= headerMatcher.range().end();
816850
} catch (NumberFormatException ignored) {
817851
baseMatch = false;
818852
}
819853
} else if (headerMatcher.prefix() != null) {
820-
baseMatch = valueStr.startsWith(headerMatcher.prefix());
854+
baseMatch = value.startsWith(headerMatcher.prefix());
821855
} else {
822-
baseMatch = valueStr.endsWith(headerMatcher.suffix());
856+
baseMatch = value.endsWith(headerMatcher.suffix());
823857
}
824858
return baseMatch != headerMatcher.inverted();
825859
}
@@ -1033,7 +1067,7 @@ public void run() {
10331067
}
10341068

10351069
/**
1036-
* Grouping of the list of usable routes and their corresponding fallback timeout value.
1070+
* VirtualHost-level configuration for request routing.
10371071
*/
10381072
private static class RoutingConfig {
10391073
private final long fallbackTimeoutNano;
@@ -1042,7 +1076,7 @@ private static class RoutingConfig {
10421076
@Nullable
10431077
private final HttpFault faultConfig;
10441078

1045-
private static RoutingConfig empty =
1079+
private static final RoutingConfig empty =
10461080
new RoutingConfig(0L, Collections.<Route>emptyList(), false, null);
10471081

10481082
private RoutingConfig(

xds/src/test/java/io/grpc/xds/WeightedRandomPickerTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ public int nextInt(int bound) {
9797
assertThat(nextInt).isLessThan(bound);
9898
return nextInt;
9999
}
100+
101+
@Override
102+
public long nextLong() {
103+
throw new UnsupportedOperationException("Should not be called");
104+
}
100105
}
101106

102107
private final FakeRandom fakeRandom = new FakeRandom();

0 commit comments

Comments
 (0)