Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 78 additions & 11 deletions xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@

package io.grpc.xds;

import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY;
import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY;

import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.auth.oauth2.IdTokenCredentials;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.UnsignedLongs;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
Expand All @@ -34,8 +39,11 @@
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.StatusOr;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper;
import io.grpc.xds.MetadataRegistry.MetadataValueParser;
import io.grpc.xds.XdsConfig.XdsClusterConfig;
import io.grpc.xds.client.XdsResourceType.ResourceInvalidException;
import java.util.LinkedHashMap;
import java.util.Map;
Expand All @@ -52,6 +60,13 @@
static final String TYPE_URL =
"type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig";

final String filterInstanceName;

GcpAuthenticationFilter(String name) {
filterInstanceName = checkNotNull(name, "name");
}


static final class Provider implements Filter.Provider {
@Override
public String[] typeUrls() {
Expand All @@ -65,7 +80,7 @@

@Override
public GcpAuthenticationFilter newInstance(String name) {
return new GcpAuthenticationFilter();
return new GcpAuthenticationFilter(name);

Check warning on line 83 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L83

Added line #L83 was not covered by tests
}

@Override
Expand Down Expand Up @@ -119,22 +134,66 @@
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {

/*String clusterName = callOptions.getOption(XdsAttributes.ATTR_CLUSTER_NAME);
String clusterName = callOptions.getOption(CLUSTER_SELECTION_KEY);
if (clusterName == null) {
return new FailingClientCall<>(
Status.UNAVAILABLE.withDescription(
String.format(

Check warning on line 141 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L139-L141

Added lines #L139 - L141 were not covered by tests
"GCP Authn for %s does not contain cluster resource", filterInstanceName)));
}

if (!clusterName.startsWith("cluster:")) {
return next.newCall(method, callOptions);
}*/
}
XdsConfig xdsConfig = callOptions.getOption(XDS_CONFIG_CALL_OPTION_KEY);
if (xdsConfig == null) {
return new FailingClientCall<>(
Status.UNAVAILABLE.withDescription(
String.format(

Check warning on line 152 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L150-L152

Added lines #L150 - L152 were not covered by tests
"GCP Authn for %s with %s does not contain xds configuration",
filterInstanceName, clusterName)));
}

// TODO: Fetch the CDS resource for the cluster.
// If the CDS resource is not available, fail the RPC with Status.UNAVAILABLE.
StatusOr<XdsClusterConfig> xdsCluster =
xdsConfig.getClusters().get(clusterName.substring(8)); // get rid of prefix "cluster:"
if (xdsCluster == null) {
return new FailingClientCall<>(
Status.UNAVAILABLE.withDescription(
String.format(

Check warning on line 162 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L160-L162

Added lines #L160 - L162 were not covered by tests
"GCP Authn for %s with %s does not contain xds cluster",
filterInstanceName, clusterName)));
}

// TODO: Extract the audience from the CDS resource metadata.
// If the audience is not found or is in the wrong format, fail the RPC.
String audience = "TEST_AUDIENCE";
if (!xdsCluster.hasValue()) {
return next.newCall(method, callOptions);

Check warning on line 168 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L168

Added line #L168 was not covered by tests
}
ImmutableMap<String, Object> parsedMetadata = xdsCluster.getValue().getClusterResource()
.parsedMetadata();

if (parsedMetadata == null || !parsedMetadata.containsKey(filterInstanceName)) {
return next.newCall(method, callOptions);

Check warning on line 174 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L174

Added line #L174 was not covered by tests
}

AudienceWrapper audience;
if (parsedMetadata.get(filterInstanceName) instanceof AudienceWrapper) {
audience = (AudienceWrapper) parsedMetadata.get(filterInstanceName);
if (audience.audience == null) {
return next.newCall(method, callOptions);

Check warning on line 181 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L181

Added line #L181 was not covered by tests
}
}
else {
return new FailingClientCall<>(
Status.UNAVAILABLE.withDescription(
String.format("GCP Authn found wrong type in %s metadata: %s=%s",

Check warning on line 187 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L185-L187

Added lines #L185 - L187 were not covered by tests
clusterName, filterInstanceName,
parsedMetadata.get(filterInstanceName) == null
? null : parsedMetadata.get(filterInstanceName))));

Check warning on line 190 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L190

Added line #L190 was not covered by tests
}

try {
CallCredentials existingCallCredentials = callOptions.getCredentials();
CallCredentials newCallCredentials =
getCallCredentials(callCredentialsCache, audience, credentials);
getCallCredentials(callCredentialsCache, audience.audience, credentials);
if (existingCallCredentials != null) {
callOptions = callOptions.withCallCredentials(
new CompositeCallCredentials(existingCallCredentials, newCallCredentials));
Expand Down Expand Up @@ -235,13 +294,21 @@

static class AudienceMetadataParser implements MetadataValueParser {

static final class AudienceWrapper {
final String audience;

AudienceWrapper(String audience) {
this.audience = audience;
}
}

@Override
public String getTypeUrl() {
return "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.Audience";
}

@Override
public String parse(Any any) throws ResourceInvalidException {
public AudienceWrapper parse(Any any) throws ResourceInvalidException {
Audience audience;
try {
audience = any.unpack(Audience.class);
Expand All @@ -253,7 +320,7 @@
throw new ResourceInvalidException(
"Audience URL is empty. Metadata value must contain a valid URL.");
}
return url;
return new AudienceWrapper(url);
}
}
}
22 changes: 17 additions & 5 deletions xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package io.grpc.xds;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY;
import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;

import com.google.protobuf.Any;
import com.google.protobuf.Empty;
Expand All @@ -34,6 +37,7 @@
import io.grpc.Channel;
import io.grpc.ClientInterceptor;
import io.grpc.MethodDescriptor;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.TestMethodDescriptors;
import io.grpc.xds.GcpAuthenticationFilter.GcpAuthenticationConfig;
import org.junit.Test;
Expand Down Expand Up @@ -92,21 +96,29 @@ public void testParseFilterConfig_withInvalidMessageType() {
}

@Test
public void testClientInterceptor_createsAndReusesCachedCredentials() {
public void testClientInterceptor_createsAndReusesCachedCredentials() throws Exception {
String serverName = InProcessServerBuilder.generateName();
XdsConfig defaultXdsConfig = XdsTestUtils.getDefaultXdsConfigWithCdsUpdate(serverName);

GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter();
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");

// Create interceptor
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();

// Mock channel and capture CallOptions
Channel mockChannel = Mockito.mock(Channel.class);
Channel mockChannel = mock(Channel.class);
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class);

// Set CallOptions with required keys
CallOptions callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);

// Execute interception twice to check caching
interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel);
interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel);
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);

// Capture and verify CallOptions for CallCredentials presence
Mockito.verify(mockChannel, Mockito.times(2))
Expand Down
15 changes: 8 additions & 7 deletions xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
import io.grpc.xds.Endpoints.LbEndpoint;
import io.grpc.xds.Endpoints.LocalityLbEndpoints;
import io.grpc.xds.Filter.FilterConfig;
import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper;
import io.grpc.xds.MetadataRegistry.MetadataValueParser;
import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig;
import io.grpc.xds.VirtualHost.Route;
Expand Down Expand Up @@ -2417,8 +2418,7 @@ public Object parse(Any value) {
}

@Test
public void processCluster_parsesAudienceMetadata()
throws ResourceInvalidException, InvalidProtocolBufferException {
public void processCluster_parsesAudienceMetadata() throws Exception {
MetadataRegistry.getInstance();

Audience audience = Audience.newBuilder()
Expand Down Expand Up @@ -2462,7 +2462,10 @@ public void processCluster_parsesAudienceMetadata()
"FILTER_METADATA", ImmutableMap.of(
"key1", "value1",
"key2", 42.0));
assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata);
assertThat(update.parsedMetadata().get("FILTER_METADATA"))
.isEqualTo(expectedParsedMetadata.get("FILTER_METADATA"));
assertThat(update.parsedMetadata().get("AUDIENCE_METADATA"))
.isInstanceOf(AudienceWrapper.class);
}

@Test
Expand Down Expand Up @@ -2519,8 +2522,7 @@ public void processCluster_parsesAddressMetadata() throws Exception {
}

@Test
public void processCluster_metadataKeyCollision_resolvesToTypedMetadata()
throws ResourceInvalidException, InvalidProtocolBufferException {
public void processCluster_metadataKeyCollision_resolvesToTypedMetadata() throws Exception {
MetadataRegistry metadataRegistry = MetadataRegistry.getInstance();

MetadataValueParser testParser =
Expand Down Expand Up @@ -2575,8 +2577,7 @@ public Object parse(Any value) {
}

@Test
public void parseNonAggregateCluster_withHttp11ProxyTransportSocket()
throws ResourceInvalidException, InvalidProtocolBufferException {
public void parseNonAggregateCluster_withHttp11ProxyTransportSocket() throws Exception {
XdsClusterResource.isEnabledXdsHttpConnect = true;

Http11ProxyUpstreamTransport http11ProxyUpstreamTransport =
Expand Down
59 changes: 59 additions & 0 deletions xds/src/test/java/io/grpc/xds/XdsTestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import io.grpc.stub.StreamObserver;
import io.grpc.xds.Endpoints.LbEndpoint;
import io.grpc.xds.Endpoints.LocalityLbEndpoints;
import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper;
import io.grpc.xds.XdsClusterResource.CdsUpdate;
import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig;
import io.grpc.xds.client.Bootstrapper;
import io.grpc.xds.client.Locality;
Expand Down Expand Up @@ -280,6 +282,63 @@ static XdsConfig getDefaultXdsConfig(String serverHostName)
return builder.build();
}

static XdsConfig getDefaultXdsConfigWithCdsUpdate(String serverHostName)
throws XdsResourceType.ResourceInvalidException, IOException {
XdsConfig.XdsConfigBuilder builder = new XdsConfig.XdsConfigBuilder();

Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig(
serverHostName, RouterFilter.ROUTER_CONFIG);

HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName(
0L, RDS_NAME, Collections.singletonList(routerFilterConfig));
XdsListenerResource.LdsUpdate ldsUpdate =
XdsListenerResource.LdsUpdate.forApiListener(httpConnectionManager);

RouteConfiguration routeConfiguration =
buildRouteConfiguration(serverHostName, RDS_NAME, CLUSTER_NAME);
Bootstrapper.ServerInfo serverInfo = null;
XdsResourceType.Args args = new XdsResourceType.Args(serverInfo, "0", "0", null, null, null);
XdsRouteConfigureResource.RdsUpdate rdsUpdate =
XdsRouteConfigureResource.getInstance().doParse(args, routeConfiguration);

// Take advantage of knowing that there is only 1 virtual host in the route configuration
assertThat(rdsUpdate.virtualHosts).hasSize(1);
VirtualHost virtualHost = rdsUpdate.virtualHosts.get(0);

// Need to create endpoints to create locality endpoints map to create edsUpdate
Map<Locality, LocalityLbEndpoints> lbEndpointsMap = new HashMap<>();
LbEndpoint lbEndpoint = LbEndpoint.create(
serverHostName, ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of());
lbEndpointsMap.put(
Locality.create("", "", ""),
LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of()));

// Need to create EdsUpdate to create CdsUpdate to create XdsClusterConfig for builder
XdsEndpointResource.EdsUpdate edsUpdate = new XdsEndpointResource.EdsUpdate(
EDS_NAME, lbEndpointsMap, Collections.emptyList());

// Use ImmutableMap.Builder to construct the map
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("TEST_AUDIENCE"));

CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds(
CLUSTER_NAME, EDS_NAME, null, null, null, null, false)
.lbPolicyConfig(getWrrLbConfigAsMap());
cdsUpdate.parsedMetadata(parsedMetadata.build());
XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME,
cdsUpdate.build(),
new EndpointConfig(StatusOr.fromValue(edsUpdate)));

builder
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(virtualHost)
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig));

return builder.build();
}

static Map<Locality, LocalityLbEndpoints> createMinimalLbEndpointsMap(String serverHostName) {
Map<Locality, LocalityLbEndpoints> lbEndpointsMap = new HashMap<>();
LbEndpoint lbEndpoint = LbEndpoint.create(
Expand Down