support AWS bedrock agents with streaming (#4850)
* support AWS bedrock agents with streaming * Add back error handlers from previous fix
This commit is contained in:
parent
133b62f9f6
commit
7c3b7906e7
@ -146,10 +146,13 @@ class AWSBedrockLLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates if the provider supports streaming responses.
|
* Some Bedrock models (Titan, Cohere) don't support streaming.
|
||||||
* @returns {boolean} True.
|
* Set AWS_BEDROCK_STREAMING_DISABLED to any value to disable streaming for those models.
|
||||||
|
* Since this can be any model even custom models we leave it to the user to disable streaming if needed.
|
||||||
|
* @returns {boolean} True if streaming is supported, false otherwise.
|
||||||
*/
|
*/
|
||||||
streamingEnabled() {
|
streamingEnabled() {
|
||||||
|
if (!!process.env.AWS_BEDROCK_STREAMING_DISABLED) return false;
|
||||||
return "streamGetChatCompletion" in this;
|
return "streamGetChatCompletion" in this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -34,8 +34,15 @@ class AWSBedrockProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
this.verbose = true;
|
this.verbose = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Some Bedrock models (Titan, Cohere) don't support streaming.
|
||||||
|
* Set AWS_BEDROCK_STREAMING_DISABLED to any value to disable streaming for those models.
|
||||||
|
* Since this can be any model even custom models we leave it to the user to disable streaming if needed.
|
||||||
|
* @returns {boolean} True if streaming is supported, false otherwise.
|
||||||
|
*/
|
||||||
get supportsAgentStreaming() {
|
get supportsAgentStreaming() {
|
||||||
return false;
|
if (!!process.env.AWS_BEDROCK_STREAMING_DISABLED) return false;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -90,6 +97,73 @@ class AWSBedrockProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
return response?.content;
|
return response?.content;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a streaming response from the Langchain Bedrock client and convert
|
||||||
|
* it to OpenAI-compatible format expected by UnTooled.
|
||||||
|
* @param {Object} options - The options object containing messages.
|
||||||
|
* @param {Array} options.messages - The messages to send to the LLM.
|
||||||
|
* @returns {AsyncGenerator} An async iterable yielding OpenAI-compatible chunks.
|
||||||
|
*/
|
||||||
|
async #handleFunctionCallStream({ messages = [] }) {
|
||||||
|
const langchainMessages = this.#convertToLangchainPrototypes(messages);
|
||||||
|
const stream = await this.client.stream(langchainMessages);
|
||||||
|
|
||||||
|
// Wrap Langchain stream to OpenAI format expected by UnTooled
|
||||||
|
const self = this;
|
||||||
|
return {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
try {
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
// Langchain chunks have .content property directly
|
||||||
|
const content =
|
||||||
|
typeof chunk.content === "string" ? chunk.content : "";
|
||||||
|
if (content) {
|
||||||
|
yield {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
content: content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
AWSBedrockLLM.errorToHumanReadable(e, {
|
||||||
|
method: "stream",
|
||||||
|
model: self.model,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stream a chat completion from the Bedrock LLM with tool calling.
|
||||||
|
*
|
||||||
|
* @param {any[]} messages - The messages to send to the LLM.
|
||||||
|
* @param {any[]} functions - The functions to use in the LLM.
|
||||||
|
* @param {function} eventHandler - The event handler to use to report stream events.
|
||||||
|
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
|
||||||
|
*/
|
||||||
|
async stream(messages, functions = [], eventHandler = null) {
|
||||||
|
return await UnTooled.prototype.stream
|
||||||
|
.call(
|
||||||
|
this,
|
||||||
|
messages,
|
||||||
|
functions,
|
||||||
|
this.#handleFunctionCallStream.bind(this),
|
||||||
|
eventHandler
|
||||||
|
)
|
||||||
|
.catch((e) => {
|
||||||
|
AWSBedrockLLM.errorToHumanReadable(e, {
|
||||||
|
method: "stream",
|
||||||
|
model: this.model,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a completion based on the received messages.
|
* Create a completion based on the received messages.
|
||||||
*
|
*
|
||||||
@ -98,54 +172,14 @@ class AWSBedrockProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
* @returns The completion.
|
* @returns The completion.
|
||||||
*/
|
*/
|
||||||
async complete(messages, functions = []) {
|
async complete(messages, functions = []) {
|
||||||
try {
|
return await UnTooled.prototype.complete
|
||||||
let completion;
|
.call(this, messages, functions, this.#handleFunctionCallChat.bind(this))
|
||||||
if (functions.length > 0) {
|
.catch((e) => {
|
||||||
const { toolCall, text } = await this.functionCall(
|
AWSBedrockLLM.errorToHumanReadable(e, {
|
||||||
messages,
|
method: "complete",
|
||||||
functions,
|
model: this.model,
|
||||||
this.#handleFunctionCallChat.bind(this)
|
});
|
||||||
);
|
|
||||||
|
|
||||||
if (toolCall !== null) {
|
|
||||||
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
|
|
||||||
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
|
|
||||||
return {
|
|
||||||
result: null,
|
|
||||||
functionCall: {
|
|
||||||
name: toolCall.name,
|
|
||||||
arguments: toolCall.arguments,
|
|
||||||
},
|
|
||||||
cost: 0,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
completion = { content: text };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!completion?.content) {
|
|
||||||
this.providerLog(
|
|
||||||
"Will assume chat completion without tool call inputs."
|
|
||||||
);
|
|
||||||
const response = await this.client.invoke(
|
|
||||||
this.#convertToLangchainPrototypes(this.cleanMsgs(messages))
|
|
||||||
);
|
|
||||||
completion = response;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
|
|
||||||
// from calling the exact same function over and over in a loop within a single chat exchange
|
|
||||||
// _but_ we should enable it to call previously used tools in a new chat interaction.
|
|
||||||
this.deduplicator.reset("runs");
|
|
||||||
return {
|
|
||||||
result: completion.content,
|
|
||||||
cost: 0,
|
|
||||||
};
|
|
||||||
} catch (error) {
|
|
||||||
AWSBedrockLLM.errorToHumanReadable(error, {
|
|
||||||
method: "complete",
|
|
||||||
model: this.model,
|
|
||||||
});
|
});
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1251,6 +1251,9 @@ function dumpENV() {
|
|||||||
|
|
||||||
// Allow disabling of MCP tool cooldown
|
// Allow disabling of MCP tool cooldown
|
||||||
"MCP_NO_COOLDOWN",
|
"MCP_NO_COOLDOWN",
|
||||||
|
|
||||||
|
// Allow disabling of streaming for AWS Bedrock
|
||||||
|
"AWS_BEDROCK_STREAMING_DISABLED",
|
||||||
];
|
];
|
||||||
|
|
||||||
// Simple sanitization of each value to prevent ENV injection via newline or quote escaping.
|
// Simple sanitization of each value to prevent ENV injection via newline or quote escaping.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user