lanny xu commited on
Commit
2d46508
·
1 Parent(s): 9cce495
Files changed (3) hide show
  1. document_processor.py +113 -0
  2. main.py +19 -14
  3. workflow_nodes.py +8 -8
document_processor.py CHANGED
@@ -307,6 +307,119 @@ class DocumentProcessor:
307
  # 返回doc_splits用于GraphRAG索引
308
  return vectorstore, retriever, doc_splits
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  def expand_query(self, query: str) -> List[str]:
311
  """扩展查询,生成相关查询"""
312
  if not self.query_expansion_model:
 
307
  # 返回doc_splits用于GraphRAG索引
308
  return vectorstore, retriever, doc_splits
309
 
310
+ async def async_expand_query(self, query: str) -> List[str]:
311
+ """异步扩展查询"""
312
+ if not self.query_expansion_model:
313
+ return [query]
314
+
315
+ try:
316
+ # 使用LLM生成扩展查询
317
+ prompt = QUERY_EXPANSION_PROMPT.format(query=query)
318
+ expanded_queries_text = await self.query_expansion_model.ainvoke(prompt)
319
+
320
+ # 解析扩展查询
321
+ expanded_queries = [query] # 包含原始查询
322
+ for line in expanded_queries_text.strip().split('\n'):
323
+ line = line.strip()
324
+ if line and not line.startswith('#') and not line.startswith('//'):
325
+ # 移除可能的编号前缀
326
+ if line[0].isdigit() and '.' in line[:5]:
327
+ line = line.split('.', 1)[1].strip()
328
+ expanded_queries.append(line)
329
+
330
+ # 限制扩展查询数量
331
+ return expanded_queries[:MAX_EXPANDED_QUERIES + 1]
332
+ except Exception as e:
333
+ print(f"⚠️ 异步查询扩展失败: {e}")
334
+ return [query]
335
+
336
+ async def async_hybrid_retrieve(self, query: str, top_k: int = 5) -> List:
337
+ """异步混合检索"""
338
+ if not ENABLE_HYBRID_SEARCH or not self.ensemble_retriever:
339
+ return await self.retriever.ainvoke(query)
340
+
341
+ try:
342
+ results = await self.ensemble_retriever.ainvoke(query)
343
+ return results[:top_k]
344
+ except Exception as e:
345
+ print(f"⚠️ 异步混合检索失败: {e}")
346
+ print("回退到向量检索")
347
+ return await self.retriever.ainvoke(query)
348
+
349
+ async def async_enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20,
350
+ image_paths: List[str] = None, use_query_expansion: bool = None):
351
+ """异步增强检索"""
352
+ import asyncio
353
+
354
+ # 确定是否使用查询扩展
355
+ if use_query_expansion is None:
356
+ use_query_expansion = ENABLE_QUERY_EXPANSION
357
+
358
+ # 如果启用查询扩展,生成扩展查询
359
+ if use_query_expansion:
360
+ expanded_queries = await self.async_expand_query(query)
361
+ print(f"查询扩展: {len(expanded_queries)} 个查询")
362
+ else:
363
+ expanded_queries = [query]
364
+
365
+ # 多模态检索(暂时保持同步,使用线程池)
366
+ if image_paths and ENABLE_MULTIMODAL:
367
+ loop = asyncio.get_running_loop()
368
+ return await loop.run_in_executor(None, self.multimodal_retrieve, query, image_paths, top_k)
369
+
370
+ # 混合检索或向量检索
371
+ all_candidate_docs = []
372
+
373
+ async def retrieve_single(q):
374
+ if ENABLE_HYBRID_SEARCH:
375
+ docs = await self.async_hybrid_retrieve(q, rerank_candidates)
376
+ else:
377
+ docs = await self.retriever.ainvoke(q)
378
+ if len(docs) > rerank_candidates:
379
+ docs = docs[:rerank_candidates]
380
+ return docs
381
+
382
+ # 并发执行所有查询的检索
383
+ results = await asyncio.gather(*[retrieve_single(q) for q in expanded_queries])
384
+
385
+ for docs in results:
386
+ all_candidate_docs.extend(docs)
387
+
388
+ # 去重(基于文档内容)
389
+ unique_docs = []
390
+ seen_content = set()
391
+ for doc in all_candidate_docs:
392
+ content = doc.page_content
393
+ if content not in seen_content:
394
+ seen_content.add(content)
395
+ unique_docs.append(doc)
396
+
397
+ print(f"检索获得 {len(unique_docs)} 个候选文档")
398
+
399
+ # 重排(如果重排器可用)
400
+ # 注意:重排通常是计算密集型,建议放入线程池
401
+ if self.reranker and len(unique_docs) > top_k:
402
+ try:
403
+ loop = asyncio.get_running_loop()
404
+ # rerank 方法内部可能也比较耗时
405
+ reranked_results = await loop.run_in_executor(
406
+ None,
407
+ self.reranker.rerank,
408
+ query, unique_docs, top_k
409
+ )
410
+ final_docs = [doc for doc, score in reranked_results]
411
+ scores = [score for doc, score in reranked_results]
412
+
413
+ print(f"重排后返回 {len(final_docs)} 个文档")
414
+ print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
415
+
416
+ return final_docs
417
+ except Exception as e:
418
+ print(f"⚠️ 重排失败: {e},使用原始检索结果")
419
+ return unique_docs[:top_k]
420
+ else:
421
+ return unique_docs[:top_k]
422
+
423
  def expand_query(self, query: str) -> List[str]:
424
  """扩展查询,生成相关查询"""
425
  if not self.query_expansion_model:
main.py CHANGED
@@ -159,9 +159,9 @@ class AdaptiveRAGSystem:
159
  debug=False
160
  )
161
 
162
- def query(self, question: str, verbose: bool = True):
163
  """
164
- 处理查询
165
 
166
  Args:
167
  question (str): 用户问题
@@ -170,6 +170,7 @@ class AdaptiveRAGSystem:
170
  Returns:
171
  dict: 包含最终答案和评估指标的字典
172
  """
 
173
  print(f"\n🔍 处理问题: {question}")
174
  print("=" * 50)
175
 
@@ -181,24 +182,19 @@ class AdaptiveRAGSystem:
181
  config = {"recursion_limit": 50} # 增加到 50,默认是 25
182
 
183
  print("\n🤖 思考过程:")
184
- for output in self.app.stream(inputs, config=config):
185
  for key, value in output.items():
186
  if verbose:
187
  # 简单的节点执行提示,模拟流式感
188
  print(f" ↳ 执行节点: {key}...", end="\r")
189
- time.sleep(0.1) # 视觉暂停
 
190
  print(f" ✅ 完成节点: {key} ")
191
 
192
- # pprint(f"节点 '{key}':")
193
- # 可选:在每个节点打印完整状态
194
- # pprint(value, indent=2, width=80, depth=None)
195
  final_generation = value.get("generation", final_generation)
196
  # 保存检索评估指标
197
  if "retrieval_metrics" in value:
198
  retrieval_metrics = value["retrieval_metrics"]
199
- if verbose:
200
- # pprint("\n---\n")
201
- pass
202
 
203
  print("\n" + "=" * 50)
204
  print("🎯 最终答案:")
@@ -207,11 +203,11 @@ class AdaptiveRAGSystem:
207
  # 模拟流式输出效果 (打字机效果)
208
  if final_generation:
209
  import sys
210
- import time
211
  for char in final_generation:
212
  sys.stdout.write(char)
213
  sys.stdout.flush()
214
- time.sleep(0.01) # 控制打字速度
 
215
  print() # 换行
216
  else:
217
  print("未生成答案")
@@ -226,6 +222,7 @@ class AdaptiveRAGSystem:
226
 
227
  def interactive_mode(self):
228
  """交互模式,允许用户持续提问"""
 
229
  print("\n🤖 欢迎使用自适应RAG系统!")
230
  print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
231
  print("-" * 50)
@@ -242,7 +239,8 @@ class AdaptiveRAGSystem:
242
  print("⚠️ 请输入一个有效的问题")
243
  continue
244
 
245
- result = self.query(question)
 
246
 
247
  # 显示检索评估摘要
248
  if result.get("retrieval_metrics"):
@@ -259,11 +257,14 @@ class AdaptiveRAGSystem:
259
  break
260
  except Exception as e:
261
  print(f"❌ 发生错误: {e}")
 
 
262
  print("请重试或输入 'quit' 退出")
263
 
264
 
265
  def main():
266
  """主函数"""
 
267
  try:
268
  # 初始化系统
269
  rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
@@ -272,7 +273,9 @@ def main():
272
  # test_question = "AlphaCodium论文讲的是什么?"
273
  test_question = "LangGraph的作者目前在哪家公司工作?"
274
  # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
275
- result = rag_system.query(test_question)
 
 
276
 
277
  # 显示测试查询的检索评估摘要
278
  if result.get("retrieval_metrics"):
@@ -289,6 +292,8 @@ def main():
289
 
290
  except Exception as e:
291
  print(f"❌ 系统初始化失败: {e}")
 
 
292
  print("请检查配置和依赖是否正确安装")
293
 
294
 
 
159
  debug=False
160
  )
161
 
162
+ async def query(self, question: str, verbose: bool = True):
163
  """
164
+ 处理查询 (异步版本)
165
 
166
  Args:
167
  question (str): 用户问题
 
170
  Returns:
171
  dict: 包含最终答案和评估指标的字典
172
  """
173
+ import asyncio
174
  print(f"\n🔍 处理问题: {question}")
175
  print("=" * 50)
176
 
 
182
  config = {"recursion_limit": 50} # 增加到 50,默认是 25
183
 
184
  print("\n🤖 思考过程:")
185
+ async for output in self.app.astream(inputs, config=config):
186
  for key, value in output.items():
187
  if verbose:
188
  # 简单的节点执行提示,模拟流式感
189
  print(f" ↳ 执行节点: {key}...", end="\r")
190
+ # 异步暂停
191
+ await asyncio.sleep(0.1)
192
  print(f" ✅ 完成节点: {key} ")
193
 
 
 
 
194
  final_generation = value.get("generation", final_generation)
195
  # 保存检索评估指标
196
  if "retrieval_metrics" in value:
197
  retrieval_metrics = value["retrieval_metrics"]
 
 
 
198
 
199
  print("\n" + "=" * 50)
200
  print("🎯 最终答案:")
 
203
  # 模拟流式输出效果 (打字机效果)
204
  if final_generation:
205
  import sys
 
206
  for char in final_generation:
207
  sys.stdout.write(char)
208
  sys.stdout.flush()
209
+ # 异步暂停
210
+ await asyncio.sleep(0.01) # 控制打字速度
211
  print() # 换行
212
  else:
213
  print("未生成答案")
 
222
 
223
  def interactive_mode(self):
224
  """交互模式,允许用户持续提问"""
225
+ import asyncio
226
  print("\n🤖 欢迎使用自适应RAG系统!")
227
  print("💡 输入问题开始对话,输入 'quit' 或 'exit' 退出")
228
  print("-" * 50)
 
239
  print("⚠️ 请输入一个有效的问题")
240
  continue
241
 
242
+ # 使用 asyncio.run 执行异步查询
243
+ result = asyncio.run(self.query(question))
244
 
245
  # 显示检索评估摘要
246
  if result.get("retrieval_metrics"):
 
257
  break
258
  except Exception as e:
259
  print(f"❌ 发生错误: {e}")
260
+ import traceback
261
+ traceback.print_exc()
262
  print("请重试或输入 'quit' 退出")
263
 
264
 
265
  def main():
266
  """主函数"""
267
+ import asyncio
268
  try:
269
  # 初始化系统
270
  rag_system: AdaptiveRAGSystem = AdaptiveRAGSystem()
 
273
  # test_question = "AlphaCodium论文讲的是什么?"
274
  test_question = "LangGraph的作者目前在哪家公司工作?"
275
  # test_question = "解释embedding嵌入的原理,最好列举实现过程的具体步骤"
276
+
277
+ # 使用 asyncio.run 执行异步查询
278
+ result = asyncio.run(rag_system.query(test_question))
279
 
280
  # 显示测试查询的检索评估摘要
281
  if result.get("retrieval_metrics"):
 
292
 
293
  except Exception as e:
294
  print(f"❌ 系统初始化失败: {e}")
295
+ import traceback
296
+ traceback.print_exc()
297
  print("请检查配置和依赖是否正确安装")
298
 
299
 
workflow_nodes.py CHANGED
@@ -118,9 +118,9 @@ class WorkflowNodes:
118
  "retry_count": 0
119
  }
120
 
121
- def retrieve(self, state):
122
  """
123
- 检索文档
124
 
125
  Args:
126
  state (dict): 当前图状态
@@ -138,8 +138,8 @@ class WorkflowNodes:
138
  # 检查是否有图像路径(多模态检索)
139
  image_paths = state.get("image_paths", None)
140
 
141
- # 使用增强检索
142
- documents = self.doc_processor.enhanced_retrieve(
143
  question,
144
  top_k=5,
145
  rerank_candidates=20,
@@ -157,15 +157,15 @@ class WorkflowNodes:
157
 
158
  except Exception as e:
159
  print(f"⚠️ 增强检索失败: {e},回退到基本检索")
160
- # 回退到基本检索
161
  try:
162
  if self.retriever is not None:
163
- documents = self.retriever.invoke(question)
164
  elif hasattr(self.doc_processor, 'vector_retriever') and self.doc_processor.vector_retriever is not None:
165
- documents = self.doc_processor.vector_retriever.invoke(question)
166
  print(" 使用 vector_retriever 作为备选")
167
  elif hasattr(self.doc_processor, 'retriever') and self.doc_processor.retriever is not None:
168
- documents = self.doc_processor.retriever.invoke(question)
169
  print(" 使用 doc_processor.retriever 作为备选")
170
  else:
171
  print("❌ 检索器未正确初始化,返回空文档列表")
 
118
  "retry_count": 0
119
  }
120
 
121
+ async def retrieve(self, state):
122
  """
123
+ 检索文档 (异步版本)
124
 
125
  Args:
126
  state (dict): 当前图状态
 
138
  # 检查是否有图像路径(多模态检索)
139
  image_paths = state.get("image_paths", None)
140
 
141
+ # 使用异步增强检索
142
+ documents = await self.doc_processor.async_enhanced_retrieve(
143
  question,
144
  top_k=5,
145
  rerank_candidates=20,
 
157
 
158
  except Exception as e:
159
  print(f"⚠️ 增强检索失败: {e},回退到基本检索")
160
+ # 回退到基本检索 (同步回退,如果需要也可以改为异步)
161
  try:
162
  if self.retriever is not None:
163
+ documents = await self.retriever.ainvoke(question)
164
  elif hasattr(self.doc_processor, 'vector_retriever') and self.doc_processor.vector_retriever is not None:
165
+ documents = await self.doc_processor.vector_retriever.ainvoke(question)
166
  print(" 使用 vector_retriever 作为备选")
167
  elif hasattr(self.doc_processor, 'retriever') and self.doc_processor.retriever is not None:
168
+ documents = await self.doc_processor.retriever.ainvoke(question)
169
  print(" 使用 doc_processor.retriever 作为备选")
170
  else:
171
  print("❌ 检索器未正确初始化,返回空文档列表")