- Chat model{" "}
- ({settings?.LLMProvider})
+ Workspace Chat model
The specific chat model that will be used for this workspace. If
@@ -59,9 +57,6 @@ export default function ChatModelSelection({
}}
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
>
-
- System default
-
{defaultModels.length > 0 && (
{defaultModels.map((model) => {
diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/WorkspaceLLMItem/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/WorkspaceLLMItem/index.jsx
new file mode 100644
index 00000000..872d2a42
--- /dev/null
+++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/WorkspaceLLMItem/index.jsx
@@ -0,0 +1,151 @@
+// This component differs from the main LLMItem in that it shows if a provider is
+// "ready for use" and if not - will then highjack the click handler to show a modal
+// of the provider options that must be saved to continue.
+import { createPortal } from "react-dom";
+import ModalWrapper from "@/components/ModalWrapper";
+import { useModal } from "@/hooks/useModal";
+import { X } from "@phosphor-icons/react";
+import System from "@/models/system";
+import showToast from "@/utils/toast";
+
+export default function WorkspaceLLM({
+ llm,
+ availableLLMs,
+ settings,
+ checked,
+ onClick,
+}) {
+ const { isOpen, openModal, closeModal } = useModal();
+ const { name, value, logo, description } = llm;
+
+ function handleProviderSelection() {
+ // Determine if provider needs additional setup because its minimum required keys are
+ // not yet set in settings.
+ const requiresAdditionalSetup = (llm.requiredConfig || []).some(
+ (key) => !settings[key]
+ );
+ if (requiresAdditionalSetup) {
+ openModal();
+ return;
+ }
+ onClick(value);
+ }
+
+ return (
+ <>
+
+
+
+
+
+
{name}
+
{description}
+
+
+
+
+ >
+ );
+}
+
+function SetupProvider({
+ availableLLMs,
+ isOpen,
+ provider,
+ closeModal,
+ postSubmit,
+}) {
+ if (!isOpen) return null;
+ const LLMOption = availableLLMs.find((llm) => llm.value === provider);
+ if (!LLMOption) return null;
+
+ async function handleUpdate(e) {
+ e.preventDefault();
+ e.stopPropagation();
+ const data = {};
+ const form = new FormData(e.target);
+ for (var [key, value] of form.entries()) data[key] = value;
+ const { error } = await System.updateSystem(data);
+ if (error) {
+ showToast(`Failed to save ${LLMOption.name} settings: ${error}`, "error");
+ return;
+ }
+
+ closeModal();
+ postSubmit();
+ return false;
+ }
+
+ // Cannot do nested forms, it will cause all sorts of issues, so we portal this out
+ // to the parent container form so we don't have nested forms.
+ return createPortal(
+
+
+
+
+
+ Setup {LLMOption.name}
+
+
+
+
+
+
+
+
+
+ ,
+ document.getElementById("workspace-chat-settings-container")
+ );
+}
diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx
new file mode 100644
index 00000000..07e35596
--- /dev/null
+++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/WorkspaceLLMSelection/index.jsx
@@ -0,0 +1,159 @@
+import React, { useEffect, useRef, useState } from "react";
+import AnythingLLMIcon from "@/media/logo/anything-llm-icon.png";
+import WorkspaceLLMItem from "./WorkspaceLLMItem";
+import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
+import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react";
+import ChatModelSelection from "../ChatModelSelection";
+
+const DISABLED_PROVIDERS = ["azure", "lmstudio", "native"];
+const LLM_DEFAULT = {
+ name: "System default",
+ value: "default",
+ logo: AnythingLLMIcon,
+ options: () => ,
+ description: "Use the system LLM preference for this workspace.",
+ requiredConfig: [],
+};
+
+export default function WorkspaceLLMSelection({
+ settings,
+ workspace,
+ setHasChanges,
+}) {
+ const [filteredLLMs, setFilteredLLMs] = useState([]);
+ const [selectedLLM, setSelectedLLM] = useState(
+ workspace?.chatProvider ?? "default"
+ );
+ const [searchQuery, setSearchQuery] = useState("");
+ const [searchMenuOpen, setSearchMenuOpen] = useState(false);
+ const searchInputRef = useRef(null);
+ const LLMS = [LLM_DEFAULT, ...AVAILABLE_LLM_PROVIDERS].filter(
+ (llm) => !DISABLED_PROVIDERS.includes(llm.value)
+ );
+
+ function updateLLMChoice(selection) {
+ console.log({ selection });
+ setSearchQuery("");
+ setSelectedLLM(selection);
+ setSearchMenuOpen(false);
+ setHasChanges(true);
+ }
+
+ function handleXButton() {
+ if (searchQuery.length > 0) {
+ setSearchQuery("");
+ if (searchInputRef.current) searchInputRef.current.value = "";
+ } else {
+ setSearchMenuOpen(!searchMenuOpen);
+ }
+ }
+
+ useEffect(() => {
+ const filtered = LLMS.filter((llm) =>
+ llm.name.toLowerCase().includes(searchQuery.toLowerCase())
+ );
+ setFilteredLLMs(filtered);
+ }, [LLMS, searchQuery, selectedLLM]);
+
+ const selectedLLMObject = LLMS.find((llm) => llm.value === selectedLLM);
+ return (
+
+
+
+ Workspace LLM Provider
+
+
+ The specific LLM provider & model that will be used for this
+ workspace. By default, it uses the system LLM provider and settings.
+
+
+
+
+
+ {searchMenuOpen && (
+
setSearchMenuOpen(false)}
+ />
+ )}
+ {searchMenuOpen ? (
+
+
+
+
+ setSearchQuery(e.target.value)}
+ ref={searchInputRef}
+ onKeyDown={(e) => {
+ if (e.key === "Enter") e.preventDefault();
+ }}
+ />
+
+
+
+ {filteredLLMs.map((llm) => {
+ return (
+ updateLLMChoice(llm.value)}
+ />
+ );
+ })}
+
+
+
+ ) : (
+
setSearchMenuOpen(true)}
+ >
+
+
+
+
+ {selectedLLMObject.name}
+
+
+ {selectedLLMObject.description}
+
+
+
+
+
+ )}
+
+ {selectedLLM !== "default" && (
+
+
+
+ )}
+
+ );
+}
diff --git a/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx b/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx
index 3004b871..a6bab2c3 100644
--- a/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx
+++ b/frontend/src/pages/WorkspaceSettings/ChatSettings/index.jsx
@@ -3,11 +3,11 @@ import Workspace from "@/models/workspace";
import showToast from "@/utils/toast";
import { castToType } from "@/utils/types";
import { useEffect, useRef, useState } from "react";
-import ChatModelSelection from "./ChatModelSelection";
import ChatHistorySettings from "./ChatHistorySettings";
import ChatPromptSettings from "./ChatPromptSettings";
import ChatTemperatureSettings from "./ChatTemperatureSettings";
import ChatModeSelection from "./ChatModeSelection";
+import WorkspaceLLMSelection from "./WorkspaceLLMSelection";
export default function ChatSettings({ workspace }) {
const [settings, setSettings] = useState({});
@@ -44,35 +44,45 @@ export default function ChatSettings({ workspace }) {
if (!workspace) return null;
return (
-
+
+
+
);
}
diff --git a/server/endpoints/workspaces.js b/server/endpoints/workspaces.js
index da9e2ad9..1c87dc36 100644
--- a/server/endpoints/workspaces.js
+++ b/server/endpoints/workspaces.js
@@ -508,7 +508,7 @@ function workspaceEndpoints(app) {
if (fs.existsSync(oldPfpPath)) fs.unlinkSync(oldPfpPath);
}
- const { workspace, message } = await Workspace.update(
+ const { workspace, message } = await Workspace._update(
workspaceRecord.id,
{
pfpFilename: uploadedFileName,
@@ -547,7 +547,7 @@ function workspaceEndpoints(app) {
if (fs.existsSync(oldPfpPath)) fs.unlinkSync(oldPfpPath);
}
- const { workspace, message } = await Workspace.update(
+ const { workspace, message } = await Workspace._update(
workspaceRecord.id,
{
pfpFilename: null,
diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js
index c4529ad9..080a01f0 100644
--- a/server/models/systemSettings.js
+++ b/server/models/systemSettings.js
@@ -57,103 +57,13 @@ const SystemSettings = {
// VectorDB Provider Selection Settings & Configs
// --------------------------------------------------------
VectorDB: vectorDB,
- // Pinecone DB Keys
- PineConeKey: !!process.env.PINECONE_API_KEY,
- PineConeIndex: process.env.PINECONE_INDEX,
-
- // Chroma DB Keys
- ChromaEndpoint: process.env.CHROMA_ENDPOINT,
- ChromaApiHeader: process.env.CHROMA_API_HEADER,
- ChromaApiKey: !!process.env.CHROMA_API_KEY,
-
- // Weaviate DB Keys
- WeaviateEndpoint: process.env.WEAVIATE_ENDPOINT,
- WeaviateApiKey: process.env.WEAVIATE_API_KEY,
-
- // QDrant DB Keys
- QdrantEndpoint: process.env.QDRANT_ENDPOINT,
- QdrantApiKey: process.env.QDRANT_API_KEY,
-
- // Milvus DB Keys
- MilvusAddress: process.env.MILVUS_ADDRESS,
- MilvusUsername: process.env.MILVUS_USERNAME,
- MilvusPassword: !!process.env.MILVUS_PASSWORD,
-
- // Zilliz DB Keys
- ZillizEndpoint: process.env.ZILLIZ_ENDPOINT,
- ZillizApiToken: process.env.ZILLIZ_API_TOKEN,
-
- // AstraDB Keys
- AstraDBApplicationToken: process?.env?.ASTRA_DB_APPLICATION_TOKEN,
- AstraDBEndpoint: process?.env?.ASTRA_DB_ENDPOINT,
+ ...this.vectorDBPreferenceKeys(),
// --------------------------------------------------------
// LLM Provider Selection Settings & Configs
// --------------------------------------------------------
LLMProvider: llmProvider,
- // OpenAI Keys
- OpenAiKey: !!process.env.OPEN_AI_KEY,
- OpenAiModelPref: process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo",
-
- // Azure + OpenAI Keys
- AzureOpenAiEndpoint: process.env.AZURE_OPENAI_ENDPOINT,
- AzureOpenAiKey: !!process.env.AZURE_OPENAI_KEY,
- AzureOpenAiModelPref: process.env.OPEN_MODEL_PREF,
- AzureOpenAiEmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF,
- AzureOpenAiTokenLimit: process.env.AZURE_OPENAI_TOKEN_LIMIT || 4096,
-
- // Anthropic Keys
- AnthropicApiKey: !!process.env.ANTHROPIC_API_KEY,
- AnthropicModelPref: process.env.ANTHROPIC_MODEL_PREF || "claude-2",
-
- // Gemini Keys
- GeminiLLMApiKey: !!process.env.GEMINI_API_KEY,
- GeminiLLMModelPref: process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro",
-
- // LMStudio Keys
- LMStudioBasePath: process.env.LMSTUDIO_BASE_PATH,
- LMStudioTokenLimit: process.env.LMSTUDIO_MODEL_TOKEN_LIMIT,
- LMStudioModelPref: process.env.LMSTUDIO_MODEL_PREF,
-
- // LocalAI Keys
- LocalAiApiKey: !!process.env.LOCAL_AI_API_KEY,
- LocalAiBasePath: process.env.LOCAL_AI_BASE_PATH,
- LocalAiModelPref: process.env.LOCAL_AI_MODEL_PREF,
- LocalAiTokenLimit: process.env.LOCAL_AI_MODEL_TOKEN_LIMIT,
-
- // Ollama LLM Keys
- OllamaLLMBasePath: process.env.OLLAMA_BASE_PATH,
- OllamaLLMModelPref: process.env.OLLAMA_MODEL_PREF,
- OllamaLLMTokenLimit: process.env.OLLAMA_MODEL_TOKEN_LIMIT,
-
- // TogetherAI Keys
- TogetherAiApiKey: !!process.env.TOGETHER_AI_API_KEY,
- TogetherAiModelPref: process.env.TOGETHER_AI_MODEL_PREF,
-
- // Perplexity AI Keys
- PerplexityApiKey: !!process.env.PERPLEXITY_API_KEY,
- PerplexityModelPref: process.env.PERPLEXITY_MODEL_PREF,
-
- // OpenRouter Keys
- OpenRouterApiKey: !!process.env.OPENROUTER_API_KEY,
- OpenRouterModelPref: process.env.OPENROUTER_MODEL_PREF,
-
- // Mistral AI (API) Keys
- MistralApiKey: !!process.env.MISTRAL_API_KEY,
- MistralModelPref: process.env.MISTRAL_MODEL_PREF,
-
- // Groq AI API Keys
- GroqApiKey: !!process.env.GROQ_API_KEY,
- GroqModelPref: process.env.GROQ_MODEL_PREF,
-
- // Native LLM Keys
- NativeLLMModelPref: process.env.NATIVE_LLM_MODEL_PREF,
- NativeLLMTokenLimit: process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT,
-
- // HuggingFace Dedicated Inference
- HuggingFaceLLMEndpoint: process.env.HUGGING_FACE_LLM_ENDPOINT,
- HuggingFaceLLMAccessToken: !!process.env.HUGGING_FACE_LLM_API_KEY,
- HuggingFaceLLMTokenLimit: process.env.HUGGING_FACE_LLM_TOKEN_LIMIT,
+ ...this.llmPreferenceKeys(),
// --------------------------------------------------------
// Whisper (Audio transcription) Selection Settings & Configs
@@ -273,6 +183,108 @@ const SystemSettings = {
return false;
}
},
+
+ vectorDBPreferenceKeys: function () {
+ return {
+ // Pinecone DB Keys
+ PineConeKey: !!process.env.PINECONE_API_KEY,
+ PineConeIndex: process.env.PINECONE_INDEX,
+
+ // Chroma DB Keys
+ ChromaEndpoint: process.env.CHROMA_ENDPOINT,
+ ChromaApiHeader: process.env.CHROMA_API_HEADER,
+ ChromaApiKey: !!process.env.CHROMA_API_KEY,
+
+ // Weaviate DB Keys
+ WeaviateEndpoint: process.env.WEAVIATE_ENDPOINT,
+ WeaviateApiKey: process.env.WEAVIATE_API_KEY,
+
+ // QDrant DB Keys
+ QdrantEndpoint: process.env.QDRANT_ENDPOINT,
+ QdrantApiKey: process.env.QDRANT_API_KEY,
+
+ // Milvus DB Keys
+ MilvusAddress: process.env.MILVUS_ADDRESS,
+ MilvusUsername: process.env.MILVUS_USERNAME,
+ MilvusPassword: !!process.env.MILVUS_PASSWORD,
+
+ // Zilliz DB Keys
+ ZillizEndpoint: process.env.ZILLIZ_ENDPOINT,
+ ZillizApiToken: process.env.ZILLIZ_API_TOKEN,
+
+ // AstraDB Keys
+ AstraDBApplicationToken: process?.env?.ASTRA_DB_APPLICATION_TOKEN,
+ AstraDBEndpoint: process?.env?.ASTRA_DB_ENDPOINT,
+ };
+ },
+
+ llmPreferenceKeys: function () {
+ return {
+ // OpenAI Keys
+ OpenAiKey: !!process.env.OPEN_AI_KEY,
+ OpenAiModelPref: process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo",
+
+ // Azure + OpenAI Keys
+ AzureOpenAiEndpoint: process.env.AZURE_OPENAI_ENDPOINT,
+ AzureOpenAiKey: !!process.env.AZURE_OPENAI_KEY,
+ AzureOpenAiModelPref: process.env.OPEN_MODEL_PREF,
+ AzureOpenAiEmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF,
+ AzureOpenAiTokenLimit: process.env.AZURE_OPENAI_TOKEN_LIMIT || 4096,
+
+ // Anthropic Keys
+ AnthropicApiKey: !!process.env.ANTHROPIC_API_KEY,
+ AnthropicModelPref: process.env.ANTHROPIC_MODEL_PREF || "claude-2",
+
+ // Gemini Keys
+ GeminiLLMApiKey: !!process.env.GEMINI_API_KEY,
+ GeminiLLMModelPref: process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro",
+
+ // LMStudio Keys
+ LMStudioBasePath: process.env.LMSTUDIO_BASE_PATH,
+ LMStudioTokenLimit: process.env.LMSTUDIO_MODEL_TOKEN_LIMIT,
+ LMStudioModelPref: process.env.LMSTUDIO_MODEL_PREF,
+
+ // LocalAI Keys
+ LocalAiApiKey: !!process.env.LOCAL_AI_API_KEY,
+ LocalAiBasePath: process.env.LOCAL_AI_BASE_PATH,
+ LocalAiModelPref: process.env.LOCAL_AI_MODEL_PREF,
+ LocalAiTokenLimit: process.env.LOCAL_AI_MODEL_TOKEN_LIMIT,
+
+ // Ollama LLM Keys
+ OllamaLLMBasePath: process.env.OLLAMA_BASE_PATH,
+ OllamaLLMModelPref: process.env.OLLAMA_MODEL_PREF,
+ OllamaLLMTokenLimit: process.env.OLLAMA_MODEL_TOKEN_LIMIT,
+
+ // TogetherAI Keys
+ TogetherAiApiKey: !!process.env.TOGETHER_AI_API_KEY,
+ TogetherAiModelPref: process.env.TOGETHER_AI_MODEL_PREF,
+
+ // Perplexity AI Keys
+ PerplexityApiKey: !!process.env.PERPLEXITY_API_KEY,
+ PerplexityModelPref: process.env.PERPLEXITY_MODEL_PREF,
+
+ // OpenRouter Keys
+ OpenRouterApiKey: !!process.env.OPENROUTER_API_KEY,
+ OpenRouterModelPref: process.env.OPENROUTER_MODEL_PREF,
+
+ // Mistral AI (API) Keys
+ MistralApiKey: !!process.env.MISTRAL_API_KEY,
+ MistralModelPref: process.env.MISTRAL_MODEL_PREF,
+
+ // Groq AI API Keys
+ GroqApiKey: !!process.env.GROQ_API_KEY,
+ GroqModelPref: process.env.GROQ_MODEL_PREF,
+
+ // Native LLM Keys
+ NativeLLMModelPref: process.env.NATIVE_LLM_MODEL_PREF,
+ NativeLLMTokenLimit: process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT,
+
+ // HuggingFace Dedicated Inference
+ HuggingFaceLLMEndpoint: process.env.HUGGING_FACE_LLM_ENDPOINT,
+ HuggingFaceLLMAccessToken: !!process.env.HUGGING_FACE_LLM_API_KEY,
+ HuggingFaceLLMTokenLimit: process.env.HUGGING_FACE_LLM_TOKEN_LIMIT,
+ };
+ },
};
module.exports.SystemSettings = SystemSettings;
diff --git a/server/models/workspace.js b/server/models/workspace.js
index f061ca20..b905c199 100644
--- a/server/models/workspace.js
+++ b/server/models/workspace.js
@@ -19,6 +19,7 @@ const Workspace = {
"lastUpdatedAt",
"openAiPrompt",
"similarityThreshold",
+ "chatProvider",
"chatModel",
"topN",
"chatMode",
@@ -52,19 +53,42 @@ const Workspace = {
}
},
- update: async function (id = null, data = {}) {
+ update: async function (id = null, updates = {}) {
if (!id) throw new Error("No workspace id provided for update");
- const validKeys = Object.keys(data).filter((key) =>
+ const validFields = Object.keys(updates).filter((key) =>
this.writable.includes(key)
);
- if (validKeys.length === 0)
+
+ Object.entries(updates).forEach(([key]) => {
+ if (validFields.includes(key)) return;
+ delete updates[key];
+ });
+
+ if (Object.keys(updates).length === 0)
return { workspace: { id }, message: "No valid fields to update!" };
+ // If the user unset the chatProvider we will need
+ // to then clear the chatModel as well to prevent confusion during
+ // LLM loading.
+ if (updates?.chatProvider === "default") {
+ updates.chatProvider = null;
+ updates.chatModel = null;
+ }
+
+ return this._update(id, updates);
+ },
+
+ // Explicit update of settings + key validations.
+ // Only use this method when directly setting a key value
+ // that takes no user input for the keys being modified.
+ _update: async function (id = null, data = {}) {
+ if (!id) throw new Error("No workspace id provided for update");
+
try {
const workspace = await prisma.workspaces.update({
where: { id },
- data, // TODO: strict validation on writables here.
+ data,
});
return { workspace, message: null };
} catch (error) {
@@ -229,47 +253,40 @@ const Workspace = {
}
},
- resetWorkspaceChatModels: async () => {
- try {
- await prisma.workspaces.updateMany({
- data: {
- chatModel: null,
- },
- });
- return { success: true, error: null };
- } catch (error) {
- console.error("Error resetting workspace chat models:", error.message);
- return { success: false, error: error.message };
- }
- },
-
trackChange: async function (prevData, newData, user) {
try {
- const { Telemetry } = require("./telemetry");
- const { EventLogs } = require("./eventLogs");
- if (
- !newData?.openAiPrompt ||
- newData?.openAiPrompt === this.defaultPrompt ||
- newData?.openAiPrompt === prevData?.openAiPrompt
- )
- return;
-
- await Telemetry.sendTelemetry("workspace_prompt_changed");
- await EventLogs.logEvent(
- "workspace_prompt_changed",
- {
- workspaceName: prevData?.name,
- prevSystemPrompt: prevData?.openAiPrompt || this.defaultPrompt,
- newSystemPrompt: newData?.openAiPrompt,
- },
- user?.id
- );
+ await this._trackWorkspacePromptChange(prevData, newData, user);
return;
} catch (error) {
console.error("Error tracking workspace change:", error.message);
return;
}
},
+
+ // We are only tracking this change to determine the need to a prompt library or
+ // prompt assistant feature. If this is something you would like to see - tell us on GitHub!
+ _trackWorkspacePromptChange: async function (prevData, newData, user) {
+ const { Telemetry } = require("./telemetry");
+ const { EventLogs } = require("./eventLogs");
+ if (
+ !newData?.openAiPrompt ||
+ newData?.openAiPrompt === this.defaultPrompt ||
+ newData?.openAiPrompt === prevData?.openAiPrompt
+ )
+ return;
+
+ await Telemetry.sendTelemetry("workspace_prompt_changed");
+ await EventLogs.logEvent(
+ "workspace_prompt_changed",
+ {
+ workspaceName: prevData?.name,
+ prevSystemPrompt: prevData?.openAiPrompt || this.defaultPrompt,
+ newSystemPrompt: newData?.openAiPrompt,
+ },
+ user?.id
+ );
+ return;
+ },
};
module.exports = { Workspace };
diff --git a/server/prisma/migrations/20240405015034_init/migration.sql b/server/prisma/migrations/20240405015034_init/migration.sql
new file mode 100644
index 00000000..54a39d94
--- /dev/null
+++ b/server/prisma/migrations/20240405015034_init/migration.sql
@@ -0,0 +1,2 @@
+-- AlterTable
+ALTER TABLE "workspaces" ADD COLUMN "chatProvider" TEXT;
diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma
index fbb5f61d..1e589b0f 100644
--- a/server/prisma/schema.prisma
+++ b/server/prisma/schema.prisma
@@ -98,6 +98,7 @@ model workspaces {
lastUpdatedAt DateTime @default(now())
openAiPrompt String?
similarityThreshold Float? @default(0.25)
+ chatProvider String?
chatModel String?
topN Int? @default(4)
chatMode String? @default("chat")
diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js
index f748a3a5..497b2c8e 100644
--- a/server/utils/chats/embed.js
+++ b/server/utils/chats/embed.js
@@ -28,7 +28,9 @@ async function streamChatWithForEmbed(
embed.workspace.openAiTemp = parseFloat(temperatureOverride);
const uuid = uuidv4();
- const LLMConnector = getLLMProvider(chatModel ?? embed.workspace?.chatModel);
+ const LLMConnector = getLLMProvider({
+ model: chatModel ?? embed.workspace?.chatModel,
+ });
const VectorDb = getVectorDbClass();
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
if (!safe) {
diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js
index 10df9983..7e40b9a8 100644
--- a/server/utils/chats/index.js
+++ b/server/utils/chats/index.js
@@ -37,7 +37,10 @@ async function chatWithWorkspace(
return await VALID_COMMANDS[command](workspace, message, uuid, user);
}
- const LLMConnector = getLLMProvider(workspace?.chatModel);
+ const LLMConnector = getLLMProvider({
+ provider: workspace?.chatProvider,
+ model: workspace?.chatModel,
+ });
const VectorDb = getVectorDbClass();
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
if (!safe) {
diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js
index f1a335bc..0ec969eb 100644
--- a/server/utils/chats/stream.js
+++ b/server/utils/chats/stream.js
@@ -35,7 +35,10 @@ async function streamChatWithWorkspace(
return;
}
- const LLMConnector = getLLMProvider(workspace?.chatModel);
+ const LLMConnector = getLLMProvider({
+ provider: workspace?.chatProvider,
+ model: workspace?.chatModel,
+ });
const VectorDb = getVectorDbClass();
const { safe, reasons = [] } = await LLMConnector.isSafe(message);
if (!safe) {
diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js
index 78360972..a441bf82 100644
--- a/server/utils/helpers/index.js
+++ b/server/utils/helpers/index.js
@@ -30,52 +30,53 @@ function getVectorDbClass() {
}
}
-function getLLMProvider(modelPreference = null) {
- const vectorSelection = process.env.LLM_PROVIDER || "openai";
+function getLLMProvider({ provider = null, model = null } = {}) {
+ const LLMSelection = provider ?? process.env.LLM_PROVIDER ?? "openai";
const embedder = getEmbeddingEngineSelection();
- switch (vectorSelection) {
+
+ switch (LLMSelection) {
case "openai":
const { OpenAiLLM } = require("../AiProviders/openAi");
- return new OpenAiLLM(embedder, modelPreference);
+ return new OpenAiLLM(embedder, model);
case "azure":
const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi");
- return new AzureOpenAiLLM(embedder, modelPreference);
+ return new AzureOpenAiLLM(embedder, model);
case "anthropic":
const { AnthropicLLM } = require("../AiProviders/anthropic");
- return new AnthropicLLM(embedder, modelPreference);
+ return new AnthropicLLM(embedder, model);
case "gemini":
const { GeminiLLM } = require("../AiProviders/gemini");
- return new GeminiLLM(embedder, modelPreference);
+ return new GeminiLLM(embedder, model);
case "lmstudio":
const { LMStudioLLM } = require("../AiProviders/lmStudio");
- return new LMStudioLLM(embedder, modelPreference);
+ return new LMStudioLLM(embedder, model);
case "localai":
const { LocalAiLLM } = require("../AiProviders/localAi");
- return new LocalAiLLM(embedder, modelPreference);
+ return new LocalAiLLM(embedder, model);
case "ollama":
const { OllamaAILLM } = require("../AiProviders/ollama");
- return new OllamaAILLM(embedder, modelPreference);
+ return new OllamaAILLM(embedder, model);
case "togetherai":
const { TogetherAiLLM } = require("../AiProviders/togetherAi");
- return new TogetherAiLLM(embedder, modelPreference);
+ return new TogetherAiLLM(embedder, model);
case "perplexity":
const { PerplexityLLM } = require("../AiProviders/perplexity");
- return new PerplexityLLM(embedder, modelPreference);
+ return new PerplexityLLM(embedder, model);
case "openrouter":
const { OpenRouterLLM } = require("../AiProviders/openRouter");
- return new OpenRouterLLM(embedder, modelPreference);
+ return new OpenRouterLLM(embedder, model);
case "mistral":
const { MistralLLM } = require("../AiProviders/mistral");
- return new MistralLLM(embedder, modelPreference);
+ return new MistralLLM(embedder, model);
case "native":
const { NativeLLM } = require("../AiProviders/native");
- return new NativeLLM(embedder, modelPreference);
+ return new NativeLLM(embedder, model);
case "huggingface":
const { HuggingFaceLLM } = require("../AiProviders/huggingface");
- return new HuggingFaceLLM(embedder, modelPreference);
+ return new HuggingFaceLLM(embedder, model);
case "groq":
const { GroqLLM } = require("../AiProviders/groq");
- return new GroqLLM(embedder, modelPreference);
+ return new GroqLLM(embedder, model);
default:
throw new Error("ENV: No LLM_PROVIDER value found in environment!");
}
diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js
index 12c45af2..a026fe33 100644
--- a/server/utils/helpers/updateENV.js
+++ b/server/utils/helpers/updateENV.js
@@ -2,7 +2,6 @@ const KEY_MAPPING = {
LLMProvider: {
envKey: "LLM_PROVIDER",
checks: [isNotEmpty, supportedLLM],
- postUpdate: [wipeWorkspaceModelPreference],
},
// OpenAI Settings
OpenAiKey: {
@@ -493,15 +492,6 @@ function validHuggingFaceEndpoint(input = "") {
: null;
}
-// If the LLMProvider has changed we need to reset all workspace model preferences to
-// null since the provider<>model name combination will be invalid for whatever the new
-// provider is.
-async function wipeWorkspaceModelPreference(key, prev, next) {
- if (prev === next) return;
- const { Workspace } = require("../../models/workspace");
- await Workspace.resetWorkspaceChatModels();
-}
-
// This will force update .env variables which for any which reason were not able to be parsed or
// read from an ENV file as this seems to be a complicating step for many so allowing people to write
// to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks