add Dynamic max_tokens retreival for Anthropic models (#5255)
This commit is contained in:
parent
9d242bc053
commit
1b0add0318
@ -36,11 +36,17 @@ class AnthropicLLM {
|
|||||||
user: this.promptWindowLimit() * 0.7,
|
user: this.promptWindowLimit() * 0.7,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
this.maxTokens = null;
|
||||||
this.embedder = embedder ?? new NativeEmbedder();
|
this.embedder = embedder ?? new NativeEmbedder();
|
||||||
this.defaultTemp = 0.7;
|
this.defaultTemp = 0.7;
|
||||||
this.log(
|
this.log(
|
||||||
`Initialized with ${this.model}. Cache ${this.cacheControl ? `enabled (${this.cacheControl.ttl})` : "disabled"}`
|
`Initialized with ${this.model}. Cache ${this.cacheControl ? `enabled (${this.cacheControl.ttl})` : "disabled"}`
|
||||||
);
|
);
|
||||||
|
|
||||||
|
AnthropicLLM.fetchModelMaxTokens(this.model).then((maxTokens) => {
|
||||||
|
this.maxTokens = maxTokens;
|
||||||
|
this.log(`Model ${this.model} max tokens: ${this.maxTokens}`);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
log(text, ...args) {
|
log(text, ...args) {
|
||||||
@ -63,6 +69,35 @@ class AnthropicLLM {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async assertModelMaxTokens() {
|
||||||
|
if (this.maxTokens) return this.maxTokens;
|
||||||
|
this.maxTokens = await AnthropicLLM.fetchModelMaxTokens(this.model);
|
||||||
|
return this.maxTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetches the maximum number of tokens the model should generate in its response.
|
||||||
|
* This varies per model but will fallback to 4096 if the model is not found.
|
||||||
|
* @param {string} modelName - The name of the model to fetch the max tokens for
|
||||||
|
* @returns {Promise<number>} The maximum output tokens limit for API calls.
|
||||||
|
*/
|
||||||
|
static async fetchModelMaxTokens(
|
||||||
|
modelName = process.env.ANTHROPIC_MODEL_PREF
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
const AnthropicAI = require("@anthropic-ai/sdk");
|
||||||
|
/** @type {import("@anthropic-ai/sdk").Anthropic} */
|
||||||
|
const anthropic = new AnthropicAI({
|
||||||
|
apiKey: process.env.ANTHROPIC_API_KEY,
|
||||||
|
});
|
||||||
|
const model = await anthropic.models.retrieve(modelName);
|
||||||
|
return Number(model.max_tokens ?? 4096);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error fetching model max tokens for ${modelName}:`, error);
|
||||||
|
return 4096;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parses the cache control ENV variable
|
* Parses the cache control ENV variable
|
||||||
*
|
*
|
||||||
@ -152,12 +187,13 @@ class AnthropicLLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||||
|
await this.assertModelMaxTokens();
|
||||||
try {
|
try {
|
||||||
const systemContent = messages[0].content;
|
const systemContent = messages[0].content;
|
||||||
const result = await LLMPerformanceMonitor.measureAsyncFunction(
|
const result = await LLMPerformanceMonitor.measureAsyncFunction(
|
||||||
this.anthropic.messages.create({
|
this.anthropic.messages.create({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
max_tokens: 4096,
|
max_tokens: this.maxTokens,
|
||||||
system: this.#buildSystemPrompt(systemContent),
|
system: this.#buildSystemPrompt(systemContent),
|
||||||
messages: messages.slice(1), // Pop off the system message
|
messages: messages.slice(1), // Pop off the system message
|
||||||
temperature: Number(temperature ?? this.defaultTemp),
|
temperature: Number(temperature ?? this.defaultTemp),
|
||||||
@ -187,11 +223,12 @@ class AnthropicLLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||||
|
await this.assertModelMaxTokens();
|
||||||
const systemContent = messages[0].content;
|
const systemContent = messages[0].content;
|
||||||
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream({
|
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream({
|
||||||
func: this.anthropic.messages.stream({
|
func: this.anthropic.messages.stream({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
max_tokens: 4096,
|
max_tokens: this.maxTokens,
|
||||||
system: this.#buildSystemPrompt(systemContent),
|
system: this.#buildSystemPrompt(systemContent),
|
||||||
messages: messages.slice(1), // Pop off the system message
|
messages: messages.slice(1), // Pop off the system message
|
||||||
temperature: Number(temperature ?? this.defaultTemp),
|
temperature: Number(temperature ?? this.defaultTemp),
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
const Anthropic = require("@anthropic-ai/sdk");
|
const Anthropic = require("@anthropic-ai/sdk");
|
||||||
|
const { AnthropicLLM } = require("../../../AiProviders/anthropic");
|
||||||
const { RetryError } = require("../error.js");
|
const { RetryError } = require("../error.js");
|
||||||
const Provider = require("./ai-provider.js");
|
const Provider = require("./ai-provider.js");
|
||||||
const { v4 } = require("uuid");
|
const { v4 } = require("uuid");
|
||||||
@ -11,6 +12,7 @@ const { getAnythingLLMUserAgent } = require("../../../../endpoints/utils");
|
|||||||
*/
|
*/
|
||||||
class AnthropicProvider extends Provider {
|
class AnthropicProvider extends Provider {
|
||||||
model;
|
model;
|
||||||
|
maxTokens = null;
|
||||||
|
|
||||||
constructor(config = {}) {
|
constructor(config = {}) {
|
||||||
const {
|
const {
|
||||||
@ -39,6 +41,17 @@ class AnthropicProvider extends Provider {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetches the maximum number of tokens the model should generate in its response.
|
||||||
|
* This varies per model but will fallback to 4096 if the model is not found.
|
||||||
|
* @returns {Promise<number>} The maximum output tokens limit for API calls.
|
||||||
|
*/
|
||||||
|
async assertModelMaxTokens() {
|
||||||
|
if (this.maxTokens) return this.maxTokens;
|
||||||
|
this.maxTokens = await AnthropicLLM.fetchModelMaxTokens(this.model);
|
||||||
|
return this.maxTokens;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parses the cache control ENV variable
|
* Parses the cache control ENV variable
|
||||||
*
|
*
|
||||||
@ -227,6 +240,7 @@ class AnthropicProvider extends Provider {
|
|||||||
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>} - The result of the chat completion.
|
* @returns {Promise<{ functionCall: any, textResponse: string, uuid: string }>} - The result of the chat completion.
|
||||||
*/
|
*/
|
||||||
async stream(messages, functions = [], eventHandler = null) {
|
async stream(messages, functions = [], eventHandler = null) {
|
||||||
|
await this.assertModelMaxTokens();
|
||||||
this.resetUsage();
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -235,7 +249,7 @@ class AnthropicProvider extends Provider {
|
|||||||
const response = await this.client.messages.create(
|
const response = await this.client.messages.create(
|
||||||
{
|
{
|
||||||
model: this.model,
|
model: this.model,
|
||||||
max_tokens: 4096,
|
max_tokens: this.maxTokens,
|
||||||
system: this.#buildSystemPrompt(systemPrompt),
|
system: this.#buildSystemPrompt(systemPrompt),
|
||||||
messages: chats,
|
messages: chats,
|
||||||
stream: true,
|
stream: true,
|
||||||
@ -374,6 +388,7 @@ class AnthropicProvider extends Provider {
|
|||||||
* @returns The completion.
|
* @returns The completion.
|
||||||
*/
|
*/
|
||||||
async complete(messages, functions = []) {
|
async complete(messages, functions = []) {
|
||||||
|
await this.assertModelMaxTokens();
|
||||||
this.resetUsage();
|
this.resetUsage();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -381,7 +396,7 @@ class AnthropicProvider extends Provider {
|
|||||||
const response = await this.client.messages.create(
|
const response = await this.client.messages.create(
|
||||||
{
|
{
|
||||||
model: this.model,
|
model: this.model,
|
||||||
max_tokens: 4096,
|
max_tokens: this.maxTokens,
|
||||||
system: this.#buildSystemPrompt(systemPrompt),
|
system: this.#buildSystemPrompt(systemPrompt),
|
||||||
messages: chats,
|
messages: chats,
|
||||||
stream: false,
|
stream: false,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user