Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
import PackageDescription
import CompilerPluginSupport

let llamaVersion = "b6519"
let llamaVersion = "b6628"

// MARK: - Package Dependencies

var packageDependencies: [Package.Dependency] = [
.package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")),
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.2.0")),
.package(url: "https://github.com/huggingface/swift-jinja", .upToNextMinor(from: "2.0.0")),
.package(url: "https://github.com/swiftlang/swift-syntax", from: "600.0.0")
]

#if os(iOS) || os(macOS)
packageDependencies.append(contentsOf: [
.package(url: "https://github.com/huggingface/swift-transformers", .upToNextMinor(from: "0.1.21")),
.package(url: "https://github.com/ml-explore/mlx-swift-examples", branch: "main"),
.package(url: "https://github.com/apple/swift-docc-plugin", from: "1.4.0")
])
Expand Down Expand Up @@ -58,7 +57,8 @@ var packageTargets: [Target] = [
.target(
name: "LocalLLMClientCore",
dependencies: [
"LocalLLMClientUtility"
"LocalLLMClientUtility",
.product(name: "Jinja", package: "swift-jinja")
]
),

Expand Down Expand Up @@ -112,8 +112,7 @@ packageTargets.append(contentsOf: [
name: "LocalLLMClientLlama",
dependencies: [
"LocalLLMClientCore",
"LocalLLMClientLlamaC",
.product(name: "Jinja", package: "Jinja")
"LocalLLMClientLlamaC"
],
resources: [.process("Resources")],
swiftSettings: (Context.environment["BUILD_DOCC"] == nil ? [] : [
Expand Down Expand Up @@ -155,7 +154,7 @@ packageTargets.append(contentsOf: [
name: "LocalLLMClientLlamaFramework",
url:
"https://github.com/ggml-org/llama.cpp/releases/download/\(llamaVersion)/llama-\(llamaVersion)-xcframework.zip",
checksum: "df054bcebc3363f2e21d0c9a18a7f9e0ee4e5ff44d458fa665b0cc0cc64d6fde"
checksum: "b25aad9f424ecb7d589d843deabd0ebc0f0d6f9ea126f7d31bb1ac5204543fa9"
),
.target(
name: "LocalLLMClientLlamaC",
Expand Down Expand Up @@ -203,8 +202,7 @@ packageTargets.append(contentsOf: [
name: "LocalLLMClientLlama",
dependencies: [
"LocalLLMClientCore",
"LocalLLMClientLlamaC",
.product(name: "Jinja", package: "Jinja")
"LocalLLMClientLlamaC"
],
resources: [.process("Resources")],
swiftSettings: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Jinja
struct TemplateContext {
let specialTokens: [String: String]
let additionalContext: [String: Any]

init(
specialTokens: [String: String] = [:],
additionalContext: [String: Any] = [:]
Expand All @@ -30,11 +30,11 @@ struct TemplateContext {
/// Standard Jinja-based template renderer
struct JinjaChatTemplateRenderer: ChatTemplateRenderer {
private let toolProcessor: ToolInstructionProcessor

init(toolProcessor: ToolInstructionProcessor = StandardToolInstructionProcessor()) {
self.toolProcessor = toolProcessor
}

func render(
messages: [LLMInput.ChatTemplateMessage],
template: String,
Expand All @@ -47,56 +47,63 @@ struct TemplateContext {
} catch {
throw LLMError.invalidParameter(reason: "Failed to parse template: \(error.localizedDescription)")
}

// Extract message data
var messagesData = messages.map(\.value)

// Process tool instructions if needed
let hasNativeToolSupport = toolProcessor.hasNativeToolSupport(in: template)
messagesData = try toolProcessor.processMessages(
messagesData,
tools: tools,
templateHasNativeSupport: hasNativeToolSupport
)

// Build template context
let templateContext = buildTemplateContext(
let templateContext = try buildTemplateContext(
messages: messagesData,
tools: tools,
hasNativeToolSupport: hasNativeToolSupport,
context: context
)

// Render template
do {
return try jinjaTemplate.render(templateContext)
let environment = Environment()
environment.lstripBlocks = true
environment.trimBlocks = true
return try jinjaTemplate.render(templateContext, environment: environment)
} catch {
throw LLMError.invalidParameter(reason: "Failed to render template: \(error.localizedDescription)")
}
}

private func buildTemplateContext(
messages: [[String: any Sendable]],
tools: [AnyLLMTool],
hasNativeToolSupport: Bool,
context: TemplateContext
) -> [String: Any] {
var templateContext: [String: Any] = [
"add_generation_prompt": true,
"messages": messages
]

// Add special tokens
templateContext.merge(context.specialTokens) { _, new in new }

// Add tools for templates with native support
if !tools.isEmpty && hasNativeToolSupport {
templateContext["tools"] = tools.compactMap { $0.toOAICompatJSON() }
) throws(LLMError) -> [String: Value] {
do {
var templateContext: [String: Value] = [
"add_generation_prompt": .boolean(true),
"messages": try Value(any: messages)
]

// Add special tokens
try templateContext.merge(context.specialTokens.mapValues { try Value(any: $0) }) { _, new in new }

// Add tools for templates with native support
if !tools.isEmpty && hasNativeToolSupport {
templateContext["tools"] = try Value(any: tools.compactMap { $0.toOAICompatJSON() })
}

// Add additional context
templateContext.merge(try context.additionalContext.mapValues { try Value(any: $0) }) { _, new in new }

return templateContext
} catch {
throw LLMError.invalidParameter(reason: "Failed to build template context: \(error.localizedDescription)")
}

// Add additional context
templateContext.merge(context.additionalContext) { _, new in new }

return templateContext
}
}
2 changes: 1 addition & 1 deletion Sources/LocalLLMClientLlamaC/exclude/llama.cpp
Submodule llama.cpp updated 365 files
4 changes: 2 additions & 2 deletions Tests/LocalLLMClientLlamaTests/MessageProcessorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct MessageProcessorTests {
.assistant(assistantMarker),
])
#expect(rendered.contains("<|start_header_id|>") && rendered.contains("<|end_header_id|>"))
#expect(chunks == [.text(" <|start_header_id|>user<|end_header_id|>\n\n\(userMarker)"), .image([.testImage]), .text("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\(assistantMarker)<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")])
#expect(chunks == [.text(" <|start_header_id|>user<|end_header_id|>\n\n\(userMarker)"), .image([.testImage]), .text("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\(assistantMarker)<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")])
}

@Test
Expand All @@ -60,7 +60,7 @@ struct MessageProcessorTests {
let processor = MessageProcessorFactory.llama32VisionProcessor()
let (rendered, chunks) = try validate(processor: processor, chatTemplate: template)
#expect(rendered.contains("<|header_start|>") && rendered.contains("<|header_end|>"))
#expect(chunks == [.text("<|header_start|>system<|header_end|>\n\n\(systemMarker)<|eot|><|header_start|>user<|header_end|>\n\n\(userMarker)"), .image([.testImage]), .text("<|eot|><|header_start|>assistant<|header_end|>\n\n\(assistantMarker)<|eot|><|header_start|>assistant<|header_end|>\n\n")])
#expect(chunks == [.text(" <|header_start|>system<|header_end|>\n\n\(systemMarker)<|eot|> <|header_start|>user<|header_end|>\n\n\(userMarker)"), .image([.testImage]), .text("<|eot|><|header_start|>assistant<|header_end|>\n\n\(assistantMarker)<|eot|><|header_start|>assistant<|header_end|>\n\n")])
}

@Test
Expand Down
Loading