lanny xu commited on
Commit
a576aa9
·
1 Parent(s): c33bb69

resolve conflict

Browse files
Files changed (3) hide show
  1. entity_extractor.py +125 -3
  2. graph_indexer.py +81 -22
  3. requirements_graphrag.txt +3 -0
entity_extractor.py CHANGED
@@ -5,6 +5,9 @@
5
 
6
  from typing import List, Dict, Tuple
7
  import time
 
 
 
8
  try:
9
  from langchain_core.prompts import PromptTemplate
10
  except ImportError:
@@ -16,14 +19,15 @@ from config import LOCAL_LLM
16
 
17
 
18
  class EntityExtractor:
19
- """实体提取器 - 使用LLM从文本中提取实体"""
20
 
21
- def __init__(self, timeout: int = 60, max_retries: int = 3):
22
  """初始化实体提取器
23
 
24
  Args:
25
  timeout: LLM调用超时时间(秒)
26
  max_retries: 失败重试次数
 
27
  """
28
  self.llm = ChatOllama(
29
  model=LOCAL_LLM,
@@ -32,6 +36,8 @@ class EntityExtractor:
32
  timeout=timeout # 添加超时设置
33
  )
34
  self.max_retries = max_retries
 
 
35
 
36
  # 实体提取提示模板
37
  self.entity_prompt = PromptTemplate(
@@ -175,9 +181,124 @@ class EntityExtractor:
175
  return []
176
  return []
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def extract_from_document(self, document_text: str, doc_index: int = 0) -> Dict:
179
  """
180
- 从单个文档中提取实体和关系
181
 
182
  Args:
183
  document_text: 文档文本
@@ -186,6 +307,7 @@ class EntityExtractor:
186
  Returns:
187
  包含实体和关系的字典
188
  """
 
189
  print(f"\n🔍 文档 #{doc_index + 1}: 开始提取...")
190
 
191
  entities = self.extract_entities(document_text)
 
5
 
6
  from typing import List, Dict, Tuple
7
  import time
8
+ import asyncio
9
+ import aiohttp
10
+ import json
11
  try:
12
  from langchain_core.prompts import PromptTemplate
13
  except ImportError:
 
19
 
20
 
21
  class EntityExtractor:
22
+ """实体提取器 - 使用LLM从文本中提取实体(支持异步批处理)"""
23
 
24
+ def __init__(self, timeout: int = 60, max_retries: int = 3, enable_async: bool = True):
25
  """初始化实体提取器
26
 
27
  Args:
28
  timeout: LLM调用超时时间(秒)
29
  max_retries: 失败重试次数
30
+ enable_async: 是否启用异步处理(默认启用)
31
  """
32
  self.llm = ChatOllama(
33
  model=LOCAL_LLM,
 
36
  timeout=timeout # 添加超时设置
37
  )
38
  self.max_retries = max_retries
39
+ self.enable_async = enable_async
40
+ self.ollama_url = "http://localhost:11434/api/generate"
41
 
42
  # 实体提取提示模板
43
  self.entity_prompt = PromptTemplate(
 
181
  return []
182
  return []
183
 
184
+ async def _async_llm_call(self, prompt: str, session: aiohttp.ClientSession, attempt: int = 0) -> Dict:
185
+ """异步调用 Ollama API"""
186
+ try:
187
+ async with session.post(
188
+ self.ollama_url,
189
+ json={
190
+ "model": LOCAL_LLM,
191
+ "prompt": prompt,
192
+ "format": "json",
193
+ "stream": False,
194
+ "options": {"temperature": 0}
195
+ },
196
+ timeout=aiohttp.ClientTimeout(total=self.llm.timeout if hasattr(self.llm, 'timeout') else 60)
197
+ ) as response:
198
+ if response.status == 200:
199
+ result = await response.json()
200
+ return json.loads(result.get('response', '{}'))
201
+ else:
202
+ raise Exception(f"API返回错误: {response.status}")
203
+ except asyncio.TimeoutError:
204
+ if attempt < self.max_retries - 1:
205
+ await asyncio.sleep((attempt + 1) * 2)
206
+ return await self._async_llm_call(prompt, session, attempt + 1)
207
+ raise
208
+ except Exception as e:
209
+ if attempt < self.max_retries - 1:
210
+ await asyncio.sleep(1)
211
+ return await self._async_llm_call(prompt, session, attempt + 1)
212
+ raise
213
+
214
+ async def _extract_entities_async(self, text: str, doc_index: int, session: aiohttp.ClientSession) -> List[Dict]:
215
+ """异步提取实体"""
216
+ prompt = self.entity_prompt.format(text=text[:2000])
217
+
218
+ for attempt in range(self.max_retries):
219
+ try:
220
+ print(f" [文档 #{doc_index + 1}] 🔄 提取实体 (尝试 {attempt + 1}/{self.max_retries})...", end="")
221
+ result = await self._async_llm_call(prompt, session, attempt)
222
+ entities = result.get("entities", [])
223
+ print(f" ✅ {len(entities)} 个实体")
224
+ return entities
225
+ except Exception as e:
226
+ print(f" ❌ {str(e)[:50]}")
227
+ if attempt == self.max_retries - 1:
228
+ return []
229
+ return []
230
+
231
+ async def _extract_relations_async(self, text: str, entities: List[Dict], doc_index: int, session: aiohttp.ClientSession) -> List[Dict]:
232
+ """异步提取关系"""
233
+ if not entities:
234
+ return []
235
+
236
+ entity_names = [e["name"] for e in entities]
237
+ prompt = self.relation_prompt.format(
238
+ text=text[:2000],
239
+ entities=", ".join(entity_names)
240
+ )
241
+
242
+ for attempt in range(self.max_retries):
243
+ try:
244
+ print(f" [文档 #{doc_index + 1}] 🔄 提取关系 (尝试 {attempt + 1}/{self.max_retries})...", end="")
245
+ result = await self._async_llm_call(prompt, session, attempt)
246
+ relations = result.get("relations", [])
247
+ print(f" ✅ {len(relations)} 个关系")
248
+ return relations
249
+ except Exception as e:
250
+ print(f" ❌ {str(e)[:50]}")
251
+ if attempt == self.max_retries - 1:
252
+ return []
253
+ return []
254
+
255
+ async def _extract_from_document_async(self, document_text: str, doc_index: int, session: aiohttp.ClientSession) -> Dict:
256
+ """异步处理单个文档"""
257
+ print(f"\n🔍 [文档 #{doc_index + 1}] 开始异步提取...")
258
+
259
+ # 并发提取实体和关系(先实体,再关系)
260
+ entities = await self._extract_entities_async(document_text, doc_index, session)
261
+ relations = await self._extract_relations_async(document_text, entities, doc_index, session)
262
+
263
+ print(f"📊 [文档 #{doc_index + 1}] 完成: {len(entities)} 实体, {len(relations)} 关系")
264
+
265
+ return {
266
+ "entities": entities,
267
+ "relations": relations
268
+ }
269
+
270
+ async def extract_batch_async(self, documents: List[Tuple[str, int]]) -> List[Dict]:
271
+ """异步批量处理多个文档
272
+
273
+ Args:
274
+ documents: 文档列表,每个元素为 (document_text, doc_index) 元组
275
+
276
+ Returns:
277
+ 提取结果列表
278
+ """
279
+ async with aiohttp.ClientSession() as session:
280
+ tasks = [
281
+ self._extract_from_document_async(doc_text, doc_idx, session)
282
+ for doc_text, doc_idx in documents
283
+ ]
284
+
285
+ # 并发执行所有任务
286
+ results = await asyncio.gather(*tasks, return_exceptions=True)
287
+
288
+ # 处理异常结果
289
+ processed_results = []
290
+ for i, result in enumerate(results):
291
+ if isinstance(result, Exception):
292
+ print(f"⚠️ 文档 #{documents[i][1] + 1} 处理失败: {result}")
293
+ processed_results.append({"entities": [], "relations": []})
294
+ else:
295
+ processed_results.append(result)
296
+
297
+ return processed_results
298
+
299
  def extract_from_document(self, document_text: str, doc_index: int = 0) -> Dict:
300
  """
301
+ 从单个文档中提取实体和关系(同步接口,保持向后兼容)
302
 
303
  Args:
304
  document_text: 文档文本
 
307
  Returns:
308
  包含实体和关系的字典
309
  """
310
+ # 同步方式调用(保持向后兼容)
311
  print(f"\n🔍 文档 #{doc_index + 1}: 开始提取...")
312
 
313
  entities = self.extract_entities(document_text)
graph_indexer.py CHANGED
@@ -4,6 +4,7 @@ GraphRAG索引器
4
  """
5
 
6
  from typing import List, Dict, Optional
 
7
  try:
8
  from langchain_core.documents import Document
9
  except ImportError:
@@ -16,17 +17,26 @@ from knowledge_graph import KnowledgeGraph, CommunitySummarizer
16
  class GraphRAGIndexer:
17
  """GraphRAG索引器 - 实现Microsoft GraphRAG的索引流程"""
18
 
19
- def __init__(self):
 
 
 
 
 
 
20
  print("🚀 初始化GraphRAG索引器...")
21
 
22
- self.entity_extractor = EntityExtractor()
23
  self.entity_deduplicator = EntityDeduplicator()
24
  self.knowledge_graph = KnowledgeGraph()
25
  self.community_summarizer = CommunitySummarizer()
26
 
 
 
27
  self.indexed = False
28
 
29
- print(" GraphRAG索引器初始化完成")
 
30
 
31
  def index_documents(self, documents: List[Document],
32
  batch_size: int = 10,
@@ -58,27 +68,35 @@ class GraphRAGIndexer:
58
  # 步骤1: 实体和关系提取
59
  print("📍 步骤 1/5: 实体和关系提取")
60
  extraction_results = []
61
- total_batches = (len(documents) - 1) // batch_size + 1
62
 
63
- for i in range(0, len(documents), batch_size):
64
- batch = documents[i:i+batch_size]
65
- batch_num = i // batch_size + 1
66
- print(f"\n⚙️ === 批次 {batch_num}/{total_batches} (文档 {i+1}-{min(i+batch_size, len(documents))}) ===")
67
-
68
- for idx, doc in enumerate(batch):
69
- doc_global_index = i + idx
70
- try:
71
- result = self.entity_extractor.extract_from_document(
72
- doc.page_content,
73
- doc_index=doc_global_index
74
- )
75
- extraction_results.append(result)
76
- except Exception as e:
77
- print(f" ❌ 文档 #{doc_global_index + 1} 处理失败: {e}")
78
- # 添加空结果以保持索引一致
79
- extraction_results.append({"entities": [], "relations": []})
80
 
81
- print(f"✅ 批次 {batch_num}/{total_batches} 完成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # 步骤2: 实体去重
84
  print("\n📍 步骤 2/5: 实体去重和合并")
@@ -142,6 +160,47 @@ class GraphRAGIndexer:
142
 
143
  return self.knowledge_graph
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def get_graph(self) -> KnowledgeGraph:
146
  """获取知识图谱"""
147
  if not self.indexed:
 
4
  """
5
 
6
  from typing import List, Dict, Optional
7
+ import asyncio
8
  try:
9
  from langchain_core.documents import Document
10
  except ImportError:
 
17
  class GraphRAGIndexer:
18
  """GraphRAG索引器 - 实现Microsoft GraphRAG的索引流程"""
19
 
20
+ def __init__(self, enable_async: bool = True, async_batch_size: int = 5):
21
+ """初始化GraphRAG索引器
22
+
23
+ Args:
24
+ enable_async: 是否启用异步处理(默认启用)
25
+ async_batch_size: 异步并发批次大小(默认5个文档并发)
26
+ """
27
  print("🚀 初始化GraphRAG索引器...")
28
 
29
+ self.entity_extractor = EntityExtractor(enable_async=enable_async)
30
  self.entity_deduplicator = EntityDeduplicator()
31
  self.knowledge_graph = KnowledgeGraph()
32
  self.community_summarizer = CommunitySummarizer()
33
 
34
+ self.enable_async = enable_async
35
+ self.async_batch_size = async_batch_size
36
  self.indexed = False
37
 
38
+ mode = "异步模式" if enable_async else "同步模式"
39
+ print(f"✅ GraphRAG索引器初始化完成 ({mode}, 并发数={async_batch_size})")
40
 
41
  def index_documents(self, documents: List[Document],
42
  batch_size: int = 10,
 
68
  # 步骤1: 实体和关系提取
69
  print("📍 步骤 1/5: 实体和关系提取")
70
  extraction_results = []
 
71
 
72
+ if self.enable_async:
73
+ # 异步批量处理模式
74
+ print(f"🚀 使用异步处理模式,并发数={self.async_batch_size}")
75
+ extraction_results = self._extract_async(documents)
76
+ else:
77
+ # 同步处理模式(原有逻辑)
78
+ print("🔄 使用同步处理模式")
79
+ total_batches = (len(documents) - 1) // batch_size + 1
 
 
 
 
 
 
 
 
 
80
 
81
+ for i in range(0, len(documents), batch_size):
82
+ batch = documents[i:i+batch_size]
83
+ batch_num = i // batch_size + 1
84
+ print(f"\n⚙️ === 批次 {batch_num}/{total_batches} (文档 {i+1}-{min(i+batch_size, len(documents))}) ===")
85
+
86
+ for idx, doc in enumerate(batch):
87
+ doc_global_index = i + idx
88
+ try:
89
+ result = self.entity_extractor.extract_from_document(
90
+ doc.page_content,
91
+ doc_index=doc_global_index
92
+ )
93
+ extraction_results.append(result)
94
+ except Exception as e:
95
+ print(f" ❌ 文档 #{doc_global_index + 1} 处理失败: {e}")
96
+ # 添加空结果以保持索引一致
97
+ extraction_results.append({"entities": [], "relations": []})
98
+
99
+ print(f"✅ 批次 {batch_num}/{total_batches} 完成")
100
 
101
  # 步骤2: 实体去重
102
  print("\n📍 步骤 2/5: 实体去重和合并")
 
160
 
161
  return self.knowledge_graph
162
 
163
+ def _extract_async(self, documents: List[Document]) -> List[Dict]:
164
+ """异步批量提取实体和关系
165
+
166
+ Args:
167
+ documents: 文档列表
168
+
169
+ Returns:
170
+ 提取结果列表
171
+ """
172
+ total_docs = len(documents)
173
+ extraction_results = []
174
+
175
+ # 将文档分成多个异步批次
176
+ for i in range(0, total_docs, self.async_batch_size):
177
+ batch_end = min(i + self.async_batch_size, total_docs)
178
+ batch_num = i // self.async_batch_size + 1
179
+ total_batches = (total_docs - 1) // self.async_batch_size + 1
180
+
181
+ print(f"\n⚡ === 异步批次 {batch_num}/{total_batches} (文档 {i+1}-{batch_end}) ===")
182
+
183
+ # 准备异步批次数据
184
+ async_batch = [
185
+ (documents[idx].page_content, idx)
186
+ for idx in range(i, batch_end)
187
+ ]
188
+
189
+ # 异步执行当前批次
190
+ try:
191
+ batch_results = asyncio.run(
192
+ self.entity_extractor.extract_batch_async(async_batch)
193
+ )
194
+ extraction_results.extend(batch_results)
195
+ print(f"✅ 异步批次 {batch_num}/{total_batches} 完成")
196
+ except Exception as e:
197
+ print(f"❌ 异步批次 {batch_num} 失败: {e}")
198
+ # 添加空结果
199
+ for _ in range(len(async_batch)):
200
+ extraction_results.append({"entities": [], "relations": []})
201
+
202
+ return extraction_results
203
+
204
  def get_graph(self) -> KnowledgeGraph:
205
  """获取知识图谱"""
206
  if not self.indexed:
requirements_graphrag.txt CHANGED
@@ -35,3 +35,6 @@ plotly>=5.18.0
35
  # 缓存和性能优化
36
  diskcache>=5.6.0
37
  joblib>=1.3.0
 
 
 
 
35
  # 缓存和性能优化
36
  diskcache>=5.6.0
37
  joblib>=1.3.0
38
+
39
+ # 异步HTTP请求(用于并发处理)
40
+ aiohttp>=3.9.0