Implement full chat and @agent chat user indentificiation for OpenRouter (#4668)
Implmenet chat and agentic chat user-id for OpenRouter resolves #4553 closes #4482
This commit is contained in:
parent
05df4ac72b
commit
cf76bad452
@ -238,7 +238,7 @@ class OpenRouterLLM {
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
async getChatCompletion(messages = null, { temperature = 0.7 }) {
|
async getChatCompletion(messages = null, { temperature = 0.7, user = null }) {
|
||||||
if (!(await this.isValidChatCompletionModel(this.model)))
|
if (!(await this.isValidChatCompletionModel(this.model)))
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`OpenRouter chat: ${this.model} is not valid for chat completion!`
|
`OpenRouter chat: ${this.model} is not valid for chat completion!`
|
||||||
@ -253,6 +253,7 @@ class OpenRouterLLM {
|
|||||||
// This is an OpenRouter specific option that allows us to get the reasoning text
|
// This is an OpenRouter specific option that allows us to get the reasoning text
|
||||||
// before the token text.
|
// before the token text.
|
||||||
include_reasoning: true,
|
include_reasoning: true,
|
||||||
|
user: user?.id ? `user_${user.id}` : "",
|
||||||
})
|
})
|
||||||
.catch((e) => {
|
.catch((e) => {
|
||||||
throw new Error(e.message);
|
throw new Error(e.message);
|
||||||
@ -279,7 +280,10 @@ class OpenRouterLLM {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
|
async streamGetChatCompletion(
|
||||||
|
messages = null,
|
||||||
|
{ temperature = 0.7, user = null }
|
||||||
|
) {
|
||||||
if (!(await this.isValidChatCompletionModel(this.model)))
|
if (!(await this.isValidChatCompletionModel(this.model)))
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`OpenRouter chat: ${this.model} is not valid for chat completion!`
|
`OpenRouter chat: ${this.model} is not valid for chat completion!`
|
||||||
@ -294,6 +298,7 @@ class OpenRouterLLM {
|
|||||||
// This is an OpenRouter specific option that allows us to get the reasoning text
|
// This is an OpenRouter specific option that allows us to get the reasoning text
|
||||||
// before the token text.
|
// before the token text.
|
||||||
include_reasoning: true,
|
include_reasoning: true,
|
||||||
|
user: user?.id ? `user_${user.id}` : "",
|
||||||
}),
|
}),
|
||||||
messages
|
messages
|
||||||
// We have to manually count the tokens
|
// We have to manually count the tokens
|
||||||
|
|||||||
@ -474,6 +474,8 @@ class AIbitat {
|
|||||||
...this.defaultProvider,
|
...this.defaultProvider,
|
||||||
...channelConfig,
|
...channelConfig,
|
||||||
});
|
});
|
||||||
|
provider.attachHandlerProps(this.handlerProps);
|
||||||
|
|
||||||
const history = this.getHistory({ to: channel });
|
const history = this.getHistory({ to: channel });
|
||||||
|
|
||||||
// build the messages to send to the provider
|
// build the messages to send to the provider
|
||||||
@ -594,6 +596,7 @@ ${this.getHistory({ to: route.to })
|
|||||||
...this.defaultProvider,
|
...this.defaultProvider,
|
||||||
...fromConfig,
|
...fromConfig,
|
||||||
});
|
});
|
||||||
|
provider.attachHandlerProps(this.handlerProps);
|
||||||
|
|
||||||
let content;
|
let content;
|
||||||
if (provider.supportsAgentStreaming) {
|
if (provider.supportsAgentStreaming) {
|
||||||
@ -911,11 +914,10 @@ ${this.getHistory({ to: route.to })
|
|||||||
* If the provider is a string, it will return the default provider for that string.
|
* If the provider is a string, it will return the default provider for that string.
|
||||||
*
|
*
|
||||||
* @param config The provider configuration.
|
* @param config The provider configuration.
|
||||||
|
* @returns {Providers.OpenAIProvider} The provider instance.
|
||||||
*/
|
*/
|
||||||
getProviderForConfig(config) {
|
getProviderForConfig(config) {
|
||||||
if (typeof config.provider === "object") {
|
if (typeof config.provider === "object") return config.provider;
|
||||||
return config.provider;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (config.provider) {
|
switch (config.provider) {
|
||||||
case "openai":
|
case "openai":
|
||||||
|
|||||||
@ -28,6 +28,22 @@ const DEFAULT_WORKSPACE_PROMPT =
|
|||||||
|
|
||||||
class Provider {
|
class Provider {
|
||||||
_client;
|
_client;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The invocation object containing the user ID and other invocation details.
|
||||||
|
* @type {import("@prisma/client").workspace_agent_invocations}
|
||||||
|
*/
|
||||||
|
invocation = {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The user ID for the chat completion to send to the LLM provider for user tracking.
|
||||||
|
* In order for this to be set, the handler props must be attached to the provider after instantiation.
|
||||||
|
* ex: this.attachHandlerProps({ ..., invocation: { ..., user_id: 123 } });
|
||||||
|
* eg: `user_123`
|
||||||
|
* @type {string}
|
||||||
|
*/
|
||||||
|
executingUserId = "";
|
||||||
|
|
||||||
constructor(client) {
|
constructor(client) {
|
||||||
if (this.constructor == Provider) {
|
if (this.constructor == Provider) {
|
||||||
return;
|
return;
|
||||||
@ -42,6 +58,19 @@ class Provider {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attaches handler props to the provider for reuse in the provider.
|
||||||
|
* - Explicitly sets the invocation object.
|
||||||
|
* - Explicitly sets the executing user ID from the invocation object.
|
||||||
|
* @param {Object} handlerProps - The handler props to attach to the provider.
|
||||||
|
*/
|
||||||
|
attachHandlerProps(handlerProps = {}) {
|
||||||
|
this.invocation = handlerProps?.invocation || {};
|
||||||
|
this.executingUserId = this.invocation?.user_id
|
||||||
|
? `user_${this.invocation.user_id}`
|
||||||
|
: "";
|
||||||
|
}
|
||||||
|
|
||||||
get client() {
|
get client() {
|
||||||
return this._client;
|
return this._client;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,8 @@ const UnTooled = require("./helpers/untooled.js");
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* The agent provider for the OpenRouter provider.
|
* The agent provider for the OpenRouter provider.
|
||||||
|
* @extends {Provider}
|
||||||
|
* @extends {UnTooled}
|
||||||
*/
|
*/
|
||||||
class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
||||||
model;
|
model;
|
||||||
@ -40,6 +42,7 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
.create({
|
.create({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
messages,
|
messages,
|
||||||
|
user: this.executingUserId,
|
||||||
})
|
})
|
||||||
.then((result) => {
|
.then((result) => {
|
||||||
if (!result.hasOwnProperty("choices"))
|
if (!result.hasOwnProperty("choices"))
|
||||||
@ -58,6 +61,7 @@ class OpenRouterProvider extends InheritMultiple([Provider, UnTooled]) {
|
|||||||
model: this.model,
|
model: this.model,
|
||||||
stream: true,
|
stream: true,
|
||||||
messages,
|
messages,
|
||||||
|
user: this.executingUserId,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -307,6 +307,7 @@ async function chatSync({
|
|||||||
const { textResponse, metrics: performanceMetrics } =
|
const { textResponse, metrics: performanceMetrics } =
|
||||||
await LLMConnector.getChatCompletion(messages, {
|
await LLMConnector.getChatCompletion(messages, {
|
||||||
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||||
|
user: user,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!textResponse) {
|
if (!textResponse) {
|
||||||
@ -649,6 +650,7 @@ async function streamChat({
|
|||||||
const { textResponse, metrics: performanceMetrics } =
|
const { textResponse, metrics: performanceMetrics } =
|
||||||
await LLMConnector.getChatCompletion(messages, {
|
await LLMConnector.getChatCompletion(messages, {
|
||||||
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||||
|
user: user,
|
||||||
});
|
});
|
||||||
completeText = textResponse;
|
completeText = textResponse;
|
||||||
metrics = performanceMetrics;
|
metrics = performanceMetrics;
|
||||||
@ -664,6 +666,7 @@ async function streamChat({
|
|||||||
} else {
|
} else {
|
||||||
const stream = await LLMConnector.streamGetChatCompletion(messages, {
|
const stream = await LLMConnector.streamGetChatCompletion(messages, {
|
||||||
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||||
|
user: user,
|
||||||
});
|
});
|
||||||
completeText = await LLMConnector.handleStream(response, stream, { uuid });
|
completeText = await LLMConnector.handleStream(response, stream, { uuid });
|
||||||
metrics = stream.metrics;
|
metrics = stream.metrics;
|
||||||
|
|||||||
@ -248,6 +248,7 @@ async function streamChatWithWorkspace(
|
|||||||
const { textResponse, metrics: performanceMetrics } =
|
const { textResponse, metrics: performanceMetrics } =
|
||||||
await LLMConnector.getChatCompletion(messages, {
|
await LLMConnector.getChatCompletion(messages, {
|
||||||
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||||
|
user: user,
|
||||||
});
|
});
|
||||||
|
|
||||||
completeText = textResponse;
|
completeText = textResponse;
|
||||||
@ -264,6 +265,7 @@ async function streamChatWithWorkspace(
|
|||||||
} else {
|
} else {
|
||||||
const stream = await LLMConnector.streamGetChatCompletion(messages, {
|
const stream = await LLMConnector.streamGetChatCompletion(messages, {
|
||||||
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
temperature: workspace?.openAiTemp ?? LLMConnector.defaultTemp,
|
||||||
|
user: user,
|
||||||
});
|
});
|
||||||
completeText = await LLMConnector.handleStream(response, stream, {
|
completeText = await LLMConnector.handleStream(response, stream, {
|
||||||
uuid,
|
uuid,
|
||||||
|
|||||||
@ -24,6 +24,7 @@
|
|||||||
*
|
*
|
||||||
* @typedef {Object} ChatCompletionOptions
|
* @typedef {Object} ChatCompletionOptions
|
||||||
* @property {number} temperature - The sampling temperature for the LLM response
|
* @property {number} temperature - The sampling temperature for the LLM response
|
||||||
|
* @property {import("@prisma/client").users} user - The user object for the chat completion to send to the LLM provider for user tracking (optional)
|
||||||
*
|
*
|
||||||
* @typedef {function(Array<ChatMessage>, ChatCompletionOptions): Promise<ChatCompletionResponse>} getChatCompletionFunction
|
* @typedef {function(Array<ChatMessage>, ChatCompletionOptions): Promise<ChatCompletionResponse>} getChatCompletionFunction
|
||||||
*
|
*
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user