import os
import getpass
from typing import Union, List
from premai import Prem
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
def get_embeddings(
project_id: int,
embedding_model: str,
documents: Union[str, List[str]]
) -> List[List[float]]:
"""
Helper function to get the embeddings from premai sdk
Args
project_id (int): The project id from prem saas platform.
embedding_model (str): The embedding model alias to choose
documents (Union[str, List[str]]): Single texts or list of texts to embed
Returns:
List[List[int]]: A list of list of integers that represents different
embeddings
"""
embeddings = []
documents = [documents] if isinstance(documents, str) else documents
for embedding in prem_client.embeddings.create(
project_id=project_id,
model=embedding_model,
input=documents
).data:
embeddings.append(embedding.embedding)
return embeddings
if __name__ == "__main__":
if os.environ.get("PREMAI_API_KEY") is None:
os.environ["PREMAI_API_KEY"] = getpass.getpass("PremAI API Key:")
# Note: project_id: 123 is a dummy project id
# You need to have an actual project ID here. Otherwise, it will throw an error.
PROJECT_ID = 123
EMBEDDING_MODEL = "text-embedding-3-large"
COLLECTION_NAME = "prem-collection-py"
QDRANT_SERVER_URL = "http://127.0.0.1:6333"
DOCUMENTS = [
"This is a sample python document",
"We will be using qdrant and premai python sdk"
]
api_key = os.environ["PREMAI_API_KEY"]
prem_client = Prem(api_key=api_key)
qdrant_client = QdrantClient(url=QDRANT_SERVER_URL)
# Get the embedding and create Qdrant points
embeddings = get_embeddings(
project_id=PROJECT_ID,
embedding_model=EMBEDDING_MODEL,
documents=DOCUMENTS
)
points = [
PointStruct(
id=idx,
vector=embedding,
payload={"text": text},
) for idx, (embedding, text) in enumerate(zip(embeddings, DOCUMENTS))
]
# Create a collection. Comment this if this is created already
qdrant_client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=3072, distance=Distance.DOT)
)
# Upload all the documents to the collection
doc_ids = list(range(len(embeddings)))
qdrant_client.upsert(
collection_name=COLLECTION_NAME,
points=points
)
# Query your Collection
query = "what is the extension of python document"
query_embedding = get_embeddings(
project_id=PROJECT_ID,
embedding_model=EMBEDDING_MODEL,
documents=query
)
qdrant_client.search(collection_name=COLLECTION_NAME, query_vector=query_embedding[0])