Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -171,7 +172,7 @@ public SessionState GetSessionState()
{
var executorState = ((StatefulExecutorBase)Executor).GetStateData();
return new SessionState(
executorState.PastTokensCount > 0
executorState.PastTokensCount > 0
? Executor.Context.GetState() : null,
executorState,
History,
Expand Down Expand Up @@ -227,7 +228,7 @@ public void LoadSession(string path, bool loadTransforms = true)
if (state.ExecutorState is null)
{
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
((StatefulExecutorBase)Executor).LoadState(filename: executorPath);
}
LoadSession(state, loadTransforms);
}
Expand Down Expand Up @@ -441,21 +442,21 @@ public async IAsyncEnumerable<string> ChatAsync(
prompt = HistoryTransform.HistoryToText(singleMessageHistory);
}

string assistantMessage = string.Empty;
StringBuilder assistantMessage = new();

await foreach (
string textToken
in ChatAsyncInternal(
prompt,
inferenceParams,
cancellationToken))
try
{
assistantMessage += textToken;
yield return textToken;
await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
assistantMessage.Append(textToken);
yield return textToken;
}
}
finally
{
// Add the assistant message to the history
AddAssistantMessage(assistantMessage.ToString());
}

// Add the assistant message to the history
AddAssistantMessage(assistantMessage);
}

/// <summary>
Expand Down Expand Up @@ -624,7 +625,7 @@ public record SessionState
/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public ITextTransform[] InputTransformPipeline { get; set; } = [ ];
public ITextTransform[] InputTransformPipeline { get; set; } = [];

/// <summary>
/// The output transform used in this session.
Expand All @@ -635,11 +636,11 @@ public record SessionState
/// The history transform used in this session.
/// </summary>
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();

/// <summary>
/// The chat history messages for this session.
/// </summary>
public ChatHistory.Message[] History { get; set; } = [ ];
public ChatHistory.Message[] History { get; set; } = [];

/// <summary>
/// Create a new session state.
Expand All @@ -651,7 +652,7 @@ public record SessionState
/// <param name="outputTransform"></param>
/// <param name="historyTransform"></param>
public SessionState(
State? contextState, ExecutorBaseState executorState,
State? contextState, ExecutorBaseState executorState,
ChatHistory history, List<ITextTransform> inputTransformPipeline,
ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
{
Expand Down Expand Up @@ -738,22 +739,22 @@ public static SessionState Load(string path)
ITextTransform[] inputTransforms;
try
{
inputTransforms = File.Exists(inputTransformFilepath) ?
inputTransforms = File.Exists(inputTransformFilepath) ?
(JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
: [ ];
: [];
}
catch (JsonException)
{
throw new ArgumentException("Input transform file is invalid", nameof(path));
}

string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);

ITextStreamTransform outputTransform;
try
{
outputTransform = File.Exists(outputTransformFilepath) ?
outputTransform = File.Exists(outputTransformFilepath) ?
(JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
: new LLamaTransforms.EmptyTextOutputStreamTransform();
Expand All @@ -767,7 +768,7 @@ public static SessionState Load(string path)
IHistoryTransform historyTransform;
try
{
historyTransform = File.Exists(historyTransformFilepath) ?
historyTransform = File.Exists(historyTransformFilepath) ?
(JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
?? throw new ArgumentException("History transform file is invalid", nameof(path)))
: new LLamaTransforms.DefaultHistoryTransform();
Expand Down
Loading