1- using System ;
1+ // Copyright (c) Microsoft. All rights reserved.
2+ // Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+ using System ;
5+ using System . Buffers ;
26using System . Collections . Concurrent ;
37using System . Threading ;
48using System . Threading . Tasks ;
9+ using MessagePack ;
510using Microsoft . Azure . SignalR . Common ;
611using Microsoft . Azure . SignalR . Protocol ;
712
@@ -35,15 +40,27 @@ public Task<AckStatus> CreateSingleAck(out int id, TimeSpan? ackTimeout = defaul
3540 {
3641 return Task . FromResult ( AckStatus . Ok ) ;
3742 }
38- var info = ( IAckInfo < AckStatus > ) _acks . GetOrAdd ( id , _ => new SingleAckInfo ( ackTimeout ?? _defaultAckTimeout ) ) ;
39- if ( info is MultiAckInfo )
43+ var info = ( IAckInfo < AckStatus > ) _acks . GetOrAdd ( id , _ => new SingleAckWithStatusInfo ( ackTimeout ?? _defaultAckTimeout ) ) ;
44+ if ( info is MultiAckWithStatusInfo )
4045 {
4146 throw new InvalidOperationException ( ) ;
4247 }
4348 cancellationToken . Register ( ( ) => info . Cancel ( ) ) ;
4449 return info . Task ;
4550 }
4651
52+ public Task < T > CreateSingleAck < T > ( out int id , TimeSpan ? ackTimeout = default , CancellationToken cancellationToken = default ) where T : IMessagePackSerializable , new ( )
53+ {
54+ id = NextId ( ) ;
55+ if ( _disposed )
56+ {
57+ return Task . FromResult ( new T ( ) ) ;
58+ }
59+ var info = ( IAckInfo < IMessagePackSerializable > ) _acks . GetOrAdd ( id , _ => new SingleAckWithMessagePackPayloadInfo < T > ( ackTimeout ?? _defaultAckTimeout ) ) ;
60+ cancellationToken . Register ( info . Cancel ) ;
61+ return info . Task . ContinueWith ( task => ( T ) task . Result ) ;
62+ }
63+
4764 public static bool HandleAckStatus ( IAckableMessage message , AckStatus status )
4865 {
4966 return status switch
@@ -62,29 +79,19 @@ public Task<AckStatus> CreateMultiAck(out int id, TimeSpan? ackTimeout = default
6279 {
6380 return Task . FromResult ( AckStatus . Ok ) ;
6481 }
65- var info = ( IAckInfo < AckStatus > ) _acks . GetOrAdd ( id , _ => new MultiAckInfo ( ackTimeout ?? _defaultAckTimeout ) ) ;
66- if ( info is SingleAckInfo )
82+ var info = ( IAckInfo < AckStatus > ) _acks . GetOrAdd ( id , _ => new MultiAckWithStatusInfo ( ackTimeout ?? _defaultAckTimeout ) ) ;
83+ if ( info is SingleAckInfo < AckStatus > )
6784 {
6885 throw new InvalidOperationException ( ) ;
6986 }
7087 return info . Task ;
7188 }
7289
73- public void TriggerAck ( int id , AckStatus status = AckStatus . Ok )
90+ public void TriggerAck ( int id , AckStatus status = AckStatus . Ok , ReadOnlySequence < byte > ? payload = default )
7491 {
75- if ( _acks . TryGetValue ( id , out var info ) )
92+ if ( _acks . TryGetValue ( id , out var info ) && info . Ack ( status , payload ) )
7693 {
77- switch ( info )
78- {
79- case IAckInfo < AckStatus > ackInfo :
80- if ( ackInfo . Ack ( status ) )
81- {
82- _acks . TryRemove ( id , out _ ) ;
83- }
84- break ;
85- default :
86- throw new InvalidCastException ( $ "Expected: IAckInfo<{ typeof ( IAckInfo < AckStatus > ) . Name } >, actual type: { info . GetType ( ) . Name } ") ;
87- }
94+ _acks . TryRemove ( id , out _ ) ;
8895 }
8996 }
9097
@@ -125,11 +132,11 @@ private void CheckAcks()
125132 {
126133 if ( _acks . TryRemove ( id , out _ ) )
127134 {
128- if ( ack is SingleAckInfo singleAckInfo )
135+ if ( ack is SingleAckInfo < AckStatus > singleAckInfo )
129136 {
130137 singleAckInfo . Ack ( AckStatus . Timeout ) ;
131138 }
132- else if ( ack is MultiAckInfo multipleAckInfo )
139+ else if ( ack is MultiAckWithStatusInfo multipleAckInfo )
133140 {
134141 multipleAckInfo . ForceAck ( AckStatus . Timeout ) ;
135142 }
@@ -170,39 +177,57 @@ private interface IAckInfo
170177 {
171178 DateTime TimeoutAt { get ; }
172179 void Cancel ( ) ;
180+ bool Ack ( AckStatus status , ReadOnlySequence < byte > ? payload = null ) ;
173181 }
174182
175183 private interface IAckInfo < T > : IAckInfo
176184 {
177185 Task < T > Task { get ; }
178- bool Ack ( T status ) ;
179186 }
180187
181188 public interface IMultiAckInfo
182189 {
183190 bool SetExpectedCount ( int expectedCount ) ;
184191 }
185192
186- private sealed class SingleAckInfo : IAckInfo < AckStatus >
193+ private abstract class SingleAckInfo < T > : IAckInfo < T >
187194 {
188- public readonly TaskCompletionSource < AckStatus > _tcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
189-
195+ public readonly TaskCompletionSource < T > _tcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
190196 public DateTime TimeoutAt { get ; }
191-
192197 public SingleAckInfo ( TimeSpan timeout )
193198 {
194199 TimeoutAt = DateTime . UtcNow + timeout ;
195200 }
201+ public abstract bool Ack ( AckStatus status , ReadOnlySequence < byte > ? payload = null ) ;
202+ public Task < T > Task => _tcs . Task ;
203+ public void Cancel ( ) => _tcs . TrySetCanceled ( ) ;
204+ }
196205
197- public bool Ack ( AckStatus status = AckStatus . Ok ) = >
198- _tcs . TrySetResult ( status ) ;
206+ private class SingleAckWithStatusInfo : SingleAckInfo < AckStatus >
207+ {
199208
200- public Task < AckStatus > Task => _tcs . Task ;
209+ public SingleAckWithStatusInfo ( TimeSpan timeout ) : base ( timeout ) { }
201210
202- public void Cancel ( ) => _tcs . TrySetCanceled ( ) ;
211+ public override bool Ack ( AckStatus status , ReadOnlySequence < byte > ? payload = null ) =>
212+ _tcs . TrySetResult ( status ) ;
213+ }
214+
215+ private sealed class SingleAckWithMessagePackPayloadInfo < T > : SingleAckInfo < IMessagePackSerializable > where T : IMessagePackSerializable , new ( )
216+ {
217+ public SingleAckWithMessagePackPayloadInfo ( TimeSpan timeout ) : base ( timeout ) { }
218+ public override bool Ack ( AckStatus status , ReadOnlySequence < byte > ? payload = null )
219+ {
220+ if ( payload == null )
221+ {
222+ throw new ArgumentNullException ( nameof ( payload ) ) ;
223+ }
224+ var reader = new MessagePackReader ( payload . Value ) ;
225+ var result = reader . Deserialize < T > ( string . Empty ) ;
226+ return _tcs . TrySetResult ( result ) ;
227+ }
203228 }
204229
205- private sealed class MultiAckInfo : IAckInfo < AckStatus > , IMultiAckInfo
230+ private sealed class MultiAckWithStatusInfo : IAckInfo < AckStatus > , IMultiAckInfo
206231 {
207232 public readonly TaskCompletionSource < AckStatus > _tcs = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
208233
@@ -211,7 +236,7 @@ private sealed class MultiAckInfo : IAckInfo<AckStatus>, IMultiAckInfo
211236
212237 public DateTime TimeoutAt { get ; }
213238
214- public MultiAckInfo ( TimeSpan timeout )
239+ public MultiAckWithStatusInfo ( TimeSpan timeout )
215240 {
216241 TimeoutAt = DateTime . UtcNow + timeout ;
217242 }
@@ -239,7 +264,7 @@ public bool SetExpectedCount(int expectedCount)
239264 return result ;
240265 }
241266
242- public bool Ack ( AckStatus status = AckStatus . Ok )
267+ public bool Ack ( AckStatus status = AckStatus . Ok , ReadOnlySequence < byte > ? payload = null )
243268 {
244269 bool result ;
245270 lock ( _tcs )
0 commit comments