Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public static SqlScalarExpression VisitBuiltinFunctionCall(MethodCallExpression
case nameof(CosmosLinqExtensions.DocumentId):
case nameof(CosmosLinqExtensions.RRF):
case nameof(CosmosLinqExtensions.FullTextScore):
case nameof(CosmosLinqExtensions.VectorDistance):
return OtherBuiltinSystemFunctions.Visit(methodCallExpression, context);
default:
return TypeCheckFunctions.Visit(methodCallExpression, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ namespace Microsoft.Azure.Cosmos.Linq
using System.Collections.ObjectModel;
using System.Globalization;
using System.Linq.Expressions;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Azure.Cosmos.CosmosElements;
using Microsoft.Azure.Cosmos.SqlObjects;

internal static class OtherBuiltinSystemFunctions
Expand Down Expand Up @@ -41,19 +44,22 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
throw new ArgumentException(
string.Format(
CultureInfo.CurrentCulture,
"Expressions of type {0} is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to {1}.",
"Expressions of type {0} is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to {1}, {2}.",
argument.Type,
nameof(CosmosLinqExtensions.FullTextScore)));
nameof(CosmosLinqExtensions.FullTextScore),
nameof(CosmosLinqExtensions.VectorDistance)));
}

if (functionCallExpression.Method.Name != nameof(CosmosLinqExtensions.FullTextScore))
if (functionCallExpression.Method.Name != nameof(CosmosLinqExtensions.FullTextScore) &&
functionCallExpression.Method.Name != nameof(CosmosLinqExtensions.VectorDistance))
{
throw new ArgumentException(
string.Format(
CultureInfo.CurrentCulture,
"Method {0} is not supported as an argument to CosmosLinqExtensions.RRF. Supported methods are {1}.",
"Method {0} is not supported as an argument to CosmosLinqExtensions.RRF. Supported methods are {1}, {2}.",
functionCallExpression.Method.Name,
nameof(CosmosLinqExtensions.FullTextScore)));
nameof(CosmosLinqExtensions.FullTextScore),
nameof(CosmosLinqExtensions.VectorDistance)));
}

arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(argument, context));
Expand Down Expand Up @@ -108,6 +114,55 @@ protected override SqlScalarExpression VisitExplicit(MethodCallExpression method
}
}

private class VectorDistanceVisit : SqlBuiltinFunctionVisitor
{
public VectorDistanceVisit()
: base("VectorDistance",
true,
new List<Type[]>()
{
new Type[]{typeof(float[]), typeof(float[]), typeof(bool), typeof(CosmosLinqExtensions.VectorDistanceOptions)},
new Type[]{typeof(sbyte[]), typeof(sbyte[]), typeof(bool), typeof(CosmosLinqExtensions.VectorDistanceOptions)},
new Type[]{typeof(byte[]), typeof(byte[]), typeof(bool), typeof(CosmosLinqExtensions.VectorDistanceOptions)},
})
{
}

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
if (methodCallExpression.Arguments.Count != 4) throw new ArgumentException();

List<SqlScalarExpression> arguments = new List<SqlScalarExpression>
{
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[0], context),
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[1], context),
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[2], context)
};

if (methodCallExpression.Arguments[3] is ConstantExpression optionExpression && optionExpression.Value != null)
{
JsonSerializerOptions options = new JsonSerializerOptions
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};
options.Converters.Add(new JsonStringEnumConverter(JsonNamingPolicy.CamelCase));

string serializedConstant = JsonSerializer.Serialize(
optionExpression.Value,
options);

arguments.Add(CosmosElement.Parse(serializedConstant).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton));
}

return SqlFunctionCallScalarExpression.CreateBuiltin(SqlFunctionCallScalarExpression.Names.VectorDistance, arguments.ToImmutableArray());
}

protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
return null;
}
}

private static Dictionary<string, BuiltinFunctionVisitor> FunctionsDefinitions { get; set; }

static OtherBuiltinSystemFunctions()
Expand All @@ -123,6 +178,7 @@ static OtherBuiltinSystemFunctions()
}),
[nameof(CosmosLinqExtensions.RRF)] = new RRFVisit(),
[nameof(CosmosLinqExtensions.FullTextScore)] = new FullTextScoreVisit(),
[nameof(CosmosLinqExtensions.VectorDistance)] = new VectorDistanceVisit(),
};
}

Expand Down
101 changes: 98 additions & 3 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
namespace Microsoft.Azure.Cosmos.Linq
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Diagnostics;
Expand All @@ -22,6 +21,33 @@ namespace Microsoft.Azure.Cosmos.Linq
/// </summary>
public static class CosmosLinqExtensions
{
/// <summary>
/// Object representing the options for vector distance calculation. All field are optional. if a field is not specified, the default value will be used.
/// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/vectordistance.
/// </summary>
public sealed class VectorDistanceOptions
{
/// <summary>
/// The metric used to compute distance/similarity. Valid values are "cosine", "dotproduct", "euclidean".
/// If not specified, the default value is what is defined in the container policy
/// </summary>
[JsonPropertyName("distanceFunction")]
public DistanceFunction? DistanceFunction { get; set; }

/// <summary>
/// The data type of the vectors. float32, int8, uint8 values. Default value is float32.
/// </summary>
[JsonPropertyName("dataType")]
public VectorDataType? DataType { get; set; }

/// <summary>
/// An integer specifying the size of the search list when conducting a vector search on the DiskANN index.
/// Increasing this may improve accuracy at the expense of RU cost and latency. Min=1, Default=10, Max=100.
/// </summary>
[JsonPropertyName("searchListSizeMultiplier")]
public int? SearchListSizeMultiplier { get; set; }
}

/// <summary>
/// Returns the integer identifier corresponding to a specific item within a physical partition.
/// This method is to be used in LINQ expressions only and will be evaluated on server.
Expand Down Expand Up @@ -239,6 +265,75 @@ public static bool RegexMatch(this object obj, string regularExpression, string
throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented);
}

/// <summary>
/// Returns the similarity score between two specified vectors.
/// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/vectordistance.
/// This method is to be used in LINQ expressions only and will be evaluated on server.
/// There's no implementation provided in the client library.
/// </summary>
/// <param name="vector1">The first vector.</param>
/// <param name="vector2">The second vector.</param>
/// <param name="isBruteForce">A boolean specifying how the computed value is used in an ORDER BY expression. If true, then brute force is used. A value of false uses any index defined on the vector property, if it exists. </param>
/// <param name="options">An JSON formatted object literal used to specify options for the vector distance calculation. </param>
/// <returns>Returns the similarity score between two specified vectors.</returns>
/// <example>
/// <code>
/// <![CDATA[
/// var matched = documents.Select(document => document.vector1.VectorDistance(<vector2>, true, new VectorDistanceOptions() { DistanceFunction = DistanceFunction.Cosine, DataType = VectorDataType.Float32}));
/// ]]>
/// </code>
/// </example>
public static double VectorDistance(this float[] vector1, float[] vector2, bool isBruteForce, VectorDistanceOptions options)
{
throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented);
}

/// <summary>
/// Returns the similarity score between two specified vectors.
/// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/vectordistance.
/// This method is to be used in LINQ expressions only and will be evaluated on server.
/// There's no implementation provided in the client library.
/// </summary>
/// <param name="vector1">The first vector.</param>
/// <param name="vector2">The second vector.</param>
/// <param name="isBruteForce">A boolean specifying how the computed value is used in an ORDER BY expression. If true, then brute force is used. A value of false uses any index defined on the vector property, if it exists. </param>
/// <param name="options">An JSON formatted object literal used to specify options for the vector distance calculation. </param>
/// <returns>Returns the similarity score between two specified vectors.</returns>
/// <example>
/// <code>
/// <![CDATA[
/// var matched = documents.Select(document => document.vector1.VectorDistance(<vector2>, true, new VectorDistanceOptions() { DistanceFunction = DistanceFunction.Cosine, DataType = VectorDataType.Int8}));
/// ]]>
/// </code>
/// </example>
public static double VectorDistance(this byte[] vector1, byte[] vector2, bool isBruteForce, VectorDistanceOptions options)
{
throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented);
}

/// <summary>
/// Returns the similarity score between two specified vectors.
/// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/vectordistance.
/// This method is to be used in LINQ expressions only and will be evaluated on server.
/// There's no implementation provided in the client library.
/// </summary>
/// <param name="vector1">The first vector.</param>
/// <param name="vector2">The second vector.</param>
/// <param name="isBruteForce">A boolean specifying how the computed value is used in an ORDER BY expression. If true, then brute force is used. A value of false uses any index defined on the vector property, if it exists. </param>
/// <param name="options">An JSON formatted object literal used to specify options for the vector distance calculation. </param>
/// <returns>Returns the similarity score between two specified vectors.</returns>
/// <example>
/// <code>
/// <![CDATA[
/// var matched = documents.Select(document => document.vector1.VectorDistance(<vector2>, true, new VectorDistanceOptions() { DistanceFunction = DistanceFunction.Cosine, DataType = VectorDataType.Uint8}));
/// ]]>
/// </code>
/// </example>
public static double VectorDistance(this sbyte[] vector1, sbyte[] vector2, bool isBruteForce, VectorDistanceOptions options)
{
throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented);
}

/// <summary>
/// Returns a boolean indicating whether the keyword string expression is contained in a specified property path.
/// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextcontains.
Expand Down Expand Up @@ -310,7 +405,7 @@ public static bool FullTextContainsAny(this object obj, params string[] searches
/// </summary>
/// <param name="obj"></param>
/// <param name="terms">A nonempty array of string literals.</param>
/// <returns>Returns true BM25 score value that can only be used in an ORDER BY RANK clause.</returns>
/// <returns>Returns a BM25 score value that can only be used in an ORDER BY RANK clause.</returns>
/// <example>
/// <code>
/// <![CDATA[
Expand Down
6 changes: 5 additions & 1 deletion Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2050,7 +2050,11 @@ private static SqlOrderByClause VisitOrderBy(ReadOnlyCollection<Expression> argu

LambdaExpression lambda = Utilities.GetLambda(arguments[1]);
SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(lambda, context);
SqlOrderByItem orderByItem = SqlOrderByItem.Create(sqlfunc, isDescending);

// Order By VectorDistance is a special case, since there is no ordering required.
bool isVectorDistance = (sqlfunc is SqlFunctionCallScalarExpression functionCall) && (functionCall.Name.Value == SqlFunctionCallScalarExpression.Names.VectorDistance);

SqlOrderByItem orderByItem = SqlOrderByItem.Create(sqlfunc, isVectorDistance ? null : isDescending);
SqlOrderByClause orderby = SqlOrderByClause.Create(new SqlOrderByItem[] { orderByItem });
return orderby;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ public static class Names
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword
public const string Trim = "TRIM";
public const string Trunc = "TRUNC";
public const string Upper = "UPPER";
public const string Upper = "UPPER";
public const string VectorDistance = "VectorDistance";
}

public static class Identifiers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ WHERE (RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["Stri
</Input>
<Output>
<SqlQuery><![CDATA[]]></SqlQuery>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore.]]></ErrorMessage>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore, VectorDistance.]]></ErrorMessage>
</Output>
</Result>
<Result>
Expand All @@ -87,7 +87,7 @@ WHERE (RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["Stri
</Input>
<Output>
<SqlQuery><![CDATA[]]></SqlQuery>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore.]]></ErrorMessage>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore, VectorDistance.]]></ErrorMessage>
</Output>
</Result>
<Result>
Expand All @@ -97,7 +97,7 @@ WHERE (RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["Stri
</Input>
<Output>
<SqlQuery><![CDATA[]]></SqlQuery>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore.]]></ErrorMessage>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore, VectorDistance.]]></ErrorMessage>
</Output>
</Result>
<Result>
Expand All @@ -107,7 +107,7 @@ WHERE (RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["Stri
</Input>
<Output>
<SqlQuery><![CDATA[]]></SqlQuery>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore.]]></ErrorMessage>
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore, VectorDistance.]]></ErrorMessage>
</Output>
</Result>
<Result>
Expand All @@ -117,7 +117,7 @@ WHERE (RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["Stri
</Input>
<Output>
<SqlQuery><![CDATA[]]></SqlQuery>
<ErrorMessage><![CDATA[Method RRF is not supported as an argument to CosmosLinqExtensions.RRF. Supported methods are FullTextScore.]]></ErrorMessage>
<ErrorMessage><![CDATA[Method RRF is not supported as an argument to CosmosLinqExtensions.RRF. Supported methods are FullTextScore, VectorDistance.]]></ErrorMessage>
</Output>
</Result>
</Results>
Loading
Loading