Overhaul AWS Bedrock provider (#3537)
* Patch AWS Bedrock provider for newer models and performance * patch prompt constructor
This commit is contained in:
parent
c29f7669c4
commit
78c83383d8
@ -20,6 +20,7 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@anthropic-ai/sdk": "^0.39.0",
|
"@anthropic-ai/sdk": "^0.39.0",
|
||||||
|
"@aws-sdk/client-bedrock-runtime": "^3.775.0",
|
||||||
"@azure/openai": "1.0.0-beta.10",
|
"@azure/openai": "1.0.0-beta.10",
|
||||||
"@datastax/astra-db-ts": "^0.1.3",
|
"@datastax/astra-db-ts": "^0.1.3",
|
||||||
"@google/generative-ai": "^0.7.1",
|
"@google/generative-ai": "^0.7.1",
|
||||||
|
|||||||
@ -1,15 +1,18 @@
|
|||||||
const { StringOutputParser } = require("@langchain/core/output_parsers");
|
const {
|
||||||
|
BedrockRuntimeClient,
|
||||||
|
ConverseCommand,
|
||||||
|
ConverseStreamCommand,
|
||||||
|
} = require("@aws-sdk/client-bedrock-runtime");
|
||||||
const {
|
const {
|
||||||
writeResponseChunk,
|
writeResponseChunk,
|
||||||
clientAbortedHandler,
|
clientAbortedHandler,
|
||||||
formatChatHistory,
|
|
||||||
} = require("../../helpers/chat/responses");
|
} = require("../../helpers/chat/responses");
|
||||||
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
|
const { NativeEmbedder } = require("../../EmbeddingEngines/native");
|
||||||
const {
|
const {
|
||||||
LLMPerformanceMonitor,
|
LLMPerformanceMonitor,
|
||||||
} = require("../../helpers/chat/LLMPerformanceMonitor");
|
} = require("../../helpers/chat/LLMPerformanceMonitor");
|
||||||
|
const { v4: uuidv4 } = require("uuid");
|
||||||
|
|
||||||
// Docs: https://js.langchain.com/v0.2/docs/integrations/chat/bedrock_converse
|
|
||||||
class AWSBedrockLLM {
|
class AWSBedrockLLM {
|
||||||
/**
|
/**
|
||||||
* These models do not support system prompts
|
* These models do not support system prompts
|
||||||
@ -23,6 +26,7 @@ class AWSBedrockLLM {
|
|||||||
"amazon.titan-text-lite-v1",
|
"amazon.titan-text-lite-v1",
|
||||||
"cohere.command-text-v14",
|
"cohere.command-text-v14",
|
||||||
"cohere.command-light-text-v14",
|
"cohere.command-light-text-v14",
|
||||||
|
"us.deepseek.r1-v1:0",
|
||||||
];
|
];
|
||||||
|
|
||||||
constructor(embedder = null, modelPreference = null) {
|
constructor(embedder = null, modelPreference = null) {
|
||||||
@ -51,6 +55,17 @@ class AWSBedrockLLM {
|
|||||||
user: this.promptWindowLimit() * 0.7,
|
user: this.promptWindowLimit() * 0.7,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
this.bedrockClient = new BedrockRuntimeClient({
|
||||||
|
region: process.env.AWS_BEDROCK_LLM_REGION,
|
||||||
|
credentials: {
|
||||||
|
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
|
||||||
|
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
|
||||||
|
...(this.authMethod === "sessionToken"
|
||||||
|
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
|
||||||
|
: {}),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
this.embedder = embedder ?? new NativeEmbedder();
|
this.embedder = embedder ?? new NativeEmbedder();
|
||||||
this.defaultTemp = 0.7;
|
this.defaultTemp = 0.7;
|
||||||
this.#log(
|
this.#log(
|
||||||
@ -69,62 +84,6 @@ class AWSBedrockLLM {
|
|||||||
return method;
|
return method;
|
||||||
}
|
}
|
||||||
|
|
||||||
#bedrockClient({ temperature = 0.7 }) {
|
|
||||||
const { ChatBedrockConverse } = require("@langchain/aws");
|
|
||||||
return new ChatBedrockConverse({
|
|
||||||
model: this.model,
|
|
||||||
region: process.env.AWS_BEDROCK_LLM_REGION,
|
|
||||||
credentials: {
|
|
||||||
accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
|
|
||||||
secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
|
|
||||||
...(this.authMethod === "sessionToken"
|
|
||||||
? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
|
|
||||||
: {}),
|
|
||||||
},
|
|
||||||
temperature,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// For streaming we use Langchain's wrapper to handle weird chunks
|
|
||||||
// or otherwise absorb headaches that can arise from Bedrock models
|
|
||||||
#convertToLangchainPrototypes(chats = []) {
|
|
||||||
const {
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
AIMessage,
|
|
||||||
} = require("@langchain/core/messages");
|
|
||||||
const langchainChats = [];
|
|
||||||
const roleToMessageMap = {
|
|
||||||
system: SystemMessage,
|
|
||||||
user: HumanMessage,
|
|
||||||
assistant: AIMessage,
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const chat of chats) {
|
|
||||||
if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
|
|
||||||
|
|
||||||
// When a model does not support system prompts, we need to handle it.
|
|
||||||
// We will add a new message that simulates the system prompt via a user message and AI response.
|
|
||||||
// This will allow the model to respond without crashing but we can still inject context.
|
|
||||||
if (
|
|
||||||
this.noSystemPromptModels.includes(this.model) &&
|
|
||||||
chat.role === "system"
|
|
||||||
) {
|
|
||||||
this.#log(
|
|
||||||
`Model does not support system prompts! Simulating system prompt via Human/AI message pairs.`
|
|
||||||
);
|
|
||||||
langchainChats.push(new HumanMessage({ content: chat.content }));
|
|
||||||
langchainChats.push(new AIMessage({ content: "Okay." }));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const MessageClass = roleToMessageMap[chat.role];
|
|
||||||
langchainChats.push(new MessageClass({ content: chat.content }));
|
|
||||||
}
|
|
||||||
|
|
||||||
return langchainChats;
|
|
||||||
}
|
|
||||||
|
|
||||||
#appendContext(contextTexts = []) {
|
#appendContext(contextTexts = []) {
|
||||||
if (!contextTexts || !contextTexts.length) return "";
|
if (!contextTexts || !contextTexts.length) return "";
|
||||||
return (
|
return (
|
||||||
@ -167,22 +126,21 @@ class AWSBedrockLLM {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Generates appropriate content array for a message + attachments.
|
* Generates appropriate content array for a message + attachments.
|
||||||
|
* TODO: Implement this - attachments are not supported yet for Bedrock.
|
||||||
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
|
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
|
||||||
* @returns {string|object[]}
|
* @returns {string|object[]}
|
||||||
*/
|
*/
|
||||||
#generateContent({ userPrompt, attachments = [] }) {
|
#generateContent({ userPrompt, attachments = [] }) {
|
||||||
if (!attachments.length) {
|
if (!attachments.length) return [{ text: userPrompt }];
|
||||||
return { content: userPrompt };
|
|
||||||
}
|
|
||||||
|
|
||||||
const content = [{ type: "text", text: userPrompt }];
|
// const content = [{ type: "text", text: userPrompt }];
|
||||||
for (let attachment of attachments) {
|
// for (let attachment of attachments) {
|
||||||
content.push({
|
// content.push({
|
||||||
type: "image_url",
|
// type: "image_url",
|
||||||
image_url: attachment.contentString,
|
// image_url: attachment.contentString,
|
||||||
});
|
// });
|
||||||
}
|
// }
|
||||||
return { content: content.flat() };
|
// return { content: content.flat() };
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -195,72 +153,125 @@ class AWSBedrockLLM {
|
|||||||
contextTexts = [],
|
contextTexts = [],
|
||||||
chatHistory = [],
|
chatHistory = [],
|
||||||
userPrompt = "",
|
userPrompt = "",
|
||||||
attachments = [],
|
_attachments = [],
|
||||||
}) {
|
}) {
|
||||||
// AWS Mistral models do not support system prompts
|
let prompt = [
|
||||||
if (this.model.startsWith("mistral"))
|
{
|
||||||
return [
|
role: "system",
|
||||||
...formatChatHistory(chatHistory, this.#generateContent, "spread"),
|
content: [
|
||||||
|
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// If the model does not support system prompts, we need to add a user message and assistant message
|
||||||
|
if (this.noSystemPromptModels.includes(this.model)) {
|
||||||
|
prompt = [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
...this.#generateContent({ userPrompt, attachments }),
|
content: [
|
||||||
|
{ text: `${systemPrompt}${this.#appendContext(contextTexts)}` },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: [{ text: "Okay." }],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
}
|
||||||
|
|
||||||
const prompt = {
|
|
||||||
role: "system",
|
|
||||||
content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
|
|
||||||
};
|
|
||||||
return [
|
return [
|
||||||
prompt,
|
...prompt,
|
||||||
...formatChatHistory(chatHistory, this.#generateContent, "spread"),
|
...chatHistory.map((msg) => ({
|
||||||
|
role: msg.role,
|
||||||
|
content: this.#generateContent({
|
||||||
|
userPrompt: msg.content,
|
||||||
|
attachments: msg.attachments,
|
||||||
|
}),
|
||||||
|
})),
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
...this.#generateContent({ userPrompt, attachments }),
|
content: this.#generateContent({
|
||||||
|
userPrompt: userPrompt,
|
||||||
|
attachments: [],
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses and prepends reasoning from the response and returns the full text response.
|
||||||
|
* @param {Object} response
|
||||||
|
* @returns {string}
|
||||||
|
*/
|
||||||
|
#parseReasoningFromResponse({ content = [] }) {
|
||||||
|
let textResponse = content[0]?.text;
|
||||||
|
|
||||||
|
if (
|
||||||
|
!!content?.[1]?.reasoningContent &&
|
||||||
|
content?.[1]?.reasoningContent?.reasoningText?.text?.trim().length > 0
|
||||||
|
)
|
||||||
|
textResponse = `<think>${content?.[1]?.reasoningContent?.reasoningText?.text}</think>${textResponse}`;
|
||||||
|
|
||||||
|
return textResponse;
|
||||||
|
}
|
||||||
|
|
||||||
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||||
const model = this.#bedrockClient({ temperature });
|
const hasSystem = messages[0]?.role === "system";
|
||||||
|
const [system, ...history] = hasSystem ? messages : [null, ...messages];
|
||||||
|
|
||||||
const result = await LLMPerformanceMonitor.measureAsyncFunction(
|
const result = await LLMPerformanceMonitor.measureAsyncFunction(
|
||||||
model
|
this.bedrockClient
|
||||||
.pipe(new StringOutputParser())
|
.send(
|
||||||
.invoke(this.#convertToLangchainPrototypes(messages))
|
new ConverseCommand({
|
||||||
|
modelId: this.model,
|
||||||
|
messages: history,
|
||||||
|
inferenceConfig: {
|
||||||
|
maxTokens: this.promptWindowLimit(),
|
||||||
|
temperature,
|
||||||
|
},
|
||||||
|
system: !!system ? system.content : undefined,
|
||||||
|
})
|
||||||
|
)
|
||||||
.catch((e) => {
|
.catch((e) => {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}`
|
`AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}`
|
||||||
);
|
);
|
||||||
})
|
}),
|
||||||
|
messages,
|
||||||
|
false
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!result.output || result.output.length === 0) return null;
|
const response = result.output;
|
||||||
|
if (!response || !response?.output) return null;
|
||||||
// Langchain does not return the usage metrics in the response so we estimate them
|
|
||||||
const promptTokens = LLMPerformanceMonitor.countTokens(messages);
|
|
||||||
const completionTokens = LLMPerformanceMonitor.countTokens([
|
|
||||||
{ content: result.output },
|
|
||||||
]);
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
textResponse: result.output,
|
textResponse: this.#parseReasoningFromResponse(response.output?.message),
|
||||||
metrics: {
|
metrics: {
|
||||||
prompt_tokens: promptTokens,
|
prompt_tokens: response?.usage?.inputTokens,
|
||||||
completion_tokens: completionTokens,
|
completion_tokens: response?.usage?.outputTokens,
|
||||||
total_tokens: promptTokens + completionTokens,
|
total_tokens: response?.usage?.totalTokens,
|
||||||
outputTps: completionTokens / result.duration,
|
outputTps:
|
||||||
|
response?.usage?.outputTokens / (response?.metrics?.latencyMs / 1000),
|
||||||
duration: result.duration,
|
duration: result.duration,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
||||||
const model = this.#bedrockClient({ temperature });
|
const hasSystem = messages[0]?.role === "system";
|
||||||
|
const [system, ...history] = hasSystem ? messages : [null, ...messages];
|
||||||
|
|
||||||
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
|
const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
|
||||||
model
|
this.bedrockClient.send(
|
||||||
.pipe(new StringOutputParser())
|
new ConverseStreamCommand({
|
||||||
.stream(this.#convertToLangchainPrototypes(messages)),
|
modelId: this.model,
|
||||||
messages
|
messages: history,
|
||||||
|
inferenceConfig: { maxTokens: this.promptWindowLimit(), temperature },
|
||||||
|
system: !!system ? system.content : undefined,
|
||||||
|
})
|
||||||
|
),
|
||||||
|
messages,
|
||||||
|
false
|
||||||
);
|
);
|
||||||
return measuredStreamRequest;
|
return measuredStreamRequest;
|
||||||
}
|
}
|
||||||
@ -275,12 +286,15 @@ class AWSBedrockLLM {
|
|||||||
*/
|
*/
|
||||||
handleStream(response, stream, responseProps) {
|
handleStream(response, stream, responseProps) {
|
||||||
const { uuid = uuidv4(), sources = [] } = responseProps;
|
const { uuid = uuidv4(), sources = [] } = responseProps;
|
||||||
|
let hasUsageMetrics = false;
|
||||||
|
let usage = {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
return new Promise(async (resolve) => {
|
return new Promise(async (resolve) => {
|
||||||
let fullText = "";
|
let fullText = "";
|
||||||
let usage = {
|
let reasoningText = "";
|
||||||
completion_tokens: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Establish listener to early-abort a streaming response
|
// Establish listener to early-abort a streaming response
|
||||||
// in case things go sideways or the user does not like the response.
|
// in case things go sideways or the user does not like the response.
|
||||||
@ -293,25 +307,82 @@ class AWSBedrockLLM {
|
|||||||
response.on("close", handleAbort);
|
response.on("close", handleAbort);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream.stream) {
|
||||||
if (chunk === undefined)
|
if (chunk === undefined)
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"Stream returned undefined chunk. Aborting reply - check model provider logs."
|
"Stream returned undefined chunk. Aborting reply - check model provider logs."
|
||||||
);
|
);
|
||||||
|
|
||||||
const content = chunk.hasOwnProperty("content")
|
const action = Object.keys(chunk)[0];
|
||||||
? chunk.content
|
if (action === "metadata") {
|
||||||
: chunk;
|
hasUsageMetrics = true;
|
||||||
fullText += content;
|
usage.prompt_tokens = chunk.metadata?.usage?.inputTokens ?? 0;
|
||||||
if (!!content) usage.completion_tokens++; // Dont count empty chunks
|
usage.completion_tokens = chunk.metadata?.usage?.outputTokens ?? 0;
|
||||||
writeResponseChunk(response, {
|
usage.total_tokens = chunk.metadata?.usage?.totalTokens ?? 0;
|
||||||
uuid,
|
}
|
||||||
sources: [],
|
|
||||||
type: "textResponseChunk",
|
if (action === "contentBlockDelta") {
|
||||||
textResponse: content,
|
const token = chunk.contentBlockDelta?.delta?.text;
|
||||||
close: false,
|
const reasoningToken =
|
||||||
error: false,
|
chunk.contentBlockDelta?.delta?.reasoningContent?.text;
|
||||||
});
|
|
||||||
|
// Reasoning models will always return the reasoning text before the token text.
|
||||||
|
if (reasoningToken) {
|
||||||
|
// If the reasoning text is empty (''), we need to initialize it
|
||||||
|
// and send the first chunk of reasoning text.
|
||||||
|
if (reasoningText.length === 0) {
|
||||||
|
writeResponseChunk(response, {
|
||||||
|
uuid,
|
||||||
|
sources: [],
|
||||||
|
type: "textResponseChunk",
|
||||||
|
textResponse: `<think>${reasoningToken}`,
|
||||||
|
close: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
reasoningText += `<think>${reasoningToken}`;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
writeResponseChunk(response, {
|
||||||
|
uuid,
|
||||||
|
sources: [],
|
||||||
|
type: "textResponseChunk",
|
||||||
|
textResponse: reasoningToken,
|
||||||
|
close: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
reasoningText += reasoningToken;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the reasoning text is not empty, but the reasoning token is empty
|
||||||
|
// and the token text is not empty we need to close the reasoning text and begin sending the token text.
|
||||||
|
if (!!reasoningText && !reasoningToken && token) {
|
||||||
|
writeResponseChunk(response, {
|
||||||
|
uuid,
|
||||||
|
sources: [],
|
||||||
|
type: "textResponseChunk",
|
||||||
|
textResponse: `</think>`,
|
||||||
|
close: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
fullText += `${reasoningText}</think>`;
|
||||||
|
reasoningText = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (token) {
|
||||||
|
fullText += token;
|
||||||
|
// If we never saw a usage metric, we can estimate them by number of completion chunks
|
||||||
|
if (!hasUsageMetrics) usage.completion_tokens++;
|
||||||
|
writeResponseChunk(response, {
|
||||||
|
uuid,
|
||||||
|
sources: [],
|
||||||
|
type: "textResponseChunk",
|
||||||
|
textResponse: token,
|
||||||
|
close: false,
|
||||||
|
error: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
writeResponseChunk(response, {
|
writeResponseChunk(response, {
|
||||||
@ -326,18 +397,18 @@ class AWSBedrockLLM {
|
|||||||
stream?.endMeasurement(usage);
|
stream?.endMeasurement(usage);
|
||||||
resolve(fullText);
|
resolve(fullText);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
console.log(`\x1b[43m\x1b[34m[STREAMING ERROR]\x1b[0m ${e.message}`);
|
||||||
writeResponseChunk(response, {
|
writeResponseChunk(response, {
|
||||||
uuid,
|
uuid,
|
||||||
|
type: "abort",
|
||||||
|
textResponse: null,
|
||||||
sources: [],
|
sources: [],
|
||||||
type: "textResponseChunk",
|
|
||||||
textResponse: "",
|
|
||||||
close: true,
|
close: true,
|
||||||
error: `AWSBedrock:streaming - could not stream chat. ${
|
error: `AWSBedrock:streaming - could not stream chat. ${error?.cause ?? error.message}`,
|
||||||
error?.cause ?? error.message
|
|
||||||
}`,
|
|
||||||
});
|
});
|
||||||
response.removeListener("close", handleAbort);
|
response.removeListener("close", handleAbort);
|
||||||
stream?.endMeasurement(usage);
|
stream?.endMeasurement(usage);
|
||||||
|
resolve(fullText); // Return what we currently have - if anything.
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
848
server/yarn.lock
848
server/yarn.lock
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user