Support Agent stream metric reporting (#5197)
This commit is contained in:
parent
f1439d7fcb
commit
15a84d5121
@ -186,14 +186,17 @@ const HistoricalMessage = ({
|
|||||||
export default memo(
|
export default memo(
|
||||||
HistoricalMessage,
|
HistoricalMessage,
|
||||||
// Skip re-render the historical message:
|
// Skip re-render the historical message:
|
||||||
// if the content is the exact same AND (not streaming)
|
// - if the content is the exact same
|
||||||
// the lastMessage status is the same (regen icon)
|
// - AND (not streaming)
|
||||||
// and the chatID matches between renders. (feedback icons)
|
// - the lastMessage status is the same (regen icon)
|
||||||
|
// - the chatID matches between renders. (feedback icons)
|
||||||
|
// - the metrics are the same (metrics are updated in real time)
|
||||||
(prevProps, nextProps) => {
|
(prevProps, nextProps) => {
|
||||||
return (
|
return (
|
||||||
prevProps.message === nextProps.message &&
|
prevProps.message === nextProps.message &&
|
||||||
prevProps.isLastMessage === nextProps.isLastMessage &&
|
prevProps.isLastMessage === nextProps.isLastMessage &&
|
||||||
prevProps.chatId === nextProps.chatId
|
prevProps.chatId === nextProps.chatId &&
|
||||||
|
JSON.stringify(prevProps.metrics) === JSON.stringify(nextProps.metrics)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|||||||
@ -41,6 +41,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
error: null,
|
error: null,
|
||||||
animate: false,
|
animate: false,
|
||||||
pending: false,
|
pending: false,
|
||||||
|
metrics: {},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
});
|
});
|
||||||
@ -74,6 +75,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
error: null,
|
error: null,
|
||||||
animate: false,
|
animate: false,
|
||||||
pending: false,
|
pending: false,
|
||||||
|
metrics: {},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
@ -95,6 +97,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
error: null,
|
error: null,
|
||||||
animate: false,
|
animate: false,
|
||||||
pending: false,
|
pending: false,
|
||||||
|
metrics: {},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
@ -111,6 +114,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
error: null,
|
error: null,
|
||||||
animate: false,
|
animate: false,
|
||||||
pending: false,
|
pending: false,
|
||||||
|
metrics: {},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
} else {
|
} else {
|
||||||
@ -127,6 +131,13 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
]; // If the message is known, replace it with the new content
|
]; // If the message is known, replace it with the new content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === "usageMetrics") {
|
||||||
|
if (!data.content.metrics) return prev;
|
||||||
|
return prev.map((msg) =>
|
||||||
|
msg.uuid === uuid ? { ...msg, metrics: data.content.metrics } : msg
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === "textResponseChunk") {
|
if (type === "textResponseChunk") {
|
||||||
return prev
|
return prev
|
||||||
.map((msg) =>
|
.map((msg) =>
|
||||||
@ -172,6 +183,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
error: null,
|
error: null,
|
||||||
animate: false,
|
animate: false,
|
||||||
pending: false,
|
pending: false,
|
||||||
|
metrics: data.metrics || {},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
});
|
});
|
||||||
@ -190,6 +202,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
error: data.content,
|
error: data.content,
|
||||||
animate: false,
|
animate: false,
|
||||||
pending: false,
|
pending: false,
|
||||||
|
metrics: {},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
});
|
});
|
||||||
@ -208,6 +221,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
|||||||
error: null,
|
error: null,
|
||||||
animate: data?.animate || false,
|
animate: data?.animate || false,
|
||||||
pending: false,
|
pending: false,
|
||||||
|
metrics: data.metrics || {},
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
});
|
});
|
||||||
|
|||||||
@ -626,6 +626,8 @@ ${this.getHistory({ to: route.to })
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store the active provider so plugins can access usage metrics
|
||||||
|
this.provider = provider;
|
||||||
this.newMessage({ ...route, content });
|
this.newMessage({ ...route, content });
|
||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
@ -652,7 +654,7 @@ ${this.getHistory({ to: route.to })
|
|||||||
this?.socket?.send(type, data);
|
this?.socket?.send(type, data);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @type {{ functionCall: { name: string, arguments: string }, textResponse: string }} */
|
/** @type {{ functionCall: { name: string, arguments: string }, textResponse: string, uuid: string }} */
|
||||||
const completionStream = await provider.stream(
|
const completionStream = await provider.stream(
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
@ -669,6 +671,11 @@ ${this.getHistory({ to: route.to })
|
|||||||
);
|
);
|
||||||
|
|
||||||
const finalStream = await provider.stream(messages, [], eventHandler);
|
const finalStream = await provider.stream(messages, [], eventHandler);
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
type: "usageMetrics",
|
||||||
|
uuid: finalStream?.uuid || v4(),
|
||||||
|
metrics: provider.getUsage(),
|
||||||
|
});
|
||||||
const finalResponse =
|
const finalResponse =
|
||||||
finalStream?.textResponse ||
|
finalStream?.textResponse ||
|
||||||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run.";
|
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run.";
|
||||||
@ -726,11 +733,17 @@ ${this.getHistory({ to: route.to })
|
|||||||
this.handlerProps?.log?.(
|
this.handlerProps?.log?.(
|
||||||
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
|
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
|
||||||
);
|
);
|
||||||
|
const directOutputUUID = completionStream?.uuid || v4();
|
||||||
eventHandler?.("reportStreamEvent", {
|
eventHandler?.("reportStreamEvent", {
|
||||||
type: "fullTextResponse",
|
type: "fullTextResponse",
|
||||||
uuid: v4(),
|
uuid: directOutputUUID,
|
||||||
content: result,
|
content: result,
|
||||||
});
|
});
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
type: "usageMetrics",
|
||||||
|
uuid: directOutputUUID,
|
||||||
|
metrics: provider.getUsage(),
|
||||||
|
});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -751,6 +764,11 @@ ${this.getHistory({ to: route.to })
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
type: "usageMetrics",
|
||||||
|
uuid: completionStream?.uuid || v4(),
|
||||||
|
metrics: provider.getUsage(),
|
||||||
|
});
|
||||||
return completionStream?.textResponse;
|
return completionStream?.textResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -762,6 +780,8 @@ ${this.getHistory({ to: route.to })
|
|||||||
* @param messages
|
* @param messages
|
||||||
* @param functions
|
* @param functions
|
||||||
* @param byAgent
|
* @param byAgent
|
||||||
|
* @param depth
|
||||||
|
* @param msgUUID - The message UUID to use for event correlation (created at depth=0)
|
||||||
*
|
*
|
||||||
* @returns {Promise<string>}
|
* @returns {Promise<string>}
|
||||||
*/
|
*/
|
||||||
@ -770,8 +790,15 @@ ${this.getHistory({ to: route.to })
|
|||||||
messages = [],
|
messages = [],
|
||||||
functions = [],
|
functions = [],
|
||||||
byAgent = null,
|
byAgent = null,
|
||||||
depth = 0
|
depth = 0,
|
||||||
|
msgUUID = null
|
||||||
) {
|
) {
|
||||||
|
// Create a stable UUID at the start of execution for event correlation
|
||||||
|
if (!msgUUID) msgUUID = v4();
|
||||||
|
const eventHandler = (type, data) => {
|
||||||
|
this?.socket?.send(type, data);
|
||||||
|
};
|
||||||
|
|
||||||
// get the chat completion
|
// get the chat completion
|
||||||
const completion = await provider.complete(messages, functions);
|
const completion = await provider.complete(messages, functions);
|
||||||
|
|
||||||
@ -785,6 +812,11 @@ ${this.getHistory({ to: route.to })
|
|||||||
);
|
);
|
||||||
|
|
||||||
const finalCompletion = await provider.complete(messages, []);
|
const finalCompletion = await provider.complete(messages, []);
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
type: "usageMetrics",
|
||||||
|
uuid: msgUUID,
|
||||||
|
metrics: provider.getUsage(),
|
||||||
|
});
|
||||||
return (
|
return (
|
||||||
finalCompletion?.textResponse ||
|
finalCompletion?.textResponse ||
|
||||||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run."
|
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run."
|
||||||
@ -808,7 +840,8 @@ ${this.getHistory({ to: route.to })
|
|||||||
],
|
],
|
||||||
functions,
|
functions,
|
||||||
byAgent,
|
byAgent,
|
||||||
depth + 1
|
depth + 1,
|
||||||
|
msgUUID
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -836,6 +869,11 @@ ${this.getHistory({ to: route.to })
|
|||||||
this.handlerProps?.log?.(
|
this.handlerProps?.log?.(
|
||||||
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
|
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
|
||||||
);
|
);
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
type: "usageMetrics",
|
||||||
|
uuid: msgUUID,
|
||||||
|
metrics: provider.getUsage(),
|
||||||
|
});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -852,10 +890,16 @@ ${this.getHistory({ to: route.to })
|
|||||||
],
|
],
|
||||||
functions,
|
functions,
|
||||||
byAgent,
|
byAgent,
|
||||||
depth + 1
|
depth + 1,
|
||||||
|
msgUUID
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
eventHandler?.("reportStreamEvent", {
|
||||||
|
type: "usageMetrics",
|
||||||
|
uuid: msgUUID,
|
||||||
|
metrics: provider.getUsage(),
|
||||||
|
});
|
||||||
return completion?.textResponse;
|
return completion?.textResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -43,6 +43,7 @@ const chatHistory = {
|
|||||||
},
|
},
|
||||||
_store: async function (aibitat, { prompt, response } = {}) {
|
_store: async function (aibitat, { prompt, response } = {}) {
|
||||||
const invocation = aibitat.handlerProps.invocation;
|
const invocation = aibitat.handlerProps.invocation;
|
||||||
|
const metrics = aibitat.provider?.getUsage?.() ?? {};
|
||||||
await WorkspaceChats.new({
|
await WorkspaceChats.new({
|
||||||
workspaceId: Number(invocation.workspace_id),
|
workspaceId: Number(invocation.workspace_id),
|
||||||
prompt,
|
prompt,
|
||||||
@ -50,6 +51,7 @@ const chatHistory = {
|
|||||||
text: response,
|
text: response,
|
||||||
sources: [],
|
sources: [],
|
||||||
type: "chat",
|
type: "chat",
|
||||||
|
metrics,
|
||||||
},
|
},
|
||||||
user: { id: invocation?.user_id || null },
|
user: { id: invocation?.user_id || null },
|
||||||
threadId: invocation?.thread_id || null,
|
threadId: invocation?.thread_id || null,
|
||||||
@ -60,6 +62,7 @@ const chatHistory = {
|
|||||||
{ prompt, response, options = {} } = {}
|
{ prompt, response, options = {} } = {}
|
||||||
) {
|
) {
|
||||||
const invocation = aibitat.handlerProps.invocation;
|
const invocation = aibitat.handlerProps.invocation;
|
||||||
|
const metrics = aibitat.provider?.getUsage?.() ?? {};
|
||||||
await WorkspaceChats.new({
|
await WorkspaceChats.new({
|
||||||
workspaceId: Number(invocation.workspace_id),
|
workspaceId: Number(invocation.workspace_id),
|
||||||
prompt,
|
prompt,
|
||||||
@ -71,6 +74,7 @@ const chatHistory = {
|
|||||||
? options.storedResponse(response)
|
? options.storedResponse(response)
|
||||||
: response,
|
: response,
|
||||||
type: options?.saveAsType ?? "chat",
|
type: options?.saveAsType ?? "chat",
|
||||||
|
metrics,
|
||||||
},
|
},
|
||||||
user: { id: invocation?.user_id || null },
|
user: { id: invocation?.user_id || null },
|
||||||
threadId: invocation?.thread_id || null,
|
threadId: invocation?.thread_id || null,
|
||||||
|
|||||||
@ -33,6 +33,17 @@ const { OllamaAILLM } = require("../../../AiProviders/ollama");
|
|||||||
const DEFAULT_WORKSPACE_PROMPT =
|
const DEFAULT_WORKSPACE_PROMPT =
|
||||||
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
|
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @typedef {Object} ProviderUsageMetrics
|
||||||
|
* @property {number} prompt_tokens - Number of tokens in the prompt/input
|
||||||
|
* @property {number} completion_tokens - Number of tokens in the completion/output
|
||||||
|
* @property {number} total_tokens - Total tokens used
|
||||||
|
* @property {number} duration - Duration in seconds
|
||||||
|
* @property {number} outputTps - Output tokens per second
|
||||||
|
* @property {string} model - Model name
|
||||||
|
* @property {Date} timestamp - Timestamp of the completion
|
||||||
|
*/
|
||||||
|
|
||||||
class Provider {
|
class Provider {
|
||||||
_client;
|
_client;
|
||||||
|
|
||||||
@ -51,6 +62,27 @@ class Provider {
|
|||||||
*/
|
*/
|
||||||
executingUserId = "";
|
executingUserId = "";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores the usage metrics from the last completion call.
|
||||||
|
* @type {ProviderUsageMetrics}
|
||||||
|
*/
|
||||||
|
lastUsage = {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
duration: 0,
|
||||||
|
outputTps: 0,
|
||||||
|
model: null,
|
||||||
|
provider: null,
|
||||||
|
timestamp: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Timestamp when the current request started (for duration calculation).
|
||||||
|
* @type {number}
|
||||||
|
*/
|
||||||
|
_requestStartTime = 0;
|
||||||
|
|
||||||
constructor(client) {
|
constructor(client) {
|
||||||
if (this.constructor == Provider) {
|
if (this.constructor == Provider) {
|
||||||
return;
|
return;
|
||||||
@ -407,6 +439,60 @@ class Provider {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resets the usage metrics to zero and starts the request timer.
|
||||||
|
* Call this before each completion to ensure accurate per-call metrics.
|
||||||
|
*/
|
||||||
|
resetUsage() {
|
||||||
|
this._requestStartTime = Date.now();
|
||||||
|
this.lastUsage = {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
outputTps: 0,
|
||||||
|
duration: 0,
|
||||||
|
model: null,
|
||||||
|
provider: null,
|
||||||
|
timestamp: null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Updates the stored usage metrics from a provider response.
|
||||||
|
* Override in subclasses to handle provider-specific usage formats.
|
||||||
|
* @param {Object} usage - The usage object from the provider response
|
||||||
|
*/
|
||||||
|
recordUsage(usage = {}) {
|
||||||
|
let duration = 0;
|
||||||
|
if (this._requestStartTime > 0) {
|
||||||
|
duration = (Date.now() - this._requestStartTime) / 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
const promptTokens = usage.prompt_tokens || usage.input_tokens || 0;
|
||||||
|
const completionTokens =
|
||||||
|
usage.completion_tokens || usage.output_tokens || 0;
|
||||||
|
|
||||||
|
this.lastUsage = {
|
||||||
|
prompt_tokens: promptTokens,
|
||||||
|
completion_tokens: completionTokens,
|
||||||
|
total_tokens: usage.total_tokens || promptTokens + completionTokens,
|
||||||
|
outputTps:
|
||||||
|
completionTokens && duration > 0 ? completionTokens / duration : 0,
|
||||||
|
duration,
|
||||||
|
model: this.model,
|
||||||
|
provider: this.constructor.name,
|
||||||
|
timestamp: new Date(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the usage metrics from the last completion.
|
||||||
|
* @returns {ProviderUsageMetrics} The usage metrics
|
||||||
|
*/
|
||||||
|
getUsage() {
|
||||||
|
return { ...this.lastUsage };
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stream a chat completion from the LLM with tool calling
|
* Stream a chat completion from the LLM with tool calling
|
||||||
* Note: This using the OpenAI API format and may need to be adapted for other providers.
|
* Note: This using the OpenAI API format and may need to be adapted for other providers.
|
||||||
|
|||||||
@ -180,9 +180,11 @@ class AnthropicProvider extends Provider {
|
|||||||
* @param {any[]} messages - The messages to send to the LLM.
|
* @param {any[]} messages - The messages to send to the LLM.
|
||||||
* @param {any[]} functions - The functions to use in the LLM.
|
* @param {any[]} functions - The functions to use in the LLM.
|
||||||
* @param {function} eventHandler - The event handler to use to report stream events.
|
* @param {function} eventHandler - The event handler to use to report stream events.
|
||||||
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
|
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>} - The result of the chat completion.
|
||||||
*/
|
*/
|
||||||
async stream(messages, functions = [], eventHandler = null) {
|
async stream(messages, functions = [], eventHandler = null) {
|
||||||
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const msgUUID = v4();
|
const msgUUID = v4();
|
||||||
const [systemPrompt, chats] = this.#prepareMessages(messages);
|
const [systemPrompt, chats] = this.#prepareMessages(messages);
|
||||||
@ -205,7 +207,20 @@ class AnthropicProvider extends Provider {
|
|||||||
textResponse: "",
|
textResponse: "",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Track usage from streaming events
|
||||||
|
const usage = { input_tokens: 0, output_tokens: 0 };
|
||||||
|
|
||||||
for await (const chunk of response) {
|
for await (const chunk of response) {
|
||||||
|
// Capture input tokens from message_start event
|
||||||
|
if (chunk.type === "message_start" && chunk.message?.usage) {
|
||||||
|
usage.input_tokens = chunk.message.usage.input_tokens || 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture output tokens from message_delta event
|
||||||
|
if (chunk.type === "message_delta" && chunk.usage) {
|
||||||
|
usage.output_tokens = chunk.usage.output_tokens || 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (chunk.type === "content_block_start") {
|
if (chunk.type === "content_block_start") {
|
||||||
if (chunk.content_block.type === "text") {
|
if (chunk.content_block.type === "text") {
|
||||||
result.textResponse += chunk.content_block.text;
|
result.textResponse += chunk.content_block.text;
|
||||||
@ -254,6 +269,8 @@ class AnthropicProvider extends Provider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record accumulated usage
|
||||||
|
this.recordUsage(usage);
|
||||||
if (result.functionCall) {
|
if (result.functionCall) {
|
||||||
result.functionCall.arguments = safeJsonParse(
|
result.functionCall.arguments = safeJsonParse(
|
||||||
result.functionCall.arguments,
|
result.functionCall.arguments,
|
||||||
@ -278,6 +295,7 @@ class AnthropicProvider extends Provider {
|
|||||||
arguments: result.functionCall.arguments,
|
arguments: result.functionCall.arguments,
|
||||||
},
|
},
|
||||||
cost: 0,
|
cost: 0,
|
||||||
|
uuid: msgUUID,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -285,6 +303,7 @@ class AnthropicProvider extends Provider {
|
|||||||
textResponse: result.textResponse,
|
textResponse: result.textResponse,
|
||||||
functionCall: null,
|
functionCall: null,
|
||||||
cost: 0,
|
cost: 0,
|
||||||
|
uuid: msgUUID,
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// If invalid Auth error we need to abort because no amount of waiting
|
// If invalid Auth error we need to abort because no amount of waiting
|
||||||
@ -311,6 +330,8 @@ class AnthropicProvider extends Provider {
|
|||||||
* @returns The completion.
|
* @returns The completion.
|
||||||
*/
|
*/
|
||||||
async complete(messages, functions = []) {
|
async complete(messages, functions = []) {
|
||||||
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const [systemPrompt, chats] = this.#prepareMessages(messages);
|
const [systemPrompt, chats] = this.#prepareMessages(messages);
|
||||||
const response = await this.client.messages.create(
|
const response = await this.client.messages.create(
|
||||||
@ -327,6 +348,9 @@ class AnthropicProvider extends Provider {
|
|||||||
{ headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools.
|
{ headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools.
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Record usage from response (Anthropic uses input_tokens/output_tokens)
|
||||||
|
if (response.usage) this.recordUsage(response.usage);
|
||||||
|
|
||||||
// We know that we need to call a tool. So we are about to recurse through completions/handleExecution
|
// We know that we need to call a tool. So we are about to recurse through completions/handleExecution
|
||||||
// https://docs.anthropic.com/claude/docs/tool-use#how-tool-use-works
|
// https://docs.anthropic.com/claude/docs/tool-use#how-tool-use-works
|
||||||
if (response.stop_reason === "tool_use") {
|
if (response.stop_reason === "tool_use") {
|
||||||
@ -374,6 +398,7 @@ class AnthropicProvider extends Provider {
|
|||||||
arguments: functionArgs,
|
arguments: functionArgs,
|
||||||
},
|
},
|
||||||
cost: 0,
|
cost: 0,
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,6 +408,7 @@ class AnthropicProvider extends Provider {
|
|||||||
completion?.text ??
|
completion?.text ??
|
||||||
"The model failed to complete the task and return back a valid response.",
|
"The model failed to complete the task and return back a valid response.",
|
||||||
cost: 0,
|
cost: 0,
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// If invalid Auth error we need to abort because no amount of waiting
|
// If invalid Auth error we need to abort because no amount of waiting
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class AzureOpenAiProvider extends Provider {
|
|||||||
* @param {any[]} messages
|
* @param {any[]} messages
|
||||||
* @param {any[]} functions
|
* @param {any[]} functions
|
||||||
* @param {function} eventHandler
|
* @param {function} eventHandler
|
||||||
* @returns {Promise<{ functionCall: any, textResponse: string }>}
|
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>}
|
||||||
*/
|
*/
|
||||||
async stream(messages, functions = [], eventHandler = null) {
|
async stream(messages, functions = [], eventHandler = null) {
|
||||||
this.providerLog("Provider.stream - will process this chat completion.");
|
this.providerLog("Provider.stream - will process this chat completion.");
|
||||||
@ -45,7 +45,8 @@ class AzureOpenAiProvider extends Provider {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -75,7 +76,8 @@ class AzureOpenAiProvider extends Provider {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -48,7 +48,10 @@ class DeepSeekProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
get #tooledOptions() {
|
get #tooledOptions() {
|
||||||
return this.#isThinkingModel ? { injectReasoningContent: true } : {};
|
return {
|
||||||
|
provider: this,
|
||||||
|
...(this.#isThinkingModel ? { injectReasoningContent: true } : {}),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
async #handleFunctionCallChat({ messages = [] }) {
|
async #handleFunctionCallChat({ messages = [] }) {
|
||||||
|
|||||||
@ -112,7 +112,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -151,7 +152,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -165,6 +165,8 @@ class GeminiProvider extends Provider {
|
|||||||
if (!this.supportsToolCalling)
|
if (!this.supportsToolCalling)
|
||||||
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
||||||
this.providerLog("Gemini.stream - will process this chat completion.");
|
this.providerLog("Gemini.stream - will process this chat completion.");
|
||||||
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const msgUUID = v4();
|
const msgUUID = v4();
|
||||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletion} */
|
/** @type {OpenAI.OpenAI.Chat.ChatCompletion} */
|
||||||
@ -172,6 +174,7 @@ class GeminiProvider extends Provider {
|
|||||||
model: this.model,
|
model: this.model,
|
||||||
messages: this.#formatMessages(messages),
|
messages: this.#formatMessages(messages),
|
||||||
stream: true,
|
stream: true,
|
||||||
|
stream_options: { include_usage: true },
|
||||||
...(Array.isArray(functions) && functions?.length > 0
|
...(Array.isArray(functions) && functions?.length > 0
|
||||||
? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
|
? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
|
||||||
: {}),
|
: {}),
|
||||||
@ -186,6 +189,9 @@ class GeminiProvider extends Provider {
|
|||||||
for await (const streamEvent of response) {
|
for await (const streamEvent of response) {
|
||||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */
|
/** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */
|
||||||
const chunk = streamEvent;
|
const chunk = streamEvent;
|
||||||
|
|
||||||
|
// Capture usage from final chunk (when stream_options.include_usage is true)
|
||||||
|
if (chunk?.usage) this.recordUsage(chunk.usage);
|
||||||
const { content, tool_calls } = chunk?.choices?.[0]?.delta || {};
|
const { content, tool_calls } = chunk?.choices?.[0]?.delta || {};
|
||||||
|
|
||||||
if (content) {
|
if (content) {
|
||||||
@ -228,6 +234,7 @@ class GeminiProvider extends Provider {
|
|||||||
extra_content: completion.functionCall.extra_content,
|
extra_content: completion.functionCall.extra_content,
|
||||||
},
|
},
|
||||||
cost: this.getCost(),
|
cost: this.getCost(),
|
||||||
|
uuid: msgUUID,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,6 +242,7 @@ class GeminiProvider extends Provider {
|
|||||||
textResponse: completion.content,
|
textResponse: completion.content,
|
||||||
functionCall: null,
|
functionCall: null,
|
||||||
cost: this.getCost(),
|
cost: this.getCost(),
|
||||||
|
uuid: msgUUID,
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||||
@ -261,6 +269,8 @@ class GeminiProvider extends Provider {
|
|||||||
if (!this.supportsToolCalling)
|
if (!this.supportsToolCalling)
|
||||||
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
||||||
this.providerLog("Gemini.complete - will process this chat completion.");
|
this.providerLog("Gemini.complete - will process this chat completion.");
|
||||||
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await this.client.chat.completions.create({
|
const response = await this.client.chat.completions.create({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
@ -271,6 +281,8 @@ class GeminiProvider extends Provider {
|
|||||||
: {}),
|
: {}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (response.usage) this.recordUsage(response.usage);
|
||||||
|
|
||||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */
|
/** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */
|
||||||
const completion = response.choices[0].message;
|
const completion = response.choices[0].message;
|
||||||
const cost = this.getCost(response.usage);
|
const cost = this.getCost(response.usage);
|
||||||
@ -287,12 +299,14 @@ class GeminiProvider extends Provider {
|
|||||||
extra_content: toolCall.extra_content ?? null,
|
extra_content: toolCall.extra_content ?? null,
|
||||||
},
|
},
|
||||||
cost,
|
cost,
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
textResponse: completion.content,
|
textResponse: completion.content,
|
||||||
cost,
|
cost,
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// If invalid Auth error we need to abort because no amount of waiting
|
// If invalid Auth error we need to abort because no amount of waiting
|
||||||
|
|||||||
@ -131,7 +131,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -170,7 +171,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -111,7 +111,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -150,7 +151,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -131,8 +131,9 @@ function formatMessagesForTools(messages, options = {}) {
|
|||||||
* @param {Array} messages - Raw aibitat message history
|
* @param {Array} messages - Raw aibitat message history
|
||||||
* @param {Array} functions - Aibitat function definitions
|
* @param {Array} functions - Aibitat function definitions
|
||||||
* @param {function|null} eventHandler - Stream event handler
|
* @param {function|null} eventHandler - Stream event handler
|
||||||
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools
|
* @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
|
||||||
* @returns {Promise<{textResponse: string, functionCall: object|null}>}
|
* - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
|
||||||
|
* @returns {Promise<{textResponse: string, functionCall: object|null, uuid: string, usage: object|null}>}
|
||||||
*/
|
*/
|
||||||
async function tooledStream(
|
async function tooledStream(
|
||||||
client,
|
client,
|
||||||
@ -142,13 +143,23 @@ async function tooledStream(
|
|||||||
eventHandler = null,
|
eventHandler = null,
|
||||||
options = {}
|
options = {}
|
||||||
) {
|
) {
|
||||||
|
const { provider, ...formatOptions } = options;
|
||||||
|
|
||||||
|
// Auto-reset usage if provider is passed
|
||||||
|
if (provider?.resetUsage) {
|
||||||
|
try {
|
||||||
|
provider.resetUsage();
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
|
||||||
const msgUUID = v4();
|
const msgUUID = v4();
|
||||||
const formattedMessages = formatMessagesForTools(messages, options);
|
const formattedMessages = formatMessagesForTools(messages, formatOptions);
|
||||||
const tools = formatFunctionsToTools(functions);
|
const tools = formatFunctionsToTools(functions);
|
||||||
|
|
||||||
const stream = await client.chat.completions.create({
|
const stream = await client.chat.completions.create({
|
||||||
model,
|
model,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
stream_options: { include_usage: true },
|
||||||
messages: formattedMessages,
|
messages: formattedMessages,
|
||||||
...(tools.length > 0 ? { tools } : {}),
|
...(tools.length > 0 ? { tools } : {}),
|
||||||
});
|
});
|
||||||
@ -159,8 +170,14 @@ async function tooledStream(
|
|||||||
};
|
};
|
||||||
|
|
||||||
const toolCallsByIndex = {};
|
const toolCallsByIndex = {};
|
||||||
|
let usage = null;
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
|
// Capture usage from final chunk (some providers send usage after finish_reason)
|
||||||
|
if (chunk?.usage) {
|
||||||
|
usage = chunk.usage;
|
||||||
|
}
|
||||||
|
|
||||||
if (!chunk?.choices?.[0]) continue;
|
if (!chunk?.choices?.[0]) continue;
|
||||||
const choice = chunk.choices[0];
|
const choice = chunk.choices[0];
|
||||||
|
|
||||||
@ -203,6 +220,13 @@ async function tooledStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Auto-record usage if provider is passed and usage is available
|
||||||
|
if (provider?.recordUsage && usage) {
|
||||||
|
try {
|
||||||
|
provider.recordUsage(usage);
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
|
||||||
const toolCallIndices = Object.keys(toolCallsByIndex).map(Number);
|
const toolCallIndices = Object.keys(toolCallsByIndex).map(Number);
|
||||||
if (toolCallIndices.length > 0) {
|
if (toolCallIndices.length > 0) {
|
||||||
const firstToolCall = toolCallsByIndex[Math.min(...toolCallIndices)];
|
const firstToolCall = toolCallsByIndex[Math.min(...toolCallIndices)];
|
||||||
@ -216,6 +240,8 @@ async function tooledStream(
|
|||||||
return {
|
return {
|
||||||
textResponse: result.textResponse,
|
textResponse: result.textResponse,
|
||||||
functionCall: result.functionCall,
|
functionCall: result.functionCall,
|
||||||
|
uuid: msgUUID,
|
||||||
|
usage,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,8 +254,9 @@ async function tooledStream(
|
|||||||
* @param {Array} messages - Raw aibitat message history
|
* @param {Array} messages - Raw aibitat message history
|
||||||
* @param {Array} functions - Aibitat function definitions
|
* @param {Array} functions - Aibitat function definitions
|
||||||
* @param {function} getCostFn - Provider's getCost function
|
* @param {function} getCostFn - Provider's getCost function
|
||||||
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools
|
* @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
|
||||||
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number}>}
|
* - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
|
||||||
|
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number, usage: object|null}>}
|
||||||
*/
|
*/
|
||||||
async function tooledComplete(
|
async function tooledComplete(
|
||||||
client,
|
client,
|
||||||
@ -239,7 +266,16 @@ async function tooledComplete(
|
|||||||
getCostFn = () => 0,
|
getCostFn = () => 0,
|
||||||
options = {}
|
options = {}
|
||||||
) {
|
) {
|
||||||
const formattedMessages = formatMessagesForTools(messages, options);
|
const { provider, ...formatOptions } = options;
|
||||||
|
|
||||||
|
// Auto-reset usage if provider is passed
|
||||||
|
if (provider?.resetUsage) {
|
||||||
|
try {
|
||||||
|
provider.resetUsage();
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
|
||||||
|
const formattedMessages = formatMessagesForTools(messages, formatOptions);
|
||||||
const tools = formatFunctionsToTools(functions);
|
const tools = formatFunctionsToTools(functions);
|
||||||
|
|
||||||
const response = await client.chat.completions.create({
|
const response = await client.chat.completions.create({
|
||||||
@ -251,6 +287,14 @@ async function tooledComplete(
|
|||||||
|
|
||||||
const completion = response.choices[0].message;
|
const completion = response.choices[0].message;
|
||||||
const cost = getCostFn(response.usage);
|
const cost = getCostFn(response.usage);
|
||||||
|
const usage = response.usage || null;
|
||||||
|
|
||||||
|
// Auto-record usage if provider is passed and usage is available
|
||||||
|
if (provider?.recordUsage && usage) {
|
||||||
|
try {
|
||||||
|
provider.recordUsage(usage);
|
||||||
|
} catch {}
|
||||||
|
}
|
||||||
|
|
||||||
if (completion.tool_calls && completion.tool_calls.length > 0) {
|
if (completion.tool_calls && completion.tool_calls.length > 0) {
|
||||||
const toolCall = completion.tool_calls[0];
|
const toolCall = completion.tool_calls[0];
|
||||||
@ -270,6 +314,7 @@ async function tooledComplete(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
cost,
|
cost,
|
||||||
|
usage,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -281,12 +326,14 @@ async function tooledComplete(
|
|||||||
arguments: functionArgs,
|
arguments: functionArgs,
|
||||||
},
|
},
|
||||||
cost,
|
cost,
|
||||||
|
usage,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
textResponse: completion.content,
|
textResponse: completion.content,
|
||||||
cost,
|
cost,
|
||||||
|
usage,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -135,7 +135,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -175,7 +176,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -113,7 +113,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -152,7 +153,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -118,7 +118,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -158,7 +159,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -115,7 +115,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -154,7 +155,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -112,7 +112,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -151,7 +152,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
@ -245,6 +245,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.providerLog(
|
this.providerLog(
|
||||||
"OllamaProvider.stream (tooled) - will process this chat completion."
|
"OllamaProvider.stream (tooled) - will process this chat completion."
|
||||||
);
|
);
|
||||||
|
this.resetUsage();
|
||||||
await OllamaAILLM.cacheContextWindows();
|
await OllamaAILLM.cacheContextWindows();
|
||||||
const msgUUID = v4();
|
const msgUUID = v4();
|
||||||
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
|
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
|
||||||
@ -262,6 +263,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
let toolCalls = null;
|
let toolCalls = null;
|
||||||
|
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
|
// Capture usage from final chunk (Ollama sends usage when done=true)
|
||||||
|
if (chunk.done === true) {
|
||||||
|
this.recordUsage({
|
||||||
|
prompt_tokens: chunk.prompt_eval_count || 0,
|
||||||
|
completion_tokens: chunk.eval_count || 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
if (!chunk?.message) continue;
|
if (!chunk?.message) continue;
|
||||||
|
|
||||||
if (chunk.message.content) {
|
if (chunk.message.content) {
|
||||||
@ -297,10 +306,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
name: toolCall.function.name,
|
name: toolCall.function.name,
|
||||||
arguments: args,
|
arguments: args,
|
||||||
},
|
},
|
||||||
|
cost: 0,
|
||||||
|
uuid: msgUUID,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return { textResponse, functionCall: null };
|
return { textResponse, functionCall: null, cost: 0, uuid: msgUUID };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback: UnTooled prompt-based approach via the native Ollama SDK
|
// Fallback: UnTooled prompt-based approach via the native Ollama SDK
|
||||||
@ -431,6 +442,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
functions.length > 0 && (await this.supportsNativeToolCalling());
|
functions.length > 0 && (await this.supportsNativeToolCalling());
|
||||||
|
|
||||||
if (useNative) {
|
if (useNative) {
|
||||||
|
this.resetUsage();
|
||||||
await OllamaAILLM.cacheContextWindows();
|
await OllamaAILLM.cacheContextWindows();
|
||||||
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
|
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
|
||||||
const tools = formatFunctionsToTools(functions);
|
const tools = formatFunctionsToTools(functions);
|
||||||
@ -442,6 +454,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
options: this.queryOptions,
|
options: this.queryOptions,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Record usage (Ollama uses prompt_eval_count/eval_count)
|
||||||
|
this.recordUsage({
|
||||||
|
prompt_tokens: response.prompt_eval_count || 0,
|
||||||
|
completion_tokens: response.eval_count || 0,
|
||||||
|
});
|
||||||
|
|
||||||
if (response.message?.tool_calls?.length > 0) {
|
if (response.message?.tool_calls?.length > 0) {
|
||||||
const toolCall = response.message.tool_calls[0];
|
const toolCall = response.message.tool_calls[0];
|
||||||
const args =
|
const args =
|
||||||
@ -457,12 +475,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
arguments: args,
|
arguments: args,
|
||||||
},
|
},
|
||||||
cost: 0,
|
cost: 0,
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
textResponse: response.message?.content || null,
|
textResponse: response.message?.content || null,
|
||||||
cost: 0,
|
cost: 0,
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -117,6 +117,8 @@ class OpenAIProvider extends Provider {
|
|||||||
*/
|
*/
|
||||||
async stream(messages, functions = [], eventHandler = null) {
|
async stream(messages, functions = [], eventHandler = null) {
|
||||||
this.providerLog("OpenAI.stream - will process this chat completion.");
|
this.providerLog("OpenAI.stream - will process this chat completion.");
|
||||||
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const msgUUID = v4();
|
const msgUUID = v4();
|
||||||
|
|
||||||
@ -178,6 +180,13 @@ class OpenAIProvider extends Provider {
|
|||||||
});
|
});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (chunk.type === "response.completed") {
|
||||||
|
const completedResponse = chunk.response;
|
||||||
|
if (!completedResponse?.usage) continue;
|
||||||
|
this.recordUsage(completedResponse.usage);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (completion.functionCall) {
|
if (completion.functionCall) {
|
||||||
@ -193,6 +202,7 @@ class OpenAIProvider extends Provider {
|
|||||||
arguments: completion.functionCall.arguments,
|
arguments: completion.functionCall.arguments,
|
||||||
},
|
},
|
||||||
cost: this.getCost(),
|
cost: this.getCost(),
|
||||||
|
uuid: msgUUID,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,16 +210,15 @@ class OpenAIProvider extends Provider {
|
|||||||
textResponse: completion.content,
|
textResponse: completion.content,
|
||||||
functionCall: null,
|
functionCall: null,
|
||||||
cost: this.getCost(),
|
cost: this.getCost(),
|
||||||
|
uuid: msgUUID,
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// If invalid Auth error we need to abort because no amount of waiting
|
|
||||||
// will make auth better.
|
|
||||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
error instanceof OpenAI.RateLimitError ||
|
error instanceof OpenAI.RateLimitError ||
|
||||||
error instanceof OpenAI.InternalServerError ||
|
error instanceof OpenAI.InternalServerError ||
|
||||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
error instanceof OpenAI.APIError
|
||||||
) {
|
) {
|
||||||
throw new RetryError(error.message);
|
throw new RetryError(error.message);
|
||||||
}
|
}
|
||||||
@ -227,6 +236,8 @@ class OpenAIProvider extends Provider {
|
|||||||
*/
|
*/
|
||||||
async complete(messages, functions = []) {
|
async complete(messages, functions = []) {
|
||||||
this.providerLog("OpenAI.complete - will process this chat completion.");
|
this.providerLog("OpenAI.complete - will process this chat completion.");
|
||||||
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const completion = {
|
const completion = {
|
||||||
content: "",
|
content: "",
|
||||||
@ -245,17 +256,14 @@ class OpenAIProvider extends Provider {
|
|||||||
: {}),
|
: {}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (response.usage) this.recordUsage(response.usage);
|
||||||
for (const outputBlock of response.output) {
|
for (const outputBlock of response.output) {
|
||||||
// Grab intermediate text output if it exists
|
|
||||||
// If no tools are used, this will be returned to the aibitat handler
|
|
||||||
// Otherwise, this text will never be shown to the user
|
|
||||||
if (outputBlock.type === "message") {
|
if (outputBlock.type === "message") {
|
||||||
if (outputBlock.content[0]?.type === "output_text") {
|
if (outputBlock.content[0]?.type === "output_text") {
|
||||||
completion.content = outputBlock.content[0].text;
|
completion.content = outputBlock.content[0].text;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab function call output if it exists
|
|
||||||
if (outputBlock.type === "function_call") {
|
if (outputBlock.type === "function_call") {
|
||||||
completion.functionCall = {
|
completion.functionCall = {
|
||||||
name: outputBlock.name,
|
name: outputBlock.name,
|
||||||
@ -273,13 +281,12 @@ class OpenAIProvider extends Provider {
|
|||||||
return {
|
return {
|
||||||
textResponse: completion.content,
|
textResponse: completion.content,
|
||||||
functionCall: {
|
functionCall: {
|
||||||
// For OpenAI, the id is the call_id and we need it in followup requests
|
|
||||||
// so we can match the function call output to its invocation in the message history.
|
|
||||||
id: completion.functionCall.call_id,
|
id: completion.functionCall.call_id,
|
||||||
name: completion.functionCall.name,
|
name: completion.functionCall.name,
|
||||||
arguments: completion.functionCall.arguments,
|
arguments: completion.functionCall.arguments,
|
||||||
},
|
},
|
||||||
cost: this.getCost(),
|
cost: this.getCost(),
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,16 +294,14 @@ class OpenAIProvider extends Provider {
|
|||||||
textResponse: completion.content,
|
textResponse: completion.content,
|
||||||
functionCall: null,
|
functionCall: null,
|
||||||
cost: this.getCost(),
|
cost: this.getCost(),
|
||||||
|
usage: this.getUsage(),
|
||||||
};
|
};
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// If invalid Auth error we need to abort because no amount of waiting
|
|
||||||
// will make auth better.
|
|
||||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
error instanceof OpenAI.RateLimitError ||
|
error instanceof OpenAI.RateLimitError ||
|
||||||
error instanceof OpenAI.InternalServerError ||
|
error instanceof OpenAI.InternalServerError ||
|
||||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
error instanceof OpenAI.APIError
|
||||||
) {
|
) {
|
||||||
throw new RetryError(error.message);
|
throw new RetryError(error.message);
|
||||||
}
|
}
|
||||||
@ -307,9 +312,7 @@ class OpenAIProvider extends Provider {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the cost of the completion.
|
* Get the cost of the completion.
|
||||||
*
|
* @returns {number} The cost of the completion (currently returns 0).
|
||||||
* @param _usage The completion to get the cost for.
|
|
||||||
* @returns The cost of the completion.
|
|
||||||
*/
|
*/
|
||||||
getCost() {
|
getCost() {
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@ -121,7 +121,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
eventHandler
|
eventHandler,
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error.message, error);
|
console.error(error.message, error);
|
||||||
@ -160,7 +161,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.model,
|
this.model,
|
||||||
messages,
|
messages,
|
||||||
functions,
|
functions,
|
||||||
this.getCost.bind(this)
|
this.getCost.bind(this),
|
||||||
|
{ provider: this }
|
||||||
);
|
);
|
||||||
|
|
||||||
if (result.retryWithError) {
|
if (result.retryWithError) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user