Melika Kheirieh commited on
Commit
552a3c5
·
1 Parent(s): 72c0821

refactor(core): DI-ready Pipeline; add registry + YAML factory + typed trace/result

Browse files
Files changed (1) hide show
  1. nl2sql/pipeline_factory.py +43 -29
nl2sql/pipeline_factory.py CHANGED
@@ -79,7 +79,12 @@ def pipeline_from_config(path: str) -> Pipeline:
79
  if is_pytest:
80
 
81
  class _StubDetector:
82
- def detect(self, *args, **kwargs) -> StageResult:
 
 
 
 
 
83
  return StageResult(
84
  ok=True,
85
  data={"questions": []},
@@ -90,13 +95,15 @@ def pipeline_from_config(path: str) -> Pipeline:
90
  },
91
  )
92
 
93
- def run(self, *args, **kwargs) -> StageResult:
94
- return self.detect(*args, **kwargs)
95
-
96
  class _StubPlanner:
97
  def __init__(self, llm: Any = None) -> None: ...
98
 
99
- def plan(self, *args, **kwargs) -> StageResult:
 
 
 
 
 
100
  return StageResult(
101
  ok=True,
102
  data={"plan": "stub plan"},
@@ -107,69 +114,76 @@ def pipeline_from_config(path: str) -> Pipeline:
107
  },
108
  )
109
 
110
- def run(self, *args, **kwargs) -> StageResult:
111
- return self.plan(*args, **kwargs)
112
-
113
  class _StubGenerator:
114
  def __init__(self, llm: Any = None) -> None: ...
115
 
116
- def generate(self, *args, **kwargs) -> StageResult:
 
 
 
 
 
 
117
  return StageResult(
118
  ok=True,
119
- data={"sql": "SELECT 1;", "rationale": "stub"},
120
  trace={
121
  "stage": "generator",
122
  "duration_ms": 0,
123
- "notes": {"rationale_len": 4},
124
  },
125
  )
126
 
127
- def run(self, *args, **kwargs) -> StageResult:
128
- return self.generate(*args, **kwargs)
129
-
130
  class _StubExecutor:
131
  def __init__(self, db: Any | None = None) -> None: ...
132
 
133
- def execute(self, *args, **kwargs) -> StageResult:
 
134
  rows = [{"x": 1}]
 
 
 
 
 
135
  return StageResult(
136
  ok=True,
137
- data={"rows": rows, "row_count": 1},
138
  trace={
139
  "stage": "executor",
140
  "duration_ms": 0,
141
- "notes": {"row_count": 1},
142
  },
143
  )
144
 
145
- def run(self, *args, **kwargs) -> StageResult:
146
- return self.execute(*args, **kwargs)
147
-
148
  class _StubVerifier:
149
- def verify(self, *args, **kwargs) -> StageResult:
 
 
 
 
 
150
  return StageResult(
151
  ok=True,
152
  data={"verified": True},
153
  trace={"stage": "verifier", "duration_ms": 0, "notes": None},
154
  )
155
 
156
- def run(self, *args, **kwargs) -> StageResult:
157
- return self.verify(*args, **kwargs)
158
-
159
  class _StubRepair:
160
  def __init__(self, llm: Any = None) -> None: ...
161
 
162
- def repair(self, *args, **kwargs) -> StageResult:
163
- sql = kwargs.get("sql") or "SELECT 1;"
 
 
 
 
 
164
  return StageResult(
165
  ok=True,
166
  data={"sql": sql},
167
  trace={"stage": "repair", "duration_ms": 0, "notes": None},
168
  )
169
 
170
- def run(self, *args, **kwargs) -> StageResult:
171
- return self.repair(*args, **kwargs)
172
-
173
  detector = _StubDetector()
174
  planner = _StubPlanner()
175
  generator = _StubGenerator()
 
79
  if is_pytest:
80
 
81
  class _StubDetector:
82
+ # Domain method: return list[str]
83
+ def detect(self, *args, **kwargs) -> list[str]:
84
+ return [] # no ambiguities
85
+
86
+ # Compatibility: return StageResult
87
+ def run(self, *args, **kwargs) -> StageResult:
88
  return StageResult(
89
  ok=True,
90
  data={"questions": []},
 
95
  },
96
  )
97
 
 
 
 
98
  class _StubPlanner:
99
  def __init__(self, llm: Any = None) -> None: ...
100
 
101
+ # Domain: return str (plan text)
102
+ def plan(self, *args, **kwargs) -> str:
103
+ return "stub plan"
104
+
105
+ # Compat: StageResult
106
+ def run(self, *args, **kwargs) -> StageResult:
107
  return StageResult(
108
  ok=True,
109
  data={"plan": "stub plan"},
 
114
  },
115
  )
116
 
 
 
 
117
  class _StubGenerator:
118
  def __init__(self, llm: Any = None) -> None: ...
119
 
120
+ # Domain: return tuple[str, str] → (sql, rationale)
121
+ def generate(self, *args, **kwargs) -> tuple[str, str]:
122
+ return "SELECT 1;", "stub"
123
+
124
+ # Compat: StageResult
125
+ def run(self, *args, **kwargs) -> StageResult:
126
+ sql, rationale = self.generate(*args, **kwargs)
127
  return StageResult(
128
  ok=True,
129
+ data={"sql": sql, "rationale": rationale},
130
  trace={
131
  "stage": "generator",
132
  "duration_ms": 0,
133
+ "notes": {"rationale_len": len(rationale)},
134
  },
135
  )
136
 
 
 
 
137
  class _StubExecutor:
138
  def __init__(self, db: Any | None = None) -> None: ...
139
 
140
+ # Domain: return dict (execution result)
141
+ def execute(self, *args, **kwargs) -> Dict[str, Any]:
142
  rows = [{"x": 1}]
143
+ return {"rows": rows, "row_count": len(rows)}
144
+
145
+ # Compat: StageResult
146
+ def run(self, *args, **kwargs) -> StageResult:
147
+ out = self.execute(*args, **kwargs)
148
  return StageResult(
149
  ok=True,
150
+ data=out,
151
  trace={
152
  "stage": "executor",
153
  "duration_ms": 0,
154
+ "notes": {"row_count": out["row_count"]},
155
  },
156
  )
157
 
 
 
 
158
  class _StubVerifier:
159
+ # Domain: return bool
160
+ def verify(self, *args, **kwargs) -> bool:
161
+ return True
162
+
163
+ # Compat: StageResult
164
+ def run(self, *args, **kwargs) -> StageResult:
165
  return StageResult(
166
  ok=True,
167
  data={"verified": True},
168
  trace={"stage": "verifier", "duration_ms": 0, "notes": None},
169
  )
170
 
 
 
 
171
  class _StubRepair:
172
  def __init__(self, llm: Any = None) -> None: ...
173
 
174
+ # Domain: return str (repaired SQL)
175
+ def repair(self, *args, **kwargs) -> str:
176
+ return kwargs.get("sql") or "SELECT 1;"
177
+
178
+ # Compat: StageResult
179
+ def run(self, *args, **kwargs) -> StageResult:
180
+ sql = self.repair(*args, **kwargs)
181
  return StageResult(
182
  ok=True,
183
  data={"sql": sql},
184
  trace={"stage": "repair", "duration_ms": 0, "notes": None},
185
  )
186
 
 
 
 
187
  detector = _StubDetector()
188
  planner = _StubPlanner()
189
  generator = _StubGenerator()