Skip to content

Commit 7a3236a

Browse files
committed
Implement tools handling
1 parent 3bb3aa8 commit 7a3236a

File tree

2 files changed

+100
-11
lines changed

2 files changed

+100
-11
lines changed

readme.MD

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ Implemented via [`inlineData`](https://ai.google.dev/api/caching#Part).
135135

136136
- [x] `chat/completions`
137137

138-
Currently, most of the parameters that are applicable to both APIs have been implemented,
139-
with the exception of function calls.
138+
Currently, most of the parameters that are applicable to both APIs have been implemented.
140139
<details>
141140

142141
- [x] `messages`
@@ -145,9 +144,9 @@ Implemented via [`inlineData`](https://ai.google.dev/api/caching#Part).
145144
- [x] "system" (=>`system_instruction`)
146145
- [x] "user"
147146
- [x] "assistant"
148-
- [ ] "tool" (v1beta)
147+
- [x] "tool"
149148
- [ ] `name`
150-
- [ ] `tool_calls`
149+
- [x] `tool_calls`
151150
- [x] `model`
152151
- [x] `frequency_penalty`
153152
- [ ] `logit_bias`
@@ -167,9 +166,9 @@ Implemented via [`inlineData`](https://ai.google.dev/api/caching#Part).
167166
- [x] `include_usage`
168167
- [x] `temperature` (0.0..2.0 for OpenAI, but Gemini supports up to infinity)
169168
- [x] `top_p`
170-
- [ ] `tools` (v1beta)
171-
- [ ] `tool_choice` (v1beta)
172-
- [ ] `parallel_tool_calls`
169+
- [x] `tools`
170+
- [x] `tool_choice`
171+
- [ ] `parallel_tool_calls` (is always active in Gemini)
173172

174173
</details>
175174
- [ ] `completions`

src/worker.mjs

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,52 @@ const parseImg = async (url) => {
287287
};
288288
};
289289

290-
const transformMsg = async ({ content }) => {
290+
const transformMsg = async ({ content, tool_calls, tool_call_id }, fnames) => {
291291
const parts = [];
292+
if (tool_call_id !== undefined) {
293+
let response;
294+
try {
295+
response = JSON.parse(content);
296+
} catch (err) {
297+
console.error("Error parsing function response content:", err);
298+
throw new HttpError("Invalid function response: " + content, 400);
299+
}
300+
if (typeof response !== "object" || response === null || Array.isArray(response)) {
301+
response = { result: response };
302+
}
303+
parts.push({
304+
functionResponse: {
305+
id: tool_call_id.startsWith("{") ? null : tool_call_id,
306+
name: fnames[tool_call_id],
307+
response,
308+
}
309+
});
310+
return parts;
311+
}
312+
if (tool_calls) {
313+
for (const tcall of tool_calls) {
314+
if (tcall.type !== "function") {
315+
throw new HttpError(`Unsupported tool_call type: "${tcall.type}"`, 400);
316+
}
317+
const { function: { arguments: argstr, name }, id } = tcall;
318+
let args;
319+
try {
320+
args = JSON.parse(argstr);
321+
} catch (err) {
322+
console.error("Error parsing function arguments:", err);
323+
throw new HttpError("Invalid function arguments: " + argstr, 400);
324+
}
325+
parts.push({
326+
functionCall: {
327+
id: id.startsWith("{") ? null : id,
328+
name,
329+
args,
330+
}
331+
});
332+
fnames[id] = name;
333+
}
334+
return parts;
335+
}
292336
if (!Array.isArray(content)) {
293337
// system, user: string
294338
// assistant: string or null (Required unless tool_calls is specified.)
@@ -329,18 +373,26 @@ const transformMessages = async (messages) => {
329373
if (!messages) { return; }
330374
const contents = [];
331375
let system_instruction;
376+
const fnames = {}; // cache function names by tool_call_id between messages
332377
for (const item of messages) {
333378
if (item.role === "system") {
334379
system_instruction = { parts: await transformMsg(item) };
335380
} else {
336381
if (item.role === "assistant") {
337382
item.role = "model";
383+
} else if (item.role === "tool") {
384+
const prev = contents[contents.length - 1];
385+
if (prev?.role === "function") {
386+
prev.parts.push(...await transformMsg(item, fnames));
387+
continue;
388+
}
389+
item.role = "function"; // ignored
338390
} else if (item.role !== "user") {
339391
throw HttpError(`Unknown message role: "${item.role}"`, 400);
340392
}
341393
contents.push({
342394
role: item.role,
343-
parts: await transformMsg(item)
395+
parts: await transformMsg(item, fnames)
344396
});
345397
}
346398
}
@@ -351,10 +403,32 @@ const transformMessages = async (messages) => {
351403
return { system_instruction, contents };
352404
};
353405

406+
const transformTools = (req) => {
407+
let tools, tool_config;
408+
if (req.tools) {
409+
const funcs = req.tools.filter(tool => tool.type === "function");
410+
funcs.forEach(adjustSchema);
411+
tools = [{ function_declarations: funcs.map(schema => schema.function) }];
412+
}
413+
if (req.tool_choice) {
414+
const allowed_function_names = req.tool_choice?.type === "function" ? [ req.tool_choice?.function?.name ] : undefined;
415+
if (allowed_function_names || typeof req.tool_choice === "string") {
416+
tool_config = {
417+
function_calling_config: {
418+
mode: allowed_function_names ? "ANY" : req.tool_choice.toUpperCase(),
419+
allowed_function_names
420+
}
421+
};
422+
}
423+
}
424+
return { tools, tool_config };
425+
};
426+
354427
const transformRequest = async (req) => ({
355428
...await transformMessages(req.messages),
356429
safetySettings,
357430
generationConfig: transformConfig(req),
431+
...transformTools(req),
358432
});
359433

360434
const generateChatcmplId = () => {
@@ -370,13 +444,25 @@ const reasonsMap = { //https://ai.google.dev/api/rest/v1/GenerateContentResponse
370444
"SAFETY": "content_filter",
371445
"RECITATION": "content_filter",
372446
//"OTHER": "OTHER",
373-
// :"function_call",
374447
};
375448
const SEP = "\n\n|>";
376449
const transformCandidates = (key, cand) => {
377450
const message = { role: "assistant", content: [] };
378451
for (const part of cand.content?.parts ?? []) {
379-
message.content.push(part.text);
452+
if (part.functionCall) {
453+
const fc = part.functionCall;
454+
message.tool_calls = message.tool_calls ?? [];
455+
message.tool_calls.push({
456+
id: fc.id ?? `{${fc.name}}`,
457+
type: "function",
458+
function: {
459+
name: fc.name,
460+
arguments: JSON.stringify(fc.args),
461+
}
462+
});
463+
} else {
464+
message.content.push(part.text);
465+
}
380466
}
381467
message.content = message.content.join(SEP) || null;
382468
return {
@@ -430,10 +516,14 @@ function transformResponseStream (data, special) {
430516
const item = transformCandidatesDelta(data.candidates[0]);
431517
switch (special) {
432518
case "stop":
519+
if (item.delta.tool_calls) {
520+
item.finish_reason = "tool_calls";
521+
}
433522
item.delta = {};
434523
break;
435524
case "first":
436525
item.delta.content = "";
526+
delete item.delta.tool_calls;
437527
// eslint-disable-next-line no-fallthrough
438528
default:
439529
delete item.delta.role;

0 commit comments

Comments
 (0)