add batching Intelligent Tool Selector for performance and scoring

This commit is contained in:
Timothy Carambat 2026-03-30 21:47:07 -07:00
parent 0bfd27c6df
commit 77d42e6564

View File

@ -1,7 +1,10 @@
const { TokenManager } = require("../../../helpers/tiktoken");
const { const {
NativeEmbeddingReranker, NativeEmbeddingReranker,
} = require("../../../EmbeddingRerankers/native"); } = require("../../../EmbeddingRerankers/native");
const { TokenManager } = require("../../../helpers/tiktoken");
const CHUNK_SIZE = 25;
const MAX_TEXT_LENGTH = 1000;
class ToolReranker { class ToolReranker {
/** /**
@ -11,12 +14,12 @@ class ToolReranker {
static defaultTopN = 15; static defaultTopN = 15;
static instance = null; static instance = null;
static #reranker = null;
constructor() { constructor() {
if (ToolReranker.instance) return ToolReranker.instance; if (ToolReranker.instance) return ToolReranker.instance;
ToolReranker.instance = this; ToolReranker.instance = this;
this.tokenManager = new TokenManager(); this.tokenManager = new TokenManager();
this.reranker = null;
} }
log(text, ...args) { log(text, ...args) {
@ -43,20 +46,69 @@ class ToolReranker {
} }
/** /**
* Get or create the reranker singleton * Truncate text to max length, trying to break at word boundary
* @returns {NativeEmbeddingReranker} */
#truncateText(text, maxLength = MAX_TEXT_LENGTH) {
if (!text || text.length <= maxLength) return text;
const truncated = text.slice(0, maxLength);
const lastSpace = truncated.lastIndexOf(" ");
return lastSpace > maxLength * 0.8
? truncated.slice(0, lastSpace)
: truncated;
}
/**
* Get or initialize the reranker instance
*/ */
async #getReranker() { async #getReranker() {
if (!ToolReranker.#reranker) { if (!this.reranker) {
ToolReranker.#reranker = new NativeEmbeddingReranker(); this.reranker = new NativeEmbeddingReranker();
await ToolReranker.#reranker.initClient(); await this.reranker.initClient();
} }
return ToolReranker.#reranker; return this.reranker;
}
/**
* Process documents in chunks and merge results to get global top K
*/
async #chunkedRerank(query, documents, topK) {
const reranker = await this.#getReranker();
const totalDocs = documents.length;
if (totalDocs <= CHUNK_SIZE) {
return await reranker.rerank(query, documents, { topK });
}
this.log(`Processing ${totalDocs} documents in chunks of ${CHUNK_SIZE}...`);
const allScored = [];
for (let i = 0; i < totalDocs; i += CHUNK_SIZE) {
const chunk = documents.slice(i, i + CHUNK_SIZE);
const chunkNum = Math.floor(i / CHUNK_SIZE) + 1;
const totalChunks = Math.ceil(totalDocs / CHUNK_SIZE);
this.log(
`Processing chunk ${chunkNum}/${totalChunks} (${chunk.length} docs)...`
);
const chunkResults = await reranker.rerank(query, chunk, {
topK: chunk.length,
});
chunkResults.forEach((result) => {
allScored.push({
...result,
rerank_corpus_id: result.rerank_corpus_id + i,
});
});
}
allScored.sort((a, b) => b.rerank_score - a.rerank_score);
return allScored.slice(0, topK);
} }
/** /**
* Convert a tool/function definition to a text representation for reranking. * Convert a tool/function definition to a text representation for reranking.
* Format follows the best practices benchmark we have: name, description, param descriptions, example prompts
* @param {Object} tool - The tool definition object * @param {Object} tool - The tool definition object
* @returns {{text: string, toolName: string, tool: Object, tokens: number}} * @returns {{text: string, toolName: string, tool: Object, tokens: number}}
*/ */
@ -79,7 +131,6 @@ class ToolReranker {
} }
} }
// Include example prompts if available (common in aibitat built-in tools)
if ( if (
tool.examples && tool.examples &&
Array.isArray(tool.examples) && Array.isArray(tool.examples) &&
@ -103,11 +154,12 @@ class ToolReranker {
} }
/** /**
* Rerank tools based on the user prompt and return the top N most relevant tools * Rerank tools based on the user prompt and return the top N most relevant tools.
* Uses chunked processing to handle large numbers of tools efficiently.
* @param {string} userPrompt - The user's query/prompt * @param {string} userPrompt - The user's query/prompt
* @param {Object[]} tools - Array of tool/function definitions from aibitat.functions * @param {Object[]} tools - Array of tool/function definitions from aibitat.functions
* @param {Object} options - Options for reranking * @param {Object} options - Options for reranking
* @param {number} options.topN - (optional) Number of top tools to return (default: ToolReranker.getTopN()) * @param {number} options.topN - Number of top tools to return
* @returns {Promise<Object[]>} - Array of reranked tools (top N) * @returns {Promise<Object[]>} - Array of reranked tools (top N)
*/ */
async rerank(userPrompt, tools = [], options = {}) { async rerank(userPrompt, tools = [], options = {}) {
@ -129,15 +181,24 @@ class ToolReranker {
); );
const startTime = Date.now(); const startTime = Date.now();
const reranker = await this.#getReranker(); // Truncate and format documents for reranking
const rerankedDocs = await reranker.rerank(userPrompt, documents, { const rerankDocs = documents.map((doc) => ({
topK: topN, text: this.#truncateText(doc.text),
}); }));
const reranked = await this.#chunkedRerank(userPrompt, rerankDocs, topN);
const elapsedMs = Date.now() - startTime; const elapsedMs = Date.now() - startTime;
const rerankedTools = rerankedDocs.map((doc) => doc.tool); const rerankedIndices = reranked.map((doc) => ({
const newTokenCount = rerankedDocs.reduce( index: doc.rerank_corpus_id,
(acc, doc) => acc + doc.tokens, score: doc.rerank_score,
}));
const rerankedTools = rerankedIndices.map(
({ index }) => documents[index].tool
);
const newTokenCount = rerankedIndices.reduce(
(acc, { index }) => acc + documents[index].tokens,
0 0
); );
const percentSaved = Math.round( const percentSaved = Math.round(
@ -148,8 +209,8 @@ Identified top ${rerankedTools.length} of ${tools.length} tools in ${elapsedMs}m
${originalTokenCount.toLocaleString()} -> ${newTokenCount.toLocaleString()} tokens \x1b[0;93m(${percentSaved}% reduction)\x1b[0m`); ${originalTokenCount.toLocaleString()} -> ${newTokenCount.toLocaleString()} tokens \x1b[0;93m(${percentSaved}% reduction)\x1b[0m`);
let logText = "Selected tools:\n"; let logText = "Selected tools:\n";
rerankedDocs.forEach((doc, index) => { rerankedIndices.forEach(({ index }, i) => {
logText += ` ${index + 1}. ${doc.toolName}\n`; logText += ` ${i + 1}. ${documents[index].toolName}\n`;
}); });
this.log(logText); this.log(logText);
return rerankedTools; return rerankedTools;