diff --git a/gptcache/manager/scalar_data/mongo.py b/gptcache/manager/scalar_data/mongo.py index 89f9fc96..557a4b4b 100644 --- a/gptcache/manager/scalar_data/mongo.py +++ b/gptcache/manager/scalar_data/mongo.py @@ -324,4 +324,4 @@ def report_cache( def close(self): me.disconnect() - self.con.close() + self.con.close() \ No newline at end of file diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 86616b5b..eb0e943a 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -257,6 +257,32 @@ def get(name, **kwargs): flush_interval_sec=flush_interval_sec, index_params=index_params, ) + elif name == "weaviate": + from gptcache.manager.vector_data.weaviate import Weaviate + url = kwargs.get("url", None) + auth_client_secret = kwargs.get('auth_client_secret', None), + timeout_config = kwargs.get("timeout_config", (10, 60)) + proxies = kwargs.get("proxies", None) + trust_env = kwargs.get("trust_env", False) + additional_headers = kwargs.get("additional_headers", None) + startup_period = kwargs.get("startup_period", 5) + embedded_options = kwargs.get("embedded_options", None) + additional_config = kwargs.get("additional_config", None) + class_name = kwargs.get("class_name", "Gptcache") + top_k = kwargs.get("top_k", 1) + vector_base = Weaviate( + url= url, + auth_client_secret = auth_client_secret, + timeout_config = timeout_config, + proxies = proxies, + trust_env = trust_env, + additional_headers = additional_headers, + startup_period = startup_period, + embedded_options = embedded_options, + additional_config = additional_config, + class_name = class_name, + top_k = top_k, + ) else: raise NotFoundError("vector store", name) return vector_base diff --git a/gptcache/manager/vector_data/weaviate.py b/gptcache/manager/vector_data/weaviate.py new file mode 100644 index 00000000..d321870b --- /dev/null +++ b/gptcache/manager/vector_data/weaviate.py @@ -0,0 +1,131 @@ +from typing import List, Optional, Union + +import numpy as np + +from gptcache.manager.vector_data.base import VectorBase, VectorData +from gptcache.utils import import_weaviate +from gptcache.utils.log import gptcache_log + +import_weaviate() + +from weaviate import Client, EmbeddedOptions, Config + + +class Weaviate(VectorBase): + """Weaviate Vector store""" + def __init__( + self, + url: str = None, + auth_client_secret = None, + timeout_config = (10, 60), + proxies: Optional[Union[dict, str]] = None, + trust_env: bool = False, + additional_headers: Optional[dict] = None, + startup_period: Optional[int] = 5, + embedded_options = None, + additional_config = None, + top_k: int = 1, + distance: str = "cosine", + class_name: str = "Gptcache", + ): + self.class_name = class_name + self.top_k = top_k + self.distance = distance + if not url: + self.client = Client( + embedded_options = EmbeddedOptions(), + startup_period = startup_period, + timeout_config = timeout_config, + additional_config = additional_config + ) + else: + self.client = Client( + url, + auth_client_secret, + timeout_config, + proxies, + trust_env, + additional_headers, + startup_period, + embedded_options, + additional_config, + ) + + def _create_collection(self, class_name: str): + if not class_name: + class_name = self.class_name + if self.client.schema.exists(class_name): + gptcache_log.info( + "The %s already exists, and it will be used directly", class_name + ) + else: + gptcache_class_schema = { + "class": class_name, + "description": "caching LLM responses", + "properties": [ + { + "name": "id_", + "dataType": ["int"], + } + ], + 'vectorIndexConfig': + { + "distance": self.distance + } + } + self.client.schema.create_class(gptcache_class_schema) + + def mul_add(self, datas: List[VectorData]): + with self.client.batch( + batch_size=len(datas) + ) as batch: + # Batch import + for data in datas: + properties = { + "id_": data.id, + } + self.client.batch.add_data_object( + properties, + self.class_name, + vector = data.data.tolist() + ) + + def search(self, data: np.ndarray, top_k: int = -1): + if not self.client.schema.exists(self.class_name): + self._create_collection(self.class_name) + if top_k==-1: + top_k = self.top_k + result = self.client.query.get(class_name = self.class_name, properties = ['id_']).\ + with_near_vector(content={"vector": data.tolist()}).\ + with_additional(['distance']).\ + with_limit(top_k).do() + return list(map(lambda x: (x['_additional']['distance'], x['id_']), result['data']['Get'][self.class_name])) + + def get_uuids(self, ids: List[str]): + uuid_list = [] + for id_ in ids: + res = self.client.query.get(class_name=self.class_name, properties=['id_']).\ + with_where({"path": ["id_"], "operator":"Equal", "valueNumber":id_}).\ + with_additional(["id"]).do() + uuid_list.append(res['data']['Get'][self.class_name][0]['_additional']['id']) + return uuid_list + + def delete(self, ids: List[str]): + uuids = self.get_uuids(ids) + for uuid_ in uuids: + self.client.data_object.delete(class_name = self.class_name, uuid=uuid_) + + def rebuild(self, ids=None) : + return + + def flush(self): + return True + + def close(self): + pass + + + + + + diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index 3a09cf23..cfd34fb6 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -41,6 +41,7 @@ "import_fastapi", "import_redis", "import_qdrant", + "import_weaviate" ] import importlib.util @@ -116,7 +117,7 @@ def import_hnswlib(): def import_chromadb(): - _check_library("chromadb") + _check_library("chromadb", package="chromadb==0.3.26") def import_sqlalchemy(): @@ -260,5 +261,10 @@ def import_redis(): _check_library("redis_om") +def import_weaviate(): + _check_library("weaviate-client") + + def import_starlette(): _check_library("starlette") + diff --git a/tests/requirements.txt b/tests/requirements.txt index 605b90f2..2d04d876 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -26,3 +26,5 @@ grpcio==1.53.0 protobuf==3.20.0 milvus==2.2.8 pymilvus==2.2.8 +pymongo +mongoengine diff --git a/tests/unit_tests/manager/test_weaviate.py b/tests/unit_tests/manager/test_weaviate.py new file mode 100644 index 00000000..ee77070f --- /dev/null +++ b/tests/unit_tests/manager/test_weaviate.py @@ -0,0 +1,30 @@ +import unittest + +import numpy as np + +from gptcache.manager.vector_data import VectorBase +from gptcache.manager.vector_data.base import VectorData + + +class TestUSearchDB(unittest.TestCase): + def test_normal(self): + size = 1000 + dim = 512 + top_k = 10 + weaviate = VectorBase( + "weaviate", + top_k = top_k + ) + data = np.random.randn(size, dim).astype(np.float32) + weaviate.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))]) + search_result = weaviate.search(data[0], top_k) + self.assertEqual(len(search_result), top_k) + weaviate.mul_add([VectorData(id=size, data=data[0])]) + ret = weaviate.search(data[0]) + self.assertIn(ret[0][1], [0, size]) + self.assertIn(ret[1][1], [0, size]) + weaviate.delete([0, 1, 2, 3, 4, 5, size]) + ret = weaviate.search(data[0]) + self.assertNotIn(ret[0][1], [0, size]) + weaviate.rebuild() + weaviate.close() \ No newline at end of file