282 return relevancy score with similarityresponse (#304)

* include score value in similarityResponse for weaviate

* include score value in si
milarityResponse for qdrant

* include score value in si
milarityResponse for pinecone

* include score value in similarityResponse for chroma

* include score value in similarityResponse for lancedb

* distance to similarity

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2023-10-30 12:46:38 -07:00 committed by GitHub
parent 26dba59249
commit 669d7a396d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 2 deletions

View File

@ -49,6 +49,12 @@ const Chroma = {
} }
return totalVectors; return totalVectors;
}, },
distanceToSimilarity: function (distance = null) {
if (distance === null || typeof distance !== "number") return 0.0;
if (distance >= 1.0) return 1;
if (distance <= 0) return 0;
return 1 - distance;
},
namespaceCount: async function (_namespace = null) { namespaceCount: async function (_namespace = null) {
const { client } = await this.connect(); const { client } = await this.connect();
const namespace = await this.namespace(client, _namespace); const namespace = await this.namespace(client, _namespace);
@ -59,6 +65,7 @@ const Chroma = {
const result = { const result = {
contextTexts: [], contextTexts: [],
sourceDocuments: [], sourceDocuments: [],
scores: [],
}; };
const response = await collection.query({ const response = await collection.query({
@ -68,6 +75,7 @@ const Chroma = {
response.ids[0].forEach((_, i) => { response.ids[0].forEach((_, i) => {
result.contextTexts.push(response.documents[0][i]); result.contextTexts.push(response.documents[0][i]);
result.sourceDocuments.push(response.metadatas[0][i]); result.sourceDocuments.push(response.metadatas[0][i]);
result.scores.push(this.distanceToSimilarity(response.distances[0][i]));
}); });
return result; return result;

View File

@ -18,6 +18,12 @@ const LanceDb = {
const client = await lancedb.connect(this.uri); const client = await lancedb.connect(this.uri);
return { client }; return { client };
}, },
distanceToSimilarity: function (distance = null) {
if (distance === null || typeof distance !== "number") return 0.0;
if (distance >= 1.0) return 1;
if (distance <= 0) return 0;
return 1 - distance;
},
heartbeat: async function () { heartbeat: async function () {
await this.connect(); await this.connect();
return { heartbeat: Number(new Date()) }; return { heartbeat: Number(new Date()) };
@ -54,6 +60,7 @@ const LanceDb = {
const result = { const result = {
contextTexts: [], contextTexts: [],
sourceDocuments: [], sourceDocuments: [],
scores: [],
}; };
const response = await collection const response = await collection
@ -66,6 +73,7 @@ const LanceDb = {
const { vector: _, ...rest } = item; const { vector: _, ...rest } = item;
result.contextTexts.push(rest.text); result.contextTexts.push(rest.text);
result.sourceDocuments.push(rest); result.sourceDocuments.push(rest);
result.scores.push(this.distanceToSimilarity(item.score));
}); });
return result; return result;

View File

@ -41,6 +41,7 @@ const Pinecone = {
const result = { const result = {
contextTexts: [], contextTexts: [],
sourceDocuments: [], sourceDocuments: [],
scores: [],
}; };
const response = await index.query({ const response = await index.query({
queryRequest: { queryRequest: {
@ -54,6 +55,7 @@ const Pinecone = {
response.matches.forEach((match) => { response.matches.forEach((match) => {
result.contextTexts.push(match.metadata.text); result.contextTexts.push(match.metadata.text);
result.sourceDocuments.push(match); result.sourceDocuments.push(match);
result.scores.push(match.score);
}); });
return result; return result;

View File

@ -51,11 +51,13 @@ const QDrant = {
const result = { const result = {
contextTexts: [], contextTexts: [],
sourceDocuments: [], sourceDocuments: [],
scores: [],
}; };
const responses = await client.search(namespace, { const responses = await client.search(namespace, {
vector: queryVector, vector: queryVector,
limit: 4, limit: 4,
with_payload: true,
}); });
responses.forEach((response) => { responses.forEach((response) => {
@ -64,6 +66,7 @@ const QDrant = {
...(response?.payload || {}), ...(response?.payload || {}),
id: response.id, id: response.id,
}); });
result.scores.push(response.score);
}); });
return result; return result;

View File

@ -77,6 +77,7 @@ const Weaviate = {
const result = { const result = {
contextTexts: [], contextTexts: [],
sourceDocuments: [], sourceDocuments: [],
scores: [],
}; };
const weaviateClass = await this.namespace(client, namespace); const weaviateClass = await this.namespace(client, namespace);
@ -84,7 +85,7 @@ const Weaviate = {
const queryResponse = await client.graphql const queryResponse = await client.graphql
.get() .get()
.withClassName(camelCase(namespace)) .withClassName(camelCase(namespace))
.withFields(`${fields} _additional { id }`) .withFields(`${fields} _additional { id certainty }`)
.withNearVector({ vector: queryVector }) .withNearVector({ vector: queryVector })
.withLimit(4) .withLimit(4)
.do(); .do();
@ -94,11 +95,12 @@ const Weaviate = {
// In Weaviate we have to pluck id from _additional and spread it into the rest // In Weaviate we have to pluck id from _additional and spread it into the rest
// of the properties. // of the properties.
const { const {
_additional: { id }, _additional: { id, certainty },
...rest ...rest
} = response; } = response;
result.contextTexts.push(rest.text); result.contextTexts.push(rest.text);
result.sourceDocuments.push({ ...rest, id }); result.sourceDocuments.push({ ...rest, id });
result.scores.push(certainty);
}); });
return result; return result;