Commit c0ce9b0d authored by Stergios Papadopoulos's avatar Stergios Papadopoulos
Browse files

-Switches to chroma db.

-Getting error in the vectorize method.
parent 5dc21511
Loading
Loading
Loading
Loading
+52 −37
Original line number Diff line number Diff line
@@ -4,13 +4,18 @@ from constantscls import Consts
from openai import OpenAI
from dotenv import load_dotenv
import os
from pymilvus import MilvusClient

import chromadb
import chromadb.utils.embedding_functions as embedding_functions

class Embedder(Consts):

    GPTembed = "text-embedding-3-small"
    def __init__(self):
    def __init__(self, db_name="chromadb"):
        """
        In call a chroma db is created named as specified.
        Call vectorize() to create a collection and add the data stored in _chunks list.
        :param db_name: The name of the database.
        """

        # Load env variables
        load_dotenv()
@@ -20,30 +25,32 @@ class Embedder(Consts):
        self._docs: list[MyDoc] = []
        self._chunks = []
        self._vectorDBs = []
        self.client = OpenAI()
        self.gpt_client = OpenAI()
        self.chroma_client = chromadb.PersistentClient(path=db_name)

    ## PRIVATE METHODS
    def _get_embedding_gpt(self, text, dimensions=1024):
    def _gpt_embedding(self, gpt_model="text-embedding-3-small"):
        """
        Creates the vector using the specified embedding model.
        :param text: The text to embed.
        :param dimensions: The dimensions of the vector
        Creates the vector using the specified gpt embedding model.
        :param gpt_model: The gpt embedding model.
        :return: The vector representing the given text
        """
        return self.client.embeddings.create(input=[text], model=Embedder.GPTembed, dimensions=dimensions).data[0].embedding
        return embedding_functions.OpenAIEmbeddingFunction(
                api_key="YOUR_API_KEY",
                model_name=gpt_model
            )

    def _vectors_generator(self, dimensions=1024):
    def _vectors_generator(self):
        """
        A generator that returns vector representations if the loaded chunks.
        A generator that returns the metadata of each saved chunk.
        """

        for chunk in self.get_chunks():
            vector = self._get_embedding_gpt(chunk.page_content, dimensions)
            _id = chunk.metadata["_id"]
            title = chunk.metadata["title"]
            text  = chunk.metadata["page_content"]
            text = chunk.page_content
            color = chunk.metadata.get("color", "black") #TODO put the hexadecimal value if needed
            yield _id, vector, title, text, color
            yield _id, title, text, color

    ## CALLABLE METHODS ##
    def chunk_docs(self, chunking_type=None, color=None):
@@ -102,41 +109,49 @@ class Embedder(Consts):
        else:
            raise Exception("You need to chunk the loaded documents first!!! Use chunk_docs() method")

    def vectorize(self, vectordb_name=None, collection_name="collection", dimensions=1024):
    def vectorize(self, ef=None, collection_name="collection"):
        """
        Creates the VectoDB named as specified.
        If the name is the same as an existing vectorDB the new one will have an auto increasing number.
        :param dimensions: The dimensions that the vectors will be represented.
        Creates a collection named as specified and adds the chunked data.
        If the name is the same as an existing collection nothing will happen.
        :param ef: The embedding function to be used.
        :param collection_name: The name of the collection in the vector db.
        :param vectordb_name: The name of the vector database. Make sure to include .db in the end
        :return: The path of the vector DB created
        """

        if ".db" not in vectordb_name:
            raise Exception("Please include the '.db' at the end of the vector database name!!!")
        # Check if ef is defined else use gpt embeddings.
        if not ef:
            ef = self._gpt_embedding()

        # Check if the specified collection name already exists if not create it.
        try:
            collection = self.chroma_client.get_collection(name=collection_name)
        except :
            pass
        else:
            print("collection already exists!")
            return None

        ## TODO CREATE THE CODE FOR SAME VECTOR DB NAMES

        client = MilvusClient(vectordb_name) # Create the vectorDB named as...
        client.create_collection(
            collection_name=collection_name,
            dimension=dimensions
        collection = self.chroma_client.create_collection(
            name=collection_name,
            embedding_function=ef,
            metadata={
                "hnsw:space": "cosine"
            }
        )

        # Prepare the data to for saving in the vector DB
        data = []
        for _id, vector, title, text, color in self._vectors_generator(dimensions):
            data.append(
        # Add the data
        collection.add(
            documents=[text for _, _, text, _ in self._vectors_generator()],
            metadatas=[
                {
                    "id": _id,
                    "vector": vector,
                    "title": title,
                    "text": text,
                    "color": color,
                    "subject": title
                }
                } for _id, title, text, color in self._vectors_generator()
            ],
            ids=[_id for _id, _, _, _ in self._vectors_generator()]
        )
        res = client.insert(collection_name=collection_name, data=data)

    def add_to_vectordb(self):
        pass
@@ -160,6 +175,6 @@ class Embedder(Consts):

embedder = Embedder()
embedder.load_docs(chunking_type=Embedder.ByChar, directory="aiani dedomena/2009-04-22-14-52-16.pdf")
embedder.vectorize("vectordb.db")
embedder.vectorize(collection_name="nco")