diff --git a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx index 020c5016..c96531e2 100644 --- a/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection/index.jsx @@ -36,8 +36,8 @@ const ENABLED_PROVIDERS = [ "foundry", "zai", "giteeai", + "cohere", // TODO: More agent support. - // "cohere", // Has tool calling and will need to build explicit support // "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested. ]; const WARN_PERFORMANCE = [ diff --git a/server/endpoints/utils.js b/server/endpoints/utils.js index fc5b4133..327b58f8 100644 --- a/server/endpoints/utils.js +++ b/server/endpoints/utils.js @@ -151,6 +151,9 @@ function getModelTag() { case "giteeai": model = process.env.GITEE_AI_MODEL_PREF; break; + case "cohere": + model = process.env.COHERE_MODEL_PREF; + break; default: model = "--"; break; diff --git a/server/package.json b/server/package.json index fa5ac5d5..45636f24 100644 --- a/server/package.json +++ b/server/package.json @@ -26,6 +26,7 @@ "@lancedb/lancedb": "0.15.0", "@langchain/anthropic": "0.1.16", "@langchain/aws": "^0.0.5", + "@langchain/cohere": "0.0.11", "@langchain/community": "0.0.53", "@langchain/core": "0.1.61", "@langchain/openai": "0.0.28", diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index add1adb1..cc15c123 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -990,6 +990,8 @@ ${this.getHistory({ to: route.to }) return new Providers.FoundryProvider({ model: config.model }); case "giteeai": return new Providers.GiteeAIProvider({ model: config.model }); + case "cohere": + return new Providers.CohereProvider({ model: config.model }); default: throw new Error( `Unknown provider: ${config.provider}. Please use a valid provider.` diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index c1a41909..3752b161 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -13,6 +13,7 @@ const { v4 } = require("uuid"); const { ChatOpenAI } = require("@langchain/openai"); const { ChatAnthropic } = require("@langchain/anthropic"); +const { ChatCohere } = require("@langchain/cohere"); const { ChatOllama } = require("@langchain/community/chat_models/ollama"); const { toValidNumber, safeJsonParse } = require("../../../http"); const { getLLMProviderClass } = require("../../../helpers"); @@ -239,6 +240,11 @@ class Provider { apiKey: process.env.GITEE_AI_API_KEY ?? null, ...config, }); + case "cohere": + return new ChatCohere({ + apiKey: process.env.COHERE_API_KEY ?? null, + ...config, + }); // OSS Model Runners // case "anythingllm_ollama": // return new ChatOllama({ @@ -307,7 +313,6 @@ class Provider { ...config, }); } - default: throw new Error(`Unsupported provider ${provider} for this task.`); } diff --git a/server/utils/agents/aibitat/providers/cohere.js b/server/utils/agents/aibitat/providers/cohere.js new file mode 100644 index 00000000..326a84bf --- /dev/null +++ b/server/utils/agents/aibitat/providers/cohere.js @@ -0,0 +1,251 @@ +const { CohereClient } = require("cohere-ai"); +const Provider = require("./ai-provider"); +const InheritMultiple = require("./helpers/classes"); +const UnTooled = require("./helpers/untooled"); +const { v4 } = require("uuid"); +const { safeJsonParse } = require("../../../http"); + +class CohereProvider extends InheritMultiple([Provider, UnTooled]) { + model; + + constructor(config = {}) { + const { model = process.env.COHERE_MODEL_PREF || "command-r-08-2024" } = + config; + super(); + const client = new CohereClient({ + token: process.env.COHERE_API_KEY, + }); + this._client = client; + this.model = model; + this.verbose = true; + } + + get client() { + return this._client; + } + + get supportsAgentStreaming() { + return true; + } + + #convertChatHistoryCohere(chatHistory = []) { + let cohereHistory = []; + chatHistory.forEach((message) => { + switch (message.role) { + case "SYSTEM": + case "system": + cohereHistory.push({ role: "SYSTEM", message: message.content }); + break; + case "USER": + case "user": + cohereHistory.push({ role: "USER", message: message.content }); + break; + case "CHATBOT": + case "assistant": + cohereHistory.push({ role: "CHATBOT", message: message.content }); + break; + } + }); + + return cohereHistory; + } + + async #handleFunctionCallStream({ messages = [] }) { + const userPrompt = messages[messages.length - 1]?.content || ""; + const history = messages.slice(0, -1); + return await this.client.chatStream({ + model: this.model, + chatHistory: this.#convertChatHistoryCohere(history), + message: userPrompt, + }); + } + + async stream(messages, functions = [], eventHandler = null) { + return await UnTooled.prototype.stream.call( + this, + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + } + + async streamingFunctionCall( + messages, + functions, + chatCb = null, + eventHandler = null + ) { + const history = [...messages].filter((msg) => + ["user", "assistant"].includes(msg.role) + ); + if (history[history.length - 1]?.role !== "user") return null; + + const msgUUID = v4(); + let textResponse = ""; + const historyMessages = this.buildToolCallMessages(history, functions); + const stream = await chatCb({ messages: historyMessages }); + + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Agent is thinking...", + }); + + for await (const event of stream) { + if (event.eventType !== "text-generation") continue; + textResponse += event.text; + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: msgUUID, + content: event.text, + }); + } + + const call = safeJsonParse(textResponse, null); + if (call === null) + return { toolCall: null, text: textResponse, uuid: msgUUID }; + + const { valid, reason } = this.validFuncCall(call, functions); + if (!valid) { + this.providerLog(`Invalid function tool call: ${reason}.`); + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: + "The model attempted to make an invalid function call - it was ignored.", + }); + return { toolCall: null, text: null, uuid: msgUUID }; + } + + const { isDuplicate, reason: duplicateReason } = + this.deduplicator.isDuplicate(call.name, call.arguments); + if (isDuplicate) { + this.providerLog( + `Cannot call ${call.name} again because ${duplicateReason}.` + ); + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: + "The model tried to call a function with the same arguments as a previous call - it was ignored.", + }); + return { toolCall: null, text: null, uuid: msgUUID }; + } + + eventHandler?.("reportStreamEvent", { + uuid: `${msgUUID}:tool_call_invocation`, + type: "toolCallInvocation", + content: `Parsed Tool Call: ${call.name}(${JSON.stringify(call.arguments)})`, + }); + return { toolCall: call, text: null, uuid: msgUUID }; + } + + /** + * Stream a chat completion from the LLM with tool calling + * Override the inherited `stream` method since Cohere uses a different API format. + * + * @param {any[]} messages - The messages to send to the LLM. + * @param {any[]} functions - The functions to use in the LLM. + * @param {function} eventHandler - The event handler to use to report stream events. + * @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion. + */ + async stream(messages, functions = [], eventHandler = null) { + this.providerLog( + "CohereProvider.stream - will process this chat completion." + ); + try { + let completion = { content: "" }; + if (functions.length > 0) { + const { + toolCall, + text, + uuid: msgUUID, + } = await this.streamingFunctionCall( + messages, + functions, + this.#handleFunctionCallStream.bind(this), + eventHandler + ); + + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments, { + cooldown: this.isMCPTool(toolCall, functions), + }); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; + } + + if (text) { + this.providerLog( + `No tool call found in the response - will send as a full text response.` + ); + completion.content = text; + eventHandler?.("reportStreamEvent", { + type: "removeStatusResponse", + uuid: msgUUID, + content: "No tool call found in the response", + }); + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Done thinking.", + }); + eventHandler?.("reportStreamEvent", { + type: "fullTextResponse", + uuid: v4(), + content: text, + }); + } + } + + if (!completion?.content) { + eventHandler?.("reportStreamEvent", { + type: "statusResponse", + uuid: v4(), + content: "Done thinking.", + }); + + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const msgUUID = v4(); + completion = { content: "" }; + const stream = await this.#handleFunctionCallStream({ + messages: this.cleanMsgs(messages), + }); + + for await (const chunk of stream) { + if (chunk.eventType !== "text-generation") continue; + completion.content += chunk.text; + eventHandler?.("reportStreamEvent", { + type: "textResponseChunk", + uuid: msgUUID, + content: chunk.text, + }); + } + } + + this.deduplicator.reset("runs"); + return { + textResponse: completion.content, + cost: 0, + }; + } catch (error) { + throw error; + } + } + + getCost(_usage) { + return 0; + } +} + +module.exports = CohereProvider; diff --git a/server/utils/agents/aibitat/providers/index.js b/server/utils/agents/aibitat/providers/index.js index 9ac8465f..e4a11995 100644 --- a/server/utils/agents/aibitat/providers/index.js +++ b/server/utils/agents/aibitat/providers/index.js @@ -28,6 +28,7 @@ const MoonshotAiProvider = require("./moonshotAi.js"); const CometApiProvider = require("./cometapi.js"); const FoundryProvider = require("./foundry.js"); const GiteeAIProvider = require("./giteeai.js"); +const CohereProvider = require("./cohere.js"); module.exports = { OpenAIProvider, @@ -60,4 +61,5 @@ module.exports = { MoonshotAiProvider, FoundryProvider, GiteeAIProvider, + CohereProvider, }; diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index b2d95676..3ae8a734 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -219,6 +219,11 @@ class AgentHandler { case "giteeai": if (!process.env.GITEE_AI_API_KEY) throw new Error("GiteeAI API Key must be provided to use agents."); + break; + case "cohere": + if (!process.env.COHERE_API_KEY) + throw new Error("Cohere API key must be provided to use agents."); + break; default: throw new Error( "No workspace agent provider set. Please set your agent provider in the workspace's settings" @@ -297,6 +302,8 @@ class AgentHandler { return process.env.FOUNDRY_MODEL_PREF ?? null; case "giteeai": return process.env.GITEE_AI_MODEL_PREF ?? null; + case "cohere": + return process.env.COHERE_MODEL_PREF ?? "command-r-08-2024"; default: return null; } diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index ddfa3154..e7e094a1 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -814,7 +814,7 @@ async function getCohereModels(_apiKey = null, type = "chat") { .then((results) => results.models) .then((models) => models.map((model) => ({ - id: model.id, + id: model.name, name: model.name, })) ) diff --git a/server/yarn.lock b/server/yarn.lock index 51b671f5..ccb1f779 100644 --- a/server/yarn.lock +++ b/server/yarn.lock @@ -1952,6 +1952,14 @@ "@langchain/core" ">=0.2.16 <0.3.0" zod-to-json-schema "^3.22.5" +"@langchain/cohere@0.0.11": + version "0.0.11" + resolved "https://registry.yarnpkg.com/@langchain/cohere/-/cohere-0.0.11.tgz#ba34117d589186c33fba4d677784132f6768958c" + integrity sha512-BFLXXyJzomWRIINja43IOckuaaz+FkQY1njaoBQRqUxy+yCIQv2SUTo20AWwUPczsxhspFbu7NEp5qxCjvWqEg== + dependencies: + "@langchain/core" ">0.1.58 <0.3.0" + cohere-ai "^7.9.3" + "@langchain/community@0.0.53", "@langchain/community@~0.0.47": version "0.0.53" resolved "https://registry.npmjs.org/@langchain/community/-/community-0.0.53.tgz" @@ -1984,6 +1992,23 @@ zod "^3.22.4" zod-to-json-schema "^3.22.3" +"@langchain/core@>0.1.58 <0.3.0": + version "0.2.36" + resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.36.tgz#75754c33aa5b9310dcf117047374a1ae011005a4" + integrity sha512-qHLvScqERDeH7y2cLuJaSAlMwg3f/3Oc9nayRSXRU2UuaK/SOhI42cxiPLj1FnuHJSmN0rBQFkrLx02gI4mcVg== + dependencies: + ansi-styles "^5.0.0" + camelcase "6" + decamelize "1.2.0" + js-tiktoken "^1.0.12" + langsmith "^0.1.56-rc.1" + mustache "^4.2.0" + p-queue "^6.6.2" + p-retry "4" + uuid "^10.0.0" + zod "^3.22.4" + zod-to-json-schema "^3.22.3" + "@langchain/core@>=0.2.16 <0.3.0": version "0.2.18" resolved "https://registry.npmjs.org/@langchain/core/-/core-0.2.18.tgz" @@ -3537,6 +3562,11 @@ resolved "https://registry.npmjs.org/@types/triple-beam/-/triple-beam-1.3.5.tgz" integrity sha512-6WaYesThRMCl19iryMYP7/x2OVgCtbIVflDGFpWnb9irXI3UjYE4AzmYuiUKY1AJstGijoY+MgUszMgRxIYTYw== +"@types/uuid@^10.0.0": + version "10.0.0" + resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-10.0.0.tgz#e9c07fe50da0f53dc24970cca94d619ff03f6f6d" + integrity sha512-7gqG38EyHgyP1S+7+xomFtL+ZNHcKv6DwNaCZmJmo1vgMugyF3TCnXVg4t1uk89mLNwnLtnY3TpOpCOyp1/xHQ== + "@types/uuid@^9.0.1": version "9.0.8" resolved "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz" @@ -4238,6 +4268,22 @@ cohere-ai@^7.19.0: readable-stream "^4.5.2" url-join "4.0.1" +cohere-ai@^7.9.3: + version "7.20.0" + resolved "https://registry.yarnpkg.com/cohere-ai/-/cohere-ai-7.20.0.tgz#5d350747b4a24e8855b66475f1f4908c4908240e" + integrity sha512-h/3h3pcLXRUmkzp/W+/FWViEMcAFtSZ8YayCTFQXpib112uNSj3feApOtJg7v9lreWR1t7gznhE6N9KNCX5FOA== + dependencies: + "@aws-crypto/sha256-js" "^5.2.0" + "@aws-sdk/client-sagemaker" "^3.583.0" + "@aws-sdk/credential-providers" "^3.583.0" + "@smithy/protocol-http" "^5.1.2" + "@smithy/signature-v4" "^5.1.2" + convict "^6.2.4" + form-data "^4.0.4" + form-data-encoder "^4.1.0" + formdata-node "^6.0.3" + readable-stream "^4.7.0" + color-convert@^1.9.3: version "1.9.3" resolved "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz" @@ -5488,7 +5534,7 @@ form-data-encoder@1.7.2: resolved "https://registry.npmjs.org/form-data-encoder/-/form-data-encoder-1.7.2.tgz" integrity sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A== -form-data-encoder@^4.0.2: +form-data-encoder@^4.0.2, form-data-encoder@^4.1.0: version "4.1.0" resolved "https://registry.yarnpkg.com/form-data-encoder/-/form-data-encoder-4.1.0.tgz#497cedc94810bd5d53b99b5d4f6c152d5cbc9db2" integrity sha512-G6NsmEW15s0Uw9XnCg+33H3ViYRyiM0hMrMhhqQOR8NFc5GhYrI+6I3u7OTw7b91J2g8rtvMBZJDbcGb2YUniw== @@ -5504,7 +5550,7 @@ form-data@3.0.4, form-data@^3.0.0: hasown "^2.0.2" mime-types "^2.1.35" -form-data@4.0.0, form-data@4.0.4, form-data@^4.0.0: +form-data@4.0.0, form-data@4.0.4, form-data@^4.0.0, form-data@^4.0.4: version "4.0.4" resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.4.tgz#784cdcce0669a9d68e94d11ac4eea98088edd2c4" integrity sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow== @@ -6473,6 +6519,18 @@ langchainhub@~0.0.8: resolved "https://registry.npmjs.org/langchainhub/-/langchainhub-0.0.8.tgz" integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ== +langsmith@^0.1.56-rc.1: + version "0.1.68" + resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.68.tgz#848332e822fe5e6734a07f1c36b6530cc1798afb" + integrity sha512-otmiysWtVAqzMx3CJ4PrtUBhWRG5Co8Z4o7hSZENPjlit9/j3/vm3TSvbaxpDYakZxtMjhkcJTqrdYFipISEiQ== + dependencies: + "@types/uuid" "^10.0.0" + commander "^10.0.1" + p-queue "^6.6.2" + p-retry "4" + semver "^7.6.3" + uuid "^10.0.0" + langsmith@~0.1.1, langsmith@~0.1.7: version "0.1.21" resolved "https://registry.npmjs.org/langsmith/-/langsmith-0.1.21.tgz" @@ -7713,7 +7771,7 @@ readable-stream@^4.2.0: process "^0.11.10" string_decoder "^1.3.0" -readable-stream@^4.5.2: +readable-stream@^4.5.2, readable-stream@^4.7.0: version "4.7.0" resolved "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz" integrity sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg== @@ -7892,6 +7950,11 @@ semver@^7.3.5, semver@^7.5.4: dependencies: lru-cache "^6.0.0" +semver@^7.6.3: + version "7.7.3" + resolved "https://registry.yarnpkg.com/semver/-/semver-7.7.3.tgz#4b5f4143d007633a8dc671cd0a6ef9147b8bb946" + integrity sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q== + semver@~7.0.0: version "7.0.0" resolved "https://registry.npmjs.org/semver/-/semver-7.0.0.tgz"