Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel | |
| from text_processing import TextProcessor | |
| import gc | |
| from pathlib import Path | |
| import concurrent.futures | |
| import time | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| from concurrent.futures import ThreadPoolExecutor # Add this import | |
| nltk.download('punkt') | |
| # Configure page | |
| st.set_page_config( | |
| page_title="Biomedical Papers Analysis", | |
| page_icon="π¬", | |
| layout="wide" | |
| ) | |
| # Initialize session state | |
| if 'relevant_papers' not in st.session_state: | |
| st.session_state.relevant_papers = None | |
| if 'relevance_scores' not in st.session_state: | |
| st.session_state.relevance_scores = None | |
| if 'processed_data' not in st.session_state: | |
| st.session_state.processed_data = None | |
| if 'summaries' not in st.session_state: | |
| st.session_state.summaries = None | |
| if 'text_processor' not in st.session_state: | |
| st.session_state.text_processor = None | |
| if 'processing_started' not in st.session_state: | |
| st.session_state.processing_started = False | |
| if 'focused_summary_generated' not in st.session_state: | |
| st.session_state.focused_summary_generated = False | |
| if 'current_model' not in st.session_state: | |
| st.session_state.current_model = None | |
| if 'current_tokenizer' not in st.session_state: | |
| st.session_state.current_tokenizer = None | |
| if 'model_type' not in st.session_state: | |
| st.session_state.model_type = None | |
| if 'focused_summary' not in st.session_state: | |
| st.session_state.focused_summary = None | |
| # TextProcessor class definition | |
| try: | |
| from text_processing import TextProcessor | |
| except ImportError: | |
| class TextProcessor: | |
| def find_most_relevant_abstracts(self, question, abstracts, top_k=5): | |
| return { | |
| 'top_indices': list(range(min(top_k, len(abstracts)))), | |
| 'scores': [1.0] * min(top_k, len(abstracts)) | |
| } | |
| def load_model(model_type): | |
| """Load appropriate model based on type with proper memory management""" | |
| try: | |
| # Clear any existing cached data | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| device = "cpu" # Force CPU usage | |
| if model_type == "summarize": | |
| # Load the new fine-tuned model directly | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "pendar02/bart-large-pubmedd", | |
| cache_dir="./models", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "pendar02/bart-large-pubmedd", | |
| cache_dir="./models" | |
| ) | |
| else: # question_focused | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "GanjinZero/biobart-base", | |
| cache_dir="./models", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| "pendar02/biobart-finetune", | |
| is_trainable=False | |
| ).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "GanjinZero/biobart-base", | |
| cache_dir="./models" | |
| ) | |
| model.eval() | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| raise | |
| def get_model(model_type): | |
| """Get model from session state or load if needed""" | |
| try: | |
| if (st.session_state.current_model is None or | |
| st.session_state.model_type != model_type): | |
| # Clean up existing model | |
| if st.session_state.current_model is not None: | |
| cleanup_model(st.session_state.current_model, | |
| st.session_state.current_tokenizer) | |
| # Load new model | |
| model, tokenizer = load_model(model_type) | |
| st.session_state.current_model = model | |
| st.session_state.current_tokenizer = tokenizer | |
| st.session_state.model_type = model_type | |
| return st.session_state.current_model, st.session_state.current_tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| st.session_state.processing_started = False | |
| return None, None | |
| def cleanup_model(model, tokenizer): | |
| """Properly cleanup model resources""" | |
| try: | |
| del model | |
| del tokenizer | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| except Exception: | |
| pass | |
| def process_excel(uploaded_file): | |
| """Process uploaded Excel file""" | |
| try: | |
| df = pd.read_excel(uploaded_file) | |
| required_columns = ['Abstract', 'Article Title', 'Authors', | |
| 'Source Title', 'Publication Year', 'DOI', | |
| 'Times Cited, All Databases'] | |
| # Check required columns first | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| st.error("β Missing required columns: " + ", ".join(missing_columns)) | |
| st.error("Please ensure your Excel file contains all required columns.") | |
| return None | |
| # Only proceed with validation if all required columns exist | |
| if len(df) > 5: | |
| st.error("β Your file contains more than 5 papers. Please upload a file with maximum 5 papers.") | |
| return None | |
| # Now safe to validate structure as we know columns exist | |
| is_valid, messages = validate_excel_structure(df) | |
| if not is_valid: | |
| for msg in messages: | |
| st.error(f"β {msg}") | |
| return None | |
| return df[required_columns] | |
| except Exception as e: | |
| st.error(f"β Error reading file: {str(e)}") | |
| st.error("Please check if your file is in the correct Excel format (.xlsx or .xls)") | |
| return None | |
| def validate_excel_structure(df): | |
| """Validate the structure and content of the Excel file""" | |
| validation_messages = [] | |
| # Check for minimum content | |
| if len(df) == 0: | |
| validation_messages.append("File contains no data") | |
| return False, validation_messages | |
| try: | |
| # Check publication year format - this is useful for sorting/filtering | |
| df['Publication Year'] = pd.to_numeric(df['Publication Year'], errors='coerce') | |
| if df['Publication Year'].isna().any(): | |
| validation_messages.append("Some publication years are invalid. Please ensure all years are in numeric format (e.g., 2024)") | |
| else: | |
| years = df['Publication Year'].dropna() | |
| if len(years) > 0: | |
| if years.min() < 1900 or years.max() > 2025: | |
| validation_messages.append("Publication years must be between 1900 and 2025") | |
| # For short abstracts - just show a warning | |
| short_abstracts = df['Abstract'].fillna('').astype(str).str.len() < 50 | |
| if short_abstracts.any(): | |
| st.warning("βΉοΈ Some abstracts are quite short, but will still be processed") | |
| except Exception as e: | |
| validation_messages.append(f"Error checking data format: {str(e)}") | |
| return len(validation_messages) == 0, validation_messages | |
| def preprocess_text(text): | |
| """Clean biomedical text by handling common formatting issues and standardizing structure.""" | |
| if not isinstance(text, str) or not text.strip(): | |
| return text | |
| # Remove extra whitespace | |
| text = ' '.join(text.split()) | |
| # Roman numeral conversion | |
| roman_map = {'i': '1', 'ii': '2', 'iii': '3', 'iv': '4', 'v': '5', | |
| 'vi': '6', 'vii': '7', 'viii': '8', 'ix': '9', 'x': '10'} | |
| def replace_roman(match): | |
| roman = match.group(1).lower() | |
| return f"({roman_map.get(roman, roman)})" | |
| text = re.sub(r'\(([ivx]+)\)', replace_roman, text) | |
| # Clean enumerated lists | |
| for roman in roman_map: | |
| text = re.sub(f"\\b{roman}\\)", f"{roman_map[roman]})", text, flags=re.IGNORECASE) | |
| # Standardize section headers | |
| section_patterns = { | |
| r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ', | |
| r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ', | |
| r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ', | |
| r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: ', | |
| r'\b(?:results?\s+and\s+conclusions?)\s*:?\s*(?=.*?:)': '', # Remove if followed by another section | |
| r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)': 'Results: ' # Fix malformed combination | |
| } | |
| for pattern, replacement in section_patterns.items(): | |
| text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
| # Ensure complete sentences in sections | |
| text = re.sub(r'(?<=:)\s*([^.!?\n]*?)(?=\s*(?:[A-Z][^:]*:|$))', | |
| lambda m: f" {m.group(1)}." if m.group(1) and not m.group(1).strip().endswith('.') else m.group(0), | |
| text) | |
| # Fix truncated sentences | |
| text = re.sub(r'(?<=:)\s*([^.!?\n]*?)\s*(?=[A-Z][^:]*:)', | |
| lambda m: f" {m.group(1)}." if m.group(1) else "", | |
| text) | |
| # Clean formatting | |
| text = re.sub(r'[\r\n]+', ' ', text) | |
| text = re.sub(r'\s*:\s*', ': ', text) | |
| text = re.sub(r'\s+', ' ', text) | |
| text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text) | |
| text = re.sub(r'β’|\*|β |β‘|β|β', '', text) | |
| text = re.sub(r'\\n|\\r', ' ', text) | |
| text = re.sub(r'\s*\(\s*', ' (', text) | |
| text = re.sub(r'\s*\)\s*', ') ', text) | |
| # Fix statistical notations | |
| text = re.sub(r'p\s*[<=>]\s*0\.\d+', lambda m: m.group().replace(' ', ''), text) | |
| text = re.sub(r'(?<=\d)\s*%', '%', text) | |
| # Fix abbreviations spacing | |
| text = re.sub(r'(?<=\w)vs\.(?=\w)', 'vs. ', text) | |
| text = re.sub(r'(?<=\w)et\s+al\.(?=\w)', 'et al. ', text) | |
| # Remove repeated punctuation | |
| text = re.sub(r'([.!?])\1+', r'\1', text) | |
| # Final cleanup | |
| text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text) | |
| text = text.strip() | |
| if not text.endswith('.'): | |
| text += '.' | |
| return text | |
| # """Enhanced text preprocessing with better section handling and prompt removal.""" | |
| # if not isinstance(text, str) or not text.strip(): | |
| # return text | |
| # # Remove prompt leakage | |
| # prompt_patterns = [ | |
| # r'Generate a structured summary addressing this question:.*?(?=\w+:)', | |
| # r'Focus on key findings and methods\.', | |
| # r'is a structured summary addressing this question:' | |
| # ] | |
| # for pattern in prompt_patterns: | |
| # text = re.sub(pattern, '', text, flags=re.IGNORECASE) | |
| # # Clean section headers more aggressively | |
| # section_patterns = { | |
| # r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ', | |
| # r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ', | |
| # r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ', | |
| # r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: ' | |
| # } | |
| # # Apply section normalization | |
| # for pattern, replacement in section_patterns.items(): | |
| # text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
| # # Remove combined section headers | |
| # combined_headers = [ | |
| # r'\bmethods?\s+and\s+conclusions?\b', | |
| # r'\bresults?\s+and\s+conclusions?\b', | |
| # r'\bmaterials?\s+and\s+methods?\b' | |
| # ] | |
| # for pattern in combined_headers: | |
| # text = re.sub(pattern, 'Methods:', text, flags=re.IGNORECASE) | |
| # # Clean up sentences | |
| # sentences = text.split('.') | |
| # cleaned_sentences = [] | |
| # for sentence in sentences: | |
| # # Remove redundant section references | |
| # sentence = re.sub(r'\b(?:first|second|third|fourth|fifth)\s+sections?\b', '', sentence, flags=re.IGNORECASE) | |
| # # Remove comparative phrases about section details | |
| # sentence = re.sub(r'\b(?:more|less)\s+detailed\s+than.*', '', sentence, flags=re.IGNORECASE) | |
| # if sentence.strip(): | |
| # cleaned_sentences.append(sentence.strip()) | |
| # # Rejoin and format | |
| # text = '. '.join(cleaned_sentences) | |
| # text = re.sub(r'\s+', ' ', text) # Remove extra spaces | |
| # text = re.sub(r'\s*:\s*', ': ', text) # Fix spacing around colons | |
| # return text.strip() | |
| def generate_focused_summary(question, abstracts, model, tokenizer): | |
| formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts if abstract.strip()] | |
| abstracts_content = " [SEP] ".join(formatted_abstracts) | |
| prompt = f""" | |
| Provide a factual summary structured as: | |
| - Background: Context and origin only if present | |
| - Methods: Key procedures and approaches | |
| - Results: Specific findings with numbers | |
| - Conclusions: Main implications | |
| Requirements: | |
| - Present sections sequentially | |
| - Merge related points within sections | |
| - Complete all sentences | |
| - Avoid repeating section headers | |
| - Use original terminology | |
| Content: {abstracts_content} | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| summary_ids = model.generate( | |
| **{ | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| "max_length": 512, | |
| "min_length": 200, | |
| "num_beams": 4, | |
| "length_penalty": 2.0, | |
| "no_repeat_ngram_size": 3, | |
| "temperature": 0.7, | |
| "do_sample": False | |
| } | |
| ) | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return post_process_summary(summary) | |
| def post_process_summary(summary): | |
| """Post-process summary with improved section handling and formatting.""" | |
| if not summary: | |
| return summary | |
| valid_sections = ['Background', 'Methods', 'Results', 'Conclusions'] | |
| sections = {} | |
| current_section = None | |
| current_content = [] | |
| # Pre-clean section headers | |
| summary = re.sub(r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)', 'Results:', summary, flags=re.IGNORECASE) | |
| summary = re.sub(r'\bresults?\s*and\s*conclusions?\s*:', 'Results:', summary, flags=re.IGNORECASE) | |
| # Process line by line | |
| lines = [line.strip() for line in summary.split('.') if line.strip()] | |
| for i, line in enumerate(lines): | |
| section_match = None | |
| for section in valid_sections: | |
| if re.match(fr'\b{section}:', line, re.IGNORECASE): | |
| section_match = section | |
| break | |
| if section_match: | |
| if current_section: | |
| content = ' '.join(current_content) | |
| if content: | |
| sections[current_section] = content | |
| current_section = section_match | |
| content = re.sub(fr'\b{section_match}:\s*', '', line, flags=re.IGNORECASE) | |
| current_content = [content] if content else [] | |
| elif current_section: | |
| # Prevent section header splitting | |
| if not any(sect.lower() in line.lower() for sect in valid_sections): | |
| current_content.append(line) | |
| if current_section and current_content: | |
| sections[current_section] = ' '.join(current_content) | |
| # Format sections | |
| formatted_sections = [] | |
| for section in valid_sections: | |
| if section in sections: | |
| content = sections[section].strip() | |
| if content: | |
| # Complete truncated sentences | |
| if not re.search(r'[.!?]$', content): | |
| if len(content.split()) >= 3: # Only complete if substantial | |
| content += '.' | |
| # Ensure capitalization | |
| content = content[0].upper() + content[1:] | |
| # Fix double periods | |
| content = re.sub(r'\.+', '.', content) | |
| formatted_sections.append(f"{section}: {content}") | |
| return ' '.join(formatted_sections) | |
| def process_papers_in_batches(df, model, tokenizer, batch_size=2): | |
| """Process papers in batches for better efficiency""" | |
| abstracts = df['Abstract'].tolist() | |
| summaries = [] | |
| with ThreadPoolExecutor(max_workers=4) as executor: # Parallel processing | |
| future_to_batch = {executor.submit(generate_focused_summary, "Focus on key findings and methods.", [abstract], model, tokenizer): abstract for abstract in abstracts} | |
| for future in future_to_batch: | |
| summaries.append(future.result()) | |
| return summaries | |
| def create_filter_controls(df, sort_column): | |
| """Create appropriate filter controls based on the selected column""" | |
| filtered_df = df.copy() | |
| if sort_column == 'Publication Year': | |
| # Year range slider | |
| year_min = int(df['Publication Year'].min()) | |
| year_max = int(df['Publication Year'].max()) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| start_year = st.number_input('From Year', | |
| min_value=year_min, | |
| max_value=year_max, | |
| value=year_min) | |
| with col2: | |
| end_year = st.number_input('To Year', | |
| min_value=year_min, | |
| max_value=year_max, | |
| value=year_max) | |
| filtered_df = filtered_df[ | |
| (filtered_df['Publication Year'] >= start_year) & | |
| (filtered_df['Publication Year'] <= end_year) | |
| ] | |
| elif sort_column == 'Authors': | |
| # Multi-select for authors | |
| unique_authors = sorted(set( | |
| author.strip() | |
| for authors in df['Authors'].dropna() | |
| for author in authors.split(';') | |
| )) | |
| selected_authors = st.multiselect( | |
| 'Select Authors', | |
| unique_authors | |
| ) | |
| if selected_authors: | |
| filtered_df = filtered_df[ | |
| filtered_df['Authors'].apply( | |
| lambda x: any(author in str(x) for author in selected_authors) | |
| ) | |
| ] | |
| elif sort_column == 'Source Title': | |
| # Multi-select for source titles | |
| unique_sources = sorted(df['Source Title'].unique()) | |
| selected_sources = st.multiselect( | |
| 'Select Sources', | |
| unique_sources | |
| ) | |
| if selected_sources: | |
| filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] | |
| elif sort_column == 'Article Title': | |
| # Only alphabetical sorting, no filtering | |
| pass | |
| return filtered_df | |
| def main(): | |
| st.title("π¬ Biomedical Papers Analysis") | |
| st.info(""" | |
| **π File Upload Requirements:** | |
| - Excel file (.xlsx or .xls) with **maximum 5 papers** | |
| - Must contain these columns: | |
| β’ Abstract | |
| β’ Article Title | |
| β’ Authors | |
| β’ Source Title | |
| β’ Publication Year | |
| β’ DOI | |
| β’ Times Cited, All Databases | |
| """) | |
| # File upload section | |
| uploaded_file = st.file_uploader( | |
| "Upload Excel file containing papers (max 5 papers)", | |
| type=['xlsx', 'xls'], | |
| help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" | |
| ) | |
| # Question input - moved up but hidden initially | |
| question_container = st.empty() | |
| question = "" | |
| if uploaded_file is not None: | |
| # Process Excel file | |
| if st.session_state.processed_data is None: | |
| with st.spinner("Processing file..."): | |
| df = process_excel(uploaded_file) | |
| if df is not None: | |
| df = df.dropna(subset=["Abstract"]) | |
| if len(df) > 0: | |
| st.session_state.processed_data = df | |
| st.success(f"β Successfully loaded {len(df)} papers with abstracts") | |
| else: | |
| st.error("β No valid papers found after processing. Please check your file.") | |
| if st.session_state.processed_data is not None: | |
| df = st.session_state.processed_data | |
| st.write(f"π Loaded {len(df)} papers with abstracts") | |
| # Get question before processing | |
| with question_container: | |
| question = st.text_input( | |
| "Enter your research question (optional):", | |
| help="If provided, a question-focused summary will be generated after individual summaries" | |
| ) | |
| # Single button for both processes | |
| if not st.session_state.get('processing_started', False): | |
| if st.button("Start Analysis"): | |
| st.session_state.processing_started = True | |
| # Show processing status and results | |
| if st.session_state.get('processing_started', False): | |
| # Individual Summaries Section | |
| st.header("π Individual Paper Summaries") | |
| # Generate summaries if not already done | |
| if st.session_state.summaries is None: | |
| try: | |
| with st.spinner("Generating individual paper summaries..."): | |
| model, tokenizer = get_model("summarize") | |
| if model is None or tokenizer is None: | |
| reset_processing_state() | |
| return | |
| start_time = time.time() | |
| st.session_state.summaries = process_papers_in_batches( | |
| df, model, tokenizer, batch_size=2 | |
| ) | |
| end_time = time.time() | |
| st.write(f"Processing time: {end_time - start_time:.2f} seconds") | |
| except Exception as e: | |
| st.error(f"Error generating summaries: {str(e)}") | |
| reset_processing_state() | |
| # Display summaries with improved sorting and filtering | |
| if st.session_state.summaries is not None: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] | |
| sort_column = st.selectbox("Sort/Filter by:", sort_options) | |
| with col2: | |
| if sort_column == 'Article Title': | |
| ascending = st.radio( | |
| "Sort order", | |
| ["A to Z", "Z to A"], | |
| horizontal=True | |
| ) == "A to Z" | |
| elif sort_column == 'Times Cited': | |
| ascending = st.radio( | |
| "Sort order", | |
| ["Most cited first", "Least cited first"], | |
| horizontal=True | |
| ) == "Least cited first" | |
| else: | |
| ascending = True # Default for other columns | |
| # Create display dataframe | |
| display_df = df.copy() | |
| display_df['Summary'] = st.session_state.summaries | |
| display_df['Publication Year'] = display_df['Publication Year'].astype(int) | |
| display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) | |
| display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) | |
| # Apply filters | |
| filtered_df = create_filter_controls(display_df, sort_column) | |
| # Apply sorting | |
| if sort_column == 'Times Cited': | |
| sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) | |
| elif sort_column == 'Article Title': | |
| sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) | |
| else: | |
| sorted_df = filtered_df | |
| # Show number of filtered results | |
| if len(sorted_df) != len(display_df): | |
| st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") | |
| # Apply custom styling | |
| st.markdown(""" | |
| <style> | |
| .paper-info { | |
| border: 1px solid #ddd; | |
| padding: 15px; | |
| margin-bottom: 20px; | |
| border-radius: 5px; | |
| } | |
| .paper-section { | |
| margin-bottom: 10px; | |
| } | |
| .section-header { | |
| font-weight: bold; | |
| color: #555; | |
| margin-bottom: 8px; | |
| } | |
| .paper-title { | |
| margin-top: 5px; | |
| margin-bottom: 10px; | |
| } | |
| .paper-meta { | |
| font-size: 0.9em; | |
| color: #666; | |
| } | |
| .doi-link { | |
| color: #0366d6; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Display papers using the filtered and sorted dataframe | |
| for _, row in sorted_df.iterrows(): | |
| paper_info_cols = st.columns([1, 1]) | |
| with paper_info_cols[0]: # PAPER column | |
| st.markdown('<div class="paper-section"><div class="section-header">PAPER</div>', unsafe_allow_html=True) | |
| st.markdown(f""" | |
| <div class="paper-info"> | |
| <div class="paper-title">{row['Article Title']}</div> | |
| <div class="paper-meta"> | |
| <strong>Authors:</strong> {row['Authors']}<br> | |
| <strong>Source:</strong> {row['Source Title']}<br> | |
| <strong>Publication Year:</strong> {row['Publication Year']}<br> | |
| <strong>Times Cited:</strong> {row['Times Cited']}<br> | |
| <strong>DOI:</strong> {row['DOI'] if pd.notna(row['DOI']) else 'None'} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with paper_info_cols[1]: # SUMMARY column | |
| st.markdown('<div class="paper-section"><div class="section-header">SUMMARY</div>', unsafe_allow_html=True) | |
| st.markdown(f""" | |
| <div class="paper-info"> | |
| {row['Summary']} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Add spacing between papers | |
| st.markdown("<div style='margin-bottom: 20px;'></div>", unsafe_allow_html=True) | |
| # Question-focused Summary Section (only if question provided) | |
| if question.strip(): | |
| st.header("β Question-focused Summary") | |
| if not st.session_state.get('focused_summary_generated', False): | |
| try: | |
| with st.spinner("Analyzing relevant papers..."): | |
| if st.session_state.text_processor is None: | |
| st.session_state.text_processor = TextProcessor() | |
| model, tokenizer = get_model("question_focused") | |
| if model is None or tokenizer is None: | |
| raise Exception("Failed to load question-focused model") | |
| results = st.session_state.text_processor.find_most_relevant_abstracts( | |
| question, | |
| df['Abstract'].tolist(), | |
| top_k=5 | |
| ) | |
| if not results['top_indices']: | |
| st.warning("No papers found relevant to your question") | |
| return | |
| # Store relevant papers and scores | |
| st.session_state.relevant_papers = df.iloc[results['top_indices']] | |
| st.session_state.relevance_scores = results['scores'] | |
| relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() | |
| st.session_state.focused_summary = generate_focused_summary( | |
| question, | |
| relevant_abstracts, | |
| model, | |
| tokenizer | |
| ) | |
| st.session_state.focused_summary_generated = True | |
| except Exception as e: | |
| st.error(f"Error generating focused summary: {str(e)}") | |
| reset_processing_state() | |
| finally: | |
| cleanup_model(model, tokenizer) | |
| # Display focused summary results | |
| if st.session_state.get('focused_summary_generated', False): | |
| st.subheader("Summary") | |
| st.write(st.session_state.focused_summary) | |
| st.subheader("Most Relevant Papers") | |
| relevant_papers = st.session_state.relevant_papers[ | |
| ['Article Title', 'Authors', 'Publication Year', 'DOI'] | |
| ].copy() | |
| relevant_papers['Relevance Score'] = st.session_state.relevance_scores | |
| relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) | |
| st.dataframe(relevant_papers, hide_index=True) | |
| if __name__ == "__main__": | |
| main() |