Automatic Context window detection (#3817)

* Add context window finder from litellm maintained list
apply to all cloud providers, have client cache for 3 days

* linting
This commit is contained in:
Timothy Carambat 2025-05-14 11:03:19 -07:00 committed by GitHub
parent 4445f39bf8
commit e80492606a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 183 additions and 67 deletions

View File

@ -9,3 +9,4 @@ gemini
togetherAi
tesseract
ppio
context-windows/*

View File

@ -45,11 +45,11 @@ class AnthropicLLM {
}
static promptWindowLimit(modelName) {
return MODEL_MAP.anthropic[modelName] ?? 100_000;
return MODEL_MAP.get("anthropic", modelName) ?? 100_000;
}
promptWindowLimit() {
return MODEL_MAP.anthropic[this.model] ?? 100_000;
return MODEL_MAP.get("anthropic", this.model) ?? 100_000;
}
isValidChatCompletionModel(_modelName = "") {

View File

@ -63,11 +63,11 @@ class CohereLLM {
}
static promptWindowLimit(modelName) {
return MODEL_MAP.cohere[modelName] ?? 4_096;
return MODEL_MAP.get("cohere", modelName) ?? 4_096;
}
promptWindowLimit() {
return MODEL_MAP.cohere[this.model] ?? 4_096;
return MODEL_MAP.get("cohere", this.model) ?? 4_096;
}
async isValidChatCompletionModel(model = "") {

View File

@ -29,7 +29,9 @@ class DeepSeekLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7;
this.log("Initialized with model:", this.model);
this.log(
`Initialized ${this.model} with context window ${this.promptWindowLimit()}`
);
}
log(text, ...args) {
@ -53,11 +55,11 @@ class DeepSeekLLM {
}
static promptWindowLimit(modelName) {
return MODEL_MAP.deepseek[modelName] ?? 8192;
return MODEL_MAP.get("deepseek", modelName) ?? 8192;
}
promptWindowLimit() {
return MODEL_MAP.deepseek[this.model] ?? 8192;
return MODEL_MAP.get("deepseek", this.model) ?? 8192;
}
async isValidChatCompletionModel(modelName = "") {

View File

@ -1,7 +1,7 @@
const { MODEL_MAP } = require("../modelMap");
const stableModels = [
// %STABLE_MODELS% - updated 2025-04-07T20:29:49.276Z
// %STABLE_MODELS% - updated 2025-05-13T23:13:58.920Z
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
"gemini-1.5-pro",
@ -14,6 +14,7 @@ const stableModels = [
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-001",
"gemini-2.0-flash-lite",
"gemini-2.0-flash-preview-image-generation",
// %EOC_STABLE_MODELS%
];
@ -22,7 +23,7 @@ const stableModels = [
// generally, v1beta models have `exp` in the name, but not always
// so we check for both against a static list as well via API.
const v1BetaModels = [
// %V1BETA_MODELS% - updated 2025-04-07T20:29:49.276Z
// %V1BETA_MODELS% - updated 2025-05-13T23:13:58.920Z
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-8b-latest",
@ -30,6 +31,9 @@ const v1BetaModels = [
"gemini-1.5-flash-8b-exp-0924",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-03-25",
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-flash-preview-04-17-thinking",
"gemini-2.5-pro-preview-05-06",
"gemini-2.0-flash-exp",
"gemini-2.0-flash-exp-image-generation",
"gemini-2.0-flash-lite-preview-02-05",
@ -41,6 +45,7 @@ const v1BetaModels = [
"gemini-2.0-flash-thinking-exp",
"gemini-2.0-flash-thinking-exp-1219",
"learnlm-1.5-pro-experimental",
"learnlm-2.0-flash-experimental",
"gemma-3-1b-it",
"gemma-3-4b-it",
"gemma-3-12b-it",
@ -48,17 +53,17 @@ const v1BetaModels = [
// %EOC_V1BETA_MODELS%
];
const defaultGeminiModels = [
const defaultGeminiModels = () => [
...stableModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
contextWindow: MODEL_MAP.get("gemini", model),
experimental: false,
})),
...v1BetaModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
contextWindow: MODEL_MAP.get("gemini", model),
experimental: true,
})),
];

View File

@ -107,7 +107,7 @@ class GeminiLLM {
try {
const cacheModelPath = path.resolve(cacheFolder, "models.json");
if (!fs.existsSync(cacheModelPath))
return MODEL_MAP.gemini[modelName] ?? 30_720;
return MODEL_MAP.get("gemini", modelName) ?? 30_720;
const models = safeJsonParse(fs.readFileSync(cacheModelPath));
const model = models.find((model) => model.id === modelName);
@ -118,15 +118,14 @@ class GeminiLLM {
return model.contextWindow;
} catch (e) {
console.error(`GeminiLLM:promptWindowLimit`, e.message);
return MODEL_MAP.gemini[modelName] ?? 30_720;
return MODEL_MAP.get("gemini", modelName) ?? 30_720;
}
}
promptWindowLimit() {
try {
if (!fs.existsSync(this.cacheModelPath))
return MODEL_MAP.gemini[this.model] ?? 30_720;
return MODEL_MAP.get("gemini", this.model) ?? 30_720;
const models = safeJsonParse(fs.readFileSync(this.cacheModelPath));
const model = models.find((model) => model.id === this.model);
if (!model)
@ -136,7 +135,7 @@ class GeminiLLM {
return model.contextWindow;
} catch (e) {
console.error(`GeminiLLM:promptWindowLimit`, e.message);
return MODEL_MAP.gemini[this.model] ?? 30_720;
return MODEL_MAP.get("gemini", this.model) ?? 30_720;
}
}
@ -281,7 +280,7 @@ class GeminiLLM {
if (allModels.length === 0) {
console.error(`Gemini:getGeminiModels - No models found`);
return defaultGeminiModels;
return defaultGeminiModels();
}
console.log(

View File

@ -9,7 +9,6 @@
import fs from "fs";
import path from "path";
import dotenv from "dotenv";
import { MODEL_MAP } from "../modelMap.js";
dotenv.config({ path: `../../../.env.development` });
const existingCachePath = path.resolve('../../../storage/models/gemini')
@ -46,34 +45,4 @@ function updateDefaultModelsFile(models) {
fs.writeFileSync(path.join("./defaultModels.js"), defaultModelFileContents);
console.log("Updated defaultModels.js. Dont forget to `yarn lint` and commit!");
}
function updateModelMap(models) {
const existingModelMap = MODEL_MAP;
console.log('Updating modelMap.js `gemini` object...')
console.log(`Removed existing gemini object (${Object.keys(existingModelMap.gemini).length} models) from modelMap.js`);
existingModelMap.gemini = {};
for (const model of models) existingModelMap.gemini[model.id] = model.contextWindow;
console.log(`Updated modelMap.js 'gemini' object with ${Object.keys(existingModelMap.gemini).length} models from API`);
// Update the modelMap.js file
const contents = `/**
* The model name and context window for all know model windows
* that are available through providers which has discrete model options.
* This file is automatically generated by syncStaticLists.mjs
* and should not be edited manually.
*
* Last updated: ${new Date().toISOString()}
*/
const MODEL_MAP = {
${Object.entries(existingModelMap).map(([key, value]) => `${key}: ${JSON.stringify(value, null, 2)}`).join(',\n')}
};
module.exports = { MODEL_MAP };
`;
fs.writeFileSync(path.resolve("../modelMap.js"), contents);
console.log('Updated modelMap.js `gemini` object. Dont forget to `yarn lint` and commit!');
}
updateDefaultModelsFile(models);
updateModelMap(models);

View File

@ -49,11 +49,11 @@ class GroqLLM {
}
static promptWindowLimit(modelName) {
return MODEL_MAP.groq[modelName] ?? 8192;
return MODEL_MAP.get("groq", modelName) ?? 8192;
}
promptWindowLimit() {
return MODEL_MAP.groq[this.model] ?? 8192;
return MODEL_MAP.get("groq", this.model) ?? 8192;
}
async isValidChatCompletionModel(modelName = "") {

View File

@ -0,0 +1,140 @@
const path = require("path");
const fs = require("fs");
const LEGACY_MODEL_MAP = require("./legacy");
class ContextWindowFinder {
static instance = null;
static modelMap = LEGACY_MODEL_MAP;
/**
* Mapping for AnythingLLM provider <> LiteLLM provider
* @type {Record<string, string>}
*/
static trackedProviders = {
anthropic: "anthropic",
openai: "openai",
cohere: "cohere_chat",
gemini: "vertex_ai-language-models",
groq: "groq",
xai: "xai",
deepseek: "deepseek",
};
static expiryMs = 1000 * 60 * 60 * 24 * 3; // 3 days
static remoteUrl =
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
cacheLocation = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, "models", "context-windows")
: path.resolve(__dirname, `../../../storage/models/context-windows`)
);
cacheFilePath = path.resolve(this.cacheLocation, "context-windows.json");
cacheFileExpiryPath = path.resolve(this.cacheLocation, ".cached_at");
constructor() {
if (ContextWindowFinder.instance) return ContextWindowFinder.instance;
ContextWindowFinder.instance = this;
if (!fs.existsSync(this.cacheLocation))
fs.mkdirSync(this.cacheLocation, { recursive: true });
this.#pullRemoteModelMap();
}
log(text, ...args) {
console.log(`\x1b[33m[ContextWindowFinder]\x1b[0m ${text}`, ...args);
}
/**
* Checks if the cache is stale by checking if the cache file exists and if the cache file is older than the expiry time.
* @returns {boolean}
*/
get isCacheStale() {
if (!fs.existsSync(this.cacheFileExpiryPath)) return true;
const cachedAt = fs.readFileSync(this.cacheFileExpiryPath, "utf8");
return Date.now() - cachedAt > ContextWindowFinder.expiryMs;
}
get cache() {
if (!fs.existsSync(this.cacheFileExpiryPath)) return null;
if (!this.isCacheStale)
return JSON.parse(
fs.readFileSync(this.cacheFilePath, { encoding: "utf8" })
);
return null;
}
/**
* Pulls the remote model map from the remote URL, formats it and caches it.
* @returns {Record<string, Record<string, number>>} - The formatted model map
*/
async #pullRemoteModelMap() {
const remoteContexWindowMap = await fetch(ContextWindowFinder.remoteUrl)
.then((res) => res.json())
.then((data) => {
fs.writeFileSync(this.cacheFilePath, JSON.stringify(data, null, 2));
fs.writeFileSync(this.cacheFileExpiryPath, Date.now().toString());
this.log("Remote model map synced and cached");
return data;
})
.catch((error) => {
this.log("Error syncing remote model map", error);
return null;
});
if (!remoteContexWindowMap) return null;
const modelMap = this.#formatModelMap(remoteContexWindowMap);
fs.writeFileSync(this.cacheFilePath, JSON.stringify(modelMap, null, 2));
fs.writeFileSync(this.cacheFileExpiryPath, Date.now().toString());
return modelMap;
}
/**
* Formats the remote model map to a format that is compatible with how we store the model map
* for all providers who use it.
* @param {Record<string, any>} modelMap - The remote model map
* @returns {Record<string, Record<string, number>>} - The formatted model map
*/
#formatModelMap(modelMap = {}) {
const formattedModelMap = {};
for (const [provider, liteLLMProviderTag] of Object.entries(
ContextWindowFinder.trackedProviders
)) {
formattedModelMap[provider] = {};
const matches = Object.entries(modelMap).filter(
([_key, config]) => config.litellm_provider === liteLLMProviderTag
);
for (const [key, config] of matches) {
const contextWindow = Number(config.max_input_tokens);
if (isNaN(contextWindow)) continue;
// Some models have a provider/model-tag format, so we need to get the last part since we dont do paths
// for names with the exception of some router-providers like OpenRouter or Together.
const modelName = key.split("/").pop();
formattedModelMap[provider][modelName] = contextWindow;
}
}
return formattedModelMap;
}
/**
* Gets the context window for a given provider and model.
* @param {string} provider - The provider to get the context window for
* @param {string} model - The model to get the context window for
* @returns {number} - The context window for the given provider and model
*/
get(provider = null, model = null) {
if (!provider || !this.cache || !this.cache[provider]) return null;
if (!model) return this.cache[provider];
const modelContextWindow = this.cache[provider][model];
if (!modelContextWindow) {
this.log("Invalid access to model context window - not found in cache", {
provider,
model,
});
return null;
}
return Number(modelContextWindow);
}
}
module.exports = { MODEL_MAP: new ContextWindowFinder() };

View File

@ -1,12 +1,4 @@
/**
* The model name and context window for all know model windows
* that are available through providers which has discrete model options.
* This file is automatically generated by syncStaticLists.mjs
* and should not be edited manually.
*
* Last updated: 2025-04-07T20:29:49.277Z
*/
const MODEL_MAP = {
const LEGACY_MODEL_MAP = {
anthropic: {
"claude-instant-1.2": 100000,
"claude-2.0": 100000,
@ -117,5 +109,4 @@ const MODEL_MAP = {
"grok-beta": 131072,
},
};
module.exports = { MODEL_MAP };
module.exports = LEGACY_MODEL_MAP;

View File

@ -25,6 +25,13 @@ class OpenAiLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7;
this.log(
`Initialized ${this.model} with context window ${this.promptWindowLimit()}`
);
}
log(text, ...args) {
console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
}
/**
@ -54,11 +61,11 @@ class OpenAiLLM {
}
static promptWindowLimit(modelName) {
return MODEL_MAP.openai[modelName] ?? 4_096;
return MODEL_MAP.get("openai", modelName) ?? 4_096;
}
promptWindowLimit() {
return MODEL_MAP.openai[this.model] ?? 4_096;
return MODEL_MAP.get("openai", this.model) ?? 4_096;
}
// Short circuit if name has 'gpt' since we now fetch models from OpenAI API

View File

@ -28,7 +28,9 @@ class XAiLLM {
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7;
this.log("Initialized with model:", this.model);
this.log(
`Initialized ${this.model} with context window ${this.promptWindowLimit()}`
);
}
log(text, ...args) {
@ -52,11 +54,11 @@ class XAiLLM {
}
static promptWindowLimit(modelName) {
return MODEL_MAP.xai[modelName] ?? 131_072;
return MODEL_MAP.get("xai", modelName) ?? 131_072;
}
promptWindowLimit() {
return MODEL_MAP.xai[this.model] ?? 131_072;
return MODEL_MAP.get("xai", this.model) ?? 131_072;
}
isValidChatCompletionModel(_modelName = "") {