Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions LLama.KernelMemory/LlamaSharpTextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
Expand All @@ -107,8 +107,8 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
Expand Down
4 changes: 2 additions & 2 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
AlphaPresence = (float)requestSettings.PresencePenalty,
AlphaFrequency = (float)requestSettings.FrequencyPenalty,
PresencePenalty = (float)requestSettings.PresencePenalty,
FrequencyPenalty = (float)requestSettings.FrequencyPenalty,
}
};
}
Expand Down
6 changes: 3 additions & 3 deletions LLama/Extensions/LLamaExecutorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ private string CreatePrompt(IList<ChatMessage> messages)
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
SamplingPipeline = new DefaultSamplingPipeline()
{
AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency,
AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence,
PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS,
FrequencyPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.FrequencyPenalty), out float af) is true ? af : s_defaultPipeline.FrequencyPenalty,
PresencePenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PresencePenalty), out float ap) is true ? ap : s_defaultPipeline.PresencePenalty,
PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS,
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount,
Expand Down
79 changes: 66 additions & 13 deletions LLama/Sampling/DefaultSamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,76 @@ public sealed class DefaultSamplingPipeline
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
[Obsolete($"Use {nameof(FrequencyPenalty)} instead.")]
public float AlphaFrequency
{
get => _alphaFreq;
get => _frequencyPenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaFrequency)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaFreq = value;
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaFrequency)} must be less than 2");
_frequencyPenalty = value;
}
}
private readonly float _alphaFreq;

/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
[Obsolete($"Use {nameof(PresencePenalty)} instead.")]
public float AlphaPresence
{
get => _alphaPresence;
get => _presencePenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaPresence)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaPresence = value;
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(AlphaPresence)} must be less than 2");
_presencePenalty = value;
}
}
private readonly float _alphaPresence;

/// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
public float FrequencyPenalty
{
get => _frequencyPenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(FrequencyPenalty)} must be less than 2");
_frequencyPenalty = value;
}
}
private readonly float _frequencyPenalty;

/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
public float PresencePenalty
{
get => _presencePenalty;
init
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), $"{nameof(PresencePenalty)} must be less than 2");
_presencePenalty = value;
}
}
private readonly float _presencePenalty;

/// <summary>
/// How many tokens should be considered for penalizing repetition
Expand All @@ -71,8 +109,14 @@ public float AlphaPresence
/// <summary>
/// Whether the EOS token should be protected from being modified by penalty
/// </summary>
[Obsolete($"This doesn't do what the name implies. If you're sure you want to use it, use {nameof(PreventEOS)}.")]
public bool PenalizeEOS { get; init; } = false;

/// <summary>
/// Whether the EOS token should be suppressed. Setting this to 'true' prevents EOS from being sampled
/// </summary>
public bool PreventEOS { get; init; } = false;

/// <summary>
/// Temperature to apply (higher temperature is more "creative")
/// </summary>
Expand Down Expand Up @@ -111,7 +155,16 @@ public float AlphaPresence
/// <summary>
/// Seed to use for random sampling
/// </summary>
public uint Seed { get; set; } = 42;
public uint Seed { get; set; } = GetRandomSeed();


private static Random RandomSeedGenerator = new();
private static uint GetRandomSeed()
{
lock (RandomSeedGenerator)
return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue);
}


/// <inheritdoc />
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
Expand Down Expand Up @@ -147,8 +200,8 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
context.VocabCount,
context.ModelHandle.Tokens.EOS, context.ModelHandle.Tokens.Newline ?? 0,
RepeatPenaltyCount, RepeatPenalty,
AlphaFrequency, AlphaPresence,
PenalizeNewline, PenalizeEOS
FrequencyPenalty, PresencePenalty,
PenalizeNewline, PreventEOS
);

chain.AddTopK(TopK);
Expand Down
Loading