|
|
import os |
|
|
import gradio as gr |
|
|
import PyPDF2 |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client.models import VectorParams, Distance, PointStruct |
|
|
import cohere |
|
|
from uuid import uuid4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
QDRANT_URL = "https://9ecb3b08-e6fa-482c-a03f-e8cb3313a6c0.sa-east-1-0.aws.cloud.qdrant.io" |
|
|
QDRANT_API_KEY = "<your-qdrant-api-key>" |
|
|
COLLECTION_NAME = "Document" |
|
|
|
|
|
|
|
|
COHERE_API_KEY = "<your-cohere-api-key>" |
|
|
|
|
|
|
|
|
|
|
|
qdrant = QdrantClient( |
|
|
url=QDRANT_URL, |
|
|
api_key=QDRANT_API_KEY, |
|
|
) |
|
|
|
|
|
cohere_client = cohere.Client(COHERE_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") |
|
|
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
|
|
|
|
|
|
def setup_collection(): |
|
|
"""Create collection in Qdrant if it doesn't exist.""" |
|
|
collections = qdrant.get_collections().collections |
|
|
existing_names = [c.name for c in collections] |
|
|
|
|
|
if COLLECTION_NAME not in existing_names: |
|
|
qdrant.create_collection( |
|
|
collection_name=COLLECTION_NAME, |
|
|
vectors_config=VectorParams( |
|
|
size=384, |
|
|
distance=Distance.COSINE |
|
|
) |
|
|
) |
|
|
|
|
|
setup_collection() |
|
|
|
|
|
|
|
|
|
|
|
def load_pdf(file): |
|
|
"""Extract raw text from a PDF file.""" |
|
|
|
|
|
reader = PyPDF2.PdfReader(file) |
|
|
text = "" |
|
|
for page in reader.pages: |
|
|
page_text = page.extract_text() |
|
|
if page_text: |
|
|
text += page_text |
|
|
return text |
|
|
|
|
|
|
|
|
def get_embeddings(text): |
|
|
"""Compute mean-pooled embeddings using MiniLM.""" |
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() |
|
|
return embeddings |
|
|
|
|
|
|
|
|
def upload_document_chunks(chunks): |
|
|
""" |
|
|
Insert document chunks into Qdrant as points. |
|
|
Each chunk becomes one point with a vector and payload. |
|
|
""" |
|
|
points = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
try: |
|
|
embedding = get_embeddings(chunk) |
|
|
|
|
|
points.append( |
|
|
PointStruct( |
|
|
id=str(uuid4()), |
|
|
vector=embedding.tolist(), |
|
|
payload={"content": chunk} |
|
|
) |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"β οΈ Skipped chunk due to error: {e}") |
|
|
|
|
|
if points: |
|
|
qdrant.upsert( |
|
|
collection_name=COLLECTION_NAME, |
|
|
points=points |
|
|
) |
|
|
|
|
|
|
|
|
def query_answer(query): |
|
|
"""Search Qdrant for the most relevant chunks to the query.""" |
|
|
query_embedding = get_embeddings(query) |
|
|
|
|
|
hits = qdrant.search( |
|
|
collection_name=COLLECTION_NAME, |
|
|
query_vector=query_embedding.tolist(), |
|
|
limit=3 |
|
|
) |
|
|
return hits |
|
|
|
|
|
|
|
|
def generate_response(context, query): |
|
|
"""Use Cohere to generate a natural language answer based on context.""" |
|
|
prompt = f"""You are a helpful assistant answering questions based only on the given context. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {query} |
|
|
Answer:""" |
|
|
|
|
|
response = cohere_client.generate( |
|
|
model="command", |
|
|
prompt=prompt, |
|
|
max_tokens=200, |
|
|
temperature=0.3 |
|
|
) |
|
|
return response.generations[0].text.strip() |
|
|
|
|
|
|
|
|
def qa_pipeline(pdf_file, query): |
|
|
""" |
|
|
Full QA pipeline: |
|
|
1. Read PDF |
|
|
2. Chunk text |
|
|
3. Store chunks in Qdrant |
|
|
4. Search relevant chunks for query |
|
|
5. Generate answer via Cohere |
|
|
""" |
|
|
if pdf_file is None: |
|
|
return "β οΈ Please upload a PDF first.", "" |
|
|
|
|
|
if not query or query.strip() == "": |
|
|
return "β οΈ Please enter a question.", "" |
|
|
|
|
|
|
|
|
document_text = load_pdf(pdf_file) |
|
|
|
|
|
if not document_text.strip(): |
|
|
return "β οΈ No extractable text found in the PDF.", "" |
|
|
|
|
|
|
|
|
chunk_size = 500 |
|
|
document_chunks = [ |
|
|
document_text[i:i + chunk_size] |
|
|
for i in range(0, len(document_text), chunk_size) |
|
|
] |
|
|
|
|
|
|
|
|
upload_document_chunks(document_chunks) |
|
|
|
|
|
|
|
|
hits = query_answer(query) |
|
|
|
|
|
if not hits: |
|
|
return "β οΈ No relevant document segments found.", "I couldn't find an answer based on the document." |
|
|
|
|
|
context = " ".join([hit.payload.get("content", "") for hit in hits]) |
|
|
|
|
|
|
|
|
answer = generate_response(context, query) |
|
|
|
|
|
return context, answer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme="compact") as demo: |
|
|
gr.Markdown(""" |
|
|
<div style="text-align: center; font-size: 28px; font-weight: bold; margin-bottom: 20px; color: #2D3748;"> |
|
|
π Interactive PDF QA Bot (Qdrant + Cohere) π |
|
|
</div> |
|
|
<p style="text-align: center; font-size: 16px; color: #4A5568;"> |
|
|
Upload a PDF document, ask a question, and get answers grounded in the document content. |
|
|
</p> |
|
|
<hr style="border: 1px solid #CBD5E0; margin: 20px 0;"> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
pdf_input = gr.File(label="π Upload PDF", file_types=[".pdf"]) |
|
|
query_input = gr.Textbox( |
|
|
label="β Ask a Question", |
|
|
placeholder="Enter your question here..." |
|
|
) |
|
|
submit_button = gr.Button("π Submit") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
doc_segments_output = gr.Textbox( |
|
|
label="π Retrieved Document Segments", |
|
|
lines=10 |
|
|
) |
|
|
answer_output = gr.Textbox( |
|
|
label="π¬ Answer", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
submit_button.click( |
|
|
fn=qa_pipeline, |
|
|
inputs=[pdf_input, query_input], |
|
|
outputs=[doc_segments_output, answer_output] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
<style> |
|
|
body { |
|
|
background-color: #EDF2F7; |
|
|
} |
|
|
input[type="file"] { |
|
|
background-color: #3182CE; |
|
|
color: white; |
|
|
padding: 8px; |
|
|
border-radius: 5px; |
|
|
} |
|
|
button { |
|
|
background-color: #3182CE; |
|
|
color: white; |
|
|
padding: 10px; |
|
|
font-size: 16px; |
|
|
border-radius: 5px; |
|
|
cursor: pointer; |
|
|
border: none; |
|
|
} |
|
|
button:hover { |
|
|
background-color: #2B6CB0; |
|
|
} |
|
|
textarea { |
|
|
border: 2px solid #CBD5E0; |
|
|
border-radius: 8px; |
|
|
padding: 10px; |
|
|
background-color: #FAFAFA; |
|
|
} |
|
|
</style> |
|
|
""") |
|
|
|
|
|
demo.launch(share=True) |