QA_Bot / app.py
gaur3009's picture
Update app.py
4d290ee verified
import os
import gradio as gr
import PyPDF2
import torch
from transformers import AutoTokenizer, AutoModel
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
import cohere
from uuid import uuid4
# -------------------- CONFIG --------------------
# Qdrant Cloud config
QDRANT_URL = "https://9ecb3b08-e6fa-482c-a03f-e8cb3313a6c0.sa-east-1-0.aws.cloud.qdrant.io" # e.g. "https://xxxxxx-xxxxx-xxxxx-xxxx-xxxxxxxxx.us-east.aws.cloud.qdrant.io:6333"
QDRANT_API_KEY = "<your-qdrant-api-key>"
COLLECTION_NAME = "Document"
# Cohere config
COHERE_API_KEY = "<your-cohere-api-key>"
# -------------------- INITIALIZE CLIENTS --------------------
qdrant = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
)
cohere_client = cohere.Client(COHERE_API_KEY)
# -------------------- LOAD EMBEDDING MODEL --------------------
# Using MiniLM (384-dim)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# -------------------- VECTOR COLLECTION SETUP --------------------
def setup_collection():
"""Create collection in Qdrant if it doesn't exist."""
collections = qdrant.get_collections().collections
existing_names = [c.name for c in collections]
if COLLECTION_NAME not in existing_names:
qdrant.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=384, # MiniLM-L6-v2 output size
distance=Distance.COSINE
)
)
setup_collection()
# -------------------- UTILITY FUNCTIONS --------------------
def load_pdf(file):
"""Extract raw text from a PDF file."""
# Gradio's File component passes a tempfile path or file object
reader = PyPDF2.PdfReader(file)
text = ""
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text
return text
def get_embeddings(text):
"""Compute mean-pooled embeddings using MiniLM."""
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model(**inputs)
# Mean pooling over sequence length
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
return embeddings
def upload_document_chunks(chunks):
"""
Insert document chunks into Qdrant as points.
Each chunk becomes one point with a vector and payload.
"""
points = []
for chunk in chunks:
try:
embedding = get_embeddings(chunk)
points.append(
PointStruct(
id=str(uuid4()), # unique ID per chunk
vector=embedding.tolist(),
payload={"content": chunk}
)
)
except Exception as e:
print(f"⚠️ Skipped chunk due to error: {e}")
if points:
qdrant.upsert(
collection_name=COLLECTION_NAME,
points=points
)
def query_answer(query):
"""Search Qdrant for the most relevant chunks to the query."""
query_embedding = get_embeddings(query)
hits = qdrant.search(
collection_name=COLLECTION_NAME,
query_vector=query_embedding.tolist(),
limit=3
)
return hits
def generate_response(context, query):
"""Use Cohere to generate a natural language answer based on context."""
prompt = f"""You are a helpful assistant answering questions based only on the given context.
Context:
{context}
Question: {query}
Answer:"""
response = cohere_client.generate(
model="command",
prompt=prompt,
max_tokens=200,
temperature=0.3
)
return response.generations[0].text.strip()
def qa_pipeline(pdf_file, query):
"""
Full QA pipeline:
1. Read PDF
2. Chunk text
3. Store chunks in Qdrant
4. Search relevant chunks for query
5. Generate answer via Cohere
"""
if pdf_file is None:
return "⚠️ Please upload a PDF first.", ""
if not query or query.strip() == "":
return "⚠️ Please enter a question.", ""
# 1. Extract text
document_text = load_pdf(pdf_file)
if not document_text.strip():
return "⚠️ No extractable text found in the PDF.", ""
# 2. Simple character-based chunking
chunk_size = 500
document_chunks = [
document_text[i:i + chunk_size]
for i in range(0, len(document_text), chunk_size)
]
# 3. Upload chunks to Qdrant
upload_document_chunks(document_chunks)
# 4. Search relevant chunks
hits = query_answer(query)
if not hits:
return "⚠️ No relevant document segments found.", "I couldn't find an answer based on the document."
context = " ".join([hit.payload.get("content", "") for hit in hits])
# 5. Generate answer
answer = generate_response(context, query)
return context, answer
# -------------------- GRADIO UI --------------------
with gr.Blocks(theme="compact") as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 28px; font-weight: bold; margin-bottom: 20px; color: #2D3748;">
πŸ“„ Interactive PDF QA Bot (Qdrant + Cohere) πŸ”
</div>
<p style="text-align: center; font-size: 16px; color: #4A5568;">
Upload a PDF document, ask a question, and get answers grounded in the document content.
</p>
<hr style="border: 1px solid #CBD5E0; margin: 20px 0;">
""")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(label="πŸ“ Upload PDF", file_types=[".pdf"])
query_input = gr.Textbox(
label="❓ Ask a Question",
placeholder="Enter your question here..."
)
submit_button = gr.Button("πŸ” Submit")
with gr.Column(scale=2):
doc_segments_output = gr.Textbox(
label="πŸ“œ Retrieved Document Segments",
lines=10
)
answer_output = gr.Textbox(
label="πŸ’¬ Answer",
lines=3
)
submit_button.click(
fn=qa_pipeline,
inputs=[pdf_input, query_input],
outputs=[doc_segments_output, answer_output]
)
gr.Markdown("""
<style>
body {
background-color: #EDF2F7;
}
input[type="file"] {
background-color: #3182CE;
color: white;
padding: 8px;
border-radius: 5px;
}
button {
background-color: #3182CE;
color: white;
padding: 10px;
font-size: 16px;
border-radius: 5px;
cursor: pointer;
border: none;
}
button:hover {
background-color: #2B6CB0;
}
textarea {
border: 2px solid #CBD5E0;
border-radius: 8px;
padding: 10px;
background-color: #FAFAFA;
}
</style>
""")
demo.launch(share=True)