VectorDB class migration (#4787)

* Migrate Astra to class (#4722)

migrate astra to class

* Migrate LanceDB to class (#4721)

migrate lancedb to class

* Migrate Pinecone to class (#4726)

migrate pinecone to class

* Migrate Zilliz to class (#4729)

migrate zilliz to class

* Migrate Weaviate to class (#4728)

migrate weaviate to class

* Migrate Qdrant to class (#4727)

migrate qdrant to class

* Migrate Milvus to class (#4725)

migrate milvus to class

* Migrate Chroma to class (#4723)

migrate chroma to class

* Migrate Chroma Cloud to class (#4724)

* migrate chroma to class

* migrate chroma cloud to class

* move limits to class field

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>

* Migrate PGVector to class (#4730)

* migrate pgvector to class

* patch pgvector test

* convert connectionString, tableName, and validateConnection to static methods

* move instance properties to class fields

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>

* Refactor Zilliz Cloud vector DB provider (#4749)

simplify zilliz implementation by using milvus as base class

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>

* VectorDatabase base class (#4738)

create generic VectorDatabase base class

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>

* Extend VectorDatabase base class to all providers (#4755)

extend VectorDatabase base class to all providers

* patch lancedb import

* breakout name and add generic logger

* dev tag build

---------

Co-authored-by: Timothy Carambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2026-01-13 15:24:42 -08:00 committed by GitHub
parent 7c3b7906e7
commit 5039045f0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 892 additions and 860 deletions

View File

@ -6,7 +6,7 @@ concurrency:
on: on:
push: push:
branches: ['4841-aws-bedrock-api-key'] # put your current branch to create a build. Core team only. branches: ['vectordb-class-migration'] # put your current branch to create a build. Core team only.
paths-ignore: paths-ignore:
- '**.md' - '**.md'
- 'cloud-deployments/*' - 'cloud-deployments/*'

View File

@ -1,4 +1,6 @@
const { PGVector } = require("../../../../utils/vectorDbProviders/pgvector"); const { PGVector: PGVectorClass } = require("../../../../utils/vectorDbProviders/pgvector");
const PGVector = new PGVectorClass();
describe("PGVector.sanitizeForJsonb", () => { describe("PGVector.sanitizeForJsonb", () => {
it("returns null/undefined as-is", () => { it("returns null/undefined as-is", () => {

View File

@ -86,40 +86,40 @@ function getVectorDbClass(getExactly = null) {
switch (vectorSelection) { switch (vectorSelection) {
case "pinecone": case "pinecone":
const { Pinecone } = require("../vectorDbProviders/pinecone"); const { Pinecone } = require("../vectorDbProviders/pinecone");
return Pinecone; return new Pinecone();
case "chroma": case "chroma":
const { Chroma } = require("../vectorDbProviders/chroma"); const { Chroma } = require("../vectorDbProviders/chroma");
return Chroma; return new Chroma();
case "chromacloud": case "chromacloud":
const { ChromaCloud } = require("../vectorDbProviders/chromacloud"); const { ChromaCloud } = require("../vectorDbProviders/chromacloud");
return ChromaCloud; return new ChromaCloud();
case "lancedb": case "lancedb":
const { LanceDb } = require("../vectorDbProviders/lance"); const { LanceDb } = require("../vectorDbProviders/lance");
return LanceDb; return new LanceDb();
case "weaviate": case "weaviate":
const { Weaviate } = require("../vectorDbProviders/weaviate"); const { Weaviate } = require("../vectorDbProviders/weaviate");
return Weaviate; return new Weaviate();
case "qdrant": case "qdrant":
const { QDrant } = require("../vectorDbProviders/qdrant"); const { QDrant } = require("../vectorDbProviders/qdrant");
return QDrant; return new QDrant();
case "milvus": case "milvus":
const { Milvus } = require("../vectorDbProviders/milvus"); const { Milvus } = require("../vectorDbProviders/milvus");
return Milvus; return new Milvus();
case "zilliz": case "zilliz":
const { Zilliz } = require("../vectorDbProviders/zilliz"); const { Zilliz } = require("../vectorDbProviders/zilliz");
return Zilliz; return new Zilliz();
case "astra": case "astra":
const { AstraDB } = require("../vectorDbProviders/astra"); const { AstraDB } = require("../vectorDbProviders/astra");
return AstraDB; return new AstraDB();
case "pgvector": case "pgvector":
const { PGVector } = require("../vectorDbProviders/pgvector"); const { PGVector } = require("../vectorDbProviders/pgvector");
return PGVector; return new PGVector();
default: default:
console.error( console.error(
`\x1b[31m[ENV ERROR]\x1b[0m No VECTOR_DB value found in environment! Falling back to LanceDB` `\x1b[31m[ENV ERROR]\x1b[0m No VECTOR_DB value found in environment! Falling back to LanceDB`
); );
const { LanceDb: DefaultLanceDb } = require("../vectorDbProviders/lance"); const { LanceDb: DefaultLanceDb } = require("../vectorDbProviders/lance");
return DefaultLanceDb; return new DefaultLanceDb();
} }
} }

View File

@ -5,6 +5,7 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { VectorDatabase } = require("../base");
const sanitizeNamespace = (namespace) => { const sanitizeNamespace = (namespace) => {
// If namespace already starts with ns_, don't add it again // If namespace already starts with ns_, don't add it again
@ -22,14 +23,21 @@ const collectionExists = async function (client, namespace) {
return collections.includes(namespace); return collections.includes(namespace);
} }
} catch (error) { } catch (error) {
console.log("Astra::collectionExists check error", error?.message || error); this.logger("collectionExists check error", error?.message || error);
return false; // Return false for any error to allow creation attempt return false; // Return false for any error to allow creation attempt
} }
}; };
const AstraDB = { class AstraDB extends VectorDatabase {
name: "AstraDB", constructor() {
connect: async function () { super();
}
get name() {
return "AstraDB";
}
async connect() {
if (process.env.VECTOR_DB !== "astra") if (process.env.VECTOR_DB !== "astra")
throw new Error("AstraDB::Invalid ENV settings"); throw new Error("AstraDB::Invalid ENV settings");
@ -38,21 +46,24 @@ const AstraDB = {
process?.env?.ASTRA_DB_ENDPOINT process?.env?.ASTRA_DB_ENDPOINT
); );
return { client }; return { client };
}, }
heartbeat: async function () {
async heartbeat() {
return { heartbeat: Number(new Date()) }; return { heartbeat: Number(new Date()) };
}, }
// Astra interface will return a valid collection object even if the collection // Astra interface will return a valid collection object even if the collection
// does not actually exist. So we run a simple check which will always throw // does not actually exist. So we run a simple check which will always throw
// when the table truly does not exist. Faster than iterating all collections. // when the table truly does not exist. Faster than iterating all collections.
isRealCollection: async function (astraCollection = null) { async isRealCollection(astraCollection = null) {
if (!astraCollection) return false; if (!astraCollection) return false;
return await astraCollection return await astraCollection
.countDocuments() .countDocuments()
.then(() => true) .then(() => true)
.catch(() => false); .catch(() => false);
}, }
totalVectors: async function () {
async totalVectors() {
const { client } = await this.connect(); const { client } = await this.connect();
const collectionNames = await this.allNamespaces(client); const collectionNames = await this.allNamespaces(client);
var totalVectors = 0; var totalVectors = 0;
@ -62,13 +73,15 @@ const AstraDB = {
totalVectors += count ? count : 0; totalVectors += count ? count : 0;
} }
return totalVectors; return totalVectors;
}, }
namespaceCount: async function (_namespace = null) {
async namespaceCount(_namespace = null) {
const { client } = await this.connect(); const { client } = await this.connect();
const namespace = await this.namespace(client, _namespace); const namespace = await this.namespace(client, _namespace);
return namespace?.vectorCount || 0; return namespace?.vectorCount || 0;
}, }
namespace: async function (client, namespace = null) {
async namespace(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const sanitizedNamespace = sanitizeNamespace(namespace); const sanitizedNamespace = sanitizeNamespace(namespace);
const collection = await client const collection = await client
@ -77,7 +90,7 @@ const AstraDB = {
if (!(await this.isRealCollection(collection))) return null; if (!(await this.isRealCollection(collection))) return null;
const count = await collection.countDocuments().catch((e) => { const count = await collection.countDocuments().catch((e) => {
console.error("Astra::namespaceExists", e.message); this.logger("namespaceExists", e.message);
return null; return null;
}); });
@ -86,27 +99,31 @@ const AstraDB = {
...collection, ...collection,
vectorCount: typeof count === "number" ? count : 0, vectorCount: typeof count === "number" ? count : 0,
}; };
}, }
hasNamespace: async function (namespace = null) {
async hasNamespace(namespace = null) {
if (!namespace) return false; if (!namespace) return false;
const { client } = await this.connect(); const { client } = await this.connect();
return await this.namespaceExists(client, namespace); return await this.namespaceExists(client, namespace);
}, }
namespaceExists: async function (client, namespace = null) {
async namespaceExists(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const sanitizedNamespace = sanitizeNamespace(namespace); const sanitizedNamespace = sanitizeNamespace(namespace);
const collection = await client.collection(sanitizedNamespace); const collection = await client.collection(sanitizedNamespace);
return await this.isRealCollection(collection); return await this.isRealCollection(collection);
}, }
deleteVectorsInNamespace: async function (client, namespace = null) {
async deleteVectorsInNamespace(client, namespace = null) {
const sanitizedNamespace = sanitizeNamespace(namespace); const sanitizedNamespace = sanitizeNamespace(namespace);
await client.dropCollection(sanitizedNamespace); await client.dropCollection(sanitizedNamespace);
return true; return true;
}, }
// AstraDB requires a dimension aspect for collection creation // AstraDB requires a dimension aspect for collection creation
// we pass this in from the first chunk to infer the dimensions like other // we pass this in from the first chunk to infer the dimensions like other
// providers do. // providers do.
getOrCreateCollection: async function (client, namespace, dimensions = null) { async getOrCreateCollection(client, namespace, dimensions = null) {
const sanitizedNamespace = sanitizeNamespace(namespace); const sanitizedNamespace = sanitizeNamespace(namespace);
try { try {
const exists = await collectionExists(client, sanitizedNamespace); const exists = await collectionExists(client, sanitizedNamespace);
@ -132,14 +149,12 @@ const AstraDB = {
return await client.collection(sanitizedNamespace); return await client.collection(sanitizedNamespace);
} catch (error) { } catch (error) {
console.error( this.logger("getOrCreateCollection", error?.message || error);
"Astra::getOrCreateCollection error",
error?.message || error
);
throw error; throw error;
} }
}, }
addDocumentToNamespace: async function (
async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -151,7 +166,7 @@ const AstraDB = {
const { pageContent, docId, ...metadata } = documentData; const { pageContent, docId, ...metadata } = documentData;
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) { if (cacheResult.exists) {
@ -210,7 +225,7 @@ const AstraDB = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await EmbedderEngine.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
@ -246,7 +261,7 @@ const AstraDB = {
if (vectors.length > 0) { if (vectors.length > 0) {
const chunks = []; const chunks = [];
console.log("Inserting vectorized chunks into Astra DB."); this.logger("Inserting vectorized chunks into Astra DB.");
// AstraDB has maximum upsert size of 20 records per-request so we have to use a lower chunk size here // AstraDB has maximum upsert size of 20 records per-request so we have to use a lower chunk size here
// in order to do the queries - this takes a lot more time than other providers but there // in order to do the queries - this takes a lot more time than other providers but there
@ -266,11 +281,12 @@ const AstraDB = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (e) { } catch (e) {
console.error("addDocumentToNamespace", e.message); this.logger("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message }; return { vectorized: false, error: e.message };
} }
}, }
deleteDocumentFromNamespace: async function (namespace, docId) {
async deleteDocumentFromNamespace(namespace, docId) {
const { DocumentVectors } = require("../../../models/vectors"); const { DocumentVectors } = require("../../../models/vectors");
const { client } = await this.connect(); const { client } = await this.connect();
namespace = sanitizeNamespace(namespace); namespace = sanitizeNamespace(namespace);
@ -293,8 +309,9 @@ const AstraDB = {
const indexes = knownDocuments.map((doc) => doc.id); const indexes = knownDocuments.map((doc) => doc.id);
await DocumentVectors.deleteIds(indexes); await DocumentVectors.deleteIds(indexes);
return true; return true;
}, }
performSimilaritySearch: async function ({
async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -336,8 +353,9 @@ const AstraDB = {
sources: this.curateSources(sources), sources: this.curateSources(sources),
message: false, message: false,
}; };
}, }
similarityResponse: async function ({
async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -367,8 +385,8 @@ const AstraDB = {
responses.forEach((response) => { responses.forEach((response) => {
if (response.$similarity < similarityThreshold) return; if (response.$similarity < similarityThreshold) return;
if (filterIdentifiers.includes(sourceIdentifier(response.metadata))) { if (filterIdentifiers.includes(sourceIdentifier(response.metadata))) {
console.log( this.logger(
"AstraDB: A source was filtered from context as it's parent document is pinned." "A source was filtered from context as it's parent document is pinned."
); );
return; return;
} }
@ -380,8 +398,9 @@ const AstraDB = {
result.scores.push(response.$similarity); result.scores.push(response.$similarity);
}); });
return result; return result;
}, }
allNamespaces: async function (client) {
async allNamespaces(client) {
try { try {
let header = new Headers(); let header = new Headers();
header.append("Token", client?.httpClient?.applicationToken); header.append("Token", client?.httpClient?.applicationToken);
@ -403,11 +422,12 @@ const AstraDB = {
const collections = resp ? JSON.parse(resp)?.status?.collections : []; const collections = resp ? JSON.parse(resp)?.status?.collections : [];
return collections; return collections;
} catch (e) { } catch (e) {
console.error("Astra::AllNamespace", e); this.logger("AllNamespace", e);
return []; return [];
} }
}, }
"namespace-stats": async function (reqBody = {}) {
async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
const { client } = await this.connect(); const { client } = await this.connect();
@ -417,8 +437,9 @@ const AstraDB = {
return stats return stats
? stats ? stats
: { message: "No stats were able to be fetched from DB for namespace" }; : { message: "No stats were able to be fetched from DB for namespace" };
}, }
"delete-namespace": async function (reqBody = {}) {
async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) if (!(await this.namespaceExists(client, namespace)))
@ -431,8 +452,9 @@ const AstraDB = {
details?.vectorCount || "all" details?.vectorCount || "all"
} vectors.`, } vectors.`,
}; };
}, }
curateSources: function (sources = []) {
curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
if (Object.keys(source).length > 0) { if (Object.keys(source).length > 0) {
@ -446,7 +468,7 @@ const AstraDB = {
} }
return documents; return documents;
}, }
}; }
module.exports.AstraDB = AstraDB; module.exports.AstraDB = AstraDB;

View File

@ -0,0 +1,201 @@
/**
* Base class for all Vector Database providers.
* All vector database providers should extend this class and implement/override the necessary methods.
*/
class VectorDatabase {
get name() {
return "VectorDatabase";
}
constructor() {
if (this.constructor === VectorDatabase) {
throw new Error("VectorDatabase cannot be instantiated directly");
}
}
/**
* Connect to vector database client
* @returns {Promise<{client: any}>}
*/
async connect() {
throw new Error("Must be implemented by provider");
}
/**
* Heartbeat check for vector database client
* @returns {Promise<{heartbeat: number}>}
*/
async heartbeat() {
throw new Error("Must be implemented by provider");
}
/**
* Get total number of vectors across all namespaces
* @returns {Promise<number>}
*/
async totalVectors() {
throw new Error("Must be implemented by provider");
}
/**
* Get count of vectors in a specific namespace
* @param {string} namespace - Namespace to count vectors in
* @returns {Promise<number>}
*/
async namespaceCount(namespace = null) {
throw new Error("Must be implemented by provider");
}
/**
* Get namespace details
* @param {any} client - Vector database client
* @param {string} namespace - Namespace to get
* @returns {Promise<any>}
*/
async namespace(client, namespace = null) {
throw new Error("Must be implemented by provider");
}
/**
* Check if a namespace exists
* @param {string} namespace - Namespace to check
* @returns {Promise<boolean>}
*/
async hasNamespace(namespace = null) {
throw new Error("Must be implemented by provider");
}
/**
* Check if a namespace exists with a client
* @param {any} client - Vector database client
* @param {string} namespace - Namespace to check
* @returns {Promise<boolean>}
*/
async namespaceExists(client, namespace = null) {
throw new Error("Must be implemented by provider");
}
/**
* Delete all vectors in a namespace
* @param {any} client - Vector database client
* @param {string} namespace - Namespace to delete vectors from
* @returns {Promise<boolean>}
*/
async deleteVectorsInNamespace(client, namespace = null) {
throw new Error("Must be implemented by provider");
}
/**
* Add a document to a namespace
* @param {string} namespace - Namespace to add document to
* @param {Object} documentData - Document data
* @param {string} fullFilePath - Full file path
* @param {boolean} skipCache - Skip cache
* @returns {Promise<{vectorized: boolean, error: string|null}>}
*/
async addDocumentToNamespace(
namespace,
documentData = {},
fullFilePath = null,
skipCache = false
) {
throw new Error("Must be implemented by provider");
}
/**
* Delete a document from namespace
* @param {string} namespace - Namespace to delete document from
* @param {string} docId - Document id
* @returns {Promise<boolean>}
*/
async deleteDocumentFromNamespace(namespace, docId) {
throw new Error("Must be implemented by provider");
}
/**
* Perform a similarity search
* @param {Object} params - Search parameters
* @param {string} params.namespace - Namespace to search in
* @param {string} params.input - Input text to search for
* @param {any} params.LLMConnector - LLM connector for embeddings
* @param {number} params.similarityThreshold - Similarity threshold
* @param {number} params.topN - Number of results to return
* @param {string[]} params.filterIdentifiers - Identifiers to filter out
* @returns {Promise<{contextTexts: string[], sources: any[], message: string|boolean}>}
*/
async performSimilaritySearch({
namespace = null,
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
}) {
throw new Error("Must be implemented by provider");
}
/**
* Perform a similarity search and return raw results
* @param {Object} params - Search parameters
* @param {any} params.client - Vector database client
* @param {string} params.namespace - Namespace to search in
* @param {number[]} params.queryVector - Query vector
* @param {number} params.similarityThreshold - Similarity threshold
* @param {number} params.topN - Number of results to return
* @param {string[]} params.filterIdentifiers - Identifiers to filter out
* @returns {Promise<{contextTexts: string[], sourceDocuments: any[], scores: number[]}>}
*/
async similarityResponse({
client,
namespace,
queryVector,
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
}) {
throw new Error("Must be implemented by provider");
}
/**
* Get namespace stats
* @param {Object} reqBody - Request body
* @param {string} reqBody.namespace - Namespace to get stats for
* @returns {Promise<any>}
*/
async "namespace-stats"(reqBody = {}) {
throw new Error("Must be implemented by provider");
}
/**
* Delete a namespace
* @param {Object} reqBody - Request body
* @param {string} reqBody.namespace - Namespace to delete
* @returns {Promise<{message: string}>}
*/
async "delete-namespace"(reqBody = {}) {
throw new Error("Must be implemented by provider");
}
/**
* Reset vector database (delete all data)
* @returns {Promise<{reset: boolean}>}
*/
async reset() {
throw new Error("Must be implemented by provider");
}
/**
* Curate sources from search results
* @param {any[]} sources - Sources to curate
* @returns {any[]}
*/
curateSources(sources = []) {
throw new Error("Must be implemented by provider");
}
logger(message = null, ...args) {
console.log(`\x1b[36m[VectorDB::${this.name}]\x1b[0m ${message}`, ...args);
}
}
module.exports = { VectorDatabase };

View File

@ -6,12 +6,20 @@ const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { parseAuthHeader } = require("../../http"); const { parseAuthHeader } = require("../../http");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { VectorDatabase } = require("../base");
const COLLECTION_REGEX = new RegExp( const COLLECTION_REGEX = new RegExp(
/^(?!\d+\.\d+\.\d+\.\d+$)(?!.*\.\.)(?=^[a-zA-Z0-9][a-zA-Z0-9_-]{1,61}[a-zA-Z0-9]$).{3,63}$/ /^(?!\d+\.\d+\.\d+\.\d+$)(?!.*\.\.)(?=^[a-zA-Z0-9][a-zA-Z0-9_-]{1,61}[a-zA-Z0-9]$).{3,63}$/
); );
const Chroma = { class Chroma extends VectorDatabase {
name: "Chroma", constructor() {
super();
}
get name() {
return "Chroma";
}
// Chroma DB has specific requirements for collection names: // Chroma DB has specific requirements for collection names:
// (1) Must contain 3-63 characters // (1) Must contain 3-63 characters
// (2) Must start and end with an alphanumeric character // (2) Must start and end with an alphanumeric character
@ -20,7 +28,7 @@ const Chroma = {
// (5) Cannot be a valid IPv4 address // (5) Cannot be a valid IPv4 address
// We need to enforce these rules by normalizing the collection names // We need to enforce these rules by normalizing the collection names
// before communicating with the Chroma DB. // before communicating with the Chroma DB.
normalize: function (inputString) { normalize(inputString) {
if (COLLECTION_REGEX.test(inputString)) return inputString; if (COLLECTION_REGEX.test(inputString)) return inputString;
let normalized = inputString.replace(/[^a-zA-Z0-9_-]/g, "-"); let normalized = inputString.replace(/[^a-zA-Z0-9_-]/g, "-");
@ -54,8 +62,9 @@ const Chroma = {
} }
return normalized; return normalized;
}, }
connect: async function () {
async connect() {
if (process.env.VECTOR_DB !== "chroma") if (process.env.VECTOR_DB !== "chroma")
throw new Error("Chroma::Invalid ENV settings"); throw new Error("Chroma::Invalid ENV settings");
@ -79,12 +88,14 @@ const Chroma = {
"ChromaDB::Invalid Heartbeat received - is the instance online?" "ChromaDB::Invalid Heartbeat received - is the instance online?"
); );
return { client }; return { client };
}, }
heartbeat: async function () {
async heartbeat() {
const { client } = await this.connect(); const { client } = await this.connect();
return { heartbeat: await client.heartbeat() }; return { heartbeat: await client.heartbeat() };
}, }
totalVectors: async function () {
async totalVectors() {
const { client } = await this.connect(); const { client } = await this.connect();
const collections = await client.listCollections(); const collections = await client.listCollections();
var totalVectors = 0; var totalVectors = 0;
@ -96,19 +107,22 @@ const Chroma = {
totalVectors += await collection.count(); totalVectors += await collection.count();
} }
return totalVectors; return totalVectors;
}, }
distanceToSimilarity: function (distance = null) {
distanceToSimilarity(distance = null) {
if (distance === null || typeof distance !== "number") return 0.0; if (distance === null || typeof distance !== "number") return 0.0;
if (distance >= 1.0) return 1; if (distance >= 1.0) return 1;
if (distance < 0) return 1 - Math.abs(distance); if (distance < 0) return 1 - Math.abs(distance);
return 1 - distance; return 1 - distance;
}, }
namespaceCount: async function (_namespace = null) {
async namespaceCount(_namespace = null) {
const { client } = await this.connect(); const { client } = await this.connect();
const namespace = await this.namespace(client, this.normalize(_namespace)); const namespace = await this.namespace(client, this.normalize(_namespace));
return namespace?.vectorCount || 0; return namespace?.vectorCount || 0;
}, }
similarityResponse: async function ({
async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -137,8 +151,8 @@ const Chroma = {
if ( if (
filterIdentifiers.includes(sourceIdentifier(response.metadatas[0][i])) filterIdentifiers.includes(sourceIdentifier(response.metadatas[0][i]))
) { ) {
console.log( this.logger(
"Chroma: A source was filtered from context as it's parent document is pinned." "A source was filtered from context as it's parent document is pinned."
); );
return; return;
} }
@ -149,8 +163,9 @@ const Chroma = {
}); });
return result; return result;
}, }
namespace: async function (client, namespace = null) {
async namespace(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const collection = await client const collection = await client
.getCollection({ name: this.normalize(namespace) }) .getCollection({ name: this.normalize(namespace) })
@ -161,27 +176,31 @@ const Chroma = {
...collection, ...collection,
vectorCount: await collection.count(), vectorCount: await collection.count(),
}; };
}, }
hasNamespace: async function (namespace = null) {
async hasNamespace(namespace = null) {
if (!namespace) return false; if (!namespace) return false;
const { client } = await this.connect(); const { client } = await this.connect();
return await this.namespaceExists(client, this.normalize(namespace)); return await this.namespaceExists(client, this.normalize(namespace));
}, }
namespaceExists: async function (client, namespace = null) {
async namespaceExists(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const collection = await client const collection = await client
.getCollection({ name: this.normalize(namespace) }) .getCollection({ name: this.normalize(namespace) })
.catch((e) => { .catch((e) => {
console.error("ChromaDB::namespaceExists", e.message); this.logger("namespaceExists", e.message);
return null; return null;
}); });
return !!collection; return !!collection;
}, }
deleteVectorsInNamespace: async function (client, namespace = null) {
async deleteVectorsInNamespace(client, namespace = null) {
await client.deleteCollection({ name: this.normalize(namespace) }); await client.deleteCollection({ name: this.normalize(namespace) });
return true; return true;
}, }
addDocumentToNamespace: async function (
async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -192,7 +211,7 @@ const Chroma = {
const { pageContent, docId, ...metadata } = documentData; const { pageContent, docId, ...metadata } = documentData;
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) { if (cacheResult.exists) {
@ -254,7 +273,7 @@ const Chroma = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await EmbedderEngine.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
@ -298,16 +317,16 @@ const Chroma = {
if (vectors.length > 0) { if (vectors.length > 0) {
const chunks = []; const chunks = [];
console.log("Inserting vectorized chunks into Chroma collection."); this.logger("Inserting vectorized chunks into Chroma collection.");
for (const chunk of toChunks(vectors, 500)) chunks.push(chunk); for (const chunk of toChunks(vectors, 500)) chunks.push(chunk);
try { try {
await this.smartAdd(collection, submission); await this.smartAdd(collection, submission);
console.log( this.logger(
`Successfully added ${submission.ids.length} vectors to collection ${this.normalize(namespace)}` `Successfully added ${submission.ids.length} vectors to collection ${this.normalize(namespace)}`
); );
} catch (error) { } catch (error) {
console.error("Error adding to ChromaDB:", error); this.logger("Error adding to ChromaDB:", error);
throw new Error(`Error embedding into ChromaDB: ${error.message}`); throw new Error(`Error embedding into ChromaDB: ${error.message}`);
} }
@ -317,11 +336,12 @@ const Chroma = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (e) { } catch (e) {
console.error("addDocumentToNamespace", e.message); this.logger("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message }; return { vectorized: false, error: e.message };
} }
}, }
deleteDocumentFromNamespace: async function (namespace, docId) {
async deleteDocumentFromNamespace(namespace, docId) {
const { DocumentVectors } = require("../../../models/vectors"); const { DocumentVectors } = require("../../../models/vectors");
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) return; if (!(await this.namespaceExists(client, namespace))) return;
@ -338,8 +358,9 @@ const Chroma = {
const indexes = knownDocuments.map((doc) => doc.id); const indexes = knownDocuments.map((doc) => doc.id);
await DocumentVectors.deleteIds(indexes); await DocumentVectors.deleteIds(indexes);
return true; return true;
}, }
performSimilaritySearch: async function ({
async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -383,8 +404,9 @@ const Chroma = {
sources: this.curateSources(sources), sources: this.curateSources(sources),
message: false, message: false,
}; };
}, }
"namespace-stats": async function (reqBody = {}) {
async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
const { client } = await this.connect(); const { client } = await this.connect();
@ -394,8 +416,9 @@ const Chroma = {
return stats return stats
? stats ? stats
: { message: "No stats were able to be fetched from DB for namespace" }; : { message: "No stats were able to be fetched from DB for namespace" };
}, }
"delete-namespace": async function (reqBody = {}) {
async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, this.normalize(namespace)))) if (!(await this.namespaceExists(client, this.normalize(namespace))))
@ -406,13 +429,15 @@ const Chroma = {
return { return {
message: `Namespace ${namespace} was deleted along with ${details?.vectorCount} vectors.`, message: `Namespace ${namespace} was deleted along with ${details?.vectorCount} vectors.`,
}; };
}, }
reset: async function () {
async reset() {
const { client } = await this.connect(); const { client } = await this.connect();
await client.reset(); await client.reset();
return { reset: true }; return { reset: true };
}, }
curateSources: function (sources = []) {
curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
const { metadata = {} } = source; const { metadata = {} } = source;
@ -427,7 +452,8 @@ const Chroma = {
} }
return documents; return documents;
}, }
/** /**
* This method is a wrapper around the ChromaCollection.add method. * This method is a wrapper around the ChromaCollection.add method.
* It will return true if the add was successful, false otherwise. * It will return true if the add was successful, false otherwise.
@ -436,10 +462,11 @@ const Chroma = {
* @param {{ids: string[], embeddings: number[], metadatas: Record<string, any>[], documents: string[]}[]} submissions * @param {{ids: string[], embeddings: number[], metadatas: Record<string, any>[], documents: string[]}[]} submissions
* @returns {Promise<boolean>} True if the add was successful, false otherwise. * @returns {Promise<boolean>} True if the add was successful, false otherwise.
*/ */
smartAdd: async function (collection, submissions) { async smartAdd(collection, submissions) {
await collection.add(submissions); await collection.add(submissions);
return true; return true;
}, }
/** /**
* This method is a wrapper around the ChromaCollection.delete method. * This method is a wrapper around the ChromaCollection.delete method.
* It will return the result of the delete method directly. * It will return the result of the delete method directly.
@ -448,10 +475,10 @@ const Chroma = {
* @param {string[]} vectorIds * @param {string[]} vectorIds
* @returns {Promise<boolean>} True if the delete was successful, false otherwise. * @returns {Promise<boolean>} True if the delete was successful, false otherwise.
*/ */
smartDelete: async function (collection, vectorIds) { async smartDelete(collection, vectorIds) {
await collection.delete({ ids: vectorIds }); await collection.delete({ ids: vectorIds });
return true; return true;
}, }
}; }
module.exports.Chroma = Chroma; module.exports.Chroma = Chroma;

View File

@ -6,20 +6,27 @@ const { toChunks } = require("../../helpers");
* ChromaCloud works nearly the same as Chroma so we can just extend the * ChromaCloud works nearly the same as Chroma so we can just extend the
* Chroma class and override the connect method to use the CloudClient for major differences in API functionality. * Chroma class and override the connect method to use the CloudClient for major differences in API functionality.
*/ */
const ChromaCloud = { class ChromaCloud extends Chroma {
...Chroma, constructor() {
name: "ChromaCloud", super();
}
get name() {
return "ChromaCloud";
}
/** /**
* Basic quota/limitations for Chroma Cloud for accounts. Does not lookup client-specific limits. * Basic quota/limitations for Chroma Cloud for accounts. Does not lookup client-specific limits.
* @see https://docs.trychroma.com/cloud/quotas-limits * @see https://docs.trychroma.com/cloud/quotas-limits
*/ */
limits: { limits = {
maxEmbeddingDim: 4_096, maxEmbeddingDim: 4_096,
maxDocumentBytes: 16_384, maxDocumentBytes: 16_384,
maxMetadataBytes: 4_096, maxMetadataBytes: 4_096,
maxRecordsPerWrite: 300, maxRecordsPerWrite: 300,
}, };
connect: async function () {
async connect() {
if (process.env.VECTOR_DB !== "chromacloud") if (process.env.VECTOR_DB !== "chromacloud")
throw new Error("ChromaCloud::Invalid ENV settings"); throw new Error("ChromaCloud::Invalid ENV settings");
@ -35,7 +42,8 @@ const ChromaCloud = {
"ChromaCloud::Invalid Heartbeat received - is the instance online?" "ChromaCloud::Invalid Heartbeat received - is the instance online?"
); );
return { client }; return { client };
}, }
/** /**
* Chroma Cloud has some basic limitations on upserts to protect performance and latency. * Chroma Cloud has some basic limitations on upserts to protect performance and latency.
* Local deployments do not have these limitations since they are self-hosted. * Local deployments do not have these limitations since they are self-hosted.
@ -47,7 +55,7 @@ const ChromaCloud = {
* @returns {Promise<boolean>} True if the upsert was successful, false otherwise. * @returns {Promise<boolean>} True if the upsert was successful, false otherwise.
* If the upsert was not successful, the error message will be returned. * If the upsert was not successful, the error message will be returned.
*/ */
smartAdd: async function (collection, submission) { async smartAdd(collection, submission) {
const testSubmission = { const testSubmission = {
id: submission.ids[0], id: submission.ids[0],
embedding: submission.embeddings[0], embedding: submission.embeddings[0],
@ -77,8 +85,8 @@ const ChromaCloud = {
return true; return true;
} }
console.log( this.logger(
`ChromaCloud::Upsert Payload is too large (max is ${this.limits.maxRecordsPerWrite} records). Splitting into chunks of ${this.limits.maxRecordsPerWrite} records.` `Upsert Payload is too large (max is ${this.limits.maxRecordsPerWrite} records). Splitting into chunks of ${this.limits.maxRecordsPerWrite} records.`
); );
const chunks = []; const chunks = [];
let chunkedSubmission = { let chunkedSubmission = {
@ -93,7 +101,7 @@ const ChromaCloud = {
chunkedSubmission.metadatas.push(submission.metadatas[i]); chunkedSubmission.metadatas.push(submission.metadatas[i]);
chunkedSubmission.documents.push(submission.documents[i]); chunkedSubmission.documents.push(submission.documents[i]);
if (chunkedSubmission.ids.length === this.limits.maxRecordsPerWrite) { if (chunkedSubmission.ids.length === this.limits.maxRecordsPerWrite) {
console.log( this.logger(
`ChromaCloud::Adding chunk payload ${chunks.length + 1} of ${Math.ceil(submission.ids.length / this.limits.maxRecordsPerWrite)}` `ChromaCloud::Adding chunk payload ${chunks.length + 1} of ${Math.ceil(submission.ids.length / this.limits.maxRecordsPerWrite)}`
); );
chunks.push(chunkedSubmission); chunks.push(chunkedSubmission);
@ -114,7 +122,8 @@ const ChromaCloud = {
counter++; counter++;
} }
return true; return true;
}, }
/** /**
* This method is a wrapper around the ChromaCollection.delete method. * This method is a wrapper around the ChromaCollection.delete method.
* It will return the result of the delete method directly. * It will return the result of the delete method directly.
@ -127,22 +136,22 @@ const ChromaCloud = {
* @param {string[]} vectorIds * @param {string[]} vectorIds
* @returns {Promise<boolean>} True if the delete was successful, false otherwise. * @returns {Promise<boolean>} True if the delete was successful, false otherwise.
*/ */
smartDelete: async function (collection, vectorIds) { async smartDelete(collection, vectorIds) {
if (vectorIds.length <= this.limits.maxRecordsPerWrite) if (vectorIds.length <= this.limits.maxRecordsPerWrite)
return await collection.delete({ ids: vectorIds }); return await collection.delete({ ids: vectorIds });
console.log( this.logger(
`ChromaCloud::Delete Payload is too large (max is ${this.limits.maxRecordsPerWrite} records). Splitting into chunks of ${this.limits.maxRecordsPerWrite} records.` `Delete Payload is too large (max is ${this.limits.maxRecordsPerWrite} records). Splitting into chunks of ${this.limits.maxRecordsPerWrite} records.`
); );
const chunks = toChunks(vectorIds, this.limits.maxRecordsPerWrite); const chunks = toChunks(vectorIds, this.limits.maxRecordsPerWrite);
let counter = 1; let counter = 1;
for (const chunk of chunks) { for (const chunk of chunks) {
console.log(`ChromaCloud::Deleting chunk ${counter} of ${chunks.length}`); this.logger(`Deleting chunk ${counter} of ${chunks.length}`);
await collection.delete({ ids: chunk }); await collection.delete({ ids: chunk });
counter++; counter++;
} }
return true; return true;
}, }
}; }
module.exports.ChromaCloud = ChromaCloud; module.exports.ChromaCloud = ChromaCloud;

View File

@ -6,38 +6,54 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native"); const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native");
const { VectorDatabase } = require("../base");
const path = require("path");
/** /**
* LancedDB Client connection object * LancedDB Client connection object
* @typedef {import('@lancedb/lancedb').Connection} LanceClient * @typedef {import('@lancedb/lancedb').Connection} LanceClient
*/ */
const LanceDb = { class LanceDb extends VectorDatabase {
uri: `${ constructor() {
!!process.env.STORAGE_DIR ? `${process.env.STORAGE_DIR}/` : "./storage/" super();
}lancedb`, }
name: "LanceDb",
get uri() {
const basePath = !!process.env.STORAGE_DIR
? process.env.STORAGE_DIR
: path.resolve(__dirname, "../../../storage");
return path.resolve(basePath, "lancedb");
}
get name() {
return "LanceDb";
}
/** @returns {Promise<{client: LanceClient}>} */ /** @returns {Promise<{client: LanceClient}>} */
connect: async function () { async connect() {
const client = await lancedb.connect(this.uri); const client = await lancedb.connect(this.uri);
return { client }; return { client };
}, }
distanceToSimilarity: function (distance = null) {
distanceToSimilarity(distance = null) {
if (distance === null || typeof distance !== "number") return 0.0; if (distance === null || typeof distance !== "number") return 0.0;
if (distance >= 1.0) return 1; if (distance >= 1.0) return 1;
if (distance < 0) return 1 - Math.abs(distance); if (distance < 0) return 1 - Math.abs(distance);
return 1 - distance; return 1 - distance;
}, }
heartbeat: async function () {
async heartbeat() {
await this.connect(); await this.connect();
return { heartbeat: Number(new Date()) }; return { heartbeat: Number(new Date()) };
}, }
tables: async function () {
async tables() {
const { client } = await this.connect(); const { client } = await this.connect();
return await client.tableNames(); return await client.tableNames();
}, }
totalVectors: async function () {
async totalVectors() {
const { client } = await this.connect(); const { client } = await this.connect();
const tables = await client.tableNames(); const tables = await client.tableNames();
let count = 0; let count = 0;
@ -46,15 +62,17 @@ const LanceDb = {
count += await table.countRows(); count += await table.countRows();
} }
return count; return count;
}, }
namespaceCount: async function (_namespace = null) {
async namespaceCount(_namespace = null) {
const { client } = await this.connect(); const { client } = await this.connect();
const exists = await this.namespaceExists(client, _namespace); const exists = await this.namespaceExists(client, _namespace);
if (!exists) return 0; if (!exists) return 0;
const table = await client.openTable(_namespace); const table = await client.openTable(_namespace);
return (await table.countRows()) || 0; return (await table.countRows()) || 0;
}, }
/** /**
* Performs a SimilaritySearch + Reranking on a namespace. * Performs a SimilaritySearch + Reranking on a namespace.
* @param {Object} params - The parameters for the rerankedSimilarityResponse. * @param {Object} params - The parameters for the rerankedSimilarityResponse.
@ -67,7 +85,7 @@ const LanceDb = {
* @param {string[]} params.filterIdentifiers - The identifiers of the documents to filter out. * @param {string[]} params.filterIdentifiers - The identifiers of the documents to filter out.
* @returns * @returns
*/ */
rerankedSimilarityResponse: async function ({ async rerankedSimilarityResponse({
client, client,
namespace, namespace,
query, query,
@ -116,8 +134,8 @@ const LanceDb = {
return; return;
const { vector: _, ...rest } = item; const { vector: _, ...rest } = item;
if (filterIdentifiers.includes(sourceIdentifier(rest))) { if (filterIdentifiers.includes(sourceIdentifier(rest))) {
console.log( this.logger(
"LanceDB: A source was filtered from context as it's parent document is pinned." "A source was filtered from context as it's parent document is pinned."
); );
return; return;
} }
@ -133,12 +151,12 @@ const LanceDb = {
}); });
}) })
.catch((e) => { .catch((e) => {
console.error(e); this.logger(e);
console.error("LanceDB::rerankedSimilarityResponse", e.message); this.logger("rerankedSimilarityResponse", e.message);
}); });
return result; return result;
}, }
/** /**
* Performs a SimilaritySearch on a give LanceDB namespace. * Performs a SimilaritySearch on a give LanceDB namespace.
@ -151,7 +169,7 @@ const LanceDb = {
* @param {string[]} params.filterIdentifiers * @param {string[]} params.filterIdentifiers
* @returns * @returns
*/ */
similarityResponse: async function ({ async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -177,8 +195,8 @@ const LanceDb = {
return; return;
const { vector: _, ...rest } = item; const { vector: _, ...rest } = item;
if (filterIdentifiers.includes(sourceIdentifier(rest))) { if (filterIdentifiers.includes(sourceIdentifier(rest))) {
console.log( this.logger(
"LanceDB: A source was filtered from context as it's parent document is pinned." "A source was filtered from context as it's parent document is pinned."
); );
return; return;
} }
@ -192,14 +210,15 @@ const LanceDb = {
}); });
return result; return result;
}, }
/** /**
* *
* @param {LanceClient} client * @param {LanceClient} client
* @param {string} namespace * @param {string} namespace
* @returns * @returns
*/ */
namespace: async function (client, namespace = null) { async namespace(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const collection = await client.openTable(namespace).catch(() => false); const collection = await client.openTable(namespace).catch(() => false);
if (!collection) return null; if (!collection) return null;
@ -207,7 +226,8 @@ const LanceDb = {
return { return {
...collection, ...collection,
}; };
}, }
/** /**
* *
* @param {LanceClient} client * @param {LanceClient} client
@ -215,7 +235,7 @@ const LanceDb = {
* @param {string} namespace * @param {string} namespace
* @returns * @returns
*/ */
updateOrCreateCollection: async function (client, data = [], namespace) { async updateOrCreateCollection(client, data = [], namespace) {
const hasNamespace = await this.hasNamespace(namespace); const hasNamespace = await this.hasNamespace(namespace);
if (hasNamespace) { if (hasNamespace) {
const collection = await client.openTable(namespace); const collection = await client.openTable(namespace);
@ -225,40 +245,44 @@ const LanceDb = {
await client.createTable(namespace, data); await client.createTable(namespace, data);
return true; return true;
}, }
hasNamespace: async function (namespace = null) {
async hasNamespace(namespace = null) {
if (!namespace) return false; if (!namespace) return false;
const { client } = await this.connect(); const { client } = await this.connect();
const exists = await this.namespaceExists(client, namespace); const exists = await this.namespaceExists(client, namespace);
return exists; return exists;
}, }
/** /**
* *
* @param {LanceClient} client * @param {LanceClient} client
* @param {string} namespace * @param {string} namespace
* @returns * @returns
*/ */
namespaceExists: async function (client, namespace = null) { async namespaceExists(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const collections = await client.tableNames(); const collections = await client.tableNames();
return collections.includes(namespace); return collections.includes(namespace);
}, }
/** /**
* *
* @param {LanceClient} client * @param {LanceClient} client
* @param {string} namespace * @param {string} namespace
* @returns * @returns
*/ */
deleteVectorsInNamespace: async function (client, namespace = null) { async deleteVectorsInNamespace(client, namespace = null) {
await client.dropTable(namespace); await client.dropTable(namespace);
return true; return true;
}, }
deleteDocumentFromNamespace: async function (namespace, docId) {
async deleteDocumentFromNamespace(namespace, docId) {
const { client } = await this.connect(); const { client } = await this.connect();
const exists = await this.namespaceExists(client, namespace); const exists = await this.namespaceExists(client, namespace);
if (!exists) { if (!exists) {
console.error( this.logger(
`LanceDB:deleteDocumentFromNamespace - namespace ${namespace} does not exist.` `deleteDocumentFromNamespace - namespace ${namespace} does not exist.`
); );
return; return;
} }
@ -272,8 +296,9 @@ const LanceDb = {
if (vectorIds.length === 0) return; if (vectorIds.length === 0) return;
await table.delete(`id IN (${vectorIds.map((v) => `'${v}'`).join(",")})`); await table.delete(`id IN (${vectorIds.map((v) => `'${v}'`).join(",")})`);
return true; return true;
}, }
addDocumentToNamespace: async function (
async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -284,7 +309,7 @@ const LanceDb = {
const { pageContent, docId, ...metadata } = documentData; const { pageContent, docId, ...metadata } = documentData;
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) { if (cacheResult.exists) {
@ -329,7 +354,7 @@ const LanceDb = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const submissions = []; const submissions = [];
@ -364,7 +389,7 @@ const LanceDb = {
const chunks = []; const chunks = [];
for (const chunk of toChunks(vectors, 500)) chunks.push(chunk); for (const chunk of toChunks(vectors, 500)) chunks.push(chunk);
console.log("Inserting vectorized chunks into LanceDB collection."); this.logger("Inserting vectorized chunks into LanceDB collection.");
const { client } = await this.connect(); const { client } = await this.connect();
await this.updateOrCreateCollection(client, submissions, namespace); await this.updateOrCreateCollection(client, submissions, namespace);
await storeVectorResult(chunks, fullFilePath); await storeVectorResult(chunks, fullFilePath);
@ -373,11 +398,12 @@ const LanceDb = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (e) { } catch (e) {
console.error("addDocumentToNamespace", e.message); this.logger("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message }; return { vectorized: false, error: e.message };
} }
}, }
performSimilaritySearch: async function ({
async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -427,8 +453,9 @@ const LanceDb = {
sources: this.curateSources(sources), sources: this.curateSources(sources),
message: false, message: false,
}; };
}, }
"namespace-stats": async function (reqBody = {}) {
async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
const { client } = await this.connect(); const { client } = await this.connect();
@ -438,8 +465,9 @@ const LanceDb = {
return stats return stats
? stats ? stats
: { message: "No stats were able to be fetched from DB for namespace" }; : { message: "No stats were able to be fetched from DB for namespace" };
}, }
"delete-namespace": async function (reqBody = {}) {
async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) if (!(await this.namespaceExists(client, namespace)))
@ -449,14 +477,16 @@ const LanceDb = {
return { return {
message: `Namespace ${namespace} was deleted.`, message: `Namespace ${namespace} was deleted.`,
}; };
}, }
reset: async function () {
async reset() {
const { client } = await this.connect(); const { client } = await this.connect();
const fs = require("fs"); const fs = require("fs");
fs.rm(`${client.uri}`, { recursive: true }, () => null); fs.rm(`${client.uri}`, { recursive: true }, () => null);
return { reset: true }; return { reset: true };
}, }
curateSources: function (sources = []) {
curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
const { text, vector: _v, _distance: _d, ...rest } = source; const { text, vector: _v, _distance: _d, ...rest } = source;
@ -470,7 +500,7 @@ const LanceDb = {
} }
return documents; return documents;
}, }
}; }
module.exports.LanceDb = LanceDb; module.exports.LanceDb = LanceDb;

View File

@ -10,22 +10,31 @@ const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { VectorDatabase } = require("../base");
class Milvus extends VectorDatabase {
constructor() {
super();
}
get name() {
return "Milvus";
}
const Milvus = {
name: "Milvus",
// Milvus/Zilliz only allows letters, numbers, and underscores in collection names // Milvus/Zilliz only allows letters, numbers, and underscores in collection names
// so we need to enforce that by re-normalizing the names when communicating with // so we need to enforce that by re-normalizing the names when communicating with
// the DB. // the DB.
// If the first char of the collection is not an underscore or letter the collection name will be invalid. // If the first char of the collection is not an underscore or letter the collection name will be invalid.
normalize: function (inputString) { normalize(inputString) {
let normalized = inputString.replace(/[^a-zA-Z0-9_]/g, "_"); let normalized = inputString.replace(/[^a-zA-Z0-9_]/g, "_");
if (new RegExp(/^[a-zA-Z_]/).test(normalized.slice(0, 1))) if (new RegExp(/^[a-zA-Z_]/).test(normalized.slice(0, 1)))
normalized = `anythingllm_${normalized}`; normalized = `anythingllm_${normalized}`;
return normalized; return normalized;
}, }
connect: async function () {
async connect() {
if (process.env.VECTOR_DB !== "milvus") if (process.env.VECTOR_DB !== "milvus")
throw new Error("Milvus::Invalid ENV settings"); throw new Error(`${this.name}::Invalid ENV settings`);
const client = new MilvusClient({ const client = new MilvusClient({
address: process.env.MILVUS_ADDRESS, address: process.env.MILVUS_ADDRESS,
@ -36,16 +45,18 @@ const Milvus = {
const { isHealthy } = await client.checkHealth(); const { isHealthy } = await client.checkHealth();
if (!isHealthy) if (!isHealthy)
throw new Error( throw new Error(
"MilvusDB::Invalid Heartbeat received - is the instance online?" `${this.name}::Invalid Heartbeat received - is the instance online?`
); );
return { client }; return { client };
}, }
heartbeat: async function () {
async heartbeat() {
await this.connect(); await this.connect();
return { heartbeat: Number(new Date()) }; return { heartbeat: Number(new Date()) };
}, }
totalVectors: async function () {
async totalVectors() {
const { client } = await this.connect(); const { client } = await this.connect();
const { collection_names } = await client.listCollections(); const { collection_names } = await client.listCollections();
const total = collection_names.reduce(async (acc, collection_name) => { const total = collection_names.reduce(async (acc, collection_name) => {
@ -55,49 +66,55 @@ const Milvus = {
return Number(acc) + Number(statistics?.data?.row_count ?? 0); return Number(acc) + Number(statistics?.data?.row_count ?? 0);
}, 0); }, 0);
return total; return total;
}, }
namespaceCount: async function (_namespace = null) {
async namespaceCount(_namespace = null) {
const { client } = await this.connect(); const { client } = await this.connect();
const statistics = await client.getCollectionStatistics({ const statistics = await client.getCollectionStatistics({
collection_name: this.normalize(_namespace), collection_name: this.normalize(_namespace),
}); });
return Number(statistics?.data?.row_count ?? 0); return Number(statistics?.data?.row_count ?? 0);
}, }
namespace: async function (client, namespace = null) {
async namespace(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const collection = await client const collection = await client
.getCollectionStatistics({ collection_name: this.normalize(namespace) }) .getCollectionStatistics({ collection_name: this.normalize(namespace) })
.catch(() => null); .catch(() => null);
return collection; return collection;
}, }
hasNamespace: async function (namespace = null) {
async hasNamespace(namespace = null) {
if (!namespace) return false; if (!namespace) return false;
const { client } = await this.connect(); const { client } = await this.connect();
return await this.namespaceExists(client, namespace); return await this.namespaceExists(client, namespace);
}, }
namespaceExists: async function (client, namespace = null) {
async namespaceExists(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const { value } = await client const { value } = await client
.hasCollection({ collection_name: this.normalize(namespace) }) .hasCollection({ collection_name: this.normalize(namespace) })
.catch((e) => { .catch((e) => {
console.error("MilvusDB::namespaceExists", e.message); console.error(`${this.name}::namespaceExists`, e.message);
return { value: false }; return { value: false };
}); });
return value; return value;
}, }
deleteVectorsInNamespace: async function (client, namespace = null) {
async deleteVectorsInNamespace(client, namespace = null) {
await client.dropCollection({ collection_name: this.normalize(namespace) }); await client.dropCollection({ collection_name: this.normalize(namespace) });
return true; return true;
}, }
// Milvus requires a dimension aspect for collection creation // Milvus requires a dimension aspect for collection creation
// we pass this in from the first chunk to infer the dimensions like other // we pass this in from the first chunk to infer the dimensions like other
// providers do. // providers do.
getOrCreateCollection: async function (client, namespace, dimensions = null) { async getOrCreateCollection(client, namespace, dimensions = null) {
const isExists = await this.namespaceExists(client, namespace); const isExists = await this.namespaceExists(client, namespace);
if (!isExists) { if (!isExists) {
if (!dimensions) if (!dimensions)
throw new Error( throw new Error(
`Milvus:getOrCreateCollection Unable to infer vector dimension from input. Open an issue on GitHub for support.` `${this.name}::getOrCreateCollection Unable to infer vector dimension from input. Open an issue on GitHub for support.`
); );
await client.createCollection({ await client.createCollection({
@ -133,8 +150,9 @@ const Milvus = {
collection_name: this.normalize(namespace), collection_name: this.normalize(namespace),
}); });
} }
}, }
addDocumentToNamespace: async function (
async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -146,7 +164,7 @@ const Milvus = {
const { pageContent, docId, ...metadata } = documentData; const { pageContent, docId, ...metadata } = documentData;
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) { if (cacheResult.exists) {
@ -172,7 +190,7 @@ const Milvus = {
if (insertResult?.status.error_code !== "Success") { if (insertResult?.status.error_code !== "Success") {
throw new Error( throw new Error(
`Error embedding into Milvus! Reason:${insertResult?.status.reason}` `Error embedding into ${this.name}! Reason:${insertResult?.status.reason}`
); );
} }
} }
@ -208,7 +226,7 @@ const Milvus = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await EmbedderEngine.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
@ -238,7 +256,7 @@ const Milvus = {
const { client } = await this.connect(); const { client } = await this.connect();
await this.getOrCreateCollection(client, namespace, vectorDimension); await this.getOrCreateCollection(client, namespace, vectorDimension);
console.log("Inserting vectorized chunks into Milvus."); this.logger(`Inserting vectorized chunks into ${this.name}.`);
for (const chunk of toChunks(vectors, 100)) { for (const chunk of toChunks(vectors, 100)) {
chunks.push(chunk); chunks.push(chunk);
const insertResult = await client.insert({ const insertResult = await client.insert({
@ -252,7 +270,7 @@ const Milvus = {
if (insertResult?.status.error_code !== "Success") { if (insertResult?.status.error_code !== "Success") {
throw new Error( throw new Error(
`Error embedding into Milvus! Reason:${insertResult?.status.reason}` `Error embedding into ${this.name}! Reason:${insertResult?.status.reason}`
); );
} }
} }
@ -265,11 +283,12 @@ const Milvus = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (e) { } catch (e) {
console.error("addDocumentToNamespace", e.message); this.logger("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message }; return { vectorized: false, error: e.message };
} }
}, }
deleteDocumentFromNamespace: async function (namespace, docId) {
async deleteDocumentFromNamespace(namespace, docId) {
const { DocumentVectors } = require("../../../models/vectors"); const { DocumentVectors } = require("../../../models/vectors");
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) return; if (!(await this.namespaceExists(client, namespace))) return;
@ -291,8 +310,9 @@ const Milvus = {
// on a later call. // on a later call.
await client.flushSync({ collection_names: [this.normalize(namespace)] }); await client.flushSync({ collection_names: [this.normalize(namespace)] });
return true; return true;
}, }
performSimilaritySearch: async function ({
async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -331,8 +351,9 @@ const Milvus = {
sources: this.curateSources(sources), sources: this.curateSources(sources),
message: false, message: false,
}; };
}, }
similarityResponse: async function ({
async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -353,8 +374,8 @@ const Milvus = {
response.results.forEach((match) => { response.results.forEach((match) => {
if (match.score < similarityThreshold) return; if (match.score < similarityThreshold) return;
if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) { if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) {
console.log( this.logger(
"Milvus: A source was filtered from context as it's parent document is pinned." `${this.name}: A source was filtered from context as its parent document is pinned.`
); );
return; return;
} }
@ -367,8 +388,9 @@ const Milvus = {
result.scores.push(match.score); result.scores.push(match.score);
}); });
return result; return result;
}, }
"namespace-stats": async function (reqBody = {}) {
async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
const { client } = await this.connect(); const { client } = await this.connect();
@ -378,8 +400,9 @@ const Milvus = {
return stats return stats
? stats ? stats
: { message: "No stats were able to be fetched from DB for namespace" }; : { message: "No stats were able to be fetched from DB for namespace" };
}, }
"delete-namespace": async function (reqBody = {}) {
async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) if (!(await this.namespaceExists(client, namespace)))
@ -391,8 +414,9 @@ const Milvus = {
return { return {
message: `Namespace ${namespace} was deleted along with ${vectorCount} vectors.`, message: `Namespace ${namespace} was deleted along with ${vectorCount} vectors.`,
}; };
}, }
curateSources: function (sources = []) {
curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
const { metadata = {} } = source; const { metadata = {} } = source;
@ -404,7 +428,7 @@ const Milvus = {
} }
} }
return documents; return documents;
}, }
}; }
module.exports.Milvus = Milvus; module.exports.Milvus = Milvus;

View File

@ -3,6 +3,7 @@ const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { TextSplitter } = require("../../TextSplitter"); const { TextSplitter } = require("../../TextSplitter");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { VectorDatabase } = require("../base");
/* /*
Embedding Table Schema (table name defined by user) Embedding Table Schema (table name defined by user)
@ -13,44 +14,53 @@ const { sourceIdentifier } = require("../../chats");
- created_at: TIMESTAMP - created_at: TIMESTAMP
*/ */
const PGVector = { class PGVector extends VectorDatabase {
name: "PGVector", constructor() {
connectionTimeout: 30_000, super();
/** }
* Get the table name for the PGVector database.
* - Defaults to "anythingllm_vectors" if no table name is provided.
* @returns {string}
*/
tableName: () => process.env.PGVECTOR_TABLE_NAME || "anythingllm_vectors",
/** get name() {
* Get the connection string for the PGVector database. return "PGVector";
* - Requires a connection string to be present in the environment variables. }
* @returns {string | null}
*/
connectionString: () => process.env.PGVECTOR_CONNECTION_STRING,
connectionTimeout = 30_000;
// Possible for this to be a user-configurable option in the future. // Possible for this to be a user-configurable option in the future.
// Will require a handler per operator to ensure scores are normalized. // Will require a handler per operator to ensure scores are normalized.
operator: { operator = {
l2: "<->", l2: "<->",
innerProduct: "<#>", innerProduct: "<#>",
cosine: "<=>", cosine: "<=>",
l1: "<+>", l1: "<+>",
hamming: "<~>", hamming: "<~>",
jaccard: "<%>", jaccard: "<%>",
}, };
getTablesSql: getTablesSql =
"SELECT * FROM pg_catalog.pg_tables WHERE schemaname = 'public'", "SELECT * FROM pg_catalog.pg_tables WHERE schemaname = 'public'";
getEmbeddingTableSchemaSql: getEmbeddingTableSchemaSql =
"SELECT column_name,data_type FROM information_schema.columns WHERE table_name = $1", "SELECT column_name,data_type FROM information_schema.columns WHERE table_name = $1";
createExtensionSql: "CREATE EXTENSION IF NOT EXISTS vector;", createExtensionSql = "CREATE EXTENSION IF NOT EXISTS vector;";
createTableSql: (dimensions) =>
`CREATE TABLE IF NOT EXISTS "${PGVector.tableName()}" (id UUID PRIMARY KEY, namespace TEXT, embedding vector(${Number(dimensions)}), metadata JSONB, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)`,
log: function (message = null, ...args) { /**
console.log(`\x1b[35m[PGVectorDb]\x1b[0m ${message}`, ...args); * Get the table name for the PGVector database.
}, * - Defaults to "anythingllm_vectors" if no table name is provided.
* @returns {string}
*/
static tableName() {
return process.env.PGVECTOR_TABLE_NAME || "anythingllm_vectors";
}
/**
* Get the connection string for the PGVector database.
* - Requires a connection string to be present in the environment variables.
* @returns {string | null}
*/
static connectionString() {
return process.env.PGVECTOR_CONNECTION_STRING;
}
createTableSql(dimensions) {
return `CREATE TABLE IF NOT EXISTS "${PGVector.tableName()}" (id UUID PRIMARY KEY, namespace TEXT, embedding vector(${Number(dimensions)}), metadata JSONB, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)`;
}
/** /**
* Recursively sanitize values intended for JSONB to prevent Postgres errors * Recursively sanitize values intended for JSONB to prevent Postgres errors
@ -60,7 +70,7 @@ const PGVector = {
* @param {any} value * @param {any} value
* @returns {any} * @returns {any}
*/ */
sanitizeForJsonb: function (value) { sanitizeForJsonb(value) {
// Fast path for null/undefined and primitives that do not need changes // Fast path for null/undefined and primitives that do not need changes
if (value === null || value === undefined) return value; if (value === null || value === undefined) return value;
@ -99,13 +109,13 @@ const PGVector = {
// Numbers, booleans, etc. // Numbers, booleans, etc.
return value; return value;
}, }
client: function (connectionString = null) { client(connectionString = null) {
return new pgsql.Client({ return new pgsql.Client({
connectionString: connectionString || PGVector.connectionString(), connectionString: connectionString || PGVector.connectionString(),
}); });
}, }
/** /**
* Validate the existing embedding table schema. * Validate the existing embedding table schema.
@ -113,7 +123,7 @@ const PGVector = {
* @param {string} tableName * @param {string} tableName
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
validateExistingEmbeddingTableSchema: async function (pgClient, tableName) { async validateExistingEmbeddingTableSchema(pgClient, tableName) {
const result = await pgClient.query(this.getEmbeddingTableSchemaSql, [ const result = await pgClient.query(this.getEmbeddingTableSchemaSql, [
tableName, tableName,
]); ]);
@ -178,11 +188,11 @@ const PGVector = {
); );
} }
this.log( this.logger(
`✅ The pgvector table '${tableName}' was found and meets the minimum expected schema for an embedding table.` `✅ The pgvector table '${tableName}' was found and meets the minimum expected schema for an embedding table.`
); );
return true; return true;
}, }
/** /**
* Validate the connection to the database and verify that the table does not already exist. * Validate the connection to the database and verify that the table does not already exist.
@ -191,35 +201,36 @@ const PGVector = {
* @param {{connectionString: string | null, tableName: string | null}} params * @param {{connectionString: string | null, tableName: string | null}} params
* @returns {Promise<{error: string | null, success: boolean}>} * @returns {Promise<{error: string | null, success: boolean}>}
*/ */
validateConnection: async function ({ static async validateConnection({
connectionString = null, connectionString = null,
tableName = null, tableName = null,
}) { }) {
if (!connectionString) throw new Error("No connection string provided"); if (!connectionString) throw new Error("No connection string provided");
const instance = new PGVector();
try { try {
const timeoutPromise = new Promise((resolve) => { const timeoutPromise = new Promise((resolve) => {
setTimeout(() => { setTimeout(() => {
resolve({ resolve({
error: `Connection timeout (${(PGVector.connectionTimeout / 1000).toFixed(0)}s). Please check your connection string and try again.`, error: `Connection timeout (${(instance.connectionTimeout / 1000).toFixed(0)}s). Please check your connection string and try again.`,
success: false, success: false,
}); });
}, PGVector.connectionTimeout); }, instance.connectionTimeout);
}); });
const connectionPromise = new Promise(async (resolve) => { const connectionPromise = new Promise(async (resolve) => {
let pgClient = null; let pgClient = null;
try { try {
pgClient = this.client(connectionString); pgClient = instance.client(connectionString);
await pgClient.connect(); await pgClient.connect();
const result = await pgClient.query(this.getTablesSql); const result = await pgClient.query(instance.getTablesSql);
if (result.rows.length !== 0 && !!tableName) { if (result.rows.length !== 0 && !!tableName) {
const tableExists = result.rows.some( const tableExists = result.rows.some(
(row) => row.tablename === tableName (row) => row.tablename === tableName
); );
if (tableExists) if (tableExists)
await this.validateExistingEmbeddingTableSchema( await instance.validateExistingEmbeddingTableSchema(
pgClient, pgClient,
tableName tableName
); );
@ -236,7 +247,7 @@ const PGVector = {
const result = await Promise.race([connectionPromise, timeoutPromise]); const result = await Promise.race([connectionPromise, timeoutPromise]);
return result; return result;
} catch (err) { } catch (err) {
this.log("Validation Error:", err.message); instance.logger("Validation Error:", err.message);
let readableError = err.message; let readableError = err.message;
switch (true) { switch (true) {
case err.message.includes("ECONNREFUSED"): case err.message.includes("ECONNREFUSED"):
@ -248,13 +259,13 @@ const PGVector = {
} }
return { error: readableError, success: false }; return { error: readableError, success: false };
} }
}, }
/** /**
* Test the connection to the database directly. * Test the connection to the database directly.
* @returns {{error: string | null, success: boolean}} * @returns {{error: string | null, success: boolean}}
*/ */
testConnectionToDB: async function () { async testConnectionToDB() {
try { try {
const pgClient = await this.connect(); const pgClient = await this.connect();
await pgClient.query(this.getTablesSql); await pgClient.query(this.getTablesSql);
@ -263,14 +274,14 @@ const PGVector = {
} catch (err) { } catch (err) {
return { error: err.message, success: false }; return { error: err.message, success: false };
} }
}, }
/** /**
* Connect to the database. * Connect to the database.
* - Throws an error if the connection string or table name is not provided. * - Throws an error if the connection string or table name is not provided.
* @returns {Promise<pgsql.Client>} * @returns {Promise<pgsql.Client>}
*/ */
connect: async function () { async connect() {
if (!PGVector.connectionString()) if (!PGVector.connectionString())
throw new Error("No connection string provided"); throw new Error("No connection string provided");
if (!PGVector.tableName()) throw new Error("No table name provided"); if (!PGVector.tableName()) throw new Error("No table name provided");
@ -278,21 +289,21 @@ const PGVector = {
const client = this.client(); const client = this.client();
await client.connect(); await client.connect();
return client; return client;
}, }
/** /**
* Test the connection to the database with already set credentials via ENV * Test the connection to the database with already set credentials via ENV
* @returns {{error: string | null, success: boolean}} * @returns {{error: string | null, success: boolean}}
*/ */
heartbeat: async function () { async heartbeat() {
return this.testConnectionToDB(); return this.testConnectionToDB();
}, }
/** /**
* Check if the anythingllm embedding table exists in the database * Check if the anythingllm embedding table exists in the database
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
dbTableExists: async function () { async dbTableExists() {
let connection = null; let connection = null;
try { try {
connection = await this.connect(); connection = await this.connect();
@ -307,9 +318,9 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
totalVectors: async function () { async totalVectors() {
if (!(await this.dbTableExists())) return 0; if (!(await this.dbTableExists())) return 0;
let connection = null; let connection = null;
try { try {
@ -323,17 +334,17 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
// Distance for cosine is just the distance for pgvector. // Distance for cosine is just the distance for pgvector.
distanceToSimilarity: function (distance = null) { distanceToSimilarity(distance = null) {
if (distance === null || typeof distance !== "number") return 0.0; if (distance === null || typeof distance !== "number") return 0.0;
if (distance >= 1.0) return 1; if (distance >= 1.0) return 1;
if (distance < 0) return 1 - Math.abs(distance); if (distance < 0) return 1 - Math.abs(distance);
return 1 - distance; return 1 - distance;
}, }
namespaceCount: async function (namespace = null) { async namespaceCount(namespace = null) {
if (!(await this.dbTableExists())) return 0; if (!(await this.dbTableExists())) return 0;
let connection = null; let connection = null;
try { try {
@ -348,7 +359,7 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
/** /**
* Performs a SimilaritySearch on a given PGVector namespace. * Performs a SimilaritySearch on a given PGVector namespace.
@ -361,7 +372,7 @@ const PGVector = {
* @param {string[]} params.filterIdentifiers * @param {string[]} params.filterIdentifiers
* @returns * @returns
*/ */
similarityResponse: async function ({ async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -384,7 +395,7 @@ const PGVector = {
if (this.distanceToSimilarity(item._distance) < similarityThreshold) if (this.distanceToSimilarity(item._distance) < similarityThreshold)
return; return;
if (filterIdentifiers.includes(sourceIdentifier(item.metadata))) { if (filterIdentifiers.includes(sourceIdentifier(item.metadata))) {
this.log( this.logger(
"A source was filtered from context as it's parent document is pinned." "A source was filtered from context as it's parent document is pinned."
); );
return; return;
@ -399,15 +410,15 @@ const PGVector = {
}); });
return result; return result;
}, }
normalizeVector: function (vector) { normalizeVector(vector) {
const magnitude = Math.sqrt( const magnitude = Math.sqrt(
vector.reduce((sum, val) => sum + val * val, 0) vector.reduce((sum, val) => sum + val * val, 0)
); );
if (magnitude === 0) return vector; // Avoid division by zero if (magnitude === 0) return vector; // Avoid division by zero
return vector.map((val) => val / magnitude); return vector.map((val) => val / magnitude);
}, }
/** /**
* Update or create a collection in the database * Update or create a collection in the database
@ -418,14 +429,14 @@ const PGVector = {
* @param {number} params.dimensions * @param {number} params.dimensions
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
updateOrCreateCollection: async function ({ async updateOrCreateCollection({
connection, connection,
submissions, submissions,
namespace, namespace,
dimensions = 384, dimensions = 384,
}) { }) {
await this.createTableIfNotExists(connection, dimensions); await this.createTableIfNotExists(connection, dimensions);
this.log(`Updating or creating collection ${namespace}`); this.logger(`Updating or creating collection ${namespace}`);
try { try {
// Create a transaction of all inserts // Create a transaction of all inserts
@ -438,17 +449,17 @@ const PGVector = {
[submission.id, namespace, embedding, sanitizedMetadata] [submission.id, namespace, embedding, sanitizedMetadata]
); );
} }
this.log(`Committing ${submissions.length} vectors to ${namespace}`); this.logger(`Committing ${submissions.length} vectors to ${namespace}`);
await connection.query(`COMMIT`); await connection.query(`COMMIT`);
} catch (err) { } catch (err) {
this.log( this.logger(
`Rolling back ${submissions.length} vectors to ${namespace}`, `Rolling back ${submissions.length} vectors to ${namespace}`,
err err
); );
await connection.query(`ROLLBACK`); await connection.query(`ROLLBACK`);
} }
return true; return true;
}, }
/** /**
* create a table if it doesn't exist * create a table if it doesn't exist
@ -456,12 +467,12 @@ const PGVector = {
* @param {number} dimensions * @param {number} dimensions
* @returns * @returns
*/ */
createTableIfNotExists: async function (connection, dimensions = 384) { async createTableIfNotExists(connection, dimensions = 384) {
this.log(`Creating embedding table with ${dimensions} dimensions`); this.logger(`Creating embedding table with ${dimensions} dimensions`);
await connection.query(this.createExtensionSql); await connection.query(this.createExtensionSql);
await connection.query(this.createTableSql(dimensions)); await connection.query(this.createTableSql(dimensions));
return true; return true;
}, }
/** /**
* Get the namespace from the database * Get the namespace from the database
@ -469,21 +480,21 @@ const PGVector = {
* @param {string} namespace * @param {string} namespace
* @returns {Promise<{name: string, vectorCount: number}>} * @returns {Promise<{name: string, vectorCount: number}>}
*/ */
namespace: async function (connection, namespace = null) { async namespace(connection, namespace = null) {
if (!namespace) throw new Error("No namespace provided"); if (!namespace) throw new Error("No namespace provided");
const result = await connection.query( const result = await connection.query(
`SELECT COUNT(id) FROM "${PGVector.tableName()}" WHERE namespace = $1`, `SELECT COUNT(id) FROM "${PGVector.tableName()}" WHERE namespace = $1`,
[namespace] [namespace]
); );
return { name: namespace, vectorCount: result.rows[0].count }; return { name: namespace, vectorCount: result.rows[0].count };
}, }
/** /**
* Check if the namespace exists in the database * Check if the namespace exists in the database
* @param {string} namespace * @param {string} namespace
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
hasNamespace: async function (namespace = null) { async hasNamespace(namespace = null) {
if (!namespace) throw new Error("No namespace provided"); if (!namespace) throw new Error("No namespace provided");
let connection = null; let connection = null;
try { try {
@ -494,7 +505,7 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
/** /**
* Check if the namespace exists in the database * Check if the namespace exists in the database
@ -502,14 +513,14 @@ const PGVector = {
* @param {string} namespace * @param {string} namespace
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
namespaceExists: async function (connection, namespace = null) { async namespaceExists(connection, namespace = null) {
if (!namespace) throw new Error("No namespace provided"); if (!namespace) throw new Error("No namespace provided");
const result = await connection.query( const result = await connection.query(
`SELECT COUNT(id) FROM "${PGVector.tableName()}" WHERE namespace = $1 LIMIT 1`, `SELECT COUNT(id) FROM "${PGVector.tableName()}" WHERE namespace = $1 LIMIT 1`,
[namespace] [namespace]
); );
return result.rows[0].count > 0; return result.rows[0].count > 0;
}, }
/** /**
* Delete all vectors in the namespace * Delete all vectors in the namespace
@ -517,16 +528,16 @@ const PGVector = {
* @param {string} namespace * @param {string} namespace
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
deleteVectorsInNamespace: async function (connection, namespace = null) { async deleteVectorsInNamespace(connection, namespace = null) {
if (!namespace) throw new Error("No namespace provided"); if (!namespace) throw new Error("No namespace provided");
await connection.query( await connection.query(
`DELETE FROM "${PGVector.tableName()}" WHERE namespace = $1`, `DELETE FROM "${PGVector.tableName()}" WHERE namespace = $1`,
[namespace] [namespace]
); );
return true; return true;
}, }
addDocumentToNamespace: async function ( async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -544,7 +555,7 @@ const PGVector = {
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
connection = await this.connect(); connection = await this.connect();
this.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
let vectorDimensions; let vectorDimensions;
@ -594,7 +605,7 @@ const PGVector = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
this.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const submissions = []; const submissions = [];
@ -628,7 +639,7 @@ const PGVector = {
const chunks = []; const chunks = [];
for (const chunk of toChunks(vectors, 500)) chunks.push(chunk); for (const chunk of toChunks(vectors, 500)) chunks.push(chunk);
this.log("Inserting vectorized chunks into PGVector collection."); this.logger("Inserting vectorized chunks into PGVector collection.");
await this.updateOrCreateCollection({ await this.updateOrCreateCollection({
connection, connection,
submissions, submissions,
@ -641,12 +652,12 @@ const PGVector = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (err) { } catch (err) {
this.log("addDocumentToNamespace", err.message); this.logger("addDocumentToNamespace", err.message);
return { vectorized: false, error: err.message }; return { vectorized: false, error: err.message };
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
/** /**
* Delete a document from the namespace * Delete a document from the namespace
@ -654,7 +665,7 @@ const PGVector = {
* @param {string} docId * @param {string} docId
* @returns {Promise<boolean>} * @returns {Promise<boolean>}
*/ */
deleteDocumentFromNamespace: async function (namespace, docId) { async deleteDocumentFromNamespace(namespace, docId) {
if (!namespace) throw new Error("No namespace provided"); if (!namespace) throw new Error("No namespace provided");
if (!docId) throw new Error("No docId provided"); if (!docId) throw new Error("No docId provided");
@ -686,21 +697,21 @@ const PGVector = {
throw err; throw err;
} }
this.log( this.logger(
`Deleted ${vectorIds.length} vectors from namespace ${namespace}` `Deleted ${vectorIds.length} vectors from namespace ${namespace}`
); );
return true; return true;
} catch (err) { } catch (err) {
this.log( this.logger(
`Error deleting document from namespace ${namespace}: ${err.message}` `Error deleting document from namespace ${namespace}: ${err.message}`
); );
return false; return false;
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
performSimilaritySearch: async function ({ async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -716,7 +727,7 @@ const PGVector = {
connection = await this.connect(); connection = await this.connect();
const exists = await this.namespaceExists(connection, namespace); const exists = await this.namespaceExists(connection, namespace);
if (!exists) { if (!exists) {
this.log( this.logger(
`The namespace ${namespace} does not exist or has no vectors. Returning empty results.` `The namespace ${namespace} does not exist or has no vectors. Returning empty results.`
); );
return { return {
@ -750,9 +761,9 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
"namespace-stats": async function (reqBody = {}) { async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
if (!(await this.dbTableExists())) if (!(await this.dbTableExists()))
@ -774,9 +785,9 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
"delete-namespace": async function (reqBody = {}) { async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("No namespace provided"); if (!namespace) throw new Error("No namespace provided");
@ -800,13 +811,13 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
/** /**
* Reset the entire vector database table associated with anythingllm * Reset the entire vector database table associated with anythingllm
* @returns {Promise<{reset: boolean}>} * @returns {Promise<{reset: boolean}>}
*/ */
reset: async function () { async reset() {
let connection = null; let connection = null;
try { try {
connection = await this.connect(); connection = await this.connect();
@ -817,9 +828,9 @@ const PGVector = {
} finally { } finally {
if (connection) await connection.end(); if (connection) await connection.end();
} }
}, }
curateSources: function (sources = []) { curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
const { text, vector: _v, _distance: _d, ...rest } = source; const { text, vector: _v, _distance: _d, ...rest } = source;
@ -833,7 +844,7 @@ const PGVector = {
} }
return documents; return documents;
}, }
}; }
module.exports.PGVector = PGVector; module.exports.PGVector = PGVector;

View File

@ -5,10 +5,18 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { VectorDatabase } = require("../base");
const PineconeDB = { class PineconeDB extends VectorDatabase {
name: "Pinecone", constructor() {
connect: async function () { super();
}
get name() {
return "Pinecone";
}
async connect() {
if (process.env.VECTOR_DB !== "pinecone") if (process.env.VECTOR_DB !== "pinecone")
throw new Error("Pinecone::Invalid ENV settings"); throw new Error("Pinecone::Invalid ENV settings");
@ -21,8 +29,9 @@ const PineconeDB = {
if (!status.ready) throw new Error("Pinecone::Index not ready."); if (!status.ready) throw new Error("Pinecone::Index not ready.");
return { client, pineconeIndex, indexName: process.env.PINECONE_INDEX }; return { client, pineconeIndex, indexName: process.env.PINECONE_INDEX };
}, }
totalVectors: async function () {
async totalVectors() {
const { pineconeIndex } = await this.connect(); const { pineconeIndex } = await this.connect();
const { namespaces } = await pineconeIndex.describeIndexStats(); const { namespaces } = await pineconeIndex.describeIndexStats();
@ -30,13 +39,15 @@ const PineconeDB = {
(a, b) => a + (b?.recordCount || 0), (a, b) => a + (b?.recordCount || 0),
0 0
); );
}, }
namespaceCount: async function (_namespace = null) {
async namespaceCount(_namespace = null) {
const { pineconeIndex } = await this.connect(); const { pineconeIndex } = await this.connect();
const namespace = await this.namespace(pineconeIndex, _namespace); const namespace = await this.namespace(pineconeIndex, _namespace);
return namespace?.recordCount || 0; return namespace?.recordCount || 0;
}, }
similarityResponse: async function ({
async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -60,7 +71,7 @@ const PineconeDB = {
response.matches.forEach((match) => { response.matches.forEach((match) => {
if (match.score < similarityThreshold) return; if (match.score < similarityThreshold) return;
if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) { if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) {
console.log( this.logger(
"Pinecone: A source was filtered from context as it's parent document is pinned." "Pinecone: A source was filtered from context as it's parent document is pinned."
); );
return; return;
@ -75,28 +86,33 @@ const PineconeDB = {
}); });
return result; return result;
}, }
namespace: async function (index, namespace = null) {
async namespace(index, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const { namespaces } = await index.describeIndexStats(); const { namespaces } = await index.describeIndexStats();
return namespaces.hasOwnProperty(namespace) ? namespaces[namespace] : null; return namespaces.hasOwnProperty(namespace) ? namespaces[namespace] : null;
}, }
hasNamespace: async function (namespace = null) {
async hasNamespace(namespace = null) {
if (!namespace) return false; if (!namespace) return false;
const { pineconeIndex } = await this.connect(); const { pineconeIndex } = await this.connect();
return await this.namespaceExists(pineconeIndex, namespace); return await this.namespaceExists(pineconeIndex, namespace);
}, }
namespaceExists: async function (index, namespace = null) {
async namespaceExists(index, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const { namespaces } = await index.describeIndexStats(); const { namespaces } = await index.describeIndexStats();
return namespaces.hasOwnProperty(namespace); return namespaces.hasOwnProperty(namespace);
}, }
deleteVectorsInNamespace: async function (index, namespace = null) {
async deleteVectorsInNamespace(index, namespace = null) {
const pineconeNamespace = index.namespace(namespace); const pineconeNamespace = index.namespace(namespace);
await pineconeNamespace.deleteAll(); await pineconeNamespace.deleteAll();
return true; return true;
}, }
addDocumentToNamespace: async function (
async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -107,7 +123,7 @@ const PineconeDB = {
const { pageContent, docId, ...metadata } = documentData; const { pageContent, docId, ...metadata } = documentData;
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) { if (cacheResult.exists) {
@ -154,7 +170,7 @@ const PineconeDB = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await EmbedderEngine.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
@ -183,7 +199,7 @@ const PineconeDB = {
const chunks = []; const chunks = [];
const { pineconeIndex } = await this.connect(); const { pineconeIndex } = await this.connect();
const pineconeNamespace = pineconeIndex.namespace(namespace); const pineconeNamespace = pineconeIndex.namespace(namespace);
console.log("Inserting vectorized chunks into Pinecone."); this.logger("Inserting vectorized chunks into Pinecone.");
for (const chunk of toChunks(vectors, 100)) { for (const chunk of toChunks(vectors, 100)) {
chunks.push(chunk); chunks.push(chunk);
await pineconeNamespace.upsert([...chunk]); await pineconeNamespace.upsert([...chunk]);
@ -194,11 +210,12 @@ const PineconeDB = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (e) { } catch (e) {
console.error("addDocumentToNamespace", e.message); this.logger("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message }; return { vectorized: false, error: e.message };
} }
}, }
deleteDocumentFromNamespace: async function (namespace, docId) {
async deleteDocumentFromNamespace(namespace, docId) {
const { DocumentVectors } = require("../../../models/vectors"); const { DocumentVectors } = require("../../../models/vectors");
const { pineconeIndex } = await this.connect(); const { pineconeIndex } = await this.connect();
if (!(await this.namespaceExists(pineconeIndex, namespace))) return; if (!(await this.namespaceExists(pineconeIndex, namespace))) return;
@ -216,8 +233,9 @@ const PineconeDB = {
const indexes = knownDocuments.map((doc) => doc.id); const indexes = knownDocuments.map((doc) => doc.id);
await DocumentVectors.deleteIds(indexes); await DocumentVectors.deleteIds(indexes);
return true; return true;
}, }
"namespace-stats": async function (reqBody = {}) {
async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
const { pineconeIndex } = await this.connect(); const { pineconeIndex } = await this.connect();
@ -227,8 +245,9 @@ const PineconeDB = {
return stats return stats
? stats ? stats
: { message: "No stats were able to be fetched from DB" }; : { message: "No stats were able to be fetched from DB" };
}, }
"delete-namespace": async function (reqBody = {}) {
async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
const { pineconeIndex } = await this.connect(); const { pineconeIndex } = await this.connect();
if (!(await this.namespaceExists(pineconeIndex, namespace))) if (!(await this.namespaceExists(pineconeIndex, namespace)))
@ -239,8 +258,9 @@ const PineconeDB = {
return { return {
message: `Namespace ${namespace} was deleted along with ${details.vectorCount} vectors.`, message: `Namespace ${namespace} was deleted along with ${details.vectorCount} vectors.`,
}; };
}, }
performSimilaritySearch: async function ({
async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -275,8 +295,9 @@ const PineconeDB = {
sources: this.curateSources(sources), sources: this.curateSources(sources),
message: false, message: false,
}; };
}, }
curateSources: function (sources = []) {
curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
const { metadata = {} } = source; const { metadata = {} } = source;
@ -290,7 +311,7 @@ const PineconeDB = {
} }
} }
return documents; return documents;
}, }
}; }
module.exports.Pinecone = PineconeDB; module.exports.Pinecone = PineconeDB;

View File

@ -5,10 +5,18 @@ const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid"); const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { VectorDatabase } = require("../base");
const QDrant = { class QDrant extends VectorDatabase {
name: "QDrant", constructor() {
connect: async function () { super();
}
get name() {
return "QDrant";
}
async connect() {
if (process.env.VECTOR_DB !== "qdrant") if (process.env.VECTOR_DB !== "qdrant")
throw new Error("QDrant::Invalid ENV settings"); throw new Error("QDrant::Invalid ENV settings");
@ -26,12 +34,14 @@ const QDrant = {
); );
return { client }; return { client };
}, }
heartbeat: async function () {
async heartbeat() {
await this.connect(); await this.connect();
return { heartbeat: Number(new Date()) }; return { heartbeat: Number(new Date()) };
}, }
totalVectors: async function () {
async totalVectors() {
const { client } = await this.connect(); const { client } = await this.connect();
const { collections } = await client.getCollections(); const { collections } = await client.getCollections();
var totalVectors = 0; var totalVectors = 0;
@ -41,13 +51,15 @@ const QDrant = {
(await this.namespace(client, collection.name))?.vectorCount || 0; (await this.namespace(client, collection.name))?.vectorCount || 0;
} }
return totalVectors; return totalVectors;
}, }
namespaceCount: async function (_namespace = null) {
async namespaceCount(_namespace = null) {
const { client } = await this.connect(); const { client } = await this.connect();
const namespace = await this.namespace(client, _namespace); const namespace = await this.namespace(client, _namespace);
return namespace?.vectorCount || 0; return namespace?.vectorCount || 0;
}, }
similarityResponse: async function ({
async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -70,7 +82,7 @@ const QDrant = {
responses.forEach((response) => { responses.forEach((response) => {
if (response.score < similarityThreshold) return; if (response.score < similarityThreshold) return;
if (filterIdentifiers.includes(sourceIdentifier(response?.payload))) { if (filterIdentifiers.includes(sourceIdentifier(response?.payload))) {
console.log( this.logger(
"QDrant: A source was filtered from context as it's parent document is pinned." "QDrant: A source was filtered from context as it's parent document is pinned."
); );
return; return;
@ -86,8 +98,9 @@ const QDrant = {
}); });
return result; return result;
}, }
namespace: async function (client, namespace = null) {
async namespace(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const collection = await client.getCollection(namespace).catch(() => null); const collection = await client.getCollection(namespace).catch(() => null);
if (!collection) return null; if (!collection) return null;
@ -97,28 +110,32 @@ const QDrant = {
...collection, ...collection,
vectorCount: (await client.count(namespace, { exact: true })).count, vectorCount: (await client.count(namespace, { exact: true })).count,
}; };
}, }
hasNamespace: async function (namespace = null) {
async hasNamespace(namespace = null) {
if (!namespace) return false; if (!namespace) return false;
const { client } = await this.connect(); const { client } = await this.connect();
return await this.namespaceExists(client, namespace); return await this.namespaceExists(client, namespace);
}, }
namespaceExists: async function (client, namespace = null) {
async namespaceExists(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const collection = await client.getCollection(namespace).catch((e) => { const collection = await client.getCollection(namespace).catch((e) => {
console.error("QDrant::namespaceExists", e.message); this.logger("namespaceExists", e.message);
return null; return null;
}); });
return !!collection; return !!collection;
}, }
deleteVectorsInNamespace: async function (client, namespace = null) {
async deleteVectorsInNamespace(client, namespace = null) {
await client.deleteCollection(namespace); await client.deleteCollection(namespace);
return true; return true;
}, }
// QDrant requires a dimension aspect for collection creation // QDrant requires a dimension aspect for collection creation
// we pass this in from the first chunk to infer the dimensions like other // we pass this in from the first chunk to infer the dimensions like other
// providers do. // providers do.
getOrCreateCollection: async function (client, namespace, dimensions = null) { async getOrCreateCollection(client, namespace, dimensions = null) {
if (await this.namespaceExists(client, namespace)) { if (await this.namespaceExists(client, namespace)) {
return await client.getCollection(namespace); return await client.getCollection(namespace);
} }
@ -133,8 +150,9 @@ const QDrant = {
}, },
}); });
return await client.getCollection(namespace); return await client.getCollection(namespace);
}, }
addDocumentToNamespace: async function (
async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -146,7 +164,7 @@ const QDrant = {
const { pageContent, docId, ...metadata } = documentData; const { pageContent, docId, ...metadata } = documentData;
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) { if (cacheResult.exists) {
@ -227,7 +245,7 @@ const QDrant = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await EmbedderEngine.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
@ -276,7 +294,7 @@ const QDrant = {
if (vectors.length > 0) { if (vectors.length > 0) {
const chunks = []; const chunks = [];
console.log("Inserting vectorized chunks into QDrant collection."); this.logger("Inserting vectorized chunks into QDrant collection.");
for (const chunk of toChunks(vectors, 500)) { for (const chunk of toChunks(vectors, 500)) {
const batchIds = [], const batchIds = [],
batchVectors = [], batchVectors = [],
@ -306,11 +324,12 @@ const QDrant = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (e) { } catch (e) {
console.error("addDocumentToNamespace", e.message); this.logger("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message }; return { vectorized: false, error: e.message };
} }
}, }
deleteDocumentFromNamespace: async function (namespace, docId) {
async deleteDocumentFromNamespace(namespace, docId) {
const { DocumentVectors } = require("../../../models/vectors"); const { DocumentVectors } = require("../../../models/vectors");
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) return; if (!(await this.namespaceExists(client, namespace))) return;
@ -327,8 +346,9 @@ const QDrant = {
const indexes = knownDocuments.map((doc) => doc.id); const indexes = knownDocuments.map((doc) => doc.id);
await DocumentVectors.deleteIds(indexes); await DocumentVectors.deleteIds(indexes);
return true; return true;
}, }
performSimilaritySearch: async function ({
async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -366,8 +386,9 @@ const QDrant = {
sources: this.curateSources(sources), sources: this.curateSources(sources),
message: false, message: false,
}; };
}, }
"namespace-stats": async function (reqBody = {}) {
async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
const { client } = await this.connect(); const { client } = await this.connect();
@ -377,8 +398,9 @@ const QDrant = {
return stats return stats
? stats ? stats
: { message: "No stats were able to be fetched from DB for namespace" }; : { message: "No stats were able to be fetched from DB for namespace" };
}, }
"delete-namespace": async function (reqBody = {}) {
async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) if (!(await this.namespaceExists(client, namespace)))
@ -389,16 +411,18 @@ const QDrant = {
return { return {
message: `Namespace ${namespace} was deleted along with ${details?.vectorCount} vectors.`, message: `Namespace ${namespace} was deleted along with ${details?.vectorCount} vectors.`,
}; };
}, }
reset: async function () {
async reset() {
const { client } = await this.connect(); const { client } = await this.connect();
const response = await client.getCollections(); const response = await client.getCollections();
for (const collection of response.collections) { for (const collection of response.collections) {
await client.deleteCollection(collection.name); await client.deleteCollection(collection.name);
} }
return { reset: true }; return { reset: true };
}, }
curateSources: function (sources = []) {
curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
if (Object.keys(source).length > 0) { if (Object.keys(source).length > 0) {
@ -412,7 +436,7 @@ const QDrant = {
} }
return documents; return documents;
}, }
}; }
module.exports.QDrant = QDrant; module.exports.QDrant = QDrant;

View File

@ -6,10 +6,18 @@ const { v4: uuidv4 } = require("uuid");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers"); const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { camelCase } = require("../../helpers/camelcase"); const { camelCase } = require("../../helpers/camelcase");
const { sourceIdentifier } = require("../../chats"); const { sourceIdentifier } = require("../../chats");
const { VectorDatabase } = require("../base");
const Weaviate = { class Weaviate extends VectorDatabase {
name: "Weaviate", constructor() {
connect: async function () { super();
}
get name() {
return "Weaviate";
}
async connect() {
if (process.env.VECTOR_DB !== "weaviate") if (process.env.VECTOR_DB !== "weaviate")
throw new Error("Weaviate::Invalid ENV settings"); throw new Error("Weaviate::Invalid ENV settings");
@ -28,12 +36,14 @@ const Weaviate = {
"Weaviate::Invalid Alive signal received - is the service online?" "Weaviate::Invalid Alive signal received - is the service online?"
); );
return { client }; return { client };
}, }
heartbeat: async function () {
async heartbeat() {
await this.connect(); await this.connect();
return { heartbeat: Number(new Date()) }; return { heartbeat: Number(new Date()) };
}, }
totalVectors: async function () {
async totalVectors() {
const { client } = await this.connect(); const { client } = await this.connect();
const collectionNames = await this.allNamespaces(client); const collectionNames = await this.allNamespaces(client);
var totalVectors = 0; var totalVectors = 0;
@ -41,8 +51,9 @@ const Weaviate = {
totalVectors += await this.namespaceCountWithClient(client, name); totalVectors += await this.namespaceCountWithClient(client, name);
} }
return totalVectors; return totalVectors;
}, }
namespaceCountWithClient: async function (client, namespace) {
async namespaceCountWithClient(client, namespace) {
try { try {
const response = await client.graphql const response = await client.graphql
.aggregate() .aggregate()
@ -53,11 +64,12 @@ const Weaviate = {
response?.data?.Aggregate?.[camelCase(namespace)]?.[0]?.meta?.count || 0 response?.data?.Aggregate?.[camelCase(namespace)]?.[0]?.meta?.count || 0
); );
} catch (e) { } catch (e) {
console.error(`Weaviate:namespaceCountWithClient`, e.message); this.logger(`namespaceCountWithClient`, e.message);
return 0; return 0;
} }
}, }
namespaceCount: async function (namespace = null) {
async namespaceCount(namespace = null) {
try { try {
const { client } = await this.connect(); const { client } = await this.connect();
const response = await client.graphql const response = await client.graphql
@ -70,11 +82,12 @@ const Weaviate = {
response?.data?.Aggregate?.[camelCase(namespace)]?.[0]?.meta?.count || 0 response?.data?.Aggregate?.[camelCase(namespace)]?.[0]?.meta?.count || 0
); );
} catch (e) { } catch (e) {
console.error(`Weaviate:namespaceCountWithClient`, e.message); this.logger(`namespaceCountWithClient`, e.message);
return 0; return 0;
} }
}, }
similarityResponse: async function ({
async similarityResponse({
client, client,
namespace, namespace,
queryVector, queryVector,
@ -109,8 +122,8 @@ const Weaviate = {
} = response; } = response;
if (certainty < similarityThreshold) return; if (certainty < similarityThreshold) return;
if (filterIdentifiers.includes(sourceIdentifier(rest))) { if (filterIdentifiers.includes(sourceIdentifier(rest))) {
console.log( this.logger(
"Weaviate: A source was filtered from context as it's parent document is pinned." "A source was filtered from context as it's parent document is pinned."
); );
return; return;
} }
@ -120,17 +133,19 @@ const Weaviate = {
}); });
return result; return result;
}, }
allNamespaces: async function (client) {
async allNamespaces(client) {
try { try {
const { classes = [] } = await client.schema.getter().do(); const { classes = [] } = await client.schema.getter().do();
return classes.map((classObj) => classObj.class); return classes.map((classObj) => classObj.class);
} catch (e) { } catch (e) {
console.error("Weaviate::AllNamespace", e); this.logger("AllNamespace", e);
return []; return [];
} }
}, }
namespace: async function (client, namespace = null) {
async namespace(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
if (!(await this.namespaceExists(client, namespace))) return null; if (!(await this.namespaceExists(client, namespace))) return null;
@ -143,8 +158,9 @@ const Weaviate = {
...weaviateClass, ...weaviateClass,
vectorCount: await this.namespaceCount(namespace), vectorCount: await this.namespaceCount(namespace),
}; };
}, }
addVectors: async function (client, vectors = []) {
async addVectors(client, vectors = []) {
const response = { success: true, errors: new Set([]) }; const response = { success: true, errors: new Set([]) };
const results = await client.batch const results = await client.batch
.objectsBatcher() .objectsBatcher()
@ -160,23 +176,27 @@ const Weaviate = {
response.errors = [...response.errors]; response.errors = [...response.errors];
return response; return response;
}, }
hasNamespace: async function (namespace = null) {
async hasNamespace(namespace = null) {
if (!namespace) return false; if (!namespace) return false;
const { client } = await this.connect(); const { client } = await this.connect();
const weaviateClasses = await this.allNamespaces(client); const weaviateClasses = await this.allNamespaces(client);
return weaviateClasses.includes(camelCase(namespace)); return weaviateClasses.includes(camelCase(namespace));
}, }
namespaceExists: async function (client, namespace = null) {
async namespaceExists(client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided."); if (!namespace) throw new Error("No namespace value provided.");
const weaviateClasses = await this.allNamespaces(client); const weaviateClasses = await this.allNamespaces(client);
return weaviateClasses.includes(camelCase(namespace)); return weaviateClasses.includes(camelCase(namespace));
}, }
deleteVectorsInNamespace: async function (client, namespace = null) {
async deleteVectorsInNamespace(client, namespace = null) {
await client.schema.classDeleter().withClassName(camelCase(namespace)).do(); await client.schema.classDeleter().withClassName(camelCase(namespace)).do();
return true; return true;
}, }
addDocumentToNamespace: async function (
async addDocumentToNamespace(
namespace, namespace,
documentData = {}, documentData = {},
fullFilePath = null, fullFilePath = null,
@ -192,7 +212,7 @@ const Weaviate = {
} = documentData; } = documentData;
if (!pageContent || pageContent.length == 0) return false; if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace); this.logger("Adding new vectorized document into namespace", namespace);
if (!skipCache) { if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath); const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) { if (cacheResult.exists) {
@ -236,7 +256,7 @@ const Weaviate = {
const { success: additionResult, errors = [] } = const { success: additionResult, errors = [] } =
await this.addVectors(client, vectors); await this.addVectors(client, vectors);
if (!additionResult) { if (!additionResult) {
console.error("Weaviate::addVectors failed to insert", errors); this.logger("addVectors failed to insert", errors);
throw new Error("Error embedding into Weaviate"); throw new Error("Error embedding into Weaviate");
} }
} }
@ -267,7 +287,7 @@ const Weaviate = {
}); });
const textChunks = await textSplitter.splitText(pageContent); const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length); this.logger("Snippets created from document:", textChunks.length);
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const vectorValues = await EmbedderEngine.embedChunks(textChunks); const vectorValues = await EmbedderEngine.embedChunks(textChunks);
@ -322,13 +342,13 @@ const Weaviate = {
const chunks = []; const chunks = [];
for (const chunk of toChunks(vectors, 500)) chunks.push(chunk); for (const chunk of toChunks(vectors, 500)) chunks.push(chunk);
console.log("Inserting vectorized chunks into Weaviate collection."); this.logger("Inserting vectorized chunks into Weaviate collection.");
const { success: additionResult, errors = [] } = await this.addVectors( const { success: additionResult, errors = [] } = await this.addVectors(
client, client,
vectors vectors
); );
if (!additionResult) { if (!additionResult) {
console.error("Weaviate::addVectors failed to insert", errors); this.logger("addVectors failed to insert", errors);
throw new Error("Error embedding into Weaviate"); throw new Error("Error embedding into Weaviate");
} }
await storeVectorResult(chunks, fullFilePath); await storeVectorResult(chunks, fullFilePath);
@ -337,11 +357,12 @@ const Weaviate = {
await DocumentVectors.bulkInsert(documentVectors); await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null }; return { vectorized: true, error: null };
} catch (e) { } catch (e) {
console.error("addDocumentToNamespace", e.message); this.logger("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message }; return { vectorized: false, error: e.message };
} }
}, }
deleteDocumentFromNamespace: async function (namespace, docId) {
async deleteDocumentFromNamespace(namespace, docId) {
const { DocumentVectors } = require("../../../models/vectors"); const { DocumentVectors } = require("../../../models/vectors");
const { client } = await this.connect(); const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) return; if (!(await this.namespaceExists(client, namespace))) return;
@ -360,8 +381,9 @@ const Weaviate = {
const indexes = knownDocuments.map((doc) => doc.id); const indexes = knownDocuments.map((doc) => doc.id);
await DocumentVectors.deleteIds(indexes); await DocumentVectors.deleteIds(indexes);
return true; return true;
}, }
performSimilaritySearch: async function ({
async performSimilaritySearch({
namespace = null, namespace = null,
input = "", input = "",
LLMConnector = null, LLMConnector = null,
@ -399,8 +421,9 @@ const Weaviate = {
sources: this.curateSources(sources), sources: this.curateSources(sources),
message: false, message: false,
}; };
}, }
"namespace-stats": async function (reqBody = {}) {
async "namespace-stats"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required"); if (!namespace) throw new Error("namespace required");
const { client } = await this.connect(); const { client } = await this.connect();
@ -408,8 +431,9 @@ const Weaviate = {
return stats return stats
? stats ? stats
: { message: "No stats were able to be fetched from DB for namespace" }; : { message: "No stats were able to be fetched from DB for namespace" };
}, }
"delete-namespace": async function (reqBody = {}) {
async "delete-namespace"(reqBody = {}) {
const { namespace = null } = reqBody; const { namespace = null } = reqBody;
const { client } = await this.connect(); const { client } = await this.connect();
const details = await this.namespace(client, namespace); const details = await this.namespace(client, namespace);
@ -419,16 +443,18 @@ const Weaviate = {
details?.vectorCount details?.vectorCount
} vectors.`, } vectors.`,
}; };
}, }
reset: async function () {
async reset() {
const { client } = await this.connect(); const { client } = await this.connect();
const weaviateClasses = await this.allNamespaces(client); const weaviateClasses = await this.allNamespaces(client);
for (const weaviateClass of weaviateClasses) { for (const weaviateClass of weaviateClasses) {
await client.schema.classDeleter().withClassName(weaviateClass).do(); await client.schema.classDeleter().withClassName(weaviateClass).do();
} }
return { reset: true }; return { reset: true };
}, }
curateSources: function (sources = []) {
curateSources(sources = []) {
const documents = []; const documents = [];
for (const source of sources) { for (const source of sources) {
if (Object.keys(source).length > 0) { if (Object.keys(source).length > 0) {
@ -440,8 +466,9 @@ const Weaviate = {
} }
return documents; return documents;
}, }
flattenObjectForWeaviate: function (obj = {}) {
flattenObjectForWeaviate(obj = {}) {
// Note this function is not generic, it is designed specifically for Weaviate // Note this function is not generic, it is designed specifically for Weaviate
// https://weaviate.io/developers/weaviate/config-refs/datatypes#introduction // https://weaviate.io/developers/weaviate/config-refs/datatypes#introduction
// Credit to LangchainJS // Credit to LangchainJS
@ -478,7 +505,7 @@ const Weaviate = {
} }
return flattenedObject; return flattenedObject;
}, }
}; }
module.exports.Weaviate = Weaviate; module.exports.Weaviate = Weaviate;

View File

@ -1,33 +1,22 @@
const { const { MilvusClient } = require("@zilliz/milvus2-sdk-node");
DataType, const { Milvus } = require("../milvus");
MetricType,
IndexType,
MilvusClient,
} = require("@zilliz/milvus2-sdk-node");
const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");
// Zilliz is basically a copy of Milvus DB class with a different constructor /**
// to connect to the cloud * Zilliz is the cloud version of Milvus so we can just extend the
const Zilliz = { * Milvus class and override the connect method
name: "Zilliz", */
// Milvus/Zilliz only allows letters, numbers, and underscores in collection names class Zilliz extends Milvus {
// so we need to enforce that by re-normalizing the names when communicating with constructor() {
// the DB. super();
// If the first char of the collection is not an underscore or letter the collection name will be invalid. }
normalize: function (inputString) {
let normalized = inputString.replace(/[^a-zA-Z0-9_]/g, "_"); get name() {
if (new RegExp(/^[a-zA-Z_]/).test(normalized.slice(0, 1))) return "Zilliz";
normalized = `anythingllm_${normalized}`; }
return normalized;
}, async connect() {
connect: async function () {
if (process.env.VECTOR_DB !== "zilliz") if (process.env.VECTOR_DB !== "zilliz")
throw new Error("Zilliz::Invalid ENV settings"); throw new Error(`${this.name}::Invalid ENV settings`);
const client = new MilvusClient({ const client = new MilvusClient({
address: process.env.ZILLIZ_ENDPOINT, address: process.env.ZILLIZ_ENDPOINT,
@ -37,366 +26,11 @@ const Zilliz = {
const { isHealthy } = await client.checkHealth(); const { isHealthy } = await client.checkHealth();
if (!isHealthy) if (!isHealthy)
throw new Error( throw new Error(
"Zilliz::Invalid Heartbeat received - is the instance online?" `${this.name}::Invalid Heartbeat received - is the instance online?`
); );
return { client }; return { client };
}, }
heartbeat: async function () { }
await this.connect();
return { heartbeat: Number(new Date()) };
},
totalVectors: async function () {
const { client } = await this.connect();
const { collection_names } = await client.listCollections();
const total = collection_names.reduce(async (acc, collection_name) => {
const statistics = await client.getCollectionStatistics({
collection_name: this.normalize(collection_name),
});
return Number(acc) + Number(statistics?.data?.row_count ?? 0);
}, 0);
return total;
},
namespaceCount: async function (_namespace = null) {
const { client } = await this.connect();
const statistics = await client.getCollectionStatistics({
collection_name: this.normalize(_namespace),
});
return Number(statistics?.data?.row_count ?? 0);
},
namespace: async function (client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided.");
const collection = await client
.getCollectionStatistics({ collection_name: this.normalize(namespace) })
.catch(() => null);
return collection;
},
hasNamespace: async function (namespace = null) {
if (!namespace) return false;
const { client } = await this.connect();
return await this.namespaceExists(client, namespace);
},
namespaceExists: async function (client, namespace = null) {
if (!namespace) throw new Error("No namespace value provided.");
const { value } = await client
.hasCollection({ collection_name: this.normalize(namespace) })
.catch((e) => {
console.error("Zilliz::namespaceExists", e.message);
return { value: false };
});
return value;
},
deleteVectorsInNamespace: async function (client, namespace = null) {
await client.dropCollection({ collection_name: this.normalize(namespace) });
return true;
},
// Zilliz requires a dimension aspect for collection creation
// we pass this in from the first chunk to infer the dimensions like other
// providers do.
getOrCreateCollection: async function (client, namespace, dimensions = null) {
const isExists = await this.namespaceExists(client, namespace);
if (!isExists) {
if (!dimensions)
throw new Error(
`Zilliz:getOrCreateCollection Unable to infer vector dimension from input. Open an issue on GitHub for support.`
);
await client.createCollection({
collection_name: this.normalize(namespace),
fields: [
{
name: "id",
description: "id",
data_type: DataType.VarChar,
max_length: 255,
is_primary_key: true,
},
{
name: "vector",
description: "vector",
data_type: DataType.FloatVector,
dim: dimensions,
},
{
name: "metadata",
description: "metadata",
data_type: DataType.JSON,
},
],
});
await client.createIndex({
collection_name: this.normalize(namespace),
field_name: "vector",
index_type: IndexType.AUTOINDEX,
metric_type: MetricType.COSINE,
});
await client.loadCollectionSync({
collection_name: this.normalize(namespace),
});
}
},
addDocumentToNamespace: async function (
namespace,
documentData = {},
fullFilePath = null,
skipCache = false
) {
const { DocumentVectors } = require("../../../models/vectors");
try {
let vectorDimension = null;
const { pageContent, docId, ...metadata } = documentData;
if (!pageContent || pageContent.length == 0) return false;
console.log("Adding new vectorized document into namespace", namespace);
if (!skipCache) {
const cacheResult = await cachedVectorInformation(fullFilePath);
if (cacheResult.exists) {
const { client } = await this.connect();
const { chunks } = cacheResult;
const documentVectors = [];
vectorDimension = chunks[0][0].values.length || null;
await this.getOrCreateCollection(client, namespace, vectorDimension);
for (const chunk of chunks) {
// Before sending to Pinecone and saving the records to our db
// we need to assign the id of each chunk that is stored in the cached file.
const newChunks = chunk.map((chunk) => {
const id = uuidv4();
documentVectors.push({ docId, vectorId: id });
return { id, vector: chunk.values, metadata: chunk.metadata };
});
const insertResult = await client.insert({
collection_name: this.normalize(namespace),
data: newChunks,
});
if (insertResult?.status.error_code !== "Success") {
throw new Error(
`Error embedding into Zilliz! Reason:${insertResult?.status.reason}`
);
}
}
await DocumentVectors.bulkInsert(documentVectors);
await client.flushSync({
collection_names: [this.normalize(namespace)],
});
return { vectorized: true, error: null };
}
}
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
20
),
chunkHeaderMeta: TextSplitter.buildHeaderMeta(metadata),
chunkPrefix: EmbedderEngine?.embeddingPrefix,
});
const textChunks = await textSplitter.splitText(pageContent);
console.log("Snippets created from document:", textChunks.length);
const documentVectors = [];
const vectors = [];
const vectorValues = await EmbedderEngine.embedChunks(textChunks);
if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
if (!vectorDimension) vectorDimension = vector.length;
const vectorRecord = {
id: uuidv4(),
values: vector,
// [DO NOT REMOVE]
// LangChain will be unable to find your text if you embed manually and dont include the `text` key.
metadata: { ...metadata, text: textChunks[i] },
};
vectors.push(vectorRecord);
documentVectors.push({ docId, vectorId: vectorRecord.id });
}
} else {
throw new Error(
"Could not embed document chunks! This document will not be recorded."
);
}
if (vectors.length > 0) {
const chunks = [];
const { client } = await this.connect();
await this.getOrCreateCollection(client, namespace, vectorDimension);
console.log("Inserting vectorized chunks into Zilliz.");
for (const chunk of toChunks(vectors, 100)) {
chunks.push(chunk);
const insertResult = await client.insert({
collection_name: this.normalize(namespace),
data: chunk.map((item) => ({
id: item.id,
vector: item.values,
metadata: item.metadata,
})),
});
if (insertResult?.status.error_code !== "Success") {
throw new Error(
`Error embedding into Zilliz! Reason:${insertResult?.status.reason}`
);
}
}
await storeVectorResult(chunks, fullFilePath);
await client.flushSync({
collection_names: [this.normalize(namespace)],
});
}
await DocumentVectors.bulkInsert(documentVectors);
return { vectorized: true, error: null };
} catch (e) {
console.error("addDocumentToNamespace", e.message);
return { vectorized: false, error: e.message };
}
},
deleteDocumentFromNamespace: async function (namespace, docId) {
const { DocumentVectors } = require("../../../models/vectors");
const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) return;
const knownDocuments = await DocumentVectors.where({ docId });
if (knownDocuments.length === 0) return;
const vectorIds = knownDocuments.map((doc) => doc.vectorId);
const queryIn = vectorIds.map((v) => `'${v}'`).join(",");
await client.deleteEntities({
collection_name: this.normalize(namespace),
expr: `id in [${queryIn}]`,
});
const indexes = knownDocuments.map((doc) => doc.id);
await DocumentVectors.deleteIds(indexes);
// Even after flushing Zilliz can take some time to re-calc the count
// so all we can hope to do is flushSync so that the count can be correct
// on a later call.
await client.flushSync({ collection_names: [this.normalize(namespace)] });
return true;
},
performSimilaritySearch: async function ({
namespace = null,
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace))) {
return {
contextTexts: [],
sources: [],
message: "Invalid query - no documents found for workspace!",
};
}
const queryVector = await LLMConnector.embedTextInput(input);
const { contextTexts, sourceDocuments } = await this.similarityResponse({
client,
namespace,
queryVector,
similarityThreshold,
topN,
filterIdentifiers,
});
const sources = sourceDocuments.map((doc, i) => {
return { metadata: doc, text: contextTexts[i] };
});
return {
contextTexts,
sources: this.curateSources(sources),
message: false,
};
},
similarityResponse: async function ({
client,
namespace,
queryVector,
similarityThreshold = 0.25,
topN = 4,
filterIdentifiers = [],
}) {
const result = {
contextTexts: [],
sourceDocuments: [],
scores: [],
};
const response = await client.search({
collection_name: this.normalize(namespace),
vectors: queryVector,
limit: topN,
});
response.results.forEach((match) => {
if (match.score < similarityThreshold) return;
if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) {
console.log(
"Zilliz: A source was filtered from context as it's parent document is pinned."
);
return;
}
result.contextTexts.push(match.metadata.text);
result.sourceDocuments.push({
...match.metadata,
score: match.score,
});
result.scores.push(match.score);
});
return result;
},
"namespace-stats": async function (reqBody = {}) {
const { namespace = null } = reqBody;
if (!namespace) throw new Error("namespace required");
const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace)))
throw new Error("Namespace by that name does not exist.");
const stats = await this.namespace(client, namespace);
return stats
? stats
: { message: "No stats were able to be fetched from DB for namespace" };
},
"delete-namespace": async function (reqBody = {}) {
const { namespace = null } = reqBody;
const { client } = await this.connect();
if (!(await this.namespaceExists(client, namespace)))
throw new Error("Namespace by that name does not exist.");
const statistics = await this.namespace(client, namespace);
await this.deleteVectorsInNamespace(client, namespace);
const vectorCount = Number(statistics?.data?.row_count ?? 0);
return {
message: `Namespace ${namespace} was deleted along with ${vectorCount} vectors.`,
};
},
curateSources: function (sources = []) {
const documents = [];
for (const source of sources) {
const { metadata = {} } = source;
if (Object.keys(metadata).length > 0) {
documents.push({
...metadata,
...(source.text ? { text: source.text } : {}),
});
}
}
return documents;
},
};
module.exports.Zilliz = Zilliz; module.exports.Zilliz = Zilliz;