Support Agent stream metric reporting (#5197)

This commit is contained in:
Timothy Carambat 2026-03-12 12:50:02 -07:00 committed by GitHub
parent f1439d7fcb
commit 15a84d5121
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 339 additions and 55 deletions

View File

@ -186,14 +186,17 @@ const HistoricalMessage = ({
export default memo(
HistoricalMessage,
// Skip re-render the historical message:
// if the content is the exact same AND (not streaming)
// the lastMessage status is the same (regen icon)
// and the chatID matches between renders. (feedback icons)
// - if the content is the exact same
// - AND (not streaming)
// - the lastMessage status is the same (regen icon)
// - the chatID matches between renders. (feedback icons)
// - the metrics are the same (metrics are updated in real time)
(prevProps, nextProps) => {
return (
prevProps.message === nextProps.message &&
prevProps.isLastMessage === nextProps.isLastMessage &&
prevProps.chatId === nextProps.chatId
prevProps.chatId === nextProps.chatId &&
JSON.stringify(prevProps.metrics) === JSON.stringify(nextProps.metrics)
);
}
);

View File

@ -41,6 +41,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null,
animate: false,
pending: false,
metrics: {},
},
];
});
@ -74,6 +75,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null,
animate: false,
pending: false,
metrics: {},
},
];
}
@ -95,6 +97,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null,
animate: false,
pending: false,
metrics: {},
},
];
}
@ -111,6 +114,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null,
animate: false,
pending: false,
metrics: {},
},
];
} else {
@ -127,6 +131,13 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
]; // If the message is known, replace it with the new content
}
if (type === "usageMetrics") {
if (!data.content.metrics) return prev;
return prev.map((msg) =>
msg.uuid === uuid ? { ...msg, metrics: data.content.metrics } : msg
);
}
if (type === "textResponseChunk") {
return prev
.map((msg) =>
@ -172,6 +183,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null,
animate: false,
pending: false,
metrics: data.metrics || {},
},
];
});
@ -190,6 +202,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: data.content,
animate: false,
pending: false,
metrics: {},
},
];
});
@ -208,6 +221,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
error: null,
animate: data?.animate || false,
pending: false,
metrics: data.metrics || {},
},
];
});

View File

@ -626,6 +626,8 @@ ${this.getHistory({ to: route.to })
);
}
// Store the active provider so plugins can access usage metrics
this.provider = provider;
this.newMessage({ ...route, content });
return content;
}
@ -652,7 +654,7 @@ ${this.getHistory({ to: route.to })
this?.socket?.send(type, data);
};
/** @type {{ functionCall: { name: string, arguments: string }, textResponse: string }} */
/** @type {{ functionCall: { name: string, arguments: string }, textResponse: string, uuid: string }} */
const completionStream = await provider.stream(
messages,
functions,
@ -669,6 +671,11 @@ ${this.getHistory({ to: route.to })
);
const finalStream = await provider.stream(messages, [], eventHandler);
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: finalStream?.uuid || v4(),
metrics: provider.getUsage(),
});
const finalResponse =
finalStream?.textResponse ||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run.";
@ -726,11 +733,17 @@ ${this.getHistory({ to: route.to })
this.handlerProps?.log?.(
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
);
const directOutputUUID = completionStream?.uuid || v4();
eventHandler?.("reportStreamEvent", {
type: "fullTextResponse",
uuid: v4(),
uuid: directOutputUUID,
content: result,
});
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: directOutputUUID,
metrics: provider.getUsage(),
});
return result;
}
@ -751,6 +764,11 @@ ${this.getHistory({ to: route.to })
);
}
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: completionStream?.uuid || v4(),
metrics: provider.getUsage(),
});
return completionStream?.textResponse;
}
@ -762,6 +780,8 @@ ${this.getHistory({ to: route.to })
* @param messages
* @param functions
* @param byAgent
* @param depth
* @param msgUUID - The message UUID to use for event correlation (created at depth=0)
*
* @returns {Promise<string>}
*/
@ -770,8 +790,15 @@ ${this.getHistory({ to: route.to })
messages = [],
functions = [],
byAgent = null,
depth = 0
depth = 0,
msgUUID = null
) {
// Create a stable UUID at the start of execution for event correlation
if (!msgUUID) msgUUID = v4();
const eventHandler = (type, data) => {
this?.socket?.send(type, data);
};
// get the chat completion
const completion = await provider.complete(messages, functions);
@ -785,6 +812,11 @@ ${this.getHistory({ to: route.to })
);
const finalCompletion = await provider.complete(messages, []);
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: msgUUID,
metrics: provider.getUsage(),
});
return (
finalCompletion?.textResponse ||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run."
@ -808,7 +840,8 @@ ${this.getHistory({ to: route.to })
],
functions,
byAgent,
depth + 1
depth + 1,
msgUUID
);
}
@ -836,6 +869,11 @@ ${this.getHistory({ to: route.to })
this.handlerProps?.log?.(
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
);
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: msgUUID,
metrics: provider.getUsage(),
});
return result;
}
@ -852,10 +890,16 @@ ${this.getHistory({ to: route.to })
],
functions,
byAgent,
depth + 1
depth + 1,
msgUUID
);
}
eventHandler?.("reportStreamEvent", {
type: "usageMetrics",
uuid: msgUUID,
metrics: provider.getUsage(),
});
return completion?.textResponse;
}

View File

@ -43,6 +43,7 @@ const chatHistory = {
},
_store: async function (aibitat, { prompt, response } = {}) {
const invocation = aibitat.handlerProps.invocation;
const metrics = aibitat.provider?.getUsage?.() ?? {};
await WorkspaceChats.new({
workspaceId: Number(invocation.workspace_id),
prompt,
@ -50,6 +51,7 @@ const chatHistory = {
text: response,
sources: [],
type: "chat",
metrics,
},
user: { id: invocation?.user_id || null },
threadId: invocation?.thread_id || null,
@ -60,6 +62,7 @@ const chatHistory = {
{ prompt, response, options = {} } = {}
) {
const invocation = aibitat.handlerProps.invocation;
const metrics = aibitat.provider?.getUsage?.() ?? {};
await WorkspaceChats.new({
workspaceId: Number(invocation.workspace_id),
prompt,
@ -71,6 +74,7 @@ const chatHistory = {
? options.storedResponse(response)
: response,
type: options?.saveAsType ?? "chat",
metrics,
},
user: { id: invocation?.user_id || null },
threadId: invocation?.thread_id || null,

View File

@ -33,6 +33,17 @@ const { OllamaAILLM } = require("../../../AiProviders/ollama");
const DEFAULT_WORKSPACE_PROMPT =
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
/**
* @typedef {Object} ProviderUsageMetrics
* @property {number} prompt_tokens - Number of tokens in the prompt/input
* @property {number} completion_tokens - Number of tokens in the completion/output
* @property {number} total_tokens - Total tokens used
* @property {number} duration - Duration in seconds
* @property {number} outputTps - Output tokens per second
* @property {string} model - Model name
* @property {Date} timestamp - Timestamp of the completion
*/
class Provider {
_client;
@ -51,6 +62,27 @@ class Provider {
*/
executingUserId = "";
/**
* Stores the usage metrics from the last completion call.
* @type {ProviderUsageMetrics}
*/
lastUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
duration: 0,
outputTps: 0,
model: null,
provider: null,
timestamp: null,
};
/**
* Timestamp when the current request started (for duration calculation).
* @type {number}
*/
_requestStartTime = 0;
constructor(client) {
if (this.constructor == Provider) {
return;
@ -407,6 +439,60 @@ class Provider {
return false;
}
/**
* Resets the usage metrics to zero and starts the request timer.
* Call this before each completion to ensure accurate per-call metrics.
*/
resetUsage() {
this._requestStartTime = Date.now();
this.lastUsage = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
outputTps: 0,
duration: 0,
model: null,
provider: null,
timestamp: null,
};
}
/**
* Updates the stored usage metrics from a provider response.
* Override in subclasses to handle provider-specific usage formats.
* @param {Object} usage - The usage object from the provider response
*/
recordUsage(usage = {}) {
let duration = 0;
if (this._requestStartTime > 0) {
duration = (Date.now() - this._requestStartTime) / 1000;
}
const promptTokens = usage.prompt_tokens || usage.input_tokens || 0;
const completionTokens =
usage.completion_tokens || usage.output_tokens || 0;
this.lastUsage = {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: usage.total_tokens || promptTokens + completionTokens,
outputTps:
completionTokens && duration > 0 ? completionTokens / duration : 0,
duration,
model: this.model,
provider: this.constructor.name,
timestamp: new Date(),
};
}
/**
* Get the usage metrics from the last completion.
* @returns {ProviderUsageMetrics} The usage metrics
*/
getUsage() {
return { ...this.lastUsage };
}
/**
* Stream a chat completion from the LLM with tool calling
* Note: This using the OpenAI API format and may need to be adapted for other providers.

View File

@ -180,9 +180,11 @@ class AnthropicProvider extends Provider {
* @param {any[]} messages - The messages to send to the LLM.
* @param {any[]} functions - The functions to use in the LLM.
* @param {function} eventHandler - The event handler to use to report stream events.
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>} - The result of the chat completion.
*/
async stream(messages, functions = [], eventHandler = null) {
this.resetUsage();
try {
const msgUUID = v4();
const [systemPrompt, chats] = this.#prepareMessages(messages);
@ -205,7 +207,20 @@ class AnthropicProvider extends Provider {
textResponse: "",
};
// Track usage from streaming events
const usage = { input_tokens: 0, output_tokens: 0 };
for await (const chunk of response) {
// Capture input tokens from message_start event
if (chunk.type === "message_start" && chunk.message?.usage) {
usage.input_tokens = chunk.message.usage.input_tokens || 0;
}
// Capture output tokens from message_delta event
if (chunk.type === "message_delta" && chunk.usage) {
usage.output_tokens = chunk.usage.output_tokens || 0;
}
if (chunk.type === "content_block_start") {
if (chunk.content_block.type === "text") {
result.textResponse += chunk.content_block.text;
@ -254,6 +269,8 @@ class AnthropicProvider extends Provider {
}
}
// Record accumulated usage
this.recordUsage(usage);
if (result.functionCall) {
result.functionCall.arguments = safeJsonParse(
result.functionCall.arguments,
@ -278,6 +295,7 @@ class AnthropicProvider extends Provider {
arguments: result.functionCall.arguments,
},
cost: 0,
uuid: msgUUID,
};
}
@ -285,6 +303,7 @@ class AnthropicProvider extends Provider {
textResponse: result.textResponse,
functionCall: null,
cost: 0,
uuid: msgUUID,
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
@ -311,6 +330,8 @@ class AnthropicProvider extends Provider {
* @returns The completion.
*/
async complete(messages, functions = []) {
this.resetUsage();
try {
const [systemPrompt, chats] = this.#prepareMessages(messages);
const response = await this.client.messages.create(
@ -327,6 +348,9 @@ class AnthropicProvider extends Provider {
{ headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools.
);
// Record usage from response (Anthropic uses input_tokens/output_tokens)
if (response.usage) this.recordUsage(response.usage);
// We know that we need to call a tool. So we are about to recurse through completions/handleExecution
// https://docs.anthropic.com/claude/docs/tool-use#how-tool-use-works
if (response.stop_reason === "tool_use") {
@ -374,6 +398,7 @@ class AnthropicProvider extends Provider {
arguments: functionArgs,
},
cost: 0,
usage: this.getUsage(),
};
}
@ -383,6 +408,7 @@ class AnthropicProvider extends Provider {
completion?.text ??
"The model failed to complete the task and return back a valid response.",
cost: 0,
usage: this.getUsage(),
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting

View File

@ -34,7 +34,7 @@ class AzureOpenAiProvider extends Provider {
* @param {any[]} messages
* @param {any[]} functions
* @param {function} eventHandler
* @returns {Promise<{ functionCall: any, textResponse: string }>}
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>}
*/
async stream(messages, functions = [], eventHandler = null) {
this.providerLog("Provider.stream - will process this chat completion.");
@ -45,7 +45,8 @@ class AzureOpenAiProvider extends Provider {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -75,7 +76,8 @@ class AzureOpenAiProvider extends Provider {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -48,7 +48,10 @@ class DeepSeekProvider extends InheritMultiple([Provider, UnTooled]) {
}
get #tooledOptions() {
return this.#isThinkingModel ? { injectReasoningContent: true } : {};
return {
provider: this,
...(this.#isThinkingModel ? { injectReasoningContent: true } : {}),
};
}
async #handleFunctionCallChat({ messages = [] }) {

View File

@ -112,7 +112,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -151,7 +152,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -165,6 +165,8 @@ class GeminiProvider extends Provider {
if (!this.supportsToolCalling)
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
this.providerLog("Gemini.stream - will process this chat completion.");
this.resetUsage();
try {
const msgUUID = v4();
/** @type {OpenAI.OpenAI.Chat.ChatCompletion} */
@ -172,6 +174,7 @@ class GeminiProvider extends Provider {
model: this.model,
messages: this.#formatMessages(messages),
stream: true,
stream_options: { include_usage: true },
...(Array.isArray(functions) && functions?.length > 0
? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
: {}),
@ -186,6 +189,9 @@ class GeminiProvider extends Provider {
for await (const streamEvent of response) {
/** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */
const chunk = streamEvent;
// Capture usage from final chunk (when stream_options.include_usage is true)
if (chunk?.usage) this.recordUsage(chunk.usage);
const { content, tool_calls } = chunk?.choices?.[0]?.delta || {};
if (content) {
@ -228,6 +234,7 @@ class GeminiProvider extends Provider {
extra_content: completion.functionCall.extra_content,
},
cost: this.getCost(),
uuid: msgUUID,
};
}
@ -235,6 +242,7 @@ class GeminiProvider extends Provider {
textResponse: completion.content,
functionCall: null,
cost: this.getCost(),
uuid: msgUUID,
};
} catch (error) {
if (error instanceof OpenAI.AuthenticationError) throw error;
@ -261,6 +269,8 @@ class GeminiProvider extends Provider {
if (!this.supportsToolCalling)
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
this.providerLog("Gemini.complete - will process this chat completion.");
this.resetUsage();
try {
const response = await this.client.chat.completions.create({
model: this.model,
@ -271,6 +281,8 @@ class GeminiProvider extends Provider {
: {}),
});
if (response.usage) this.recordUsage(response.usage);
/** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */
const completion = response.choices[0].message;
const cost = this.getCost(response.usage);
@ -287,12 +299,14 @@ class GeminiProvider extends Provider {
extra_content: toolCall.extra_content ?? null,
},
cost,
usage: this.getUsage(),
};
}
return {
textResponse: completion.content,
cost,
usage: this.getUsage(),
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting

View File

@ -131,7 +131,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -170,7 +171,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -111,7 +111,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -150,7 +151,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -131,8 +131,9 @@ function formatMessagesForTools(messages, options = {}) {
* @param {Array} messages - Raw aibitat message history
* @param {Array} functions - Aibitat function definitions
* @param {function|null} eventHandler - Stream event handler
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools
* @returns {Promise<{textResponse: string, functionCall: object|null}>}
* @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
* - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
* @returns {Promise<{textResponse: string, functionCall: object|null, uuid: string, usage: object|null}>}
*/
async function tooledStream(
client,
@ -142,13 +143,23 @@ async function tooledStream(
eventHandler = null,
options = {}
) {
const { provider, ...formatOptions } = options;
// Auto-reset usage if provider is passed
if (provider?.resetUsage) {
try {
provider.resetUsage();
} catch {}
}
const msgUUID = v4();
const formattedMessages = formatMessagesForTools(messages, options);
const formattedMessages = formatMessagesForTools(messages, formatOptions);
const tools = formatFunctionsToTools(functions);
const stream = await client.chat.completions.create({
model,
stream: true,
stream_options: { include_usage: true },
messages: formattedMessages,
...(tools.length > 0 ? { tools } : {}),
});
@ -159,8 +170,14 @@ async function tooledStream(
};
const toolCallsByIndex = {};
let usage = null;
for await (const chunk of stream) {
// Capture usage from final chunk (some providers send usage after finish_reason)
if (chunk?.usage) {
usage = chunk.usage;
}
if (!chunk?.choices?.[0]) continue;
const choice = chunk.choices[0];
@ -203,6 +220,13 @@ async function tooledStream(
}
}
// Auto-record usage if provider is passed and usage is available
if (provider?.recordUsage && usage) {
try {
provider.recordUsage(usage);
} catch {}
}
const toolCallIndices = Object.keys(toolCallsByIndex).map(Number);
if (toolCallIndices.length > 0) {
const firstToolCall = toolCallsByIndex[Math.min(...toolCallIndices)];
@ -216,6 +240,8 @@ async function tooledStream(
return {
textResponse: result.textResponse,
functionCall: result.functionCall,
uuid: msgUUID,
usage,
};
}
@ -228,8 +254,9 @@ async function tooledStream(
* @param {Array} messages - Raw aibitat message history
* @param {Array} functions - Aibitat function definitions
* @param {function} getCostFn - Provider's getCost function
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number}>}
* @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
* - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number, usage: object|null}>}
*/
async function tooledComplete(
client,
@ -239,7 +266,16 @@ async function tooledComplete(
getCostFn = () => 0,
options = {}
) {
const formattedMessages = formatMessagesForTools(messages, options);
const { provider, ...formatOptions } = options;
// Auto-reset usage if provider is passed
if (provider?.resetUsage) {
try {
provider.resetUsage();
} catch {}
}
const formattedMessages = formatMessagesForTools(messages, formatOptions);
const tools = formatFunctionsToTools(functions);
const response = await client.chat.completions.create({
@ -251,6 +287,14 @@ async function tooledComplete(
const completion = response.choices[0].message;
const cost = getCostFn(response.usage);
const usage = response.usage || null;
// Auto-record usage if provider is passed and usage is available
if (provider?.recordUsage && usage) {
try {
provider.recordUsage(usage);
} catch {}
}
if (completion.tool_calls && completion.tool_calls.length > 0) {
const toolCall = completion.tool_calls[0];
@ -270,6 +314,7 @@ async function tooledComplete(
},
},
cost,
usage,
};
}
@ -281,12 +326,14 @@ async function tooledComplete(
arguments: functionArgs,
},
cost,
usage,
};
}
return {
textResponse: completion.content,
cost,
usage,
};
}

View File

@ -135,7 +135,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -175,7 +176,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -113,7 +113,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -152,7 +153,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -118,7 +118,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -158,7 +159,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -115,7 +115,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -154,7 +155,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -112,7 +112,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -151,7 +152,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {

View File

@ -245,6 +245,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
this.providerLog(
"OllamaProvider.stream (tooled) - will process this chat completion."
);
this.resetUsage();
await OllamaAILLM.cacheContextWindows();
const msgUUID = v4();
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
@ -262,6 +263,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
let toolCalls = null;
for await (const chunk of stream) {
// Capture usage from final chunk (Ollama sends usage when done=true)
if (chunk.done === true) {
this.recordUsage({
prompt_tokens: chunk.prompt_eval_count || 0,
completion_tokens: chunk.eval_count || 0,
});
}
if (!chunk?.message) continue;
if (chunk.message.content) {
@ -297,10 +306,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
name: toolCall.function.name,
arguments: args,
},
cost: 0,
uuid: msgUUID,
};
}
return { textResponse, functionCall: null };
return { textResponse, functionCall: null, cost: 0, uuid: msgUUID };
}
// Fallback: UnTooled prompt-based approach via the native Ollama SDK
@ -431,6 +442,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
functions.length > 0 && (await this.supportsNativeToolCalling());
if (useNative) {
this.resetUsage();
await OllamaAILLM.cacheContextWindows();
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
const tools = formatFunctionsToTools(functions);
@ -442,6 +454,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
options: this.queryOptions,
});
// Record usage (Ollama uses prompt_eval_count/eval_count)
this.recordUsage({
prompt_tokens: response.prompt_eval_count || 0,
completion_tokens: response.eval_count || 0,
});
if (response.message?.tool_calls?.length > 0) {
const toolCall = response.message.tool_calls[0];
const args =
@ -457,12 +475,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
arguments: args,
},
cost: 0,
usage: this.getUsage(),
};
}
return {
textResponse: response.message?.content || null,
cost: 0,
usage: this.getUsage(),
};
}

View File

@ -117,6 +117,8 @@ class OpenAIProvider extends Provider {
*/
async stream(messages, functions = [], eventHandler = null) {
this.providerLog("OpenAI.stream - will process this chat completion.");
this.resetUsage();
try {
const msgUUID = v4();
@ -178,6 +180,13 @@ class OpenAIProvider extends Provider {
});
continue;
}
if (chunk.type === "response.completed") {
const completedResponse = chunk.response;
if (!completedResponse?.usage) continue;
this.recordUsage(completedResponse.usage);
continue;
}
}
if (completion.functionCall) {
@ -193,6 +202,7 @@ class OpenAIProvider extends Provider {
arguments: completion.functionCall.arguments,
},
cost: this.getCost(),
uuid: msgUUID,
};
}
@ -200,16 +210,15 @@ class OpenAIProvider extends Provider {
textResponse: completion.content,
functionCall: null,
cost: this.getCost(),
uuid: msgUUID,
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
error instanceof OpenAI.APIError
) {
throw new RetryError(error.message);
}
@ -227,6 +236,8 @@ class OpenAIProvider extends Provider {
*/
async complete(messages, functions = []) {
this.providerLog("OpenAI.complete - will process this chat completion.");
this.resetUsage();
try {
const completion = {
content: "",
@ -245,17 +256,14 @@ class OpenAIProvider extends Provider {
: {}),
});
if (response.usage) this.recordUsage(response.usage);
for (const outputBlock of response.output) {
// Grab intermediate text output if it exists
// If no tools are used, this will be returned to the aibitat handler
// Otherwise, this text will never be shown to the user
if (outputBlock.type === "message") {
if (outputBlock.content[0]?.type === "output_text") {
completion.content = outputBlock.content[0].text;
}
}
// Grab function call output if it exists
if (outputBlock.type === "function_call") {
completion.functionCall = {
name: outputBlock.name,
@ -273,13 +281,12 @@ class OpenAIProvider extends Provider {
return {
textResponse: completion.content,
functionCall: {
// For OpenAI, the id is the call_id and we need it in followup requests
// so we can match the function call output to its invocation in the message history.
id: completion.functionCall.call_id,
name: completion.functionCall.name,
arguments: completion.functionCall.arguments,
},
cost: this.getCost(),
usage: this.getUsage(),
};
}
@ -287,16 +294,14 @@ class OpenAIProvider extends Provider {
textResponse: completion.content,
functionCall: null,
cost: this.getCost(),
usage: this.getUsage(),
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
error instanceof OpenAI.APIError
) {
throw new RetryError(error.message);
}
@ -307,9 +312,7 @@ class OpenAIProvider extends Provider {
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* @returns {number} The cost of the completion (currently returns 0).
*/
getCost() {
return 0;

View File

@ -121,7 +121,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
eventHandler
eventHandler,
{ provider: this }
);
} catch (error) {
console.error(error.message, error);
@ -160,7 +161,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
this.model,
messages,
functions,
this.getCost.bind(this)
this.getCost.bind(this),
{ provider: this }
);
if (result.retryWithError) {