Migrate gemini agents away from Untooled (#4505)
* Migrate gemini agents away from `Untooled` * disable agents for gemma models as they are not supported for tool calling * Dev build resolve #4452 via function name prefix and then stripping within provider
This commit is contained in:
parent
cf3fbcbf0f
commit
0ee0a96506
2
.github/workflows/dev-build.yaml
vendored
2
.github/workflows/dev-build.yaml
vendored
@ -6,7 +6,7 @@ concurrency:
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['improve-url-handler-collector'] # put your current branch to create a build. Core team only.
|
||||
branches: ['gemini-migration-agents'] # put your current branch to create a build. Core team only.
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'cloud-deployments/*'
|
||||
|
||||
@ -1,19 +1,14 @@
|
||||
const OpenAI = require("openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const InheritMultiple = require("./helpers/classes.js");
|
||||
const UnTooled = require("./helpers/untooled.js");
|
||||
const {
|
||||
NO_SYSTEM_PROMPT_MODELS,
|
||||
} = require("../../../AiProviders/gemini/index.js");
|
||||
const { APIError } = require("../error.js");
|
||||
const { v4 } = require("uuid");
|
||||
const { RetryError } = require("../error.js");
|
||||
const { safeJsonParse } = require("../../../http");
|
||||
const { v4 } = require("uuid");
|
||||
|
||||
/**
|
||||
* The agent provider for the Gemini provider.
|
||||
* We wrap Gemini in UnTooled because its tool-calling is not supported via the dedicated OpenAI API.
|
||||
*/
|
||||
class GeminiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
class GeminiProvider extends Provider {
|
||||
model;
|
||||
|
||||
constructor(config = {}) {
|
||||
@ -35,6 +30,11 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
return this._client;
|
||||
}
|
||||
|
||||
get supportsToolCalling() {
|
||||
if (!this.model.startsWith("gemini")) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
get supportsAgentStreaming() {
|
||||
// Tool call streaming results in a 400/503 error for all non-gemini models
|
||||
// using the compatible v1beta/openai/ endpoint
|
||||
@ -44,205 +44,250 @@ class GeminiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format the messages to the format required by the Gemini API since some models do not support system prompts.
|
||||
* @see {NO_SYSTEM_PROMPT_MODELS}
|
||||
* @param {import("openai").OpenAI.ChatCompletionMessage[]} messages
|
||||
* @returns {import("openai").OpenAI.ChatCompletionMessage[]}
|
||||
* Gemini specifcally will throw an error if the tool call's function name
|
||||
* starts with a non-alpha character. So we need to prefix the function names
|
||||
* with a valid prefix to ensure they are always valid and then strip them back
|
||||
* so they may properly be used in the tool call.
|
||||
*
|
||||
* So for all tools, we force the prefix to be gtc__ to avoid issues
|
||||
* Agent flows are already prefixed with flow__ but since we strip the prefix
|
||||
* anyway pre and post-reply, we do it anyway to ensure consistency across all tools.
|
||||
*
|
||||
* This specifically impacts the custom Agent Skills since they can be a short alphanumeric
|
||||
* and cant definitely start with a number. eg: '12xdaya31bas' -> invalid in gemini tools.
|
||||
*
|
||||
* Even if the tool is never called, if it is in the `tools` array and this prefix
|
||||
* patch is not applied, gemini will throw an error.
|
||||
*
|
||||
* This is undocumented by google, but it is the only way to ensure that tool calls
|
||||
* are valid.
|
||||
*
|
||||
* @param {string} functionName - The name of the function to prefix.
|
||||
* @param {'add' | 'strip'} action - The action to take.
|
||||
* @returns {string} The prefixed function name.
|
||||
* @returns {string} The prefix to use for tool call ids.
|
||||
*/
|
||||
formatMessages(messages) {
|
||||
if (!NO_SYSTEM_PROMPT_MODELS.includes(this.model)) return messages;
|
||||
|
||||
// Replace the system message with a user/assistant message pair
|
||||
const formattedMessages = [];
|
||||
for (const message of messages) {
|
||||
if (message.role === "system") {
|
||||
formattedMessages.push({
|
||||
role: "user",
|
||||
content: message.content,
|
||||
});
|
||||
formattedMessages.push({
|
||||
role: "assistant",
|
||||
content: "Okay, I'll follow your instructions.",
|
||||
});
|
||||
continue;
|
||||
}
|
||||
formattedMessages.push(message);
|
||||
}
|
||||
return formattedMessages;
|
||||
prefixToolCall(functionName, action = "add") {
|
||||
if (action === "add") return `gtc__${functionName}`;
|
||||
// must start with gtc__ to be valid and we only strip the first instance
|
||||
return functionName.startsWith("gtc__")
|
||||
? functionName.split("gtc__")[1]
|
||||
: functionName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format the functions for the LLM.
|
||||
* @param {any[]} functions - The functions to format.
|
||||
* @returns {any[]} - The formatted functions.
|
||||
* Format the messages to the Gemini API Responses format.
|
||||
* - Gemini has some loosely documented format for tool calls and it can change at any time.
|
||||
* - We need to map the function call to the correct id and Gemini will throw an error if it does not.
|
||||
* @param {any[]} messages - The messages to format.
|
||||
* @returns {OpenAI.OpenAI.Responses.ResponseInput[]} The formatted messages.
|
||||
*/
|
||||
formatFunctions(functions = []) {
|
||||
return functions.map((fn) => ({
|
||||
#formatMessages(messages) {
|
||||
let formattedMessages = [];
|
||||
messages.forEach((message) => {
|
||||
if (message.role === "function") {
|
||||
// If the message does not have an originalFunctionCall we cannot
|
||||
// map it to a function call id and Gemini will throw an error.
|
||||
// so if this does not carry over - log and skip
|
||||
if (!message.hasOwnProperty("originalFunctionCall")) {
|
||||
this.providerLog(
|
||||
"[Gemini.#formatMessages]: message did not pass back the originalFunctionCall. We need this to map the function call to the correct id.",
|
||||
{ message: JSON.stringify(message, null, 2) }
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
formattedMessages.push(
|
||||
{
|
||||
role: "assistant",
|
||||
tool_calls: [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
arguments: JSON.stringify(
|
||||
message.originalFunctionCall.arguments
|
||||
),
|
||||
name: message.originalFunctionCall.name,
|
||||
},
|
||||
id: message.originalFunctionCall.id,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: "tool",
|
||||
tool_call_id: message.originalFunctionCall.id,
|
||||
content: message.content,
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
formattedMessages.push({
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
});
|
||||
});
|
||||
|
||||
return formattedMessages;
|
||||
}
|
||||
|
||||
#formatFunctions(functions) {
|
||||
return functions.map((func) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: fn.name,
|
||||
description: fn.description,
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: fn.parameters.properties,
|
||||
},
|
||||
name: this.prefixToolCall(func.name, "add"),
|
||||
description: func.description,
|
||||
parameters: func.parameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
async #handleFunctionCallChat({ messages = [] }) {
|
||||
return await this.client.chat.completions
|
||||
.create({
|
||||
model: this.model,
|
||||
messages: this.cleanMsgs(this.formatMessages(messages)),
|
||||
})
|
||||
.then((result) => {
|
||||
if (!result.hasOwnProperty("choices"))
|
||||
throw new Error("Gemini chat: No results!");
|
||||
if (result.choices.length === 0)
|
||||
throw new Error("Gemini chat: No results length!");
|
||||
return result.choices[0].message.content;
|
||||
})
|
||||
.catch((_) => {
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Streaming for Gemini only supports `tools` and not `functions`, so
|
||||
* we need to apply some transformations to the messages and functions.
|
||||
*
|
||||
* @see {formatFunctions}
|
||||
* @param {*} messages
|
||||
* @param {*} functions
|
||||
* @param {*} eventHandler
|
||||
* @returns
|
||||
*/
|
||||
async stream(messages, functions = [], eventHandler = null) {
|
||||
const msgUUID = v4();
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
stream: true,
|
||||
messages: this.cleanMsgs(this.formatMessages(messages)),
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? {
|
||||
tools: this.formatFunctions(functions),
|
||||
tool_choice: "auto",
|
||||
}
|
||||
: {}),
|
||||
});
|
||||
if (!this.supportsToolCalling)
|
||||
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
||||
this.providerLog("Gemini.stream - will process this chat completion.");
|
||||
try {
|
||||
const msgUUID = v4();
|
||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletion} */
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
messages: this.#formatMessages(messages),
|
||||
stream: true,
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
|
||||
: {}),
|
||||
});
|
||||
|
||||
const result = {
|
||||
functionCall: null,
|
||||
textResponse: "",
|
||||
};
|
||||
const completion = {
|
||||
content: "",
|
||||
/** @type {null|{name: string, call_id: string, arguments: string|object}} */
|
||||
functionCall: null,
|
||||
};
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (!chunk?.choices?.[0]) continue; // Skip if no choices
|
||||
const choice = chunk.choices[0];
|
||||
for await (const streamEvent of response) {
|
||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */
|
||||
const chunk = streamEvent;
|
||||
const { content, tool_calls } = chunk?.choices?.[0]?.delta || {};
|
||||
|
||||
if (choice.delta?.content) {
|
||||
result.textResponse += choice.delta.content;
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "textResponseChunk",
|
||||
uuid: msgUUID,
|
||||
content: choice.delta.content,
|
||||
});
|
||||
}
|
||||
|
||||
if (choice.delta?.tool_calls && choice.delta.tool_calls.length > 0) {
|
||||
const toolCall = choice.delta.tool_calls[0];
|
||||
if (result.functionCall)
|
||||
result.functionCall.arguments += toolCall.function.arguments;
|
||||
else {
|
||||
result.functionCall = {
|
||||
name: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments,
|
||||
};
|
||||
if (content) {
|
||||
completion.content += content;
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "textResponseChunk",
|
||||
uuid: msgUUID,
|
||||
content,
|
||||
});
|
||||
}
|
||||
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
uuid: `${msgUUID}:tool_call_invocation`,
|
||||
type: "toolCallInvocation",
|
||||
content: `Assembling Tool Call: ${result.functionCall.name}(${result.functionCall.arguments})`,
|
||||
});
|
||||
if (tool_calls) {
|
||||
const toolCall = tool_calls[0];
|
||||
completion.functionCall = {
|
||||
name: this.prefixToolCall(toolCall.function.name, "strip"),
|
||||
call_id: toolCall.id,
|
||||
arguments: toolCall.function.arguments,
|
||||
};
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "toolCallInvocation",
|
||||
uuid: `${msgUUID}:tool_call_invocation`,
|
||||
content: `Assembling Tool Call: ${completion.functionCall.name}(${completion.functionCall.arguments})`,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are arguments, parse them as json so that the tools can use them
|
||||
if (!!result.functionCall?.arguments)
|
||||
result.functionCall.arguments = safeJsonParse(
|
||||
result.functionCall.arguments,
|
||||
{}
|
||||
);
|
||||
return result;
|
||||
if (completion.functionCall) {
|
||||
completion.functionCall.arguments = safeJsonParse(
|
||||
completion.functionCall.arguments,
|
||||
{}
|
||||
);
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
functionCall: {
|
||||
id: completion.functionCall.call_id,
|
||||
name: completion.functionCall.name,
|
||||
arguments: completion.functionCall.arguments,
|
||||
},
|
||||
cost: this.getCost(),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
functionCall: null,
|
||||
cost: this.getCost(),
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
if (
|
||||
error instanceof OpenAI.RateLimitError ||
|
||||
error instanceof OpenAI.InternalServerError ||
|
||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||
) {
|
||||
throw new RetryError(error.message);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
* TODO: see stream() - tool_calls are now supported, so we can use that instead of Untooled
|
||||
*
|
||||
* @param messages A list of messages to send to the API.
|
||||
* @param messages A list of messages to send to the Gemini API.
|
||||
* @param functions
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = []) {
|
||||
if (!this.supportsToolCalling)
|
||||
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
||||
this.providerLog("Gemini.complete - will process this chat completion.");
|
||||
try {
|
||||
let completion;
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
stream: false,
|
||||
messages: this.#formatMessages(messages),
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
|
||||
: {}),
|
||||
});
|
||||
|
||||
if (functions.length > 0) {
|
||||
const { toolCall, text } = await this.functionCall(
|
||||
this.cleanMsgs(this.formatMessages(messages)),
|
||||
functions,
|
||||
this.#handleFunctionCallChat.bind(this)
|
||||
);
|
||||
|
||||
if (toolCall !== null) {
|
||||
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
||||
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
||||
return {
|
||||
result: null,
|
||||
functionCall: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments,
|
||||
},
|
||||
cost: 0,
|
||||
};
|
||||
}
|
||||
completion = { content: text };
|
||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */
|
||||
const completion = response.choices[0].message;
|
||||
const cost = this.getCost(response.usage);
|
||||
if (completion?.tool_calls?.length > 0) {
|
||||
const toolCall = completion.tool_calls[0];
|
||||
let functionArgs = safeJsonParse(toolCall.function.arguments, {});
|
||||
return {
|
||||
textResponse: null,
|
||||
functionCall: {
|
||||
name: this.prefixToolCall(toolCall.function.name, "strip"),
|
||||
arguments: functionArgs,
|
||||
id: toolCall.id,
|
||||
},
|
||||
cost,
|
||||
};
|
||||
}
|
||||
|
||||
if (!completion?.content) {
|
||||
this.providerLog(
|
||||
"Will assume chat completion without tool call inputs."
|
||||
);
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
messages: this.cleanMsgs(this.formatMessages(messages)),
|
||||
});
|
||||
completion = response.choices[0].message;
|
||||
}
|
||||
|
||||
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
|
||||
// from calling the exact same function over and over in a loop within a single chat exchange
|
||||
// _but_ we should enable it to call previously used tools in a new chat interaction.
|
||||
this.deduplicator.reset("runs");
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
cost: 0,
|
||||
cost,
|
||||
};
|
||||
} catch (error) {
|
||||
throw new APIError(
|
||||
error?.message
|
||||
? `${this.className} encountered an error while executing the request: ${error.message}`
|
||||
: "There was an error with the Gemini provider executing the request"
|
||||
);
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
// will make auth better.
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
|
||||
if (
|
||||
error instanceof OpenAI.RateLimitError ||
|
||||
error instanceof OpenAI.InternalServerError ||
|
||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||
) {
|
||||
throw new RetryError(error.message);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user