From 6c0a80ba171b2d8d4591500d42f17f0f82c92041 Mon Sep 17 00:00:00 2001 From: Parth Patel <41171860+parthvnp@users.noreply.github.com> Date: Thu, 29 Jun 2023 03:08:21 -0400 Subject: [PATCH 01/10] Add support for Qdrant Vector Store (#453) * Add Qdrant vector store client * Add setup for Qdrant vector store * Add lazy import for Qdrant client library * Add import_qdrant to index file * Use models not types for building configs * Add tests for Qdrant vector store Signed-off-by: sunilkumardash9 --- gptcache/manager/vector_data/manager.py | 40 +++++++++ gptcache/manager/vector_data/qdrant.py | 107 ++++++++++++++++++++++++ gptcache/utils/__init__.py | 7 +- tests/unit_tests/manager/test_qdrant.py | 33 ++++++++ 4 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 gptcache/manager/vector_data/qdrant.py create mode 100644 tests/unit_tests/manager/test_qdrant.py diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 6453c88e..86616b5b 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -19,6 +19,12 @@ PGVECTOR_URL = "postgresql://postgres:postgres@localhost:5432/postgres" PGVECTOR_INDEX_PARAMS = {"index_type": "L2", "params": {"lists": 100, "probes": 10}} +QDRANT_GRPC_PORT = 6334 +QDRANT_HTTP_PORT = 6333 +QDRANT_INDEX_PARAMS = {"ef_construct": 100, "m": 16} +QDRANT_DEFAULT_LOCATION = "./qdrant_data" +QDRANT_FLUSH_INTERVAL_SEC = 5 + COLLECTION_NAME = "gptcache" @@ -217,6 +223,40 @@ def get(name, **kwargs): collection_name=collection_name, top_k=top_k, ) + elif name == "qdrant": + from gptcache.manager.vector_data.qdrant import QdrantVectorStore + url = kwargs.get("url", None) + port = kwargs.get("port", QDRANT_HTTP_PORT) + grpc_port = kwargs.get("grpc_port", QDRANT_GRPC_PORT) + prefer_grpc = kwargs.get("prefer_grpc", False) + https = kwargs.get("https", False) + api_key = kwargs.get("api_key", None) + prefix = kwargs.get("prefix", None) + timeout = kwargs.get("timeout", None) + host = kwargs.get("host", None) + collection_name = kwargs.get("collection_name", COLLECTION_NAME) + location = kwargs.get("location", QDRANT_DEFAULT_LOCATION) + dimension = kwargs.get("dimension", DIMENSION) + top_k: int = kwargs.get("top_k", TOP_K) + flush_interval_sec = kwargs.get("flush_interval_sec", QDRANT_FLUSH_INTERVAL_SEC) + index_params = kwargs.get("index_params", QDRANT_INDEX_PARAMS) + vector_base = QdrantVectorStore( + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + collection_name=collection_name, + location=location, + dimension=dimension, + top_k=top_k, + flush_interval_sec=flush_interval_sec, + index_params=index_params, + ) else: raise NotFoundError("vector store", name) return vector_base diff --git a/gptcache/manager/vector_data/qdrant.py b/gptcache/manager/vector_data/qdrant.py new file mode 100644 index 00000000..93bcf99e --- /dev/null +++ b/gptcache/manager/vector_data/qdrant.py @@ -0,0 +1,107 @@ +from typing import List, Optional +import numpy as np + +from gptcache.utils import import_qdrant +from gptcache.utils.log import gptcache_log +from gptcache.manager.vector_data.base import VectorBase, VectorData + +import_qdrant() + +from qdrant_client import QdrantClient # pylint: disable=C0413 +from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \ + Distance # pylint: disable=C0413 + + +class QdrantVectorStore(VectorBase): + + def __init__( + self, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + collection_name: Optional[str] = "gptcache", + location: Optional[str] = "./qdrant", + dimension: int = 0, + top_k: int = 1, + flush_interval_sec: int = 5, + index_params: Optional[dict] = None, + ): + if dimension <= 0: + raise ValueError( + f"invalid `dim` param: {dimension} in the Qdrant vector store." + ) + self._client: QdrantClient + self._collection_name = collection_name + self._in_memory = location == ":memory:" + self.dimension = dimension + self.top_k = top_k + if self._in_memory or location is not None: + self._create_local(location) + else: + self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https) + self._create_collection(collection_name, flush_interval_sec, index_params) + + def _create_local(self, location): + self._client = QdrantClient(location=location) + + def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https): + self._client = QdrantClient( + url=url, + port=port, + api_key=api_key, + timeout=timeout, + host=host, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + prefix=prefix, + https=https, + ) + + def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None): + hnsw_config = HnswConfigDiff(**(index_params or {})) + vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE, + hnsw_config=hnsw_config) + optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000, + flush_interval_sec=flush_interval_sec) + # check if the collection exists + existing_collections = self._client.get_collections() + for existing_collection in existing_collections.collections: + if existing_collection.name == collection_name: + gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name) + break + else: + self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config, + optimizers_config=optimizers_config) + + def mul_add(self, datas: List[VectorData]): + points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas] + self._client.upsert(collection_name=self._collection_name, points=points, wait=False) + + def search(self, data: np.ndarray, top_k: int = -1): + if top_k == -1: + top_k = self.top_k + reshaped_data = data.reshape(-1).tolist() + search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data, + limit=top_k) + return list(map(lambda x: (x.score, x.id), search_result)) + + def delete(self, ids: List[str]): + self._client.delete(collection_name=self._collection_name, points_selector=ids) + + def rebuild(self, ids=None): # pylint: disable=unused-argument + optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000) + self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config) + + def flush(self): + # no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant + pass + + + def close(self): + self.flush() diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index ae3d7b47..7a919721 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -38,7 +38,8 @@ "import_paddlenlp", "import_tiktoken", "import_fastapi", - "import_redis" + "import_redis", + "import_qdrant" ] import importlib.util @@ -65,6 +66,10 @@ def import_milvus_lite(): _check_library("milvus") +def import_qdrant(): + _check_library("qdrant_client") + + def import_sbert(): _check_library("sentence_transformers", package="sentence-transformers") diff --git a/tests/unit_tests/manager/test_qdrant.py b/tests/unit_tests/manager/test_qdrant.py new file mode 100644 index 00000000..2dafe06f --- /dev/null +++ b/tests/unit_tests/manager/test_qdrant.py @@ -0,0 +1,33 @@ +import os +import unittest + +import numpy as np + +from gptcache.manager.vector_data import VectorBase +from gptcache.manager.vector_data.base import VectorData + + +class TestQdrant(unittest.TestCase): + def test_normal(self): + size = 10 + dim = 2 + top_k = 10 + qdrant = VectorBase( + "qdrant", + top_k=top_k, + dimension=dim, + location=":memory:" + ) + data = np.random.randn(size, dim).astype(np.float32) + qdrant.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))]) + search_result = qdrant.search(data[0], top_k) + self.assertEqual(len(search_result), top_k) + qdrant.mul_add([VectorData(id=size, data=data[0])]) + ret = qdrant.search(data[0]) + self.assertIn(ret[0][1], [0, size]) + self.assertIn(ret[1][1], [0, size]) + qdrant.delete([0, 1, 2, 3, 4, 5, size]) + ret = qdrant.search(data[0]) + self.assertNotIn(ret[0][1], [0, size]) + qdrant.rebuild() + qdrant.close() From 3bde5c450f263b491f053ef717f52ba4c6012134 Mon Sep 17 00:00:00 2001 From: SimFG Date: Thu, 29 Jun 2023 17:01:12 +0800 Subject: [PATCH 02/10] Fix the pylint error and add the chromedb test (#457) Signed-off-by: SimFG Signed-off-by: sunilkumardash9 --- gptcache/adapter/adapter.py | 10 +- gptcache/config.py | 4 +- gptcache/manager/vector_data/chroma.py | 14 +- gptcache/manager/vector_data/qdrant.py | 111 ++++++---- .../test_sqlite_milvus_sbert.py | 207 +++++++++--------- tests/unit_tests/embedding/test_sbert.py | 12 +- tests/unit_tests/manager/test_chromadb.py | 3 +- 7 files changed, 211 insertions(+), 150 deletions(-) diff --git a/gptcache/adapter/adapter.py b/gptcache/adapter/adapter.py index fc1350fc..acf16d29 100644 --- a/gptcache/adapter/adapter.py +++ b/gptcache/adapter/adapter.py @@ -1,4 +1,5 @@ import numpy as np + from gptcache import cache from gptcache.processor.post import temperature_softmax from gptcache.utils.error import NotInitError @@ -16,7 +17,6 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg :param kwargs: llm kwargs :return: llm result """ - health_check_flag = kwargs.pop("health_check", False) search_only_flag = kwargs.pop("search_only", False) user_temperature = "temperature" in kwargs user_top_k = "top_k" in kwargs @@ -114,7 +114,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg continue # cache consistency check - if health_check_flag: + if chat_cache.config.data_check: is_healthy = cache_health_check( chat_cache.data_manager.v, { @@ -202,7 +202,7 @@ def post_process(): kwargs["cache_context"] = context kwargs["cache_skip"] = cache_skip kwargs["cache_factor"] = cache_factor - kwargs["search_only_flag"] = search_only_flag + kwargs["search_only"] = search_only_flag llm_data = adapt( llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs ) @@ -467,8 +467,8 @@ def update_cache_func(handled_llm_data, question=None): llm_data = update_cache_callback( llm_data, update_cache_func, *args, **kwargs ) - except Exception as e: # pylint: disable=W0703 - gptcache_log.warning("failed to save the data to cache, error: %s", e) + except Exception: # pylint: disable=W0703 + gptcache_log.error("failed to save the data to cache", exc_info=True) return llm_data diff --git a/gptcache/config.py b/gptcache/config.py index 7b62aa19..aad2ef18 100644 --- a/gptcache/config.py +++ b/gptcache/config.py @@ -44,7 +44,8 @@ def __init__( enable_token_counter: bool = True, input_summary_len: Optional[int] = None, context_len: Optional[int] = None, - skip_list: List[str] = None + skip_list: List[str] = None, + data_check: bool = False, ): if similarity_threshold < 0 or similarity_threshold > 1: raise CacheError( @@ -61,3 +62,4 @@ def __init__( if skip_list is None: skip_list = ["system", "assistant"] self.skip_list = skip_list + self.data_check = data_check diff --git a/gptcache/manager/vector_data/chroma.py b/gptcache/manager/vector_data/chroma.py index 82494943..2d9eb161 100644 --- a/gptcache/manager/vector_data/chroma.py +++ b/gptcache/manager/vector_data/chroma.py @@ -1,5 +1,7 @@ from typing import List + import numpy as np + from gptcache.manager.vector_data.base import VectorBase, VectorData from gptcache.utils import import_chromadb, import_torch @@ -45,7 +47,7 @@ def __init__( self._collection = self._client.get_or_create_collection(name=collection_name) def mul_add(self, datas: List[VectorData]): - data_array, id_array = map(list, zip(*((list(data.data), str(data.id)) for data in datas))) + data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) self._collection.add(embeddings=data_array, ids=id_array) def search(self, data, top_k: int = -1): @@ -54,21 +56,21 @@ def search(self, data, top_k: int = -1): if top_k == -1: top_k = self.top_k results = self._collection.query( - query_embeddings=[list(data)], + query_embeddings=[data.tolist()], n_results=top_k, include=["distances"], ) - return list(zip(results["distances"][0], results["ids"][0])) + return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]])) def delete(self, ids): - self._collection.delete(ids) + self._collection.delete([str(x) for x in ids]) def rebuild(self, ids=None): # pylint: disable=unused-argument return True def get_embeddings(self, data_id: str): vec_emb = self._collection.get( - data_id, + str(data_id), include=["embeddings"], )["embeddings"] if vec_emb is None or len(vec_emb) < 1: @@ -78,6 +80,6 @@ def get_embeddings(self, data_id: str): def update_embeddings(self, data_id: str, emb: np.ndarray): self._collection.update( - ids=data_id, + ids=str(data_id), embeddings=emb.tolist(), ) diff --git a/gptcache/manager/vector_data/qdrant.py b/gptcache/manager/vector_data/qdrant.py index 93bcf99e..c12c4df3 100644 --- a/gptcache/manager/vector_data/qdrant.py +++ b/gptcache/manager/vector_data/qdrant.py @@ -1,36 +1,44 @@ from typing import List, Optional + import numpy as np +from gptcache.manager.vector_data.base import VectorBase, VectorData from gptcache.utils import import_qdrant from gptcache.utils.log import gptcache_log -from gptcache.manager.vector_data.base import VectorBase, VectorData import_qdrant() -from qdrant_client import QdrantClient # pylint: disable=C0413 -from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \ - Distance # pylint: disable=C0413 +# pylint: disable=C0413 +from qdrant_client import QdrantClient +from qdrant_client.models import ( + PointStruct, + HnswConfigDiff, + VectorParams, + OptimizersConfigDiff, + Distance, +) class QdrantVectorStore(VectorBase): + """Qdrant Vector Store""" def __init__( - self, - url: Optional[str] = None, - port: Optional[int] = 6333, - grpc_port: int = 6334, - prefer_grpc: bool = False, - https: Optional[bool] = None, - api_key: Optional[str] = None, - prefix: Optional[str] = None, - timeout: Optional[float] = None, - host: Optional[str] = None, - collection_name: Optional[str] = "gptcache", - location: Optional[str] = "./qdrant", - dimension: int = 0, - top_k: int = 1, - flush_interval_sec: int = 5, - index_params: Optional[dict] = None, + self, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + collection_name: Optional[str] = "gptcache", + location: Optional[str] = "./qdrant", + dimension: int = 0, + top_k: int = 1, + flush_interval_sec: int = 5, + index_params: Optional[dict] = None, ): if dimension <= 0: raise ValueError( @@ -44,13 +52,17 @@ def __init__( if self._in_memory or location is not None: self._create_local(location) else: - self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https) + self._create_remote( + url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https + ) self._create_collection(collection_name, flush_interval_sec, index_params) def _create_local(self, location): self._client = QdrantClient(location=location) - def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https): + def _create_remote( + self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https + ): self._client = QdrantClient( url=url, port=port, @@ -63,45 +75,70 @@ def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_gr https=https, ) - def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None): + def _create_collection( + self, + collection_name: str, + flush_interval_sec: int, + index_params: Optional[dict] = None, + ): hnsw_config = HnswConfigDiff(**(index_params or {})) - vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE, - hnsw_config=hnsw_config) - optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000, - flush_interval_sec=flush_interval_sec) + vectors_config = VectorParams( + size=self.dimension, distance=Distance.COSINE, hnsw_config=hnsw_config + ) + optimizers_config = OptimizersConfigDiff( + deleted_threshold=0.2, + vacuum_min_vector_number=1000, + flush_interval_sec=flush_interval_sec, + ) # check if the collection exists existing_collections = self._client.get_collections() for existing_collection in existing_collections.collections: if existing_collection.name == collection_name: - gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name) + gptcache_log.warning( + "The %s collection already exists, and it will be used directly.", + collection_name, + ) break else: - self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config, - optimizers_config=optimizers_config) + self._client.create_collection( + collection_name=collection_name, + vectors_config=vectors_config, + optimizers_config=optimizers_config, + ) def mul_add(self, datas: List[VectorData]): - points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas] - self._client.upsert(collection_name=self._collection_name, points=points, wait=False) + points = [ + PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas + ] + self._client.upsert( + collection_name=self._collection_name, points=points, wait=False + ) def search(self, data: np.ndarray, top_k: int = -1): if top_k == -1: top_k = self.top_k reshaped_data = data.reshape(-1).tolist() - search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data, - limit=top_k) + search_result = self._client.search( + collection_name=self._collection_name, + query_vector=reshaped_data, + limit=top_k, + ) return list(map(lambda x: (x.score, x.id), search_result)) def delete(self, ids: List[str]): self._client.delete(collection_name=self._collection_name, points_selector=ids) def rebuild(self, ids=None): # pylint: disable=unused-argument - optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000) - self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config) + optimizers_config = OptimizersConfigDiff( + deleted_threshold=0.2, vacuum_min_vector_number=1000 + ) + self._client.update_collection( + collection_name=self._collection_name, optimizer_config=optimizers_config + ) def flush(self): # no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant pass - def close(self): self.flush() diff --git a/tests/integration_tests/test_sqlite_milvus_sbert.py b/tests/integration_tests/test_sqlite_milvus_sbert.py index 48228d26..f74945d0 100644 --- a/tests/integration_tests/test_sqlite_milvus_sbert.py +++ b/tests/integration_tests/test_sqlite_milvus_sbert.py @@ -1,15 +1,16 @@ import os import shutil -import pytest from tempfile import TemporaryDirectory + +import pytest + from base.client_base import Base from common import common_func as cf from gptcache import cache, Config from gptcache.adapter import openai -from gptcache.embedding import Onnx, SBERT +from gptcache.embedding import SBERT from gptcache.manager import get_data_manager, VectorBase from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation -from utils.util_log import test_log as log def get_text_response(response): @@ -37,104 +38,112 @@ def test_cache_health_check(self): expected: cache health detection & correction """ with TemporaryDirectory(dir="./") as root: - if os.path.isfile("./sqlite.db"): - os.remove("./sqlite.db") - if os.path.isdir('./milvus_data'): - shutil.rmtree('./milvus_data') onnx = SBERT() - vector_base = VectorBase( - "milvus", - dimension=onnx.dimension, - local_mode=True, - port="10086", - local_data=str(root), - ) - data_manager = get_data_manager("sqlite", vector_base, max_size=2000) - cache.init( - embedding_func=onnx.to_embeddings, - data_manager=data_manager, - similarity_evaluation=SearchDistanceEvaluation(), - config=Config( - log_time_func=cf.log_time_func, - enable_token_counter=False, + + vector_bases = [ + VectorBase( + "milvus", + dimension=onnx.dimension, + local_mode=True, + port="10086", + local_data=str(root), ), - ) + VectorBase("chromadb"), + ] - question = [ - "what is apple?", - "what is intel?", - "what is openai?", + for vector_base in vector_bases: + if os.path.isfile("./sqlite.db"): + os.remove("./sqlite.db") + if os.path.isdir('./milvus_data'): + shutil.rmtree('./milvus_data') - ] - answer = [ - "apple", - "intel", - "openai" - ] - for q, a in zip(question, answer): - cache.data_manager.save(q, a, cache.embedding_func(q)) - - # let's simulate cache out-of-sync - # situation. - touble_query = "what is google?" - cache.data_manager.v.update_embeddings(1, cache.embedding_func(touble_query)) - - # without cache health check - # respons is incorrect - response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": touble_query}, - ], - search_only=True, - stream=True, - ) - # Incorrect response "apple" returned to user - resp_txt = get_text_response(response) - # log.info(f"Inccorect response = {resp_txt} is returned") - assert answer[0] == resp_txt - - # cache health enabled - # stop returning incorrect answer - # and self-heal the trouble cache - # entry. - response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": touble_query}, - ], - search_only=True, - health_check=True, - stream=True, - ) - assert response is None - - # disable cache check, and verify - # cache is now consistent - response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": touble_query}, - ], - search_only=True, - stream=True, - ) - assert response is None - - # verify self-heal took place - response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": question[0]}, - ], - search_only=True, - stream=True, - ) - assert get_text_response(response) == answer[0] - if os.path.isfile("./sqlite.db"): - os.remove("./sqlite.db") + data_manager = get_data_manager("sqlite", vector_base, max_size=2000) + cache.init( + embedding_func=onnx.to_embeddings, + data_manager=data_manager, + similarity_evaluation=SearchDistanceEvaluation(), + config=Config( + log_time_func=cf.log_time_func, + enable_token_counter=False, + ), + ) + + question = [ + "what is apple?", + "what is intel?", + "what is openai?", + + ] + answer = [ + "apple", + "intel", + "openai" + ] + for q, a in zip(question, answer): + cache.data_manager.save(q, a, cache.embedding_func(q)) + + # let's simulate cache out-of-sync + # situation. + touble_query = "what is google?" + cache.data_manager.v.update_embeddings(1, cache.embedding_func(touble_query)) + + # without cache health check + # respons is incorrect + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": touble_query}, + ], + search_only=True, + stream=True, + ) + # Incorrect response "apple" returned to user + resp_txt = get_text_response(response) + # log.info(f"Inccorect response = {resp_txt} is returned") + assert answer[0] == resp_txt + + # cache health enabled + # stop returning incorrect answer + # and self-heal the trouble cache + # entry. + cache.config.data_check = True + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": touble_query}, + ], + search_only=True, + stream=True, + ) + assert response is None + + # disable cache check, and verify + # cache is now consistent + cache.config.data_check = False + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": touble_query}, + ], + search_only=True, + stream=True, + ) + assert response is None + + # verify self-heal took place + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question[0]}, + ], + search_only=True, + stream=True, + ) + assert get_text_response(response) == answer[0] + if os.path.isfile("./sqlite.db"): + os.remove("./sqlite.db") diff --git a/tests/unit_tests/embedding/test_sbert.py b/tests/unit_tests/embedding/test_sbert.py index d543d234..2c94628f 100644 --- a/tests/unit_tests/embedding/test_sbert.py +++ b/tests/unit_tests/embedding/test_sbert.py @@ -11,4 +11,14 @@ def test_sbert(): t = _get_model(model_src="sbert", model_config={"model": "all-MiniLM-L6-v2"}) dimension = t.dimension data = t.to_embeddings("foo") - assert len(data) == dimension, f"{len(data)}, {t.dimension}" \ No newline at end of file + assert len(data) == dimension, f"{len(data)}, {t.dimension}" + + question = [ + "what is apple?", + "what is intel?", + "what is openai?", + ] + answer = ["apple", "intel", "openai"] + for q, _ in zip(question, answer): + data = t.to_embeddings(q) + assert len(data) == dimension, f"{len(data)}, {t.dimension}" diff --git a/tests/unit_tests/manager/test_chromadb.py b/tests/unit_tests/manager/test_chromadb.py index 848c4817..823a262c 100644 --- a/tests/unit_tests/manager/test_chromadb.py +++ b/tests/unit_tests/manager/test_chromadb.py @@ -10,6 +10,7 @@ class TestChromadb(unittest.TestCase): def test_normal(self): db = VectorBase("chromadb", client_settings={}, top_k=3) db.mul_add([VectorData(id=i, data=np.random.sample(10)) for i in range(100)]) - self.assertEqual(len(db.search(np.random.sample(10))), 3) + search_res = db.search(np.random.sample(10)) + self.assertEqual(len(search_res), 3) db.delete(["1", "3", "5", "7"]) self.assertEqual(db._collection.count(), 96) From 75b3c69633c86b06d4ac6a91c62b43494e0d7f97 Mon Sep 17 00:00:00 2001 From: Anurag Wagh Date: Thu, 29 Jun 2023 14:31:36 +0530 Subject: [PATCH 03/10] [add] support for mongodb storage (#454) * [add] support for mongodb store * [add] unit test for mongodb store * [mod] using [] instead of list, oid instead of _id, change unit tests Signed-off-by: Anurag * [add] Docker image for mongo, optimized imports Signed-off-by: Anurag * [refactor] disable warnings for `wrong-import-position` Signed-off-by: Anurag --------- Signed-off-by: Anurag Co-authored-by: SimFG Signed-off-by: sunilkumardash9 --- .github/workflows/unit_test_main.yaml | 4 + gptcache/manager/scalar_data/manager.py | 10 +- gptcache/manager/scalar_data/mongo.py | 268 ++++++++++++++++++++++++ gptcache/utils/__init__.py | 9 +- tests/unit_tests/manager/test_mongo.py | 153 ++++++++++++++ 5 files changed, 442 insertions(+), 2 deletions(-) create mode 100644 gptcache/manager/scalar_data/mongo.py create mode 100644 tests/unit_tests/manager/test_mongo.py diff --git a/.github/workflows/unit_test_main.yaml b/.github/workflows/unit_test_main.yaml index 8590662e..de6b3d1a 100644 --- a/.github/workflows/unit_test_main.yaml +++ b/.github/workflows/unit_test_main.yaml @@ -40,6 +40,10 @@ jobs: image: redis/redis-stack-server ports: - 6379:6379 + mongo: + image: mongo + ports: + - 27017:27017 steps: - uses: actions/checkout@main diff --git a/gptcache/manager/scalar_data/manager.py b/gptcache/manager/scalar_data/manager.py index 5ed4b81a..e824c0e3 100644 --- a/gptcache/manager/scalar_data/manager.py +++ b/gptcache/manager/scalar_data/manager.py @@ -1,7 +1,7 @@ +from gptcache.manager.scalar_data.mongo import MongoStorage from gptcache.utils import import_sql_client from gptcache.utils.error import NotFoundError - SQL_URL = { "sqlite": "sqlite:///./sqlite.db", "duckdb": "duckdb:///./duck.db", @@ -84,6 +84,14 @@ def get(name, **kwargs): table_name=table_name, table_len_config=table_len_config, ) + elif name == "mongo": + return MongoStorage( + host=kwargs.get("mongo_host", "localhost"), + port=kwargs.get("mongo_port", 27017), + dbname=kwargs.get("dbname", TABLE_NAME), + username=kwargs.get("username"), + password=kwargs.get("password") + ) else: raise NotFoundError("cache store", name) return cache_base diff --git a/gptcache/manager/scalar_data/mongo.py b/gptcache/manager/scalar_data/mongo.py new file mode 100644 index 00000000..bf63784b --- /dev/null +++ b/gptcache/manager/scalar_data/mongo.py @@ -0,0 +1,268 @@ +from datetime import datetime +from typing import List, Optional + +import numpy as np + +from gptcache.manager.scalar_data.base import CacheStorage, CacheData, Question, QuestionDep +from gptcache.utils import import_mongodb + +import_mongodb() + +from mongoengine import Document # pylint: disable=wrong-import-position +from mongoengine import fields # pylint: disable=wrong-import-position +import mongoengine as me # pylint: disable=wrong-import-position + + +def get_models(): + class Questions(Document): + """ + questions collection + """ + meta = { + "collection": "questions", + "indexes": [ + "deleted" + ] + } + _id = fields.SequenceField() + question = fields.StringField() + create_on = fields.DateTimeField(default=datetime.now()) + last_access = fields.DateTimeField(default=datetime.now()) + embedding_data = fields.BinaryField() + deleted = fields.IntField(default=0) + + @property + def oid(self): + return self._id + + class Answers(Document): + """ + answer collection + """ + _id = fields.SequenceField() + meta = { + "collection": "answers", + "indexes": [ + "question_id" + ] + } + answer = fields.StringField() + answer_type = fields.IntField() + question_id = fields.IntField() + + @property + def oid(self): + return self._id + + class Sessions(Document): + """ + session collection + """ + meta = { + "collection": "sessions", + "indexes": [ + "question_id" + ] + + } + _id = fields.SequenceField() + session_id = fields.StringField() + session_question = fields.StringField() + question_id = fields.IntField() + + @property + def oid(self): + return self._id + + class QuestionDeps(Document): + """ + Question Dep collection + """ + meta = { + "collection": "question_deps", + "indexes": [ + "question_id" + ] + } + _id = fields.SequenceField() + question_id = fields.IntField() + dep_name = fields.StringField() + dep_data = fields.StringField() + dep_type = fields.IntField() + + @property + def oid(self): + return self._id + + return Questions, Answers, QuestionDeps, Sessions + + +class MongoStorage(CacheStorage): + """ + Using mongoengine as ORM to manage mongodb documents. + By default, data is stored 'gptcache' database and following collections are created to store the data + 1. 'sessions' + 2. 'answers' + 3. 'questions' + 4. 'question_deps' + + :param host: mongodb host, default value 'localhost' + :type host: str + + :param port: mongodb port, default value 27017 + :type host: int + + :param dbname: Mongo database name, default value 'gptcache' + :type host: str + + :param : Mongo database name, default value 'gptcache' + :type host: str + + :param username: username for authentication, default value None + :type host: str + + :param password: password for authentication, default value None + :type host: str + """ + + def __init__( + self, + host: str = "localhost", + port: int = 27017, + dbname: str = "gptcache", + username: str = None, + password: str = None, + **kwargs): + self.con = me.connect(host=host, + port=port, + db=dbname, + username=username, + password=password, + **kwargs) + self._ques, self._answer, self._ques_dep, self._session = get_models() + + def create(self): + pass + + def _insert(self, data: CacheData): + ques_data = self._ques( + question=data.question + if isinstance(data.question, str) + else data.question.content, + embedding_data=data.embedding_data.tobytes() + if data.embedding_data is not None + else None + ) + ques_data.save() + if isinstance(data.question, Question) and data.question.deps is not None: + all_deps = [] + for dep in data.question.deps: + all_deps.append( + self._ques_dep( + question_id=ques_data.oid, + dep_name=dep.name, + dep_data=dep.data, + dep_type=dep.dep_type, + ) + ) + self._ques_dep.objects.insert(all_deps) + + answers = data.answers if isinstance(data.answers, list) else [data.answers] + all_data = [] + for answer in answers: + answer_data = self._answer( + question_id=ques_data.oid, + answer=answer.answer, + answer_type=int(answer.answer_type), + ) + all_data.append(answer_data) + self._answer.objects.insert(all_data) + + if data.session_id: + session_data = self._session( + question_id=ques_data.oid, + session_id=data.session_id, + session_question=data.question + if isinstance(data.question, str) + else data.question.content, + ) + self._session.objects.insert(session_data) + + return ques_data.oid + + def batch_insert(self, all_data: List[CacheData]): + ids = [] + for data in all_data: + ids.append(self._insert(data)) + return ids + + def get_data_by_id(self, key) -> Optional[CacheData]: + qs = self._ques.objects.get(_id=key, deleted=0) + if qs is None: + return None + last_access = qs.last_access + qs.last_access = datetime.now() + qs.save() + answers = self._answer.objects(question_id=qs.oid) + deps = self._ques_dep.objects(question_id=qs.oid) + session_ids = self._session.objects(question_id=qs.oid) + + res_ans = [(item.answer, item.answer_type) for item in answers] + res_deps = [ + QuestionDep(item.dep_name, item.dep_data, item.dep_type) + for item in deps + ] + return CacheData( + question=qs.question if not deps else Question(qs.question, res_deps), + answers=res_ans, + embedding_data=np.frombuffer(qs.embedding_data, dtype=np.float32), + session_id=session_ids, + create_on=qs.create_on, + last_access=last_access, + ) + + def mark_deleted(self, keys): + self._ques.objects(_id__in=keys).update(deleted=-1) + + def clear_deleted_data(self): + questions = self._ques.objects(deleted=-1).only("_id") + q_ids = [obj.oid for obj in questions] + self._answer.objects(question_id__in=q_ids).delete() + self._ques_dep.objects(question_id__in=q_ids).delete() + self._session.objects(question_id__in=q_ids).delete() + questions.delete() + + def get_ids(self, deleted: bool = True): + state = -1 if deleted else 0 + res = [obj.oid for obj in self._ques.objects(deleted=state).only("_id")] + return res + + def count(self, state: int = 0, is_all: bool = False): + if is_all: + return self._ques.objects.count() + return self._ques.objects(deleted=state).count() + + def add_session(self, question_id, session_id, session_question): + self._session(question_id=question_id, + session_id=session_id, + session_question=session_question + ).save() + + def list_sessions(self, session_id=None, key=None): + query = {} + if session_id: + query["session_id"] = session_id + if key: + query["_id"] = key + + return self._session.objects(__raw__=query) + + def delete_session(self, keys): + self._session.objects(question_id__in=keys).delete() + + def count_answers(self): + return self._answer.objects.count() + + def close(self): + me.disconnect() + self.con.close() diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index 7a919721..258449ad 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -15,6 +15,7 @@ "import_chromadb", "import_sqlalchemy", "import_sql_client", + "import_mongodb", "import_pydantic", "import_langchain", "import_pillow", @@ -39,7 +40,7 @@ "import_tiktoken", "import_fastapi", "import_redis", - "import_qdrant" + "import_qdrant", ] import importlib.util @@ -159,6 +160,11 @@ def import_sql_client(db_name): import_duckdb() +def import_mongodb(): + _check_library("pymongo") + _check_library("mongoengine") + + def import_pydantic(): _check_library("pydantic") @@ -248,5 +254,6 @@ def import_fastapi(): _check_library("uvicorn", package="'uvicorn[standard]'") _check_library("fastapi") + def import_redis(): _check_library("redis") diff --git a/tests/unit_tests/manager/test_mongo.py b/tests/unit_tests/manager/test_mongo.py new file mode 100644 index 00000000..71ce639e --- /dev/null +++ b/tests/unit_tests/manager/test_mongo.py @@ -0,0 +1,153 @@ +import time + +import numpy as np +from mongoengine import connect, disconnect + +from gptcache.manager.scalar_data.base import CacheData, Question +from gptcache.manager.scalar_data.mongo import MongoStorage + + +def test_mongo(): + test_dbname = "gptcache_test" + _clear_test_db(test_dbname) + _inner_test_normal(test_dbname) + + _clear_test_db(test_dbname) + _inner_test_with_deps(test_dbname) + + _clear_test_db(test_dbname) + _test_create_on(dbname=test_dbname) + + _clear_test_db(test_dbname) + _test_session(dbname=test_dbname) + + _clear_test_db(test_dbname) + + +def _clear_test_db(dbname): + con = connect(db=dbname) + con.drop_database(dbname) + disconnect() + + +def _inner_test_normal(dbname: str): + mongo_storage = MongoStorage(dbname=dbname) + data = [] + for i in range(1, 10): + data.append( + CacheData( + "question_" + str(i), + ["answer_" + str(i)] * i, + np.random.rand(5) + ) + ) + mongo_storage.batch_insert(data) + + for i in range(1, 10): + data = mongo_storage.get_data_by_id(i) + assert data.question == f"question_{i}" + assert data.answers[0].answer == f"answer_{i}" + + q_id = mongo_storage.batch_insert( + [CacheData("question_single", "answer_single", np.random.rand(5))] + )[0] + data = mongo_storage.get_data_by_id(q_id) + assert data.question == "question_single" + assert data.answers[0].answer == "answer_single" + + assert len(mongo_storage.get_ids(True)) == 0 + mongo_storage.mark_deleted([1, 2, 3]) + assert mongo_storage.get_ids(True), [1, 2 == 3] + assert mongo_storage.count(is_all=True) == 10 + assert mongo_storage.count() == 7 + assert mongo_storage.count_answers() == 46 + mongo_storage.clear_deleted_data() + assert mongo_storage.count_answers() == 40 + assert mongo_storage.count(is_all=True) == 7 + + +def _inner_test_with_deps(dbname: str): + mongo_storage = MongoStorage(dbname=dbname) + data_id = mongo_storage.batch_insert( + [ + CacheData( + Question.from_dict( + { + "content": "test_question", + "deps": [ + { + "name": "text", + "data": "how many people in this picture", + "dep_type": 0, + }, + { + "name": "image", + "data": "object_name", + "dep_type": 1, + }, + ], + } + ), + "test_answer", + np.random.rand(5), + ) + ] + )[0] + + ret = mongo_storage.get_data_by_id(data_id) + assert ret.question.content == "test_question" + assert ret.question.deps[0].name == "text" + assert ret.question.deps[0].data == "how many people in this picture" + assert ret.question.deps[0].dep_type == 0 + assert ret.question.deps[1].name == "image" + assert ret.question.deps[1].data == "object_name" + assert ret.question.deps[1].dep_type == 1 + + +def _test_create_on(dbname): + mongo_storage = MongoStorage(dbname=dbname) + mongo_storage.create() + data = [] + for i in range(1, 10): + data.append( + CacheData( + "question_" + str(i), + ["answer_" + str(i)] * i, + np.random.rand(5), + ) + ) + mongo_storage.batch_insert(data) + data = mongo_storage.get_data_by_id(1) + create_on1 = data.create_on + last_access1 = data.last_access + + time.sleep(1) + + data = mongo_storage.get_data_by_id(1) + create_on2 = data.create_on + last_access2 = data.last_access + + assert create_on1 == create_on2 + assert last_access1 < last_access2 + + +def _test_session(dbname): + mongo_storage = MongoStorage(dbname=dbname) + data = [] + for i in range(1, 11): + data.append( + CacheData( + "question_" + str(i), + ["answer_" + str(i)] * i, + np.random.rand(5), + session_id=str(1 if i <= 5 else 0) + ) + ) + mongo_storage.batch_insert(data) + assert len(mongo_storage.list_sessions()) == 10 + assert len(mongo_storage.list_sessions(session_id="0")) == 5 + assert len(mongo_storage.list_sessions(session_id="1")) == 5 + assert len(mongo_storage.list_sessions(session_id="1", key=1)) == 1 + + mongo_storage.delete_session([1, 2, 3]) + assert len(mongo_storage.list_sessions()) == 7 From 5211378373e7447ae7e39d7e40be62396e0a782f Mon Sep 17 00:00:00 2001 From: SimFG Date: Thu, 29 Jun 2023 21:24:45 +0800 Subject: [PATCH 04/10] Fix the wrong return value of onnx similarity evaluation (#460) Signed-off-by: SimFG Signed-off-by: sunilkumardash9 --- gptcache/adapter/adapter.py | 4 ++-- gptcache/manager/scalar_data/manager.py | 3 ++- gptcache/processor/post.py | 1 + gptcache/similarity_evaluation/distance.py | 1 + gptcache/similarity_evaluation/onnx.py | 6 ++++-- gptcache/utils/softmax.py | 2 +- tests/requirements.txt | 2 ++ .../similarity_evaluation/test_evaluation_onnx.py | 1 + 8 files changed, 14 insertions(+), 6 deletions(-) diff --git a/gptcache/adapter/adapter.py b/gptcache/adapter/adapter.py index acf16d29..15df6ffa 100644 --- a/gptcache/adapter/adapter.py +++ b/gptcache/adapter/adapter.py @@ -166,7 +166,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg rank, ) if rank_threshold <= rank: - cache_answers.append((rank, cache_data.answers[0].answer, search_data)) + cache_answers.append((float(rank), cache_data.answers[0].answer, search_data)) chat_cache.data_manager.hit_cache_callback(search_data) cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True) answers_dict = dict((d[1], d[2]) for d in cache_answers) @@ -397,7 +397,7 @@ async def aadapt(llm_handler, cache_data_convert, update_cache_callback, *args, rank, ) if rank_threshold <= rank: - cache_answers.append((rank, cache_data.answers[0].answer, search_data)) + cache_answers.append((float(rank), cache_data.answers[0].answer, search_data)) chat_cache.data_manager.hit_cache_callback(search_data) cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True) answers_dict = dict((d[1], d[2]) for d in cache_answers) diff --git a/gptcache/manager/scalar_data/manager.py b/gptcache/manager/scalar_data/manager.py index e824c0e3..66697606 100644 --- a/gptcache/manager/scalar_data/manager.py +++ b/gptcache/manager/scalar_data/manager.py @@ -1,4 +1,3 @@ -from gptcache.manager.scalar_data.mongo import MongoStorage from gptcache.utils import import_sql_client from gptcache.utils.error import NotFoundError @@ -85,6 +84,8 @@ def get(name, **kwargs): table_len_config=table_len_config, ) elif name == "mongo": + from gptcache.manager.scalar_data.mongo import MongoStorage + return MongoStorage( host=kwargs.get("mongo_host", "localhost"), port=kwargs.get("mongo_port", 27017), diff --git a/gptcache/processor/post.py b/gptcache/processor/post.py index c8f812f4..9a1c3a6e 100644 --- a/gptcache/processor/post.py +++ b/gptcache/processor/post.py @@ -1,5 +1,6 @@ import random from typing import List, Any + import numpy from gptcache.utils import softmax diff --git a/gptcache/similarity_evaluation/distance.py b/gptcache/similarity_evaluation/distance.py index 8d1fa17b..8e3a6675 100644 --- a/gptcache/similarity_evaluation/distance.py +++ b/gptcache/similarity_evaluation/distance.py @@ -1,4 +1,5 @@ from typing import Tuple, Dict, Any + from gptcache.similarity_evaluation import SimilarityEvaluation diff --git a/gptcache/similarity_evaluation/onnx.py b/gptcache/similarity_evaluation/onnx.py index aa259a88..236e9955 100644 --- a/gptcache/similarity_evaluation/onnx.py +++ b/gptcache/similarity_evaluation/onnx.py @@ -1,11 +1,13 @@ from typing import Dict, List, Tuple, Any + import numpy as np + +from gptcache.similarity_evaluation import SimilarityEvaluation from gptcache.utils import ( import_onnxruntime, import_huggingface_hub, import_huggingface, ) -from gptcache.similarity_evaluation import SimilarityEvaluation import_onnxruntime() import_huggingface_hub() @@ -130,4 +132,4 @@ def inference(self, reference: str, candidates: List[str]) -> np.ndarray: } ort_outputs = self.ort_session.run(None, ort_inputs) scores = ort_outputs[0][:, 1] - return scores + return float(scores[0]) diff --git a/gptcache/utils/softmax.py b/gptcache/utils/softmax.py index 71efbccd..6cd98252 100644 --- a/gptcache/utils/softmax.py +++ b/gptcache/utils/softmax.py @@ -3,7 +3,7 @@ def softmax(x: list): x = np.array(x) - assert len(x.shape) == 1, f"Expect to get a shape of (len,) but got {x.shape}." + assert len(x.shape) == 1, f"Expect to get a shape of (len,) but got {x.shape}, x value: {x}." max_val = x.max() e_x = np.exp(x - max_val) return e_x / e_x.sum() diff --git a/tests/requirements.txt b/tests/requirements.txt index 605b90f2..2d04d876 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -26,3 +26,5 @@ grpcio==1.53.0 protobuf==3.20.0 milvus==2.2.8 pymilvus==2.2.8 +pymongo +mongoengine diff --git a/tests/unit_tests/similarity_evaluation/test_evaluation_onnx.py b/tests/unit_tests/similarity_evaluation/test_evaluation_onnx.py index e5397e75..c2adb8aa 100644 --- a/tests/unit_tests/similarity_evaluation/test_evaluation_onnx.py +++ b/tests/unit_tests/similarity_evaluation/test_evaluation_onnx.py @@ -17,6 +17,7 @@ def _test_evaluation(evaluation): candidate_2 = "how old are you?" score = evaluation.evaluation({"question": query}, {"question": candidate_1}) + assert isinstance(score, float), type(score) assert score > 0.8 score = evaluation.evaluation({"question": query}, {"question": candidate_2}) From d4154d43aa8dc5a697f381a5d00843a94ce6197c Mon Sep 17 00:00:00 2001 From: SimFG Date: Fri, 30 Jun 2023 20:07:25 +0800 Subject: [PATCH 05/10] Update the version to `0.1.34` Signed-off-by: SimFG Signed-off-by: sunilkumardash9 --- docs/release_note.md | 6 ++++++ gptcache/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/release_note.md b/docs/release_note.md index 6f9bdb5b..1406f5a6 100644 --- a/docs/release_note.md +++ b/docs/release_note.md @@ -5,6 +5,12 @@ To read the following content, you need to understand the basic use of GPTCache, - [Readme doc](https://github.com/zilliztech/GPTCache) - [Usage doc](https://github.com/zilliztech/GPTCache/blob/main/docs/usage.md) +## v0.1.34 (2023.6.30) + +1. Add support for Qdrant Vector Store +2. Add support for Mongodb Cache Store +3. Fix bug about the redis vector and onnx similarity evaluation + ## v0.1.33 (2023.6.27) 1. Fix the eviction error diff --git a/gptcache/__init__.py b/gptcache/__init__.py index 21842da0..428d6e66 100644 --- a/gptcache/__init__.py +++ b/gptcache/__init__.py @@ -1,5 +1,5 @@ """gptcache version""" -__version__ = "0.1.33" +__version__ = "0.1.34" from gptcache.config import Config from gptcache.core import Cache From 90d49916a2e778c713828bb717443597f4160ae8 Mon Sep 17 00:00:00 2001 From: sunilkumardash9 Date: Thu, 6 Jul 2023 23:54:50 +0530 Subject: [PATCH 06/10] added support for weaviate vector-store Signed-off-by: sunilkumardash9 --- gptcache/manager/vector_data/manager.py | 26 +++++ gptcache/manager/vector_data/weaviate.py | 126 ++++++++++++++++++++++ gptcache/utils/__init__.py | 5 + tests/unit_tests/manager/test_weaviate.py | 30 ++++++ 4 files changed, 187 insertions(+) create mode 100644 gptcache/manager/vector_data/weaviate.py create mode 100644 tests/unit_tests/manager/test_weaviate.py diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 86616b5b..6aa6b4fe 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -257,6 +257,32 @@ def get(name, **kwargs): flush_interval_sec=flush_interval_sec, index_params=index_params, ) + elif name == "weaviate": + from .. vector_data.weaviate import Weaviate + url = kwargs.get("url", None) + auth_client_secret = kwargs.get('auth_client_secret', None), + timeout_config = kwargs.get("timeout_config", (10, 60)) + proxies = kwargs.get("proxies", None) + trust_env = kwargs.get("trust_env", False) + additional_headers = kwargs.get("additional_headers", None) + startup_period = kwargs.get("startup_period", 5) + embedded_options = kwargs.get("embedded_options", None) + additional_config = kwargs.get("additional_config", None) + class_name = kwargs.get("class_name", "Gptcache") + top_k = kwargs.get("top_k", 1) + vector_base = Weaviate( + url= url, + auth_client_secret = auth_client_secret, + timeout_config = timeout_config, + proxies = proxies, + trust_env = trust_env, + additional_headers = additional_headers, + startup_period = startup_period, + embedded_options = embedded_options, + additional_config = additional_config, + class_name = class_name, + top_k = top_k, + ) else: raise NotFoundError("vector store", name) return vector_base diff --git a/gptcache/manager/vector_data/weaviate.py b/gptcache/manager/vector_data/weaviate.py new file mode 100644 index 00000000..64691371 --- /dev/null +++ b/gptcache/manager/vector_data/weaviate.py @@ -0,0 +1,126 @@ +from typing import List + +import numpy as np + +from gptcache.manager.vector_data.base import VectorBase, VectorData +from ... utils import import_weaviate +from gptcache.utils.log import gptcache_log + +from weaviate import Client, EmbeddedOptions, Config + +import_weaviate() + +class Weaviate(VectorBase): + """Weaviate Vector store""" + def __init__(self, + url: str | None = None, + auth_client_secret: None = None, + timeout_config = (10, 60), + proxies: dict | str | None = None, + trust_env: bool = False, + additional_headers: dict | None = None, + startup_period: int | None = 5, + embedded_options: None = None, + additional_config: None = None, + top_k: int = 1, + distance: str = "cosine", + collection_name: str = "Gptcache", + ): + self.class_name = collection_name + self.top_k = top_k + self.distance = distance + if embedded_options: + self.client = Client(embedded_options = EmbeddedOptions(), + startup_period = startup_period, + timeout_config = timeout_config, + additional_config=additional_config) + else: + self.client = Client(url, + auth_client_secret, + timeout_config, + proxies, + trust_env, + additional_headers, + startup_period, + embedded_options, + additional_config, + ) + + def _create_collection(self, class_name: str): + if not class_name: + class_name = self.class_name + if self.client.schema.exists(class_name): + gptcache_log.info( + "The %s already exists, and it will be used directly", class_name + ) + else: + gptcache_class_schema = { + "class": class_name, + "description": "caching LLM responses", + "properties": [ + { + "name": "id_", + "dataType": ["int"], + } + ], + 'vectorIndexConfig': + { + "distance": self.distance + } + } + self.client.schema.create_class(gptcache_class_schema) + + def mul_add(self, datas: List[VectorData]): + with self.client.batch( + batch_size=len(datas) + ) as batch: + # Batch import + for data in datas: + properties = { + "id_": data.id, + } + self.client.batch.add_data_object( + properties, + self.class_name, + vector = data.data.tolist() + ) + + def search(self, data: np.ndarray, top_k: int = -1): + if not self.client.schema.exists(self.class_name): + self._create_collection(self.class_name) + if top_k==-1: + top_k = self.top_k + result = self.client.query.get(class_name = self.class_name, properties = ['id_']).\ + with_near_vector(content={"vector": data.tolist()}).\ + with_additional(['distance']).\ + with_limit(top_k).do() + return list(map(lambda x: (x['_additional']['distance'], x['id_']), result['data']['Get'][self.class_name])) + + def get_uuids(self, ids: List[str]): + uuid_list = [] + for id_ in ids: + res = self.client.query.get(class_name=self.class_name, properties=['id_']).\ + with_where({"path": ["id_"], "operator":"Equal", "valueNumber":id_}).\ + with_additional(["id"]).do() + uuid_list.append(res['data']['Get'][self.class_name][0]['_additional']['id']) + return uuid_list + + def delete(self, ids: List[str]): + uuids = self.get_uuids(ids) + for uuid_ in uuids: + self.client.data_object.delete(class_name='example', uuid=uuid_) + + def rebuild(self, ids=None) : + return + + def flush(self): + return True + + def close(self): + pass + + + + + + diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index 258449ad..9c0aec6c 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -41,6 +41,7 @@ "import_fastapi", "import_redis", "import_qdrant", + "import_weaviate" ] import importlib.util @@ -257,3 +258,7 @@ def import_fastapi(): def import_redis(): _check_library("redis") + + +def import_weaviate(): + _check_library("weaviate-client") \ No newline at end of file diff --git a/tests/unit_tests/manager/test_weaviate.py b/tests/unit_tests/manager/test_weaviate.py new file mode 100644 index 00000000..ee77070f --- /dev/null +++ b/tests/unit_tests/manager/test_weaviate.py @@ -0,0 +1,30 @@ +import unittest + +import numpy as np + +from gptcache.manager.vector_data import VectorBase +from gptcache.manager.vector_data.base import VectorData + + +class TestUSearchDB(unittest.TestCase): + def test_normal(self): + size = 1000 + dim = 512 + top_k = 10 + weaviate = VectorBase( + "weaviate", + top_k = top_k + ) + data = np.random.randn(size, dim).astype(np.float32) + weaviate.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))]) + search_result = weaviate.search(data[0], top_k) + self.assertEqual(len(search_result), top_k) + weaviate.mul_add([VectorData(id=size, data=data[0])]) + ret = weaviate.search(data[0]) + self.assertIn(ret[0][1], [0, size]) + self.assertIn(ret[1][1], [0, size]) + weaviate.delete([0, 1, 2, 3, 4, 5, size]) + ret = weaviate.search(data[0]) + self.assertNotIn(ret[0][1], [0, size]) + weaviate.rebuild() + weaviate.close() \ No newline at end of file From b6a189938bea3a3856c8934212b08afc61e67954 Mon Sep 17 00:00:00 2001 From: sunilkumardash9 Date: Fri, 7 Jul 2023 00:09:53 +0530 Subject: [PATCH 07/10] removed relative imports of weaviate class Signed-off-by: sunilkumardash9 --- gptcache/manager/vector_data/manager.py | 2 +- gptcache/manager/vector_data/weaviate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 6aa6b4fe..eb0e943a 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -258,7 +258,7 @@ def get(name, **kwargs): index_params=index_params, ) elif name == "weaviate": - from .. vector_data.weaviate import Weaviate + from gptcache.manager.vector_data.weaviate import Weaviate url = kwargs.get("url", None) auth_client_secret = kwargs.get('auth_client_secret', None), timeout_config = kwargs.get("timeout_config", (10, 60)) diff --git a/gptcache/manager/vector_data/weaviate.py b/gptcache/manager/vector_data/weaviate.py index 64691371..9226092b 100644 --- a/gptcache/manager/vector_data/weaviate.py +++ b/gptcache/manager/vector_data/weaviate.py @@ -3,7 +3,7 @@ import numpy as np from gptcache.manager.vector_data.base import VectorBase, VectorData -from ... utils import import_weaviate +from gptcache.utils import import_weaviate from gptcache.utils.log import gptcache_log from weaviate import Client, EmbeddedOptions, Config From d270e089d7243353fc393dfd63b439bdb3627f77 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com> Date: Tue, 18 Jul 2023 21:43:16 +0530 Subject: [PATCH 08/10] Update weaviate.py few corrections 1. Placed import_weaviate() before importing its functions 2. changed parameter name from 'collection_name' to 'class_name' 3. Few cosmetic changes --- gptcache/manager/vector_data/weaviate.py | 40 +++++++++++++----------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/gptcache/manager/vector_data/weaviate.py b/gptcache/manager/vector_data/weaviate.py index 9226092b..c206884b 100644 --- a/gptcache/manager/vector_data/weaviate.py +++ b/gptcache/manager/vector_data/weaviate.py @@ -6,9 +6,10 @@ from gptcache.utils import import_weaviate from gptcache.utils.log import gptcache_log +import_weaviate() + from weaviate import Client, EmbeddedOptions, Config -import_weaviate() class Weaviate(VectorBase): """Weaviate Vector store""" @@ -24,27 +25,30 @@ def __init__(self, additional_config: None = None, top_k: int = 1, distance: str = "cosine", - collection_name: str = "Gptcache", + class_name: str = "Gptcache", ): - self.class_name = collection_name + self.class_name = class_name self.top_k = top_k self.distance = distance - if embedded_options: - self.client = Client(embedded_options = EmbeddedOptions(), - startup_period = startup_period, - timeout_config = timeout_config, - additional_config=additional_config) + if not url: + self.client = Client( + embedded_options = EmbeddedOptions(), + startup_period = startup_period, + timeout_config = timeout_config, + additional_config = additional_config + ) else: - self.client = Client(url, - auth_client_secret, - timeout_config, - proxies, - trust_env, - additional_headers, - startup_period, - embedded_options, - additional_config, - ) + self.client = Client( + url, + auth_client_secret, + timeout_config, + proxies, + trust_env, + additional_headers, + startup_period, + embedded_options, + additional_config, + ) def _create_collection(self, class_name: str): if not class_name: From abc764dcad7885471b7eed81adce5d35aec4eedd Mon Sep 17 00:00:00 2001 From: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com> Date: Wed, 19 Jul 2023 22:57:11 +0530 Subject: [PATCH 09/10] Update weaviate.py 1. Removed a hard-coded class_name in delete method 2. Better type check --- gptcache/manager/vector_data/weaviate.py | 33 ++++++++++++------------ 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/gptcache/manager/vector_data/weaviate.py b/gptcache/manager/vector_data/weaviate.py index c206884b..d321870b 100644 --- a/gptcache/manager/vector_data/weaviate.py +++ b/gptcache/manager/vector_data/weaviate.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, Union import numpy as np @@ -13,20 +13,21 @@ class Weaviate(VectorBase): """Weaviate Vector store""" - def __init__(self, - url: str | None = None, - auth_client_secret: None = None, - timeout_config = (10, 60), - proxies: dict | str | None = None, - trust_env: bool = False, - additional_headers: dict | None = None, - startup_period: int | None = 5, - embedded_options: None = None, - additional_config: None = None, - top_k: int = 1, - distance: str = "cosine", - class_name: str = "Gptcache", - ): + def __init__( + self, + url: str = None, + auth_client_secret = None, + timeout_config = (10, 60), + proxies: Optional[Union[dict, str]] = None, + trust_env: bool = False, + additional_headers: Optional[dict] = None, + startup_period: Optional[int] = 5, + embedded_options = None, + additional_config = None, + top_k: int = 1, + distance: str = "cosine", + class_name: str = "Gptcache", + ): self.class_name = class_name self.top_k = top_k self.distance = distance @@ -112,7 +113,7 @@ def get_uuids(self, ids: List[str]): def delete(self, ids: List[str]): uuids = self.get_uuids(ids) for uuid_ in uuids: - self.client.data_object.delete(class_name='example', uuid=uuid_) + self.client.data_object.delete(class_name = self.class_name, uuid=uuid_) def rebuild(self, ids=None) : return From 3dc0c7394d8153c4669e13ca0b92229f8a204a4d Mon Sep 17 00:00:00 2001 From: Sunil Kumar Dash <47926185+sunilkumardash9@users.noreply.github.com> Date: Fri, 21 Jul 2023 23:36:10 +0530 Subject: [PATCH 10/10] Update __init__.py use old chromadb --- gptcache/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index 33d55ee3..cfd34fb6 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -117,7 +117,7 @@ def import_hnswlib(): def import_chromadb(): - _check_library("chromadb") + _check_library("chromadb", package="chromadb==0.3.26") def import_sqlalchemy():