merlyn/server/utils/AiProviders/modelMap/index.js
Timothy Carambat e80492606a
Automatic Context window detection (#3817)
* Add context window finder from litellm maintained list
apply to all cloud providers, have client cache for 3 days

* linting
2025-05-14 11:03:19 -07:00

141 lines
4.9 KiB
JavaScript

const path = require("path");
const fs = require("fs");
const LEGACY_MODEL_MAP = require("./legacy");
class ContextWindowFinder {
static instance = null;
static modelMap = LEGACY_MODEL_MAP;
/**
* Mapping for AnythingLLM provider <> LiteLLM provider
* @type {Record<string, string>}
*/
static trackedProviders = {
anthropic: "anthropic",
openai: "openai",
cohere: "cohere_chat",
gemini: "vertex_ai-language-models",
groq: "groq",
xai: "xai",
deepseek: "deepseek",
};
static expiryMs = 1000 * 60 * 60 * 24 * 3; // 3 days
static remoteUrl =
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json";
cacheLocation = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, "models", "context-windows")
: path.resolve(__dirname, `../../../storage/models/context-windows`)
);
cacheFilePath = path.resolve(this.cacheLocation, "context-windows.json");
cacheFileExpiryPath = path.resolve(this.cacheLocation, ".cached_at");
constructor() {
if (ContextWindowFinder.instance) return ContextWindowFinder.instance;
ContextWindowFinder.instance = this;
if (!fs.existsSync(this.cacheLocation))
fs.mkdirSync(this.cacheLocation, { recursive: true });
this.#pullRemoteModelMap();
}
log(text, ...args) {
console.log(`\x1b[33m[ContextWindowFinder]\x1b[0m ${text}`, ...args);
}
/**
* Checks if the cache is stale by checking if the cache file exists and if the cache file is older than the expiry time.
* @returns {boolean}
*/
get isCacheStale() {
if (!fs.existsSync(this.cacheFileExpiryPath)) return true;
const cachedAt = fs.readFileSync(this.cacheFileExpiryPath, "utf8");
return Date.now() - cachedAt > ContextWindowFinder.expiryMs;
}
get cache() {
if (!fs.existsSync(this.cacheFileExpiryPath)) return null;
if (!this.isCacheStale)
return JSON.parse(
fs.readFileSync(this.cacheFilePath, { encoding: "utf8" })
);
return null;
}
/**
* Pulls the remote model map from the remote URL, formats it and caches it.
* @returns {Record<string, Record<string, number>>} - The formatted model map
*/
async #pullRemoteModelMap() {
const remoteContexWindowMap = await fetch(ContextWindowFinder.remoteUrl)
.then((res) => res.json())
.then((data) => {
fs.writeFileSync(this.cacheFilePath, JSON.stringify(data, null, 2));
fs.writeFileSync(this.cacheFileExpiryPath, Date.now().toString());
this.log("Remote model map synced and cached");
return data;
})
.catch((error) => {
this.log("Error syncing remote model map", error);
return null;
});
if (!remoteContexWindowMap) return null;
const modelMap = this.#formatModelMap(remoteContexWindowMap);
fs.writeFileSync(this.cacheFilePath, JSON.stringify(modelMap, null, 2));
fs.writeFileSync(this.cacheFileExpiryPath, Date.now().toString());
return modelMap;
}
/**
* Formats the remote model map to a format that is compatible with how we store the model map
* for all providers who use it.
* @param {Record<string, any>} modelMap - The remote model map
* @returns {Record<string, Record<string, number>>} - The formatted model map
*/
#formatModelMap(modelMap = {}) {
const formattedModelMap = {};
for (const [provider, liteLLMProviderTag] of Object.entries(
ContextWindowFinder.trackedProviders
)) {
formattedModelMap[provider] = {};
const matches = Object.entries(modelMap).filter(
([_key, config]) => config.litellm_provider === liteLLMProviderTag
);
for (const [key, config] of matches) {
const contextWindow = Number(config.max_input_tokens);
if (isNaN(contextWindow)) continue;
// Some models have a provider/model-tag format, so we need to get the last part since we dont do paths
// for names with the exception of some router-providers like OpenRouter or Together.
const modelName = key.split("/").pop();
formattedModelMap[provider][modelName] = contextWindow;
}
}
return formattedModelMap;
}
/**
* Gets the context window for a given provider and model.
* @param {string} provider - The provider to get the context window for
* @param {string} model - The model to get the context window for
* @returns {number} - The context window for the given provider and model
*/
get(provider = null, model = null) {
if (!provider || !this.cache || !this.cache[provider]) return null;
if (!model) return this.cache[provider];
const modelContextWindow = this.cache[provider][model];
if (!modelContextWindow) {
this.log("Invalid access to model context window - not found in cache", {
provider,
model,
});
return null;
}
return Number(modelContextWindow);
}
}
module.exports = { MODEL_MAP: new ContextWindowFinder() };