Overhaul AWS Bedrock provider (#3537)

* Patch AWS Bedrock provider for newer models and performance

* patch prompt constructor
This commit is contained in:
Timothy Carambat 2025-03-25 15:58:16 -07:00 committed by GitHub
parent c29f7669c4
commit 78c83383d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1048 additions and 128 deletions

View File

@ -20,6 +20,7 @@
}, },
"dependencies": { "dependencies": {
"@anthropic-ai/sdk": "^0.39.0", "@anthropic-ai/sdk": "^0.39.0",
"@aws-sdk/client-bedrock-runtime": "^3.775.0",
"@azure/openai": "1.0.0-beta.10", "@azure/openai": "1.0.0-beta.10",
"@datastax/astra-db-ts": "^0.1.3", "@datastax/astra-db-ts": "^0.1.3",
"@google/generative-ai": "^0.7.1", "@google/generative-ai": "^0.7.1",

View File

@ -1,15 +1,18 @@
const { StringOutputParser } = require("@langchain/core/output_parsers"); const {
BedrockRuntimeClient,
ConverseCommand,
ConverseStreamCommand,
} = require("@aws-sdk/client-bedrock-runtime");
const { const {
writeResponseChunk, writeResponseChunk,
clientAbortedHandler, clientAbortedHandler,
formatChatHistory,
} = require("../../helpers/chat/responses"); } = require("../../helpers/chat/responses");
const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { const {
LLMPerformanceMonitor, LLMPerformanceMonitor,
} = require("../../helpers/chat/LLMPerformanceMonitor"); } = require("../../helpers/chat/LLMPerformanceMonitor");
const { v4: uuidv4 } = require("uuid");
// Docs: https://js.langchain.com/v0.2/docs/integrations/chat/bedrock_converse
class AWSBedrockLLM { class AWSBedrockLLM {
/** /**
* These models do not support system prompts * These models do not support system prompts
@ -23,6 +26,7 @@ class AWSBedrockLLM {
"amazon.titan-text-lite-v1", "amazon.titan-text-lite-v1",
"cohere.command-text-v14", "cohere.command-text-v14",
"cohere.command-light-text-v14", "cohere.command-light-text-v14",
"us.deepseek.r1-v1:0",
]; ];
constructor(embedder = null, modelPreference = null) { constructor(embedder = null, modelPreference = null) {
@ -51,6 +55,17 @@ class AWSBedrockLLM {
user: this.promptWindowLimit() * 0.7, user: this.promptWindowLimit() * 0.7,
}; };
this.bedrockClient = new BedrockRuntimeClient({
region: process.env.AWS_BEDROCK_LLM_REGION,
credentials: {
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
...(this.authMethod === "sessionToken"
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
: {}),
},
});
this.embedder = embedder ?? new NativeEmbedder(); this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; this.defaultTemp = 0.7;
this.#log( this.#log(
@ -69,62 +84,6 @@ class AWSBedrockLLM {
return method; return method;
} }
#bedrockClient({ temperature = 0.7 }) {
const { ChatBedrockConverse } = require("@langchain/aws");
return new ChatBedrockConverse({
model: this.model,
region: process.env.AWS_BEDROCK_LLM_REGION,
credentials: {
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
...(this.authMethod === "sessionToken"
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
: {}),
},
temperature,
});
}
// For streaming we use Langchain's wrapper to handle weird chunks
// or otherwise absorb headaches that can arise from Bedrock models
#convertToLangchainPrototypes(chats = []) {
const {
HumanMessage,
SystemMessage,
AIMessage,
} = require("@langchain/core/messages");
const langchainChats = [];
const roleToMessageMap = {
system: SystemMessage,
user: HumanMessage,
assistant: AIMessage,
};
for (const chat of chats) {
if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
// When a model does not support system prompts, we need to handle it.
// We will add a new message that simulates the system prompt via a user message and AI response.
// This will allow the model to respond without crashing but we can still inject context.
if (
this.noSystemPromptModels.includes(this.model) &&
chat.role === "system"
) {
this.#log(
`Model does not support system prompts! Simulating system prompt via Human/AI message pairs.`
);
langchainChats.push(new HumanMessage({ content: chat.content }));
langchainChats.push(new AIMessage({ content: "Okay." }));
continue;
}
const MessageClass = roleToMessageMap[chat.role];
langchainChats.push(new MessageClass({ content: chat.content }));
}
return langchainChats;
}
#appendContext(contextTexts = []) { #appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return ""; if (!contextTexts || !contextTexts.length) return "";
return ( return (
@ -167,22 +126,21 @@ class AWSBedrockLLM {
/** /**
* Generates appropriate content array for a message + attachments. * Generates appropriate content array for a message + attachments.
* TODO: Implement this - attachments are not supported yet for Bedrock.
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
* @returns {string|object[]} * @returns {string|object[]}
*/ */
#generateContent({ userPrompt, attachments = [] }) { #generateContent({ userPrompt, attachments = [] }) {
if (!attachments.length) { if (!attachments.length) return [{ text: userPrompt }];
return { content: userPrompt };
}
const content = [{ type: "text", text: userPrompt }]; // const content = [{ type: "text", text: userPrompt }];
for (let attachment of attachments) { // for (let attachment of attachments) {
content.push({ // content.push({
type: "image_url", // type: "image_url",
image_url: attachment.contentString, // image_url: attachment.contentString,
}); // });
} // }
return { content: content.flat() }; // return { content: content.flat() };
} }
/** /**
@ -195,72 +153,125 @@ class AWSBedrockLLM {
contextTexts = [], contextTexts = [],
chatHistory = [], chatHistory = [],
userPrompt = "", userPrompt = "",
attachments = [], _attachments = [],
}) { }) {
// AWS Mistral models do not support system prompts let prompt = [
if (this.model.startsWith("mistral")) {
return [ role: "system",
...formatChatHistory(chatHistory, this.#generateContent, "spread"), content: [
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
],
},
];
// If the model does not support system prompts, we need to add a user message and assistant message
if (this.noSystemPromptModels.includes(this.model)) {
prompt = [
{ {
role: "user", role: "user",
...this.#generateContent({ userPrompt, attachments }), content: [
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
],
},
{
role: "assistant",
content: [{ text: "Okay." }],
}, },
]; ];
}
const prompt = {
role: "system",
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
};
return [ return [
prompt, ...prompt,
...formatChatHistory(chatHistory, this.#generateContent, "spread"), ...chatHistory.map((msg) => ({
role: msg.role,
content: this.#generateContent({
userPrompt: msg.content,
attachments: msg.attachments,
}),
})),
{ {
role: "user", role: "user",
...this.#generateContent({ userPrompt, attachments }), content: this.#generateContent({
userPrompt: userPrompt,
attachments: [],
}),
}, },
]; ];
} }
/**
* Parses and prepends reasoning from the response and returns the full text response.
* @param {Object} response
* @returns {string}
*/
#parseReasoningFromResponse({ content = [] }) {
let textResponse = content[0]?.text;
if (
!!content?.[1]?.reasoningContent &&
content?.[1]?.reasoningContent?.reasoningText?.text?.trim().length > 0
)
textResponse = `<think>${content?.[1]?.reasoningContent?.reasoningText?.text}</think>${textResponse}`;
return textResponse;
}
async getChatCompletion(messages = null, { temperature = 0.7 }) { async getChatCompletion(messages = null, { temperature = 0.7 }) {
const model = this.#bedrockClient({ temperature }); const hasSystem = messages[0]?.role === "system";
const [system, ...history] = hasSystem ? messages : [null, ...messages];
const result = await LLMPerformanceMonitor.measureAsyncFunction( const result = await LLMPerformanceMonitor.measureAsyncFunction(
model this.bedrockClient
.pipe(new StringOutputParser()) .send(
.invoke(this.#convertToLangchainPrototypes(messages)) new ConverseCommand({
modelId: this.model,
messages: history,
inferenceConfig: {
maxTokens: this.promptWindowLimit(),
temperature,
},
system: !!system ? system.content : undefined,
})
)
.catch((e) => { .catch((e) => {
throw new Error( throw new Error(
`AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}` `AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}`
); );
}) }),
messages,
false
); );
if (!result.output || result.output.length === 0) return null; const response = result.output;
if (!response || !response?.output) return null;
// Langchain does not return the usage metrics in the response so we estimate them
const promptTokens = LLMPerformanceMonitor.countTokens(messages);
const completionTokens = LLMPerformanceMonitor.countTokens([
{ content: result.output },
]);
return { return {
textResponse: result.output, textResponse: this.#parseReasoningFromResponse(response.output?.message),
metrics: { metrics: {
prompt_tokens: promptTokens, prompt_tokens: response?.usage?.inputTokens,
completion_tokens: completionTokens, completion_tokens: response?.usage?.outputTokens,
total_tokens: promptTokens + completionTokens, total_tokens: response?.usage?.totalTokens,
outputTps: completionTokens / result.duration, outputTps:
response?.usage?.outputTokens / (response?.metrics?.latencyMs / 1000),
duration: result.duration, duration: result.duration,
}, },
}; };
} }
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
const model = this.#bedrockClient({ temperature }); const hasSystem = messages[0]?.role === "system";
const [system, ...history] = hasSystem ? messages : [null, ...messages];
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream( const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
model this.bedrockClient.send(
.pipe(new StringOutputParser()) new ConverseStreamCommand({
.stream(this.#convertToLangchainPrototypes(messages)), modelId: this.model,
messages messages: history,
inferenceConfig: { maxTokens: this.promptWindowLimit(), temperature },
system: !!system ? system.content : undefined,
})
),
messages,
false
); );
return measuredStreamRequest; return measuredStreamRequest;
} }
@ -275,12 +286,15 @@ class AWSBedrockLLM {
*/ */
handleStream(response, stream, responseProps) { handleStream(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps; const { uuid = uuidv4(), sources = [] } = responseProps;
let hasUsageMetrics = false;
let usage = {
prompt_tokens: 0,
completion_tokens: 0,
};
return new Promise(async (resolve) => { return new Promise(async (resolve) => {
let fullText = ""; let fullText = "";
let usage = { let reasoningText = "";
completion_tokens: 0,
};
// Establish listener to early-abort a streaming response // Establish listener to early-abort a streaming response
// in case things go sideways or the user does not like the response. // in case things go sideways or the user does not like the response.
@ -293,25 +307,82 @@ class AWSBedrockLLM {
response.on("close", handleAbort); response.on("close", handleAbort);
try { try {
for await (const chunk of stream) { for await (const chunk of stream.stream) {
if (chunk === undefined) if (chunk === undefined)
throw new Error( throw new Error(
"Stream returned undefined chunk. Aborting reply - check model provider logs." "Stream returned undefined chunk. Aborting reply - check model provider logs."
); );
const content = chunk.hasOwnProperty("content") const action = Object.keys(chunk)[0];
? chunk.content if (action === "metadata") {
: chunk; hasUsageMetrics = true;
fullText += content; usage.prompt_tokens = chunk.metadata?.usage?.inputTokens ?? 0;
if (!!content) usage.completion_tokens++; // Dont count empty chunks usage.completion_tokens = chunk.metadata?.usage?.outputTokens ?? 0;
writeResponseChunk(response, { usage.total_tokens = chunk.metadata?.usage?.totalTokens ?? 0;
uuid, }
sources: [],
type: "textResponseChunk", if (action === "contentBlockDelta") {
textResponse: content, const token = chunk.contentBlockDelta?.delta?.text;
close: false, const reasoningToken =
error: false, chunk.contentBlockDelta?.delta?.reasoningContent?.text;
});
// Reasoning models will always return the reasoning text before the token text.
if (reasoningToken) {
// If the reasoning text is empty (''), we need to initialize it
// and send the first chunk of reasoning text.
if (reasoningText.length === 0) {
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: `<think>${reasoningToken}`,
close: false,
error: false,
});
reasoningText += `<think>${reasoningToken}`;
continue;
} else {
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: reasoningToken,
close: false,
error: false,
});
reasoningText += reasoningToken;
}
}
// If the reasoning text is not empty, but the reasoning token is empty
// and the token text is not empty we need to close the reasoning text and begin sending the token text.
if (!!reasoningText && !reasoningToken && token) {
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: `</think>`,
close: false,
error: false,
});
fullText += `${reasoningText}</think>`;
reasoningText = "";
}
if (token) {
fullText += token;
// If we never saw a usage metric, we can estimate them by number of completion chunks
if (!hasUsageMetrics) usage.completion_tokens++;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: token,
close: false,
error: false,
});
}
}
} }
writeResponseChunk(response, { writeResponseChunk(response, {
@ -326,18 +397,18 @@ class AWSBedrockLLM {
stream?.endMeasurement(usage); stream?.endMeasurement(usage);
resolve(fullText); resolve(fullText);
} catch (error) { } catch (error) {
console.log(`\x1b[43m\x1b[34m[STREAMING ERROR]\x1b[0m ${e.message}`);
writeResponseChunk(response, { writeResponseChunk(response, {
uuid, uuid,
type: "abort",
textResponse: null,
sources: [], sources: [],
type: "textResponseChunk",
textResponse: "",
close: true, close: true,
error: `AWSBedrock:streaming - could not stream chat. ${ error: `AWSBedrock:streaming - could not stream chat. ${error?.cause ?? error.message}`,
error?.cause ?? error.message
}`,
}); });
response.removeListener("close", handleAbort); response.removeListener("close", handleAbort);
stream?.endMeasurement(usage); stream?.endMeasurement(usage);
resolve(fullText); // Return what we currently have - if anything.
} }
}); });
} }

File diff suppressed because it is too large Load Diff