Migrate gemini agents away from Untooled (#4505)

* Migrate gemini agents away from `Untooled`

* disable agents for gemma models as they are not supported for tool calling

* Dev build
resolve #4452 via function name prefix and then stripping within provider
This commit is contained in:
Timothy Carambat 2025-10-07 11:40:00 -07:00 committed by GitHub
parent cf3fbcbf0f
commit 0ee0a96506
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 213 additions and 168 deletions

View File

@ -6,7 +6,7 @@ concurrency:
on: on:
push: push:
branches: ['improve-url-handler-collector'] # put your current branch to create a build. Core team only. branches: ['gemini-migration-agents'] # put your current branch to create a build. Core team only.
paths-ignore: paths-ignore:
- '**.md' - '**.md'
- 'cloud-deployments/*' - 'cloud-deployments/*'

View File

@ -1,19 +1,14 @@
const OpenAI = require("openai"); const OpenAI = require("openai");
const Provider = require("./ai-provider.js"); const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js"); const { RetryError } = require("../error.js");
const UnTooled = require("./helpers/untooled.js");
const {
NO_SYSTEM_PROMPT_MODELS,
} = require("../../../AiProviders/gemini/index.js");
const { APIError } = require("../error.js");
const { v4 } = require("uuid");
const { safeJsonParse } = require("../../../http"); const { safeJsonParse } = require("../../../http");
const { v4 } = require("uuid");
/** /**
* The agent provider for the Gemini provider. * The agent provider for the Gemini provider.
* We wrap Gemini in UnTooled because its tool-calling is not supported via the dedicated OpenAI API. * We wrap Gemini in UnTooled because its tool-calling is not supported via the dedicated OpenAI API.
*/ */
class GeminiProvider extends InheritMultiple([Provider, UnTooled]) { class GeminiProvider extends Provider {
model; model;
constructor(config = {}) { constructor(config = {}) {
@ -35,6 +30,11 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) {
return this._client; return this._client;
} }
get supportsToolCalling() {
if (!this.model.startsWith("gemini")) return false;
return true;
}
get supportsAgentStreaming() { get supportsAgentStreaming() {
// Tool call streaming results in a 400/503 error for all non-gemini models // Tool call streaming results in a 400/503 error for all non-gemini models
// using the compatible v1beta/openai/ endpoint // using the compatible v1beta/openai/ endpoint
@ -44,205 +44,250 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) {
); );
return false; return false;
} }
return true; return true;
} }
/** /**
* Format the messages to the format required by the Gemini API since some models do not support system prompts. * Gemini specifcally will throw an error if the tool call's function name
* @see {NO_SYSTEM_PROMPT_MODELS} * starts with a non-alpha character. So we need to prefix the function names
* @param {import("openai").OpenAI.ChatCompletionMessage[]} messages * with a valid prefix to ensure they are always valid and then strip them back
* @returns {import("openai").OpenAI.ChatCompletionMessage[]} * so they may properly be used in the tool call.
*
* So for all tools, we force the prefix to be gtc__ to avoid issues
* Agent flows are already prefixed with flow__ but since we strip the prefix
* anyway pre and post-reply, we do it anyway to ensure consistency across all tools.
*
* This specifically impacts the custom Agent Skills since they can be a short alphanumeric
* and cant definitely start with a number. eg: '12xdaya31bas' -> invalid in gemini tools.
*
* Even if the tool is never called, if it is in the `tools` array and this prefix
* patch is not applied, gemini will throw an error.
*
* This is undocumented by google, but it is the only way to ensure that tool calls
* are valid.
*
* @param {string} functionName - The name of the function to prefix.
* @param {'add' | 'strip'} action - The action to take.
* @returns {string} The prefixed function name.
* @returns {string} The prefix to use for tool call ids.
*/ */
formatMessages(messages) { prefixToolCall(functionName, action = "add") {
if (!NO_SYSTEM_PROMPT_MODELS.includes(this.model)) return messages; if (action === "add") return `gtc__${functionName}`;
// must start with gtc__ to be valid and we only strip the first instance
// Replace the system message with a user/assistant message pair return functionName.startsWith("gtc__")
const formattedMessages = []; ? functionName.split("gtc__")[1]
for (const message of messages) { : functionName;
if (message.role === "system") {
formattedMessages.push({
role: "user",
content: message.content,
});
formattedMessages.push({
role: "assistant",
content: "Okay, I'll follow your instructions.",
});
continue;
}
formattedMessages.push(message);
}
return formattedMessages;
} }
/** /**
* Format the functions for the LLM. * Format the messages to the Gemini API Responses format.
* @param {any[]} functions - The functions to format. * - Gemini has some loosely documented format for tool calls and it can change at any time.
* @returns {any[]} - The formatted functions. * - We need to map the function call to the correct id and Gemini will throw an error if it does not.
* @param {any[]} messages - The messages to format.
* @returns {OpenAI.OpenAI.Responses.ResponseInput[]} The formatted messages.
*/ */
formatFunctions(functions = []) { #formatMessages(messages) {
return functions.map((fn) => ({ let formattedMessages = [];
messages.forEach((message) => {
if (message.role === "function") {
// If the message does not have an originalFunctionCall we cannot
// map it to a function call id and Gemini will throw an error.
// so if this does not carry over - log and skip
if (!message.hasOwnProperty("originalFunctionCall")) {
this.providerLog(
"[Gemini.#formatMessages]: message did not pass back the originalFunctionCall. We need this to map the function call to the correct id.",
{ message: JSON.stringify(message, null, 2) }
);
return;
}
formattedMessages.push(
{
role: "assistant",
tool_calls: [
{
type: "function",
function: {
arguments: JSON.stringify(
message.originalFunctionCall.arguments
),
name: message.originalFunctionCall.name,
},
id: message.originalFunctionCall.id,
},
],
},
{
role: "tool",
tool_call_id: message.originalFunctionCall.id,
content: message.content,
}
);
return;
}
formattedMessages.push({
role: message.role,
content: message.content,
});
});
return formattedMessages;
}
#formatFunctions(functions) {
return functions.map((func) => ({
type: "function", type: "function",
function: { function: {
name: fn.name, name: this.prefixToolCall(func.name, "add"),
description: fn.description, description: func.description,
parameters: { parameters: func.parameters,
type: "object",
properties: fn.parameters.properties,
},
}, },
})); }));
} }
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
messages: this.cleanMsgs(this.formatMessages(messages)),
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("Gemini chat: No results!");
if (result.choices.length === 0)
throw new Error("Gemini chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Streaming for Gemini only supports `tools` and not `functions`, so
* we need to apply some transformations to the messages and functions.
*
* @see {formatFunctions}
* @param {*} messages
* @param {*} functions
* @param {*} eventHandler
* @returns
*/
async stream(messages, functions = [], eventHandler = null) { async stream(messages, functions = [], eventHandler = null) {
const msgUUID = v4(); if (!this.supportsToolCalling)
const stream = await this.client.chat.completions.create({ throw new Error(`Gemini: ${this.model} does not support tool calling.`);
model: this.model, this.providerLog("Gemini.stream - will process this chat completion.");
stream: true, try {
messages: this.cleanMsgs(this.formatMessages(messages)), const msgUUID = v4();
...(Array.isArray(functions) && functions?.length > 0 /** @type {OpenAI.OpenAI.Chat.ChatCompletion} */
? { const response = await this.client.chat.completions.create({
tools: this.formatFunctions(functions), model: this.model,
tool_choice: "auto", messages: this.#formatMessages(messages),
} stream: true,
: {}), ...(Array.isArray(functions) && functions?.length > 0
}); ? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
: {}),
});
const result = { const completion = {
functionCall: null, content: "",
textResponse: "", /** @type {null|{name: string, call_id: string, arguments: string|object}} */
}; functionCall: null,
};
for await (const chunk of stream) { for await (const streamEvent of response) {
if (!chunk?.choices?.[0]) continue; // Skip if no choices /** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */
const choice = chunk.choices[0]; const chunk = streamEvent;
const { content, tool_calls } = chunk?.choices?.[0]?.delta || {};
if (choice.delta?.content) { if (content) {
result.textResponse += choice.delta.content; completion.content += content;
eventHandler?.("reportStreamEvent", { eventHandler?.("reportStreamEvent", {
type: "textResponseChunk", type: "textResponseChunk",
uuid: msgUUID, uuid: msgUUID,
content: choice.delta.content, content,
}); });
}
if (choice.delta?.tool_calls && choice.delta.tool_calls.length > 0) {
const toolCall = choice.delta.tool_calls[0];
if (result.functionCall)
result.functionCall.arguments += toolCall.function.arguments;
else {
result.functionCall = {
name: toolCall.function.name,
arguments: toolCall.function.arguments,
};
} }
eventHandler?.("reportStreamEvent", { if (tool_calls) {
uuid: `${msgUUID}:tool_call_invocation`, const toolCall = tool_calls[0];
type: "toolCallInvocation", completion.functionCall = {
content: `Assembling Tool Call: ${result.functionCall.name}(${result.functionCall.arguments})`, name: this.prefixToolCall(toolCall.function.name, "strip"),
}); call_id: toolCall.id,
arguments: toolCall.function.arguments,
};
eventHandler?.("reportStreamEvent", {
type: "toolCallInvocation",
uuid: `${msgUUID}:tool_call_invocation`,
content: `Assembling Tool Call: ${completion.functionCall.name}(${completion.functionCall.arguments})`,
});
}
} }
}
// If there are arguments, parse them as json so that the tools can use them if (completion.functionCall) {
if (!!result.functionCall?.arguments) completion.functionCall.arguments = safeJsonParse(
result.functionCall.arguments = safeJsonParse( completion.functionCall.arguments,
result.functionCall.arguments, {}
{} );
); return {
return result; textResponse: completion.content,
functionCall: {
id: completion.functionCall.call_id,
name: completion.functionCall.name,
arguments: completion.functionCall.arguments,
},
cost: this.getCost(),
};
}
return {
textResponse: completion.content,
functionCall: null,
cost: this.getCost(),
};
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error;
}
} }
/** /**
* Create a completion based on the received messages. * Create a completion based on the received messages.
* *
* TODO: see stream() - tool_calls are now supported, so we can use that instead of Untooled * @param messages A list of messages to send to the Gemini API.
*
* @param messages A list of messages to send to the API.
* @param functions * @param functions
* @returns The completion. * @returns The completion.
*/ */
async complete(messages, functions = []) { async complete(messages, functions = []) {
if (!this.supportsToolCalling)
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
this.providerLog("Gemini.complete - will process this chat completion.");
try { try {
let completion; const response = await this.client.chat.completions.create({
model: this.model,
stream: false,
messages: this.#formatMessages(messages),
...(Array.isArray(functions) && functions?.length > 0
? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
: {}),
});
if (functions.length > 0) { /** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */
const { toolCall, text } = await this.functionCall( const completion = response.choices[0].message;
this.cleanMsgs(this.formatMessages(messages)), const cost = this.getCost(response.usage);
functions, if (completion?.tool_calls?.length > 0) {
this.#handleFunctionCallChat.bind(this) const toolCall = completion.tool_calls[0];
); let functionArgs = safeJsonParse(toolCall.function.arguments, {});
return {
if (toolCall !== null) { textResponse: null,
this.providerLog(`Valid tool call found - running ${toolCall.name}.`); functionCall: {
this.deduplicator.trackRun(toolCall.name, toolCall.arguments); name: this.prefixToolCall(toolCall.function.name, "strip"),
return { arguments: functionArgs,
result: null, id: toolCall.id,
functionCall: { },
name: toolCall.name, cost,
arguments: toolCall.arguments, };
},
cost: 0,
};
}
completion = { content: text };
} }
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(this.formatMessages(messages)),
});
completion = response.choices[0].message;
}
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
// from calling the exact same function over and over in a loop within a single chat exchange
// _but_ we should enable it to call previously used tools in a new chat interaction.
this.deduplicator.reset("runs");
return { return {
textResponse: completion.content, textResponse: completion.content,
cost: 0, cost,
}; };
} catch (error) { } catch (error) {
throw new APIError( // If invalid Auth error we need to abort because no amount of waiting
error?.message // will make auth better.
? `${this.className} encountered an error while executing the request: ${error.message}` if (error instanceof OpenAI.AuthenticationError) throw error;
: "There was an error with the Gemini provider executing the request"
); if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error;
} }
} }