Automatic Context window detection (#3817)
* Add context window finder from litellm maintained list apply to all cloud providers, have client cache for 3 days * linting
This commit is contained in:
parent
4445f39bf8
commit
e80492606a
1
server/storage/models/.gitignore
vendored
1
server/storage/models/.gitignore
vendored
@ -9,3 +9,4 @@ gemini
|
||||
togetherAi
|
||||
tesseract
|
||||
ppio
|
||||
context-windows/*
|
||||
@ -45,11 +45,11 @@ class AnthropicLLM {
|
||||
}
|
||||
|
||||
static promptWindowLimit(modelName) {
|
||||
return MODEL_MAP.anthropic[modelName] ?? 100_000;
|
||||
return MODEL_MAP.get("anthropic", modelName) ?? 100_000;
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
return MODEL_MAP.anthropic[this.model] ?? 100_000;
|
||||
return MODEL_MAP.get("anthropic", this.model) ?? 100_000;
|
||||
}
|
||||
|
||||
isValidChatCompletionModel(_modelName = "") {
|
||||
|
||||
@ -63,11 +63,11 @@ class CohereLLM {
|
||||
}
|
||||
|
||||
static promptWindowLimit(modelName) {
|
||||
return MODEL_MAP.cohere[modelName] ?? 4_096;
|
||||
return MODEL_MAP.get("cohere", modelName) ?? 4_096;
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
return MODEL_MAP.cohere[this.model] ?? 4_096;
|
||||
return MODEL_MAP.get("cohere", this.model) ?? 4_096;
|
||||
}
|
||||
|
||||
async isValidChatCompletionModel(model = "") {
|
||||
|
||||
@ -29,7 +29,9 @@ class DeepSeekLLM {
|
||||
|
||||
this.embedder = embedder ?? new NativeEmbedder();
|
||||
this.defaultTemp = 0.7;
|
||||
this.log("Initialized with model:", this.model);
|
||||
this.log(
|
||||
`Initialized ${this.model} with context window ${this.promptWindowLimit()}`
|
||||
);
|
||||
}
|
||||
|
||||
log(text, ...args) {
|
||||
@ -53,11 +55,11 @@ class DeepSeekLLM {
|
||||
}
|
||||
|
||||
static promptWindowLimit(modelName) {
|
||||
return MODEL_MAP.deepseek[modelName] ?? 8192;
|
||||
return MODEL_MAP.get("deepseek", modelName) ?? 8192;
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
return MODEL_MAP.deepseek[this.model] ?? 8192;
|
||||
return MODEL_MAP.get("deepseek", this.model) ?? 8192;
|
||||
}
|
||||
|
||||
async isValidChatCompletionModel(modelName = "") {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
const { MODEL_MAP } = require("../modelMap");
|
||||
|
||||
const stableModels = [
|
||||
// %STABLE_MODELS% - updated 2025-04-07T20:29:49.276Z
|
||||
// %STABLE_MODELS% - updated 2025-05-13T23:13:58.920Z
|
||||
"gemini-1.5-pro-001",
|
||||
"gemini-1.5-pro-002",
|
||||
"gemini-1.5-pro",
|
||||
@ -14,6 +14,7 @@ const stableModels = [
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-lite-001",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.0-flash-preview-image-generation",
|
||||
// %EOC_STABLE_MODELS%
|
||||
];
|
||||
|
||||
@ -22,7 +23,7 @@ const stableModels = [
|
||||
// generally, v1beta models have `exp` in the name, but not always
|
||||
// so we check for both against a static list as well via API.
|
||||
const v1BetaModels = [
|
||||
// %V1BETA_MODELS% - updated 2025-04-07T20:29:49.276Z
|
||||
// %V1BETA_MODELS% - updated 2025-05-13T23:13:58.920Z
|
||||
"gemini-1.5-pro-latest",
|
||||
"gemini-1.5-flash-latest",
|
||||
"gemini-1.5-flash-8b-latest",
|
||||
@ -30,6 +31,9 @@ const v1BetaModels = [
|
||||
"gemini-1.5-flash-8b-exp-0924",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-flash-preview-04-17-thinking",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.0-flash-exp",
|
||||
"gemini-2.0-flash-exp-image-generation",
|
||||
"gemini-2.0-flash-lite-preview-02-05",
|
||||
@ -41,6 +45,7 @@ const v1BetaModels = [
|
||||
"gemini-2.0-flash-thinking-exp",
|
||||
"gemini-2.0-flash-thinking-exp-1219",
|
||||
"learnlm-1.5-pro-experimental",
|
||||
"learnlm-2.0-flash-experimental",
|
||||
"gemma-3-1b-it",
|
||||
"gemma-3-4b-it",
|
||||
"gemma-3-12b-it",
|
||||
@ -48,17 +53,17 @@ const v1BetaModels = [
|
||||
// %EOC_V1BETA_MODELS%
|
||||
];
|
||||
|
||||
const defaultGeminiModels = [
|
||||
const defaultGeminiModels = () => [
|
||||
...stableModels.map((model) => ({
|
||||
id: model,
|
||||
name: model,
|
||||
contextWindow: MODEL_MAP.gemini[model],
|
||||
contextWindow: MODEL_MAP.get("gemini", model),
|
||||
experimental: false,
|
||||
})),
|
||||
...v1BetaModels.map((model) => ({
|
||||
id: model,
|
||||
name: model,
|
||||
contextWindow: MODEL_MAP.gemini[model],
|
||||
contextWindow: MODEL_MAP.get("gemini", model),
|
||||
experimental: true,
|
||||
})),
|
||||
];
|
||||
|
||||
@ -107,7 +107,7 @@ class GeminiLLM {
|
||||
try {
|
||||
const cacheModelPath = path.resolve(cacheFolder, "models.json");
|
||||
if (!fs.existsSync(cacheModelPath))
|
||||
return MODEL_MAP.gemini[modelName] ?? 30_720;
|
||||
return MODEL_MAP.get("gemini", modelName) ?? 30_720;
|
||||
|
||||
const models = safeJsonParse(fs.readFileSync(cacheModelPath));
|
||||
const model = models.find((model) => model.id === modelName);
|
||||
@ -118,15 +118,14 @@ class GeminiLLM {
|
||||
return model.contextWindow;
|
||||
} catch (e) {
|
||||
console.error(`GeminiLLM:promptWindowLimit`, e.message);
|
||||
return MODEL_MAP.gemini[modelName] ?? 30_720;
|
||||
return MODEL_MAP.get("gemini", modelName) ?? 30_720;
|
||||
}
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
try {
|
||||
if (!fs.existsSync(this.cacheModelPath))
|
||||
return MODEL_MAP.gemini[this.model] ?? 30_720;
|
||||
|
||||
return MODEL_MAP.get("gemini", this.model) ?? 30_720;
|
||||
const models = safeJsonParse(fs.readFileSync(this.cacheModelPath));
|
||||
const model = models.find((model) => model.id === this.model);
|
||||
if (!model)
|
||||
@ -136,7 +135,7 @@ class GeminiLLM {
|
||||
return model.contextWindow;
|
||||
} catch (e) {
|
||||
console.error(`GeminiLLM:promptWindowLimit`, e.message);
|
||||
return MODEL_MAP.gemini[this.model] ?? 30_720;
|
||||
return MODEL_MAP.get("gemini", this.model) ?? 30_720;
|
||||
}
|
||||
}
|
||||
|
||||
@ -281,7 +280,7 @@ class GeminiLLM {
|
||||
|
||||
if (allModels.length === 0) {
|
||||
console.error(`Gemini:getGeminiModels - No models found`);
|
||||
return defaultGeminiModels;
|
||||
return defaultGeminiModels();
|
||||
}
|
||||
|
||||
console.log(
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import dotenv from "dotenv";
|
||||
import { MODEL_MAP } from "../modelMap.js";
|
||||
|
||||
dotenv.config({ path: `../../../.env.development` });
|
||||
const existingCachePath = path.resolve('../../../storage/models/gemini')
|
||||
@ -46,34 +45,4 @@ function updateDefaultModelsFile(models) {
|
||||
fs.writeFileSync(path.join("./defaultModels.js"), defaultModelFileContents);
|
||||
console.log("Updated defaultModels.js. Dont forget to `yarn lint` and commit!");
|
||||
}
|
||||
|
||||
function updateModelMap(models) {
|
||||
const existingModelMap = MODEL_MAP;
|
||||
console.log('Updating modelMap.js `gemini` object...')
|
||||
console.log(`Removed existing gemini object (${Object.keys(existingModelMap.gemini).length} models) from modelMap.js`);
|
||||
existingModelMap.gemini = {};
|
||||
|
||||
for (const model of models) existingModelMap.gemini[model.id] = model.contextWindow;
|
||||
console.log(`Updated modelMap.js 'gemini' object with ${Object.keys(existingModelMap.gemini).length} models from API`);
|
||||
|
||||
// Update the modelMap.js file
|
||||
const contents = `/**
|
||||
* The model name and context window for all know model windows
|
||||
* that are available through providers which has discrete model options.
|
||||
* This file is automatically generated by syncStaticLists.mjs
|
||||
* and should not be edited manually.
|
||||
*
|
||||
* Last updated: ${new Date().toISOString()}
|
||||
*/
|
||||
const MODEL_MAP = {
|
||||
${Object.entries(existingModelMap).map(([key, value]) => `${key}: ${JSON.stringify(value, null, 2)}`).join(',\n')}
|
||||
};
|
||||
|
||||
module.exports = { MODEL_MAP };
|
||||
`;
|
||||
fs.writeFileSync(path.resolve("../modelMap.js"), contents);
|
||||
console.log('Updated modelMap.js `gemini` object. Dont forget to `yarn lint` and commit!');
|
||||
}
|
||||
|
||||
updateDefaultModelsFile(models);
|
||||
updateModelMap(models);
|
||||
|
||||
@ -49,11 +49,11 @@ class GroqLLM {
|
||||
}
|
||||
|
||||
static promptWindowLimit(modelName) {
|
||||
return MODEL_MAP.groq[modelName] ?? 8192;
|
||||
return MODEL_MAP.get("groq", modelName) ?? 8192;
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
return MODEL_MAP.groq[this.model] ?? 8192;
|
||||
return MODEL_MAP.get("groq", this.model) ?? 8192;
|
||||
}
|
||||
|
||||
async isValidChatCompletionModel(modelName = "") {
|
||||
|
||||
140
server/utils/AiProviders/modelMap/index.js
Normal file
140
server/utils/AiProviders/modelMap/index.js
Normal file
@ -0,0 +1,140 @@
|
||||
const path = require("path");
|
||||
const fs = require("fs");
|
||||
const LEGACY_MODEL_MAP = require("./legacy");
|
||||
|
||||
class ContextWindowFinder {
|
||||
static instance = null;
|
||||
static modelMap = LEGACY_MODEL_MAP;
|
||||
|
||||
/**
|
||||
* Mapping for AnythingLLM provider <> LiteLLM provider
|
||||
* @type {Record<string, string>}
|
||||
*/
|
||||
static trackedProviders = {
|
||||
anthropic: "anthropic",
|
||||
openai: "openai",
|
||||
cohere: "cohere_chat",
|
||||
gemini: "vertex_ai-language-models",
|
||||
groq: "groq",
|
||||
xai: "xai",
|
||||
deepseek: "deepseek",
|
||||
};
|
||||
static expiryMs = 1000 * 60 * 60 * 24 * 3; // 3 days
|
||||
static remoteUrl =
|
||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
|
||||
|
||||
cacheLocation = path.resolve(
|
||||
process.env.STORAGE_DIR
|
||||
? path.resolve(process.env.STORAGE_DIR, "models", "context-windows")
|
||||
: path.resolve(__dirname, `../../../storage/models/context-windows`)
|
||||
);
|
||||
cacheFilePath = path.resolve(this.cacheLocation, "context-windows.json");
|
||||
cacheFileExpiryPath = path.resolve(this.cacheLocation, ".cached_at");
|
||||
|
||||
constructor() {
|
||||
if (ContextWindowFinder.instance) return ContextWindowFinder.instance;
|
||||
ContextWindowFinder.instance = this;
|
||||
if (!fs.existsSync(this.cacheLocation))
|
||||
fs.mkdirSync(this.cacheLocation, { recursive: true });
|
||||
this.#pullRemoteModelMap();
|
||||
}
|
||||
|
||||
log(text, ...args) {
|
||||
console.log(`\x1b[33m[ContextWindowFinder]\x1b[0m ${text}`, ...args);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the cache is stale by checking if the cache file exists and if the cache file is older than the expiry time.
|
||||
* @returns {boolean}
|
||||
*/
|
||||
get isCacheStale() {
|
||||
if (!fs.existsSync(this.cacheFileExpiryPath)) return true;
|
||||
const cachedAt = fs.readFileSync(this.cacheFileExpiryPath, "utf8");
|
||||
return Date.now() - cachedAt > ContextWindowFinder.expiryMs;
|
||||
}
|
||||
|
||||
get cache() {
|
||||
if (!fs.existsSync(this.cacheFileExpiryPath)) return null;
|
||||
if (!this.isCacheStale)
|
||||
return JSON.parse(
|
||||
fs.readFileSync(this.cacheFilePath, { encoding: "utf8" })
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pulls the remote model map from the remote URL, formats it and caches it.
|
||||
* @returns {Record<string, Record<string, number>>} - The formatted model map
|
||||
*/
|
||||
async #pullRemoteModelMap() {
|
||||
const remoteContexWindowMap = await fetch(ContextWindowFinder.remoteUrl)
|
||||
.then((res) => res.json())
|
||||
.then((data) => {
|
||||
fs.writeFileSync(this.cacheFilePath, JSON.stringify(data, null, 2));
|
||||
fs.writeFileSync(this.cacheFileExpiryPath, Date.now().toString());
|
||||
this.log("Remote model map synced and cached");
|
||||
return data;
|
||||
})
|
||||
.catch((error) => {
|
||||
this.log("Error syncing remote model map", error);
|
||||
return null;
|
||||
});
|
||||
if (!remoteContexWindowMap) return null;
|
||||
|
||||
const modelMap = this.#formatModelMap(remoteContexWindowMap);
|
||||
fs.writeFileSync(this.cacheFilePath, JSON.stringify(modelMap, null, 2));
|
||||
fs.writeFileSync(this.cacheFileExpiryPath, Date.now().toString());
|
||||
return modelMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats the remote model map to a format that is compatible with how we store the model map
|
||||
* for all providers who use it.
|
||||
* @param {Record<string, any>} modelMap - The remote model map
|
||||
* @returns {Record<string, Record<string, number>>} - The formatted model map
|
||||
*/
|
||||
#formatModelMap(modelMap = {}) {
|
||||
const formattedModelMap = {};
|
||||
|
||||
for (const [provider, liteLLMProviderTag] of Object.entries(
|
||||
ContextWindowFinder.trackedProviders
|
||||
)) {
|
||||
formattedModelMap[provider] = {};
|
||||
const matches = Object.entries(modelMap).filter(
|
||||
([_key, config]) => config.litellm_provider === liteLLMProviderTag
|
||||
);
|
||||
for (const [key, config] of matches) {
|
||||
const contextWindow = Number(config.max_input_tokens);
|
||||
if (isNaN(contextWindow)) continue;
|
||||
|
||||
// Some models have a provider/model-tag format, so we need to get the last part since we dont do paths
|
||||
// for names with the exception of some router-providers like OpenRouter or Together.
|
||||
const modelName = key.split("/").pop();
|
||||
formattedModelMap[provider][modelName] = contextWindow;
|
||||
}
|
||||
}
|
||||
return formattedModelMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the context window for a given provider and model.
|
||||
* @param {string} provider - The provider to get the context window for
|
||||
* @param {string} model - The model to get the context window for
|
||||
* @returns {number} - The context window for the given provider and model
|
||||
*/
|
||||
get(provider = null, model = null) {
|
||||
if (!provider || !this.cache || !this.cache[provider]) return null;
|
||||
if (!model) return this.cache[provider];
|
||||
const modelContextWindow = this.cache[provider][model];
|
||||
if (!modelContextWindow) {
|
||||
this.log("Invalid access to model context window - not found in cache", {
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
return Number(modelContextWindow);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { MODEL_MAP: new ContextWindowFinder() };
|
||||
@ -1,12 +1,4 @@
|
||||
/**
|
||||
* The model name and context window for all know model windows
|
||||
* that are available through providers which has discrete model options.
|
||||
* This file is automatically generated by syncStaticLists.mjs
|
||||
* and should not be edited manually.
|
||||
*
|
||||
* Last updated: 2025-04-07T20:29:49.277Z
|
||||
*/
|
||||
const MODEL_MAP = {
|
||||
const LEGACY_MODEL_MAP = {
|
||||
anthropic: {
|
||||
"claude-instant-1.2": 100000,
|
||||
"claude-2.0": 100000,
|
||||
@ -117,5 +109,4 @@ const MODEL_MAP = {
|
||||
"grok-beta": 131072,
|
||||
},
|
||||
};
|
||||
|
||||
module.exports = { MODEL_MAP };
|
||||
module.exports = LEGACY_MODEL_MAP;
|
||||
@ -25,6 +25,13 @@ class OpenAiLLM {
|
||||
|
||||
this.embedder = embedder ?? new NativeEmbedder();
|
||||
this.defaultTemp = 0.7;
|
||||
this.log(
|
||||
`Initialized ${this.model} with context window ${this.promptWindowLimit()}`
|
||||
);
|
||||
}
|
||||
|
||||
log(text, ...args) {
|
||||
console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -54,11 +61,11 @@ class OpenAiLLM {
|
||||
}
|
||||
|
||||
static promptWindowLimit(modelName) {
|
||||
return MODEL_MAP.openai[modelName] ?? 4_096;
|
||||
return MODEL_MAP.get("openai", modelName) ?? 4_096;
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
return MODEL_MAP.openai[this.model] ?? 4_096;
|
||||
return MODEL_MAP.get("openai", this.model) ?? 4_096;
|
||||
}
|
||||
|
||||
// Short circuit if name has 'gpt' since we now fetch models from OpenAI API
|
||||
|
||||
@ -28,7 +28,9 @@ class XAiLLM {
|
||||
|
||||
this.embedder = embedder ?? new NativeEmbedder();
|
||||
this.defaultTemp = 0.7;
|
||||
this.log("Initialized with model:", this.model);
|
||||
this.log(
|
||||
`Initialized ${this.model} with context window ${this.promptWindowLimit()}`
|
||||
);
|
||||
}
|
||||
|
||||
log(text, ...args) {
|
||||
@ -52,11 +54,11 @@ class XAiLLM {
|
||||
}
|
||||
|
||||
static promptWindowLimit(modelName) {
|
||||
return MODEL_MAP.xai[modelName] ?? 131_072;
|
||||
return MODEL_MAP.get("xai", modelName) ?? 131_072;
|
||||
}
|
||||
|
||||
promptWindowLimit() {
|
||||
return MODEL_MAP.xai[this.model] ?? 131_072;
|
||||
return MODEL_MAP.get("xai", this.model) ?? 131_072;
|
||||
}
|
||||
|
||||
isValidChatCompletionModel(_modelName = "") {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user