parent
f7b90571be
commit
8f0f9df4fc
@ -665,6 +665,7 @@ ${this.getHistory({ to: route.to })
|
||||
name,
|
||||
role: "function",
|
||||
content: `Function "${name}" not found. Try again.`,
|
||||
originalFunctionCall: completionStream.functionCall,
|
||||
},
|
||||
],
|
||||
functions,
|
||||
@ -712,6 +713,7 @@ ${this.getHistory({ to: route.to })
|
||||
name,
|
||||
role: "function",
|
||||
content: result,
|
||||
originalFunctionCall: completionStream.functionCall,
|
||||
},
|
||||
],
|
||||
functions,
|
||||
@ -757,6 +759,7 @@ ${this.getHistory({ to: route.to })
|
||||
name,
|
||||
role: "function",
|
||||
content: `Function "${name}" not found. Try again.`,
|
||||
originalFunctionCall: completion.functionCall,
|
||||
},
|
||||
],
|
||||
functions,
|
||||
@ -804,6 +807,7 @@ ${this.getHistory({ to: route.to })
|
||||
name,
|
||||
role: "function",
|
||||
content: result,
|
||||
originalFunctionCall: completion.functionCall,
|
||||
},
|
||||
],
|
||||
functions,
|
||||
|
||||
@ -134,7 +134,7 @@ class AnthropicProvider extends Provider {
|
||||
|
||||
/**
|
||||
* Stream a chat completion from the LLM with tool calling
|
||||
* Note: This using the OpenAI API format and may need to be adapted for other providers.
|
||||
* Note: This using the Anthropic API SDK and its implementation is specific to Anthropic.
|
||||
*
|
||||
* @param {any[]} messages - The messages to send to the LLM.
|
||||
* @param {any[]} functions - The functions to use in the LLM.
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
const OpenAI = require("openai");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const { RetryError } = require("../error.js");
|
||||
const { v4 } = require("uuid");
|
||||
const { safeJsonParse } = require("../../../http");
|
||||
|
||||
/**
|
||||
* The agent provider for the OpenAI API.
|
||||
@ -28,6 +30,194 @@ class OpenAIProvider extends Provider {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format the messages to the OpenAI API Responses format.
|
||||
* - If the message is our internal `function` type, then we need to map it to a function call + output format
|
||||
* - Otherwise, map it to the input text format for user, system, and assistant messages
|
||||
*
|
||||
* @param {any[]} messages - The messages to format.
|
||||
* @returns {OpenAI.OpenAI.Responses.ResponseInput[]} The formatted messages.
|
||||
*/
|
||||
#formatToResponsesInput(messages) {
|
||||
let formattedMessages = [];
|
||||
messages.forEach((message) => {
|
||||
if (message.role === "function") {
|
||||
// If the message does not have an originalFunctionCall we cannot
|
||||
// map it to a function call id and OpenAI will throw an error.
|
||||
// so if this does not carry over - log and skip
|
||||
if (!message.hasOwnProperty("originalFunctionCall")) {
|
||||
this.providerLog(
|
||||
"[OpenAI.#formatToResponsesInput]: message did not pass back the originalFunctionCall. We need this to map the function call to the correct id.",
|
||||
{ message: JSON.stringify(message, null, 2) }
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
formattedMessages.push(
|
||||
{
|
||||
type: "function_call",
|
||||
name: message.originalFunctionCall.name,
|
||||
call_id: message.originalFunctionCall.id,
|
||||
arguments: JSON.stringify(message.originalFunctionCall.arguments),
|
||||
},
|
||||
{
|
||||
type: "function_call_output",
|
||||
call_id: message.originalFunctionCall.id,
|
||||
output: message.content,
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
formattedMessages.push({
|
||||
role: message.role,
|
||||
content: [
|
||||
{
|
||||
type: message.role === "assistant" ? "output_text" : "input_text",
|
||||
text: message.content,
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
return formattedMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format the functions to the OpenAI API Responses format.
|
||||
*
|
||||
* @param {any[]} functions - The functions to format.
|
||||
* @returns {{
|
||||
* type: "function",
|
||||
* name: string,
|
||||
* description: string,
|
||||
* parameters: object,
|
||||
* strict: boolean,
|
||||
* }[]} The formatted functions.
|
||||
*/
|
||||
#formatFunctions(functions) {
|
||||
return functions.map((func) => ({
|
||||
type: "function",
|
||||
name: func.name,
|
||||
description: func.description,
|
||||
parameters: func.parameters,
|
||||
strict: false,
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat completion from the LLM with tool calling
|
||||
* Note: This using the OpenAI API Responses SDK and its implementation is specific to OpenAI models.
|
||||
* Do not re-use this code for providers that do not EXACTLY implement the OpenAI API Responses SDK.
|
||||
*
|
||||
* @param {any[]} messages - The messages to send to the LLM.
|
||||
* @param {any[]} functions - The functions to use in the LLM.
|
||||
* @param {function} eventHandler - The event handler to use to report stream events.
|
||||
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
|
||||
*/
|
||||
async stream(messages, functions = [], eventHandler = null) {
|
||||
this.providerLog("OpenAI.stream - will process this chat completion.");
|
||||
try {
|
||||
const msgUUID = v4();
|
||||
|
||||
/** @type {OpenAI.OpenAI.Responses.Response} */
|
||||
const response = await this.client.responses.create({
|
||||
model: this.model,
|
||||
input: this.#formatToResponsesInput(messages),
|
||||
stream: true,
|
||||
store: false,
|
||||
parallel_tool_calls: false,
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? { tools: this.#formatFunctions(functions) }
|
||||
: {}),
|
||||
});
|
||||
|
||||
const completion = {
|
||||
content: "",
|
||||
/** @type {null|{name: string, call_id: string, arguments: string|object}} */
|
||||
functionCall: null,
|
||||
};
|
||||
|
||||
for await (const streamEvent of response) {
|
||||
/** @type {OpenAI.OpenAI.Responses.ResponseStreamEvent} */
|
||||
const chunk = streamEvent;
|
||||
|
||||
if (chunk.type === "response.output_text.delta") {
|
||||
completion.content += chunk.delta;
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "textResponseChunk",
|
||||
uuid: msgUUID,
|
||||
content: chunk.delta,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (
|
||||
chunk.type === "response.output_item.added" &&
|
||||
chunk.item.type === "function_call"
|
||||
) {
|
||||
completion.functionCall = {
|
||||
name: chunk.item.name,
|
||||
call_id: chunk.item.call_id,
|
||||
arguments: chunk.item.arguments,
|
||||
};
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "toolCallInvocation",
|
||||
uuid: `${msgUUID}:tool_call_invocation`,
|
||||
content: `Assembling Tool Call: ${completion.functionCall.name}(${completion.functionCall.arguments})`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk.type === "response.function_call_arguments.delta") {
|
||||
completion.functionCall.arguments += chunk.delta;
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "toolCallInvocation",
|
||||
uuid: `${msgUUID}:tool_call_invocation`,
|
||||
content: `Assembling Tool Call: ${completion.functionCall.name}(${completion.functionCall.arguments})`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (completion.functionCall) {
|
||||
completion.functionCall.arguments = safeJsonParse(
|
||||
completion.functionCall.arguments,
|
||||
{}
|
||||
);
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
functionCall: {
|
||||
id: completion.functionCall.call_id,
|
||||
name: completion.functionCall.name,
|
||||
arguments: completion.functionCall.arguments,
|
||||
},
|
||||
cost: this.getCost(),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
functionCall: null,
|
||||
cost: this.getCost(),
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
// will make auth better.
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
|
||||
if (
|
||||
error instanceof OpenAI.RateLimitError ||
|
||||
error instanceof OpenAI.InternalServerError ||
|
||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||
) {
|
||||
throw new RetryError(error.message);
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a completion based on the received messages.
|
||||
*
|
||||
@ -36,54 +226,67 @@ class OpenAIProvider extends Provider {
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = []) {
|
||||
this.providerLog("OpenAI.complete - will process this chat completion.");
|
||||
try {
|
||||
const response = await this.client.chat.completions.create({
|
||||
const completion = {
|
||||
content: "",
|
||||
functionCall: null,
|
||||
};
|
||||
|
||||
/** @type {OpenAI.OpenAI.Responses.Response} */
|
||||
const response = await this.client.responses.create({
|
||||
model: this.model,
|
||||
stream: false,
|
||||
messages,
|
||||
store: false,
|
||||
parallel_tool_calls: false,
|
||||
input: this.#formatToResponsesInput(messages),
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? { functions }
|
||||
? { tools: this.#formatFunctions(functions) }
|
||||
: {}),
|
||||
});
|
||||
|
||||
// Right now, we only support one completion,
|
||||
// so we just take the first one in the list
|
||||
const completion = response.choices[0].message;
|
||||
const cost = this.getCost(response.usage);
|
||||
// treat function calls
|
||||
if (completion.function_call) {
|
||||
let functionArgs = {};
|
||||
try {
|
||||
functionArgs = JSON.parse(completion.function_call.arguments);
|
||||
} catch (error) {
|
||||
// call the complete function again in case it gets a json error
|
||||
return this.complete(
|
||||
[
|
||||
...messages,
|
||||
{
|
||||
role: "function",
|
||||
name: completion.function_call.name,
|
||||
function_call: completion.function_call,
|
||||
content: error?.message,
|
||||
},
|
||||
],
|
||||
functions
|
||||
);
|
||||
for (const outputBlock of response.output) {
|
||||
// Grab intermediate text output if it exists
|
||||
// If no tools are used, this will be returned to the aibitat handler
|
||||
// Otherwise, this text will never be shown to the user
|
||||
if (outputBlock.type === "message") {
|
||||
if (outputBlock.content[0]?.type === "output_text") {
|
||||
completion.content = outputBlock.content[0].text;
|
||||
}
|
||||
}
|
||||
|
||||
// Grab function call output if it exists
|
||||
if (outputBlock.type === "function_call") {
|
||||
completion.functionCall = {
|
||||
name: outputBlock.name,
|
||||
call_id: outputBlock.call_id,
|
||||
arguments: outputBlock.arguments,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (completion.functionCall) {
|
||||
completion.functionCall.arguments = safeJsonParse(
|
||||
completion.functionCall.arguments,
|
||||
{}
|
||||
);
|
||||
return {
|
||||
textResponse: null,
|
||||
textResponse: completion.content,
|
||||
functionCall: {
|
||||
name: completion.function_call.name,
|
||||
arguments: functionArgs,
|
||||
// For OpenAI, the id is the call_id and we need it in followup requests
|
||||
// so we can match the function call output to its invocation in the message history.
|
||||
id: completion.functionCall.call_id,
|
||||
name: completion.functionCall.name,
|
||||
arguments: completion.functionCall.arguments,
|
||||
},
|
||||
cost,
|
||||
cost: this.getCost(),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
cost,
|
||||
functionCall: null,
|
||||
cost: this.getCost(),
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
|
||||
Loading…
Reference in New Issue
Block a user