Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Final Optimized MANIT RAG Chatbot""" | |
| from typing import List, Dict | |
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| import pickle | |
| import os | |
| import time | |
| from sentence_transformers import SentenceTransformer | |
| from src.retrieval.semantic_retriever import OptimizedSemanticRetriever # Updated import | |
| from src.generation.response_generator import ResponseGenerator | |
| from config.settings import config | |
| num_cores = os.cpu_count() | |
| print(f"Number of CPU cores: {num_cores}") | |
| class OptimizedMANITChatbot: | |
| """Performance optimized chatbot class""" | |
| def __init__(self): | |
| self.initialized = False | |
| self.initialization_status = "Starting initialization..." | |
| self.setup_components() | |
| def setup_components(self): | |
| """Initialize components with performance monitoring""" | |
| try: | |
| print("=== MANIT Chatbot Initialization ===") | |
| self.initialization_status = "Loading vector store files..." | |
| load_start = time.time() | |
| # Load vector store components | |
| self.embeddings = np.load(os.path.join(config.VECTOR_STORE_PATH, "embeddings.npy")) | |
| self.faiss_index = faiss.read_index(os.path.join(config.VECTOR_STORE_PATH, "faiss_index.bin")) | |
| with open(os.path.join(config.VECTOR_STORE_PATH, "chunks.pkl"), "rb") as f: | |
| self.chunks = pickle.load(f) | |
| with open(os.path.join(config.VECTOR_STORE_PATH, "bm25.pkl"), "rb") as f: | |
| self.bm25 = pickle.load(f) | |
| with open(os.path.join(config.VECTOR_STORE_PATH, "relationships.pkl"), "rb") as f: | |
| self.relationships = pickle.load(f) | |
| load_time = time.time() - load_start | |
| print(f"Vector store loaded in {load_time:.2f}s") | |
| self.initialization_status = "Loading embedding model..." | |
| model_start = time.time() | |
| # Initialize embedding model | |
| self.embedding_model = SentenceTransformer(config.EMBEDDING_MODEL, device='cpu') | |
| model_time = time.time() - model_start | |
| print(f"Embedding model loaded in {model_time:.2f}s") | |
| self.initialization_status = "Initializing retrieval components..." | |
| # Initialize optimized retriever | |
| self.retriever = OptimizedSemanticRetriever( | |
| embedding_model=self.embedding_model, | |
| faiss_index=self.faiss_index, | |
| chunks=self.chunks, | |
| bm25_index=self.bm25, | |
| relationships=self.relationships | |
| ) | |
| # Initialize response generator | |
| self.generator = ResponseGenerator() | |
| self.initialization_status = "Warming up system..." | |
| # Warm up with a test query | |
| warmup_start = time.time() | |
| test_chunks = self.retriever.retrieve("test warmup query") | |
| warmup_time = time.time() - warmup_start | |
| print(f"System warmup completed in {warmup_time:.2f}s") | |
| total_time = time.time() - (load_start - warmup_start + load_start) | |
| self.initialization_status = "Ready!" | |
| self.initialized = True | |
| print(f"=== Initialization Complete in {total_time:.2f}s ===") | |
| print(f"Performance Mode: {config.PERFORMANCE_MODE}") | |
| print(f"Retrieval K: {config.retrieval_k}") | |
| print(f"Using Reranker: {config.use_reranker}") | |
| except Exception as e: | |
| print(f"Initialization failed: {e}") | |
| self.initialization_status = f"Error: {str(e)}" | |
| def process_query_stream(self, query: str): | |
| """Stream response with performance monitoring""" | |
| if not self.initialized: | |
| yield f"System Error: {self.initialization_status}" | |
| return | |
| if not query.strip(): | |
| yield "Please enter a question about MANIT Bhopal." | |
| return | |
| try: | |
| print(f"\n--- Processing Query: {query} ---") | |
| total_start = time.time() | |
| # Retrieve relevant documents | |
| retrieval_start = time.time() | |
| retrieved_chunks = self.retriever.retrieve(query) | |
| retrieval_time = time.time() - retrieval_start | |
| if not retrieved_chunks: | |
| yield "I couldn't find relevant information about this topic. Please try another question." | |
| return | |
| print(f"Retrieved {len(retrieved_chunks)} chunks in {retrieval_time:.2f}s") | |
| # Format context | |
| context = self._format_context(retrieved_chunks) | |
| # Check if web search is needed | |
| web_context = "" | |
| if self.generator.needs_web_search(query, context): | |
| web_search_start = time.time() | |
| web_results = self.generator.web_search(query) | |
| web_search_time = time.time() - web_search_start | |
| print(f"Web search completed in {web_search_time:.2f}s") | |
| if web_results: | |
| web_context = "\n\n".join(web_results) | |
| # Stream the response | |
| generation_start = time.time() | |
| response_chunks = 0 | |
| for chunk in self.generator.generate_response_stream(query, context, web_context): | |
| response_chunks += 1 | |
| yield chunk | |
| generation_time = time.time() - generation_start | |
| total_time = time.time() - total_start | |
| print(f"Response generated in {generation_time:.2f}s ({response_chunks} chunks)") | |
| print(f"Total query time: {total_time:.2f}s") | |
| except Exception as e: | |
| print(f"Error processing query: {e}") | |
| yield "I encountered an error processing your question. Please try again." | |
| def _format_context(self, chunks: List[Dict]) -> str: | |
| """Format context for the prompt""" | |
| context_parts = [] | |
| for chunk in chunks: | |
| source = chunk['metadata']['source'] | |
| content = chunk['content'] | |
| context_parts.append(f"Source: {source}\nContent: {content}") | |
| return "\n\n---\n\n".join(context_parts) | |
| def create_interface(): | |
| """Create performance optimized Gradio interface""" | |
| print("Initializing MANIT Chatbot Interface...") | |
| chatbot_instance = OptimizedMANITChatbot() | |
| def chat_fn(message, history): | |
| """Optimized chat function with better error handling""" | |
| if not chatbot_instance.initialized: | |
| error_msg = f"β οΈ System Status: {chatbot_instance.initialization_status}" | |
| history.append([message, error_msg]) | |
| return history, "" | |
| # Add the user's message to history | |
| history.append([message, ""]) | |
| try: | |
| # Stream the response | |
| for chunk in chatbot_instance.process_query_stream(message): | |
| history[-1][1] += chunk | |
| yield history, "" | |
| except Exception as e: | |
| print(f"Chat function error: {e}") | |
| history[-1][1] = "I encountered an error. Please try again." | |
| yield history, "" | |
| with gr.Blocks( | |
| title="MANIT Bhopal Expert Assistant - Optimized", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.HTML(f""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>π MANIT Bhopal Assistant</h1> | |
| <p>Ask questions about programs, admissions, faculty, facilities, research, and more.</p> | |
| </div> | |
| """) | |
| chatbot_ui = gr.Chatbot( | |
| height=500, | |
| show_label=False, | |
| avatar_images=[None, "π"], | |
| show_copy_button=True, | |
| placeholder="Hi! I'm your assistant." | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask about MANIT Bhopal...", | |
| scale=8, | |
| lines=2 | |
| ) | |
| submit = gr.Button("Send", scale=1, variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| "Who is the director of MANIT?", | |
| "What are the dispensary timings?", | |
| "Tell me about the computer science department", | |
| "What research facilities are available?", | |
| "What are the guest house prices?" | |
| ], | |
| inputs=msg, | |
| label="Example Questions" | |
| ) | |
| gr.HTML(""" | |
| <div class="performance-info" style="text-align: center; margin-top: 10px;"> | |
| <p>Optimized for faster response times while maintaining accuracy</p> | |
| </div> | |
| """) | |
| # Event handlers | |
| msg.submit(chat_fn, [msg, chatbot_ui], [chatbot_ui, msg]) | |
| submit.click(chat_fn, [msg, chatbot_ui], [chatbot_ui, msg]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |