Commit 1eac56c3 authored by Stergios Papadopoulos's avatar Stergios Papadopoulos
Browse files

-Added _set_embedding_func, delete_collections, count, collection_exists, list_collections methods

-Improved vectorize method
-Checked if above work
parent c0ce9b0d
Loading
Loading
Loading
Loading
+91 −19
Original line number Diff line number Diff line
@@ -28,15 +28,17 @@ class Embedder(Consts):
        self.gpt_client = OpenAI()
        self.chroma_client = chromadb.PersistentClient(path=db_name)

        self._embedding_method = None

    ## PRIVATE METHODS
    def _gpt_embedding(self, gpt_model="text-embedding-3-small"):
        """
        Creates the vector using the specified gpt embedding model.
        Sets the embedding method to use gpt embeddings.
        :param gpt_model: The gpt embedding model.
        :return: The vector representing the given text
        """
        return embedding_functions.OpenAIEmbeddingFunction(
                api_key="YOUR_API_KEY",
        self._embedding_method = embedding_functions.OpenAIEmbeddingFunction(
                api_key=os.environ["OPENAI_API_KEY"],
                model_name=gpt_model
            )

@@ -52,6 +54,22 @@ class Embedder(Consts):
            color = chunk.metadata.get("color", "black") #TODO put the hexadecimal value if needed
            yield _id, title, text, color

    def _set_embedding_func(self, model="text-embedding-3-small"):
        """
        Sets the embedding function for the vectorization by changing the self._embedding_method.
        :param model: The model that will be used for embedding.
        :return: None
        """

        # Check if model is an openai model
        openai_models = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
        if model in openai_models:
            self._gpt_embedding(gpt_model=model)
            print(f"Embedding function set successfully! Embedding model is {model}")
        else:
            print(f"Currently the only supported embedding models are {', '.join(openai_models)}")


    ## CALLABLE METHODS ##
    def chunk_docs(self, chunking_type=None, color=None):
        """
@@ -109,32 +127,27 @@ class Embedder(Consts):
        else:
            raise Exception("You need to chunk the loaded documents first!!! Use chunk_docs() method")

    def vectorize(self, ef=None, collection_name="collection"):
    def vectorize(self, embedding_model="text-embedding-3-small", collection_name="collection"):
        """
        Creates a collection named as specified and adds the chunked data.
        Creates a collection named as specified and embedds the chunked data.
        If the name is the same as an existing collection nothing will happen.
        :param ef: The embedding function to be used.
        :param embedding_model: The embedding model to be used. Default is gpt embeddings using text-embedding-3-small model.
        :param collection_name: The name of the collection in the vector db.
        :return: The path of the vector DB created
        """

        # Check if ef is defined else use gpt embeddings.
        if not ef:
            ef = self._gpt_embedding()
        # Set embedding model
        self._set_embedding_func(model=embedding_model)

        # Check if the specified collection name already exists if not create it.
        try:
            collection = self.chroma_client.get_collection(name=collection_name)
        except :
            pass
        else:
            print("collection already exists!")
        if self.collection_exists(collection_name):
            print("Collection exists!")
            return None


        collection = self.chroma_client.create_collection(
            name=collection_name,
            embedding_function=ef,
            embedding_function=self._embedding_method,
            metadata={
                "hnsw:space": "cosine"
            }
@@ -156,8 +169,33 @@ class Embedder(Consts):
    def add_to_vectordb(self):
        pass

    def delete_vectordb(self):
        pass
    def delete_collections(self, collections_to_delete=None) -> list[str]:
        """
        Deletes the specified collection.
        :param collections_to_delete: list of collection/s to delete. If 'all' is given all the collections will be deleted
        :return: list of deleted collections
        """

        if not collections_to_delete:
            raise Exception("You need to specify the collection to delete. Give 'all' to delete all collections.")

        existing_collections = self.list_collections()
        deleted_collections = []

        if existing_collections:
            if collections_to_delete == "all":
                for collection in existing_collections:
                    self.chroma_client.delete_collection(collection)
                    deleted_collections.append(collection)
            else:
                for collection in collections_to_delete:
                    if self.collection_exists(collection):
                        self.chroma_client.delete_collection(collection)
                        deleted_collections.append(collection)
                    else:
                        print(f"Collection: {collection} does not exist.")
        return deleted_collections


    def search_vectordb(self):
        pass
@@ -165,6 +203,38 @@ class Embedder(Consts):
    def similarity_check(self):
        pass

    def count(self, collection_name) -> int:
        """
        Returns the number of items in the specified collection.
        :param collection_name: The name of the collection.
        :return: int
        """
        if self.collection_exists(collection_name):
            return self.chroma_client.get_collection(collection_name).count()

    def collection_exists(self, collection_name) -> bool:
        """
        Checks if collection exists.
        :param collection_name:
        :return: True if collection exists False otherwise.
        """
        try:
            self.chroma_client.get_collection(name=collection_name)
        except:
            return False
        else:
            return True

    def list_collections(self) -> list[str]:
        """
        Lists all collection names.
        :return: Sequence[CollectionName] - A list of collection names. CollectionName is a string.
        """

        return self.chroma_client.list_collections()



    def has_loaded_docs(self):
        """
        Checks if documents have been loaded or not.
@@ -175,6 +245,8 @@ class Embedder(Consts):

embedder = Embedder()
embedder.load_docs(chunking_type=Embedder.ByChar, directory="aiani dedomena/2009-04-22-14-52-16.pdf")
embedder.vectorize(collection_name="nco")
embedder.vectorize(collection_name="fgfg")
print(embedder.count(collection_name="fgfg"))
print(embedder.delete_collections(collections_to_delete="all"))