Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Aspire.OpenAI;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using OpenAI;

namespace Microsoft.Extensions.Hosting;
Expand Down Expand Up @@ -44,6 +45,11 @@ public static ChatClientBuilder AddKeyedChatClient(
services => CreateInnerChatClient(services, builder, deploymentName));
}

/// <summary>
/// Wrap the <see cref="OpenAIClient"/> in a telemetry client if tracing is enabled.
/// Note that this doesn't use ".UseOpenTelemetry()" because the order of the clients would be incorrect.
/// We want the telemetry client to be the innermost client, right next to the inner <see cref="OpenAIClient"/>.
/// </summary>
private static IChatClient CreateInnerChatClient(
IServiceProvider services,
AspireOpenAIClientBuilder builder,
Expand All @@ -56,6 +62,12 @@ private static IChatClient CreateInnerChatClient(
deploymentName ??= builder.GetRequiredDeploymentName();
var result = openAiClient.AsChatClient(deploymentName);

return builder.DisableTracing ? result : new OpenTelemetryChatClient(result);
if (builder.DisableTracing)
{
return result;
}

var loggerFactory = services.GetService<ILoggerFactory>();
return new OpenTelemetryChatClient(result, loggerFactory?.CreateLogger(typeof(OpenTelemetryChatClient)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Aspire.OpenAI;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using OpenAI;

namespace Microsoft.Extensions.Hosting;
Expand Down Expand Up @@ -44,6 +45,11 @@ public static EmbeddingGeneratorBuilder<string, Embedding<float>> AddKeyedEmbedd
services => CreateInnerEmbeddingGenerator(services, builder, deploymentName));
}

/// <summary>
/// Wrap the <see cref="OpenAIClient"/> in a telemetry client if tracing is enabled.
/// Note that this doesn't use ".UseOpenTelemetry()" because the order of the clients would be incorrect.
/// We want the telemetry client to be the innermost client, right next to the inner <see cref="OpenAIClient"/>.
/// </summary>
private static IEmbeddingGenerator<string, Embedding<float>> CreateInnerEmbeddingGenerator(
IServiceProvider services,
AspireOpenAIClientBuilder builder,
Expand All @@ -56,8 +62,14 @@ private static IEmbeddingGenerator<string, Embedding<float>> CreateInnerEmbeddin
deploymentName ??= builder.GetRequiredDeploymentName();
var result = openAiClient.AsEmbeddingGenerator(deploymentName);

return builder.DisableTracing
? result
: new OpenTelemetryEmbeddingGenerator<string, Embedding<float>>(result);
if (builder.DisableTracing)
{
return result;
}

var loggerFactory = services.GetService<ILoggerFactory>();
return new OpenTelemetryEmbeddingGenerator<string, Embedding<float>>(
result,
loggerFactory?.CreateLogger(typeof(OpenTelemetryEmbeddingGenerator<string, Embedding<float>>)));
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Aspire.Components.ConformanceTests;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Xunit;

namespace Aspire.OpenAI.Tests;
Expand Down Expand Up @@ -217,8 +219,52 @@ public async Task CanConfigurePipelineAsync(bool useKeyed)

var completion = await client.GetResponseAsync("Whatever");
Assert.Equal("Hello from middleware", completion.Message.Text);
}

static Task<ChatResponse> TestMiddleware(IList<ChatMessage> list, ChatOptions? options, IChatClient client, CancellationToken token)
=> Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Hello from middleware")));
[Theory]
[InlineData(true, false)]
[InlineData(false, false)]
[InlineData(true, true)]
[InlineData(false, true)]
public async Task LogsCorrectly(bool useKeyed, bool disableOpenTelemetry)
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake"),
new("Aspire:OpenAI:DisableTracing", disableOpenTelemetry.ToString()),
]);

builder.Services.AddSingleton<ILoggerFactory, TestLoggerFactory>();

if (useKeyed)
{
builder.AddOpenAIClient("openai").AddKeyedChatClient("openai_chatclient", "testdeployment1").Use(TestMiddleware, null);
}
else
{
builder.AddOpenAIClient("openai").AddChatClient("testdeployment1").Use(TestMiddleware, null);
}

using var host = builder.Build();
var client = useKeyed ?
host.Services.GetRequiredKeyedService<IChatClient>("openai_chatclient") :
host.Services.GetRequiredService<IChatClient>();
var loggerFactory = (TestLoggerFactory)host.Services.GetRequiredService<ILoggerFactory>();

var completion = await client.GetResponseAsync("Whatever");
Assert.Equal("Hello from middleware", completion.Message.Text);

const string category = "Microsoft.Extensions.AI.OpenTelemetryChatClient";
if (disableOpenTelemetry)
{
Assert.DoesNotContain(category, loggerFactory.Categories);
}
else
{
Assert.Contains(category, loggerFactory.Categories);
}
}

private static Task<ChatResponse> TestMiddleware(IList<ChatMessage> list, ChatOptions? options, IChatClient client, CancellationToken token)
=> Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Hello from middleware")));
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Aspire.Components.ConformanceTests;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Xunit;

namespace Aspire.OpenAI.Tests;
Expand Down Expand Up @@ -219,6 +221,50 @@ public async Task CanConfigurePipelineAsync(bool useKeyed)
Assert.Equal(1.23f, vector.ToArray().Single());
}

[Theory]
[InlineData(true, false)]
[InlineData(false, false)]
[InlineData(true, true)]
[InlineData(false, true)]
public async Task LogsCorrectly(bool useKeyed, bool disableOpenTelemetry)
{
var builder = Host.CreateEmptyApplicationBuilder(null);
builder.Configuration.AddInMemoryCollection([
new("ConnectionStrings:openai", $"Endpoint=https://aspireopenaitests.openai.azure.com/;Key=fake"),
new("Aspire:OpenAI:DisableTracing", disableOpenTelemetry.ToString()),
]);

builder.Services.AddSingleton<ILoggerFactory, TestLoggerFactory>();

if (useKeyed)
{
builder.AddOpenAIClient("openai").AddKeyedEmbeddingGenerator("openai_embeddinggenerator", "testdeployment1").Use(TestMiddleware);
}
else
{
builder.AddOpenAIClient("openai").AddEmbeddingGenerator("testdeployment1").Use(TestMiddleware);
}

using var host = builder.Build();
var generator = useKeyed ?
host.Services.GetRequiredKeyedService<IEmbeddingGenerator<string, Embedding<float>>>("openai_embeddinggenerator") :
host.Services.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>();
var loggerFactory = (TestLoggerFactory)host.Services.GetRequiredService<ILoggerFactory>();

var vector = await generator.GenerateEmbeddingVectorAsync("Hello");
Assert.Equal(1.23f, vector.ToArray().Single());

const string category = "Microsoft.Extensions.AI.OpenTelemetryEmbeddingGenerator";
if (disableOpenTelemetry)
{
Assert.DoesNotContain(category, loggerFactory.Categories);
}
else
{
Assert.Contains(category, loggerFactory.Categories);
}
}

private Task<GeneratedEmbeddings<Embedding<float>>> TestMiddleware(IEnumerable<string> inputs, EmbeddingGenerationOptions? options, IEmbeddingGenerator<string, Embedding<float>> nextAsync, CancellationToken cancellationToken)
{
float[] floats = [1.23f];
Expand Down