Skip to content

Commit b320e4e

Browse files
feat: optimize CachedRequestBuilder (#1716)
Co-authored-by: Chris Pulman <[email protected]>
1 parent 107d716 commit b320e4e

File tree

2 files changed

+232
-24
lines changed

2 files changed

+232
-24
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
using System.Net;
2+
using System.Net.Http;
3+
using System.Reflection;
4+
5+
using RichardSzalay.MockHttp;
6+
7+
using Xunit;
8+
9+
namespace Refit.Tests;
10+
11+
public interface IGeneralRequests
12+
{
13+
[Post("/foo")]
14+
Task Empty();
15+
16+
[Post("/foo")]
17+
Task SingleParameter(string id);
18+
19+
[Post("/foo")]
20+
Task MultiParameter(string id, string name);
21+
22+
[Post("/foo")]
23+
Task SingleGenericMultiParameter<TValue>(string id, string name, TValue generic);
24+
}
25+
26+
public interface IDuplicateNames
27+
{
28+
[Post("/foo")]
29+
Task SingleParameter(string id);
30+
31+
[Post("/foo")]
32+
Task SingleParameter(int id);
33+
}
34+
35+
public class CachedRequestBuilderTests
36+
{
37+
[Fact]
38+
public async Task CacheHasCorrectNumberOfElementsTest()
39+
{
40+
var mockHttp = new MockHttpMessageHandler();
41+
var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp };
42+
43+
var fixture = RestService.For<IGeneralRequests>("http://bar", settings);
44+
45+
// get internal dictionary to check count
46+
var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder");
47+
var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation;
48+
49+
mockHttp
50+
.Expect(HttpMethod.Post, "http://bar/foo")
51+
.Respond(HttpStatusCode.OK);
52+
await fixture.Empty();
53+
Assert.Single(requestBuilder.MethodDictionary);
54+
55+
mockHttp
56+
.Expect(HttpMethod.Post, "http://bar/foo")
57+
.WithQueryString("id", "id")
58+
.Respond(HttpStatusCode.OK);
59+
await fixture.SingleParameter("id");
60+
Assert.Equal(2, requestBuilder.MethodDictionary.Count);
61+
62+
mockHttp
63+
.Expect(HttpMethod.Post, "http://bar/foo")
64+
.WithQueryString("id", "id")
65+
.WithQueryString("name", "name")
66+
.Respond(HttpStatusCode.OK);
67+
await fixture.MultiParameter("id", "name");
68+
Assert.Equal(3, requestBuilder.MethodDictionary.Count);
69+
70+
mockHttp
71+
.Expect(HttpMethod.Post, "http://bar/foo")
72+
.WithQueryString("id", "id")
73+
.WithQueryString("name", "name")
74+
.WithQueryString("generic", "generic")
75+
.Respond(HttpStatusCode.OK);
76+
await fixture.SingleGenericMultiParameter("id", "name", "generic");
77+
Assert.Equal(4, requestBuilder.MethodDictionary.Count);
78+
79+
mockHttp.VerifyNoOutstandingExpectation();
80+
}
81+
82+
[Fact]
83+
public async Task NoDuplicateEntriesTest()
84+
{
85+
var mockHttp = new MockHttpMessageHandler();
86+
var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp };
87+
88+
var fixture = RestService.For<IGeneralRequests>("http://bar", settings);
89+
90+
// get internal dictionary to check count
91+
var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder");
92+
var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation;
93+
94+
// send the same request repeatedly to ensure that multiple dictionary entries are not created
95+
mockHttp
96+
.Expect(HttpMethod.Post, "http://bar/foo")
97+
.WithQueryString("id", "id")
98+
.Respond(HttpStatusCode.OK);
99+
await fixture.SingleParameter("id");
100+
Assert.Single(requestBuilder.MethodDictionary);
101+
102+
mockHttp
103+
.Expect(HttpMethod.Post, "http://bar/foo")
104+
.WithQueryString("id", "id")
105+
.Respond(HttpStatusCode.OK);
106+
await fixture.SingleParameter("id");
107+
Assert.Single(requestBuilder.MethodDictionary);
108+
109+
mockHttp
110+
.Expect(HttpMethod.Post, "http://bar/foo")
111+
.WithQueryString("id", "id")
112+
.Respond(HttpStatusCode.OK);
113+
await fixture.SingleParameter("id");
114+
Assert.Single(requestBuilder.MethodDictionary);
115+
116+
mockHttp.VerifyNoOutstandingExpectation();
117+
}
118+
119+
[Fact]
120+
public async Task SameNameDuplicateEntriesTest()
121+
{
122+
var mockHttp = new MockHttpMessageHandler();
123+
var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp };
124+
125+
var fixture = RestService.For<IDuplicateNames>("http://bar", settings);
126+
127+
// get internal dictionary to check count
128+
var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder");
129+
var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation;
130+
131+
// send the two different requests with the same name
132+
mockHttp
133+
.Expect(HttpMethod.Post, "http://bar/foo")
134+
.WithQueryString("id", "id")
135+
.Respond(HttpStatusCode.OK);
136+
await fixture.SingleParameter("id");
137+
Assert.Single(requestBuilder.MethodDictionary);
138+
139+
mockHttp
140+
.Expect(HttpMethod.Post, "http://bar/foo")
141+
.WithQueryString("id", "10")
142+
.Respond(HttpStatusCode.OK);
143+
await fixture.SingleParameter(10);
144+
Assert.Equal(2, requestBuilder.MethodDictionary.Count);
145+
146+
mockHttp.VerifyNoOutstandingExpectation();
147+
}
148+
}

Refit/CachedRequestBuilderImplementation.cs

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,33 @@ public CachedRequestBuilderImplementation(IRequestBuilder innerBuilder)
2020
}
2121

2222
readonly IRequestBuilder innerBuilder;
23-
readonly ConcurrentDictionary<
24-
string,
23+
internal readonly ConcurrentDictionary<
24+
MethodTableKey,
2525
Func<HttpClient, object[], object?>
26-
> methodDictionary = new();
26+
> MethodDictionary = new();
2727

2828
public Func<HttpClient, object[], object?> BuildRestResultFuncForMethod(
2929
string methodName,
3030
Type[]? parameterTypes = null,
3131
Type[]? genericArgumentTypes = null
3232
)
3333
{
34-
var cacheKey = GetCacheKey(
34+
var cacheKey = new MethodTableKey(
3535
methodName,
3636
parameterTypes ?? Array.Empty<Type>(),
3737
genericArgumentTypes ?? Array.Empty<Type>()
3838
);
39-
var func = methodDictionary.GetOrAdd(
40-
cacheKey,
39+
40+
if (MethodDictionary.TryGetValue(cacheKey, out var methodFunc))
41+
{
42+
return methodFunc;
43+
}
44+
45+
// use GetOrAdd with cloned array method table key. This prevents the array from being modified, breaking the dictionary.
46+
var func = MethodDictionary.GetOrAdd(
47+
new MethodTableKey(methodName,
48+
parameterTypes?.ToArray() ?? Array.Empty<Type>(),
49+
genericArgumentTypes?.ToArray() ?? Array.Empty<Type>()),
4150
_ =>
4251
innerBuilder.BuildRestResultFuncForMethod(
4352
methodName,
@@ -48,37 +57,88 @@ readonly ConcurrentDictionary<
4857

4958
return func;
5059
}
60+
}
5161

52-
static string GetCacheKey(
53-
string methodName,
54-
Type[] parameterTypes,
55-
Type[] genericArgumentTypes
56-
)
62+
/// <summary>
63+
/// Represents a method composed of its name, generic arguments and parameters.
64+
/// </summary>
65+
internal readonly struct MethodTableKey : IEquatable<MethodTableKey>
66+
{
67+
/// <summary>
68+
/// Constructs an instance of <see cref="MethodTableKey"/>.
69+
/// </summary>
70+
/// <param name="methodName">Represents the methods name.</param>
71+
/// <param name="parameters">Array containing the methods parameters.</param>
72+
/// <param name="genericArguments">Array containing the methods generic arguments.</param>
73+
public MethodTableKey (string methodName, Type[] parameters, Type[] genericArguments)
5774
{
58-
var genericDefinition = GetGenericString(genericArgumentTypes);
59-
var argumentString = GetArgumentString(parameterTypes);
60-
61-
return $"{methodName}{genericDefinition}({argumentString})";
75+
MethodName = methodName;
76+
Parameters = parameters;
77+
GenericArguments = genericArguments;
6278
}
6379

64-
static string GetArgumentString(Type[] parameterTypes)
80+
/// <summary>
81+
/// The methods name.
82+
/// </summary>
83+
string MethodName { get; }
84+
85+
/// <summary>
86+
/// Array containing the methods parameters.
87+
/// </summary>
88+
Type[] Parameters { get; }
89+
90+
/// <summary>
91+
/// Array containing the methods generic arguments.
92+
/// </summary>
93+
Type[] GenericArguments { get; }
94+
95+
public override int GetHashCode()
6596
{
66-
if (parameterTypes == null || parameterTypes.Length == 0)
97+
unchecked
6798
{
68-
return "";
69-
}
99+
var hashCode = MethodName.GetHashCode();
100+
101+
foreach (var argument in Parameters)
102+
{
103+
hashCode = (hashCode * 397) ^ argument.GetHashCode();
104+
}
70105

71-
return string.Join(", ", parameterTypes.Select(t => t.FullName));
106+
foreach (var genericArgument in GenericArguments)
107+
{
108+
hashCode = (hashCode * 397) ^ genericArgument.GetHashCode();
109+
}
110+
return hashCode;
111+
}
72112
}
73113

74-
static string GetGenericString(Type[] genericArgumentTypes)
114+
public bool Equals(MethodTableKey other)
75115
{
76-
if (genericArgumentTypes == null || genericArgumentTypes.Length == 0)
116+
if (Parameters.Length != other.Parameters.Length
117+
|| GenericArguments.Length != other.GenericArguments.Length
118+
|| MethodName != other.MethodName)
77119
{
78-
return "";
120+
return false;
79121
}
80122

81-
return "<" + string.Join(", ", genericArgumentTypes.Select(t => t.FullName)) + ">";
123+
for (var i = 0; i < Parameters.Length; i++)
124+
{
125+
if (Parameters[i] != other.Parameters[i])
126+
{
127+
return false;
128+
}
129+
}
130+
131+
for (var i = 0; i < GenericArguments.Length; i++)
132+
{
133+
if (GenericArguments[i] != other.GenericArguments[i])
134+
{
135+
return false;
136+
}
137+
}
138+
139+
return true;
82140
}
141+
142+
public override bool Equals(object? obj) => obj is MethodTableKey other && Equals(other);
83143
}
84144
}

0 commit comments

Comments
 (0)