Patch AzureOpenAI tool calling from function to tool (#4922)
This commit is contained in:
parent
246969eede
commit
9584ebcd2c
@ -2,9 +2,13 @@ const { OpenAI } = require("openai");
|
|||||||
const { AzureOpenAiLLM } = require("../../../AiProviders/azureOpenAi");
|
const { AzureOpenAiLLM } = require("../../../AiProviders/azureOpenAi");
|
||||||
const Provider = require("./ai-provider.js");
|
const Provider = require("./ai-provider.js");
|
||||||
const { RetryError } = require("../error.js");
|
const { RetryError } = require("../error.js");
|
||||||
|
const { v4 } = require("uuid");
|
||||||
|
const { safeJsonParse } = require("../../../http");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The agent provider for the Azure OpenAI API.
|
* The agent provider for the Azure OpenAI API.
|
||||||
|
* Uses the tool calling format (not legacy function calling) for compatibility
|
||||||
|
* with newer Azure OpenAI models.
|
||||||
*/
|
*/
|
||||||
class AzureOpenAiProvider extends Provider {
|
class AzureOpenAiProvider extends Provider {
|
||||||
model;
|
model;
|
||||||
@ -23,8 +27,215 @@ class AzureOpenAiProvider extends Provider {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert legacy function definitions to the tools format.
|
||||||
|
* @param {Array} functions - Legacy function definitions
|
||||||
|
* @returns {Array} Tools in the new format
|
||||||
|
*/
|
||||||
|
#formatFunctionsToTools(functions) {
|
||||||
|
if (!Array.isArray(functions) || functions.length === 0) return [];
|
||||||
|
return functions.map((func) => ({
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: func.name,
|
||||||
|
description: func.description,
|
||||||
|
parameters: func.parameters,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format messages to use tool calling format instead of legacy function format.
|
||||||
|
* Converts role: "function" messages to role: "tool" messages.
|
||||||
|
* @param {Array} messages - Messages array that may contain legacy function messages
|
||||||
|
* @returns {Array} Messages formatted for tool calling
|
||||||
|
*/
|
||||||
|
#formatMessagesForTools(messages) {
|
||||||
|
const formattedMessages = [];
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
if (message.role === "function") {
|
||||||
|
// Convert legacy function result to tool result format
|
||||||
|
// We need the tool_call_id from the originalFunctionCall
|
||||||
|
if (message.originalFunctionCall?.id) {
|
||||||
|
// First, add the assistant message with the tool_call if not already present
|
||||||
|
// Check if previous message already has this tool call
|
||||||
|
const prevMsg = formattedMessages[formattedMessages.length - 1];
|
||||||
|
if (!prevMsg || prevMsg.role !== "assistant" || !prevMsg.tool_calls) {
|
||||||
|
formattedMessages.push({
|
||||||
|
role: "assistant",
|
||||||
|
content: null,
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: message.originalFunctionCall.id,
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: message.originalFunctionCall.name,
|
||||||
|
arguments:
|
||||||
|
typeof message.originalFunctionCall.arguments === "string"
|
||||||
|
? message.originalFunctionCall.arguments
|
||||||
|
: JSON.stringify(
|
||||||
|
message.originalFunctionCall.arguments
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Add the tool result
|
||||||
|
formattedMessages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: message.originalFunctionCall.id,
|
||||||
|
content:
|
||||||
|
typeof message.content === "string"
|
||||||
|
? message.content
|
||||||
|
: JSON.stringify(message.content),
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Fallback: generate a tool_call_id if not present
|
||||||
|
const toolCallId = `call_${v4()}`;
|
||||||
|
formattedMessages.push({
|
||||||
|
role: "assistant",
|
||||||
|
content: null,
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: toolCallId,
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: message.name,
|
||||||
|
arguments: "{}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
formattedMessages.push({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: toolCallId,
|
||||||
|
content:
|
||||||
|
typeof message.content === "string"
|
||||||
|
? message.content
|
||||||
|
: JSON.stringify(message.content),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
formattedMessages.push(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return formattedMessages;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stream a chat completion from the LLM with tool calling.
|
||||||
|
* Uses the tool calling format instead of legacy function calling.
|
||||||
|
*
|
||||||
|
* @param {any[]} messages - The messages to send to the LLM.
|
||||||
|
* @param {any[]} functions - The functions to use in the LLM.
|
||||||
|
* @param {function} eventHandler - The event handler to use to report stream events.
|
||||||
|
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
|
||||||
|
*/
|
||||||
|
async stream(messages, functions = [], eventHandler = null) {
|
||||||
|
this.providerLog("Provider.stream - will process this chat completion.");
|
||||||
|
const msgUUID = v4();
|
||||||
|
|
||||||
|
try {
|
||||||
|
const formattedMessages = this.#formatMessagesForTools(messages);
|
||||||
|
const tools = this.#formatFunctionsToTools(functions);
|
||||||
|
|
||||||
|
const stream = await this.client.chat.completions.create({
|
||||||
|
model: this.model,
|
||||||
|
stream: true,
|
||||||
|
messages: formattedMessages,
|
||||||
|
...(tools.length > 0 ? { tools } : {}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = {
|
||||||
|
functionCall: null,
|
||||||
|
textResponse: "",
|
||||||
|
};
|
||||||
|
|
||||||
|
// For accumulating tool calls during streaming
|
||||||
|
let currentToolCall = null;
|
||||||
|
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
if (!chunk?.choices?.[0]) continue;
|
||||||
|
const choice = chunk.choices[0];
|
||||||
|
|
||||||
|
if (choice.delta?.content) {
|
||||||
|
result.textResponse += choice.delta.content;
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
type: "textResponseChunk",
|
||||||
|
uuid: msgUUID,
|
||||||
|
content: choice.delta.content,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool calls (new format)
|
||||||
|
if (choice.delta?.tool_calls) {
|
||||||
|
for (const toolCall of choice.delta.tool_calls) {
|
||||||
|
if (toolCall.id) {
|
||||||
|
// New tool call starting
|
||||||
|
currentToolCall = {
|
||||||
|
id: toolCall.id,
|
||||||
|
name: toolCall.function?.name || "",
|
||||||
|
arguments: toolCall.function?.arguments || "",
|
||||||
|
};
|
||||||
|
} else if (currentToolCall) {
|
||||||
|
// Continuation of existing tool call
|
||||||
|
if (toolCall.function?.name) {
|
||||||
|
currentToolCall.name += toolCall.function.name;
|
||||||
|
}
|
||||||
|
if (toolCall.function?.arguments) {
|
||||||
|
currentToolCall.arguments += toolCall.function.arguments;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (currentToolCall) {
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
uuid: `${msgUUID}:tool_call_invocation`,
|
||||||
|
type: "toolCallInvocation",
|
||||||
|
content: `Assembling Tool Call: ${currentToolCall.name}(${currentToolCall.arguments})`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the function call result if we have a tool call
|
||||||
|
if (currentToolCall) {
|
||||||
|
result.functionCall = {
|
||||||
|
id: currentToolCall.id,
|
||||||
|
name: currentToolCall.name,
|
||||||
|
arguments: safeJsonParse(currentToolCall.arguments, {}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
textResponse: result.textResponse,
|
||||||
|
functionCall: result.functionCall,
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error.message, error);
|
||||||
|
|
||||||
|
// If invalid Auth error we need to abort because no amount of waiting
|
||||||
|
// will make auth better.
|
||||||
|
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||||
|
|
||||||
|
if (
|
||||||
|
error instanceof OpenAI.RateLimitError ||
|
||||||
|
error instanceof OpenAI.InternalServerError ||
|
||||||
|
error instanceof OpenAI.APIError
|
||||||
|
) {
|
||||||
|
throw new RetryError(error.message);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a completion based on the received messages.
|
* Create a completion based on the received messages.
|
||||||
|
* Uses the tool calling format instead of legacy function calling.
|
||||||
*
|
*
|
||||||
* @param messages A list of messages to send to the OpenAI API.
|
* @param messages A list of messages to send to the OpenAI API.
|
||||||
* @param functions
|
* @param functions
|
||||||
@ -32,45 +243,53 @@ class AzureOpenAiProvider extends Provider {
|
|||||||
*/
|
*/
|
||||||
async complete(messages, functions = []) {
|
async complete(messages, functions = []) {
|
||||||
try {
|
try {
|
||||||
|
const formattedMessages = this.#formatMessagesForTools(messages);
|
||||||
|
const tools = this.#formatFunctionsToTools(functions);
|
||||||
|
|
||||||
const response = await this.client.chat.completions.create({
|
const response = await this.client.chat.completions.create({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
stream: false,
|
stream: false,
|
||||||
messages,
|
messages: formattedMessages,
|
||||||
...(Array.isArray(functions) && functions?.length > 0
|
...(tools.length > 0 ? { tools } : {}),
|
||||||
? { functions }
|
|
||||||
: {}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Right now, we only support one completion,
|
// Right now, we only support one completion,
|
||||||
// so we just take the first one in the list
|
// so we just take the first one in the list
|
||||||
const completion = response.choices[0].message;
|
const completion = response.choices[0].message;
|
||||||
const cost = this.getCost(response.usage);
|
const cost = this.getCost(response.usage);
|
||||||
// treat function calls
|
|
||||||
if (completion.function_call) {
|
// Handle tool calls (new format)
|
||||||
|
if (completion.tool_calls && completion.tool_calls.length > 0) {
|
||||||
|
const toolCall = completion.tool_calls[0];
|
||||||
let functionArgs = {};
|
let functionArgs = {};
|
||||||
try {
|
try {
|
||||||
functionArgs = JSON.parse(completion.function_call.arguments);
|
functionArgs = JSON.parse(toolCall.function.arguments);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// call the complete function again in case it gets a json error
|
// Call the complete function again in case of JSON error
|
||||||
|
const toolCallId = toolCall.id;
|
||||||
return this.complete(
|
return this.complete(
|
||||||
[
|
[
|
||||||
...messages,
|
...messages,
|
||||||
{
|
{
|
||||||
role: "function",
|
role: "function",
|
||||||
name: completion.function_call.name,
|
name: toolCall.function.name,
|
||||||
function_call: completion.function_call,
|
|
||||||
content: error?.message,
|
content: error?.message,
|
||||||
|
originalFunctionCall: {
|
||||||
|
id: toolCallId,
|
||||||
|
name: toolCall.function.name,
|
||||||
|
arguments: toolCall.function.arguments,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
functions
|
functions
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// console.log(completion, { functionArgs })
|
|
||||||
return {
|
return {
|
||||||
textResponse: null,
|
textResponse: null,
|
||||||
functionCall: {
|
functionCall: {
|
||||||
name: completion.function_call.name,
|
id: toolCall.id,
|
||||||
|
name: toolCall.function.name,
|
||||||
arguments: functionArgs,
|
arguments: functionArgs,
|
||||||
},
|
},
|
||||||
cost,
|
cost,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user