Spaces:
Paused
Paused
lanny xu
commited on
Commit
·
2d46508
1
Parent(s):
9cce495
add async
Browse files- document_processor.py +113 -0
- main.py +19 -14
- workflow_nodes.py +8 -8
document_processor.py
CHANGED
|
@@ -307,6 +307,119 @@ class DocumentProcessor:
|
|
| 307 |
# 返回doc_splits用于GraphRAG索引
|
| 308 |
return vectorstore, retriever, doc_splits
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
def expand_query(self, query: str) -> List[str]:
|
| 311 |
"""扩展查询,生成相关查询"""
|
| 312 |
if not self.query_expansion_model:
|
|
|
|
| 307 |
# 返回doc_splits用于GraphRAG索引
|
| 308 |
return vectorstore, retriever, doc_splits
|
| 309 |
|
| 310 |
+
async def async_expand_query(self, query: str) -> List[str]:
|
| 311 |
+
"""异步扩展查询"""
|
| 312 |
+
if not self.query_expansion_model:
|
| 313 |
+
return [query]
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
# 使用LLM生成扩展查询
|
| 317 |
+
prompt = QUERY_EXPANSION_PROMPT.format(query=query)
|
| 318 |
+
expanded_queries_text = await self.query_expansion_model.ainvoke(prompt)
|
| 319 |
+
|
| 320 |
+
# 解析扩展查询
|
| 321 |
+
expanded_queries = [query] # 包含原始查询
|
| 322 |
+
for line in expanded_queries_text.strip().split('\n'):
|
| 323 |
+
line = line.strip()
|
| 324 |
+
if line and not line.startswith('#') and not line.startswith('//'):
|
| 325 |
+
# 移除可能的编号前缀
|
| 326 |
+
if line[0].isdigit() and '.' in line[:5]:
|
| 327 |
+
line = line.split('.', 1)[1].strip()
|
| 328 |
+
expanded_queries.append(line)
|
| 329 |
+
|
| 330 |
+
# 限制扩展查询数量
|
| 331 |
+
return expanded_queries[:MAX_EXPANDED_QUERIES + 1]
|
| 332 |
+
except Exception as e:
|
| 333 |
+
print(f"⚠️ 异步查询扩展失败: {e}")
|
| 334 |
+
return [query]
|
| 335 |
+
|
| 336 |
+
async def async_hybrid_retrieve(self, query: str, top_k: int = 5) -> List:
|
| 337 |
+
"""异步混合检索"""
|
| 338 |
+
if not ENABLE_HYBRID_SEARCH or not self.ensemble_retriever:
|
| 339 |
+
return await self.retriever.ainvoke(query)
|
| 340 |
+
|
| 341 |
+
try:
|
| 342 |
+
results = await self.ensemble_retriever.ainvoke(query)
|
| 343 |
+
return results[:top_k]
|
| 344 |
+
except Exception as e:
|
| 345 |
+
print(f"⚠️ 异步混合检索失败: {e}")
|
| 346 |
+
print("回退到向量检索")
|
| 347 |
+
return await self.retriever.ainvoke(query)
|
| 348 |
+
|
| 349 |
+
async def async_enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20,
|
| 350 |
+
image_paths: List[str] = None, use_query_expansion: bool = None):
|
| 351 |
+
"""异步增强检索"""
|
| 352 |
+
import asyncio
|
| 353 |
+
|
| 354 |
+
# 确定是否使用查询扩展
|
| 355 |
+
if use_query_expansion is None:
|
| 356 |
+
use_query_expansion = ENABLE_QUERY_EXPANSION
|
| 357 |
+
|
| 358 |
+
# 如果启用查询扩展,生成扩展查询
|
| 359 |
+
if use_query_expansion:
|
| 360 |
+
expanded_queries = await self.async_expand_query(query)
|
| 361 |
+
print(f"查询扩展: {len(expanded_queries)} 个查询")
|
| 362 |
+
else:
|
| 363 |
+
expanded_queries = [query]
|
| 364 |
+
|
| 365 |
+
# 多模态检索(暂时保持同步,使用线程池)
|
| 366 |
+
if image_paths and ENABLE_MULTIMODAL:
|
| 367 |
+
loop = asyncio.get_running_loop()
|
| 368 |
+
return await loop.run_in_executor(None, self.multimodal_retrieve, query, image_paths, top_k)
|
| 369 |
+
|
| 370 |
+
# 混合检索或向量检索
|
| 371 |
+
all_candidate_docs = []
|
| 372 |
+
|
| 373 |
+
async def retrieve_single(q):
|
| 374 |
+
if ENABLE_HYBRID_SEARCH:
|
| 375 |
+
docs = await self.async_hybrid_retrieve(q, rerank_candidates)
|
| 376 |
+
else:
|
| 377 |
+
docs = await self.retriever.ainvoke(q)
|
| 378 |
+
if len(docs) > rerank_candidates:
|
| 379 |
+
docs = docs[:rerank_candidates]
|
| 380 |
+
return docs
|
| 381 |
+
|
| 382 |
+
# 并发执行所有查询的检索
|
| 383 |
+
results = await asyncio.gather(*[retrieve_single(q) for q in expanded_queries])
|
| 384 |
+
|
| 385 |
+
for docs in results:
|
| 386 |
+
all_candidate_docs.extend(docs)
|
| 387 |
+
|
| 388 |
+
# 去重(基于文档内容)
|
| 389 |
+
unique_docs = []
|
| 390 |
+
seen_content = set()
|
| 391 |
+
for doc in all_candidate_docs:
|
| 392 |
+
content = doc.page_content
|
| 393 |
+
if content not in seen_content:
|
| 394 |
+
seen_content.add(content)
|
| 395 |
+
unique_docs.append(doc)
|
| 396 |
+
|
| 397 |
+
print(f"检索获得 {len(unique_docs)} 个候选文档")
|
| 398 |
+
|
| 399 |
+
# 重排(如果重排器可用)
|
| 400 |
+
# 注意:重排通常是计算密集型,建议放入线程池
|
| 401 |
+
if self.reranker and len(unique_docs) > top_k:
|
| 402 |
+
try:
|
| 403 |
+
loop = asyncio.get_running_loop()
|
| 404 |
+
# rerank 方法内部可能也比较耗时
|
| 405 |
+
reranked_results = await loop.run_in_executor(
|
| 406 |
+
None,
|
| 407 |
+
self.reranker.rerank,
|
| 408 |
+
query, unique_docs, top_k
|
| 409 |
+
)
|
| 410 |
+
final_docs = [doc for doc, score in reranked_results]
|
| 411 |
+
scores = [score for doc, score in reranked_results]
|
| 412 |
+
|
| 413 |
+
print(f"重排后返回 {len(final_docs)} 个文档")
|
| 414 |
+
print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
|
| 415 |
+
|
| 416 |
+
return final_docs
|
| 417 |
+
except Exception as e:
|
| 418 |
+
print(f"⚠️ 重排失败: {e},使用原始检索结果")
|
| 419 |
+
return unique_docs[:top_k]
|
| 420 |
+
else:
|
| 421 |
+
return unique_docs[:top_k]
|
| 422 |
+
|
| 423 |
def expand_query(self, query: str) -> List[str]:
|
| 424 |
"""扩展查询,生成相关查询"""
|
| 425 |
if not self.query_expansion_model:
|
main.py
CHANGED
|
@@ -159,9 +159,9 @@ class AdaptiveRAGSystem:
|
|
| 159 |
debug=False
|
| 160 |
)
|
| 161 |
|
| 162 |
-
def query(self, question: str, verbose: bool = True):
|
| 163 |
"""
|
| 164 |
-
处理查询
|
| 165 |
|
| 166 |
Args:
|
| 167 |
question (str): 用户问题
|
|
@@ -170,6 +170,7 @@ class AdaptiveRAGSystem:
|
|
| 170 |
Returns:
|
| 171 |
dict: 包含最终答案和评估指标的字典
|
| 172 |
"""
|
|
|
|
| 173 |
print(f"\n🔍 处理问题: {question}")
|
| 174 |
print("=" * 50)
|
| 175 |
|
|
@@ -181,24 +182,19 @@ class AdaptiveRAGSystem:
|
|
| 181 |
config = {"recursion_limit": 50} # 增加到 50,默认是 25
|
| 182 |
|
| 183 |
print("\n🤖 思考过程:")
|
| 184 |
-
for output in self.app.
|
| 185 |
for key, value in output.items():
|
| 186 |
if verbose:
|
| 187 |
# 简单的节点执行提示,模拟流式感
|
| 188 |
print(f" ↳ 执行节点: {key}...", end="\r")
|
| 189 |
-
|
|
|
|
| 190 |
print(f" ✅ 完成节点: {key} ")
|
| 191 |
|
| 192 |
-
# pprint(f"节点 '{key}':")
|
| 193 |
-
# 可选:在每个节点打印完整状态
|
| 194 |
-
# pprint(value, indent=2, width=80, depth=None)
|
| 195 |
final_generation = value.get("generation", final_generation)
|
| 196 |
# 保存检索评估指标
|
| 197 |
if "retrieval_metrics" in value:
|
| 198 |
retrieval_metrics = value["retrieval_metrics"]
|
| 199 |
-
if verbose:
|
| 200 |
-
# pprint("\n---\n")
|
| 201 |
-
pass
|
| 202 |
|
| 203 |
print("\n" + "=" * 50)
|
| 204 |
print("🎯 最终答案:")
|
|
@@ -207,11 +203,11 @@ class AdaptiveRAGSystem:
|
|
| 207 |
# 模拟流式输出效果 (打字机效果)
|
| 208 |
if final_generation:
|
| 209 |
import sys
|
| 210 |
-
import time
|
| 211 |
for char in final_generation:
|
| 212 |
sys.stdout.write(char)
|
| 213 |
sys.stdout.flush()
|
| 214 |
-
|
|
|
|
| 215 |
print() # 换行
|
| 216 |
else:
|
| 217 |
print("未生成答案")
|
|
@@ -226,6 +222,7 @@ class AdaptiveRAGSystem:
|
|
| 226 |
|
| 227 |
def interactive_mode(self):
|
| 228 |
"""交互模式,允许用户持续提问"""
|
|
|
|
| 229 |
print("\n🤖 欢迎使用自适应RAG系统!")
|
| 230 |
print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
|
| 231 |
print("-" * 50)
|
|
@@ -242,7 +239,8 @@ class AdaptiveRAGSystem:
|
|
| 242 |
print("⚠️ 请输入一个有效的问题")
|
| 243 |
continue
|
| 244 |
|
| 245 |
-
|
|
|
|
| 246 |
|
| 247 |
# 显示检索评估摘要
|
| 248 |
if result.get("retrieval_metrics"):
|
|
@@ -259,11 +257,14 @@ class AdaptiveRAGSystem:
|
|
| 259 |
break
|
| 260 |
except Exception as e:
|
| 261 |
print(f"❌ 发生错误: {e}")
|
|
|
|
|
|
|
| 262 |
print("请重试或输入 'quit' 退出")
|
| 263 |
|
| 264 |
|
| 265 |
def main():
|
| 266 |
"""主函数"""
|
|
|
|
| 267 |
try:
|
| 268 |
# 初始化系统
|
| 269 |
rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
|
|
@@ -272,7 +273,9 @@ def main():
|
|
| 272 |
# test_question = "AlphaCodium论文讲的是什么?"
|
| 273 |
test_question = "LangGraph的作者目前在哪家公司工作?"
|
| 274 |
# test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
|
| 275 |
-
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# 显示测试查询的检索评估摘要
|
| 278 |
if result.get("retrieval_metrics"):
|
|
@@ -289,6 +292,8 @@ def main():
|
|
| 289 |
|
| 290 |
except Exception as e:
|
| 291 |
print(f"❌ 系统初始化失败: {e}")
|
|
|
|
|
|
|
| 292 |
print("请检查配置和依赖是否正确安装")
|
| 293 |
|
| 294 |
|
|
|
|
| 159 |
debug=False
|
| 160 |
)
|
| 161 |
|
| 162 |
+
async def query(self, question: str, verbose: bool = True):
|
| 163 |
"""
|
| 164 |
+
处理查询 (异步版本)
|
| 165 |
|
| 166 |
Args:
|
| 167 |
question (str): 用户问题
|
|
|
|
| 170 |
Returns:
|
| 171 |
dict: 包含最终答案和评估指标的字典
|
| 172 |
"""
|
| 173 |
+
import asyncio
|
| 174 |
print(f"\n🔍 处理问题: {question}")
|
| 175 |
print("=" * 50)
|
| 176 |
|
|
|
|
| 182 |
config = {"recursion_limit": 50} # 增加到 50,默认是 25
|
| 183 |
|
| 184 |
print("\n🤖 思考过程:")
|
| 185 |
+
async for output in self.app.astream(inputs, config=config):
|
| 186 |
for key, value in output.items():
|
| 187 |
if verbose:
|
| 188 |
# 简单的节点执行提示,模拟流式感
|
| 189 |
print(f" ↳ 执行节点: {key}...", end="\r")
|
| 190 |
+
# 异步暂停
|
| 191 |
+
await asyncio.sleep(0.1)
|
| 192 |
print(f" ✅ 完成节点: {key} ")
|
| 193 |
|
|
|
|
|
|
|
|
|
|
| 194 |
final_generation = value.get("generation", final_generation)
|
| 195 |
# 保存检索评估指标
|
| 196 |
if "retrieval_metrics" in value:
|
| 197 |
retrieval_metrics = value["retrieval_metrics"]
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
print("\n" + "=" * 50)
|
| 200 |
print("🎯 最终答案:")
|
|
|
|
| 203 |
# 模拟流式输出效果 (打字机效果)
|
| 204 |
if final_generation:
|
| 205 |
import sys
|
|
|
|
| 206 |
for char in final_generation:
|
| 207 |
sys.stdout.write(char)
|
| 208 |
sys.stdout.flush()
|
| 209 |
+
# 异步暂停
|
| 210 |
+
await asyncio.sleep(0.01) # 控制打字速度
|
| 211 |
print() # 换行
|
| 212 |
else:
|
| 213 |
print("未生成答案")
|
|
|
|
| 222 |
|
| 223 |
def interactive_mode(self):
|
| 224 |
"""交互模式,允许用户持续提问"""
|
| 225 |
+
import asyncio
|
| 226 |
print("\n🤖 欢迎使用自适应RAG系统!")
|
| 227 |
print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
|
| 228 |
print("-" * 50)
|
|
|
|
| 239 |
print("⚠️ 请输入一个有效的问题")
|
| 240 |
continue
|
| 241 |
|
| 242 |
+
# 使用 asyncio.run 执行异步查询
|
| 243 |
+
result = asyncio.run(self.query(question))
|
| 244 |
|
| 245 |
# 显示检索评估摘要
|
| 246 |
if result.get("retrieval_metrics"):
|
|
|
|
| 257 |
break
|
| 258 |
except Exception as e:
|
| 259 |
print(f"❌ 发生错误: {e}")
|
| 260 |
+
import traceback
|
| 261 |
+
traceback.print_exc()
|
| 262 |
print("请重试或输入 'quit' 退出")
|
| 263 |
|
| 264 |
|
| 265 |
def main():
|
| 266 |
"""主函数"""
|
| 267 |
+
import asyncio
|
| 268 |
try:
|
| 269 |
# 初始化系统
|
| 270 |
rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
|
|
|
|
| 273 |
# test_question = "AlphaCodium论文讲的是什么?"
|
| 274 |
test_question = "LangGraph的作者目前在哪家公司工作?"
|
| 275 |
# test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
|
| 276 |
+
|
| 277 |
+
# 使用 asyncio.run 执行异步查询
|
| 278 |
+
result = asyncio.run(rag_system.query(test_question))
|
| 279 |
|
| 280 |
# 显示测试查询的检索评估摘要
|
| 281 |
if result.get("retrieval_metrics"):
|
|
|
|
| 292 |
|
| 293 |
except Exception as e:
|
| 294 |
print(f"❌ 系统初始化失败: {e}")
|
| 295 |
+
import traceback
|
| 296 |
+
traceback.print_exc()
|
| 297 |
print("请检查配置和依赖是否正确安装")
|
| 298 |
|
| 299 |
|
workflow_nodes.py
CHANGED
|
@@ -118,9 +118,9 @@ class WorkflowNodes:
|
|
| 118 |
"retry_count": 0
|
| 119 |
}
|
| 120 |
|
| 121 |
-
def retrieve(self, state):
|
| 122 |
"""
|
| 123 |
-
检索文档
|
| 124 |
|
| 125 |
Args:
|
| 126 |
state (dict): 当前图状态
|
|
@@ -138,8 +138,8 @@ class WorkflowNodes:
|
|
| 138 |
# 检查是否有图像路径(多模态检索)
|
| 139 |
image_paths = state.get("image_paths", None)
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
documents = self.doc_processor.
|
| 143 |
question,
|
| 144 |
top_k=5,
|
| 145 |
rerank_candidates=20,
|
|
@@ -157,15 +157,15 @@ class WorkflowNodes:
|
|
| 157 |
|
| 158 |
except Exception as e:
|
| 159 |
print(f"⚠️ 增强检索失败: {e},回退到基本检索")
|
| 160 |
-
# 回退到基本检索
|
| 161 |
try:
|
| 162 |
if self.retriever is not None:
|
| 163 |
-
documents = self.retriever.
|
| 164 |
elif hasattr(self.doc_processor, 'vector_retriever') and self.doc_processor.vector_retriever is not None:
|
| 165 |
-
documents = self.doc_processor.vector_retriever.
|
| 166 |
print(" 使用 vector_retriever 作为备选")
|
| 167 |
elif hasattr(self.doc_processor, 'retriever') and self.doc_processor.retriever is not None:
|
| 168 |
-
documents = self.doc_processor.retriever.
|
| 169 |
print(" 使用 doc_processor.retriever 作为备选")
|
| 170 |
else:
|
| 171 |
print("❌ 检索器未正确初始化,返回空文档列表")
|
|
|
|
| 118 |
"retry_count": 0
|
| 119 |
}
|
| 120 |
|
| 121 |
+
async def retrieve(self, state):
|
| 122 |
"""
|
| 123 |
+
检索文档 (异步版本)
|
| 124 |
|
| 125 |
Args:
|
| 126 |
state (dict): 当前图状态
|
|
|
|
| 138 |
# 检查是否有图像路径(多模态检索)
|
| 139 |
image_paths = state.get("image_paths", None)
|
| 140 |
|
| 141 |
+
# 使用异步增强检索
|
| 142 |
+
documents = await self.doc_processor.async_enhanced_retrieve(
|
| 143 |
question,
|
| 144 |
top_k=5,
|
| 145 |
rerank_candidates=20,
|
|
|
|
| 157 |
|
| 158 |
except Exception as e:
|
| 159 |
print(f"⚠️ 增强检索失败: {e},回退到基本检索")
|
| 160 |
+
# 回退到基本检索 (同步回退,如果需要也可以改为异步)
|
| 161 |
try:
|
| 162 |
if self.retriever is not None:
|
| 163 |
+
documents = await self.retriever.ainvoke(question)
|
| 164 |
elif hasattr(self.doc_processor, 'vector_retriever') and self.doc_processor.vector_retriever is not None:
|
| 165 |
+
documents = await self.doc_processor.vector_retriever.ainvoke(question)
|
| 166 |
print(" 使用 vector_retriever 作为备选")
|
| 167 |
elif hasattr(self.doc_processor, 'retriever') and self.doc_processor.retriever is not None:
|
| 168 |
+
documents = await self.doc_processor.retriever.ainvoke(question)
|
| 169 |
print(" 使用 doc_processor.retriever 作为备选")
|
| 170 |
else:
|
| 171 |
print("❌ 检索器未正确初始化,返回空文档列表")
|