Ollama agents (#1270)

* add LMStudio agent support (generic) support
"work" with non-tool callable LLMs, highly dependent on system specs

* add comments

* enable few-shot prompting per function for OSS models

* Add Agent support for Ollama models

* improve json parsing for ollama text responses
This commit is contained in:
Timothy Carambat 2024-05-07 18:06:31 -07:00 committed by GitHub
parent 1b4559f57f
commit 331d3741c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 185 additions and 42 deletions

View File

@ -5,8 +5,8 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
import AgentModelSelection from "../AgentModelSelection";
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio"];
const WARN_PERFORMANCE = ["lmstudio"];
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"];
const WARN_PERFORMANCE = ["lmstudio", "ollama"];
const LLM_DEFAULT = {
name: "Please make a selection",

View File

@ -46,6 +46,7 @@
"dotenv": "^16.0.3",
"express": "^4.18.2",
"express-ws": "^5.0.2",
"extract-json-from-string": "^1.0.1",
"extract-zip": "^2.0.1",
"graphql": "^16.7.1",
"joi": "^17.11.0",
@ -59,6 +60,7 @@
"multer": "^1.4.5-lts.1",
"node-html-markdown": "^1.3.0",
"node-llama-cpp": "^2.8.0",
"ollama": "^0.5.0",
"openai": "4.38.5",
"pinecone-client": "^1.1.0",
"pluralize": "^8.0.0",

View File

@ -741,6 +741,8 @@ ${this.getHistory({ to: route.to })
return new Providers.AnthropicProvider({ model: config.model });
case "lmstudio":
return new Providers.LMStudioProvider({});
case "ollama":
return new Providers.OllamaProvider({ model: config.model });
default:
throw new Error(

View File

@ -102,16 +102,12 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
return { valid: true, reason: null };
}
async functionCall(messages, functions) {
async functionCall(messages, functions, chatCb = null) {
const history = [...messages].filter((msg) =>
["user", "assistant"].includes(msg.role)
);
if (history[history.length - 1].role !== "user") return null;
const response = await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
const response = await chatCb({
messages: [
{
content: `You are a program which picks the most optimal function and parameters to call.
@ -133,16 +129,6 @@ Now pick a function if there is an appropriate one to use given the last user me
},
...history,
],
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LMStudio chat: No results!");
if (result.choices.length === 0)
throw new Error("LMStudio chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
const call = safeJsonParse(response, null);

View File

@ -1,9 +1,11 @@
const OpenAIProvider = require("./openai.js");
const AnthropicProvider = require("./anthropic.js");
const LMStudioProvider = require("./lmstudio.js");
const OllamaProvider = require("./ollama.js");
module.exports = {
OpenAIProvider,
AnthropicProvider,
LMStudioProvider,
OllamaProvider,
};

View File

@ -27,6 +27,25 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LMStudio chat: No results!");
if (result.choices.length === 0)
throw new Error("LMStudio chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
@ -38,7 +57,11 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(messages, functions);
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);

View File

@ -0,0 +1,107 @@
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
const { Ollama } = require("ollama");
/**
* The provider for the Ollama provider.
*/
class OllamaProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const {
// options = {},
model = null,
} = config;
super();
this._client = new Ollama({ host: process.env.OLLAMA_BASE_PATH });
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
const response = await this.client.chat({
model: this.model,
messages,
options: {
temperature: 0,
},
});
return response?.message?.content || null;
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat({
model: this.model,
messages: this.cleanMsgs(messages),
options: {
use_mlock: true,
temperature: 0.5,
},
});
completion = response.message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LMStudio has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = OllamaProvider;

View File

@ -79,7 +79,11 @@ class AgentHandler {
break;
case "lmstudio":
if (!process.env.LMSTUDIO_BASE_PATH)
throw new Error("LMStudio bash path must be provided to use agents.");
throw new Error("LMStudio base path must be provided to use agents.");
break;
case "ollama":
if (!process.env.OLLAMA_BASE_PATH)
throw new Error("Ollama base path must be provided to use agents.");
break;
default:
throw new Error("No provider found to power agent cluster.");
@ -94,6 +98,8 @@ class AgentHandler {
return "claude-3-sonnet-20240229";
case "lmstudio":
return "server-default";
case "ollama":
return "llama3:latest";
default:
return "unknown";
}

View File

@ -4,6 +4,7 @@ process.env.NODE_ENV === "development"
const JWT = require("jsonwebtoken");
const { User } = require("../../models/user");
const { jsonrepair } = require("jsonrepair");
const extract = require("extract-json-from-string");
function reqBody(request) {
return typeof request.body === "string"
@ -67,8 +68,6 @@ function safeJsonParse(jsonString, fallback = null) {
return JSON.parse(jsonString);
} catch {}
// If the jsonString does not look like an Obj or Array, dont attempt
// to repair it.
if (jsonString?.startsWith("[") || jsonString?.startsWith("{")) {
try {
const repairedJson = jsonrepair(jsonString);
@ -76,6 +75,10 @@ function safeJsonParse(jsonString, fallback = null) {
} catch {}
}
try {
return extract(jsonString)[0];
} catch {}
return fallback;
}

View File

@ -2678,6 +2678,11 @@ extract-files@^9.0.0:
resolved "https://registry.yarnpkg.com/extract-files/-/extract-files-9.0.0.tgz#8a7744f2437f81f5ed3250ed9f1550de902fe54a"
integrity sha512-CvdFfHkC95B4bBBk36hcEmvdR2awOdhhVUYH6S/zrVj3477zven/fJMYg7121h4T1xHZC+tetUpubpAhxwI7hQ==
extract-json-from-string@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/extract-json-from-string/-/extract-json-from-string-1.0.1.tgz#5001f17e6c905826dcd5989564e130959de60c96"
integrity sha512-xfQOSFYbELVs9QVkKsV9FZAjlAmXQ2SLR6FpfFX1kpn4QAvaGBJlrnVOblMLwrLPYc26H+q9qxo6JTd4E7AwgQ==
extract-zip@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/extract-zip/-/extract-zip-2.0.1.tgz#663dca56fe46df890d5f131ef4a06d22bb8ba13a"
@ -4560,6 +4565,13 @@ octokit@^3.1.0:
"@octokit/request-error" "^5.0.0"
"@octokit/types" "^12.0.0"
ollama@^0.5.0:
version "0.5.0"
resolved "https://registry.yarnpkg.com/ollama/-/ollama-0.5.0.tgz#cb9bc709d4d3278c9f484f751b0d9b98b06f4859"
integrity sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw==
dependencies:
whatwg-fetch "^3.6.20"
on-finished@2.4.1:
version "2.4.1"
resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.4.1.tgz#58c8c44116e54845ad57f14ab10b03533184ac3f"
@ -5980,7 +5992,7 @@ webidl-conversions@^3.0.0:
resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871"
integrity sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==
whatwg-fetch@^3.4.1:
whatwg-fetch@^3.4.1, whatwg-fetch@^3.6.20:
version "3.6.20"
resolved "https://registry.yarnpkg.com/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz#580ce6d791facec91d37c72890995a0b48d31c70"
integrity sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==