add batching Intelligent Tool Selector for performance and scoring
This commit is contained in:
parent
0bfd27c6df
commit
77d42e6564
@ -1,7 +1,10 @@
|
|||||||
|
const { TokenManager } = require("../../../helpers/tiktoken");
|
||||||
const {
|
const {
|
||||||
NativeEmbeddingReranker,
|
NativeEmbeddingReranker,
|
||||||
} = require("../../../EmbeddingRerankers/native");
|
} = require("../../../EmbeddingRerankers/native");
|
||||||
const { TokenManager } = require("../../../helpers/tiktoken");
|
|
||||||
|
const CHUNK_SIZE = 25;
|
||||||
|
const MAX_TEXT_LENGTH = 1000;
|
||||||
|
|
||||||
class ToolReranker {
|
class ToolReranker {
|
||||||
/**
|
/**
|
||||||
@ -11,12 +14,12 @@ class ToolReranker {
|
|||||||
static defaultTopN = 15;
|
static defaultTopN = 15;
|
||||||
|
|
||||||
static instance = null;
|
static instance = null;
|
||||||
static #reranker = null;
|
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
if (ToolReranker.instance) return ToolReranker.instance;
|
if (ToolReranker.instance) return ToolReranker.instance;
|
||||||
ToolReranker.instance = this;
|
ToolReranker.instance = this;
|
||||||
this.tokenManager = new TokenManager();
|
this.tokenManager = new TokenManager();
|
||||||
|
this.reranker = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log(text, ...args) {
|
log(text, ...args) {
|
||||||
@ -43,20 +46,69 @@ class ToolReranker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get or create the reranker singleton
|
* Truncate text to max length, trying to break at word boundary
|
||||||
* @returns {NativeEmbeddingReranker}
|
*/
|
||||||
|
#truncateText(text, maxLength = MAX_TEXT_LENGTH) {
|
||||||
|
if (!text || text.length <= maxLength) return text;
|
||||||
|
const truncated = text.slice(0, maxLength);
|
||||||
|
const lastSpace = truncated.lastIndexOf(" ");
|
||||||
|
return lastSpace > maxLength * 0.8
|
||||||
|
? truncated.slice(0, lastSpace)
|
||||||
|
: truncated;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get or initialize the reranker instance
|
||||||
*/
|
*/
|
||||||
async #getReranker() {
|
async #getReranker() {
|
||||||
if (!ToolReranker.#reranker) {
|
if (!this.reranker) {
|
||||||
ToolReranker.#reranker = new NativeEmbeddingReranker();
|
this.reranker = new NativeEmbeddingReranker();
|
||||||
await ToolReranker.#reranker.initClient();
|
await this.reranker.initClient();
|
||||||
}
|
}
|
||||||
return ToolReranker.#reranker;
|
return this.reranker;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process documents in chunks and merge results to get global top K
|
||||||
|
*/
|
||||||
|
async #chunkedRerank(query, documents, topK) {
|
||||||
|
const reranker = await this.#getReranker();
|
||||||
|
const totalDocs = documents.length;
|
||||||
|
|
||||||
|
if (totalDocs <= CHUNK_SIZE) {
|
||||||
|
return await reranker.rerank(query, documents, { topK });
|
||||||
|
}
|
||||||
|
|
||||||
|
this.log(`Processing ${totalDocs} documents in chunks of ${CHUNK_SIZE}...`);
|
||||||
|
const allScored = [];
|
||||||
|
|
||||||
|
for (let i = 0; i < totalDocs; i += CHUNK_SIZE) {
|
||||||
|
const chunk = documents.slice(i, i + CHUNK_SIZE);
|
||||||
|
const chunkNum = Math.floor(i / CHUNK_SIZE) + 1;
|
||||||
|
const totalChunks = Math.ceil(totalDocs / CHUNK_SIZE);
|
||||||
|
|
||||||
|
this.log(
|
||||||
|
`Processing chunk ${chunkNum}/${totalChunks} (${chunk.length} docs)...`
|
||||||
|
);
|
||||||
|
|
||||||
|
const chunkResults = await reranker.rerank(query, chunk, {
|
||||||
|
topK: chunk.length,
|
||||||
|
});
|
||||||
|
|
||||||
|
chunkResults.forEach((result) => {
|
||||||
|
allScored.push({
|
||||||
|
...result,
|
||||||
|
rerank_corpus_id: result.rerank_corpus_id + i,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
allScored.sort((a, b) => b.rerank_score - a.rerank_score);
|
||||||
|
return allScored.slice(0, topK);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert a tool/function definition to a text representation for reranking.
|
* Convert a tool/function definition to a text representation for reranking.
|
||||||
* Format follows the best practices benchmark we have: name, description, param descriptions, example prompts
|
|
||||||
* @param {Object} tool - The tool definition object
|
* @param {Object} tool - The tool definition object
|
||||||
* @returns {{text: string, toolName: string, tool: Object, tokens: number}}
|
* @returns {{text: string, toolName: string, tool: Object, tokens: number}}
|
||||||
*/
|
*/
|
||||||
@ -79,7 +131,6 @@ class ToolReranker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include example prompts if available (common in aibitat built-in tools)
|
|
||||||
if (
|
if (
|
||||||
tool.examples &&
|
tool.examples &&
|
||||||
Array.isArray(tool.examples) &&
|
Array.isArray(tool.examples) &&
|
||||||
@ -103,11 +154,12 @@ class ToolReranker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Rerank tools based on the user prompt and return the top N most relevant tools
|
* Rerank tools based on the user prompt and return the top N most relevant tools.
|
||||||
|
* Uses chunked processing to handle large numbers of tools efficiently.
|
||||||
* @param {string} userPrompt - The user's query/prompt
|
* @param {string} userPrompt - The user's query/prompt
|
||||||
* @param {Object[]} tools - Array of tool/function definitions from aibitat.functions
|
* @param {Object[]} tools - Array of tool/function definitions from aibitat.functions
|
||||||
* @param {Object} options - Options for reranking
|
* @param {Object} options - Options for reranking
|
||||||
* @param {number} options.topN - (optional) Number of top tools to return (default: ToolReranker.getTopN())
|
* @param {number} options.topN - Number of top tools to return
|
||||||
* @returns {Promise<Object[]>} - Array of reranked tools (top N)
|
* @returns {Promise<Object[]>} - Array of reranked tools (top N)
|
||||||
*/
|
*/
|
||||||
async rerank(userPrompt, tools = [], options = {}) {
|
async rerank(userPrompt, tools = [], options = {}) {
|
||||||
@ -129,15 +181,24 @@ class ToolReranker {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
const reranker = await this.#getReranker();
|
// Truncate and format documents for reranking
|
||||||
const rerankedDocs = await reranker.rerank(userPrompt, documents, {
|
const rerankDocs = documents.map((doc) => ({
|
||||||
topK: topN,
|
text: this.#truncateText(doc.text),
|
||||||
});
|
}));
|
||||||
|
|
||||||
|
const reranked = await this.#chunkedRerank(userPrompt, rerankDocs, topN);
|
||||||
const elapsedMs = Date.now() - startTime;
|
const elapsedMs = Date.now() - startTime;
|
||||||
|
|
||||||
const rerankedTools = rerankedDocs.map((doc) => doc.tool);
|
const rerankedIndices = reranked.map((doc) => ({
|
||||||
const newTokenCount = rerankedDocs.reduce(
|
index: doc.rerank_corpus_id,
|
||||||
(acc, doc) => acc + doc.tokens,
|
score: doc.rerank_score,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const rerankedTools = rerankedIndices.map(
|
||||||
|
({ index }) => documents[index].tool
|
||||||
|
);
|
||||||
|
const newTokenCount = rerankedIndices.reduce(
|
||||||
|
(acc, { index }) => acc + documents[index].tokens,
|
||||||
0
|
0
|
||||||
);
|
);
|
||||||
const percentSaved = Math.round(
|
const percentSaved = Math.round(
|
||||||
@ -148,8 +209,8 @@ Identified top ${rerankedTools.length} of ${tools.length} tools in ${elapsedMs}m
|
|||||||
${originalTokenCount.toLocaleString()} -> ${newTokenCount.toLocaleString()} tokens \x1b[0;93m(${percentSaved}% reduction)\x1b[0m`);
|
${originalTokenCount.toLocaleString()} -> ${newTokenCount.toLocaleString()} tokens \x1b[0;93m(${percentSaved}% reduction)\x1b[0m`);
|
||||||
|
|
||||||
let logText = "Selected tools:\n";
|
let logText = "Selected tools:\n";
|
||||||
rerankedDocs.forEach((doc, index) => {
|
rerankedIndices.forEach(({ index }, i) => {
|
||||||
logText += ` ${index + 1}. ${doc.toolName}\n`;
|
logText += ` ${i + 1}. ${documents[index].toolName}\n`;
|
||||||
});
|
});
|
||||||
this.log(logText);
|
this.log(logText);
|
||||||
return rerankedTools;
|
return rerankedTools;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user