Support Agent stream metric reporting (#5197)

This commit is contained in:
Timothy Carambat 2026-03-12 12:50:02 -07:00 committed by GitHub
parent f1439d7fcb
commit 15a84d5121
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 339 additions and 55 deletions

View File

@ -186,14 +186,17 @@ const HistoricalMessage = ({
export default memo( export default memo(
HistoricalMessage, HistoricalMessage,
// Skip re-render the historical message: // Skip re-render the historical message:
// if the content is the exact same AND (not streaming) // - if the content is the exact same
// the lastMessage status is the same (regen icon) // - AND (not streaming)
// and the chatID matches between renders. (feedback icons) // - the lastMessage status is the same (regen icon)
// - the chatID matches between renders. (feedback icons)
// - the metrics are the same (metrics are updated in real time)
(prevProps, nextProps) => { (prevProps, nextProps) => {
return ( return (
prevProps.message === nextProps.message && prevProps.message === nextProps.message &&
prevProps.isLastMessage === nextProps.isLastMessage && prevProps.isLastMessage === nextProps.isLastMessage &&
prevProps.chatId === nextProps.chatId prevProps.chatId === nextProps.chatId &&
JSON.stringify(prevProps.metrics) === JSON.stringify(nextProps.metrics)
); );
} }
); );

View File

@ -41,6 +41,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null, error: null,
animate: false, animate: false,
pending: false, pending: false,
metrics: {},
}, },
]; ];
}); });
@ -74,6 +75,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null, error: null,
animate: false, animate: false,
pending: false, pending: false,
metrics: {},
}, },
]; ];
} }
@ -95,6 +97,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null, error: null,
animate: false, animate: false,
pending: false, pending: false,
metrics: {},
}, },
]; ];
} }
@ -111,6 +114,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null, error: null,
animate: false, animate: false,
pending: false, pending: false,
metrics: {},
}, },
]; ];
} else { } else {
@ -127,6 +131,13 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
]; // If the message is known, replace it with the new content ]; // If the message is known, replace it with the new content
} }
if (type === "usageMetrics") {
if (!data.content.metrics) return prev;
return prev.map((msg) =>
msg.uuid === uuid ? { ...msg, metrics: data.content.metrics } : msg
);
}
if (type === "textResponseChunk") { if (type === "textResponseChunk") {
return prev return prev
.map((msg) => .map((msg) =>
@ -172,6 +183,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null, error: null,
animate: false, animate: false,
pending: false, pending: false,
metrics: data.metrics || {},
}, },
]; ];
}); });
@ -190,6 +202,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: data.content, error: data.content,
animate: false, animate: false,
pending: false, pending: false,
metrics: {},
}, },
]; ];
}); });
@ -208,6 +221,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null, error: null,
animate: data?.animate || false, animate: data?.animate || false,
pending: false, pending: false,
metrics: data.metrics || {},
}, },
]; ];
}); });

View File

@ -626,6 +626,8 @@ ${this.getHistory({ to: route.to })
); );
} }
// Store the active provider so plugins can access usage metrics
this.provider = provider;
this.newMessage({ ...route, content }); this.newMessage({ ...route, content });
return content; return content;
} }
@ -652,7 +654,7 @@ ${this.getHistory({ to: route.to })
this?.socket?.send(type, data); this?.socket?.send(type, data);
}; };
/** @type {{ functionCall: { name: string, arguments: string }, textResponse: string }} */ /** @type {{ functionCall: { name: string, arguments: string }, textResponse: string, uuid: string }} */
const completionStream = await provider.stream( const completionStream = await provider.stream(
messages, messages,
functions, functions,
@ -669,6 +671,11 @@ ${this.getHistory({ to: route.to })
); );
const finalStream = await provider.stream(messages, [], eventHandler); const finalStream = await provider.stream(messages, [], eventHandler);
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: finalStream?.uuid || v4(),
metrics: provider.getUsage(),
});
const finalResponse = const finalResponse =
finalStream?.textResponse || finalStream?.textResponse ||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run."; "I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run.";
@ -726,11 +733,17 @@ ${this.getHistory({ to: route.to })
this.handlerProps?.log?.( this.handlerProps?.log?.(
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.` `${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
); );
const directOutputUUID = completionStream?.uuid || v4();
eventHandler?.("reportStreamEvent", { eventHandler?.("reportStreamEvent", {
type: "fullTextResponse", type: "fullTextResponse",
uuid: v4(), uuid: directOutputUUID,
content: result, content: result,
}); });
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: directOutputUUID,
metrics: provider.getUsage(),
});
return result; return result;
} }
@ -751,6 +764,11 @@ ${this.getHistory({ to: route.to })
); );
} }
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: completionStream?.uuid || v4(),
metrics: provider.getUsage(),
});
return completionStream?.textResponse; return completionStream?.textResponse;
} }
@ -762,6 +780,8 @@ ${this.getHistory({ to: route.to })
* @param messages * @param messages
* @param functions * @param functions
* @param byAgent * @param byAgent
* @param depth
* @param msgUUID - The message UUID to use for event correlation (created at depth=0)
* *
* @returns {Promise<string>} * @returns {Promise<string>}
*/ */
@ -770,8 +790,15 @@ ${this.getHistory({ to: route.to })
messages = [], messages = [],
functions = [], functions = [],
byAgent = null, byAgent = null,
depth = 0 depth = 0,
msgUUID = null
) { ) {
// Create a stable UUID at the start of execution for event correlation
if (!msgUUID) msgUUID = v4();
const eventHandler = (type, data) => {
this?.socket?.send(type, data);
};
// get the chat completion // get the chat completion
const completion = await provider.complete(messages, functions); const completion = await provider.complete(messages, functions);
@ -785,6 +812,11 @@ ${this.getHistory({ to: route.to })
); );
const finalCompletion = await provider.complete(messages, []); const finalCompletion = await provider.complete(messages, []);
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: msgUUID,
metrics: provider.getUsage(),
});
return ( return (
finalCompletion?.textResponse || finalCompletion?.textResponse ||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run." "I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run."
@ -808,7 +840,8 @@ ${this.getHistory({ to: route.to })
], ],
functions, functions,
byAgent, byAgent,
depth + 1 depth + 1,
msgUUID
); );
} }
@ -836,6 +869,11 @@ ${this.getHistory({ to: route.to })
this.handlerProps?.log?.( this.handlerProps?.log?.(
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.` `${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
); );
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: msgUUID,
metrics: provider.getUsage(),
});
return result; return result;
} }
@ -852,10 +890,16 @@ ${this.getHistory({ to: route.to })
], ],
functions, functions,
byAgent, byAgent,
depth + 1 depth + 1,
msgUUID
); );
} }
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: msgUUID,
metrics: provider.getUsage(),
});
return completion?.textResponse; return completion?.textResponse;
} }

View File

@ -43,6 +43,7 @@ const chatHistory = {
}, },
_store: async function (aibitat, { prompt, response } = {}) { _store: async function (aibitat, { prompt, response } = {}) {
const invocation = aibitat.handlerProps.invocation; const invocation = aibitat.handlerProps.invocation;
const metrics = aibitat.provider?.getUsage?.() ?? {};
await WorkspaceChats.new({ await WorkspaceChats.new({
workspaceId: Number(invocation.workspace_id), workspaceId: Number(invocation.workspace_id),
prompt, prompt,
@ -50,6 +51,7 @@ const chatHistory = {
text: response, text: response,
sources: [], sources: [],
type: "chat", type: "chat",
metrics,
}, },
user: { id: invocation?.user_id || null }, user: { id: invocation?.user_id || null },
threadId: invocation?.thread_id || null, threadId: invocation?.thread_id || null,
@ -60,6 +62,7 @@ const chatHistory = {
{ prompt, response, options = {} } = {} { prompt, response, options = {} } = {}
) { ) {
const invocation = aibitat.handlerProps.invocation; const invocation = aibitat.handlerProps.invocation;
const metrics = aibitat.provider?.getUsage?.() ?? {};
await WorkspaceChats.new({ await WorkspaceChats.new({
workspaceId: Number(invocation.workspace_id), workspaceId: Number(invocation.workspace_id),
prompt, prompt,
@ -71,6 +74,7 @@ const chatHistory = {
? options.storedResponse(response) ? options.storedResponse(response)
: response, : response,
type: options?.saveAsType ?? "chat", type: options?.saveAsType ?? "chat",
metrics,
}, },
user: { id: invocation?.user_id || null }, user: { id: invocation?.user_id || null },
threadId: invocation?.thread_id || null, threadId: invocation?.thread_id || null,

View File

@ -33,6 +33,17 @@ const { OllamaAILLM } = require("../../../AiProviders/ollama");
const DEFAULT_WORKSPACE_PROMPT = const DEFAULT_WORKSPACE_PROMPT =
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions."; "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
/**
* @typedef {Object} ProviderUsageMetrics
* @property {number} prompt_tokens - Number of tokens in the prompt/input
* @property {number} completion_tokens - Number of tokens in the completion/output
* @property {number} total_tokens - Total tokens used
* @property {number} duration - Duration in seconds
* @property {number} outputTps - Output tokens per second
* @property {string} model - Model name
* @property {Date} timestamp - Timestamp of the completion
*/
class Provider { class Provider {
_client; _client;
@ -51,6 +62,27 @@ class Provider {
*/ */
executingUserId = ""; executingUserId = "";
/**
* Stores the usage metrics from the last completion call.
* @type {ProviderUsageMetrics}
*/
lastUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
duration: 0,
outputTps: 0,
model: null,
provider: null,
timestamp: null,
};
/**
* Timestamp when the current request started (for duration calculation).
* @type {number}
*/
_requestStartTime = 0;
constructor(client) { constructor(client) {
if (this.constructor == Provider) { if (this.constructor == Provider) {
return; return;
@ -407,6 +439,60 @@ class Provider {
return false; return false;
} }
/**
* Resets the usage metrics to zero and starts the request timer.
* Call this before each completion to ensure accurate per-call metrics.
*/
resetUsage() {
this._requestStartTime = Date.now();
this.lastUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
outputTps: 0,
duration: 0,
model: null,
provider: null,
timestamp: null,
};
}
/**
* Updates the stored usage metrics from a provider response.
* Override in subclasses to handle provider-specific usage formats.
* @param {Object} usage - The usage object from the provider response
*/
recordUsage(usage = {}) {
let duration = 0;
if (this._requestStartTime > 0) {
duration = (Date.now() - this._requestStartTime) / 1000;
}
const promptTokens = usage.prompt_tokens || usage.input_tokens || 0;
const completionTokens =
usage.completion_tokens || usage.output_tokens || 0;
this.lastUsage = {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: usage.total_tokens || promptTokens + completionTokens,
outputTps:
completionTokens && duration > 0 ? completionTokens / duration : 0,
duration,
model: this.model,
provider: this.constructor.name,
timestamp: new Date(),
};
}
/**
* Get the usage metrics from the last completion.
* @returns {ProviderUsageMetrics} The usage metrics
*/
getUsage() {
return { ...this.lastUsage };
}
/** /**
* Stream a chat completion from the LLM with tool calling * Stream a chat completion from the LLM with tool calling
* Note: This using the OpenAI API format and may need to be adapted for other providers. * Note: This using the OpenAI API format and may need to be adapted for other providers.

View File

@ -180,9 +180,11 @@ class AnthropicProvider extends Provider {
* @param {any[]} messages - The messages to send to the LLM. * @param {any[]} messages - The messages to send to the LLM.
* @param {any[]} functions - The functions to use in the LLM. * @param {any[]} functions - The functions to use in the LLM.
* @param {function} eventHandler - The event handler to use to report stream events. * @param {function} eventHandler - The event handler to use to report stream events.
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion. * @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>} - The result of the chat completion.
*/ */
async stream(messages, functions = [], eventHandler = null) { async stream(messages, functions = [], eventHandler = null) {
this.resetUsage();
try { try {
const msgUUID = v4(); const msgUUID = v4();
const [systemPrompt, chats] = this.#prepareMessages(messages); const [systemPrompt, chats] = this.#prepareMessages(messages);
@ -205,7 +207,20 @@ class AnthropicProvider extends Provider {
textResponse: "", textResponse: "",
}; };
// Track usage from streaming events
const usage = { input_tokens: 0, output_tokens: 0 };
for await (const chunk of response) { for await (const chunk of response) {
// Capture input tokens from message_start event
if (chunk.type === "message_start" && chunk.message?.usage) {
usage.input_tokens = chunk.message.usage.input_tokens || 0;
}
// Capture output tokens from message_delta event
if (chunk.type === "message_delta" && chunk.usage) {
usage.output_tokens = chunk.usage.output_tokens || 0;
}
if (chunk.type === "content_block_start") { if (chunk.type === "content_block_start") {
if (chunk.content_block.type === "text") { if (chunk.content_block.type === "text") {
result.textResponse += chunk.content_block.text; result.textResponse += chunk.content_block.text;
@ -254,6 +269,8 @@ class AnthropicProvider extends Provider {
} }
} }
// Record accumulated usage
this.recordUsage(usage);
if (result.functionCall) { if (result.functionCall) {
result.functionCall.arguments = safeJsonParse( result.functionCall.arguments = safeJsonParse(
result.functionCall.arguments, result.functionCall.arguments,
@ -278,6 +295,7 @@ class AnthropicProvider extends Provider {
arguments: result.functionCall.arguments, arguments: result.functionCall.arguments,
}, },
cost: 0, cost: 0,
uuid: msgUUID,
}; };
} }
@ -285,6 +303,7 @@ class AnthropicProvider extends Provider {
textResponse: result.textResponse, textResponse: result.textResponse,
functionCall: null, functionCall: null,
cost: 0, cost: 0,
uuid: msgUUID,
}; };
} catch (error) { } catch (error) {
// If invalid Auth error we need to abort because no amount of waiting // If invalid Auth error we need to abort because no amount of waiting
@ -311,6 +330,8 @@ class AnthropicProvider extends Provider {
* @returns The completion. * @returns The completion.
*/ */
async complete(messages, functions = []) { async complete(messages, functions = []) {
this.resetUsage();
try { try {
const [systemPrompt, chats] = this.#prepareMessages(messages); const [systemPrompt, chats] = this.#prepareMessages(messages);
const response = await this.client.messages.create( const response = await this.client.messages.create(
@ -327,6 +348,9 @@ class AnthropicProvider extends Provider {
{ headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools. { headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools.
); );
// Record usage from response (Anthropic uses input_tokens/output_tokens)
if (response.usage) this.recordUsage(response.usage);
// We know that we need to call a tool. So we are about to recurse through completions/handleExecution // We know that we need to call a tool. So we are about to recurse through completions/handleExecution
// https://docs.anthropic.com/claude/docs/tool-use#how-tool-use-works // https://docs.anthropic.com/claude/docs/tool-use#how-tool-use-works
if (response.stop_reason === "tool_use") { if (response.stop_reason === "tool_use") {
@ -374,6 +398,7 @@ class AnthropicProvider extends Provider {
arguments: functionArgs, arguments: functionArgs,
}, },
cost: 0, cost: 0,
usage: this.getUsage(),
}; };
} }
@ -383,6 +408,7 @@ class AnthropicProvider extends Provider {
completion?.text ?? completion?.text ??
"The model failed to complete the task and return back a valid response.", "The model failed to complete the task and return back a valid response.",
cost: 0, cost: 0,
usage: this.getUsage(),
}; };
} catch (error) { } catch (error) {
// If invalid Auth error we need to abort because no amount of waiting // If invalid Auth error we need to abort because no amount of waiting

View File

@ -34,7 +34,7 @@ class AzureOpenAiProvider extends Provider {
* @param {any[]} messages * @param {any[]} messages
* @param {any[]} functions * @param {any[]} functions
* @param {function} eventHandler * @param {function} eventHandler
* @returns {Promise<{ functionCall: any, textResponse: string }>} * @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>}
*/ */
async stream(messages, functions = [], eventHandler = null) { async stream(messages, functions = [], eventHandler = null) {
this.providerLog("Provider.stream - will process this chat completion."); this.providerLog("Provider.stream - will process this chat completion.");
@ -45,7 +45,8 @@ class AzureOpenAiProvider extends Provider {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -75,7 +76,8 @@ class AzureOpenAiProvider extends Provider {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -48,7 +48,10 @@ class DeepSeekProvider extends InheritMultiple([Provider, UnTooled]) {
} }
get #tooledOptions() { get #tooledOptions() {
return this.#isThinkingModel ? { injectReasoningContent: true } : {}; return {
provider: this,
...(this.#isThinkingModel ? { injectReasoningContent: true } : {}),
};
} }
async #handleFunctionCallChat({ messages = [] }) { async #handleFunctionCallChat({ messages = [] }) {

View File

@ -112,7 +112,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -151,7 +152,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -165,6 +165,8 @@ class GeminiProvider extends Provider {
if (!this.supportsToolCalling) if (!this.supportsToolCalling)
throw new Error(`Gemini: ${this.model} does not support tool calling.`); throw new Error(`Gemini: ${this.model} does not support tool calling.`);
this.providerLog("Gemini.stream - will process this chat completion."); this.providerLog("Gemini.stream - will process this chat completion.");
this.resetUsage();
try { try {
const msgUUID = v4(); const msgUUID = v4();
/** @type {OpenAI.OpenAI.Chat.ChatCompletion} */ /** @type {OpenAI.OpenAI.Chat.ChatCompletion} */
@ -172,6 +174,7 @@ class GeminiProvider extends Provider {
model: this.model, model: this.model,
messages: this.#formatMessages(messages), messages: this.#formatMessages(messages),
stream: true, stream: true,
stream_options: { include_usage: true },
...(Array.isArray(functions) && functions?.length > 0 ...(Array.isArray(functions) && functions?.length > 0
? { tools: this.#formatFunctions(functions), tool_choice: "auto" } ? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
: {}), : {}),
@ -186,6 +189,9 @@ class GeminiProvider extends Provider {
for await (const streamEvent of response) { for await (const streamEvent of response) {
/** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */ /** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */
const chunk = streamEvent; const chunk = streamEvent;
// Capture usage from final chunk (when stream_options.include_usage is true)
if (chunk?.usage) this.recordUsage(chunk.usage);
const { content, tool_calls } = chunk?.choices?.[0]?.delta || {}; const { content, tool_calls } = chunk?.choices?.[0]?.delta || {};
if (content) { if (content) {
@ -228,6 +234,7 @@ class GeminiProvider extends Provider {
extra_content: completion.functionCall.extra_content, extra_content: completion.functionCall.extra_content,
}, },
cost: this.getCost(), cost: this.getCost(),
uuid: msgUUID,
}; };
} }
@ -235,6 +242,7 @@ class GeminiProvider extends Provider {
textResponse: completion.content, textResponse: completion.content,
functionCall: null, functionCall: null,
cost: this.getCost(), cost: this.getCost(),
uuid: msgUUID,
}; };
} catch (error) { } catch (error) {
if (error instanceof OpenAI.AuthenticationError) throw error; if (error instanceof OpenAI.AuthenticationError) throw error;
@ -261,6 +269,8 @@ class GeminiProvider extends Provider {
if (!this.supportsToolCalling) if (!this.supportsToolCalling)
throw new Error(`Gemini: ${this.model} does not support tool calling.`); throw new Error(`Gemini: ${this.model} does not support tool calling.`);
this.providerLog("Gemini.complete - will process this chat completion."); this.providerLog("Gemini.complete - will process this chat completion.");
this.resetUsage();
try { try {
const response = await this.client.chat.completions.create({ const response = await this.client.chat.completions.create({
model: this.model, model: this.model,
@ -271,6 +281,8 @@ class GeminiProvider extends Provider {
: {}), : {}),
}); });
if (response.usage) this.recordUsage(response.usage);
/** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */ /** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */
const completion = response.choices[0].message; const completion = response.choices[0].message;
const cost = this.getCost(response.usage); const cost = this.getCost(response.usage);
@ -287,12 +299,14 @@ class GeminiProvider extends Provider {
extra_content: toolCall.extra_content ?? null, extra_content: toolCall.extra_content ?? null,
}, },
cost, cost,
usage: this.getUsage(),
}; };
} }
return { return {
textResponse: completion.content, textResponse: completion.content,
cost, cost,
usage: this.getUsage(),
}; };
} catch (error) { } catch (error) {
// If invalid Auth error we need to abort because no amount of waiting // If invalid Auth error we need to abort because no amount of waiting

View File

@ -131,7 +131,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -170,7 +171,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -111,7 +111,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -150,7 +151,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -131,8 +131,9 @@ function formatMessagesForTools(messages, options = {}) {
* @param {Array} messages - Raw aibitat message history * @param {Array} messages - Raw aibitat message history
* @param {Array} functions - Aibitat function definitions * @param {Array} functions - Aibitat function definitions
* @param {function|null} eventHandler - Stream event handler * @param {function|null} eventHandler - Stream event handler
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools * @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
* @returns {Promise<{textResponse: string, functionCall: object|null}>} * - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
* @returns {Promise<{textResponse: string, functionCall: object|null, uuid: string, usage: object|null}>}
*/ */
async function tooledStream( async function tooledStream(
client, client,
@ -142,13 +143,23 @@ async function tooledStream(
eventHandler = null, eventHandler = null,
options = {} options = {}
) { ) {
const { provider, ...formatOptions } = options;
// Auto-reset usage if provider is passed
if (provider?.resetUsage) {
try {
provider.resetUsage();
} catch {}
}
const msgUUID = v4(); const msgUUID = v4();
const formattedMessages = formatMessagesForTools(messages, options); const formattedMessages = formatMessagesForTools(messages, formatOptions);
const tools = formatFunctionsToTools(functions); const tools = formatFunctionsToTools(functions);
const stream = await client.chat.completions.create({ const stream = await client.chat.completions.create({
model, model,
stream: true, stream: true,
stream_options: { include_usage: true },
messages: formattedMessages, messages: formattedMessages,
...(tools.length > 0 ? { tools } : {}), ...(tools.length > 0 ? { tools } : {}),
}); });
@ -159,8 +170,14 @@ async function tooledStream(
}; };
const toolCallsByIndex = {}; const toolCallsByIndex = {};
let usage = null;
for await (const chunk of stream) { for await (const chunk of stream) {
// Capture usage from final chunk (some providers send usage after finish_reason)
if (chunk?.usage) {
usage = chunk.usage;
}
if (!chunk?.choices?.[0]) continue; if (!chunk?.choices?.[0]) continue;
const choice = chunk.choices[0]; const choice = chunk.choices[0];
@ -203,6 +220,13 @@ async function tooledStream(
} }
} }
// Auto-record usage if provider is passed and usage is available
if (provider?.recordUsage && usage) {
try {
provider.recordUsage(usage);
} catch {}
}
const toolCallIndices = Object.keys(toolCallsByIndex).map(Number); const toolCallIndices = Object.keys(toolCallsByIndex).map(Number);
if (toolCallIndices.length > 0) { if (toolCallIndices.length > 0) {
const firstToolCall = toolCallsByIndex[Math.min(...toolCallIndices)]; const firstToolCall = toolCallsByIndex[Math.min(...toolCallIndices)];
@ -216,6 +240,8 @@ async function tooledStream(
return { return {
textResponse: result.textResponse, textResponse: result.textResponse,
functionCall: result.functionCall, functionCall: result.functionCall,
uuid: msgUUID,
usage,
}; };
} }
@ -228,8 +254,9 @@ async function tooledStream(
* @param {Array} messages - Raw aibitat message history * @param {Array} messages - Raw aibitat message history
* @param {Array} functions - Aibitat function definitions * @param {Array} functions - Aibitat function definitions
* @param {function} getCostFn - Provider's getCost function * @param {function} getCostFn - Provider's getCost function
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools * @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number}>} * - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number, usage: object|null}>}
*/ */
async function tooledComplete( async function tooledComplete(
client, client,
@ -239,7 +266,16 @@ async function tooledComplete(
getCostFn = () => 0, getCostFn = () => 0,
options = {} options = {}
) { ) {
const formattedMessages = formatMessagesForTools(messages, options); const { provider, ...formatOptions } = options;
// Auto-reset usage if provider is passed
if (provider?.resetUsage) {
try {
provider.resetUsage();
} catch {}
}
const formattedMessages = formatMessagesForTools(messages, formatOptions);
const tools = formatFunctionsToTools(functions); const tools = formatFunctionsToTools(functions);
const response = await client.chat.completions.create({ const response = await client.chat.completions.create({
@ -251,6 +287,14 @@ async function tooledComplete(
const completion = response.choices[0].message; const completion = response.choices[0].message;
const cost = getCostFn(response.usage); const cost = getCostFn(response.usage);
const usage = response.usage || null;
// Auto-record usage if provider is passed and usage is available
if (provider?.recordUsage && usage) {
try {
provider.recordUsage(usage);
} catch {}
}
if (completion.tool_calls && completion.tool_calls.length > 0) { if (completion.tool_calls && completion.tool_calls.length > 0) {
const toolCall = completion.tool_calls[0]; const toolCall = completion.tool_calls[0];
@ -270,6 +314,7 @@ async function tooledComplete(
}, },
}, },
cost, cost,
usage,
}; };
} }
@ -281,12 +326,14 @@ async function tooledComplete(
arguments: functionArgs, arguments: functionArgs,
}, },
cost, cost,
usage,
}; };
} }
return { return {
textResponse: completion.content, textResponse: completion.content,
cost, cost,
usage,
}; };
} }

View File

@ -135,7 +135,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -175,7 +176,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -113,7 +113,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -152,7 +153,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -118,7 +118,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -158,7 +159,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -115,7 +115,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -154,7 +155,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -112,7 +112,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -151,7 +152,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {

View File

@ -245,6 +245,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
this.providerLog( this.providerLog(
"OllamaProvider.stream (tooled) - will process this chat completion." "OllamaProvider.stream (tooled) - will process this chat completion."
); );
this.resetUsage();
await OllamaAILLM.cacheContextWindows(); await OllamaAILLM.cacheContextWindows();
const msgUUID = v4(); const msgUUID = v4();
const formattedMessages = this.#formatMessagesForOllamaTools(messages); const formattedMessages = this.#formatMessagesForOllamaTools(messages);
@ -262,6 +263,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
let toolCalls = null; let toolCalls = null;
for await (const chunk of stream) { for await (const chunk of stream) {
// Capture usage from final chunk (Ollama sends usage when done=true)
if (chunk.done === true) {
this.recordUsage({
prompt_tokens: chunk.prompt_eval_count || 0,
completion_tokens: chunk.eval_count || 0,
});
}
if (!chunk?.message) continue; if (!chunk?.message) continue;
if (chunk.message.content) { if (chunk.message.content) {
@ -297,10 +306,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
name: toolCall.function.name, name: toolCall.function.name,
arguments: args, arguments: args,
}, },
cost: 0,
uuid: msgUUID,
}; };
} }
return { textResponse, functionCall: null }; return { textResponse, functionCall: null, cost: 0, uuid: msgUUID };
} }
// Fallback: UnTooled prompt-based approach via the native Ollama SDK // Fallback: UnTooled prompt-based approach via the native Ollama SDK
@ -431,6 +442,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
functions.length > 0 && (await this.supportsNativeToolCalling()); functions.length > 0 && (await this.supportsNativeToolCalling());
if (useNative) { if (useNative) {
this.resetUsage();
await OllamaAILLM.cacheContextWindows(); await OllamaAILLM.cacheContextWindows();
const formattedMessages = this.#formatMessagesForOllamaTools(messages); const formattedMessages = this.#formatMessagesForOllamaTools(messages);
const tools = formatFunctionsToTools(functions); const tools = formatFunctionsToTools(functions);
@ -442,6 +454,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
options: this.queryOptions, options: this.queryOptions,
}); });
// Record usage (Ollama uses prompt_eval_count/eval_count)
this.recordUsage({
prompt_tokens: response.prompt_eval_count || 0,
completion_tokens: response.eval_count || 0,
});
if (response.message?.tool_calls?.length > 0) { if (response.message?.tool_calls?.length > 0) {
const toolCall = response.message.tool_calls[0]; const toolCall = response.message.tool_calls[0];
const args = const args =
@ -457,12 +475,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
arguments: args, arguments: args,
}, },
cost: 0, cost: 0,
usage: this.getUsage(),
}; };
} }
return { return {
textResponse: response.message?.content || null, textResponse: response.message?.content || null,
cost: 0, cost: 0,
usage: this.getUsage(),
}; };
} }

View File

@ -117,6 +117,8 @@ class OpenAIProvider extends Provider {
*/ */
async stream(messages, functions = [], eventHandler = null) { async stream(messages, functions = [], eventHandler = null) {
this.providerLog("OpenAI.stream - will process this chat completion."); this.providerLog("OpenAI.stream - will process this chat completion.");
this.resetUsage();
try { try {
const msgUUID = v4(); const msgUUID = v4();
@ -178,6 +180,13 @@ class OpenAIProvider extends Provider {
}); });
continue; continue;
} }
if (chunk.type === "response.completed") {
const completedResponse = chunk.response;
if (!completedResponse?.usage) continue;
this.recordUsage(completedResponse.usage);
continue;
}
} }
if (completion.functionCall) { if (completion.functionCall) {
@ -193,6 +202,7 @@ class OpenAIProvider extends Provider {
arguments: completion.functionCall.arguments, arguments: completion.functionCall.arguments,
}, },
cost: this.getCost(), cost: this.getCost(),
uuid: msgUUID,
}; };
} }
@ -200,16 +210,15 @@ class OpenAIProvider extends Provider {
textResponse: completion.content, textResponse: completion.content,
functionCall: null, functionCall: null,
cost: this.getCost(), cost: this.getCost(),
uuid: msgUUID,
}; };
} catch (error) { } catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error; if (error instanceof OpenAI.AuthenticationError) throw error;
if ( if (
error instanceof OpenAI.RateLimitError || error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError || error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! error instanceof OpenAI.APIError
) { ) {
throw new RetryError(error.message); throw new RetryError(error.message);
} }
@ -227,6 +236,8 @@ class OpenAIProvider extends Provider {
*/ */
async complete(messages, functions = []) { async complete(messages, functions = []) {
this.providerLog("OpenAI.complete - will process this chat completion."); this.providerLog("OpenAI.complete - will process this chat completion.");
this.resetUsage();
try { try {
const completion = { const completion = {
content: "", content: "",
@ -245,17 +256,14 @@ class OpenAIProvider extends Provider {
: {}), : {}),
}); });
if (response.usage) this.recordUsage(response.usage);
for (const outputBlock of response.output) { for (const outputBlock of response.output) {
// Grab intermediate text output if it exists
// If no tools are used, this will be returned to the aibitat handler
// Otherwise, this text will never be shown to the user
if (outputBlock.type === "message") { if (outputBlock.type === "message") {
if (outputBlock.content[0]?.type === "output_text") { if (outputBlock.content[0]?.type === "output_text") {
completion.content = outputBlock.content[0].text; completion.content = outputBlock.content[0].text;
} }
} }
// Grab function call output if it exists
if (outputBlock.type === "function_call") { if (outputBlock.type === "function_call") {
completion.functionCall = { completion.functionCall = {
name: outputBlock.name, name: outputBlock.name,
@ -273,13 +281,12 @@ class OpenAIProvider extends Provider {
return { return {
textResponse: completion.content, textResponse: completion.content,
functionCall: { functionCall: {
// For OpenAI, the id is the call_id and we need it in followup requests
// so we can match the function call output to its invocation in the message history.
id: completion.functionCall.call_id, id: completion.functionCall.call_id,
name: completion.functionCall.name, name: completion.functionCall.name,
arguments: completion.functionCall.arguments, arguments: completion.functionCall.arguments,
}, },
cost: this.getCost(), cost: this.getCost(),
usage: this.getUsage(),
}; };
} }
@ -287,16 +294,14 @@ class OpenAIProvider extends Provider {
textResponse: completion.content, textResponse: completion.content,
functionCall: null, functionCall: null,
cost: this.getCost(), cost: this.getCost(),
usage: this.getUsage(),
}; };
} catch (error) { } catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error; if (error instanceof OpenAI.AuthenticationError) throw error;
if ( if (
error instanceof OpenAI.RateLimitError || error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError || error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! error instanceof OpenAI.APIError
) { ) {
throw new RetryError(error.message); throw new RetryError(error.message);
} }
@ -307,9 +312,7 @@ class OpenAIProvider extends Provider {
/** /**
* Get the cost of the completion. * Get the cost of the completion.
* * @returns {number} The cost of the completion (currently returns 0).
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
*/ */
getCost() { getCost() {
return 0; return 0;

View File

@ -121,7 +121,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
eventHandler eventHandler,
{ provider: this }
); );
} catch (error) { } catch (error) {
console.error(error.message, error); console.error(error.message, error);
@ -160,7 +161,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
this.model, this.model,
messages, messages,
functions, functions,
this.getCost.bind(this) this.getCost.bind(this),
{ provider: this }
); );
if (result.retryWithError) { if (result.retryWithError) {