Overhaul AWS Bedrock provider (#3537)
* Patch AWS Bedrock provider for newer models and performance * patch prompt constructor
This commit is contained in:
parent
c29f7669c4
commit
78c83383d8
@ -20,6 +20,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.39.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.775.0",
|
||||
"@azure/openai": "1.0.0-beta.10",
|
||||
"@datastax/astra-db-ts": "^0.1.3",
|
||||
"@google/generative-ai": "^0.7.1",
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
const { StringOutputParser } = require("@langchain/core/output_parsers");
|
||||
const {
|
||||
BedrockRuntimeClient,
|
||||
ConverseCommand,
|
||||
ConverseStreamCommand,
|
||||
} = require("@aws-sdk/client-bedrock-runtime");
|
||||
const {
|
||||
writeResponseChunk,
|
||||
clientAbortedHandler,
|
||||
formatChatHistory,
|
||||
} = require("../../helpers/chat/responses");
|
||||
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
|
||||
const {
|
||||
LLMPerformanceMonitor,
|
||||
} = require("../../helpers/chat/LLMPerformanceMonitor");
|
||||
const { v4: uuidv4 } = require("uuid");
|
||||
|
||||
// Docs: https://js.langchain.com/v0.2/docs/integrations/chat/bedrock_converse
|
||||
class AWSBedrockLLM {
|
||||
/**
|
||||
* These models do not support system prompts
|
||||
@ -23,6 +26,7 @@ class AWSBedrockLLM {
|
||||
"amazon.titan-text-lite-v1",
|
||||
"cohere.command-text-v14",
|
||||
"cohere.command-light-text-v14",
|
||||
"us.deepseek.r1-v1:0",
|
||||
];
|
||||
|
||||
constructor(embedder = null, modelPreference = null) {
|
||||
@ -51,6 +55,17 @@ class AWSBedrockLLM {
|
||||
user: this.promptWindowLimit() * 0.7,
|
||||
};
|
||||
|
||||
this.bedrockClient = new BedrockRuntimeClient({
|
||||
region: process.env.AWS_BEDROCK_LLM_REGION,
|
||||
credentials: {
|
||||
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
|
||||
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
|
||||
...(this.authMethod === "sessionToken"
|
||||
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
|
||||
: {}),
|
||||
},
|
||||
});
|
||||
|
||||
this.embedder = embedder ?? new NativeEmbedder();
|
||||
this.defaultTemp = 0.7;
|
||||
this.#log(
|
||||
@ -69,62 +84,6 @@ class AWSBedrockLLM {
|
||||
return method;
|
||||
}
|
||||
|
||||
#bedrockClient({ temperature = 0.7 }) {
|
||||
const { ChatBedrockConverse } = require("@langchain/aws");
|
||||
return new ChatBedrockConverse({
|
||||
model: this.model,
|
||||
region: process.env.AWS_BEDROCK_LLM_REGION,
|
||||
credentials: {
|
||||
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
|
||||
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
|
||||
...(this.authMethod === "sessionToken"
|
||||
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
|
||||
: {}),
|
||||
},
|
||||
temperature,
|
||||
});
|
||||
}
|
||||
|
||||
// For streaming we use Langchain's wrapper to handle weird chunks
|
||||
// or otherwise absorb headaches that can arise from Bedrock models
|
||||
#convertToLangchainPrototypes(chats = []) {
|
||||
const {
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
AIMessage,
|
||||
} = require("@langchain/core/messages");
|
||||
const langchainChats = [];
|
||||
const roleToMessageMap = {
|
||||
system: SystemMessage,
|
||||
user: HumanMessage,
|
||||
assistant: AIMessage,
|
||||
};
|
||||
|
||||
for (const chat of chats) {
|
||||
if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
|
||||
|
||||
// When a model does not support system prompts, we need to handle it.
|
||||
// We will add a new message that simulates the system prompt via a user message and AI response.
|
||||
// This will allow the model to respond without crashing but we can still inject context.
|
||||
if (
|
||||
this.noSystemPromptModels.includes(this.model) &&
|
||||
chat.role === "system"
|
||||
) {
|
||||
this.#log(
|
||||
`Model does not support system prompts! Simulating system prompt via Human/AI message pairs.`
|
||||
);
|
||||
langchainChats.push(new HumanMessage({ content: chat.content }));
|
||||
langchainChats.push(new AIMessage({ content: "Okay." }));
|
||||
continue;
|
||||
}
|
||||
|
||||
const MessageClass = roleToMessageMap[chat.role];
|
||||
langchainChats.push(new MessageClass({ content: chat.content }));
|
||||
}
|
||||
|
||||
return langchainChats;
|
||||
}
|
||||
|
||||
#appendContext(contextTexts = []) {
|
||||
if (!contextTexts || !contextTexts.length) return "";
|
||||
return (
|
||||
@ -167,22 +126,21 @@ class AWSBedrockLLM {
|
||||
|
||||
/**
|
||||
* Generates appropriate content array for a message + attachments.
|
||||
* TODO: Implement this - attachments are not supported yet for Bedrock.
|
||||
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
|
||||
* @returns {string|object[]}
|
||||
*/
|
||||
#generateContent({ userPrompt, attachments = [] }) {
|
||||
if (!attachments.length) {
|
||||
return { content: userPrompt };
|
||||
}
|
||||
if (!attachments.length) return [{ text: userPrompt }];
|
||||
|
||||
const content = [{ type: "text", text: userPrompt }];
|
||||
for (let attachment of attachments) {
|
||||
content.push({
|
||||
type: "image_url",
|
||||
image_url: attachment.contentString,
|
||||
});
|
||||
}
|
||||
return { content: content.flat() };
|
||||
// const content = [{ type: "text", text: userPrompt }];
|
||||
// for (let attachment of attachments) {
|
||||
// content.push({
|
||||
// type: "image_url",
|
||||
// image_url: attachment.contentString,
|
||||
// });
|
||||
// }
|
||||
// return { content: content.flat() };
|
||||
}
|
||||
|
||||
/**
|
||||
@ -195,72 +153,125 @@ class AWSBedrockLLM {
|
||||
contextTexts = [],
|
||||
chatHistory = [],
|
||||
userPrompt = "",
|
||||
attachments = [],
|
||||
_attachments = [],
|
||||
}) {
|
||||
// AWS Mistral models do not support system prompts
|
||||
if (this.model.startsWith("mistral"))
|
||||
return [
|
||||
...formatChatHistory(chatHistory, this.#generateContent, "spread"),
|
||||
let prompt = [
|
||||
{
|
||||
role: "system",
|
||||
content: [
|
||||
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
// If the model does not support system prompts, we need to add a user message and assistant message
|
||||
if (this.noSystemPromptModels.includes(this.model)) {
|
||||
prompt = [
|
||||
{
|
||||
role: "user",
|
||||
...this.#generateContent({ userPrompt, attachments }),
|
||||
content: [
|
||||
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
|
||||
],
|
||||
},
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ text: "Okay." }],
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
const prompt = {
|
||||
role: "system",
|
||||
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
|
||||
};
|
||||
return [
|
||||
prompt,
|
||||
...formatChatHistory(chatHistory, this.#generateContent, "spread"),
|
||||
...prompt,
|
||||
...chatHistory.map((msg) => ({
|
||||
role: msg.role,
|
||||
content: this.#generateContent({
|
||||
userPrompt: msg.content,
|
||||
attachments: msg.attachments,
|
||||
}),
|
||||
})),
|
||||
{
|
||||
role: "user",
|
||||
...this.#generateContent({ userPrompt, attachments }),
|
||||
content: this.#generateContent({
|
||||
userPrompt: userPrompt,
|
||||
attachments: [],
|
||||
}),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses and prepends reasoning from the response and returns the full text response.
|
||||
* @param {Object} response
|
||||
* @returns {string}
|
||||
*/
|
||||
#parseReasoningFromResponse({ content = [] }) {
|
||||
let textResponse = content[0]?.text;
|
||||
|
||||
if (
|
||||
!!content?.[1]?.reasoningContent &&
|
||||
content?.[1]?.reasoningContent?.reasoningText?.text?.trim().length > 0
|
||||
)
|
||||
textResponse = `<think>${content?.[1]?.reasoningContent?.reasoningText?.text}</think>${textResponse}`;
|
||||
|
||||
return textResponse;
|
||||
}
|
||||
|
||||
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||
const model = this.#bedrockClient({ temperature });
|
||||
const hasSystem = messages[0]?.role === "system";
|
||||
const [system, ...history] = hasSystem ? messages : [null, ...messages];
|
||||
|
||||
const result = await LLMPerformanceMonitor.measureAsyncFunction(
|
||||
model
|
||||
.pipe(new StringOutputParser())
|
||||
.invoke(this.#convertToLangchainPrototypes(messages))
|
||||
this.bedrockClient
|
||||
.send(
|
||||
new ConverseCommand({
|
||||
modelId: this.model,
|
||||
messages: history,
|
||||
inferenceConfig: {
|
||||
maxTokens: this.promptWindowLimit(),
|
||||
temperature,
|
||||
},
|
||||
system: !!system ? system.content : undefined,
|
||||
})
|
||||
)
|
||||
.catch((e) => {
|
||||
throw new Error(
|
||||
`AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}`
|
||||
);
|
||||
})
|
||||
}),
|
||||
messages,
|
||||
false
|
||||
);
|
||||
|
||||
if (!result.output || result.output.length === 0) return null;
|
||||
|
||||
// Langchain does not return the usage metrics in the response so we estimate them
|
||||
const promptTokens = LLMPerformanceMonitor.countTokens(messages);
|
||||
const completionTokens = LLMPerformanceMonitor.countTokens([
|
||||
{ content: result.output },
|
||||
]);
|
||||
|
||||
const response = result.output;
|
||||
if (!response || !response?.output) return null;
|
||||
return {
|
||||
textResponse: result.output,
|
||||
textResponse: this.#parseReasoningFromResponse(response.output?.message),
|
||||
metrics: {
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens,
|
||||
total_tokens: promptTokens + completionTokens,
|
||||
outputTps: completionTokens / result.duration,
|
||||
prompt_tokens: response?.usage?.inputTokens,
|
||||
completion_tokens: response?.usage?.outputTokens,
|
||||
total_tokens: response?.usage?.totalTokens,
|
||||
outputTps:
|
||||
response?.usage?.outputTokens / (response?.metrics?.latencyMs / 1000),
|
||||
duration: result.duration,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||
const model = this.#bedrockClient({ temperature });
|
||||
const hasSystem = messages[0]?.role === "system";
|
||||
const [system, ...history] = hasSystem ? messages : [null, ...messages];
|
||||
|
||||
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
|
||||
model
|
||||
.pipe(new StringOutputParser())
|
||||
.stream(this.#convertToLangchainPrototypes(messages)),
|
||||
messages
|
||||
this.bedrockClient.send(
|
||||
new ConverseStreamCommand({
|
||||
modelId: this.model,
|
||||
messages: history,
|
||||
inferenceConfig: { maxTokens: this.promptWindowLimit(), temperature },
|
||||
system: !!system ? system.content : undefined,
|
||||
})
|
||||
),
|
||||
messages,
|
||||
false
|
||||
);
|
||||
return measuredStreamRequest;
|
||||
}
|
||||
@ -275,12 +286,15 @@ class AWSBedrockLLM {
|
||||
*/
|
||||
handleStream(response, stream, responseProps) {
|
||||
const { uuid = uuidv4(), sources = [] } = responseProps;
|
||||
let hasUsageMetrics = false;
|
||||
let usage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
};
|
||||
|
||||
return new Promise(async (resolve) => {
|
||||
let fullText = "";
|
||||
let usage = {
|
||||
completion_tokens: 0,
|
||||
};
|
||||
let reasoningText = "";
|
||||
|
||||
// Establish listener to early-abort a streaming response
|
||||
// in case things go sideways or the user does not like the response.
|
||||
@ -293,25 +307,82 @@ class AWSBedrockLLM {
|
||||
response.on("close", handleAbort);
|
||||
|
||||
try {
|
||||
for await (const chunk of stream) {
|
||||
for await (const chunk of stream.stream) {
|
||||
if (chunk === undefined)
|
||||
throw new Error(
|
||||
"Stream returned undefined chunk. Aborting reply - check model provider logs."
|
||||
);
|
||||
|
||||
const content = chunk.hasOwnProperty("content")
|
||||
? chunk.content
|
||||
: chunk;
|
||||
fullText += content;
|
||||
if (!!content) usage.completion_tokens++; // Dont count empty chunks
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources: [],
|
||||
type: "textResponseChunk",
|
||||
textResponse: content,
|
||||
close: false,
|
||||
error: false,
|
||||
});
|
||||
const action = Object.keys(chunk)[0];
|
||||
if (action === "metadata") {
|
||||
hasUsageMetrics = true;
|
||||
usage.prompt_tokens = chunk.metadata?.usage?.inputTokens ?? 0;
|
||||
usage.completion_tokens = chunk.metadata?.usage?.outputTokens ?? 0;
|
||||
usage.total_tokens = chunk.metadata?.usage?.totalTokens ?? 0;
|
||||
}
|
||||
|
||||
if (action === "contentBlockDelta") {
|
||||
const token = chunk.contentBlockDelta?.delta?.text;
|
||||
const reasoningToken =
|
||||
chunk.contentBlockDelta?.delta?.reasoningContent?.text;
|
||||
|
||||
// Reasoning models will always return the reasoning text before the token text.
|
||||
if (reasoningToken) {
|
||||
// If the reasoning text is empty (''), we need to initialize it
|
||||
// and send the first chunk of reasoning text.
|
||||
if (reasoningText.length === 0) {
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources: [],
|
||||
type: "textResponseChunk",
|
||||
textResponse: `<think>${reasoningToken}`,
|
||||
close: false,
|
||||
error: false,
|
||||
});
|
||||
reasoningText += `<think>${reasoningToken}`;
|
||||
continue;
|
||||
} else {
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources: [],
|
||||
type: "textResponseChunk",
|
||||
textResponse: reasoningToken,
|
||||
close: false,
|
||||
error: false,
|
||||
});
|
||||
reasoningText += reasoningToken;
|
||||
}
|
||||
}
|
||||
|
||||
// If the reasoning text is not empty, but the reasoning token is empty
|
||||
// and the token text is not empty we need to close the reasoning text and begin sending the token text.
|
||||
if (!!reasoningText && !reasoningToken && token) {
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources: [],
|
||||
type: "textResponseChunk",
|
||||
textResponse: `</think>`,
|
||||
close: false,
|
||||
error: false,
|
||||
});
|
||||
fullText += `${reasoningText}</think>`;
|
||||
reasoningText = "";
|
||||
}
|
||||
|
||||
if (token) {
|
||||
fullText += token;
|
||||
// If we never saw a usage metric, we can estimate them by number of completion chunks
|
||||
if (!hasUsageMetrics) usage.completion_tokens++;
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
sources: [],
|
||||
type: "textResponseChunk",
|
||||
textResponse: token,
|
||||
close: false,
|
||||
error: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeResponseChunk(response, {
|
||||
@ -326,18 +397,18 @@ class AWSBedrockLLM {
|
||||
stream?.endMeasurement(usage);
|
||||
resolve(fullText);
|
||||
} catch (error) {
|
||||
console.log(`\x1b[43m\x1b[34m[STREAMING ERROR]\x1b[0m ${e.message}`);
|
||||
writeResponseChunk(response, {
|
||||
uuid,
|
||||
type: "abort",
|
||||
textResponse: null,
|
||||
sources: [],
|
||||
type: "textResponseChunk",
|
||||
textResponse: "",
|
||||
close: true,
|
||||
error: `AWSBedrock:streaming - could not stream chat. ${
|
||||
error?.cause ?? error.message
|
||||
}`,
|
||||
error: `AWSBedrock:streaming - could not stream chat. ${error?.cause ?? error.message}`,
|
||||
});
|
||||
response.removeListener("close", handleAbort);
|
||||
stream?.endMeasurement(usage);
|
||||
resolve(fullText); // Return what we currently have - if anything.
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
848
server/yarn.lock
848
server/yarn.lock
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user