Skip to content

Commit 68ee7ef

Browse files
authored
Search client update (#344)
## Motivation and Context (Why the change? What's the scenario?) With the exception of `MaxTokens`, properties of the TextGenerationOptions object used by `AskAsync` are hard-coded. ## High level description (Approach, Design) Complete the work started by the PR 341, updating SearchClient.cs and the corresponding text generators to use the new LLM request settings that are now available in the SearchClientConfig.cs.
1 parent 023bdde commit 68ee7ef

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

extensions/AzureOpenAI/AzureOpenAITextGenerator.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ public async IAsyncEnumerable<string> GenerateTextAsync(
138138
foreach (var s in options.StopSequences) { openaiOptions.StopSequences.Add(s); }
139139
}
140140

141+
if (options.TokenSelectionBiases is { Count: > 0 })
142+
{
143+
foreach (var (token, bias) in options.TokenSelectionBiases) { openaiOptions.TokenSelectionBiases.Add(token, (int)bias); }
144+
}
145+
141146
StreamingResponse<Completions>? response = await this._client.GetCompletionsStreamingAsync(openaiOptions, cancellationToken).ConfigureAwait(false);
142147
await foreach (Completions? completions in response.EnumerateValues().WithCancellation(cancellationToken).ConfigureAwait(false))
143148
{
@@ -165,6 +170,11 @@ public async IAsyncEnumerable<string> GenerateTextAsync(
165170
foreach (var s in options.StopSequences) { openaiOptions.StopSequences.Add(s); }
166171
}
167172

173+
if (options.TokenSelectionBiases is { Count: > 0 })
174+
{
175+
foreach (var (token, bias) in options.TokenSelectionBiases) { openaiOptions.TokenSelectionBiases.Add(token, (int)bias); }
176+
}
177+
168178
openaiOptions.Messages.Add(new ChatRequestSystemMessage(prompt));
169179

170180
StreamingResponse<StreamingChatCompletionsUpdate>? response = await this._client.GetChatCompletionsStreamingAsync(openaiOptions, cancellationToken).ConfigureAwait(false);

extensions/LlamaSharp/LlamaSharp/LlamaSharpTextGenerator.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public IAsyncEnumerable<string> GenerateTextAsync(
106106
TopP = (float)options.TopP,
107107
PresencePenalty = (float)options.PresencePenalty,
108108
FrequencyPenalty = (float)options.FrequencyPenalty,
109-
AntiPrompts = options.StopSequences.ToList(),
109+
AntiPrompts = options.StopSequences?.ToList() ?? new(),
110110
LogitBias = new(),
111111
// RepeatLastTokensCount = 0, // [int] last n tokens to penalize (0 = disable penalty, -1 = context size)
112112
// TopK = 0, // [int] The number of highest probability vocabulary tokens to keep for top-k-filtering.
@@ -121,9 +121,12 @@ public IAsyncEnumerable<string> GenerateTextAsync(
121121
// Grammar = null // SafeLLamaGrammarHandle
122122
};
123123

124-
foreach (KeyValuePair<int, float> b in options.TokenSelectionBiases)
124+
if (options.TokenSelectionBiases is { Count: > 0 })
125125
{
126-
settings.LogitBias!.Add(b.Key, b.Value);
126+
foreach (var (token, bias) in options.TokenSelectionBiases)
127+
{
128+
settings.LogitBias!.Add(token, bias);
129+
}
127130
}
128131

129132
return executor.InferAsync(prompt, settings, cancellationToken);

extensions/OpenAI/OpenAITextGenerator.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ public async IAsyncEnumerable<string> GenerateTextAsync(
119119
foreach (var s in options.StopSequences) { openaiOptions.StopSequences.Add(s); }
120120
}
121121

122+
if (options.TokenSelectionBiases is { Count: > 0 })
123+
{
124+
foreach (var (token, bias) in options.TokenSelectionBiases) { openaiOptions.TokenSelectionBiases.Add(token, (int)bias); }
125+
}
126+
122127
StreamingResponse<Completions>? response = await this._client.GetCompletionsStreamingAsync(openaiOptions, cancellationToken).ConfigureAwait(false);
123128
await foreach (Completions? completions in response.EnumerateValues().WithCancellation(cancellationToken).ConfigureAwait(false))
124129
{
@@ -146,6 +151,11 @@ public async IAsyncEnumerable<string> GenerateTextAsync(
146151
foreach (var s in options.StopSequences) { openaiOptions.StopSequences.Add(s); }
147152
}
148153

154+
if (options.TokenSelectionBiases is { Count: > 0 })
155+
{
156+
foreach (var (token, bias) in options.TokenSelectionBiases) { openaiOptions.TokenSelectionBiases.Add(token, (int)bias); }
157+
}
158+
149159
openaiOptions.Messages.Add(new ChatRequestSystemMessage(prompt));
150160

151161
StreamingResponse<StreamingChatCompletionsUpdate>? response = await this._client.GetChatCompletionsStreamingAsync(openaiOptions, cancellationToken).ConfigureAwait(false);

service/Core/Search/SearchClient.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,15 @@ private IAsyncEnumerable<string> GenerateAnswerAsync(string question, string fac
337337

338338
prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase);
339339

340-
// TODO: receive options from API: https://github.com/microsoft/kernel-memory/issues/137
341340
var options = new TextGenerationOptions
342341
{
343-
// Temperature = 0,
344-
// TopP = 0,
345-
// PresencePenalty = 0,
346-
// FrequencyPenalty = 0,
342+
Temperature = this._config.Temperature,
343+
TopP = this._config.TopP,
344+
PresencePenalty = this._config.PresencePenalty,
345+
FrequencyPenalty = this._config.FrequencyPenalty,
347346
MaxTokens = this._config.AnswerTokens,
348-
// StopSequences = null,
349-
// TokenSelectionBiases = null
347+
StopSequences = this._config.StopSequences,
348+
TokenSelectionBiases = this._config.TokenSelectionBiases,
350349
};
351350

352351
if (this._log.IsEnabled(LogLevel.Debug))

0 commit comments

Comments
 (0)