Support Agent stream metric reporting (#5197)
This commit is contained in:
parent
f1439d7fcb
commit
15a84d5121
@ -186,14 +186,17 @@ const HistoricalMessage = ({
|
||||
export default memo(
|
||||
HistoricalMessage,
|
||||
// Skip re-render the historical message:
|
||||
// if the content is the exact same AND (not streaming)
|
||||
// the lastMessage status is the same (regen icon)
|
||||
// and the chatID matches between renders. (feedback icons)
|
||||
// - if the content is the exact same
|
||||
// - AND (not streaming)
|
||||
// - the lastMessage status is the same (regen icon)
|
||||
// - the chatID matches between renders. (feedback icons)
|
||||
// - the metrics are the same (metrics are updated in real time)
|
||||
(prevProps, nextProps) => {
|
||||
return (
|
||||
prevProps.message === nextProps.message &&
|
||||
prevProps.isLastMessage === nextProps.isLastMessage &&
|
||||
prevProps.chatId === nextProps.chatId
|
||||
prevProps.chatId === nextProps.chatId &&
|
||||
JSON.stringify(prevProps.metrics) === JSON.stringify(nextProps.metrics)
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
@ -41,6 +41,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
error: null,
|
||||
animate: false,
|
||||
pending: false,
|
||||
metrics: {},
|
||||
},
|
||||
];
|
||||
});
|
||||
@ -74,6 +75,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
error: null,
|
||||
animate: false,
|
||||
pending: false,
|
||||
metrics: {},
|
||||
},
|
||||
];
|
||||
}
|
||||
@ -95,6 +97,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
error: null,
|
||||
animate: false,
|
||||
pending: false,
|
||||
metrics: {},
|
||||
},
|
||||
];
|
||||
}
|
||||
@ -111,6 +114,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
error: null,
|
||||
animate: false,
|
||||
pending: false,
|
||||
metrics: {},
|
||||
},
|
||||
];
|
||||
} else {
|
||||
@ -127,6 +131,13 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
]; // If the message is known, replace it with the new content
|
||||
}
|
||||
|
||||
if (type === "usageMetrics") {
|
||||
if (!data.content.metrics) return prev;
|
||||
return prev.map((msg) =>
|
||||
msg.uuid === uuid ? { ...msg, metrics: data.content.metrics } : msg
|
||||
);
|
||||
}
|
||||
|
||||
if (type === "textResponseChunk") {
|
||||
return prev
|
||||
.map((msg) =>
|
||||
@ -172,6 +183,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
error: null,
|
||||
animate: false,
|
||||
pending: false,
|
||||
metrics: data.metrics || {},
|
||||
},
|
||||
];
|
||||
});
|
||||
@ -190,6 +202,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
error: data.content,
|
||||
animate: false,
|
||||
pending: false,
|
||||
metrics: {},
|
||||
},
|
||||
];
|
||||
});
|
||||
@ -208,6 +221,7 @@ export default function handleSocketResponse(socket, event, setChatHistory) {
|
||||
error: null,
|
||||
animate: data?.animate || false,
|
||||
pending: false,
|
||||
metrics: data.metrics || {},
|
||||
},
|
||||
];
|
||||
});
|
||||
|
||||
@ -626,6 +626,8 @@ ${this.getHistory({ to: route.to })
|
||||
);
|
||||
}
|
||||
|
||||
// Store the active provider so plugins can access usage metrics
|
||||
this.provider = provider;
|
||||
this.newMessage({ ...route, content });
|
||||
return content;
|
||||
}
|
||||
@ -652,7 +654,7 @@ ${this.getHistory({ to: route.to })
|
||||
this?.socket?.send(type, data);
|
||||
};
|
||||
|
||||
/** @type {{ functionCall: { name: string, arguments: string }, textResponse: string }} */
|
||||
/** @type {{ functionCall: { name: string, arguments: string }, textResponse: string, uuid: string }} */
|
||||
const completionStream = await provider.stream(
|
||||
messages,
|
||||
functions,
|
||||
@ -669,6 +671,11 @@ ${this.getHistory({ to: route.to })
|
||||
);
|
||||
|
||||
const finalStream = await provider.stream(messages, [], eventHandler);
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "usageMetrics",
|
||||
uuid: finalStream?.uuid || v4(),
|
||||
metrics: provider.getUsage(),
|
||||
});
|
||||
const finalResponse =
|
||||
finalStream?.textResponse ||
|
||||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run.";
|
||||
@ -726,11 +733,17 @@ ${this.getHistory({ to: route.to })
|
||||
this.handlerProps?.log?.(
|
||||
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
|
||||
);
|
||||
const directOutputUUID = completionStream?.uuid || v4();
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "fullTextResponse",
|
||||
uuid: v4(),
|
||||
uuid: directOutputUUID,
|
||||
content: result,
|
||||
});
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "usageMetrics",
|
||||
uuid: directOutputUUID,
|
||||
metrics: provider.getUsage(),
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -751,6 +764,11 @@ ${this.getHistory({ to: route.to })
|
||||
);
|
||||
}
|
||||
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "usageMetrics",
|
||||
uuid: completionStream?.uuid || v4(),
|
||||
metrics: provider.getUsage(),
|
||||
});
|
||||
return completionStream?.textResponse;
|
||||
}
|
||||
|
||||
@ -762,6 +780,8 @@ ${this.getHistory({ to: route.to })
|
||||
* @param messages
|
||||
* @param functions
|
||||
* @param byAgent
|
||||
* @param depth
|
||||
* @param msgUUID - The message UUID to use for event correlation (created at depth=0)
|
||||
*
|
||||
* @returns {Promise<string>}
|
||||
*/
|
||||
@ -770,8 +790,15 @@ ${this.getHistory({ to: route.to })
|
||||
messages = [],
|
||||
functions = [],
|
||||
byAgent = null,
|
||||
depth = 0
|
||||
depth = 0,
|
||||
msgUUID = null
|
||||
) {
|
||||
// Create a stable UUID at the start of execution for event correlation
|
||||
if (!msgUUID) msgUUID = v4();
|
||||
const eventHandler = (type, data) => {
|
||||
this?.socket?.send(type, data);
|
||||
};
|
||||
|
||||
// get the chat completion
|
||||
const completion = await provider.complete(messages, functions);
|
||||
|
||||
@ -785,6 +812,11 @@ ${this.getHistory({ to: route.to })
|
||||
);
|
||||
|
||||
const finalCompletion = await provider.complete(messages, []);
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "usageMetrics",
|
||||
uuid: msgUUID,
|
||||
metrics: provider.getUsage(),
|
||||
});
|
||||
return (
|
||||
finalCompletion?.textResponse ||
|
||||
"I reached the maximum number of tool calls allowed for a single response. Here is what I have so far based on the tools I was able to run."
|
||||
@ -808,7 +840,8 @@ ${this.getHistory({ to: route.to })
|
||||
],
|
||||
functions,
|
||||
byAgent,
|
||||
depth + 1
|
||||
depth + 1,
|
||||
msgUUID
|
||||
);
|
||||
}
|
||||
|
||||
@ -836,6 +869,11 @@ ${this.getHistory({ to: route.to })
|
||||
this.handlerProps?.log?.(
|
||||
`${fn.caller} tool call resulted in direct output! Returning raw result as string. NO MORE TOOL CALLS WILL BE EXECUTED.`
|
||||
);
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "usageMetrics",
|
||||
uuid: msgUUID,
|
||||
metrics: provider.getUsage(),
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -852,10 +890,16 @@ ${this.getHistory({ to: route.to })
|
||||
],
|
||||
functions,
|
||||
byAgent,
|
||||
depth + 1
|
||||
depth + 1,
|
||||
msgUUID
|
||||
);
|
||||
}
|
||||
|
||||
eventHandler?.("reportStreamEvent", {
|
||||
type: "usageMetrics",
|
||||
uuid: msgUUID,
|
||||
metrics: provider.getUsage(),
|
||||
});
|
||||
return completion?.textResponse;
|
||||
}
|
||||
|
||||
|
||||
@ -43,6 +43,7 @@ const chatHistory = {
|
||||
},
|
||||
_store: async function (aibitat, { prompt, response } = {}) {
|
||||
const invocation = aibitat.handlerProps.invocation;
|
||||
const metrics = aibitat.provider?.getUsage?.() ?? {};
|
||||
await WorkspaceChats.new({
|
||||
workspaceId: Number(invocation.workspace_id),
|
||||
prompt,
|
||||
@ -50,6 +51,7 @@ const chatHistory = {
|
||||
text: response,
|
||||
sources: [],
|
||||
type: "chat",
|
||||
metrics,
|
||||
},
|
||||
user: { id: invocation?.user_id || null },
|
||||
threadId: invocation?.thread_id || null,
|
||||
@ -60,6 +62,7 @@ const chatHistory = {
|
||||
{ prompt, response, options = {} } = {}
|
||||
) {
|
||||
const invocation = aibitat.handlerProps.invocation;
|
||||
const metrics = aibitat.provider?.getUsage?.() ?? {};
|
||||
await WorkspaceChats.new({
|
||||
workspaceId: Number(invocation.workspace_id),
|
||||
prompt,
|
||||
@ -71,6 +74,7 @@ const chatHistory = {
|
||||
? options.storedResponse(response)
|
||||
: response,
|
||||
type: options?.saveAsType ?? "chat",
|
||||
metrics,
|
||||
},
|
||||
user: { id: invocation?.user_id || null },
|
||||
threadId: invocation?.thread_id || null,
|
||||
|
||||
@ -33,6 +33,17 @@ const { OllamaAILLM } = require("../../../AiProviders/ollama");
|
||||
const DEFAULT_WORKSPACE_PROMPT =
|
||||
"You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions.";
|
||||
|
||||
/**
|
||||
* @typedef {Object} ProviderUsageMetrics
|
||||
* @property {number} prompt_tokens - Number of tokens in the prompt/input
|
||||
* @property {number} completion_tokens - Number of tokens in the completion/output
|
||||
* @property {number} total_tokens - Total tokens used
|
||||
* @property {number} duration - Duration in seconds
|
||||
* @property {number} outputTps - Output tokens per second
|
||||
* @property {string} model - Model name
|
||||
* @property {Date} timestamp - Timestamp of the completion
|
||||
*/
|
||||
|
||||
class Provider {
|
||||
_client;
|
||||
|
||||
@ -51,6 +62,27 @@ class Provider {
|
||||
*/
|
||||
executingUserId = "";
|
||||
|
||||
/**
|
||||
* Stores the usage metrics from the last completion call.
|
||||
* @type {ProviderUsageMetrics}
|
||||
*/
|
||||
lastUsage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
duration: 0,
|
||||
outputTps: 0,
|
||||
model: null,
|
||||
provider: null,
|
||||
timestamp: null,
|
||||
};
|
||||
|
||||
/**
|
||||
* Timestamp when the current request started (for duration calculation).
|
||||
* @type {number}
|
||||
*/
|
||||
_requestStartTime = 0;
|
||||
|
||||
constructor(client) {
|
||||
if (this.constructor == Provider) {
|
||||
return;
|
||||
@ -407,6 +439,60 @@ class Provider {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the usage metrics to zero and starts the request timer.
|
||||
* Call this before each completion to ensure accurate per-call metrics.
|
||||
*/
|
||||
resetUsage() {
|
||||
this._requestStartTime = Date.now();
|
||||
this.lastUsage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
outputTps: 0,
|
||||
duration: 0,
|
||||
model: null,
|
||||
provider: null,
|
||||
timestamp: null,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the stored usage metrics from a provider response.
|
||||
* Override in subclasses to handle provider-specific usage formats.
|
||||
* @param {Object} usage - The usage object from the provider response
|
||||
*/
|
||||
recordUsage(usage = {}) {
|
||||
let duration = 0;
|
||||
if (this._requestStartTime > 0) {
|
||||
duration = (Date.now() - this._requestStartTime) / 1000;
|
||||
}
|
||||
|
||||
const promptTokens = usage.prompt_tokens || usage.input_tokens || 0;
|
||||
const completionTokens =
|
||||
usage.completion_tokens || usage.output_tokens || 0;
|
||||
|
||||
this.lastUsage = {
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens,
|
||||
total_tokens: usage.total_tokens || promptTokens + completionTokens,
|
||||
outputTps:
|
||||
completionTokens && duration > 0 ? completionTokens / duration : 0,
|
||||
duration,
|
||||
model: this.model,
|
||||
provider: this.constructor.name,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the usage metrics from the last completion.
|
||||
* @returns {ProviderUsageMetrics} The usage metrics
|
||||
*/
|
||||
getUsage() {
|
||||
return { ...this.lastUsage };
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat completion from the LLM with tool calling
|
||||
* Note: This using the OpenAI API format and may need to be adapted for other providers.
|
||||
|
||||
@ -180,9 +180,11 @@ class AnthropicProvider extends Provider {
|
||||
* @param {any[]} messages - The messages to send to the LLM.
|
||||
* @param {any[]} functions - The functions to use in the LLM.
|
||||
* @param {function} eventHandler - The event handler to use to report stream events.
|
||||
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
|
||||
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>} - The result of the chat completion.
|
||||
*/
|
||||
async stream(messages, functions = [], eventHandler = null) {
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
const msgUUID = v4();
|
||||
const [systemPrompt, chats] = this.#prepareMessages(messages);
|
||||
@ -205,7 +207,20 @@ class AnthropicProvider extends Provider {
|
||||
textResponse: "",
|
||||
};
|
||||
|
||||
// Track usage from streaming events
|
||||
const usage = { input_tokens: 0, output_tokens: 0 };
|
||||
|
||||
for await (const chunk of response) {
|
||||
// Capture input tokens from message_start event
|
||||
if (chunk.type === "message_start" && chunk.message?.usage) {
|
||||
usage.input_tokens = chunk.message.usage.input_tokens || 0;
|
||||
}
|
||||
|
||||
// Capture output tokens from message_delta event
|
||||
if (chunk.type === "message_delta" && chunk.usage) {
|
||||
usage.output_tokens = chunk.usage.output_tokens || 0;
|
||||
}
|
||||
|
||||
if (chunk.type === "content_block_start") {
|
||||
if (chunk.content_block.type === "text") {
|
||||
result.textResponse += chunk.content_block.text;
|
||||
@ -254,6 +269,8 @@ class AnthropicProvider extends Provider {
|
||||
}
|
||||
}
|
||||
|
||||
// Record accumulated usage
|
||||
this.recordUsage(usage);
|
||||
if (result.functionCall) {
|
||||
result.functionCall.arguments = safeJsonParse(
|
||||
result.functionCall.arguments,
|
||||
@ -278,6 +295,7 @@ class AnthropicProvider extends Provider {
|
||||
arguments: result.functionCall.arguments,
|
||||
},
|
||||
cost: 0,
|
||||
uuid: msgUUID,
|
||||
};
|
||||
}
|
||||
|
||||
@ -285,6 +303,7 @@ class AnthropicProvider extends Provider {
|
||||
textResponse: result.textResponse,
|
||||
functionCall: null,
|
||||
cost: 0,
|
||||
uuid: msgUUID,
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
@ -311,6 +330,8 @@ class AnthropicProvider extends Provider {
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = []) {
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
const [systemPrompt, chats] = this.#prepareMessages(messages);
|
||||
const response = await this.client.messages.create(
|
||||
@ -327,6 +348,9 @@ class AnthropicProvider extends Provider {
|
||||
{ headers: { "anthropic-beta": "tools-2024-04-04" } } // Required to we can use tools.
|
||||
);
|
||||
|
||||
// Record usage from response (Anthropic uses input_tokens/output_tokens)
|
||||
if (response.usage) this.recordUsage(response.usage);
|
||||
|
||||
// We know that we need to call a tool. So we are about to recurse through completions/handleExecution
|
||||
// https://docs.anthropic.com/claude/docs/tool-use#how-tool-use-works
|
||||
if (response.stop_reason === "tool_use") {
|
||||
@ -374,6 +398,7 @@ class AnthropicProvider extends Provider {
|
||||
arguments: functionArgs,
|
||||
},
|
||||
cost: 0,
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
}
|
||||
|
||||
@ -383,6 +408,7 @@ class AnthropicProvider extends Provider {
|
||||
completion?.text ??
|
||||
"The model failed to complete the task and return back a valid response.",
|
||||
cost: 0,
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
|
||||
@ -34,7 +34,7 @@ class AzureOpenAiProvider extends Provider {
|
||||
* @param {any[]} messages
|
||||
* @param {any[]} functions
|
||||
* @param {function} eventHandler
|
||||
* @returns {Promise<{ functionCall: any, textResponse: string }>}
|
||||
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>}
|
||||
*/
|
||||
async stream(messages, functions = [], eventHandler = null) {
|
||||
this.providerLog("Provider.stream - will process this chat completion.");
|
||||
@ -45,7 +45,8 @@ class AzureOpenAiProvider extends Provider {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -75,7 +76,8 @@ class AzureOpenAiProvider extends Provider {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -48,7 +48,10 @@ class DeepSeekProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
}
|
||||
|
||||
get #tooledOptions() {
|
||||
return this.#isThinkingModel ? { injectReasoningContent: true } : {};
|
||||
return {
|
||||
provider: this,
|
||||
...(this.#isThinkingModel ? { injectReasoningContent: true } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
async #handleFunctionCallChat({ messages = [] }) {
|
||||
|
||||
@ -112,7 +112,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -151,7 +152,8 @@ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -165,6 +165,8 @@ class GeminiProvider extends Provider {
|
||||
if (!this.supportsToolCalling)
|
||||
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
||||
this.providerLog("Gemini.stream - will process this chat completion.");
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
const msgUUID = v4();
|
||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletion} */
|
||||
@ -172,6 +174,7 @@ class GeminiProvider extends Provider {
|
||||
model: this.model,
|
||||
messages: this.#formatMessages(messages),
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
...(Array.isArray(functions) && functions?.length > 0
|
||||
? { tools: this.#formatFunctions(functions), tool_choice: "auto" }
|
||||
: {}),
|
||||
@ -186,6 +189,9 @@ class GeminiProvider extends Provider {
|
||||
for await (const streamEvent of response) {
|
||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletionChunk} */
|
||||
const chunk = streamEvent;
|
||||
|
||||
// Capture usage from final chunk (when stream_options.include_usage is true)
|
||||
if (chunk?.usage) this.recordUsage(chunk.usage);
|
||||
const { content, tool_calls } = chunk?.choices?.[0]?.delta || {};
|
||||
|
||||
if (content) {
|
||||
@ -228,6 +234,7 @@ class GeminiProvider extends Provider {
|
||||
extra_content: completion.functionCall.extra_content,
|
||||
},
|
||||
cost: this.getCost(),
|
||||
uuid: msgUUID,
|
||||
};
|
||||
}
|
||||
|
||||
@ -235,6 +242,7 @@ class GeminiProvider extends Provider {
|
||||
textResponse: completion.content,
|
||||
functionCall: null,
|
||||
cost: this.getCost(),
|
||||
uuid: msgUUID,
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
@ -261,6 +269,8 @@ class GeminiProvider extends Provider {
|
||||
if (!this.supportsToolCalling)
|
||||
throw new Error(`Gemini: ${this.model} does not support tool calling.`);
|
||||
this.providerLog("Gemini.complete - will process this chat completion.");
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
const response = await this.client.chat.completions.create({
|
||||
model: this.model,
|
||||
@ -271,6 +281,8 @@ class GeminiProvider extends Provider {
|
||||
: {}),
|
||||
});
|
||||
|
||||
if (response.usage) this.recordUsage(response.usage);
|
||||
|
||||
/** @type {OpenAI.OpenAI.Chat.ChatCompletionMessage} */
|
||||
const completion = response.choices[0].message;
|
||||
const cost = this.getCost(response.usage);
|
||||
@ -287,12 +299,14 @@ class GeminiProvider extends Provider {
|
||||
extra_content: toolCall.extra_content ?? null,
|
||||
},
|
||||
cost,
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
cost,
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
|
||||
@ -131,7 +131,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -170,7 +171,8 @@ class GenericOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -111,7 +111,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -150,7 +151,8 @@ class GroqProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -131,8 +131,9 @@ function formatMessagesForTools(messages, options = {}) {
|
||||
* @param {Array} messages - Raw aibitat message history
|
||||
* @param {Array} functions - Aibitat function definitions
|
||||
* @param {function|null} eventHandler - Stream event handler
|
||||
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools
|
||||
* @returns {Promise<{textResponse: string, functionCall: object|null}>}
|
||||
* @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
|
||||
* - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
|
||||
* @returns {Promise<{textResponse: string, functionCall: object|null, uuid: string, usage: object|null}>}
|
||||
*/
|
||||
async function tooledStream(
|
||||
client,
|
||||
@ -142,13 +143,23 @@ async function tooledStream(
|
||||
eventHandler = null,
|
||||
options = {}
|
||||
) {
|
||||
const { provider, ...formatOptions } = options;
|
||||
|
||||
// Auto-reset usage if provider is passed
|
||||
if (provider?.resetUsage) {
|
||||
try {
|
||||
provider.resetUsage();
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const msgUUID = v4();
|
||||
const formattedMessages = formatMessagesForTools(messages, options);
|
||||
const formattedMessages = formatMessagesForTools(messages, formatOptions);
|
||||
const tools = formatFunctionsToTools(functions);
|
||||
|
||||
const stream = await client.chat.completions.create({
|
||||
model,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
messages: formattedMessages,
|
||||
...(tools.length > 0 ? { tools } : {}),
|
||||
});
|
||||
@ -159,8 +170,14 @@ async function tooledStream(
|
||||
};
|
||||
|
||||
const toolCallsByIndex = {};
|
||||
let usage = null;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
// Capture usage from final chunk (some providers send usage after finish_reason)
|
||||
if (chunk?.usage) {
|
||||
usage = chunk.usage;
|
||||
}
|
||||
|
||||
if (!chunk?.choices?.[0]) continue;
|
||||
const choice = chunk.choices[0];
|
||||
|
||||
@ -203,6 +220,13 @@ async function tooledStream(
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-record usage if provider is passed and usage is available
|
||||
if (provider?.recordUsage && usage) {
|
||||
try {
|
||||
provider.recordUsage(usage);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const toolCallIndices = Object.keys(toolCallsByIndex).map(Number);
|
||||
if (toolCallIndices.length > 0) {
|
||||
const firstToolCall = toolCallsByIndex[Math.min(...toolCallIndices)];
|
||||
@ -216,6 +240,8 @@ async function tooledStream(
|
||||
return {
|
||||
textResponse: result.textResponse,
|
||||
functionCall: result.functionCall,
|
||||
uuid: msgUUID,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
|
||||
@ -228,8 +254,9 @@ async function tooledStream(
|
||||
* @param {Array} messages - Raw aibitat message history
|
||||
* @param {Array} functions - Aibitat function definitions
|
||||
* @param {function} getCostFn - Provider's getCost function
|
||||
* @param {{injectReasoningContent?: boolean}} options - Provider-specific options forwarded to formatMessagesForTools
|
||||
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number}>}
|
||||
* @param {{injectReasoningContent?: boolean, provider?: object}} options - Provider-specific options
|
||||
* - provider: If passed, automatically handles usage tracking via provider.resetUsage()/recordUsage()
|
||||
* @returns {Promise<{textResponse: string|null, functionCall: object|null, cost: number, usage: object|null}>}
|
||||
*/
|
||||
async function tooledComplete(
|
||||
client,
|
||||
@ -239,7 +266,16 @@ async function tooledComplete(
|
||||
getCostFn = () => 0,
|
||||
options = {}
|
||||
) {
|
||||
const formattedMessages = formatMessagesForTools(messages, options);
|
||||
const { provider, ...formatOptions } = options;
|
||||
|
||||
// Auto-reset usage if provider is passed
|
||||
if (provider?.resetUsage) {
|
||||
try {
|
||||
provider.resetUsage();
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const formattedMessages = formatMessagesForTools(messages, formatOptions);
|
||||
const tools = formatFunctionsToTools(functions);
|
||||
|
||||
const response = await client.chat.completions.create({
|
||||
@ -251,6 +287,14 @@ async function tooledComplete(
|
||||
|
||||
const completion = response.choices[0].message;
|
||||
const cost = getCostFn(response.usage);
|
||||
const usage = response.usage || null;
|
||||
|
||||
// Auto-record usage if provider is passed and usage is available
|
||||
if (provider?.recordUsage && usage) {
|
||||
try {
|
||||
provider.recordUsage(usage);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
if (completion.tool_calls && completion.tool_calls.length > 0) {
|
||||
const toolCall = completion.tool_calls[0];
|
||||
@ -270,6 +314,7 @@ async function tooledComplete(
|
||||
},
|
||||
},
|
||||
cost,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
|
||||
@ -281,12 +326,14 @@ async function tooledComplete(
|
||||
arguments: functionArgs,
|
||||
},
|
||||
cost,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
cost,
|
||||
usage,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -135,7 +135,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -175,7 +176,8 @@ class LemonadeProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -113,7 +113,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -152,7 +153,8 @@ class LiteLLMProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -118,7 +118,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -158,7 +159,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -115,7 +115,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -154,7 +155,8 @@ class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -112,7 +112,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -151,7 +152,8 @@ class NovitaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
@ -245,6 +245,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.providerLog(
|
||||
"OllamaProvider.stream (tooled) - will process this chat completion."
|
||||
);
|
||||
this.resetUsage();
|
||||
await OllamaAILLM.cacheContextWindows();
|
||||
const msgUUID = v4();
|
||||
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
|
||||
@ -262,6 +263,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
let toolCalls = null;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
// Capture usage from final chunk (Ollama sends usage when done=true)
|
||||
if (chunk.done === true) {
|
||||
this.recordUsage({
|
||||
prompt_tokens: chunk.prompt_eval_count || 0,
|
||||
completion_tokens: chunk.eval_count || 0,
|
||||
});
|
||||
}
|
||||
|
||||
if (!chunk?.message) continue;
|
||||
|
||||
if (chunk.message.content) {
|
||||
@ -297,10 +306,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
name: toolCall.function.name,
|
||||
arguments: args,
|
||||
},
|
||||
cost: 0,
|
||||
uuid: msgUUID,
|
||||
};
|
||||
}
|
||||
|
||||
return { textResponse, functionCall: null };
|
||||
return { textResponse, functionCall: null, cost: 0, uuid: msgUUID };
|
||||
}
|
||||
|
||||
// Fallback: UnTooled prompt-based approach via the native Ollama SDK
|
||||
@ -431,6 +442,7 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
functions.length > 0 && (await this.supportsNativeToolCalling());
|
||||
|
||||
if (useNative) {
|
||||
this.resetUsage();
|
||||
await OllamaAILLM.cacheContextWindows();
|
||||
const formattedMessages = this.#formatMessagesForOllamaTools(messages);
|
||||
const tools = formatFunctionsToTools(functions);
|
||||
@ -442,6 +454,12 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
options: this.queryOptions,
|
||||
});
|
||||
|
||||
// Record usage (Ollama uses prompt_eval_count/eval_count)
|
||||
this.recordUsage({
|
||||
prompt_tokens: response.prompt_eval_count || 0,
|
||||
completion_tokens: response.eval_count || 0,
|
||||
});
|
||||
|
||||
if (response.message?.tool_calls?.length > 0) {
|
||||
const toolCall = response.message.tool_calls[0];
|
||||
const args =
|
||||
@ -457,12 +475,14 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
arguments: args,
|
||||
},
|
||||
cost: 0,
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
textResponse: response.message?.content || null,
|
||||
cost: 0,
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -117,6 +117,8 @@ class OpenAIProvider extends Provider {
|
||||
*/
|
||||
async stream(messages, functions = [], eventHandler = null) {
|
||||
this.providerLog("OpenAI.stream - will process this chat completion.");
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
const msgUUID = v4();
|
||||
|
||||
@ -178,6 +180,13 @@ class OpenAIProvider extends Provider {
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk.type === "response.completed") {
|
||||
const completedResponse = chunk.response;
|
||||
if (!completedResponse?.usage) continue;
|
||||
this.recordUsage(completedResponse.usage);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (completion.functionCall) {
|
||||
@ -193,6 +202,7 @@ class OpenAIProvider extends Provider {
|
||||
arguments: completion.functionCall.arguments,
|
||||
},
|
||||
cost: this.getCost(),
|
||||
uuid: msgUUID,
|
||||
};
|
||||
}
|
||||
|
||||
@ -200,16 +210,15 @@ class OpenAIProvider extends Provider {
|
||||
textResponse: completion.content,
|
||||
functionCall: null,
|
||||
cost: this.getCost(),
|
||||
uuid: msgUUID,
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
// will make auth better.
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
|
||||
if (
|
||||
error instanceof OpenAI.RateLimitError ||
|
||||
error instanceof OpenAI.InternalServerError ||
|
||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||
error instanceof OpenAI.APIError
|
||||
) {
|
||||
throw new RetryError(error.message);
|
||||
}
|
||||
@ -227,6 +236,8 @@ class OpenAIProvider extends Provider {
|
||||
*/
|
||||
async complete(messages, functions = []) {
|
||||
this.providerLog("OpenAI.complete - will process this chat completion.");
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
const completion = {
|
||||
content: "",
|
||||
@ -245,17 +256,14 @@ class OpenAIProvider extends Provider {
|
||||
: {}),
|
||||
});
|
||||
|
||||
if (response.usage) this.recordUsage(response.usage);
|
||||
for (const outputBlock of response.output) {
|
||||
// Grab intermediate text output if it exists
|
||||
// If no tools are used, this will be returned to the aibitat handler
|
||||
// Otherwise, this text will never be shown to the user
|
||||
if (outputBlock.type === "message") {
|
||||
if (outputBlock.content[0]?.type === "output_text") {
|
||||
completion.content = outputBlock.content[0].text;
|
||||
}
|
||||
}
|
||||
|
||||
// Grab function call output if it exists
|
||||
if (outputBlock.type === "function_call") {
|
||||
completion.functionCall = {
|
||||
name: outputBlock.name,
|
||||
@ -273,13 +281,12 @@ class OpenAIProvider extends Provider {
|
||||
return {
|
||||
textResponse: completion.content,
|
||||
functionCall: {
|
||||
// For OpenAI, the id is the call_id and we need it in followup requests
|
||||
// so we can match the function call output to its invocation in the message history.
|
||||
id: completion.functionCall.call_id,
|
||||
name: completion.functionCall.name,
|
||||
arguments: completion.functionCall.arguments,
|
||||
},
|
||||
cost: this.getCost(),
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
}
|
||||
|
||||
@ -287,16 +294,14 @@ class OpenAIProvider extends Provider {
|
||||
textResponse: completion.content,
|
||||
functionCall: null,
|
||||
cost: this.getCost(),
|
||||
usage: this.getUsage(),
|
||||
};
|
||||
} catch (error) {
|
||||
// If invalid Auth error we need to abort because no amount of waiting
|
||||
// will make auth better.
|
||||
if (error instanceof OpenAI.AuthenticationError) throw error;
|
||||
|
||||
if (
|
||||
error instanceof OpenAI.RateLimitError ||
|
||||
error instanceof OpenAI.InternalServerError ||
|
||||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
|
||||
error instanceof OpenAI.APIError
|
||||
) {
|
||||
throw new RetryError(error.message);
|
||||
}
|
||||
@ -307,9 +312,7 @@ class OpenAIProvider extends Provider {
|
||||
|
||||
/**
|
||||
* Get the cost of the completion.
|
||||
*
|
||||
* @param _usage The completion to get the cost for.
|
||||
* @returns The cost of the completion.
|
||||
* @returns {number} The cost of the completion (currently returns 0).
|
||||
*/
|
||||
getCost() {
|
||||
return 0;
|
||||
|
||||
@ -121,7 +121,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
eventHandler
|
||||
eventHandler,
|
||||
{ provider: this }
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(error.message, error);
|
||||
@ -160,7 +161,8 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||
this.model,
|
||||
messages,
|
||||
functions,
|
||||
this.getCost.bind(this)
|
||||
this.getCost.bind(this),
|
||||
{ provider: this }
|
||||
);
|
||||
|
||||
if (result.retryWithError) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user