Skip to content

Commit d0066c2

Browse files
lukedukeusmysticmind
authored andcommitted
Add support for string comparison methods (#3744)
* add support for string.compare * add support for string.compareto * add test cases
1 parent e2a58cf commit d0066c2

File tree

4 files changed

+132
-1
lines changed

4 files changed

+132
-1
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using System.Linq;
2+
using System.Threading.Tasks;
3+
using Marten.Testing.Documents;
4+
using Marten.Testing.Harness;
5+
using Shouldly;
6+
7+
namespace LinqTests.Operators;
8+
9+
public class string_compare_operator: IntegrationContext
10+
{
11+
[Fact]
12+
public async Task string_compare_works()
13+
{
14+
theSession.Store(new Target { String = "Apple" });
15+
theSession.Store(new Target { String = "Banana" });
16+
theSession.Store(new Target { String = "Cherry" });
17+
theSession.Store(new Target { String = "Durian" });
18+
await theSession.SaveChangesAsync();
19+
20+
var queryable = theSession.Query<Target>().Where(x => string.Compare(x.String, "Cherry") > 0);
21+
22+
queryable.ToList().Count.ShouldBe(1);
23+
}
24+
25+
[Fact]
26+
public async Task string_compare_to_works()
27+
{
28+
theSession.Store(new Target { String = "Apple" });
29+
theSession.Store(new Target { String = "Banana" });
30+
theSession.Store(new Target { String = "Cherry" });
31+
theSession.Store(new Target { String = "Durian" });
32+
await theSession.SaveChangesAsync();
33+
34+
var queryable = theSession.Query<Target>().Where(x => x.String.CompareTo("Banana") > 0);
35+
36+
queryable.ToList().Count.ShouldBe(2);
37+
}
38+
39+
public string_compare_operator(DefaultStoreFixture fixture) : base(fixture)
40+
{
41+
}
42+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using System.Linq.Expressions;
2+
using Marten.Exceptions;
3+
using Marten.Linq.Members;
4+
using Weasel.Postgresql.SqlGeneration;
5+
6+
namespace Marten.Linq.Parsing;
7+
8+
internal class CompareToComparable: IComparableMember
9+
{
10+
private readonly SimpleExpression _left;
11+
private readonly SimpleExpression _right;
12+
13+
public CompareToComparable(SimpleExpression left, SimpleExpression right)
14+
{
15+
_left = left;
16+
_right = right;
17+
}
18+
19+
public ISqlFragment CreateComparison(string op, ConstantExpression constant)
20+
{
21+
// Only compare to 0 is valid: CompareTo() > 0 → ">", CompareTo() == 0 → "=", etc.
22+
if (constant.Value is int intValue && intValue == 0)
23+
{
24+
var leftFragment = _left.FindValueFragment();
25+
var rightFragment = _right.FindValueFragment();
26+
27+
return new ComparisonFilter(leftFragment, rightFragment, op);
28+
}
29+
30+
throw new BadLinqExpressionException(
31+
"string.CompareTo() must be compared to 0 (e.g., x.Name.CompareTo(\"A\") > 0)");
32+
}
33+
}
34+

src/Marten/Linq/Parsing/SimpleExpression.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#nullable disable
22
using System;
33
using System.Collections.Generic;
4-
using System.Diagnostics;
54
using System.Linq;
65
using System.Linq.Expressions;
76
using System.Reflection;
@@ -306,6 +305,29 @@ protected override Expression VisitParameter(ParameterExpression node)
306305

307306
protected override Expression VisitMethodCall(MethodCallExpression node)
308307
{
308+
if (node.Object == null &&
309+
node.Method.DeclaringType == typeof(string) &&
310+
node.Method.Name == "Compare" &&
311+
node.Arguments.Count == 2 &&
312+
node.Method.IsStatic)
313+
{
314+
var leftArg = new SimpleExpression(_queryableMembers, node.Arguments[0]);
315+
var rightArg = new SimpleExpression(_queryableMembers, node.Arguments[1]);
316+
Comparable = new StringCompareComparable(leftArg, rightArg);
317+
return null;
318+
}
319+
320+
if (node.Method.DeclaringType == typeof(string) &&
321+
node.Method.Name == "CompareTo" &&
322+
node.Arguments.Count == 1)
323+
{
324+
var left = new SimpleExpression(_queryableMembers, node.Object);
325+
var right = new SimpleExpression(_queryableMembers, node.Arguments[0]);
326+
327+
Comparable = new CompareToComparable(left, right);
328+
return null;
329+
}
330+
309331
if (node.Object == null && !(node.Arguments.FirstOrDefault() is MemberExpression))
310332
{
311333
// It's a method of a static, so this has to be a constant
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System.Linq.Expressions;
2+
using Marten.Exceptions;
3+
using Marten.Linq.Members;
4+
using Weasel.Postgresql.SqlGeneration;
5+
6+
namespace Marten.Linq.Parsing;
7+
8+
internal class StringCompareComparable: IComparableMember
9+
{
10+
private readonly SimpleExpression _left;
11+
private readonly SimpleExpression _right;
12+
13+
internal StringCompareComparable(SimpleExpression left, SimpleExpression right)
14+
{
15+
_left = left;
16+
_right = right;
17+
}
18+
19+
public ISqlFragment CreateComparison(string op, ConstantExpression constant)
20+
{
21+
if (constant.Value is int intValue && intValue == 0)
22+
{
23+
var leftFragment = _left.FindValueFragment();
24+
var rightFragment = _right.FindValueFragment();
25+
26+
return new ComparisonFilter(leftFragment, rightFragment, op);
27+
}
28+
else
29+
{
30+
throw new BadLinqExpressionException("string.Compare must be compared to 0");
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)