Feat/cohere agent implementation (#4703)

* implement cohere agent support

* run yarn lint

* moderize Cohere
add supported langchain method
redo streaming since it was not working
looping of agent calls was not functioning

* change default model to real model tag
add case statement for model tag

* remove debug

* update default

* only whitelist known labels

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
This commit is contained in:
Colin Perry 2025-12-12 16:25:58 -08:00 committed by GitHub
parent 62b45a76dc
commit a8bdc00aba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 340 additions and 6 deletions

View File

@ -36,8 +36,8 @@ const ENABLED_PROVIDERS = [
"foundry", "foundry",
"zai", "zai",
"giteeai", "giteeai",
"cohere",
// TODO: More agent support. // TODO: More agent support.
// "cohere", // Has tool calling and will need to build explicit support
// "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested. // "huggingface" // Can be done but already has issues with no-chat templated. Needs to be tested.
]; ];
const WARN_PERFORMANCE = [ const WARN_PERFORMANCE = [

View File

@ -151,6 +151,9 @@ function getModelTag() {
case "giteeai": case "giteeai":
model = process.env.GITEE_AI_MODEL_PREF; model = process.env.GITEE_AI_MODEL_PREF;
break; break;
case "cohere":
model = process.env.COHERE_MODEL_PREF;
break;
default: default:
model = "--"; model = "--";
break; break;

View File

@ -26,6 +26,7 @@
"@lancedb/lancedb": "0.15.0", "@lancedb/lancedb": "0.15.0",
"@langchain/anthropic": "0.1.16", "@langchain/anthropic": "0.1.16",
"@langchain/aws": "^0.0.5", "@langchain/aws": "^0.0.5",
"@langchain/cohere": "0.0.11",
"@langchain/community": "0.0.53", "@langchain/community": "0.0.53",
"@langchain/core": "0.1.61", "@langchain/core": "0.1.61",
"@langchain/openai": "0.0.28", "@langchain/openai": "0.0.28",

View File

@ -990,6 +990,8 @@ ${this.getHistory({ to: route.to })
return new Providers.FoundryProvider({ model: config.model }); return new Providers.FoundryProvider({ model: config.model });
case "giteeai": case "giteeai":
return new Providers.GiteeAIProvider({ model: config.model }); return new Providers.GiteeAIProvider({ model: config.model });
case "cohere":
return new Providers.CohereProvider({ model: config.model });
default: default:
throw new Error( throw new Error(
`Unknown provider: ${config.provider}. Please use a valid provider.` `Unknown provider: ${config.provider}. Please use a valid provider.`

View File

@ -13,6 +13,7 @@
const { v4 } = require("uuid"); const { v4 } = require("uuid");
const { ChatOpenAI } = require("@langchain/openai"); const { ChatOpenAI } = require("@langchain/openai");
const { ChatAnthropic } = require("@langchain/anthropic"); const { ChatAnthropic } = require("@langchain/anthropic");
const { ChatCohere } = require("@langchain/cohere");
const { ChatOllama } = require("@langchain/community/chat_models/ollama"); const { ChatOllama } = require("@langchain/community/chat_models/ollama");
const { toValidNumber, safeJsonParse } = require("../../../http"); const { toValidNumber, safeJsonParse } = require("../../../http");
const { getLLMProviderClass } = require("../../../helpers"); const { getLLMProviderClass } = require("../../../helpers");
@ -239,6 +240,11 @@ class Provider {
apiKey: process.env.GITEE_AI_API_KEY ?? null, apiKey: process.env.GITEE_AI_API_KEY ?? null,
...config, ...config,
}); });
case "cohere":
return new ChatCohere({
apiKey: process.env.COHERE_API_KEY ?? null,
...config,
});
// OSS Model Runners // OSS Model Runners
// case "anythingllm_ollama": // case "anythingllm_ollama":
// return new ChatOllama({ // return new ChatOllama({
@ -307,7 +313,6 @@ class Provider {
...config, ...config,
}); });
} }
default: default:
throw new Error(`Unsupported provider ${provider} for this task.`); throw new Error(`Unsupported provider ${provider} for this task.`);
} }

View File

@ -0,0 +1,251 @@
const { CohereClient } = require("cohere-ai");
const Provider = require("./ai-provider");
const InheritMultiple = require("./helpers/classes");
const UnTooled = require("./helpers/untooled");
const { v4 } = require("uuid");
const { safeJsonParse } = require("../../../http");
class CohereProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = process.env.COHERE_MODEL_PREF || "command-r-08-2024" } =
config;
super();
const client = new CohereClient({
token: process.env.COHERE_API_KEY,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
get supportsAgentStreaming() {
return true;
}
#convertChatHistoryCohere(chatHistory = []) {
let cohereHistory = [];
chatHistory.forEach((message) => {
switch (message.role) {
case "SYSTEM":
case "system":
cohereHistory.push({ role: "SYSTEM", message: message.content });
break;
case "USER":
case "user":
cohereHistory.push({ role: "USER", message: message.content });
break;
case "CHATBOT":
case "assistant":
cohereHistory.push({ role: "CHATBOT", message: message.content });
break;
}
});
return cohereHistory;
}
async #handleFunctionCallStream({ messages = [] }) {
const userPrompt = messages[messages.length - 1]?.content || "";
const history = messages.slice(0, -1);
return await this.client.chatStream({
model: this.model,
chatHistory: this.#convertChatHistoryCohere(history),
message: userPrompt,
});
}
async stream(messages, functions = [], eventHandler = null) {
return await UnTooled.prototype.stream.call(
this,
messages,
functions,
this.#handleFunctionCallStream.bind(this),
eventHandler
);
}
async streamingFunctionCall(
messages,
functions,
chatCb = null,
eventHandler = null
) {
const history = [...messages].filter((msg) =>
["user", "assistant"].includes(msg.role)
);
if (history[history.length - 1]?.role !== "user") return null;
const msgUUID = v4();
let textResponse = "";
const historyMessages = this.buildToolCallMessages(history, functions);
const stream = await chatCb({ messages: historyMessages });
eventHandler?.("reportStreamEvent", {
type: "statusResponse",
uuid: v4(),
content: "Agent is thinking...",
});
for await (const event of stream) {
if (event.eventType !== "text-generation") continue;
textResponse += event.text;
eventHandler?.("reportStreamEvent", {
type: "statusResponse",
uuid: msgUUID,
content: event.text,
});
}
const call = safeJsonParse(textResponse, null);
if (call === null)
return { toolCall: null, text: textResponse, uuid: msgUUID };
const { valid, reason } = this.validFuncCall(call, functions);
if (!valid) {
this.providerLog(`Invalid function tool call: ${reason}.`);
eventHandler?.("reportStreamEvent", {
type: "removeStatusResponse",
uuid: msgUUID,
content:
"The model attempted to make an invalid function call - it was ignored.",
});
return { toolCall: null, text: null, uuid: msgUUID };
}
const { isDuplicate, reason: duplicateReason } =
this.deduplicator.isDuplicate(call.name, call.arguments);
if (isDuplicate) {
this.providerLog(
`Cannot call ${call.name} again because ${duplicateReason}.`
);
eventHandler?.("reportStreamEvent", {
type: "removeStatusResponse",
uuid: msgUUID,
content:
"The model tried to call a function with the same arguments as a previous call - it was ignored.",
});
return { toolCall: null, text: null, uuid: msgUUID };
}
eventHandler?.("reportStreamEvent", {
uuid: `${msgUUID}:tool_call_invocation`,
type: "toolCallInvocation",
content: `Parsed Tool Call: ${call.name}(${JSON.stringify(call.arguments)})`,
});
return { toolCall: call, text: null, uuid: msgUUID };
}
/**
* Stream a chat completion from the LLM with tool calling
* Override the inherited `stream` method since Cohere uses a different API format.
*
* @param {any[]} messages - The messages to send to the LLM.
* @param {any[]} functions - The functions to use in the LLM.
* @param {function} eventHandler - The event handler to use to report stream events.
* @returns {Promise<{ functionCall: any, textResponse: string }>} - The result of the chat completion.
*/
async stream(messages, functions = [], eventHandler = null) {
this.providerLog(
"CohereProvider.stream - will process this chat completion."
);
try {
let completion = { content: "" };
if (functions.length > 0) {
const {
toolCall,
text,
uuid: msgUUID,
} = await this.streamingFunctionCall(
messages,
functions,
this.#handleFunctionCallStream.bind(this),
eventHandler
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments, {
cooldown: this.isMCPTool(toolCall, functions),
});
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
if (text) {
this.providerLog(
`No tool call found in the response - will send as a full text response.`
);
completion.content = text;
eventHandler?.("reportStreamEvent", {
type: "removeStatusResponse",
uuid: msgUUID,
content: "No tool call found in the response",
});
eventHandler?.("reportStreamEvent", {
type: "statusResponse",
uuid: v4(),
content: "Done thinking.",
});
eventHandler?.("reportStreamEvent", {
type: "fullTextResponse",
uuid: v4(),
content: text,
});
}
}
if (!completion?.content) {
eventHandler?.("reportStreamEvent", {
type: "statusResponse",
uuid: v4(),
content: "Done thinking.",
});
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const msgUUID = v4();
completion = { content: "" };
const stream = await this.#handleFunctionCallStream({
messages: this.cleanMsgs(messages),
});
for await (const chunk of stream) {
if (chunk.eventType !== "text-generation") continue;
completion.content += chunk.text;
eventHandler?.("reportStreamEvent", {
type: "textResponseChunk",
uuid: msgUUID,
content: chunk.text,
});
}
}
this.deduplicator.reset("runs");
return {
textResponse: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
getCost(_usage) {
return 0;
}
}
module.exports = CohereProvider;

View File

@ -28,6 +28,7 @@ const MoonshotAiProvider = require("./moonshotAi.js");
const CometApiProvider = require("./cometapi.js"); const CometApiProvider = require("./cometapi.js");
const FoundryProvider = require("./foundry.js"); const FoundryProvider = require("./foundry.js");
const GiteeAIProvider = require("./giteeai.js"); const GiteeAIProvider = require("./giteeai.js");
const CohereProvider = require("./cohere.js");
module.exports = { module.exports = {
OpenAIProvider, OpenAIProvider,
@ -60,4 +61,5 @@ module.exports = {
MoonshotAiProvider, MoonshotAiProvider,
FoundryProvider, FoundryProvider,
GiteeAIProvider, GiteeAIProvider,
CohereProvider,
}; };

View File

@ -219,6 +219,11 @@ class AgentHandler {
case "giteeai": case "giteeai":
if (!process.env.GITEE_AI_API_KEY) if (!process.env.GITEE_AI_API_KEY)
throw new Error("GiteeAI API Key must be provided to use agents."); throw new Error("GiteeAI API Key must be provided to use agents.");
break;
case "cohere":
if (!process.env.COHERE_API_KEY)
throw new Error("Cohere API key must be provided to use agents.");
break;
default: default:
throw new Error( throw new Error(
"No workspace agent provider set. Please set your agent provider in the workspace's settings" "No workspace agent provider set. Please set your agent provider in the workspace's settings"
@ -297,6 +302,8 @@ class AgentHandler {
return process.env.FOUNDRY_MODEL_PREF ?? null; return process.env.FOUNDRY_MODEL_PREF ?? null;
case "giteeai": case "giteeai":
return process.env.GITEE_AI_MODEL_PREF ?? null; return process.env.GITEE_AI_MODEL_PREF ?? null;
case "cohere":
return process.env.COHERE_MODEL_PREF ?? "command-r-08-2024";
default: default:
return null; return null;
} }

View File

@ -814,7 +814,7 @@ async function getCohereModels(_apiKey = null, type = "chat") {
.then((results) => results.models) .then((results) => results.models)
.then((models) => .then((models) =>
models.map((model) => ({ models.map((model) => ({
id: model.id, id: model.name,
name: model.name, name: model.name,
})) }))
) )

View File

@ -1952,6 +1952,14 @@
"@langchain/core" ">=0.2.16 <0.3.0" "@langchain/core" ">=0.2.16 <0.3.0"
zod-to-json-schema "^3.22.5" zod-to-json-schema "^3.22.5"
"@langchain/cohere@0.0.11":
version "0.0.11"
resolved "https://registry.yarnpkg.com/@langchain/cohere/-/cohere-0.0.11.tgz#ba34117d589186c33fba4d677784132f6768958c"
integrity sha512-BFLXXyJzomWRIINja43IOckuaaz+FkQY1njaoBQRqUxy+yCIQv2SUTo20AWwUPczsxhspFbu7NEp5qxCjvWqEg==
dependencies:
"@langchain/core" ">0.1.58 <0.3.0"
cohere-ai "^7.9.3"
"@langchain/community@0.0.53", "@langchain/community@~0.0.47": "@langchain/community@0.0.53", "@langchain/community@~0.0.47":
version "0.0.53" version "0.0.53"
resolved "https://registry.npmjs.org/@langchain/community/-/community-0.0.53.tgz" resolved "https://registry.npmjs.org/@langchain/community/-/community-0.0.53.tgz"
@ -1984,6 +1992,23 @@
zod "^3.22.4" zod "^3.22.4"
zod-to-json-schema "^3.22.3" zod-to-json-schema "^3.22.3"
"@langchain/core@>0.1.58 <0.3.0":
version "0.2.36"
resolved "https://registry.yarnpkg.com/@langchain/core/-/core-0.2.36.tgz#75754c33aa5b9310dcf117047374a1ae011005a4"
integrity sha512-qHLvScqERDeH7y2cLuJaSAlMwg3f/3Oc9nayRSXRU2UuaK/SOhI42cxiPLj1FnuHJSmN0rBQFkrLx02gI4mcVg==
dependencies:
ansi-styles "^5.0.0"
camelcase "6"
decamelize "1.2.0"
js-tiktoken "^1.0.12"
langsmith "^0.1.56-rc.1"
mustache "^4.2.0"
p-queue "^6.6.2"
p-retry "4"
uuid "^10.0.0"
zod "^3.22.4"
zod-to-json-schema "^3.22.3"
"@langchain/core@>=0.2.16 <0.3.0": "@langchain/core@>=0.2.16 <0.3.0":
version "0.2.18" version "0.2.18"
resolved "https://registry.npmjs.org/@langchain/core/-/core-0.2.18.tgz" resolved "https://registry.npmjs.org/@langchain/core/-/core-0.2.18.tgz"
@ -3537,6 +3562,11 @@
resolved "https://registry.npmjs.org/@types/triple-beam/-/triple-beam-1.3.5.tgz" resolved "https://registry.npmjs.org/@types/triple-beam/-/triple-beam-1.3.5.tgz"
integrity sha512-6WaYesThRMCl19iryMYP7/x2OVgCtbIVflDGFpWnb9irXI3UjYE4AzmYuiUKY1AJstGijoY+MgUszMgRxIYTYw== integrity sha512-6WaYesThRMCl19iryMYP7/x2OVgCtbIVflDGFpWnb9irXI3UjYE4AzmYuiUKY1AJstGijoY+MgUszMgRxIYTYw==
"@types/uuid@^10.0.0":
version "10.0.0"
resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-10.0.0.tgz#e9c07fe50da0f53dc24970cca94d619ff03f6f6d"
integrity sha512-7gqG38EyHgyP1S+7+xomFtL+ZNHcKv6DwNaCZmJmo1vgMugyF3TCnXVg4t1uk89mLNwnLtnY3TpOpCOyp1/xHQ==
"@types/uuid@^9.0.1": "@types/uuid@^9.0.1":
version "9.0.8" version "9.0.8"
resolved "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz" resolved "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz"
@ -4238,6 +4268,22 @@ cohere-ai@^7.19.0:
readable-stream "^4.5.2" readable-stream "^4.5.2"
url-join "4.0.1" url-join "4.0.1"
cohere-ai@^7.9.3:
version "7.20.0"
resolved "https://registry.yarnpkg.com/cohere-ai/-/cohere-ai-7.20.0.tgz#5d350747b4a24e8855b66475f1f4908c4908240e"
integrity sha512-h/3h3pcLXRUmkzp/W+/FWViEMcAFtSZ8YayCTFQXpib112uNSj3feApOtJg7v9lreWR1t7gznhE6N9KNCX5FOA==
dependencies:
"@aws-crypto/sha256-js" "^5.2.0"
"@aws-sdk/client-sagemaker" "^3.583.0"
"@aws-sdk/credential-providers" "^3.583.0"
"@smithy/protocol-http" "^5.1.2"
"@smithy/signature-v4" "^5.1.2"
convict "^6.2.4"
form-data "^4.0.4"
form-data-encoder "^4.1.0"
formdata-node "^6.0.3"
readable-stream "^4.7.0"
color-convert@^1.9.3: color-convert@^1.9.3:
version "1.9.3" version "1.9.3"
resolved "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz" resolved "https://registry.npmjs.org/color-convert/-/color-convert-1.9.3.tgz"
@ -5488,7 +5534,7 @@ form-data-encoder@1.7.2:
resolved "https://registry.npmjs.org/form-data-encoder/-/form-data-encoder-1.7.2.tgz" resolved "https://registry.npmjs.org/form-data-encoder/-/form-data-encoder-1.7.2.tgz"
integrity sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A== integrity sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A==
form-data-encoder@^4.0.2: form-data-encoder@^4.0.2, form-data-encoder@^4.1.0:
version "4.1.0" version "4.1.0"
resolved "https://registry.yarnpkg.com/form-data-encoder/-/form-data-encoder-4.1.0.tgz#497cedc94810bd5d53b99b5d4f6c152d5cbc9db2" resolved "https://registry.yarnpkg.com/form-data-encoder/-/form-data-encoder-4.1.0.tgz#497cedc94810bd5d53b99b5d4f6c152d5cbc9db2"
integrity sha512-G6NsmEW15s0Uw9XnCg+33H3ViYRyiM0hMrMhhqQOR8NFc5GhYrI+6I3u7OTw7b91J2g8rtvMBZJDbcGb2YUniw== integrity sha512-G6NsmEW15s0Uw9XnCg+33H3ViYRyiM0hMrMhhqQOR8NFc5GhYrI+6I3u7OTw7b91J2g8rtvMBZJDbcGb2YUniw==
@ -5504,7 +5550,7 @@ form-data@3.0.4, form-data@^3.0.0:
hasown "^2.0.2" hasown "^2.0.2"
mime-types "^2.1.35" mime-types "^2.1.35"
form-data@4.0.0, form-data@4.0.4, form-data@^4.0.0: form-data@4.0.0, form-data@4.0.4, form-data@^4.0.0, form-data@^4.0.4:
version "4.0.4" version "4.0.4"
resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.4.tgz#784cdcce0669a9d68e94d11ac4eea98088edd2c4" resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.4.tgz#784cdcce0669a9d68e94d11ac4eea98088edd2c4"
integrity sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow== integrity sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==
@ -6473,6 +6519,18 @@ langchainhub@~0.0.8:
resolved "https://registry.npmjs.org/langchainhub/-/langchainhub-0.0.8.tgz" resolved "https://registry.npmjs.org/langchainhub/-/langchainhub-0.0.8.tgz"
integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ== integrity sha512-Woyb8YDHgqqTOZvWIbm2CaFDGfZ4NTSyXV687AG4vXEfoNo7cGQp7nhl7wL3ehenKWmNEmcxCLgOZzW8jE6lOQ==
langsmith@^0.1.56-rc.1:
version "0.1.68"
resolved "https://registry.yarnpkg.com/langsmith/-/langsmith-0.1.68.tgz#848332e822fe5e6734a07f1c36b6530cc1798afb"
integrity sha512-otmiysWtVAqzMx3CJ4PrtUBhWRG5Co8Z4o7hSZENPjlit9/j3/vm3TSvbaxpDYakZxtMjhkcJTqrdYFipISEiQ==
dependencies:
"@types/uuid" "^10.0.0"
commander "^10.0.1"
p-queue "^6.6.2"
p-retry "4"
semver "^7.6.3"
uuid "^10.0.0"
langsmith@~0.1.1, langsmith@~0.1.7: langsmith@~0.1.1, langsmith@~0.1.7:
version "0.1.21" version "0.1.21"
resolved "https://registry.npmjs.org/langsmith/-/langsmith-0.1.21.tgz" resolved "https://registry.npmjs.org/langsmith/-/langsmith-0.1.21.tgz"
@ -7713,7 +7771,7 @@ readable-stream@^4.2.0:
process "^0.11.10" process "^0.11.10"
string_decoder "^1.3.0" string_decoder "^1.3.0"
readable-stream@^4.5.2: readable-stream@^4.5.2, readable-stream@^4.7.0:
version "4.7.0" version "4.7.0"
resolved "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz" resolved "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz"
integrity sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg== integrity sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==
@ -7892,6 +7950,11 @@ semver@^7.3.5, semver@^7.5.4:
dependencies: dependencies:
lru-cache "^6.0.0" lru-cache "^6.0.0"
semver@^7.6.3:
version "7.7.3"
resolved "https://registry.yarnpkg.com/semver/-/semver-7.7.3.tgz#4b5f4143d007633a8dc671cd0a6ef9147b8bb946"
integrity sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==
semver@~7.0.0: semver@~7.0.0:
version "7.0.0" version "7.0.0"
resolved "https://registry.npmjs.org/semver/-/semver-7.0.0.tgz" resolved "https://registry.npmjs.org/semver/-/semver-7.0.0.tgz"