Overhaul AWS Bedrock provider (#3537)

* Patch AWS Bedrock provider for newer models and performance

* patch prompt constructor
This commit is contained in:
Timothy Carambat 2025-03-25 15:58:16 -07:00 committed by GitHub
parent c29f7669c4
commit 78c83383d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1048 additions and 128 deletions

View File

@ -20,6 +20,7 @@
},
"dependencies": {
"@anthropic-ai/sdk": "^0.39.0",
"@aws-sdk/client-bedrock-runtime": "^3.775.0",
"@azure/openai": "1.0.0-beta.10",
"@datastax/astra-db-ts": "^0.1.3",
"@google/generative-ai": "^0.7.1",

View File

@ -1,15 +1,18 @@
const { StringOutputParser } = require("@langchain/core/output_parsers");
const {
BedrockRuntimeClient,
ConverseCommand,
ConverseStreamCommand,
} = require("@aws-sdk/client-bedrock-runtime");
const {
writeResponseChunk,
clientAbortedHandler,
formatChatHistory,
} = require("../../helpers/chat/responses");
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const {
LLMPerformanceMonitor,
} = require("../../helpers/chat/LLMPerformanceMonitor");
const { v4: uuidv4 } = require("uuid");
// Docs: https://js.langchain.com/v0.2/docs/integrations/chat/bedrock_converse
class AWSBedrockLLM {
/**
* These models do not support system prompts
@ -23,6 +26,7 @@ class AWSBedrockLLM {
"amazon.titan-text-lite-v1",
"cohere.command-text-v14",
"cohere.command-light-text-v14",
"us.deepseek.r1-v1:0",
];
constructor(embedder = null, modelPreference = null) {
@ -51,6 +55,17 @@ class AWSBedrockLLM {
user: this.promptWindowLimit() * 0.7,
};
this.bedrockClient = new BedrockRuntimeClient({
region: process.env.AWS_BEDROCK_LLM_REGION,
credentials: {
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
...(this.authMethod === "sessionToken"
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
: {}),
},
});
this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7;
this.#log(
@ -69,62 +84,6 @@ class AWSBedrockLLM {
return method;
}
#bedrockClient({ temperature = 0.7 }) {
const { ChatBedrockConverse } = require("@langchain/aws");
return new ChatBedrockConverse({
model: this.model,
region: process.env.AWS_BEDROCK_LLM_REGION,
credentials: {
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
...(this.authMethod === "sessionToken"
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
: {}),
},
temperature,
});
}
// For streaming we use Langchain's wrapper to handle weird chunks
// or otherwise absorb headaches that can arise from Bedrock models
#convertToLangchainPrototypes(chats = []) {
const {
HumanMessage,
SystemMessage,
AIMessage,
} = require("@langchain/core/messages");
const langchainChats = [];
const roleToMessageMap = {
system: SystemMessage,
user: HumanMessage,
assistant: AIMessage,
};
for (const chat of chats) {
if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
// When a model does not support system prompts, we need to handle it.
// We will add a new message that simulates the system prompt via a user message and AI response.
// This will allow the model to respond without crashing but we can still inject context.
if (
this.noSystemPromptModels.includes(this.model) &&
chat.role === "system"
) {
this.#log(
`Model does not support system prompts! Simulating system prompt via Human/AI message pairs.`
);
langchainChats.push(new HumanMessage({ content: chat.content }));
langchainChats.push(new AIMessage({ content: "Okay." }));
continue;
}
const MessageClass = roleToMessageMap[chat.role];
langchainChats.push(new MessageClass({ content: chat.content }));
}
return langchainChats;
}
#appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return "";
return (
@ -167,22 +126,21 @@ class AWSBedrockLLM {
/**
* Generates appropriate content array for a message + attachments.
* TODO: Implement this - attachments are not supported yet for Bedrock.
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
* @returns {string|object[]}
*/
#generateContent({ userPrompt, attachments = [] }) {
if (!attachments.length) {
return { content: userPrompt };
}
if (!attachments.length) return [{ text: userPrompt }];
const content = [{ type: "text", text: userPrompt }];
for (let attachment of attachments) {
content.push({
type: "image_url",
image_url: attachment.contentString,
});
}
return { content: content.flat() };
// const content = [{ type: "text", text: userPrompt }];
// for (let attachment of attachments) {
// content.push({
// type: "image_url",
// image_url: attachment.contentString,
// });
// }
// return { content: content.flat() };
}
/**
@ -195,72 +153,125 @@ class AWSBedrockLLM {
contextTexts = [],
chatHistory = [],
userPrompt = "",
attachments = [],
_attachments = [],
}) {
// AWS Mistral models do not support system prompts
if (this.model.startsWith("mistral"))
return [
...formatChatHistory(chatHistory, this.#generateContent, "spread"),
let prompt = [
{
role: "user",
...this.#generateContent({ userPrompt, attachments }),
role: "system",
content: [
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
],
},
];
const prompt = {
role: "system",
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
};
return [
prompt,
...formatChatHistory(chatHistory, this.#generateContent, "spread"),
// If the model does not support system prompts, we need to add a user message and assistant message
if (this.noSystemPromptModels.includes(this.model)) {
prompt = [
{
role: "user",
...this.#generateContent({ userPrompt, attachments }),
content: [
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
],
},
{
role: "assistant",
content: [{ text: "Okay." }],
},
];
}
return [
...prompt,
...chatHistory.map((msg) => ({
role: msg.role,
content: this.#generateContent({
userPrompt: msg.content,
attachments: msg.attachments,
}),
})),
{
role: "user",
content: this.#generateContent({
userPrompt: userPrompt,
attachments: [],
}),
},
];
}
/**
* Parses and prepends reasoning from the response and returns the full text response.
* @param {Object} response
* @returns {string}
*/
#parseReasoningFromResponse({ content = [] }) {
let textResponse = content[0]?.text;
if (
!!content?.[1]?.reasoningContent &&
content?.[1]?.reasoningContent?.reasoningText?.text?.trim().length > 0
)
textResponse = `<think>${content?.[1]?.reasoningContent?.reasoningText?.text}</think>${textResponse}`;
return textResponse;
}
async getChatCompletion(messages = null, { temperature = 0.7 }) {
const model = this.#bedrockClient({ temperature });
const hasSystem = messages[0]?.role === "system";
const [system, ...history] = hasSystem ? messages : [null, ...messages];
const result = await LLMPerformanceMonitor.measureAsyncFunction(
model
.pipe(new StringOutputParser())
.invoke(this.#convertToLangchainPrototypes(messages))
this.bedrockClient
.send(
new ConverseCommand({
modelId: this.model,
messages: history,
inferenceConfig: {
maxTokens: this.promptWindowLimit(),
temperature,
},
system: !!system ? system.content : undefined,
})
)
.catch((e) => {
throw new Error(
`AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}`
);
})
}),
messages,
false
);
if (!result.output || result.output.length === 0) return null;
// Langchain does not return the usage metrics in the response so we estimate them
const promptTokens = LLMPerformanceMonitor.countTokens(messages);
const completionTokens = LLMPerformanceMonitor.countTokens([
{ content: result.output },
]);
const response = result.output;
if (!response || !response?.output) return null;
return {
textResponse: result.output,
textResponse: this.#parseReasoningFromResponse(response.output?.message),
metrics: {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: promptTokens + completionTokens,
outputTps: completionTokens / result.duration,
prompt_tokens: response?.usage?.inputTokens,
completion_tokens: response?.usage?.outputTokens,
total_tokens: response?.usage?.totalTokens,
outputTps:
response?.usage?.outputTokens / (response?.metrics?.latencyMs / 1000),
duration: result.duration,
},
};
}
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
const model = this.#bedrockClient({ temperature });
const hasSystem = messages[0]?.role === "system";
const [system, ...history] = hasSystem ? messages : [null, ...messages];
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
model
.pipe(new StringOutputParser())
.stream(this.#convertToLangchainPrototypes(messages)),
messages
this.bedrockClient.send(
new ConverseStreamCommand({
modelId: this.model,
messages: history,
inferenceConfig: { maxTokens: this.promptWindowLimit(), temperature },
system: !!system ? system.content : undefined,
})
),
messages,
false
);
return measuredStreamRequest;
}
@ -275,12 +286,15 @@ class AWSBedrockLLM {
*/
handleStream(response, stream, responseProps) {
const { uuid = uuidv4(), sources = [] } = responseProps;
let hasUsageMetrics = false;
let usage = {
prompt_tokens: 0,
completion_tokens: 0,
};
return new Promise(async (resolve) => {
let fullText = "";
let usage = {
completion_tokens: 0,
};
let reasoningText = "";
// Establish listener to early-abort a streaming response
// in case things go sideways or the user does not like the response.
@ -293,25 +307,82 @@ class AWSBedrockLLM {
response.on("close", handleAbort);
try {
for await (const chunk of stream) {
for await (const chunk of stream.stream) {
if (chunk === undefined)
throw new Error(
"Stream returned undefined chunk. Aborting reply - check model provider logs."
);
const content = chunk.hasOwnProperty("content")
? chunk.content
: chunk;
fullText += content;
if (!!content) usage.completion_tokens++; // Dont count empty chunks
const action = Object.keys(chunk)[0];
if (action === "metadata") {
hasUsageMetrics = true;
usage.prompt_tokens = chunk.metadata?.usage?.inputTokens ?? 0;
usage.completion_tokens = chunk.metadata?.usage?.outputTokens ?? 0;
usage.total_tokens = chunk.metadata?.usage?.totalTokens ?? 0;
}
if (action === "contentBlockDelta") {
const token = chunk.contentBlockDelta?.delta?.text;
const reasoningToken =
chunk.contentBlockDelta?.delta?.reasoningContent?.text;
// Reasoning models will always return the reasoning text before the token text.
if (reasoningToken) {
// If the reasoning text is empty (''), we need to initialize it
// and send the first chunk of reasoning text.
if (reasoningText.length === 0) {
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: content,
textResponse: `<think>${reasoningToken}`,
close: false,
error: false,
});
reasoningText += `<think>${reasoningToken}`;
continue;
} else {
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: reasoningToken,
close: false,
error: false,
});
reasoningText += reasoningToken;
}
}
// If the reasoning text is not empty, but the reasoning token is empty
// and the token text is not empty we need to close the reasoning text and begin sending the token text.
if (!!reasoningText && !reasoningToken && token) {
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: `</think>`,
close: false,
error: false,
});
fullText += `${reasoningText}</think>`;
reasoningText = "";
}
if (token) {
fullText += token;
// If we never saw a usage metric, we can estimate them by number of completion chunks
if (!hasUsageMetrics) usage.completion_tokens++;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: token,
close: false,
error: false,
});
}
}
}
writeResponseChunk(response, {
@ -326,18 +397,18 @@ class AWSBedrockLLM {
stream?.endMeasurement(usage);
resolve(fullText);
} catch (error) {
console.log(`\x1b[43m\x1b[34m[STREAMING ERROR]\x1b[0m ${e.message}`);
writeResponseChunk(response, {
uuid,
type: "abort",
textResponse: null,
sources: [],
type: "textResponseChunk",
textResponse: "",
close: true,
error: `AWSBedrock:streaming - could not stream chat. ${
error?.cause ?? error.message
}`,
error: `AWSBedrock:streaming - could not stream chat. ${error?.cause ?? error.message}`,
});
response.removeListener("close", handleAbort);
stream?.endMeasurement(usage);
resolve(fullText); // Return what we currently have - if anything.
}
});
}

File diff suppressed because it is too large Load Diff