Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions extensions/AzureOpenAI/AzureOpenAITextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.Identity;
Expand All @@ -21,15 +22,17 @@ public class AzureOpenAITextEmbeddingGenerator : ITextEmbeddingGenerator
public AzureOpenAITextEmbeddingGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextEmbeddingGenerator>())
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextEmbeddingGenerator>(), httpClient)
{
}

public AzureOpenAITextEmbeddingGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILogger<AzureOpenAITextEmbeddingGenerator>? log = null)
ILogger<AzureOpenAITextEmbeddingGenerator>? log = null,
HttpClient? httpClient = null)
{
this._log = log ?? DefaultLogger<AzureOpenAITextEmbeddingGenerator>.Instance;

Expand All @@ -52,15 +55,17 @@ public AzureOpenAITextEmbeddingGenerator(
deploymentName: config.Deployment,
modelId: config.Deployment,
endpoint: config.Endpoint,
credential: new DefaultAzureCredential());
credential: new DefaultAzureCredential(),
httpClient: httpClient);
break;

case AzureOpenAIConfig.AuthTypes.APIKey:
this._client = new AzureOpenAITextEmbeddingGenerationService(
deploymentName: config.Deployment,
modelId: config.Deployment,
endpoint: config.Endpoint,
apiKey: config.APIKey);
apiKey: config.APIKey,
httpClient: httpClient);
break;

default:
Expand Down
14 changes: 11 additions & 3 deletions extensions/AzureOpenAI/AzureOpenAITextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -27,15 +28,17 @@ public class AzureOpenAITextGenerator : ITextGenerator
public AzureOpenAITextGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextGenerator>())
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<AzureOpenAITextGenerator>(), httpClient)
{
}

public AzureOpenAITextGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILogger<AzureOpenAITextGenerator>? log = null)
ILogger<AzureOpenAITextGenerator>? log = null,
HttpClient? httpClient = null)
{
this._log = log ?? DefaultLogger<AzureOpenAITextGenerator>.Instance;

Expand Down Expand Up @@ -73,6 +76,11 @@ public AzureOpenAITextGenerator(
}
};

if (httpClient is not null)
{
options.Transport = new HttpClientTransport(httpClient);
}

switch (config.Auth)
{
case AzureOpenAIConfig.AuthTypes.AzureIdentity:
Expand Down
26 changes: 18 additions & 8 deletions extensions/AzureOpenAI/DependencyInjection.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
Expand All @@ -20,13 +21,15 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="textTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="loggerFactory">.NET Logger factory</param>
/// <param name="onlyForRetrieval">Whether to use this embedding generator only during data ingestion, and not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithAzureOpenAITextEmbeddingGeneration(
this IKernelMemoryBuilder builder,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -38,7 +41,8 @@ public static IKernelMemoryBuilder WithAzureOpenAITextEmbeddingGeneration(
new AzureOpenAITextEmbeddingGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: loggerFactory));
loggerFactory: loggerFactory,
httpClient));
}

return builder;
Expand All @@ -50,15 +54,17 @@ public static IKernelMemoryBuilder WithAzureOpenAITextEmbeddingGeneration(
/// <param name="builder">Kernel Memory builder</param>
/// <param name="config">Azure OpenAI settings</param>
/// <param name="textTokenizer">Tokenizer used to count tokens used by prompts</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithAzureOpenAITextGeneration(
this IKernelMemoryBuilder builder,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
builder.Services.AddAzureOpenAITextGeneration(config, textTokenizer);
builder.Services.AddAzureOpenAITextGeneration(config, textTokenizer, httpClient);
return builder;
}
}
Expand All @@ -68,28 +74,32 @@ public static partial class DependencyInjection
public static IServiceCollection AddAzureOpenAIEmbeddingGeneration(
this IServiceCollection services,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
return services
.AddSingleton<ITextEmbeddingGenerator>(serviceProvider => new AzureOpenAITextEmbeddingGenerator(
config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
httpClient));
}

public static IServiceCollection AddAzureOpenAITextGeneration(
this IServiceCollection services,
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
return services
.AddSingleton<ITextGenerator>(serviceProvider => new AzureOpenAITextGenerator(
config: config,
textTokenizer: textTokenizer,
log: serviceProvider.GetService<ILogger<AzureOpenAITextGenerator>>()));
log: serviceProvider.GetService<ILogger<AzureOpenAITextGenerator>>(),
httpClient: httpClient));
}
}
44 changes: 29 additions & 15 deletions extensions/OpenAI/DependencyInjection.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
Expand All @@ -26,6 +27,7 @@ public static partial class KernelMemoryBuilderExtensions
/// <param name="textEmbeddingTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="loggerFactory">.NET Logger factory</param>
/// <param name="onlyForRetrieval">Whether to use OpenAI defaults only for ingestion, and not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAIDefaults(
this IKernelMemoryBuilder builder,
Expand All @@ -34,7 +36,8 @@ public static IKernelMemoryBuilder WithOpenAIDefaults(
ITextTokenizer? textGenerationTokenizer = null,
ITextTokenizer? textEmbeddingTokenizer = null,
ILoggerFactory? loggerFactory = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
textGenerationTokenizer ??= new DefaultGPTTokenizer();
textEmbeddingTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -50,15 +53,16 @@ public static IKernelMemoryBuilder WithOpenAIDefaults(
};
openAIConfig.Validate();

builder.Services.AddOpenAITextEmbeddingGeneration(openAIConfig, textEmbeddingTokenizer);
builder.Services.AddOpenAITextGeneration(openAIConfig, textGenerationTokenizer);
builder.Services.AddOpenAITextEmbeddingGeneration(openAIConfig, textEmbeddingTokenizer, httpClient);
builder.Services.AddOpenAITextGeneration(openAIConfig, textGenerationTokenizer, httpClient);

if (!onlyForRetrieval)
{
builder.AddIngestionEmbeddingGenerator(new OpenAITextEmbeddingGenerator(
config: openAIConfig,
textTokenizer: textEmbeddingTokenizer,
loggerFactory: loggerFactory));
loggerFactory: loggerFactory,
httpClient: httpClient));
}

return builder;
Expand All @@ -72,19 +76,21 @@ public static IKernelMemoryBuilder WithOpenAIDefaults(
/// <param name="textGenerationTokenizer">Tokenizer used to count tokens used by prompts</param>
/// <param name="textEmbeddingTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="onlyForRetrieval">Whether to use OpenAI only for ingestion, not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAI(
this IKernelMemoryBuilder builder,
OpenAIConfig config,
ITextTokenizer? textGenerationTokenizer = null,
ITextTokenizer? textEmbeddingTokenizer = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
config.Validate();
textGenerationTokenizer ??= new DefaultGPTTokenizer();
textEmbeddingTokenizer ??= new DefaultGPTTokenizer();

builder.WithOpenAITextEmbeddingGeneration(config, textEmbeddingTokenizer, onlyForRetrieval);
builder.WithOpenAITextEmbeddingGeneration(config, textEmbeddingTokenizer, onlyForRetrieval, httpClient);
builder.WithOpenAITextGeneration(config, textGenerationTokenizer);
return builder;
}
Expand All @@ -96,21 +102,23 @@ public static IKernelMemoryBuilder WithOpenAI(
/// <param name="config">OpenAI settings</param>
/// <param name="textTokenizer">Tokenizer used to count tokens sent to the embedding generator</param>
/// <param name="onlyForRetrieval">Whether to use OpenAI only for ingestion, not for retrieval (search and ask API)</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAITextEmbeddingGeneration(
this IKernelMemoryBuilder builder,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
bool onlyForRetrieval = false)
bool onlyForRetrieval = false,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();

builder.Services.AddOpenAITextEmbeddingGeneration(config);
builder.Services.AddOpenAITextEmbeddingGeneration(config, httpClient: httpClient);
if (!onlyForRetrieval)
{
builder.AddIngestionEmbeddingGenerator(
new OpenAITextEmbeddingGenerator(config, textTokenizer, loggerFactory: null));
new OpenAITextEmbeddingGenerator(config, textTokenizer, loggerFactory: null, httpClient));
}

return builder;
Expand All @@ -122,16 +130,18 @@ public static IKernelMemoryBuilder WithOpenAITextEmbeddingGeneration(
/// <param name="builder">Kernel Memory builder</param>
/// <param name="config">OpenAI settings</param>
/// <param name="textTokenizer">Tokenizer used to count tokens used by prompts</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <returns>KM builder instance</returns>
public static IKernelMemoryBuilder WithOpenAITextGeneration(
this IKernelMemoryBuilder builder,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();

builder.Services.AddOpenAITextGeneration(config, textTokenizer);
builder.Services.AddOpenAITextGeneration(config, textTokenizer, httpClient);
return builder;
}
}
Expand All @@ -141,7 +151,8 @@ public static partial class DependencyInjection
public static IServiceCollection AddOpenAITextEmbeddingGeneration(
this IServiceCollection services,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -151,13 +162,15 @@ public static IServiceCollection AddOpenAITextEmbeddingGeneration(
serviceProvider => new OpenAITextEmbeddingGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
httpClient));
}

public static IServiceCollection AddOpenAITextGeneration(
this IServiceCollection services,
OpenAIConfig config,
ITextTokenizer? textTokenizer = null)
ITextTokenizer? textTokenizer = null,
HttpClient? httpClient = null)
{
config.Validate();
textTokenizer ??= new DefaultGPTTokenizer();
Expand All @@ -166,6 +179,7 @@ public static IServiceCollection AddOpenAITextGeneration(
.AddSingleton<ITextGenerator, OpenAITextGenerator>(serviceProvider => new OpenAITextGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
loggerFactory: serviceProvider.GetService<ILoggerFactory>(),
httpClient));
}
}
12 changes: 8 additions & 4 deletions extensions/OpenAI/OpenAITextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand All @@ -18,15 +19,17 @@ public class OpenAITextEmbeddingGenerator : ITextEmbeddingGenerator
public OpenAITextEmbeddingGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<OpenAITextEmbeddingGenerator>())
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
: this(config, textTokenizer, loggerFactory?.CreateLogger<OpenAITextEmbeddingGenerator>(), httpClient)
{
}

public OpenAITextEmbeddingGenerator(
OpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILogger<OpenAITextEmbeddingGenerator>? log = null)
ILogger<OpenAITextEmbeddingGenerator>? log = null,
HttpClient? httpClient = null)
{
this._log = log ?? DefaultLogger<OpenAITextEmbeddingGenerator>.Instance;

Expand All @@ -45,7 +48,8 @@ public OpenAITextEmbeddingGenerator(
this._client = new OpenAITextEmbeddingGenerationService(
modelId: config.EmbeddingModel,
apiKey: config.APIKey,
organization: config.OrgId);
organization: config.OrgId,
httpClient: httpClient);
}

/// <inheritdoc/>
Expand Down
Loading