Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,14 @@ private void SetMessageContent(TRequest request, HttpRequestMessage message)
GrpcProtocolConstants.GrpcContentTypeHeaderValue);
}

public void CancelCallFromCancellationToken()
{
using (StartScope())
{
CancelCall(new Status(StatusCode.Cancelled, "Call canceled by the client."));
}
}

private void CancelCall(Status status)
{
// Set overall call status first. Status can be used in throw RpcException from cancellation.
Expand Down Expand Up @@ -690,13 +698,7 @@ public Exception CreateFailureStatusException(Status status)
// The cancellation token will cancel the call CTS.
// This must be registered after the client writer has been created
// so that cancellation will always complete the writer.
_ctsRegistration = Options.CancellationToken.Register(() =>
{
using (StartScope())
{
CancelCall(new Status(StatusCode.Cancelled, "Call canceled by the client."));
}
});
_ctsRegistration = Options.CancellationToken.Register(CancelCallFromCancellationToken);
}

return (diagnosticSourceEnabled, activity);
Expand Down
17 changes: 6 additions & 11 deletions src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,16 @@ public Task<bool> MoveNext(CancellationToken cancellationToken)

private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
{
CancellationTokenSource? cts = null;
CancellationTokenRegistration? ctsRegistration = null;
try
{
// Linking tokens is expensive. Only create a linked token if the token passed in requires it
if (cancellationToken.CanBeCanceled)
{
cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _call.CancellationToken);
cancellationToken = cts.Token;
}
else
{
cancellationToken = _call.CancellationToken;
// The cancellation token will cancel the call CTS.
ctsRegistration = cancellationToken.Register(_call.CancelCallFromCancellationToken);
}

cancellationToken.ThrowIfCancellationRequested();
_call.CancellationToken.ThrowIfCancellationRequested();

if (_httpResponse == null)
{
Expand Down Expand Up @@ -167,7 +162,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
_responseStream,
_grpcEncoding,
singleMessage: false,
cancellationToken).ConfigureAwait(false);
_call.CancellationToken).ConfigureAwait(false);
if (Current == null)
{
// No more content in response so report status to call.
Expand Down Expand Up @@ -202,7 +197,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
}
finally
{
cts?.Dispose();
ctsRegistration?.Dispose();
}
}

Expand Down
48 changes: 48 additions & 0 deletions test/FunctionalTests/Client/StreamingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
Expand Down Expand Up @@ -515,6 +516,53 @@ await context.WriteResponseHeadersAsync(new Metadata
Assert.AreEqual("Message", call.GetStatus().Detail);
}

[Test]
public async Task DuplexStreaming_CancelResponseMoveNext_CancellationSentToServer()
{
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);

async Task DuplexStreamingWithCancellation(IAsyncStreamReader<DataMessage> requestStream, IServerStreamWriter<DataMessage> responseStream, ServerCallContext context)
{
try
{
await foreach (var message in requestStream.ReadAllAsync())
{
await responseStream.WriteAsync(message);
}
}
catch (Exception ex)
{
tcs.TrySetException(ex);
}
}

// Arrange
var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod<DataMessage, DataMessage>(DuplexStreamingWithCancellation);

var channel = CreateChannel();

var client = TestClientFactory.Create(channel, method);

// Act
var call = client.DuplexStreamingCall();

await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(Encoding.UTF8.GetBytes("Hello world")) });

await call.ResponseStream.MoveNext();

var cts = new CancellationTokenSource();
var task = call.ResponseStream.MoveNext(cts.Token);

cts.Cancel();

// Assert
var clientEx = await ExceptionAssert.ThrowsAsync<RpcException>(() => task);
Assert.AreEqual(StatusCode.Cancelled, clientEx.StatusCode);
Assert.AreEqual("Call canceled by the client.", clientEx.Status.Detail);

await ExceptionAssert.ThrowsAsync<IOException>(() => tcs.Task);
}

private static byte[] CreateTestData(int size)
{
var data = new byte[size];
Expand Down
30 changes: 13 additions & 17 deletions test/FunctionalTests/Web/Server/DeadlineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,24 @@ static async Task<HelloReply> WaitUntilDeadline(HelloRequest request, ServerCall

var grpcWebClient = CreateGrpcWebClient();

// TODO(JamesNK): This test is/was flaky. Remove loop if this test is no longer a problem
for (int i = 0; i < 20; i++)
var requestMessage = new HelloRequest
{
var requestMessage = new HelloRequest
{
Name = "World"
};
Name = "World"
};

var requestStream = new MemoryStream();
MessageHelpers.WriteMessage(requestStream, requestMessage);
var requestStream = new MemoryStream();
MessageHelpers.WriteMessage(requestStream, requestMessage);

var httpRequest = GrpcHttpHelper.Create(method.FullName);
httpRequest.Headers.Add(GrpcProtocolConstants.TimeoutHeader, "50m");
httpRequest.Content = new GrpcStreamContent(requestStream);
var httpRequest = GrpcHttpHelper.Create(method.FullName);
httpRequest.Headers.Add(GrpcProtocolConstants.TimeoutHeader, "50m");
httpRequest.Content = new GrpcStreamContent(requestStream);

// Act
var response = await grpcWebClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
// Act
var response = await grpcWebClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();

// Assert
response.AssertIsSuccessfulGrpcRequest();
response.AssertTrailerStatus(StatusCode.DeadlineExceeded, "Deadline Exceeded");
}
// Assert
response.AssertIsSuccessfulGrpcRequest();
response.AssertTrailerStatus(StatusCode.DeadlineExceeded, "Deadline Exceeded");
}

[Test]
Expand Down