add Dynamic max_tokens retreival for Anthropic models (#5255)
This commit is contained in:
parent
9d242bc053
commit
1b0add0318
@ -36,11 +36,17 @@ class AnthropicLLM {
|
||||
user: this.promptWindowLimit() * 0.7,
|
||||
};
|
||||
|
||||
this.maxTokens = null;
|
||||
this.embedder = embedder ?? new NativeEmbedder();
|
||||
this.defaultTemp = 0.7;
|
||||
this.log(
|
||||
`Initialized with ${this.model}. Cache ${this.cacheControl ? `enabled (${this.cacheControl.ttl})` : "disabled"}`
|
||||
);
|
||||
|
||||
AnthropicLLM.fetchModelMaxTokens(this.model).then((maxTokens) => {
|
||||
this.maxTokens = maxTokens;
|
||||
this.log(`Model ${this.model} max tokens: ${this.maxTokens}`);
|
||||
});
|
||||
}
|
||||
|
||||
log(text, ...args) {
|
||||
@ -63,6 +69,35 @@ class AnthropicLLM {
|
||||
return true;
|
||||
}
|
||||
|
||||
async assertModelMaxTokens() {
|
||||
if (this.maxTokens) return this.maxTokens;
|
||||
this.maxTokens = await AnthropicLLM.fetchModelMaxTokens(this.model);
|
||||
return this.maxTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches the maximum number of tokens the model should generate in its response.
|
||||
* This varies per model but will fallback to 4096 if the model is not found.
|
||||
* @param {string} modelName - The name of the model to fetch the max tokens for
|
||||
* @returns {Promise<number>} The maximum output tokens limit for API calls.
|
||||
*/
|
||||
static async fetchModelMaxTokens(
|
||||
modelName = process.env.ANTHROPIC_MODEL_PREF
|
||||
) {
|
||||
try {
|
||||
const AnthropicAI = require("@anthropic-ai/sdk");
|
||||
/** @type {import("@anthropic-ai/sdk").Anthropic} */
|
||||
const anthropic = new AnthropicAI({
|
||||
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||
});
|
||||
const model = await anthropic.models.retrieve(modelName);
|
||||
return Number(model.max_tokens ?? 4096);
|
||||
} catch (error) {
|
||||
console.error(`Error fetching model max tokens for ${modelName}:`, error);
|
||||
return 4096;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the cache control ENV variable
|
||||
*
|
||||
@ -152,12 +187,13 @@ class AnthropicLLM {
|
||||
}
|
||||
|
||||
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||
await this.assertModelMaxTokens();
|
||||
try {
|
||||
const systemContent = messages[0].content;
|
||||
const result = await LLMPerformanceMonitor.measureAsyncFunction(
|
||||
this.anthropic.messages.create({
|
||||
model: this.model,
|
||||
max_tokens: 4096,
|
||||
max_tokens: this.maxTokens,
|
||||
system: this.#buildSystemPrompt(systemContent),
|
||||
messages: messages.slice(1), // Pop off the system message
|
||||
temperature: Number(temperature ?? this.defaultTemp),
|
||||
@ -187,11 +223,12 @@ class AnthropicLLM {
|
||||
}
|
||||
|
||||
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||
await this.assertModelMaxTokens();
|
||||
const systemContent = messages[0].content;
|
||||
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream({
|
||||
func: this.anthropic.messages.stream({
|
||||
model: this.model,
|
||||
max_tokens: 4096,
|
||||
max_tokens: this.maxTokens,
|
||||
system: this.#buildSystemPrompt(systemContent),
|
||||
messages: messages.slice(1), // Pop off the system message
|
||||
temperature: Number(temperature ?? this.defaultTemp),
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
const Anthropic = require("@anthropic-ai/sdk");
|
||||
const { AnthropicLLM } = require("../../../AiProviders/anthropic");
|
||||
const { RetryError } = require("../error.js");
|
||||
const Provider = require("./ai-provider.js");
|
||||
const { v4 } = require("uuid");
|
||||
@ -11,6 +12,7 @@ const { getAnythingLLMUserAgent } = require("../../../../endpoints/utils");
|
||||
*/
|
||||
class AnthropicProvider extends Provider {
|
||||
model;
|
||||
maxTokens = null;
|
||||
|
||||
constructor(config = {}) {
|
||||
const {
|
||||
@ -39,6 +41,17 @@ class AnthropicProvider extends Provider {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches the maximum number of tokens the model should generate in its response.
|
||||
* This varies per model but will fallback to 4096 if the model is not found.
|
||||
* @returns {Promise<number>} The maximum output tokens limit for API calls.
|
||||
*/
|
||||
async assertModelMaxTokens() {
|
||||
if (this.maxTokens) return this.maxTokens;
|
||||
this.maxTokens = await AnthropicLLM.fetchModelMaxTokens(this.model);
|
||||
return this.maxTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the cache control ENV variable
|
||||
*
|
||||
@ -227,6 +240,7 @@ class AnthropicProvider extends Provider {
|
||||
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>} - The result of the chat completion.
|
||||
*/
|
||||
async stream(messages, functions = [], eventHandler = null) {
|
||||
await this.assertModelMaxTokens();
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
@ -235,7 +249,7 @@ class AnthropicProvider extends Provider {
|
||||
const response = await this.client.messages.create(
|
||||
{
|
||||
model: this.model,
|
||||
max_tokens: 4096,
|
||||
max_tokens: this.maxTokens,
|
||||
system: this.#buildSystemPrompt(systemPrompt),
|
||||
messages: chats,
|
||||
stream: true,
|
||||
@ -374,6 +388,7 @@ class AnthropicProvider extends Provider {
|
||||
* @returns The completion.
|
||||
*/
|
||||
async complete(messages, functions = []) {
|
||||
await this.assertModelMaxTokens();
|
||||
this.resetUsage();
|
||||
|
||||
try {
|
||||
@ -381,7 +396,7 @@ class AnthropicProvider extends Provider {
|
||||
const response = await this.client.messages.create(
|
||||
{
|
||||
model: this.model,
|
||||
max_tokens: 4096,
|
||||
max_tokens: this.maxTokens,
|
||||
system: this.#buildSystemPrompt(systemPrompt),
|
||||
messages: chats,
|
||||
stream: false,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user