Commit ac24d108 authored by Stergios Papadopoulos's avatar Stergios Papadopoulos
Browse files

-Changed vectorize method to _create_collection and add_data

-changed collection_exists to _get_collection and _get_collection_error
-Repair load_docs color loading
-Added search_similar but needs fixing i must create pickle saving class for saving embedding methods related to colections
-Must set search_similar to class method
parent 1eac56c3
Loading
Loading
Loading
Loading
+102 −59
Original line number Diff line number Diff line
import glob
from typing import Any

from mydoc import MyDoc
from constantscls import Consts
from openai import OpenAI
@@ -69,12 +71,59 @@ class Embedder(Consts):
        else:
            print(f"Currently the only supported embedding models are {', '.join(openai_models)}")

    def _create_collection(self, collection_name, embedding_model) -> chromadb.Collection:
        """
        Creates an empty collection named as specified. If collection exists raises Exception.
        :param collection_name: The name of the collection
        :param embedding_model: The embedding model.
        :return: The collection created.
        """

        # Set embedding model
        self._set_embedding_func(model=embedding_model)

        # Check if the specified collection name already exists.
        if self._get_collection(collection_name):
            raise Exception("Collection already exists.")

        # Creates and returns the collection
        return self.chroma_client.create_collection(
            name=collection_name,
            embedding_function=self._embedding_method,
            metadata={
                "hnsw:space": "cosine"
            }
        )

    def _get_collection(self, collection_name) -> Any | None:
        """
        Checks if collection exists. If exists returns it.
        :param collection_name: collection name
        :return: Collection if exists else None.
        """
        try:
            embedding_function=self._embedding_method
            return self.chroma_client.get_collection(name=collection_name, embedding_function=embedding_function)
        except:
            return None

    def _get_collection_error(self, collection_name):
        """
        Checks if collection exists. If exists returns it else raises error.
        :param collection_name: collection name
        :return: Collection if exists.
        """
        try:
            embedding_function = self._embedding_method
            return self.chroma_client.get_collection(name=collection_name, embedding_function=embedding_function)
        except:
            raise Exception("Collection does not exists!!!")

    ## CALLABLE METHODS ##
    def chunk_docs(self, chunking_type=None, color=None):
        """
        Chunks the documents in a specified type.
        :param chunking_type: Constant from Conts class.
        :param chunking_type: Constant from Consts class.
        :param color: color to add to the chunks if specified.
        :return: None
        """
@@ -91,16 +140,27 @@ class Embedder(Consts):
    def load_docs(self, directory="aiani dedomena/*", chunking_type=None, colors=None) -> None:
        """
        Loads the pdfs in MyDoc parser and saves them in self._docs.
        Also, if specified chunks the documents in the desired method.
        If specified a color metadata will be added.
        :param colors: List of hexadecimal colors to add as metadata in the chunks.
        Also, if chunking type is provided, the documents get chunked in the desired method.
        If specified, a color metadata will be added to the docs.
        :param colors: List of hexadecimal colors to add as metadata in the documents. The length of the list should be equal to the pdfs
        :param chunking_type: The chunking method of the documents.
        :param directory: the path of the directory where the pdfs to load are located, should "smth/*".
        :return: None
        """

        # Load documents
        # Load document's paths
        doc_paths = glob.glob(directory) # Load document paths

        # Check for colors
        if colors:
            colors_len = len(colors)
            doc_paths_len = len(doc_paths)
            if colors_len > doc_paths_len:
                raise Exception("You provided more colors than documents.")
            elif colors_len < doc_paths_len:
                raise Exception("You provided less colors than documents length.")

        # Load documents
        for i in range(len(doc_paths)):
            color = colors[i] if colors else None
            doc_path = doc_paths[i]
@@ -108,7 +168,7 @@ class Embedder(Consts):

            # Load Chunks if specified
        if chunking_type:
            self.chunk_docs()
            self.chunk_docs(chunking_type=chunking_type)

    def get_docs(self) -> list[MyDoc]:
        """
@@ -127,33 +187,18 @@ class Embedder(Consts):
        else:
            raise Exception("You need to chunk the loaded documents first!!! Use chunk_docs() method")

    def vectorize(self, embedding_model="text-embedding-3-small", collection_name="collection"):
    def add_data(self, collection_name, embedding_model="text-embedding-3-small"):
        """
        Creates a collection named as specified and embedds the chunked data.
        If the name is the same as an existing collection nothing will happen.
        :param embedding_model: The embedding model to be used. Default is gpt embeddings using text-embedding-3-small model.
        :param collection_name: The name of the collection in the vector db.
        :return: The path of the vector DB created
        Adds data to the specified collection. The data are the saved chunks.
        :param embedding_model: The embedding model. Default is 'text-embedding-3-small'.
        :param collection_name: The name of the collection to add the data.
        :return: None
        """
        # If collection exists get it else create it.
        collection = self._get_collection(collection_name)
        if not collection:
            collection = self._create_collection(collection_name, embedding_model)

        # Set embedding model
        self._set_embedding_func(model=embedding_model)

        # Check if the specified collection name already exists if not create it.
        if self.collection_exists(collection_name):
            print("Collection exists!")
            return None


        collection = self.chroma_client.create_collection(
            name=collection_name,
            embedding_function=self._embedding_method,
            metadata={
                "hnsw:space": "cosine"
            }
        )

        # Add the data
        collection.add(
            documents=[text for _, _, text, _ in self._vectors_generator()],
            metadatas=[
@@ -166,10 +211,8 @@ class Embedder(Consts):
            ids=[_id for _id, _, _, _ in self._vectors_generator()]
        )

    def add_to_vectordb(self):
        pass

    def delete_collections(self, collections_to_delete=None) -> list[str]:
    def delete_collections(self, collections_to_delete) -> list[str]:
        """
        Deletes the specified collection.
        :param collections_to_delete: list of collection/s to delete. If 'all' is given all the collections will be deleted
@@ -189,7 +232,7 @@ class Embedder(Consts):
                    deleted_collections.append(collection)
            else:
                for collection in collections_to_delete:
                    if self.collection_exists(collection):
                    if self._get_collection(collection):
                        self.chroma_client.delete_collection(collection)
                        deleted_collections.append(collection)
                    else:
@@ -197,34 +240,36 @@ class Embedder(Consts):
        return deleted_collections


    def search_vectordb(self):
        pass
    def search_similar(self, collection_name, *input_text, n_results=3) -> list[str]:
        """
        Searches specified collection for similar text chunks according to given input_text.
        :param n_results: How many results to return.
        :param collection_name: The collection to search to.
        :param input_text: The text chunk/s to search for similar chunks.
        :return: list of results.
        """

        query_text = list(input_text)

    def similarity_check(self):
        pass
        # Set embedding function

        # Get collection or raise error if it doesn't exist.
        collection = self._get_collection_error(collection_name)

        return collection.query(
            query_texts=query_text,
            n_results=n_results,
        )

    def count(self, collection_name) -> int:
        """
        Returns the number of items in the specified collection.
        Counts the number of items in the specified collection.
        :param collection_name: The name of the collection.
        :return: int
        :return: The number of items in the collection
        """
        if self.collection_exists(collection_name):
        if self._get_collection(collection_name):
            return self.chroma_client.get_collection(collection_name).count()

    def collection_exists(self, collection_name) -> bool:
        """
        Checks if collection exists.
        :param collection_name:
        :return: True if collection exists False otherwise.
        """
        try:
            self.chroma_client.get_collection(name=collection_name)
        except:
            return False
        else:
            return True

    def list_collections(self) -> list[str]:
        """
        Lists all collection names.
@@ -244,9 +289,7 @@ class Embedder(Consts):


embedder = Embedder()
embedder.load_docs(chunking_type=Embedder.ByChar, directory="aiani dedomena/2009-04-22-14-52-16.pdf")
embedder.vectorize(collection_name="fgfg")
print(embedder.count(collection_name="fgfg"))
print(embedder.delete_collections(collections_to_delete="all"))
print(embedder.search_similar("Mycollection", "Τι είναι το σπίτι με τις σκάλες?"))