diff --git a/server/utils/agents/aibitat/utils/toolReranker.js b/server/utils/agents/aibitat/utils/toolReranker.js index 1b793434..18bca37c 100644 --- a/server/utils/agents/aibitat/utils/toolReranker.js +++ b/server/utils/agents/aibitat/utils/toolReranker.js @@ -1,7 +1,10 @@ +const { TokenManager } = require("../../../helpers/tiktoken"); const { NativeEmbeddingReranker, } = require("../../../EmbeddingRerankers/native"); -const { TokenManager } = require("../../../helpers/tiktoken"); + +const CHUNK_SIZE = 25; +const MAX_TEXT_LENGTH = 1000; class ToolReranker { /** @@ -11,12 +14,12 @@ class ToolReranker { static defaultTopN = 15; static instance = null; - static #reranker = null; constructor() { if (ToolReranker.instance) return ToolReranker.instance; ToolReranker.instance = this; this.tokenManager = new TokenManager(); + this.reranker = null; } log(text, ...args) { @@ -43,20 +46,69 @@ class ToolReranker { } /** - * Get or create the reranker singleton - * @returns {NativeEmbeddingReranker} + * Truncate text to max length, trying to break at word boundary + */ + #truncateText(text, maxLength = MAX_TEXT_LENGTH) { + if (!text || text.length <= maxLength) return text; + const truncated = text.slice(0, maxLength); + const lastSpace = truncated.lastIndexOf(" "); + return lastSpace > maxLength * 0.8 + ? truncated.slice(0, lastSpace) + : truncated; + } + + /** + * Get or initialize the reranker instance */ async #getReranker() { - if (!ToolReranker.#reranker) { - ToolReranker.#reranker = new NativeEmbeddingReranker(); - await ToolReranker.#reranker.initClient(); + if (!this.reranker) { + this.reranker = new NativeEmbeddingReranker(); + await this.reranker.initClient(); } - return ToolReranker.#reranker; + return this.reranker; + } + + /** + * Process documents in chunks and merge results to get global top K + */ + async #chunkedRerank(query, documents, topK) { + const reranker = await this.#getReranker(); + const totalDocs = documents.length; + + if (totalDocs <= CHUNK_SIZE) { + return await reranker.rerank(query, documents, { topK }); + } + + this.log(`Processing ${totalDocs} documents in chunks of ${CHUNK_SIZE}...`); + const allScored = []; + + for (let i = 0; i < totalDocs; i += CHUNK_SIZE) { + const chunk = documents.slice(i, i + CHUNK_SIZE); + const chunkNum = Math.floor(i / CHUNK_SIZE) + 1; + const totalChunks = Math.ceil(totalDocs / CHUNK_SIZE); + + this.log( + `Processing chunk ${chunkNum}/${totalChunks} (${chunk.length} docs)...` + ); + + const chunkResults = await reranker.rerank(query, chunk, { + topK: chunk.length, + }); + + chunkResults.forEach((result) => { + allScored.push({ + ...result, + rerank_corpus_id: result.rerank_corpus_id + i, + }); + }); + } + + allScored.sort((a, b) => b.rerank_score - a.rerank_score); + return allScored.slice(0, topK); } /** * Convert a tool/function definition to a text representation for reranking. - * Format follows the best practices benchmark we have: name, description, param descriptions, example prompts * @param {Object} tool - The tool definition object * @returns {{text: string, toolName: string, tool: Object, tokens: number}} */ @@ -79,7 +131,6 @@ class ToolReranker { } } - // Include example prompts if available (common in aibitat built-in tools) if ( tool.examples && Array.isArray(tool.examples) && @@ -103,11 +154,12 @@ class ToolReranker { } /** - * Rerank tools based on the user prompt and return the top N most relevant tools + * Rerank tools based on the user prompt and return the top N most relevant tools. + * Uses chunked processing to handle large numbers of tools efficiently. * @param {string} userPrompt - The user's query/prompt * @param {Object[]} tools - Array of tool/function definitions from aibitat.functions * @param {Object} options - Options for reranking - * @param {number} options.topN - (optional) Number of top tools to return (default: ToolReranker.getTopN()) + * @param {number} options.topN - Number of top tools to return * @returns {Promise} - Array of reranked tools (top N) */ async rerank(userPrompt, tools = [], options = {}) { @@ -129,15 +181,24 @@ class ToolReranker { ); const startTime = Date.now(); - const reranker = await this.#getReranker(); - const rerankedDocs = await reranker.rerank(userPrompt, documents, { - topK: topN, - }); + // Truncate and format documents for reranking + const rerankDocs = documents.map((doc) => ({ + text: this.#truncateText(doc.text), + })); + + const reranked = await this.#chunkedRerank(userPrompt, rerankDocs, topN); const elapsedMs = Date.now() - startTime; - const rerankedTools = rerankedDocs.map((doc) => doc.tool); - const newTokenCount = rerankedDocs.reduce( - (acc, doc) => acc + doc.tokens, + const rerankedIndices = reranked.map((doc) => ({ + index: doc.rerank_corpus_id, + score: doc.rerank_score, + })); + + const rerankedTools = rerankedIndices.map( + ({ index }) => documents[index].tool + ); + const newTokenCount = rerankedIndices.reduce( + (acc, { index }) => acc + documents[index].tokens, 0 ); const percentSaved = Math.round( @@ -148,8 +209,8 @@ Identified top ${rerankedTools.length} of ${tools.length} tools in ${elapsedMs}m ${originalTokenCount.toLocaleString()} -> ${newTokenCount.toLocaleString()} tokens \x1b[0;93m(${percentSaved}% reduction)\x1b[0m`); let logText = "Selected tools:\n"; - rerankedDocs.forEach((doc, index) => { - logText += ` ${index + 1}. ${doc.toolName}\n`; + rerankedIndices.forEach(({ index }, i) => { + logText += ` ${i + 1}. ${documents[index].toolName}\n`; }); this.log(logText); return rerankedTools;