ianshank Claude commited on
Commit
40ee6b4
·
0 Parent(s):

feat: add personality output and bug fixes

Browse files

- Added PersonalityResponseGenerator for conversational advisor responses
- Updated app.py with personality-infused output section
- Added sentence-transformers dependency
- Includes all trained model files and bug fixes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +27 -0
  2. DEPLOYMENT_GUIDE.md +306 -0
  3. README.md +225 -0
  4. app.py +553 -0
  5. app_mock.py +590 -0
  6. demo_src/__init__.py +1 -0
  7. demo_src/agents_demo.py +234 -0
  8. demo_src/llm_mock.py +182 -0
  9. demo_src/mcts_demo.py +436 -0
  10. demo_src/wandb_tracker.py +349 -0
  11. models/bert_lora/final_model/README.md +206 -0
  12. models/bert_lora/final_model/adapter_config.json +40 -0
  13. models/bert_lora/final_model/adapter_model.safetensors +0 -0
  14. models/bert_lora/generated_dataset.json +0 -0
  15. models/bert_lora/training_results.json +48 -0
  16. models/rnn_meta_controller.history.json +128 -0
  17. models/rnn_meta_controller.pt +0 -0
  18. requirements.txt +28 -0
  19. src/__init__.py +0 -0
  20. src/adapters/__init__.py +7 -0
  21. src/adapters/llm/__init__.py +257 -0
  22. src/adapters/llm/anthropic_client.py +521 -0
  23. src/adapters/llm/base.py +305 -0
  24. src/adapters/llm/exceptions.py +204 -0
  25. src/adapters/llm/lmstudio_client.py +346 -0
  26. src/adapters/llm/openai_client.py +458 -0
  27. src/agents/__init__.py +0 -0
  28. src/agents/hrm_agent.py +454 -0
  29. src/agents/meta_controller/__init__.py +45 -0
  30. src/agents/meta_controller/base.py +219 -0
  31. src/agents/meta_controller/bert_controller.py +428 -0
  32. src/agents/meta_controller/config_loader.py +304 -0
  33. src/agents/meta_controller/rnn_controller.py +345 -0
  34. src/agents/meta_controller/utils.py +201 -0
  35. src/agents/trm_agent.py +395 -0
  36. src/api/__init__.py +35 -0
  37. src/api/auth.py +439 -0
  38. src/api/exceptions.py +299 -0
  39. src/api/inference_server.py +380 -0
  40. src/api/rest_server.py +441 -0
  41. src/config/__init__.py +0 -0
  42. src/config/meta_controller.yaml +22 -0
  43. src/config/settings.py +431 -0
  44. src/data/__init__.py +29 -0
  45. src/data/dataset_loader.py +551 -0
  46. src/data/preprocessing.py +406 -0
  47. src/data/tactical_augmentation.py +484 -0
  48. src/data/train_test_split.py +505 -0
  49. src/framework/__init__.py +1 -0
  50. src/framework/agents/__init__.py +22 -0
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+
7
+ # Virtual environment
8
+ venv/
9
+ env/
10
+ .env
11
+
12
+ # IDE
13
+ .vscode/
14
+ .idea/
15
+ *.swp
16
+ *.swo
17
+
18
+ # OS
19
+ .DS_Store
20
+ Thumbs.db
21
+
22
+ # Gradio
23
+ flagged/
24
+ gradio_cached_examples/
25
+
26
+ # Logs
27
+ *.log
DEPLOYMENT_GUIDE.md ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces Deployment Guide
2
+
3
+ This guide walks you through deploying the LangGraph Multi-Agent MCTS demo to Hugging Face Spaces.
4
+
5
+ ## Prerequisites
6
+
7
+ - [Hugging Face Account](https://huggingface.co/join)
8
+ - Git installed locally
9
+ - Python 3.10+ (for local testing)
10
+
11
+ ## Step 1: Create a New Space
12
+
13
+ 1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
14
+ 2. Click **"Create new Space"**
15
+ 3. Fill in the form:
16
+ - **Owner**: Your username or organization
17
+ - **Space name**: `langgraph-mcts-demo` (or your choice)
18
+ - **License**: MIT
19
+ - **SDK**: Gradio
20
+ - **Hardware**: CPU Basic (Free tier - sufficient for demo)
21
+ - **Visibility**: Public (or Private)
22
+ 4. Click **"Create Space"**
23
+
24
+ ## Step 2: Clone and Deploy
25
+
26
+ ### Option A: Git-based Deployment (Recommended)
27
+
28
+ ```bash
29
+ # 1. Clone your new empty Space
30
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo
31
+ cd langgraph-mcts-demo
32
+
33
+ # 2. Copy demo files from this directory
34
+ cp -r /path/to/huggingface_space/* .
35
+ cp -r /path/to/huggingface_space/.gitignore .
36
+
37
+ # 3. Verify structure
38
+ ls -la
39
+ # Should show:
40
+ # - app.py
41
+ # - requirements.txt
42
+ # - README.md
43
+ # - .gitignore
44
+ # - demo_src/
45
+ # - __init__.py
46
+ # - agents_demo.py
47
+ # - llm_mock.py
48
+ # - mcts_demo.py
49
+
50
+ # 4. Commit and push
51
+ git add -A
52
+ git commit -m "Initial deployment of LangGraph Multi-Agent MCTS demo"
53
+ git push
54
+
55
+ # 5. Space will automatically build and deploy (takes 2-5 minutes)
56
+ ```
57
+
58
+ ### Option B: Direct Upload via Web UI
59
+
60
+ 1. Navigate to your Space on Hugging Face
61
+ 2. Click **"Files"** tab
62
+ 3. Click **"Add file"** → **"Upload files"**
63
+ 4. Upload all files maintaining the directory structure:
64
+ - `app.py`
65
+ - `requirements.txt`
66
+ - `README.md`
67
+ - `.gitignore`
68
+ - `demo_src/__init__.py`
69
+ - `demo_src/agents_demo.py`
70
+ - `demo_src/llm_mock.py`
71
+ - `demo_src/mcts_demo.py`
72
+ 5. Commit changes
73
+
74
+ ## Step 3: Monitor Deployment
75
+
76
+ 1. Go to your Space URL: `https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo`
77
+ 2. Click **"Logs"** tab to monitor build progress
78
+ 3. Wait for "Running on" message
79
+ 4. Your demo is now live!
80
+
81
+ ## Step 4: Test the Demo
82
+
83
+ 1. Enter a query or select an example
84
+ 2. Enable/disable different agents
85
+ 3. Adjust MCTS parameters
86
+ 4. Click "Process Query"
87
+ 5. Review results and consensus scores
88
+
89
+ ## Optional: Enable Real LLM Responses
90
+
91
+ To use Hugging Face Inference API instead of mock responses:
92
+
93
+ ### 1. Update requirements.txt
94
+
95
+ ```txt
96
+ gradio>=4.0.0,<5.0.0
97
+ numpy>=1.24.0,<2.0.0
98
+ huggingface_hub>=0.20.0
99
+ ```
100
+
101
+ ### 2. Add Secret Token
102
+
103
+ 1. Go to Space Settings → **Repository secrets**
104
+ 2. Add new secret:
105
+ - Name: `HF_TOKEN`
106
+ - Value: Your Hugging Face token (from [Settings → Access Tokens](https://huggingface.co/settings/tokens))
107
+
108
+ ### 3. Update app.py Initialization
109
+
110
+ Change line ~290 in `app.py`:
111
+
112
+ ```python
113
+ # From:
114
+ framework = MultiAgentFrameworkDemo(use_hf_inference=False)
115
+
116
+ # To:
117
+ import os
118
+ framework = MultiAgentFrameworkDemo(
119
+ use_hf_inference=True,
120
+ hf_model="mistralai/Mistral-7B-Instruct-v0.2"
121
+ )
122
+ ```
123
+
124
+ ### 4. Commit and Push
125
+
126
+ ```bash
127
+ git add -A
128
+ git commit -m "Enable Hugging Face Inference API"
129
+ git push
130
+ ```
131
+
132
+ ## Optional: Enable Weights & Biases Tracking
133
+
134
+ Track experiments and visualize metrics with W&B integration.
135
+
136
+ ### 1. Get W&B API Key
137
+
138
+ 1. Sign up at [wandb.ai](https://wandb.ai)
139
+ 2. Go to Settings → API Keys
140
+ 3. Copy your API key
141
+
142
+ ### 2. Add W&B Secret to Space
143
+
144
+ 1. Go to Space Settings → **Repository secrets**
145
+ 2. Add new secret:
146
+ - Name: `WANDB_API_KEY`
147
+ - Value: Your W&B API key
148
+
149
+ ### 3. Use W&B in the Demo
150
+
151
+ 1. Expand "Weights & Biases Tracking" accordion in the UI
152
+ 2. Check "Enable W&B Tracking"
153
+ 3. Optionally set:
154
+ - **Project Name**: Your W&B project (default: `langgraph-mcts-demo`)
155
+ - **Run Name**: Custom name for this run (auto-generated if empty)
156
+ 4. Process your query
157
+ 5. View the W&B run URL in the results
158
+
159
+ ### 4. What Gets Logged
160
+
161
+ - **Agent Metrics**: Confidence scores, execution times, response lengths
162
+ - **MCTS Metrics**: Best value, visits, tree depth, exploration paths
163
+ - **Consensus Metrics**: Agreement scores, agent combinations
164
+ - **Performance**: Total processing time
165
+ - **Artifacts**: Full JSON results as artifacts
166
+
167
+ ### 5. View Your Dashboard
168
+
169
+ After runs, visit your W&B project dashboard to:
170
+ - Compare different agent configurations
171
+ - Visualize consensus patterns
172
+ - Analyze MCTS exploration strategies
173
+ - Track performance over time
174
+
175
+ ## Customization Options
176
+
177
+ ### Change Gradio Theme
178
+
179
+ In `app.py`, modify:
180
+
181
+ ```python
182
+ with gr.Blocks(
183
+ theme=gr.themes.Soft(), # Try: Default(), Monochrome(), Glass()
184
+ ...
185
+ ) as demo:
186
+ ```
187
+
188
+ ### Add Custom Examples
189
+
190
+ Update `EXAMPLE_QUERIES` list in `app.py`:
191
+
192
+ ```python
193
+ EXAMPLE_QUERIES = [
194
+ "Your custom query 1",
195
+ "Your custom query 2",
196
+ ...
197
+ ]
198
+ ```
199
+
200
+ ### Adjust MCTS Parameters
201
+
202
+ Modify sliders in `app.py`:
203
+
204
+ ```python
205
+ mcts_iterations = gr.Slider(
206
+ minimum=10,
207
+ maximum=200, # Increase for more thorough search
208
+ value=50, # Change default
209
+ ...
210
+ )
211
+ ```
212
+
213
+ ### Add More Agent Types
214
+
215
+ 1. Create new agent in `demo_src/agents_demo.py`
216
+ 2. Add to `MultiAgentFrameworkDemo` in `app.py`
217
+ 3. Add UI controls in Gradio interface
218
+
219
+ ## Troubleshooting
220
+
221
+ ### Build Fails
222
+
223
+ - Check **Logs** tab for error details
224
+ - Verify `requirements.txt` has compatible versions
225
+ - Ensure all imports in `app.py` are satisfied
226
+
227
+ ### Slow Performance
228
+
229
+ - Reduce default MCTS iterations
230
+ - Use mock LLM (no API calls)
231
+ - Simplify tree visualization
232
+
233
+ ### Memory Issues (Free Tier)
234
+
235
+ - Limit max MCTS iterations to 100
236
+ - Reduce tree depth in `demo_src/mcts_demo.py`
237
+ - Simplify response generation
238
+
239
+ ### Missing Files
240
+
241
+ Ensure directory structure:
242
+ ```
243
+ your-space/
244
+ ├── app.py
245
+ ├── requirements.txt
246
+ ├── README.md
247
+ ├── .gitignore
248
+ └── demo_src/
249
+ ├── __init__.py
250
+ ├── agents_demo.py
251
+ ├── llm_mock.py
252
+ ├── mcts_demo.py
253
+ └── wandb_tracker.py
254
+ ```
255
+
256
+ ## Upgrading Hardware
257
+
258
+ For better performance:
259
+
260
+ 1. Go to Space Settings
261
+ 2. Under **Hardware**, select:
262
+ - **CPU Upgrade** ($0.03/hr) - Faster processing
263
+ - **T4 Small** ($0.60/hr) - GPU for neural models
264
+ 3. Save changes
265
+
266
+ ## Sharing Your Space
267
+
268
+ ### Embed in Website
269
+
270
+ ```html
271
+ <iframe
272
+ src="https://YOUR_USERNAME-langgraph-mcts-demo.hf.space"
273
+ frameborder="0"
274
+ width="100%"
275
+ height="600"
276
+ ></iframe>
277
+ ```
278
+
279
+ ### Direct Link
280
+
281
+ Share: `https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo`
282
+
283
+ ### API Access
284
+
285
+ Gradio automatically provides API endpoint:
286
+ ```
287
+ https://YOUR_USERNAME-langgraph-mcts-demo.hf.space/api/predict
288
+ ```
289
+
290
+ ## Next Steps
291
+
292
+ 1. **Collect Feedback**: Enable flagging for user feedback
293
+ 2. **Add Analytics**: Track usage patterns
294
+ 3. **Extend Agents**: Add domain-specific reasoning modules
295
+ 4. **Integrate RAG**: Connect to vector databases for real context
296
+ 5. **Add Visualization**: Enhanced tree and consensus displays
297
+
298
+ ## Support
299
+
300
+ - **Hugging Face Docs**: https://huggingface.co/docs/hub/spaces
301
+ - **Gradio Docs**: https://www.gradio.app/docs
302
+ - **Full Framework**: https://github.com/ianshank/langgraph_multi_agent_mcts
303
+
304
+ ---
305
+
306
+ **Happy Deploying!** 🚀
README.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LangGraph Multi-Agent MCTS Demo
3
+ emoji: 🌳
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ tags:
12
+ - multi-agent
13
+ - mcts
14
+ - reasoning
15
+ - langgraph
16
+ - ai-agents
17
+ - wandb
18
+ - experiment-tracking
19
+ short_description: Multi-agent reasoning framework with Monte Carlo Tree Search
20
+ ---
21
+
22
+ # LangGraph Multi-Agent MCTS Framework
23
+
24
+ **Production Demo with Trained Neural Models** - Experience real trained meta-controllers for intelligent agent routing
25
+
26
+ ## What This Demo Shows
27
+
28
+ This interactive demo showcases trained neural meta-controllers that dynamically route queries to specialized agents:
29
+
30
+ ### 🤖 Trained Meta-Controllers
31
+
32
+ 1. **RNN Meta-Controller**
33
+ - GRU-based recurrent neural network
34
+ - Learns sequential patterns in agent performance
35
+ - Fast inference (~2ms latency)
36
+ - Trained on 1000+ synthetic routing examples
37
+
38
+ 2. **BERT Meta-Controller with LoRA**
39
+ - Transformer-based text understanding
40
+ - Parameter-efficient fine-tuning with LoRA adapters
41
+ - Context-aware routing decisions
42
+ - Better generalization to unseen query patterns
43
+
44
+ ### 🧠 Three Specialized Agents
45
+
46
+ 1. **HRM (Hierarchical Reasoning Module)**
47
+ - Best for: Complex decomposition, multi-level problems
48
+ - Technique: Hierarchical planning with adaptive computation
49
+
50
+ 2. **TRM (Tree Reasoning Module)**
51
+ - Best for: Iterative refinement, comparison tasks
52
+ - Technique: Recursive refinement with convergence detection
53
+
54
+ 3. **MCTS (Monte Carlo Tree Search)**
55
+ - Best for: Optimization, strategic planning
56
+ - Technique: UCB1 exploration with value backpropagation
57
+
58
+ ### 📊 Key Features
59
+
60
+ - **Real Trained Models**: Production-ready neural meta-controllers
61
+ - **Intelligent Routing**: Models learn optimal agent selection patterns
62
+ - **Routing Visualization**: See confidence scores and probability distributions
63
+ - **Feature Engineering**: Demonstrates query → features → routing pipeline
64
+ - **Performance Metrics**: Track execution time and routing accuracy
65
+
66
+ ## How to Use
67
+
68
+ 1. **Enter a Query**: Type your question or select an example
69
+ 2. **Select Controller**: Choose RNN (fast) or BERT (context-aware)
70
+ 3. **Process Query**: Click "🚀 Process Query"
71
+ 4. **Review Results**:
72
+ - See which agent the controller selected
73
+ - View routing confidence and probabilities
74
+ - Examine features used for decision-making
75
+ - Check agent execution details
76
+
77
+ ## Weights & Biases Integration
78
+
79
+ Track your experiments with **Weights & Biases** for:
80
+ - 📈 **Metrics Dashboard**: Visualize consensus scores, execution times, agent performance
81
+ - 🔄 **Run Comparison**: Compare different configurations side-by-side
82
+ - 📊 **Experiment History**: Track all your queries and results
83
+ - 🌳 **MCTS Visualization**: Log tree exploration patterns
84
+
85
+ ### Setting Up W&B
86
+
87
+ 1. **Get API Key**: Sign up at [wandb.ai](https://wandb.ai) and get your API key
88
+ 2. **Configure Space Secret** (if deploying your own):
89
+ - Go to Space Settings → Repository secrets
90
+ - Add: `WANDB_API_KEY` = your API key
91
+ 3. **Enable in UI**:
92
+ - Expand "Weights & Biases Tracking" accordion
93
+ - Check "Enable W&B Tracking"
94
+ - Set project name (optional)
95
+ - Set run name (optional, auto-generated if empty)
96
+ 4. **View Results**: After processing, click the W&B run URL to see your dashboard
97
+
98
+ ### Logged Metrics
99
+
100
+ - **Per Agent**: Confidence, execution time, response length, reasoning steps
101
+ - **MCTS**: Best value, visits, tree depth, top actions with UCB1 scores
102
+ - **Consensus**: Score, level (high/medium/low), number of agents
103
+ - **Performance**: Total processing time
104
+ - **Artifacts**: Full JSON results, tree visualizations
105
+
106
+ ## Example Queries
107
+
108
+ - "What are the key factors to consider when choosing between microservices and monolithic architecture?"
109
+ - "How can we optimize a Python application that processes 10GB of log files daily?"
110
+ - "Should we use SQL or NoSQL database for a social media application with 1M users?"
111
+ - "How to design a fault-tolerant message queue system?"
112
+
113
+ ## Technical Details
114
+
115
+ ### Architecture
116
+
117
+ ```
118
+ Query Input
119
+
120
+ ├─→ HRM Agent (Hierarchical Decomposition)
121
+ │ ├─ Component Analysis
122
+ │ └─ Structured Synthesis
123
+
124
+ ├─→ TRM Agent (Iterative Refinement)
125
+ │ ├─ Initial Response
126
+ │ ├─ Clarity Enhancement
127
+ │ └─ Validation Check
128
+
129
+ └─→ MCTS Engine (Strategic Search)
130
+ ├─ Selection (UCB1)
131
+ ├─ Expansion
132
+ ├─ Simulation
133
+ └─ Backpropagation
134
+
135
+
136
+ Consensus Scoring
137
+
138
+
139
+ Final Synthesized Response
140
+ ```
141
+
142
+ ### MCTS Algorithm
143
+
144
+ The Monte Carlo Tree Search implementation uses:
145
+
146
+ - **UCB1 Selection**: `Q(s,a) + C * sqrt(ln(N(s)) / N(s,a))`
147
+ - **Progressive Widening**: Controls branching factor
148
+ - **Domain-Aware Actions**: Contextual decision options
149
+ - **Value Backpropagation**: Updates entire path statistics
150
+
151
+ ### Consensus Calculation
152
+
153
+ ```
154
+ consensus = average_confidence * agreement_factor
155
+ agreement_factor = max(0, 1 - std_deviation * 2)
156
+ ```
157
+
158
+ High consensus (>70%) indicates agents agree on approach.
159
+ Low consensus (<40%) suggests uncertainty or conflicting strategies.
160
+
161
+ ## Demo Scope
162
+
163
+ This demonstration focuses on **meta-controller training and routing**:
164
+
165
+ - ✅ **Real Trained Models**: Production RNN and BERT controllers
166
+ - ✅ **Actual Model Loading**: PyTorch and HuggingFace Transformers
167
+ - ✅ **Feature Engineering**: Query analysis → feature vectors
168
+ - ✅ **Routing Visualization**: See controller decision-making
169
+ - ⚠️ **Simplified Agents**: Agent responses are mocked for demo purposes
170
+ - ⚠️ **No Live LLM Calls**: Agents don't call actual LLMs (to reduce latency/cost)
171
+
172
+ ## Full Production Framework
173
+
174
+ The complete repository includes all production features:
175
+
176
+ - ✅ **Neural Meta-Controllers**: RNN and BERT with LoRA (deployed here!)
177
+ - ✅ **Agent Implementations**: Full HRM, TRM, and MCTS with PyTorch
178
+ - ✅ **Training Pipeline**: Data generation, training, evaluation
179
+ - ✅ **LLM Integration**: OpenAI, Anthropic, LM Studio support
180
+ - ✅ **RAG Systems**: ChromaDB, FAISS, Pinecone vector stores
181
+ - ✅ **Observability**: OpenTelemetry tracing, Prometheus metrics
182
+ - ✅ **Storage**: S3 artifact storage, experiment tracking
183
+ - ✅ **CI/CD**: Automated testing, security scanning, deployment
184
+
185
+ **GitHub Repository**: [ianshank/langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts)
186
+
187
+ ## Technical Stack
188
+
189
+ - **Python**: 3.11+
190
+ - **UI**: Gradio 4.x
191
+ - **ML Frameworks**: PyTorch 2.1+, Transformers, PEFT (LoRA)
192
+ - **Models**: GRU-based RNN, BERT-mini with LoRA adapters
193
+ - **Architecture**: Neural meta-controller + multi-agent system
194
+ - **Experiment Tracking**: Weights & Biases (optional)
195
+ - **Numerical**: NumPy
196
+
197
+ ## Research Applications
198
+
199
+ This framework demonstrates concepts applicable to:
200
+
201
+ - Complex decision-making systems
202
+ - AI-assisted software architecture decisions
203
+ - Multi-perspective problem analysis
204
+ - Strategic planning with uncertainty
205
+
206
+ ## Citation
207
+
208
+ If you use this framework in research, please cite:
209
+
210
+ ```bibtex
211
+ @software{langgraph_mcts_2024,
212
+ title={LangGraph Multi-Agent MCTS Framework},
213
+ author={Your Name},
214
+ year={2024},
215
+ url={https://github.com/ianshank/langgraph_multi_agent_mcts}
216
+ }
217
+ ```
218
+
219
+ ## License
220
+
221
+ MIT License - See repository for details.
222
+
223
+ ---
224
+
225
+ **Built with** LangGraph, Gradio, and Python | **Demo Version**: 1.0.0
app.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph Multi-Agent MCTS Framework - Integrated Demo with Trained Models
3
+
4
+ Demonstrates the actual trained neural meta-controllers:
5
+ - RNN Meta-Controller for sequential pattern recognition
6
+ - BERT with LoRA adapters for text-based routing
7
+
8
+ This is a production demonstration using real trained models.
9
+ """
10
+
11
+ import asyncio
12
+ import sys
13
+ import time
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+
17
+ # Fail fast if critical dependencies are missing or broken
18
+ try:
19
+ import peft
20
+
21
+ print(f"[OK] PEFT library imported successfully (version: {peft.__version__})")
22
+ except ImportError as e:
23
+ print(f"CRITICAL ERROR: Could not import peft library: {e}")
24
+ # We don't exit here to allow the app to crash naturally later with full stack trace,
25
+ # but this print ensures it's visible in the logs immediately.
26
+
27
+ import gradio as gr
28
+ import torch
29
+
30
+ # Import the trained controllers
31
+ sys.path.insert(0, str(Path(__file__).parent))
32
+
33
+ from src.agents.meta_controller.base import MetaControllerFeatures
34
+ from src.agents.meta_controller.bert_controller import BERTMetaController
35
+ from src.agents.meta_controller.rnn_controller import RNNMetaController
36
+ from src.agents.meta_controller.feature_extractor import (
37
+ FeatureExtractor,
38
+ FeatureExtractorConfig,
39
+ )
40
+ from src.utils.personality_response import PersonalityResponseGenerator
41
+
42
+
43
+ @dataclass
44
+ class AgentResult:
45
+ """Result from a single agent."""
46
+
47
+ agent_name: str
48
+ response: str
49
+ confidence: float
50
+ reasoning_steps: list[str]
51
+ execution_time_ms: float
52
+
53
+
54
+ @dataclass
55
+ class ControllerDecision:
56
+ """Decision made by the meta-controller."""
57
+
58
+ selected_agent: str
59
+ confidence: float
60
+ routing_probabilities: dict[str, float]
61
+ features_used: dict
62
+
63
+
64
+ def create_features_from_query(
65
+ query: str,
66
+ iteration: int = 0,
67
+ last_agent: str = "none",
68
+ feature_extractor: FeatureExtractor | None = None,
69
+ ) -> MetaControllerFeatures:
70
+ """
71
+ Convert a text query into features for the meta-controller.
72
+
73
+ Uses semantic embeddings for robust feature extraction. Falls back to
74
+ heuristic-based extraction if embeddings are not available.
75
+
76
+ Args:
77
+ query: The input query text
78
+ iteration: Current iteration number
79
+ last_agent: Name of the last agent used
80
+ feature_extractor: Optional FeatureExtractor instance (created if None)
81
+
82
+ Returns:
83
+ MetaControllerFeatures instance
84
+ """
85
+ # Use provided feature extractor or create a new one
86
+ if feature_extractor is None:
87
+ try:
88
+ config = FeatureExtractorConfig.from_env()
89
+ feature_extractor = FeatureExtractor(config)
90
+ except Exception as e:
91
+ print(f"Warning: Failed to initialize FeatureExtractor: {e}")
92
+ print("Falling back to heuristic-based feature extraction")
93
+ # Will use heuristic fallback below
94
+
95
+ # Extract features using the feature extractor
96
+ try:
97
+ if feature_extractor is not None:
98
+ return feature_extractor.extract_features(query, iteration, last_agent)
99
+ except Exception as e:
100
+ print(f"Warning: Feature extraction failed: {e}")
101
+ print("Falling back to heuristic-based feature extraction")
102
+
103
+ # Fallback to original heuristic-based extraction
104
+ # (This code is kept as a safety net but should rarely be used)
105
+ query_length = len(query)
106
+
107
+ # Estimate complexity based on query characteristics
108
+ has_multiple_questions = "?" in query and query.count("?") > 1
109
+ has_comparison = any(word in query.lower() for word in ["vs", "versus", "compare", "difference", "better"])
110
+ has_optimization = any(word in query.lower() for word in ["optimize", "best", "improve", "maximize", "minimize"])
111
+ has_technical = any(word in query.lower() for word in ["algorithm", "code", "implement", "technical", "system"])
112
+
113
+ # Create mock confidence scores based on query characteristics
114
+ hrm_confidence = 0.5 + (0.3 if has_multiple_questions else 0) + (0.1 if has_technical else 0)
115
+ trm_confidence = 0.5 + (0.3 if has_comparison else 0) + (0.1 if query_length > 100 else 0)
116
+ mcts_confidence = 0.5 + (0.3 if has_optimization else 0) + (0.1 if has_technical else 0)
117
+
118
+ # Normalize
119
+ total = hrm_confidence + trm_confidence + mcts_confidence
120
+ if total == 0:
121
+ hrm_confidence = 1.0 / 3.0
122
+ trm_confidence = 1.0 / 3.0
123
+ mcts_confidence = 1.0 / 3.0
124
+ else:
125
+ hrm_confidence /= total
126
+ trm_confidence /= total
127
+ mcts_confidence /= total
128
+
129
+ # Calculate consensus score
130
+ max_confidence = max(hrm_confidence, trm_confidence, mcts_confidence)
131
+ if max_confidence == 0:
132
+ consensus_score = 0.0
133
+ else:
134
+ consensus_score = min(hrm_confidence, trm_confidence, mcts_confidence) / max_confidence
135
+
136
+ features = MetaControllerFeatures(
137
+ hrm_confidence=hrm_confidence,
138
+ trm_confidence=trm_confidence,
139
+ mcts_value=mcts_confidence,
140
+ consensus_score=consensus_score,
141
+ last_agent=last_agent,
142
+ iteration=iteration,
143
+ query_length=query_length,
144
+ has_rag_context=query_length > 50,
145
+ rag_relevance_score=0.7 if query_length > 50 else 0.0,
146
+ is_technical_query=has_technical,
147
+ )
148
+
149
+ return features
150
+
151
+
152
+ class IntegratedFramework:
153
+ """
154
+ Integrated multi-agent framework using trained meta-controllers.
155
+ """
156
+
157
+ def __init__(self):
158
+ """Initialize the framework with trained models."""
159
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
160
+ print(f"Using device: {self.device}")
161
+
162
+ # Initialize feature extractor with semantic embeddings
163
+ print("Initializing Feature Extractor...")
164
+ try:
165
+ config = FeatureExtractorConfig.from_env()
166
+ # Set device to match the framework device
167
+ config.device = self.device
168
+ self.feature_extractor = FeatureExtractor(config)
169
+ print(f"[OK] Feature Extractor initialized: {self.feature_extractor}")
170
+ except Exception as e:
171
+ print(f"[WARN] Failed to initialize Feature Extractor: {e}")
172
+ print("[WARN] Will fall back to heuristic-based feature extraction")
173
+ self.feature_extractor = None
174
+
175
+ # Load trained RNN Meta-Controller
176
+ print("Loading RNN Meta-Controller...")
177
+ self.rnn_controller = RNNMetaController(name="RNNController", seed=42, device=self.device)
178
+
179
+ # Load the trained weights
180
+ rnn_model_path = Path(__file__).parent / "models" / "rnn_meta_controller.pt"
181
+ if rnn_model_path.exists():
182
+ checkpoint = torch.load(rnn_model_path, map_location=self.device, weights_only=True)
183
+ self.rnn_controller.model.load_state_dict(checkpoint)
184
+ self.rnn_controller.model.eval()
185
+ print(f"[OK] Loaded RNN model from {rnn_model_path}")
186
+ else:
187
+ print(f"[WARN] RNN model not found at {rnn_model_path}, using untrained model")
188
+
189
+ # Load trained BERT Meta-Controller with LoRA
190
+ print("Loading BERT Meta-Controller with LoRA...")
191
+ self.bert_controller = BERTMetaController(name="BERTController", seed=42, device=self.device, use_lora=True)
192
+
193
+ bert_model_path = Path(__file__).parent / "models" / "bert_lora" / "final_model"
194
+ if bert_model_path.exists():
195
+ try:
196
+ self.bert_controller.load_model(str(bert_model_path))
197
+ print(f"[OK] Loaded BERT LoRA model from {bert_model_path}")
198
+ except Exception as e:
199
+ print(f"[WARN] Error loading BERT model: {e}")
200
+ print("Using untrained BERT model")
201
+ else:
202
+ print(f"[WARN] BERT model not found at {bert_model_path}, using untrained model")
203
+
204
+ # Agent routing map
205
+ self.agent_handlers = {
206
+ "hrm": self._handle_hrm,
207
+ "trm": self._handle_trm,
208
+ "mcts": self._handle_mcts,
209
+ }
210
+
211
+ print("Framework initialized successfully!")
212
+
213
+ async def process_query(
214
+ self,
215
+ query: str,
216
+ controller_type: str = "rnn",
217
+ ) -> tuple[AgentResult, ControllerDecision]:
218
+ """
219
+ Process a query using the trained meta-controller.
220
+
221
+ Args:
222
+ query: The input query
223
+ controller_type: Which controller to use ("rnn" or "bert")
224
+
225
+ Returns:
226
+ (agent_result, controller_decision) tuple
227
+ """
228
+ start_time = time.perf_counter()
229
+
230
+ # Step 1: Convert query to features using semantic embeddings
231
+ features = create_features_from_query(query, feature_extractor=self.feature_extractor)
232
+
233
+ # Step 2: Get controller decision
234
+ if controller_type == "rnn":
235
+ prediction = self.rnn_controller.predict(features)
236
+ else: # bert
237
+ prediction = self.bert_controller.predict(features)
238
+
239
+ selected_agent = prediction.agent
240
+ confidence = prediction.confidence
241
+
242
+ # Get routing probabilities (prediction.probabilities is already a dict)
243
+ routing_probs = prediction.probabilities
244
+
245
+ # Step 3: Route to selected agent
246
+ handler = self.agent_handlers.get(selected_agent, self._handle_hrm)
247
+ agent_result = await handler(query)
248
+
249
+ # Create controller decision summary
250
+ controller_decision = ControllerDecision(
251
+ selected_agent=selected_agent,
252
+ confidence=confidence,
253
+ routing_probabilities=routing_probs,
254
+ features_used={
255
+ "hrm_confidence": features.hrm_confidence,
256
+ "trm_confidence": features.trm_confidence,
257
+ "mcts_value": features.mcts_value,
258
+ "consensus_score": features.consensus_score,
259
+ "query_length": features.query_length,
260
+ "is_technical": features.is_technical_query,
261
+ },
262
+ )
263
+
264
+ total_time = (time.perf_counter() - start_time) * 1000
265
+ agent_result.execution_time_ms = round(total_time, 2)
266
+
267
+ return agent_result, controller_decision
268
+
269
+ async def _handle_hrm(self, query: str) -> AgentResult:
270
+ """Handle query with Hierarchical Reasoning Module."""
271
+ # Simulate HRM processing
272
+ await asyncio.sleep(0.1)
273
+
274
+ steps = [
275
+ "Decompose query into hierarchical subproblems",
276
+ "Apply high-level reasoning (H-Module)",
277
+ "Execute low-level refinement (L-Module)",
278
+ "Synthesize hierarchical solution",
279
+ ]
280
+
281
+ response = f"[HRM Analysis] Breaking down the problem hierarchically: {query[:100]}..."
282
+
283
+ return AgentResult(
284
+ agent_name="HRM (Hierarchical Reasoning)",
285
+ response=response,
286
+ confidence=0.85,
287
+ reasoning_steps=steps,
288
+ execution_time_ms=0.0,
289
+ )
290
+
291
+ async def _handle_trm(self, query: str) -> AgentResult:
292
+ """Handle query with Tree Reasoning Module."""
293
+ # Simulate TRM processing
294
+ await asyncio.sleep(0.1)
295
+
296
+ steps = [
297
+ "Initialize solution state",
298
+ "Recursive refinement iteration 1",
299
+ "Recursive refinement iteration 2",
300
+ "Convergence achieved - finalize",
301
+ ]
302
+
303
+ response = f"[TRM Analysis] Applying iterative refinement: {query[:100]}..."
304
+
305
+ return AgentResult(
306
+ agent_name="TRM (Iterative Refinement)",
307
+ response=response,
308
+ confidence=0.80,
309
+ reasoning_steps=steps,
310
+ execution_time_ms=0.0,
311
+ )
312
+
313
+ async def _handle_mcts(self, query: str) -> AgentResult:
314
+ """Handle query with MCTS."""
315
+ # Simulate MCTS processing
316
+ await asyncio.sleep(0.15)
317
+
318
+ steps = [
319
+ "Build search tree",
320
+ "Selection: UCB1 exploration",
321
+ "Expansion: Add promising nodes",
322
+ "Simulation: Rollout evaluation",
323
+ "Backpropagation: Update values",
324
+ ]
325
+
326
+ response = f"[MCTS Analysis] Strategic exploration via tree search: {query[:100]}..."
327
+
328
+ return AgentResult(
329
+ agent_name="MCTS (Monte Carlo Tree Search)",
330
+ response=response,
331
+ confidence=0.88,
332
+ reasoning_steps=steps,
333
+ execution_time_ms=0.0,
334
+ )
335
+
336
+
337
+ # Global framework instance
338
+ framework = None
339
+
340
+
341
+ def initialize_framework():
342
+ """Initialize or reinitialize the framework."""
343
+ global framework
344
+ try:
345
+ framework = IntegratedFramework()
346
+ return "[OK] Framework initialized with trained models!"
347
+ except Exception as e:
348
+ return f"[ERROR] Error initializing framework: {str(e)}"
349
+
350
+
351
+ def process_query_sync(
352
+ query: str,
353
+ controller_type: str,
354
+ ):
355
+ """Synchronous wrapper for async processing."""
356
+ global framework
357
+
358
+ if framework is None:
359
+ framework = IntegratedFramework()
360
+
361
+ if not query.strip():
362
+ return ("Please enter a query.", {}, "", {}, "", "")
363
+
364
+ # Run async function
365
+ agent_result, controller_decision = asyncio.run(
366
+ framework.process_query(query=query, controller_type=controller_type.lower())
367
+ )
368
+
369
+ # Format outputs
370
+ final_response = agent_result.response
371
+
372
+ # Generate personality-infused response
373
+ personality_gen = PersonalityResponseGenerator()
374
+ try:
375
+ personality_response = personality_gen.generate_response(
376
+ agent_response=final_response,
377
+ query=query
378
+ )
379
+ except Exception as e:
380
+ # Fallback to a simple wrapper if personality generation fails
381
+ personality_response = f"Here's what I found:\n\n{final_response}"
382
+ print(f"Warning: Personality generation failed: {e}")
383
+
384
+ # Controller decision visualization
385
+ routing_viz = "### 🧠 Meta-Controller Decision\n\n"
386
+ routing_viz += f"**Selected Agent:** `{controller_decision.selected_agent.upper()}`\n\n"
387
+ routing_viz += f"**Confidence:** {controller_decision.confidence:.1%}\n\n"
388
+ routing_viz += "**Routing Probabilities:**\n"
389
+ for agent, prob in controller_decision.routing_probabilities.items():
390
+ bar = "█" * int(prob * 50)
391
+ routing_viz += f"- **{agent.upper()}**: {prob:.1%} {bar}\n"
392
+
393
+ # Agent details
394
+ agent_details = {
395
+ "agent": agent_result.agent_name,
396
+ "confidence": f"{agent_result.confidence:.1%}",
397
+ "reasoning_steps": agent_result.reasoning_steps,
398
+ "execution_time_ms": agent_result.execution_time_ms,
399
+ }
400
+
401
+ # Features used
402
+ features_viz = "### 📊 Features Used for Routing\n\n"
403
+ for feature, value in controller_decision.features_used.items():
404
+ if isinstance(value, float):
405
+ features_viz += f"- **{feature}**: {value:.3f}\n"
406
+ elif isinstance(value, bool):
407
+ features_viz += f"- **{feature}**: {'Yes' if value else 'No'}\n"
408
+ else:
409
+ features_viz += f"- **{feature}**: {value}\n"
410
+
411
+ # Metrics
412
+ metrics = f"""
413
+ **Controller:** {controller_type}
414
+ **Execution Time:** {agent_result.execution_time_ms:.2f} ms
415
+ **Agent Confidence:** {agent_result.confidence:.1%}
416
+ """
417
+
418
+ return final_response, agent_details, routing_viz, features_viz, metrics, personality_response
419
+
420
+
421
+ # Example queries
422
+ EXAMPLE_QUERIES = [
423
+ "What are the key factors to consider when choosing between microservices and monolithic architecture?",
424
+ "How can we optimize a Python application that processes 10GB of log files daily?",
425
+ "Compare the performance characteristics of B-trees vs LSM-trees for write-heavy workloads",
426
+ "Design a distributed rate limiting system that handles 100k requests per second",
427
+ "Explain the difference between supervised and unsupervised learning with examples",
428
+ ]
429
+
430
+
431
+ # Gradio Interface
432
+ with gr.Blocks(
433
+ title="LangGraph Multi-Agent MCTS - Trained Models Demo",
434
+ theme=gr.themes.Soft(),
435
+ css="""
436
+ .agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; }
437
+ .highlight { background-color: #e3f2fd; padding: 10px; border-radius: 5px; margin: 10px 0; }
438
+ """,
439
+ ) as demo:
440
+ gr.Markdown(
441
+ """
442
+ # 🎯 LangGraph Multi-Agent MCTS Framework
443
+ ## Production Demo with Trained Neural Meta-Controllers
444
+
445
+ This demo uses **REAL trained models**:
446
+ - 🧠 **RNN Meta-Controller**: GRU-based sequential pattern recognition
447
+ - 🤖 **BERT with LoRA**: Transformer-based text understanding for routing
448
+
449
+ The meta-controllers learn to route queries to the optimal agent:
450
+ - **HRM**: Hierarchical reasoning for complex decomposition
451
+ - **TRM**: Iterative refinement for progressive improvement
452
+ - **MCTS**: Strategic exploration for optimization problems
453
+
454
+ ---
455
+ """
456
+ )
457
+
458
+ with gr.Row():
459
+ with gr.Column(scale=2):
460
+ query_input = gr.Textbox(
461
+ label="Query", placeholder="Enter your question or reasoning task...", lines=4, max_lines=10
462
+ )
463
+
464
+ gr.Markdown("**Example Queries:**")
465
+ example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True)
466
+
467
+ def load_example(example):
468
+ return example
469
+
470
+ example_dropdown.change(load_example, example_dropdown, query_input)
471
+
472
+ with gr.Column(scale=1):
473
+ gr.Markdown("**Meta-Controller Selection**")
474
+ controller_type = gr.Radio(
475
+ choices=["RNN", "BERT"],
476
+ value="RNN",
477
+ label="Controller Type",
478
+ info="Choose which trained controller to use",
479
+ )
480
+
481
+ gr.Markdown(
482
+ """
483
+ **Controller Comparison:**
484
+ - **RNN**: Fast, captures sequential patterns
485
+ - **BERT**: More context-aware, text understanding
486
+ """
487
+ )
488
+
489
+ process_btn = gr.Button("🚀 Process Query", variant="primary", size="lg")
490
+
491
+ gr.Markdown("---")
492
+
493
+ with gr.Row():
494
+ with gr.Column():
495
+ gr.Markdown("### 🎯 Agent Response")
496
+ final_response_output = gr.Textbox(label="Response", lines=4, interactive=False)
497
+
498
+ gr.Markdown("### 🤝 Personality-Infused Response")
499
+ gr.Markdown("*A conversational, balanced advisor interpretation*")
500
+ personality_output = gr.Textbox(label="Balanced Advisor Response", lines=8, interactive=False)
501
+
502
+ gr.Markdown("### 📈 Performance Metrics")
503
+ metrics_output = gr.Markdown()
504
+
505
+ with gr.Column():
506
+ routing_viz = gr.Markdown(label="Controller Decision")
507
+ features_viz = gr.Markdown(label="Features")
508
+
509
+ with gr.Accordion("🔍 Detailed Agent Information", open=False):
510
+ agent_details_output = gr.JSON(label="Agent Execution Details")
511
+
512
+ # Wire up the processing
513
+ process_btn.click(
514
+ fn=process_query_sync,
515
+ inputs=[
516
+ query_input,
517
+ controller_type,
518
+ ],
519
+ outputs=[final_response_output, agent_details_output, routing_viz, features_viz, metrics_output, personality_output],
520
+ )
521
+
522
+ gr.Markdown(
523
+ """
524
+ ---
525
+
526
+ ### 📚 About This Demo
527
+
528
+ This is a **production demonstration** of trained neural meta-controllers for multi-agent routing.
529
+
530
+ **Models:**
531
+ - RNN Meta-Controller: 10-dimensional feature vector → 3-class routing (HRM/TRM/MCTS)
532
+ - BERT with LoRA: Text features → routing decision with adapters
533
+
534
+ **Training:**
535
+ - Synthetic dataset: 1000+ samples with balanced routing decisions
536
+ - Optimization: Adam optimizer, cross-entropy loss
537
+ - Validation: 80/20 train/val split with early stopping
538
+
539
+ **Repository:** [GitHub - langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts)
540
+
541
+ ---
542
+ *Built with PyTorch, Transformers, PEFT, and Gradio*
543
+ """
544
+ )
545
+
546
+
547
+ if __name__ == "__main__":
548
+ # Initialize framework
549
+ print("Initializing framework with trained models...")
550
+ framework = IntegratedFramework()
551
+
552
+ # Launch the demo
553
+ demo.launch(server_name="0.0.0.0", share=False, show_error=True)
app_mock.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LangGraph Multi-Agent MCTS Framework - Hugging Face Spaces Demo
3
+
4
+ A proof-of-concept demonstration of multi-agent reasoning with Monte Carlo Tree Search.
5
+ """
6
+
7
+ import asyncio
8
+ import time
9
+ from dataclasses import dataclass
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+
14
+ # Demo-specific simplified implementations
15
+ from demo_src.agents_demo import HRMAgent, TRMAgent
16
+ from demo_src.llm_mock import HuggingFaceClient, MockLLMClient
17
+ from demo_src.mcts_demo import MCTSDemo
18
+ from demo_src.wandb_tracker import WandBTracker, is_wandb_available
19
+
20
+
21
+ @dataclass
22
+ class AgentResult:
23
+ """Result from a single agent."""
24
+
25
+ agent_name: str
26
+ response: str
27
+ confidence: float
28
+ reasoning_steps: list[str]
29
+ execution_time_ms: float
30
+
31
+
32
+ @dataclass
33
+ class FrameworkResult:
34
+ """Combined result from all agents."""
35
+
36
+ query: str
37
+ hrm_result: AgentResult | None
38
+ trm_result: AgentResult | None
39
+ mcts_result: dict | None
40
+ consensus_score: float
41
+ final_response: str
42
+ total_time_ms: float
43
+ metadata: dict
44
+
45
+
46
+ class MultiAgentFrameworkDemo:
47
+ """Simplified multi-agent framework for Hugging Face Spaces demo."""
48
+
49
+ def __init__(self, use_hf_inference: bool = False, hf_model: str = ""):
50
+ """Initialize the demo framework.
51
+
52
+ Args:
53
+ use_hf_inference: Use Hugging Face Inference API instead of mock
54
+ hf_model: Hugging Face model ID for inference
55
+ """
56
+ self.use_hf_inference = use_hf_inference
57
+ self.hf_model = hf_model
58
+
59
+ # Initialize components
60
+ if use_hf_inference and hf_model:
61
+ self.llm_client = HuggingFaceClient(model_id=hf_model)
62
+ else:
63
+ self.llm_client = MockLLMClient()
64
+
65
+ self.hrm_agent = HRMAgent(self.llm_client)
66
+ self.trm_agent = TRMAgent(self.llm_client)
67
+ self.mcts = MCTSDemo()
68
+
69
+ async def process_query(
70
+ self,
71
+ query: str,
72
+ use_hrm: bool = True,
73
+ use_trm: bool = True,
74
+ use_mcts: bool = False,
75
+ mcts_iterations: int = 25,
76
+ exploration_weight: float = 1.414,
77
+ seed: int | None = None,
78
+ ) -> FrameworkResult:
79
+ """Process a query through the multi-agent framework.
80
+
81
+ Args:
82
+ query: The input query to process
83
+ use_hrm: Enable Hierarchical Reasoning Module
84
+ use_trm: Enable Tree Reasoning Module
85
+ use_mcts: Enable Monte Carlo Tree Search
86
+ mcts_iterations: Number of MCTS iterations
87
+ exploration_weight: UCB1 exploration parameter
88
+ seed: Random seed for reproducibility
89
+
90
+ Returns:
91
+ FrameworkResult with all agent outputs and consensus
92
+ """
93
+ start_time = time.perf_counter()
94
+
95
+ hrm_result = None
96
+ trm_result = None
97
+ mcts_result = None
98
+
99
+ # Run enabled agents
100
+ tasks = []
101
+ agent_names = []
102
+
103
+ if use_hrm:
104
+ tasks.append(self._run_hrm(query))
105
+ agent_names.append("hrm")
106
+
107
+ if use_trm:
108
+ tasks.append(self._run_trm(query))
109
+ agent_names.append("trm")
110
+
111
+ if use_mcts:
112
+ tasks.append(self._run_mcts(query, mcts_iterations, exploration_weight, seed))
113
+ agent_names.append("mcts")
114
+
115
+ # Execute agents concurrently
116
+ if tasks:
117
+ results = await asyncio.gather(*tasks, return_exceptions=True)
118
+
119
+ for name, result in zip(agent_names, results, strict=False):
120
+ if isinstance(result, Exception):
121
+ continue
122
+ if name == "hrm":
123
+ hrm_result = result
124
+ elif name == "trm":
125
+ trm_result = result
126
+ elif name == "mcts":
127
+ mcts_result = result
128
+
129
+ # Calculate consensus score
130
+ consensus_score = self._calculate_consensus(hrm_result, trm_result, mcts_result)
131
+
132
+ # Generate final synthesized response
133
+ final_response = self._synthesize_response(query, hrm_result, trm_result, mcts_result, consensus_score)
134
+
135
+ total_time = (time.perf_counter() - start_time) * 1000
136
+
137
+ return FrameworkResult(
138
+ query=query,
139
+ hrm_result=hrm_result,
140
+ trm_result=trm_result,
141
+ mcts_result=mcts_result,
142
+ consensus_score=consensus_score,
143
+ final_response=final_response,
144
+ total_time_ms=round(total_time, 2),
145
+ metadata={
146
+ "agents_used": agent_names,
147
+ "mcts_config": (
148
+ {"iterations": mcts_iterations, "exploration_weight": exploration_weight, "seed": seed}
149
+ if use_mcts
150
+ else None
151
+ ),
152
+ },
153
+ )
154
+
155
+ async def _run_hrm(self, query: str) -> AgentResult:
156
+ """Run Hierarchical Reasoning Module."""
157
+ start = time.perf_counter()
158
+ result = await self.hrm_agent.process(query)
159
+ elapsed = (time.perf_counter() - start) * 1000
160
+
161
+ return AgentResult(
162
+ agent_name="HRM (Hierarchical Reasoning)",
163
+ response=result["response"],
164
+ confidence=result["confidence"],
165
+ reasoning_steps=result["steps"],
166
+ execution_time_ms=round(elapsed, 2),
167
+ )
168
+
169
+ async def _run_trm(self, query: str) -> AgentResult:
170
+ """Run Tree Reasoning Module."""
171
+ start = time.perf_counter()
172
+ result = await self.trm_agent.process(query)
173
+ elapsed = (time.perf_counter() - start) * 1000
174
+
175
+ return AgentResult(
176
+ agent_name="TRM (Iterative Refinement)",
177
+ response=result["response"],
178
+ confidence=result["confidence"],
179
+ reasoning_steps=result["steps"],
180
+ execution_time_ms=round(elapsed, 2),
181
+ )
182
+
183
+ async def _run_mcts(self, query: str, iterations: int, exploration_weight: float, seed: int | None) -> dict:
184
+ """Run Monte Carlo Tree Search."""
185
+ start = time.perf_counter()
186
+
187
+ # MCTSDemo.search is now async and uses the production framework
188
+ result = await self.mcts.search(query=query, iterations=iterations, exploration_weight=exploration_weight, seed=seed)
189
+
190
+ elapsed = (time.perf_counter() - start) * 1000
191
+ result["execution_time_ms"] = round(elapsed, 2)
192
+
193
+ return result
194
+
195
+ def _calculate_consensus(
196
+ self, hrm_result: AgentResult | None, trm_result: AgentResult | None, mcts_result: dict | None
197
+ ) -> float:
198
+ """Calculate agreement score between agents."""
199
+ confidences = []
200
+
201
+ if hrm_result:
202
+ confidences.append(hrm_result.confidence)
203
+ if trm_result:
204
+ confidences.append(trm_result.confidence)
205
+ if mcts_result:
206
+ confidences.append(mcts_result.get("best_value", 0.5))
207
+
208
+ if not confidences:
209
+ return 0.0
210
+
211
+ # Consensus is based on confidence alignment and average
212
+ if len(confidences) == 1:
213
+ return confidences[0]
214
+
215
+ avg_confidence = np.mean(confidences)
216
+ std_confidence = np.std(confidences)
217
+
218
+ # Higher consensus when agents agree (low std) and are confident (high avg)
219
+ agreement_factor = max(0, 1 - std_confidence * 2)
220
+ consensus = avg_confidence * agreement_factor
221
+
222
+ return round(min(1.0, consensus), 3)
223
+
224
+ def _synthesize_response(
225
+ self,
226
+ query: str,
227
+ hrm_result: AgentResult | None,
228
+ trm_result: AgentResult | None,
229
+ mcts_result: dict | None,
230
+ consensus_score: float,
231
+ ) -> str:
232
+ """Synthesize final response from all agent outputs."""
233
+ parts = []
234
+
235
+ if hrm_result and hrm_result.confidence > 0.5:
236
+ parts.append(f"[HRM] {hrm_result.response}")
237
+
238
+ if trm_result and trm_result.confidence > 0.5:
239
+ parts.append(f"[TRM] {trm_result.response}")
240
+
241
+ if mcts_result and mcts_result.get("best_value", 0) > 0.5:
242
+ parts.append(f"[MCTS] Best path: {mcts_result.get('best_action', 'N/A')}")
243
+
244
+ if not parts:
245
+ truncated_query = f"{query[:80]}..." if len(query) > 80 else query
246
+ return f"Insufficient confidence to answer query: '{truncated_query}'."
247
+
248
+ synthesis = " | ".join(parts)
249
+
250
+ if consensus_score > 0.7:
251
+ return f"HIGH CONSENSUS ({consensus_score:.1%}): {synthesis}"
252
+ elif consensus_score > 0.4:
253
+ return f"MODERATE CONSENSUS ({consensus_score:.1%}): {synthesis}"
254
+ else:
255
+ return f"LOW CONSENSUS ({consensus_score:.1%}): {synthesis}"
256
+
257
+
258
+ # Global framework instance
259
+ framework = None
260
+ wandb_tracker = None
261
+
262
+
263
+ def initialize_framework(use_hf: bool, model_id: str):
264
+ """Initialize or reinitialize the framework."""
265
+ global framework
266
+ framework = MultiAgentFrameworkDemo(use_hf_inference=use_hf, hf_model=model_id)
267
+ return "Framework initialized successfully!"
268
+
269
+
270
+ def process_query_sync(
271
+ query: str,
272
+ use_hrm: bool,
273
+ use_trm: bool,
274
+ use_mcts: bool,
275
+ mcts_iterations: int,
276
+ exploration_weight: float,
277
+ seed: int,
278
+ enable_wandb: bool = False,
279
+ wandb_project: str = "langgraph-mcts-demo",
280
+ wandb_run_name: str = "",
281
+ ):
282
+ """Synchronous wrapper for async processing."""
283
+ global framework, wandb_tracker
284
+
285
+ if framework is None:
286
+ framework = MultiAgentFrameworkDemo()
287
+
288
+ if not query.strip():
289
+ return "Please enter a query.", {}, "", {}, ""
290
+
291
+ # Handle seed
292
+ seed_value = seed if seed > 0 else None
293
+
294
+ # Initialize W&B tracking if enabled
295
+ wandb_url = ""
296
+ if enable_wandb and is_wandb_available():
297
+ if wandb_tracker is None:
298
+ wandb_tracker = WandBTracker(project_name=wandb_project, enabled=True)
299
+
300
+ # Start a new run
301
+ run_name = wandb_run_name if wandb_run_name.strip() else None
302
+ config = {
303
+ "query": query[:200], # Truncate for config
304
+ "use_hrm": use_hrm,
305
+ "use_trm": use_trm,
306
+ "use_mcts": use_mcts,
307
+ "mcts_iterations": mcts_iterations,
308
+ "exploration_weight": exploration_weight,
309
+ "seed": seed_value,
310
+ }
311
+ wandb_tracker.init_run(run_name=run_name, config=config)
312
+
313
+ # Run async function
314
+ result = asyncio.run(
315
+ framework.process_query(
316
+ query=query,
317
+ use_hrm=use_hrm,
318
+ use_trm=use_trm,
319
+ use_mcts=use_mcts,
320
+ mcts_iterations=int(mcts_iterations),
321
+ exploration_weight=exploration_weight,
322
+ seed=seed_value,
323
+ )
324
+ )
325
+
326
+ # Format outputs
327
+ final_response = result.final_response
328
+
329
+ # Agent details
330
+ agent_details = {}
331
+ if result.hrm_result:
332
+ agent_details["HRM"] = {
333
+ "response": result.hrm_result.response,
334
+ "confidence": f"{result.hrm_result.confidence:.1%}",
335
+ "reasoning_steps": result.hrm_result.reasoning_steps,
336
+ "time_ms": result.hrm_result.execution_time_ms,
337
+ }
338
+
339
+ # Log to W&B
340
+ if enable_wandb and wandb_tracker:
341
+ wandb_tracker.log_agent_result(
342
+ "HRM",
343
+ result.hrm_result.response,
344
+ result.hrm_result.confidence,
345
+ result.hrm_result.execution_time_ms,
346
+ result.hrm_result.reasoning_steps,
347
+ )
348
+
349
+ if result.trm_result:
350
+ agent_details["TRM"] = {
351
+ "response": result.trm_result.response,
352
+ "confidence": f"{result.trm_result.confidence:.1%}",
353
+ "reasoning_steps": result.trm_result.reasoning_steps,
354
+ "time_ms": result.trm_result.execution_time_ms,
355
+ }
356
+
357
+ # Log to W&B
358
+ if enable_wandb and wandb_tracker:
359
+ wandb_tracker.log_agent_result(
360
+ "TRM",
361
+ result.trm_result.response,
362
+ result.trm_result.confidence,
363
+ result.trm_result.execution_time_ms,
364
+ result.trm_result.reasoning_steps,
365
+ )
366
+
367
+ if result.mcts_result:
368
+ agent_details["MCTS"] = result.mcts_result
369
+
370
+ # Log to W&B
371
+ if enable_wandb and wandb_tracker:
372
+ wandb_tracker.log_mcts_result(result.mcts_result)
373
+
374
+ # Log consensus and performance to W&B
375
+ if enable_wandb and wandb_tracker:
376
+ wandb_tracker.log_consensus(result.consensus_score, result.metadata["agents_used"], result.final_response)
377
+ wandb_tracker.log_performance(result.total_time_ms)
378
+ wandb_tracker.log_query_summary(query, use_hrm, use_trm, use_mcts, result.consensus_score, result.total_time_ms)
379
+
380
+ # Get run URL
381
+ wandb_url = wandb_tracker.get_run_url() or ""
382
+
383
+ # Finish the run
384
+ wandb_tracker.finish_run()
385
+
386
+ # Metrics
387
+ metrics = f"""
388
+ **Consensus Score:** {result.consensus_score:.1%}
389
+ **Total Processing Time:** {result.total_time_ms:.2f} ms
390
+ **Agents Used:** {", ".join(result.metadata["agents_used"])}
391
+ """
392
+
393
+ if wandb_url:
394
+ metrics += f"\n**W&B Run:** [{wandb_url}]({wandb_url})"
395
+
396
+ # Full JSON result
397
+ full_result = {
398
+ "query": result.query,
399
+ "final_response": result.final_response,
400
+ "consensus_score": result.consensus_score,
401
+ "total_time_ms": result.total_time_ms,
402
+ "metadata": result.metadata,
403
+ "agent_details": agent_details,
404
+ "wandb_url": wandb_url,
405
+ }
406
+
407
+ return final_response, agent_details, metrics, full_result, wandb_url
408
+
409
+
410
+ def visualize_mcts_tree(mcts_result: dict) -> str:
411
+ """Create ASCII visualization of MCTS tree."""
412
+ if not mcts_result or "tree_visualization" not in mcts_result:
413
+ return "No MCTS tree data available"
414
+
415
+ return mcts_result["tree_visualization"]
416
+
417
+
418
+ # Example queries for demonstration
419
+ EXAMPLE_QUERIES = [
420
+ "What are the key factors to consider when choosing between microservices and monolithic architecture?",
421
+ "How can we optimize a Python application that processes 10GB of log files daily?",
422
+ "What is the best approach to implement rate limiting in a distributed system?",
423
+ "Should we use SQL or NoSQL database for a social media application with 1M users?",
424
+ "How to design a fault-tolerant message queue system?",
425
+ ]
426
+
427
+
428
+ # Gradio Interface
429
+ with gr.Blocks(
430
+ title="LangGraph Multi-Agent MCTS Demo",
431
+ theme=gr.themes.Soft(),
432
+ css="""
433
+ .agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; }
434
+ .consensus-high { color: #28a745; font-weight: bold; }
435
+ .consensus-medium { color: #ffc107; font-weight: bold; }
436
+ .consensus-low { color: #dc3545; font-weight: bold; }
437
+ """,
438
+ ) as demo:
439
+ gr.Markdown(
440
+ """
441
+ # LangGraph Multi-Agent MCTS Framework
442
+
443
+ **Proof-of-Concept Demo** - Multi-agent reasoning with Monte Carlo Tree Search
444
+
445
+ This demo showcases:
446
+ - **HRM**: Hierarchical Reasoning Module - breaks down complex queries
447
+ - **TRM**: Tree Reasoning Module - iterative refinement of responses
448
+ - **MCTS**: Monte Carlo Tree Search - strategic exploration of solution space
449
+ - **Consensus**: Agreement scoring between agents
450
+
451
+ ---
452
+ """
453
+ )
454
+
455
+ with gr.Row():
456
+ with gr.Column(scale=2):
457
+ query_input = gr.Textbox(
458
+ label="Query", placeholder="Enter your reasoning task or question...", lines=3, max_lines=10
459
+ )
460
+
461
+ gr.Markdown("**Example Queries:**")
462
+ example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True)
463
+
464
+ def load_example(example):
465
+ return example
466
+
467
+ example_dropdown.change(load_example, example_dropdown, query_input)
468
+
469
+ with gr.Column(scale=1):
470
+ gr.Markdown("**Agent Configuration**")
471
+ use_hrm = gr.Checkbox(label="Enable HRM (Hierarchical)", value=True)
472
+ use_trm = gr.Checkbox(label="Enable TRM (Iterative)", value=True)
473
+ use_mcts = gr.Checkbox(label="Enable MCTS", value=False)
474
+
475
+ gr.Markdown("**MCTS Parameters**")
476
+ mcts_iterations = gr.Slider(
477
+ minimum=10,
478
+ maximum=100,
479
+ value=25,
480
+ step=5,
481
+ label="Iterations",
482
+ info="More iterations = better search, but slower",
483
+ )
484
+ exploration_weight = gr.Slider(
485
+ minimum=0.1,
486
+ maximum=3.0,
487
+ value=1.414,
488
+ step=0.1,
489
+ label="Exploration Weight (C)",
490
+ info="Higher = more exploration, Lower = more exploitation",
491
+ )
492
+ seed_input = gr.Number(label="Random Seed (0 for random)", value=0, precision=0)
493
+
494
+ with gr.Accordion("Weights & Biases Tracking", open=False):
495
+ gr.Markdown(
496
+ """
497
+ **Experiment Tracking with W&B**
498
+
499
+ Track your experiments, visualize metrics, and compare runs.
500
+ Requires W&B API key set in Space secrets as `WANDB_API_KEY`.
501
+ """
502
+ )
503
+ with gr.Row():
504
+ enable_wandb = gr.Checkbox(
505
+ label="Enable W&B Tracking", value=False, info="Log metrics and results to Weights & Biases"
506
+ )
507
+ wandb_project = gr.Textbox(
508
+ label="Project Name", value="langgraph-mcts-demo", placeholder="Your W&B project name"
509
+ )
510
+ wandb_run_name = gr.Textbox(label="Run Name (optional)", value="", placeholder="Auto-generated if empty")
511
+
512
+ wandb_status = gr.Markdown(f"**W&B Status:** {'Available' if is_wandb_available() else 'Not installed'}")
513
+
514
+ process_btn = gr.Button("Process Query", variant="primary", size="lg")
515
+
516
+ gr.Markdown("---")
517
+
518
+ with gr.Row():
519
+ with gr.Column():
520
+ gr.Markdown("### Final Response")
521
+ final_response_output = gr.Textbox(label="Synthesized Response", lines=4, interactive=False)
522
+
523
+ gr.Markdown("### Performance Metrics")
524
+ metrics_output = gr.Markdown()
525
+
526
+ with gr.Column():
527
+ gr.Markdown("### Agent Details")
528
+ agent_details_output = gr.JSON(label="Individual Agent Results")
529
+
530
+ with gr.Accordion("Full JSON Result", open=False):
531
+ full_result_output = gr.JSON(label="Complete Framework Output")
532
+
533
+ with gr.Accordion("W&B Run Details", open=False, visible=True):
534
+ wandb_url_output = gr.Textbox(
535
+ label="W&B Run URL", interactive=False, placeholder="Enable W&B tracking to see run URL here"
536
+ )
537
+
538
+ # Wire up the processing
539
+ process_btn.click(
540
+ fn=process_query_sync,
541
+ inputs=[
542
+ query_input,
543
+ use_hrm,
544
+ use_trm,
545
+ use_mcts,
546
+ mcts_iterations,
547
+ exploration_weight,
548
+ seed_input,
549
+ enable_wandb,
550
+ wandb_project,
551
+ wandb_run_name,
552
+ ],
553
+ outputs=[final_response_output, agent_details_output, metrics_output, full_result_output, wandb_url_output],
554
+ )
555
+
556
+ gr.Markdown(
557
+ """
558
+ ---
559
+
560
+ ### About This Demo
561
+
562
+ This is a **proof-of-concept** demonstration of the LangGraph Multi-Agent MCTS Framework.
563
+
564
+ **Features:**
565
+ - Multi-agent orchestration with consensus scoring
566
+ - Monte Carlo Tree Search for strategic reasoning
567
+ - Configurable exploration vs exploitation trade-offs
568
+ - Deterministic results with seeded randomness
569
+ - **Weights & Biases integration** for experiment tracking
570
+
571
+ **Limitations (POC):**
572
+ - Uses mock/simplified LLM responses (not production LLM)
573
+ - Limited to demonstration scenarios
574
+ - No persistent storage or RAG
575
+ - Simplified MCTS implementation
576
+
577
+ **Full Framework:** [GitHub Repository](https://github.com/ianshank/langgraph_multi_agent_mcts)
578
+
579
+ ---
580
+ *Built with LangGraph, Gradio, Weights & Biases, and Python*
581
+ """
582
+ )
583
+
584
+
585
+ if __name__ == "__main__":
586
+ # Initialize with mock client for demo
587
+ framework = MultiAgentFrameworkDemo(use_hf_inference=False)
588
+
589
+ # Launch the demo
590
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
demo_src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Demo source modules for Hugging Face Spaces
demo_src/agents_demo.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified agent implementations for Hugging Face Spaces demo.
3
+ """
4
+
5
+ import asyncio
6
+ from typing import Any
7
+
8
+
9
+ class HRMAgent:
10
+ """Hierarchical Reasoning Module - breaks down complex queries."""
11
+
12
+ def __init__(self, llm_client):
13
+ """Initialize with an LLM client.
14
+
15
+ Args:
16
+ llm_client: LLM client (MockLLMClient or HuggingFaceClient)
17
+ """
18
+ self.llm_client = llm_client
19
+ self.name = "HRM (Hierarchical Reasoning)"
20
+
21
+ async def process(self, query: str) -> dict[str, Any]:
22
+ """Process query using hierarchical decomposition.
23
+
24
+ Args:
25
+ query: Input query to process
26
+
27
+ Returns:
28
+ Dictionary with response, confidence, and reasoning steps
29
+ """
30
+ # Step 1: Decompose the query
31
+ decomposition_steps = await self._decompose_query(query)
32
+
33
+ # Step 2: Analyze each component
34
+ analysis_results = await self._analyze_components(decomposition_steps)
35
+
36
+ # Step 3: Synthesize hierarchical response
37
+ llm_result = await self.llm_client.generate(
38
+ prompt=f"Hierarchical analysis of: {query}", context=f"Components: {', '.join(decomposition_steps)}"
39
+ )
40
+
41
+ # Compile reasoning steps
42
+ reasoning_steps = [
43
+ f"1. Query decomposition: Identified {len(decomposition_steps)} key components",
44
+ f"2. Component analysis: {analysis_results}",
45
+ "3. Hierarchical synthesis: Combined insights from all levels",
46
+ f"4. Confidence assessment: {llm_result['confidence']:.1%} based on component clarity",
47
+ ]
48
+
49
+ return {
50
+ "response": llm_result["response"],
51
+ "confidence": llm_result["confidence"],
52
+ "steps": reasoning_steps,
53
+ "components": decomposition_steps,
54
+ "tokens_used": llm_result.get("tokens_used", 0),
55
+ }
56
+
57
+ async def _decompose_query(self, query: str) -> list[str]:
58
+ """Decompose query into hierarchical components."""
59
+ # Simulate decomposition based on query structure
60
+ await asyncio.sleep(0.05) # Simulate processing
61
+
62
+ # Simple heuristic decomposition
63
+ components = []
64
+
65
+ # Extract key phrases
66
+ query_lower = query.lower()
67
+
68
+ if "?" in query:
69
+ components.append("Question type: Analytical")
70
+ else:
71
+ components.append("Question type: Directive")
72
+
73
+ if "how" in query_lower:
74
+ components.append("Focus: Methodology/Process")
75
+ elif "what" in query_lower:
76
+ components.append("Focus: Definition/Identification")
77
+ elif "why" in query_lower:
78
+ components.append("Focus: Causation/Reasoning")
79
+ elif "should" in query_lower or "best" in query_lower:
80
+ components.append("Focus: Decision/Recommendation")
81
+ else:
82
+ components.append("Focus: General inquiry")
83
+
84
+ # Domain detection
85
+ if any(term in query_lower for term in ["database", "sql", "nosql", "storage"]):
86
+ components.append("Domain: Data Management")
87
+ elif any(term in query_lower for term in ["architecture", "design", "pattern"]):
88
+ components.append("Domain: System Architecture")
89
+ elif any(term in query_lower for term in ["performance", "optimization", "speed"]):
90
+ components.append("Domain: Performance Engineering")
91
+ elif any(term in query_lower for term in ["scale", "distributed", "cluster"]):
92
+ components.append("Domain: Distributed Systems")
93
+ else:
94
+ components.append("Domain: Software Engineering")
95
+
96
+ # Complexity assessment
97
+ word_count = len(query.split())
98
+ if word_count > 20:
99
+ components.append("Complexity: High (detailed query)")
100
+ elif word_count > 10:
101
+ components.append("Complexity: Medium")
102
+ else:
103
+ components.append("Complexity: Low (concise query)")
104
+
105
+ return components
106
+
107
+ async def _analyze_components(self, components: list[str]) -> str:
108
+ """Analyze the decomposed components."""
109
+ await asyncio.sleep(0.03) # Simulate processing
110
+
111
+ # Generate analysis summary
112
+ analysis_parts = []
113
+
114
+ for component in components:
115
+ if "Focus:" in component:
116
+ focus = component.split(":")[1].strip()
117
+ analysis_parts.append(f"requires {focus.lower()} approach")
118
+ elif "Domain:" in component:
119
+ domain = component.split(":")[1].strip()
120
+ analysis_parts.append(f"applies to {domain}")
121
+ elif "Complexity:" in component:
122
+ complexity = component.split(":")[1].strip().split()[0]
123
+ analysis_parts.append(f"{complexity.lower()} complexity level")
124
+
125
+ return "; ".join(analysis_parts) if analysis_parts else "Standard analysis"
126
+
127
+
128
+ class TRMAgent:
129
+ """Tree Reasoning Module - iterative refinement of responses."""
130
+
131
+ def __init__(self, llm_client):
132
+ """Initialize with an LLM client.
133
+
134
+ Args:
135
+ llm_client: LLM client (MockLLMClient or HuggingFaceClient)
136
+ """
137
+ self.llm_client = llm_client
138
+ self.name = "TRM (Iterative Refinement)"
139
+ self.max_iterations = 3
140
+
141
+ async def process(self, query: str) -> dict[str, Any]:
142
+ """Process query using iterative refinement.
143
+
144
+ Args:
145
+ query: Input query to process
146
+
147
+ Returns:
148
+ Dictionary with response, confidence, and reasoning steps
149
+ """
150
+ reasoning_steps = []
151
+ current_response = ""
152
+ current_confidence = 0.0
153
+
154
+ # Iterative refinement loop
155
+ for iteration in range(self.max_iterations):
156
+ step_num = iteration + 1
157
+
158
+ # Generate or refine response
159
+ if iteration == 0:
160
+ # Initial response
161
+ result = await self.llm_client.generate(prompt=query, context="")
162
+ current_response = result["response"]
163
+ current_confidence = result["confidence"]
164
+ reasoning_steps.append(
165
+ f"Iteration {step_num}: Initial response generated (confidence: {current_confidence:.1%})"
166
+ )
167
+ else:
168
+ # Refinement iteration
169
+ refinement_result = await self._refine_response(query, current_response, iteration)
170
+ current_response = refinement_result["response"]
171
+
172
+ # Confidence typically improves with refinement
173
+ confidence_improvement = min(0.1, (1 - current_confidence) * 0.3)
174
+ current_confidence = min(0.95, current_confidence + confidence_improvement)
175
+
176
+ reasoning_steps.append(
177
+ f"Iteration {step_num}: {refinement_result['refinement_type']} "
178
+ f"(confidence: {current_confidence:.1%})"
179
+ )
180
+
181
+ # Check if confidence is high enough to stop
182
+ if current_confidence > 0.85:
183
+ reasoning_steps.append(f"Early termination: High confidence ({current_confidence:.1%}) achieved")
184
+ break
185
+
186
+ # Final reasoning step
187
+ reasoning_steps.append(f"Final: Response refined through {len(reasoning_steps)} iterations")
188
+
189
+ return {
190
+ "response": current_response,
191
+ "confidence": round(current_confidence, 3),
192
+ "steps": reasoning_steps,
193
+ "iterations_used": min(iteration + 1, self.max_iterations),
194
+ "refinement_history": reasoning_steps,
195
+ }
196
+
197
+ async def _refine_response(self, query: str, current_response: str, iteration: int) -> dict[str, Any]:
198
+ """Refine the current response."""
199
+ await asyncio.sleep(0.05) # Simulate refinement processing
200
+
201
+ # Different refinement strategies based on iteration
202
+ refinement_strategies = [
203
+ ("Clarity enhancement", "improve clarity and precision"),
204
+ ("Detail expansion", "add technical depth and specifics"),
205
+ ("Validation check", "verify accuracy and completeness"),
206
+ ]
207
+
208
+ strategy_name, strategy_action = refinement_strategies[iteration % len(refinement_strategies)]
209
+
210
+ # Generate refined response
211
+ refinement_prompt = f"""
212
+ Original query: {query}
213
+ Current response: {current_response}
214
+ Refinement task: {strategy_action}
215
+ """
216
+
217
+ result = await self.llm_client.generate(
218
+ prompt=refinement_prompt, context=f"Refinement iteration {iteration + 1}"
219
+ )
220
+
221
+ # Enhance the response based on strategy
222
+ enhanced_response = current_response
223
+ if strategy_name == "Clarity enhancement":
224
+ enhanced_response = f"{current_response}. {result['response']}"
225
+ elif strategy_name == "Detail expansion":
226
+ enhanced_response = f"{current_response}. Furthermore, {result['response']}"
227
+ else: # Validation
228
+ enhanced_response = f"{current_response}. Validated: {result['response']}"
229
+
230
+ # Truncate if too long
231
+ if len(enhanced_response) > 300:
232
+ enhanced_response = enhanced_response[:297] + "..."
233
+
234
+ return {"response": enhanced_response, "refinement_type": strategy_name, "strategy_action": strategy_action}
demo_src/llm_mock.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mock and lightweight LLM clients for demo purposes.
3
+ """
4
+
5
+ import asyncio
6
+ import random
7
+ from typing import Any
8
+
9
+
10
+ class MockLLMClient:
11
+ """Mock LLM client that generates plausible demo responses."""
12
+
13
+ def __init__(self):
14
+ self.response_templates = {
15
+ "architecture": [
16
+ "Consider scalability requirements and team expertise",
17
+ "Evaluate coupling, deployment complexity, and operational overhead",
18
+ "Balance between development speed and long-term maintainability",
19
+ ],
20
+ "optimization": [
21
+ "Profile first to identify actual bottlenecks",
22
+ "Consider memory-mapped files and streaming processing",
23
+ "Implement parallel processing with appropriate chunk sizes",
24
+ ],
25
+ "database": [
26
+ "Consider data consistency requirements and query patterns",
27
+ "Evaluate write-heavy vs read-heavy workload characteristics",
28
+ "Plan for horizontal scaling and data distribution strategies",
29
+ ],
30
+ "distributed": [
31
+ "Implement proper failure detection and recovery mechanisms",
32
+ "Use circuit breakers and bulkhead patterns for resilience",
33
+ "Consider eventual consistency vs strong consistency trade-offs",
34
+ ],
35
+ "default": [
36
+ "Break down the problem into smaller components",
37
+ "Consider trade-offs between different approaches",
38
+ "Evaluate based on specific use case requirements",
39
+ ],
40
+ }
41
+
42
+ async def generate(self, prompt: str, context: str = "") -> dict[str, Any]:
43
+ """Generate a mock response based on the prompt and optional context."""
44
+ # Simulate processing time
45
+ await asyncio.sleep(random.uniform(0.1, 0.3))
46
+
47
+ # Determine response category
48
+ prompt_lower = prompt.lower()
49
+ if "architecture" in prompt_lower or "microservice" in prompt_lower or "monolith" in prompt_lower:
50
+ category = "architecture"
51
+ elif "optim" in prompt_lower or "performance" in prompt_lower or "process" in prompt_lower:
52
+ category = "optimization"
53
+ elif "database" in prompt_lower or "sql" in prompt_lower or "nosql" in prompt_lower:
54
+ category = "database"
55
+ elif "distribut" in prompt_lower or "fault" in prompt_lower or "rate limit" in prompt_lower:
56
+ category = "distributed"
57
+ else:
58
+ category = "default"
59
+
60
+ templates = self.response_templates[category]
61
+
62
+ # Generate response with some randomness
63
+ response = random.choice(templates)
64
+ confidence = random.uniform(0.6, 0.95)
65
+
66
+ # Add more detail based on prompt length (simulating "understanding")
67
+ if len(prompt) > 100:
68
+ confidence = min(0.95, confidence + 0.1)
69
+ response += f". Additionally, {random.choice(self.response_templates['default'])}"
70
+
71
+ # Lightly incorporate context to simulate conditioning
72
+ context_snippet = context.strip()
73
+ if context_snippet:
74
+ confidence = min(0.99, confidence + 0.05)
75
+ response += f" (context signal: {context_snippet[:60]}{'...' if len(context_snippet) > 60 else ''})"
76
+
77
+ return {
78
+ "response": response,
79
+ "confidence": round(confidence, 3),
80
+ "tokens_used": len(prompt.split()) * 2 + len(response.split()),
81
+ }
82
+
83
+ async def generate_reasoning_steps(self, query: str, num_steps: int = 3) -> list[str]:
84
+ """Generate mock reasoning steps."""
85
+ await asyncio.sleep(random.uniform(0.05, 0.15))
86
+
87
+ base_steps = [
88
+ f"Analyzing query: '{query[:50]}...'",
89
+ "Identifying key requirements and constraints",
90
+ "Evaluating potential approaches",
91
+ "Considering trade-offs and implications",
92
+ "Synthesizing recommendations based on analysis",
93
+ "Validating conclusions against requirements",
94
+ ]
95
+
96
+ return random.sample(base_steps, min(num_steps, len(base_steps)))
97
+
98
+
99
+ class HuggingFaceClient:
100
+ """Lightweight Hugging Face Inference API client."""
101
+
102
+ def __init__(self, model_id: str = "mistralai/Mistral-7B-Instruct-v0.2"):
103
+ """Initialize with a Hugging Face model.
104
+
105
+ Args:
106
+ model_id: The model ID on Hugging Face Hub
107
+ """
108
+ self.model_id = model_id
109
+ self._client = None
110
+
111
+ def _get_client(self):
112
+ """Lazy load the HF client."""
113
+ if self._client is None:
114
+ try:
115
+ from huggingface_hub import InferenceClient
116
+
117
+ self._client = InferenceClient(model=self.model_id)
118
+ except ImportError:
119
+ raise ImportError("huggingface_hub not installed. Install with: pip install huggingface_hub")
120
+ return self._client
121
+
122
+ async def generate(self, prompt: str, context: str = "") -> dict[str, Any]:
123
+ """Generate response using Hugging Face Inference API."""
124
+ try:
125
+ client = self._get_client()
126
+
127
+ # Format prompt
128
+ if context:
129
+ full_prompt = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"
130
+ else:
131
+ full_prompt = f"Question: {prompt}\n\nProvide a concise, technical answer:\n\nAnswer:"
132
+
133
+ # Call HF Inference API (sync call wrapped in async)
134
+ response_text = await asyncio.to_thread(
135
+ client.text_generation, full_prompt, max_new_tokens=150, temperature=0.7, do_sample=True
136
+ )
137
+
138
+ # Estimate confidence based on response characteristics
139
+ confidence = min(0.95, 0.6 + len(response_text) / 500)
140
+
141
+ return {
142
+ "response": response_text.strip(),
143
+ "confidence": round(confidence, 3),
144
+ "tokens_used": len(full_prompt.split()) + len(response_text.split()),
145
+ }
146
+
147
+ except Exception as e:
148
+ # Fallback to mock on error
149
+ print(f"HF Inference error: {e}. Falling back to mock.")
150
+ mock = MockLLMClient()
151
+ return await mock.generate(prompt, context)
152
+
153
+ async def generate_reasoning_steps(self, query: str, num_steps: int = 3) -> list[str]:
154
+ """Generate reasoning steps using HF model."""
155
+ try:
156
+ client = self._get_client()
157
+
158
+ prompt = f"""Break down this question into {num_steps} reasoning steps:
159
+ Question: {query}
160
+
161
+ Reasoning steps (one per line):
162
+ 1."""
163
+
164
+ response = await asyncio.to_thread(client.text_generation, prompt, max_new_tokens=200, temperature=0.5)
165
+
166
+ # Parse steps from response
167
+ lines = response.strip().split("\n")
168
+ steps = []
169
+ for line in lines:
170
+ line = line.strip()
171
+ if line and not line.startswith("#"):
172
+ # Remove numbering
173
+ if line[0].isdigit() and "." in line[:3]:
174
+ line = line.split(".", 1)[1].strip()
175
+ steps.append(line)
176
+
177
+ return steps[:num_steps] if steps else ["Analysis in progress"]
178
+
179
+ except Exception as e:
180
+ print(f"HF reasoning error: {e}. Falling back to mock.")
181
+ mock = MockLLMClient()
182
+ return await mock.generate_reasoning_steps(query, num_steps)
demo_src/mcts_demo.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Educational MCTS demonstration using the production framework.
3
+
4
+ This demo uses the real MCTSEngine from src.framework.mcts.core to provide
5
+ an authentic learning experience while remaining accessible for demonstrations.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ from typing import Any
12
+
13
+ from src.framework.mcts.core import MCTSEngine, MCTSNode, MCTSState
14
+ from src.framework.mcts.policies import RolloutPolicy, SelectionPolicy
15
+
16
+
17
+ class DemoRolloutPolicy(RolloutPolicy):
18
+ """
19
+ Educational rollout policy for demo purposes.
20
+
21
+ Evaluates states based on:
22
+ - Depth of exploration (deeper = more thorough)
23
+ - Action quality (domain-specific heuristics)
24
+ - Exploration randomness
25
+ """
26
+
27
+ def __init__(self, category: str, action_templates: dict[str, list[str]]):
28
+ """
29
+ Initialize demo rollout policy.
30
+
31
+ Args:
32
+ category: Query category for heuristic evaluation
33
+ action_templates: Available action templates for scoring
34
+ """
35
+ self.category = category
36
+ self.action_templates = action_templates
37
+
38
+ # Define key terms that indicate quality actions per category
39
+ self.quality_indicators = {
40
+ "architecture": ["scalability", "consistency", "requirements"],
41
+ "optimization": ["profile", "caching", "parallel"],
42
+ "database": ["patterns", "relationships", "scaling"],
43
+ "distributed": ["circuit", "retry", "bulkhead"],
44
+ "default": ["decompose", "constraints", "trade-offs"],
45
+ }
46
+
47
+ async def evaluate(
48
+ self,
49
+ state: MCTSState,
50
+ rng,
51
+ max_depth: int = 10,
52
+ ) -> float:
53
+ """
54
+ Evaluate a state through heuristic analysis.
55
+
56
+ This combines:
57
+ - Depth bonus: rewards thorough exploration
58
+ - Action quality: rewards domain-appropriate actions
59
+ - Noise: adds exploration randomness
60
+
61
+ Args:
62
+ state: State to evaluate
63
+ rng: Random number generator
64
+ max_depth: Maximum depth (unused in heuristic)
65
+
66
+ Returns:
67
+ Estimated value in [0, 1] range
68
+ """
69
+ # Base value
70
+ base_value = 0.5
71
+
72
+ # Depth bonus: deeper exploration = more value (up to 0.3)
73
+ depth = state.features.get("depth", 0)
74
+ depth_bonus = min(depth * 0.1, 0.3)
75
+
76
+ # Action quality bonus
77
+ action_bonus = 0.0
78
+ last_action = state.features.get("last_action", "")
79
+
80
+ if last_action:
81
+ # Check if action contains quality indicators for this category
82
+ indicators = self.quality_indicators.get(self.category, self.quality_indicators["default"])
83
+ for term in indicators:
84
+ if term in last_action.lower():
85
+ action_bonus = 0.15
86
+ break
87
+
88
+ # Add exploration noise
89
+ noise = rng.uniform(-0.1, 0.1)
90
+
91
+ # Combine components
92
+ value = base_value + depth_bonus + action_bonus + noise
93
+
94
+ # Clamp to [0, 1]
95
+ return max(0.0, min(1.0, value))
96
+
97
+
98
+ class MCTSDemo:
99
+ """
100
+ Educational MCTS demonstration using the production framework.
101
+
102
+ This class wraps the production MCTSEngine to provide:
103
+ - Simple, educational interface for demos
104
+ - Category-based action selection
105
+ - Tree visualization for learning
106
+ - Deterministic behavior with seeds
107
+
108
+ Unlike the old mock implementation, this uses the real MCTS algorithm
109
+ with all its features: UCB1 selection, progressive widening, caching, etc.
110
+ """
111
+
112
+ def __init__(self, max_depth: int = 5):
113
+ """
114
+ Initialize MCTS demo.
115
+
116
+ Args:
117
+ max_depth: Maximum tree depth for exploration
118
+ """
119
+ self.max_depth = max_depth
120
+
121
+ # Action templates for different query types
122
+ # These provide domain-specific reasoning paths
123
+ self.action_templates = {
124
+ "architecture": [
125
+ "Consider microservices for scalability",
126
+ "Evaluate monolith for simplicity",
127
+ "Analyze team capabilities",
128
+ "Assess deployment requirements",
129
+ "Review data consistency needs",
130
+ ],
131
+ "optimization": [
132
+ "Profile application hotspots",
133
+ "Implement caching layer",
134
+ "Use parallel processing",
135
+ "Optimize database queries",
136
+ "Reduce memory allocations",
137
+ ],
138
+ "database": [
139
+ "Analyze query patterns",
140
+ "Consider data relationships",
141
+ "Evaluate consistency requirements",
142
+ "Plan for horizontal scaling",
143
+ "Assess read/write ratios",
144
+ ],
145
+ "distributed": [
146
+ "Implement circuit breakers",
147
+ "Add retry mechanisms",
148
+ "Use message queues",
149
+ "Apply bulkhead pattern",
150
+ "Design for eventual consistency",
151
+ ],
152
+ "default": [
153
+ "Decompose the problem",
154
+ "Identify constraints",
155
+ "Evaluate trade-offs",
156
+ "Consider alternatives",
157
+ "Validate assumptions",
158
+ ],
159
+ }
160
+
161
+ def _categorize_query(self, query: str) -> str:
162
+ """
163
+ Categorize query to select appropriate action templates.
164
+
165
+ Args:
166
+ query: User's input query
167
+
168
+ Returns:
169
+ Category name for action selection
170
+ """
171
+ query_lower = query.lower()
172
+ if "architecture" in query_lower or "microservice" in query_lower:
173
+ return "architecture"
174
+ elif "optim" in query_lower or "performance" in query_lower:
175
+ return "optimization"
176
+ elif "database" in query_lower or "sql" in query_lower:
177
+ return "database"
178
+ elif "distribut" in query_lower or "fault" in query_lower:
179
+ return "distributed"
180
+ return "default"
181
+
182
+ def _create_action_generator(self, category: str):
183
+ """
184
+ Create action generator function for this query category.
185
+
186
+ Args:
187
+ category: Query category
188
+
189
+ Returns:
190
+ Function that generates actions for a given state
191
+ """
192
+ def action_generator(state: MCTSState) -> list[str]:
193
+ """Generate available actions from current state."""
194
+ # Get category-specific actions
195
+ actions = self.action_templates.get(category, self.action_templates["default"])
196
+
197
+ # Filter out already-used actions (track via state features)
198
+ used_actions = state.features.get("used_actions", set())
199
+ available = [a for a in actions if a not in used_actions]
200
+
201
+ # If all actions used, allow re-exploring top 2
202
+ if not available:
203
+ return actions[:2]
204
+
205
+ return available
206
+
207
+ return action_generator
208
+
209
+ def _create_state_transition(self, category: str):
210
+ """
211
+ Create state transition function for this query category.
212
+
213
+ Args:
214
+ category: Query category
215
+
216
+ Returns:
217
+ Function that computes next state from current state + action
218
+ """
219
+ def state_transition(state: MCTSState, action: str) -> MCTSState:
220
+ """Compute next state by applying action."""
221
+ # Track action history
222
+ action_history = list(state.features.get("action_history", []))
223
+ action_history.append(action)
224
+
225
+ # Track used actions
226
+ used_actions = set(state.features.get("used_actions", set()))
227
+ used_actions.add(action)
228
+
229
+ # Increment depth
230
+ depth = state.features.get("depth", 0) + 1
231
+
232
+ # Create new state ID from action history
233
+ state_id = " -> ".join(action_history)
234
+
235
+ # Build new state
236
+ new_state = MCTSState(
237
+ state_id=state_id,
238
+ features={
239
+ "action_history": action_history,
240
+ "used_actions": used_actions,
241
+ "depth": depth,
242
+ "last_action": action,
243
+ "category": category,
244
+ },
245
+ )
246
+
247
+ return new_state
248
+
249
+ return state_transition
250
+
251
+ def _generate_tree_visualization(self, root: MCTSNode, max_nodes: int = 20) -> str:
252
+ """
253
+ Generate ASCII visualization of the MCTS tree.
254
+
255
+ This provides educational insight into the search process.
256
+
257
+ Args:
258
+ root: Root node of the tree
259
+ max_nodes: Maximum nodes to display
260
+
261
+ Returns:
262
+ ASCII art representation of the tree
263
+ """
264
+ max_nodes = max(1, max_nodes)
265
+ lines = []
266
+ lines.append("MCTS Tree Visualization")
267
+ lines.append("=" * 50)
268
+
269
+ nodes_rendered = 0
270
+
271
+ def format_node(node: MCTSNode, prefix: str = "", is_last: bool = True) -> list[str]:
272
+ nonlocal nodes_rendered
273
+ result = []
274
+
275
+ # Node representation
276
+ connector = "└── " if is_last else "├── "
277
+
278
+ if nodes_rendered >= max_nodes:
279
+ result.append(f"{prefix}{connector}... (truncated)")
280
+ return result
281
+
282
+ nodes_rendered += 1
283
+
284
+ # Display action or state
285
+ node_str = f"{node.state.state_id[:30]}..."
286
+ if node.action:
287
+ node_str = f"{node.action[:25]}..."
288
+
289
+ stats = f"[V:{node.visits}, Q:{node.value:.3f}]"
290
+
291
+ result.append(f"{prefix}{connector}{node_str} {stats}")
292
+
293
+ # Recursively add children
294
+ new_prefix = prefix + (" " if is_last else "│ ")
295
+
296
+ # Limit children shown
297
+ children_to_show = node.children[:3]
298
+ for i, child in enumerate(children_to_show):
299
+ is_child_last = i == len(children_to_show) - 1
300
+ result.extend(format_node(child, new_prefix, is_child_last))
301
+
302
+ if len(node.children) > 3:
303
+ result.append(f"{new_prefix} ... and {len(node.children) - 3} more")
304
+
305
+ return result
306
+
307
+ # Start with root
308
+ lines.append(f"Root: {root.state.state_id[:40]}... [V:{root.visits}, Q:{root.value:.3f}]")
309
+ nodes_rendered += 1
310
+
311
+ for i, child in enumerate(root.children[:5]):
312
+ is_last = i == len(root.children[:5]) - 1
313
+ lines.extend(format_node(child, "", is_last))
314
+
315
+ if len(root.children) > 5:
316
+ lines.append(f"... and {len(root.children) - 5} more branches")
317
+
318
+ return "\n".join(lines)
319
+
320
+ async def search(
321
+ self,
322
+ query: str,
323
+ iterations: int = 25,
324
+ exploration_weight: float = 1.414,
325
+ seed: int | None = None,
326
+ ) -> dict[str, Any]:
327
+ """
328
+ Run MCTS search on the query using the production framework.
329
+
330
+ This method demonstrates the full MCTS algorithm:
331
+ 1. Selection: UCB1-based tree traversal
332
+ 2. Expansion: Progressive widening of nodes
333
+ 3. Simulation: Heuristic evaluation (rollout)
334
+ 4. Backpropagation: Value updates up the tree
335
+
336
+ Args:
337
+ query: The input query to analyze
338
+ iterations: Number of MCTS iterations (more = better but slower)
339
+ exploration_weight: UCB1 exploration constant (higher = more exploration)
340
+ seed: Random seed for deterministic results
341
+
342
+ Returns:
343
+ Dictionary with:
344
+ - best_action: Recommended next step
345
+ - best_value: Confidence in recommendation
346
+ - statistics: Search metrics and performance data
347
+ - tree_visualization: ASCII art of search tree
348
+ """
349
+ # Determine query category
350
+ category = self._categorize_query(query)
351
+
352
+ # Initialize MCTS engine with production features
353
+ engine = MCTSEngine(
354
+ seed=seed if seed is not None else 42,
355
+ exploration_weight=exploration_weight,
356
+ progressive_widening_k=1.0, # Moderate expansion
357
+ progressive_widening_alpha=0.5,
358
+ max_parallel_rollouts=4,
359
+ cache_size_limit=10000,
360
+ )
361
+
362
+ # Create root state
363
+ root_state = MCTSState(
364
+ state_id=f"Query: {query[:50]}",
365
+ features={
366
+ "query": query,
367
+ "category": category,
368
+ "action_history": [],
369
+ "used_actions": set(),
370
+ "depth": 0,
371
+ "last_action": "",
372
+ },
373
+ )
374
+
375
+ # Create root node
376
+ root = MCTSNode(state=root_state, rng=engine.rng)
377
+
378
+ # Create domain-specific functions
379
+ action_generator = self._create_action_generator(category)
380
+ state_transition = self._create_state_transition(category)
381
+ rollout_policy = DemoRolloutPolicy(category, self.action_templates)
382
+
383
+ # Run MCTS search with production engine
384
+ best_action, stats = await engine.search(
385
+ root=root,
386
+ num_iterations=iterations,
387
+ action_generator=action_generator,
388
+ state_transition=state_transition,
389
+ rollout_policy=rollout_policy,
390
+ max_rollout_depth=self.max_depth,
391
+ selection_policy=SelectionPolicy.MAX_VISITS, # Most robust
392
+ )
393
+
394
+ # Extract best child info
395
+ best_child = None
396
+ if root.children:
397
+ best_child = max(root.children, key=lambda c: c.visits)
398
+
399
+ # Compile results for demo interface
400
+ result = {
401
+ "best_action": best_action or "No action found",
402
+ "best_value": round(best_child.value, 4) if best_child else 0.0,
403
+ "root_visits": root.visits,
404
+ "total_nodes": engine.get_cached_node_count(),
405
+ "max_depth_reached": engine.get_cached_tree_depth(),
406
+ "iterations_completed": iterations,
407
+ "exploration_weight": exploration_weight,
408
+ "seed": seed,
409
+ "category": category,
410
+
411
+ # Top actions sorted by visits
412
+ "top_actions": [
413
+ {
414
+ "action": child.action,
415
+ "visits": child.visits,
416
+ "value": round(child.value, 4),
417
+ "ucb1": round(
418
+ child.visits / root.visits if root.visits > 0 else 0.0, 4
419
+ ), # Simplified UCB display
420
+ }
421
+ for child in sorted(root.children, key=lambda c: -c.visits)[:5]
422
+ ],
423
+
424
+ # Framework statistics
425
+ "framework_stats": {
426
+ "cache_hits": stats.get("cache_hits", 0),
427
+ "cache_misses": stats.get("cache_misses", 0),
428
+ "cache_hit_rate": round(stats.get("cache_hit_rate", 0.0), 4),
429
+ "total_simulations": stats.get("total_simulations", 0),
430
+ },
431
+
432
+ # Educational visualization
433
+ "tree_visualization": self._generate_tree_visualization(root),
434
+ }
435
+
436
+ return result
demo_src/wandb_tracker.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Weights & Biases integration for experiment tracking.
3
+ """
4
+
5
+ import os
6
+ from datetime import datetime
7
+ from typing import Any
8
+
9
+ try:
10
+ import wandb
11
+
12
+ WANDB_AVAILABLE = True
13
+ except ImportError:
14
+ WANDB_AVAILABLE = False
15
+ wandb = None
16
+
17
+
18
+ class WandBTracker:
19
+ """Weights & Biases experiment tracker for multi-agent MCTS demo."""
20
+
21
+ def __init__(self, project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True):
22
+ """Initialize W&B tracker.
23
+
24
+ Args:
25
+ project_name: W&B project name
26
+ entity: W&B entity (username or team)
27
+ enabled: Whether tracking is enabled
28
+ """
29
+ self.project_name = project_name
30
+ self.entity = entity
31
+ self.enabled = enabled and WANDB_AVAILABLE
32
+ self.run = None
33
+ self.run_id = None
34
+
35
+ def is_available(self) -> bool:
36
+ """Check if W&B is available."""
37
+ return WANDB_AVAILABLE
38
+
39
+ def init_run(
40
+ self, run_name: str | None = None, config: dict[str, Any] | None = None, tags: list[str] | None = None
41
+ ) -> bool:
42
+ """Initialize a new W&B run.
43
+
44
+ Args:
45
+ run_name: Optional name for the run
46
+ config: Configuration dictionary to log
47
+ tags: Tags for the run
48
+
49
+ Returns:
50
+ True if run initialized successfully, False otherwise
51
+ """
52
+ if not self.enabled:
53
+ return False
54
+
55
+ try:
56
+ # Generate run name if not provided
57
+ if run_name is None:
58
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
59
+ run_name = f"mcts_query_{timestamp}"
60
+
61
+ # Default tags
62
+ if tags is None:
63
+ tags = ["demo", "multi-agent", "mcts"]
64
+
65
+ # Initialize run
66
+ self.run = wandb.init(
67
+ project=self.project_name,
68
+ entity=self.entity,
69
+ name=run_name,
70
+ config=config or {},
71
+ tags=tags,
72
+ reinit=True,
73
+ )
74
+
75
+ self.run_id = self.run.id
76
+ return True
77
+
78
+ except Exception as e:
79
+ print(f"W&B init error: {e}")
80
+ self.enabled = False
81
+ return False
82
+
83
+ def log_query_config(self, config: dict[str, Any]):
84
+ """Log query configuration.
85
+
86
+ Args:
87
+ config: Configuration dictionary with agent settings, MCTS params, etc.
88
+ """
89
+ if not self.enabled or not self.run:
90
+ return
91
+
92
+ try:
93
+ wandb.config.update(config)
94
+ except Exception as e:
95
+ print(f"W&B config log error: {e}")
96
+
97
+ def log_agent_result(
98
+ self,
99
+ agent_name: str,
100
+ response: str,
101
+ confidence: float,
102
+ execution_time_ms: float,
103
+ reasoning_steps: list[str] | None = None,
104
+ ):
105
+ """Log individual agent results.
106
+
107
+ Args:
108
+ agent_name: Name of the agent (HRM, TRM, MCTS)
109
+ response: Agent's response text
110
+ confidence: Confidence score (0-1)
111
+ execution_time_ms: Execution time in milliseconds
112
+ reasoning_steps: Optional list of reasoning steps
113
+ """
114
+ if not self.enabled or not self.run:
115
+ return
116
+
117
+ try:
118
+ metrics = {
119
+ f"{agent_name}/confidence": confidence,
120
+ f"{agent_name}/execution_time_ms": execution_time_ms,
121
+ f"{agent_name}/response_length": len(response),
122
+ }
123
+
124
+ if reasoning_steps:
125
+ metrics[f"{agent_name}/num_reasoning_steps"] = len(reasoning_steps)
126
+
127
+ wandb.log(metrics)
128
+
129
+ # Log response as text
130
+ wandb.log({f"{agent_name}/response": wandb.Html(f"<pre>{response}</pre>")})
131
+
132
+ except Exception as e:
133
+ print(f"W&B agent result log error: {e}")
134
+
135
+ def log_mcts_result(self, mcts_result: dict[str, Any]):
136
+ """Log MCTS-specific metrics.
137
+
138
+ Args:
139
+ mcts_result: Dictionary containing MCTS search results
140
+ """
141
+ if not self.enabled or not self.run:
142
+ return
143
+
144
+ try:
145
+ # Extract key metrics
146
+ metrics = {
147
+ "mcts/best_value": mcts_result.get("best_value", 0),
148
+ "mcts/root_visits": mcts_result.get("root_visits", 0),
149
+ "mcts/total_nodes": mcts_result.get("total_nodes", 0),
150
+ "mcts/max_depth": mcts_result.get("max_depth_reached", 0),
151
+ "mcts/iterations": mcts_result.get("iterations_completed", 0),
152
+ "mcts/exploration_weight": mcts_result.get("exploration_weight", 1.414),
153
+ }
154
+
155
+ wandb.log(metrics)
156
+
157
+ # Log top actions as table
158
+ if "top_actions" in mcts_result:
159
+ top_actions_data = []
160
+ for action in mcts_result["top_actions"]:
161
+ top_actions_data.append(
162
+ [
163
+ action.get("action", ""),
164
+ action.get("visits", 0),
165
+ action.get("value", 0),
166
+ action.get("ucb1", 0),
167
+ ]
168
+ )
169
+
170
+ if top_actions_data:
171
+ table = wandb.Table(data=top_actions_data, columns=["Action", "Visits", "Value", "UCB1"])
172
+ wandb.log({"mcts/top_actions_table": table})
173
+
174
+ # Log tree visualization as text artifact
175
+ if "tree_visualization" in mcts_result:
176
+ wandb.log({"mcts/tree_visualization": wandb.Html(f"<pre>{mcts_result['tree_visualization']}</pre>")})
177
+
178
+ except Exception as e:
179
+ print(f"W&B MCTS result log error: {e}")
180
+
181
+ def log_consensus(self, consensus_score: float, agents_used: list[str], final_response: str):
182
+ """Log consensus metrics.
183
+
184
+ Args:
185
+ consensus_score: Agreement score between agents (0-1)
186
+ agents_used: List of agent names that were used
187
+ final_response: Final synthesized response
188
+ """
189
+ if not self.enabled or not self.run:
190
+ return
191
+
192
+ try:
193
+ wandb.log(
194
+ {
195
+ "consensus/score": consensus_score,
196
+ "consensus/num_agents": len(agents_used),
197
+ "consensus/agents": ", ".join(agents_used),
198
+ "consensus/final_response_length": len(final_response),
199
+ }
200
+ )
201
+
202
+ # Categorize consensus level
203
+ if consensus_score > 0.7:
204
+ consensus_level = "high"
205
+ elif consensus_score > 0.4:
206
+ consensus_level = "medium"
207
+ else:
208
+ consensus_level = "low"
209
+
210
+ wandb.log({"consensus/level": consensus_level})
211
+
212
+ except Exception as e:
213
+ print(f"W&B consensus log error: {e}")
214
+
215
+ def log_performance(self, total_time_ms: float):
216
+ """Log overall performance metrics.
217
+
218
+ Args:
219
+ total_time_ms: Total execution time in milliseconds
220
+ """
221
+ if not self.enabled or not self.run:
222
+ return
223
+
224
+ try:
225
+ wandb.log({"performance/total_time_ms": total_time_ms, "performance/total_time_s": total_time_ms / 1000})
226
+ except Exception as e:
227
+ print(f"W&B performance log error: {e}")
228
+
229
+ def log_full_result(self, result: dict[str, Any]):
230
+ """Log the complete result as an artifact.
231
+
232
+ Args:
233
+ result: Full framework result dictionary
234
+ """
235
+ if not self.enabled or not self.run:
236
+ return
237
+
238
+ try:
239
+ # Create artifact
240
+ artifact = wandb.Artifact(name=f"query_result_{self.run_id}", type="result")
241
+
242
+ # Add result as JSON
243
+ import json
244
+ import tempfile
245
+
246
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
247
+ json.dump(result, f, indent=2, default=str)
248
+ temp_path = f.name
249
+
250
+ artifact.add_file(temp_path, name="result.json")
251
+ wandb.log_artifact(artifact)
252
+
253
+ # Clean up temp file
254
+ os.unlink(temp_path)
255
+
256
+ except Exception as e:
257
+ print(f"W&B full result log error: {e}")
258
+
259
+ def log_query_summary(
260
+ self, query: str, use_hrm: bool, use_trm: bool, use_mcts: bool, consensus_score: float, total_time_ms: float
261
+ ):
262
+ """Log a summary row for the query.
263
+
264
+ Args:
265
+ query: The input query
266
+ use_hrm: Whether HRM was enabled
267
+ use_trm: Whether TRM was enabled
268
+ use_mcts: Whether MCTS was enabled
269
+ consensus_score: Final consensus score
270
+ total_time_ms: Total execution time
271
+ """
272
+ if not self.enabled or not self.run:
273
+ return
274
+
275
+ try:
276
+ # Create summary table entry
277
+ summary_data = [
278
+ [
279
+ query[:100] + "..." if len(query) > 100 else query,
280
+ "✓" if use_hrm else "✗",
281
+ "✓" if use_trm else "✗",
282
+ "✓" if use_mcts else "✗",
283
+ f"{consensus_score:.1%}",
284
+ f"{total_time_ms:.2f}",
285
+ ]
286
+ ]
287
+
288
+ table = wandb.Table(data=summary_data, columns=["Query", "HRM", "TRM", "MCTS", "Consensus", "Time (ms)"])
289
+
290
+ wandb.log({"query_summary": table})
291
+
292
+ except Exception as e:
293
+ print(f"W&B summary log error: {e}")
294
+
295
+ def finish_run(self):
296
+ """Finish the current W&B run."""
297
+ if not self.enabled or not self.run:
298
+ return
299
+
300
+ try:
301
+ wandb.finish()
302
+ self.run = None
303
+ self.run_id = None
304
+ except Exception as e:
305
+ print(f"W&B finish error: {e}")
306
+
307
+ def get_run_url(self) -> str | None:
308
+ """Get the URL for the current run.
309
+
310
+ Returns:
311
+ URL string or None if no active run
312
+ """
313
+ if not self.enabled or not self.run:
314
+ return None
315
+
316
+ try:
317
+ return self.run.get_url()
318
+ except Exception:
319
+ return None
320
+
321
+
322
+ # Global tracker instance
323
+ _global_tracker: WandBTracker | None = None
324
+
325
+
326
+ def get_tracker(
327
+ project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True
328
+ ) -> WandBTracker:
329
+ """Get or create the global W&B tracker.
330
+
331
+ Args:
332
+ project_name: W&B project name
333
+ entity: W&B entity
334
+ enabled: Whether tracking is enabled
335
+
336
+ Returns:
337
+ WandBTracker instance
338
+ """
339
+ global _global_tracker
340
+
341
+ if _global_tracker is None:
342
+ _global_tracker = WandBTracker(project_name=project_name, entity=entity, enabled=enabled)
343
+
344
+ return _global_tracker
345
+
346
+
347
+ def is_wandb_available() -> bool:
348
+ """Check if W&B is available."""
349
+ return WANDB_AVAILABLE
models/bert_lora/final_model/README.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: prajjwal1/bert-mini
3
+ library_name: peft
4
+ tags:
5
+ - base_model:adapter:prajjwal1/bert-mini
6
+ - lora
7
+ - transformers
8
+ ---
9
+
10
+ # Model Card for Model ID
11
+
12
+ <!-- Provide a quick summary of what the model is/does. -->
13
+
14
+
15
+
16
+ ## Model Details
17
+
18
+ ### Model Description
19
+
20
+ <!-- Provide a longer summary of what this model is. -->
21
+
22
+
23
+
24
+ - **Developed by:** [More Information Needed]
25
+ - **Funded by [optional]:** [More Information Needed]
26
+ - **Shared by [optional]:** [More Information Needed]
27
+ - **Model type:** [More Information Needed]
28
+ - **Language(s) (NLP):** [More Information Needed]
29
+ - **License:** [More Information Needed]
30
+ - **Finetuned from model [optional]:** [More Information Needed]
31
+
32
+ ### Model Sources [optional]
33
+
34
+ <!-- Provide the basic links for the model. -->
35
+
36
+ - **Repository:** [More Information Needed]
37
+ - **Paper [optional]:** [More Information Needed]
38
+ - **Demo [optional]:** [More Information Needed]
39
+
40
+ ## Uses
41
+
42
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
43
+
44
+ ### Direct Use
45
+
46
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
47
+
48
+ [More Information Needed]
49
+
50
+ ### Downstream Use [optional]
51
+
52
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
53
+
54
+ [More Information Needed]
55
+
56
+ ### Out-of-Scope Use
57
+
58
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
59
+
60
+ [More Information Needed]
61
+
62
+ ## Bias, Risks, and Limitations
63
+
64
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
65
+
66
+ [More Information Needed]
67
+
68
+ ### Recommendations
69
+
70
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
71
+
72
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
73
+
74
+ ## How to Get Started with the Model
75
+
76
+ Use the code below to get started with the model.
77
+
78
+ [More Information Needed]
79
+
80
+ ## Training Details
81
+
82
+ ### Training Data
83
+
84
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
85
+
86
+ [More Information Needed]
87
+
88
+ ### Training Procedure
89
+
90
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
91
+
92
+ #### Preprocessing [optional]
93
+
94
+ [More Information Needed]
95
+
96
+
97
+ #### Training Hyperparameters
98
+
99
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
100
+
101
+ #### Speeds, Sizes, Times [optional]
102
+
103
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
104
+
105
+ [More Information Needed]
106
+
107
+ ## Evaluation
108
+
109
+ <!-- This section describes the evaluation protocols and provides the results. -->
110
+
111
+ ### Testing Data, Factors & Metrics
112
+
113
+ #### Testing Data
114
+
115
+ <!-- This should link to a Dataset Card if possible. -->
116
+
117
+ [More Information Needed]
118
+
119
+ #### Factors
120
+
121
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
122
+
123
+ [More Information Needed]
124
+
125
+ #### Metrics
126
+
127
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
128
+
129
+ [More Information Needed]
130
+
131
+ ### Results
132
+
133
+ [More Information Needed]
134
+
135
+ #### Summary
136
+
137
+
138
+
139
+ ## Model Examination [optional]
140
+
141
+ <!-- Relevant interpretability work for the model goes here -->
142
+
143
+ [More Information Needed]
144
+
145
+ ## Environmental Impact
146
+
147
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
148
+
149
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
150
+
151
+ - **Hardware Type:** [More Information Needed]
152
+ - **Hours used:** [More Information Needed]
153
+ - **Cloud Provider:** [More Information Needed]
154
+ - **Compute Region:** [More Information Needed]
155
+ - **Carbon Emitted:** [More Information Needed]
156
+
157
+ ## Technical Specifications [optional]
158
+
159
+ ### Model Architecture and Objective
160
+
161
+ [More Information Needed]
162
+
163
+ ### Compute Infrastructure
164
+
165
+ [More Information Needed]
166
+
167
+ #### Hardware
168
+
169
+ [More Information Needed]
170
+
171
+ #### Software
172
+
173
+ [More Information Needed]
174
+
175
+ ## Citation [optional]
176
+
177
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
178
+
179
+ **BibTeX:**
180
+
181
+ [More Information Needed]
182
+
183
+ **APA:**
184
+
185
+ [More Information Needed]
186
+
187
+ ## Glossary [optional]
188
+
189
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
190
+
191
+ [More Information Needed]
192
+
193
+ ## More Information [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Authors [optional]
198
+
199
+ [More Information Needed]
200
+
201
+ ## Model Card Contact
202
+
203
+ [More Information Needed]
204
+ ### Framework versions
205
+
206
+ - PEFT 0.17.1
models/bert_lora/final_model/adapter_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "prajjwal1/bert-mini",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 16,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.1,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "classifier",
23
+ "score"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "qalora_group_size": 16,
27
+ "r": 4,
28
+ "rank_pattern": {},
29
+ "revision": null,
30
+ "target_modules": [
31
+ "query",
32
+ "value"
33
+ ],
34
+ "target_parameters": null,
35
+ "task_type": "SEQ_CLS",
36
+ "trainable_token_indices": null,
37
+ "use_dora": false,
38
+ "use_qalora": false,
39
+ "use_rslora": false
40
+ }
models/bert_lora/final_model/adapter_model.safetensors ADDED
Binary file (71 kB). View file
 
models/bert_lora/generated_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
models/bert_lora/training_results.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config": {
3
+ "model_name": "prajjwal1/bert-mini",
4
+ "lora_r": 4,
5
+ "lora_alpha": 16,
6
+ "lora_dropout": 0.1,
7
+ "lr": 0.001,
8
+ "batch_size": 16,
9
+ "epochs": 5,
10
+ "warmup_steps": 100,
11
+ "seed": 42,
12
+ "num_samples": 1000,
13
+ "data_path": null,
14
+ "balanced": true,
15
+ "output_dir": "models/bert_lora"
16
+ },
17
+ "train_history": {
18
+ "train_loss": 1.1033503922549162,
19
+ "train_runtime": 11.0946,
20
+ "train_samples_per_second": 315.018,
21
+ "epochs": 5,
22
+ "final_metrics": {
23
+ "train_runtime": 11.0946,
24
+ "train_samples_per_second": 315.018,
25
+ "train_steps_per_second": 19.829,
26
+ "total_flos": 34821822412800.0,
27
+ "train_loss": 1.1033503922549162,
28
+ "epoch": 5.0
29
+ },
30
+ "eval_results": {
31
+ "eval_loss": 1.0453400611877441,
32
+ "eval_accuracy": 0.47651006711409394,
33
+ "eval_runtime": 0.1251,
34
+ "eval_samples_per_second": 1191.171,
35
+ "eval_steps_per_second": 79.944,
36
+ "epoch": 5.0
37
+ }
38
+ },
39
+ "test_results": {
40
+ "loss": 1.0559743153338401,
41
+ "accuracy": 0.4768211920529801
42
+ },
43
+ "model_params": {
44
+ "total_params": 11188486,
45
+ "trainable_params": 17155,
46
+ "trainable_percentage": 0.15
47
+ }
48
+ }
models/rnn_meta_controller.history.json ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config": {
3
+ "hidden_dim": 64,
4
+ "num_layers": 1,
5
+ "dropout": 0.1,
6
+ "lr": 0.001,
7
+ "batch_size": 32,
8
+ "epochs": 20,
9
+ "patience": 5,
10
+ "seed": 42,
11
+ "num_samples": 1000
12
+ },
13
+ "training_history": {
14
+ "train_losses": [
15
+ 1.060307163180727,
16
+ 0.9014069383794611,
17
+ 0.6105747597687172,
18
+ 0.35656250968123926,
19
+ 0.22574858390020602,
20
+ 0.16157509059165465,
21
+ 0.12456387586214325,
22
+ 0.10158110240643675,
23
+ 0.08592396827809738,
24
+ 0.07474524908783761,
25
+ 0.06479036057311477,
26
+ 0.057878461638183304,
27
+ 0.052609961931452606,
28
+ 0.04809149278497154,
29
+ 0.043710527828697,
30
+ 0.041286276738074695,
31
+ 0.03756282673302022,
32
+ 0.03491098284156936,
33
+ 0.031911260236731985,
34
+ 0.030496817025722878
35
+ ],
36
+ "val_losses": [
37
+ 1.0059996803601583,
38
+ 0.7808501919110616,
39
+ 0.47826388080914817,
40
+ 0.29279296696186063,
41
+ 0.2008462185660998,
42
+ 0.1529717780649662,
43
+ 0.12299496456980705,
44
+ 0.10291122049093246,
45
+ 0.08860023791591326,
46
+ 0.07790809428940217,
47
+ 0.06982718824098508,
48
+ 0.06387854401643077,
49
+ 0.05984275036801894,
50
+ 0.05463591649507483,
51
+ 0.04938021237030625,
52
+ 0.0452831008626769,
53
+ 0.04252756762628754,
54
+ 0.039516554485696055,
55
+ 0.038632405494960644,
56
+ 0.035608950459087886
57
+ ],
58
+ "val_accuracies": [
59
+ 0.8466666666666667,
60
+ 0.92,
61
+ 0.9822222222222222,
62
+ 0.9933333333333333,
63
+ 0.9911111111111112,
64
+ 0.9933333333333333,
65
+ 0.9955555555555555,
66
+ 0.9955555555555555,
67
+ 0.9955555555555555,
68
+ 0.9955555555555555,
69
+ 0.9955555555555555,
70
+ 0.9977777777777778,
71
+ 0.9933333333333333,
72
+ 0.9933333333333333,
73
+ 0.9977777777777778,
74
+ 0.9977777777777778,
75
+ 0.9977777777777778,
76
+ 0.9977777777777778,
77
+ 0.9955555555555555,
78
+ 0.9977777777777778
79
+ ],
80
+ "best_epoch": 20,
81
+ "best_val_loss": 0.035608950459087886,
82
+ "best_val_accuracy": 0.9977777777777778,
83
+ "stopped_early": false,
84
+ "total_epochs": 20
85
+ },
86
+ "test_results": {
87
+ "loss": 0.022989434589787076,
88
+ "accuracy": 0.9977777777777778,
89
+ "per_class_metrics": {
90
+ "hrm": {
91
+ "precision": 1.0,
92
+ "recall": 1.0,
93
+ "f1_score": 1.0,
94
+ "support": 153
95
+ },
96
+ "trm": {
97
+ "precision": 0.9933774834437086,
98
+ "recall": 1.0,
99
+ "f1_score": 0.9966777408637874,
100
+ "support": 150
101
+ },
102
+ "mcts": {
103
+ "precision": 1.0,
104
+ "recall": 0.9931972789115646,
105
+ "f1_score": 0.9965870307167235,
106
+ "support": 147
107
+ }
108
+ },
109
+ "confusion_matrix": [
110
+ [
111
+ 153,
112
+ 0,
113
+ 0
114
+ ],
115
+ [
116
+ 0,
117
+ 150,
118
+ 0
119
+ ],
120
+ [
121
+ 0,
122
+ 1,
123
+ 146
124
+ ]
125
+ ],
126
+ "total_samples": 450
127
+ }
128
+ }
models/rnn_meta_controller.pt ADDED
Binary file (61.6 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LangGraph Multi-Agent MCTS Demo - Dependencies
2
+ # Optimized for Hugging Face Spaces deployment with trained models
3
+
4
+ # Core UI Framework
5
+ gradio>=4.0.0,<5.0.0
6
+
7
+ # Numerical computation
8
+ numpy>=1.24.0,<2.0.0
9
+
10
+ # Machine Learning - Neural Models
11
+ torch>=2.1.0
12
+ transformers>=4.40.0
13
+ peft>=0.7.0
14
+ sentence-transformers>=2.2.0
15
+
16
+ # Configuration
17
+ pyyaml>=6.0
18
+
19
+ # Experiment Tracking
20
+ wandb>=0.16.0
21
+
22
+ # Required for Gradio OAuth and model loading
23
+ huggingface_hub>=0.20.0,<0.30.0
24
+
25
+ # Note: This demo now uses REAL trained models:
26
+ # - RNN Meta-Controller (models/rnn_meta_controller.pt)
27
+ # - BERT with LoRA adapters (models/bert_lora/final_model/)
28
+ # - Actual HRM and TRM agent implementations
src/__init__.py ADDED
File without changes
src/adapters/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapters package for external service integrations.
3
+ """
4
+
5
+ from .llm import BaseLLMClient, LLMResponse, create_client
6
+
7
+ __all__ = ["create_client", "BaseLLMClient", "LLMResponse"]
src/adapters/llm/__init__.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Client Factory and Provider Registry.
3
+
4
+ This module provides a factory function to instantiate the correct LLM client
5
+ based on provider settings, with lazy loading of adapters.
6
+ """
7
+
8
+ import importlib
9
+ import logging
10
+ from typing import Any
11
+
12
+ from .base import BaseLLMClient, LLMClient, LLMResponse, LLMToolResponse, ToolCall
13
+ from .exceptions import (
14
+ CircuitBreakerOpenError,
15
+ LLMAuthenticationError,
16
+ LLMClientError,
17
+ LLMConnectionError,
18
+ LLMContentFilterError,
19
+ LLMContextLengthError,
20
+ LLMInvalidRequestError,
21
+ LLMModelNotFoundError,
22
+ LLMQuotaExceededError,
23
+ LLMRateLimitError,
24
+ LLMResponseParseError,
25
+ LLMServerError,
26
+ LLMStreamError,
27
+ LLMTimeoutError,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Provider registry with lazy loading
33
+ # Maps provider name to (module_path, class_name)
34
+ _PROVIDER_REGISTRY: dict[str, tuple[str, str]] = {
35
+ "openai": ("src.adapters.llm.openai_client", "OpenAIClient"),
36
+ "anthropic": ("src.adapters.llm.anthropic_client", "AnthropicClient"),
37
+ "lmstudio": ("src.adapters.llm.lmstudio_client", "LMStudioClient"),
38
+ "local": ("src.adapters.llm.lmstudio_client", "LMStudioClient"), # Alias
39
+ }
40
+
41
+ # Cache for loaded client classes
42
+ _CLIENT_CACHE: dict[str, type[BaseLLMClient]] = {}
43
+
44
+
45
+ def register_provider(name: str, module_path: str, class_name: str, override: bool = False) -> None:
46
+ """
47
+ Register a new LLM provider.
48
+
49
+ Args:
50
+ name: Provider identifier (e.g., "azure", "bedrock")
51
+ module_path: Full module path (e.g., "src.adapters.llm.azure_client")
52
+ class_name: Class name in the module (e.g., "AzureOpenAIClient")
53
+ override: If True, allow overriding existing provider
54
+ """
55
+ if name in _PROVIDER_REGISTRY and not override:
56
+ raise ValueError(f"Provider '{name}' already registered. Use override=True to replace.")
57
+
58
+ _PROVIDER_REGISTRY[name] = (module_path, class_name)
59
+ # Clear cache if overriding
60
+ if name in _CLIENT_CACHE:
61
+ del _CLIENT_CACHE[name]
62
+
63
+ logger.info(f"Registered LLM provider: {name} -> {module_path}.{class_name}")
64
+
65
+
66
+ def list_providers() -> list[str]:
67
+ """
68
+ List all registered provider names.
69
+
70
+ Returns:
71
+ List of provider identifiers
72
+ """
73
+ return list(_PROVIDER_REGISTRY.keys())
74
+
75
+
76
+ def get_provider_class(provider: str) -> type[BaseLLMClient]:
77
+ """
78
+ Get the client class for a provider (with lazy loading).
79
+
80
+ Args:
81
+ provider: Provider identifier
82
+
83
+ Returns:
84
+ Client class (not instantiated)
85
+
86
+ Raises:
87
+ ValueError: If provider not registered
88
+ ImportError: If module cannot be loaded
89
+ """
90
+ if provider not in _PROVIDER_REGISTRY:
91
+ available = ", ".join(list_providers())
92
+ raise ValueError(f"Unknown provider '{provider}'. Available: {available}")
93
+
94
+ # Check cache first
95
+ if provider in _CLIENT_CACHE:
96
+ return _CLIENT_CACHE[provider]
97
+
98
+ # Lazy load the module
99
+ module_path, class_name = _PROVIDER_REGISTRY[provider]
100
+
101
+ try:
102
+ module = importlib.import_module(module_path)
103
+ client_class = getattr(module, class_name)
104
+ except ImportError as e:
105
+ raise ImportError(f"Failed to load provider '{provider}': {e}") from e
106
+ except AttributeError as e:
107
+ raise ImportError(f"Class '{class_name}' not found in module '{module_path}'") from e
108
+
109
+ # Cache for future use
110
+ _CLIENT_CACHE[provider] = client_class
111
+ return client_class
112
+
113
+
114
+ def create_client(
115
+ provider: str = "openai",
116
+ *,
117
+ api_key: str | None = None,
118
+ model: str | None = None,
119
+ base_url: str | None = None,
120
+ timeout: float | None = None,
121
+ max_retries: int | None = None,
122
+ **kwargs: Any,
123
+ ) -> BaseLLMClient:
124
+ """
125
+ Create an LLM client instance.
126
+
127
+ This is the main factory function for creating provider clients.
128
+
129
+ Args:
130
+ provider: Provider name ("openai", "anthropic", "lmstudio", etc.)
131
+ api_key: API key (may be optional for some providers)
132
+ model: Model identifier
133
+ base_url: Base URL for API
134
+ timeout: Request timeout in seconds
135
+ max_retries: Maximum retry attempts
136
+ **kwargs: Provider-specific parameters
137
+
138
+ Returns:
139
+ Configured LLMClient instance
140
+
141
+ Examples:
142
+ # OpenAI client
143
+ client = create_client("openai", model="gpt-4-turbo-preview")
144
+
145
+ # Anthropic client
146
+ client = create_client("anthropic", model="sonnet")
147
+
148
+ # Local LM Studio
149
+ client = create_client("lmstudio", base_url="http://localhost:1234/v1")
150
+
151
+ # With custom settings
152
+ client = create_client(
153
+ "openai",
154
+ api_key="sk-...",
155
+ timeout=120.0,
156
+ max_retries=5,
157
+ organization="org-..."
158
+ )
159
+ """
160
+ client_class = get_provider_class(provider)
161
+
162
+ # Build kwargs for client initialization
163
+ init_kwargs = {**kwargs}
164
+
165
+ if api_key is not None:
166
+ init_kwargs["api_key"] = api_key
167
+ if model is not None:
168
+ init_kwargs["model"] = model
169
+ if base_url is not None:
170
+ init_kwargs["base_url"] = base_url
171
+ if timeout is not None:
172
+ init_kwargs["timeout"] = timeout
173
+ if max_retries is not None:
174
+ init_kwargs["max_retries"] = max_retries
175
+
176
+ logger.info(f"Creating {provider} client with model={model or 'default'}")
177
+
178
+ return client_class(**init_kwargs)
179
+
180
+
181
+ def create_client_from_config(config: dict) -> BaseLLMClient:
182
+ """
183
+ Create an LLM client from a configuration dictionary.
184
+
185
+ Useful for loading settings from YAML/JSON config files.
186
+
187
+ Args:
188
+ config: Configuration dictionary with keys:
189
+ - provider: Required provider name
190
+ - Other keys passed to create_client
191
+
192
+ Returns:
193
+ Configured LLMClient instance
194
+
195
+ Example:
196
+ config = {
197
+ "provider": "openai",
198
+ "model": "gpt-4-turbo-preview",
199
+ "timeout": 60.0,
200
+ "max_retries": 3
201
+ }
202
+ client = create_client_from_config(config)
203
+ """
204
+ config = config.copy()
205
+ provider = config.pop("provider", "openai")
206
+ return create_client(provider, **config)
207
+
208
+
209
+ # Convenience aliases for common use cases
210
+ def create_openai_client(**kwargs) -> BaseLLMClient:
211
+ """Create an OpenAI client."""
212
+ return create_client("openai", **kwargs)
213
+
214
+
215
+ def create_anthropic_client(**kwargs) -> BaseLLMClient:
216
+ """Create an Anthropic Claude client."""
217
+ return create_client("anthropic", **kwargs)
218
+
219
+
220
+ def create_local_client(**kwargs) -> BaseLLMClient:
221
+ """Create a local LM Studio client."""
222
+ return create_client("lmstudio", **kwargs)
223
+
224
+
225
+ __all__ = [
226
+ # Base types
227
+ "LLMClient",
228
+ "LLMResponse",
229
+ "LLMToolResponse",
230
+ "ToolCall",
231
+ "BaseLLMClient",
232
+ # Exceptions
233
+ "LLMClientError",
234
+ "LLMAuthenticationError",
235
+ "LLMRateLimitError",
236
+ "LLMQuotaExceededError",
237
+ "LLMModelNotFoundError",
238
+ "LLMContextLengthError",
239
+ "LLMInvalidRequestError",
240
+ "LLMTimeoutError",
241
+ "LLMConnectionError",
242
+ "LLMServerError",
243
+ "LLMResponseParseError",
244
+ "LLMStreamError",
245
+ "LLMContentFilterError",
246
+ "CircuitBreakerOpenError",
247
+ # Factory functions
248
+ "create_client",
249
+ "create_client_from_config",
250
+ "create_openai_client",
251
+ "create_anthropic_client",
252
+ "create_local_client",
253
+ # Registry functions
254
+ "register_provider",
255
+ "list_providers",
256
+ "get_provider_class",
257
+ ]
src/adapters/llm/anthropic_client.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anthropic Claude LLM client adapter.
3
+
4
+ Implements the LLMClient protocol for Anthropic's Messages API.
5
+ Supports Claude 3 models with proper content block handling.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from collections.abc import AsyncIterator
11
+ from typing import Any
12
+
13
+ import httpx
14
+ from tenacity import (
15
+ before_sleep_log,
16
+ retry,
17
+ retry_if_exception_type,
18
+ stop_after_attempt,
19
+ wait_exponential,
20
+ )
21
+
22
+ from .base import BaseLLMClient, LLMResponse, LLMToolResponse, ToolCall
23
+ from .exceptions import (
24
+ CircuitBreakerOpenError,
25
+ LLMAuthenticationError,
26
+ LLMClientError,
27
+ LLMConnectionError,
28
+ LLMContentFilterError,
29
+ LLMContextLengthError,
30
+ LLMInvalidRequestError,
31
+ LLMModelNotFoundError,
32
+ LLMQuotaExceededError,
33
+ LLMRateLimitError,
34
+ LLMResponseParseError,
35
+ LLMServerError,
36
+ LLMStreamError,
37
+ LLMTimeoutError,
38
+ )
39
+ from .openai_client import CircuitBreaker
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ # Model mappings for convenience
45
+ ANTHROPIC_MODELS = {
46
+ "claude-3-opus": "claude-3-opus-20240229",
47
+ "claude-3-sonnet": "claude-3-sonnet-20240229",
48
+ "claude-3-haiku": "claude-3-haiku-20240307",
49
+ "claude-3.5-sonnet": "claude-3-5-sonnet-20240620",
50
+ "claude-3.5-sonnet-v2": "claude-3-5-sonnet-20241022",
51
+ "claude-sonnet-4": "claude-sonnet-4-20250514",
52
+ # Add latest models
53
+ "opus": "claude-3-opus-20240229",
54
+ "sonnet": "claude-3-5-sonnet-20241022",
55
+ "haiku": "claude-3-haiku-20240307",
56
+ }
57
+
58
+
59
+ class AnthropicClient(BaseLLMClient):
60
+ """
61
+ Anthropic Claude API client.
62
+
63
+ Features:
64
+ - Messages API support (not legacy completion API)
65
+ - Content block handling (text, tool_use)
66
+ - Streaming with proper SSE parsing
67
+ - Model alias mapping
68
+ - System prompt support
69
+ - Tool/function calling (beta)
70
+ """
71
+
72
+ PROVIDER_NAME = "anthropic"
73
+ DEFAULT_BASE_URL = "https://api.anthropic.com"
74
+ DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
75
+ API_VERSION = "2023-06-01"
76
+
77
+ def __init__(
78
+ self,
79
+ api_key: str | None = None,
80
+ model: str | None = None,
81
+ base_url: str | None = None,
82
+ timeout: float = 120.0, # Claude can be slower
83
+ max_retries: int = 3,
84
+ # Circuit breaker settings
85
+ circuit_breaker_threshold: int = 5,
86
+ circuit_breaker_reset: float = 60.0,
87
+ # Rate limiting
88
+ rate_limit_per_minute: int | None = None,
89
+ ):
90
+ """
91
+ Initialize Anthropic client.
92
+
93
+ Args:
94
+ api_key: Anthropic API key (or set ANTHROPIC_API_KEY env var)
95
+ model: Model to use (supports aliases like 'sonnet', 'opus')
96
+ base_url: API base URL
97
+ timeout: Request timeout in seconds (default longer for Claude)
98
+ max_retries: Max retry attempts
99
+ circuit_breaker_threshold: Failures before circuit opens
100
+ circuit_breaker_reset: Seconds before circuit resets
101
+ rate_limit_per_minute: Rate limit for requests per minute (None to disable)
102
+ """
103
+ import os
104
+
105
+ api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
106
+ if not api_key:
107
+ raise LLMAuthenticationError(self.PROVIDER_NAME, "API key not provided and ANTHROPIC_API_KEY not set")
108
+
109
+ # Resolve model alias
110
+ model_name = model or self.DEFAULT_MODEL
111
+ resolved_model = ANTHROPIC_MODELS.get(model_name, model_name)
112
+
113
+ super().__init__(
114
+ api_key=api_key,
115
+ model=resolved_model,
116
+ base_url=base_url or self.DEFAULT_BASE_URL,
117
+ timeout=timeout,
118
+ max_retries=max_retries,
119
+ rate_limit_per_minute=rate_limit_per_minute,
120
+ )
121
+
122
+ self.circuit_breaker = CircuitBreaker(
123
+ failure_threshold=circuit_breaker_threshold,
124
+ reset_timeout=circuit_breaker_reset,
125
+ )
126
+ self._client: httpx.AsyncClient | None = None
127
+
128
+ async def _get_client(self) -> httpx.AsyncClient:
129
+ """Get or create the HTTP client."""
130
+ if self._client is None or self._client.is_closed:
131
+ headers = {
132
+ "x-api-key": self.api_key,
133
+ "anthropic-version": self.API_VERSION,
134
+ "Content-Type": "application/json",
135
+ }
136
+
137
+ self._client = httpx.AsyncClient(
138
+ base_url=self.base_url,
139
+ headers=headers,
140
+ timeout=httpx.Timeout(self.timeout),
141
+ )
142
+ return self._client
143
+
144
+ def _convert_messages_to_anthropic(self, messages: list[dict]) -> tuple[str | None, list[dict]]:
145
+ """
146
+ Convert OpenAI-style messages to Anthropic format.
147
+
148
+ Returns:
149
+ Tuple of (system_prompt, messages)
150
+ """
151
+ system_prompt = None
152
+ anthropic_messages = []
153
+
154
+ for msg in messages:
155
+ role = msg.get("role", "user")
156
+ content = msg.get("content", "")
157
+
158
+ if role == "system":
159
+ # Anthropic uses separate system parameter
160
+ system_prompt = content
161
+ elif role == "assistant":
162
+ anthropic_messages.append({"role": "assistant", "content": content})
163
+ elif role == "user":
164
+ anthropic_messages.append({"role": "user", "content": content})
165
+ elif role == "tool":
166
+ # Tool result message
167
+ anthropic_messages.append(
168
+ {
169
+ "role": "user",
170
+ "content": [
171
+ {
172
+ "type": "tool_result",
173
+ "tool_use_id": msg.get("tool_call_id", ""),
174
+ "content": content,
175
+ }
176
+ ],
177
+ }
178
+ )
179
+
180
+ return system_prompt, anthropic_messages
181
+
182
+ def _convert_tools_to_anthropic(self, tools: list[dict]) -> list[dict]:
183
+ """Convert OpenAI-style tool definitions to Anthropic format."""
184
+ anthropic_tools = []
185
+
186
+ for tool in tools:
187
+ if tool.get("type") == "function":
188
+ func = tool["function"]
189
+ anthropic_tools.append(
190
+ {
191
+ "name": func["name"],
192
+ "description": func.get("description", ""),
193
+ "input_schema": func.get("parameters", {"type": "object"}),
194
+ }
195
+ )
196
+ else:
197
+ # Already in Anthropic format
198
+ anthropic_tools.append(tool)
199
+
200
+ return anthropic_tools
201
+
202
+ def _handle_error_response(self, response: httpx.Response) -> None:
203
+ """Convert HTTP error responses to appropriate exceptions."""
204
+ status_code = response.status_code
205
+
206
+ try:
207
+ error_data = response.json()
208
+ error_type = error_data.get("error", {}).get("type", "")
209
+ error_message = error_data.get("error", {}).get("message", response.text)
210
+ except Exception:
211
+ error_type = ""
212
+ error_message = response.text
213
+
214
+ if status_code == 401:
215
+ raise LLMAuthenticationError(self.PROVIDER_NAME, error_message)
216
+ elif status_code == 429:
217
+ retry_after = response.headers.get("retry-after")
218
+ retry_after_float = float(retry_after) if retry_after else None
219
+ raise LLMRateLimitError(self.PROVIDER_NAME, retry_after=retry_after_float, message=error_message)
220
+ elif status_code == 402 or "billing" in error_type.lower():
221
+ raise LLMQuotaExceededError(self.PROVIDER_NAME, error_message)
222
+ elif status_code == 404 or error_type == "not_found_error":
223
+ raise LLMModelNotFoundError(self.PROVIDER_NAME, self.model)
224
+ elif status_code == 400:
225
+ if "context" in error_message.lower() or "token" in error_message.lower():
226
+ raise LLMContextLengthError(self.PROVIDER_NAME)
227
+ if "content_policy" in error_type or "safety" in error_message.lower():
228
+ raise LLMContentFilterError(self.PROVIDER_NAME, error_message)
229
+ raise LLMInvalidRequestError(self.PROVIDER_NAME, error_message)
230
+ elif status_code >= 500:
231
+ raise LLMServerError(self.PROVIDER_NAME, status_code, error_message)
232
+ else:
233
+ raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code)
234
+
235
+ def _make_retry_decorator(self):
236
+ """Create retry decorator with exponential backoff."""
237
+ return retry(
238
+ stop=stop_after_attempt(self.max_retries),
239
+ wait=wait_exponential(multiplier=1, min=2, max=120),
240
+ retry=retry_if_exception_type((LLMRateLimitError, LLMServerError, LLMConnectionError)),
241
+ before_sleep=before_sleep_log(logger, logging.WARNING),
242
+ reraise=True,
243
+ )
244
+
245
+ async def generate(
246
+ self,
247
+ *,
248
+ messages: list[dict] | None = None,
249
+ prompt: str | None = None,
250
+ temperature: float = 0.7,
251
+ max_tokens: int | None = None,
252
+ tools: list[dict] | None = None,
253
+ stream: bool = False,
254
+ stop: list[str] | None = None,
255
+ **kwargs: Any,
256
+ ) -> LLMResponse | AsyncIterator[str]:
257
+ """
258
+ Generate a response from Anthropic Claude.
259
+
260
+ Args:
261
+ messages: Chat messages (will be converted to Anthropic format)
262
+ prompt: Simple string prompt
263
+ temperature: Sampling temperature (0.0 to 1.0 for Claude)
264
+ max_tokens: Maximum tokens to generate (required for Anthropic)
265
+ tools: Tool definitions (will be converted to Anthropic format)
266
+ stream: If True, returns AsyncIterator
267
+ stop: Stop sequences
268
+ **kwargs: Additional parameters (top_p, top_k, etc.)
269
+
270
+ Returns:
271
+ LLMResponse or AsyncIterator[str] for streaming
272
+ """
273
+ # Apply rate limiting before proceeding
274
+ await self._apply_rate_limit()
275
+
276
+ # Check circuit breaker
277
+ if not self.circuit_breaker.can_execute():
278
+ raise CircuitBreakerOpenError(
279
+ self.PROVIDER_NAME,
280
+ self.circuit_breaker.failure_count,
281
+ self.circuit_breaker.get_reset_time(),
282
+ )
283
+
284
+ # Anthropic requires max_tokens
285
+ if max_tokens is None:
286
+ max_tokens = 4096 # Sensible default
287
+
288
+ if stream:
289
+ return self._generate_stream(
290
+ messages=messages,
291
+ prompt=prompt,
292
+ temperature=temperature,
293
+ max_tokens=max_tokens,
294
+ tools=tools,
295
+ stop=stop,
296
+ **kwargs,
297
+ )
298
+ else:
299
+ return await self._generate_non_stream(
300
+ messages=messages,
301
+ prompt=prompt,
302
+ temperature=temperature,
303
+ max_tokens=max_tokens,
304
+ tools=tools,
305
+ stop=stop,
306
+ **kwargs,
307
+ )
308
+
309
+ async def _generate_non_stream(
310
+ self,
311
+ *,
312
+ messages: list[dict] | None = None,
313
+ prompt: str | None = None,
314
+ temperature: float = 0.7,
315
+ max_tokens: int = 4096,
316
+ tools: list[dict] | None = None,
317
+ stop: list[str] | None = None,
318
+ **kwargs: Any,
319
+ ) -> LLMResponse:
320
+ """Non-streaming generation with retry logic."""
321
+
322
+ @self._make_retry_decorator()
323
+ async def _request():
324
+ client = await self._get_client()
325
+
326
+ # Convert messages
327
+ built_messages = self._build_messages(messages, prompt)
328
+ system_prompt, anthropic_messages = self._convert_messages_to_anthropic(built_messages)
329
+
330
+ # Build request payload
331
+ payload = {
332
+ "model": self.model,
333
+ "messages": anthropic_messages,
334
+ "max_tokens": max_tokens,
335
+ "temperature": min(temperature, 1.0), # Anthropic max is 1.0
336
+ }
337
+
338
+ if system_prompt:
339
+ payload["system"] = system_prompt
340
+ if stop:
341
+ payload["stop_sequences"] = stop
342
+ if tools:
343
+ payload["tools"] = self._convert_tools_to_anthropic(tools)
344
+
345
+ # Add any additional kwargs (top_p, top_k, etc.)
346
+ for key in ["top_p", "top_k", "metadata"]:
347
+ if key in kwargs:
348
+ payload[key] = kwargs[key]
349
+
350
+ try:
351
+ response = await client.post("/v1/messages", json=payload)
352
+ except httpx.TimeoutException:
353
+ raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
354
+ except httpx.ConnectError:
355
+ raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
356
+
357
+ if response.status_code != 200:
358
+ self._handle_error_response(response)
359
+
360
+ return response
361
+
362
+ try:
363
+ response = await _request()
364
+ self.circuit_breaker.record_success()
365
+ except Exception:
366
+ self.circuit_breaker.record_failure()
367
+ raise
368
+
369
+ # Parse response
370
+ try:
371
+ data = response.json()
372
+
373
+ # Extract text from content blocks
374
+ text_parts = []
375
+ tool_calls = []
376
+
377
+ for block in data.get("content", []):
378
+ if block.get("type") == "text":
379
+ text_parts.append(block.get("text", ""))
380
+ elif block.get("type") == "tool_use":
381
+ tool_calls.append(
382
+ ToolCall(
383
+ id=block.get("id", ""),
384
+ name=block.get("name", ""),
385
+ arguments=block.get("input", {}),
386
+ type="tool_use",
387
+ )
388
+ )
389
+
390
+ text = "\n".join(text_parts)
391
+
392
+ # Build usage dict
393
+ usage = {
394
+ "prompt_tokens": data.get("usage", {}).get("input_tokens", 0),
395
+ "completion_tokens": data.get("usage", {}).get("output_tokens", 0),
396
+ }
397
+ usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
398
+
399
+ finish_reason = data.get("stop_reason", "stop")
400
+
401
+ if tool_calls:
402
+ llm_response = LLMToolResponse(
403
+ text=text,
404
+ usage=usage,
405
+ model=data.get("model", self.model),
406
+ raw_response=data,
407
+ finish_reason=finish_reason,
408
+ tool_calls=tool_calls,
409
+ )
410
+ else:
411
+ llm_response = LLMResponse(
412
+ text=text,
413
+ usage=usage,
414
+ model=data.get("model", self.model),
415
+ raw_response=data,
416
+ finish_reason=finish_reason,
417
+ )
418
+
419
+ self._update_stats(llm_response)
420
+ return llm_response
421
+
422
+ except (KeyError, json.JSONDecodeError) as e:
423
+ raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e
424
+
425
+ async def _generate_stream(
426
+ self,
427
+ *,
428
+ messages: list[dict] | None = None,
429
+ prompt: str | None = None,
430
+ temperature: float = 0.7,
431
+ max_tokens: int = 4096,
432
+ tools: list[dict] | None = None,
433
+ stop: list[str] | None = None,
434
+ **kwargs: Any,
435
+ ) -> AsyncIterator[str]:
436
+ """Streaming generation with Server-Sent Events."""
437
+
438
+ client = await self._get_client()
439
+
440
+ # Convert messages
441
+ built_messages = self._build_messages(messages, prompt)
442
+ system_prompt, anthropic_messages = self._convert_messages_to_anthropic(built_messages)
443
+
444
+ # Build request payload
445
+ payload = {
446
+ "model": self.model,
447
+ "messages": anthropic_messages,
448
+ "max_tokens": max_tokens,
449
+ "temperature": min(temperature, 1.0),
450
+ "stream": True,
451
+ }
452
+
453
+ if system_prompt:
454
+ payload["system"] = system_prompt
455
+ if stop:
456
+ payload["stop_sequences"] = stop
457
+ if tools:
458
+ payload["tools"] = self._convert_tools_to_anthropic(tools)
459
+
460
+ for key in ["top_p", "top_k"]:
461
+ if key in kwargs:
462
+ payload[key] = kwargs[key]
463
+
464
+ async def stream_generator():
465
+ try:
466
+ async with client.stream("POST", "/v1/messages", json=payload) as response:
467
+ if response.status_code != 200:
468
+ await response.aread()
469
+ self._handle_error_response(response)
470
+
471
+ async for line in response.aiter_lines():
472
+ if not line.strip():
473
+ continue
474
+
475
+ if line.startswith("event:"):
476
+ event_type = line[6:].strip()
477
+ continue
478
+
479
+ if line.startswith("data:"):
480
+ data_str = line[5:].strip()
481
+ if not data_str:
482
+ continue
483
+
484
+ try:
485
+ data = json.loads(data_str)
486
+ event_type = data.get("type", "")
487
+
488
+ if event_type == "content_block_delta":
489
+ delta = data.get("delta", {})
490
+ if delta.get("type") == "text_delta":
491
+ text = delta.get("text", "")
492
+ if text:
493
+ yield text
494
+
495
+ elif event_type == "message_stop":
496
+ break
497
+
498
+ except json.JSONDecodeError:
499
+ continue
500
+
501
+ self.circuit_breaker.record_success()
502
+
503
+ except httpx.TimeoutException:
504
+ self.circuit_breaker.record_failure()
505
+ raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
506
+ except httpx.ConnectError:
507
+ self.circuit_breaker.record_failure()
508
+ raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
509
+ except Exception as e:
510
+ self.circuit_breaker.record_failure()
511
+ if isinstance(e, LLMClientError):
512
+ raise
513
+ raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e
514
+
515
+ return stream_generator()
516
+
517
+ async def close(self) -> None:
518
+ """Close the HTTP client."""
519
+ if self._client and not self._client.is_closed:
520
+ await self._client.aclose()
521
+ self._client = None
src/adapters/llm/base.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base LLM client interface for provider-agnostic model access.
3
+
4
+ This module defines the protocol and data structures for LLM clients,
5
+ enabling seamless switching between providers (OpenAI, Anthropic, LM Studio, etc.)
6
+ """
7
+
8
+ import asyncio
9
+ import time
10
+ from abc import ABC, abstractmethod
11
+ from collections.abc import AsyncIterator
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime
14
+ from typing import Any, Protocol, runtime_checkable
15
+
16
+
17
+ @dataclass
18
+ class LLMResponse:
19
+ """Standardized response from any LLM provider."""
20
+
21
+ text: str
22
+ usage: dict = field(default_factory=dict)
23
+ model: str = ""
24
+ raw_response: Any = None
25
+ finish_reason: str = "stop"
26
+ created_at: datetime = field(default_factory=datetime.utcnow)
27
+
28
+ @property
29
+ def total_tokens(self) -> int:
30
+ """Total tokens used in request/response."""
31
+ return self.usage.get("total_tokens", 0)
32
+
33
+ @property
34
+ def prompt_tokens(self) -> int:
35
+ """Tokens used in prompt."""
36
+ return self.usage.get("prompt_tokens", 0)
37
+
38
+ @property
39
+ def completion_tokens(self) -> int:
40
+ """Tokens used in completion."""
41
+ return self.usage.get("completion_tokens", 0)
42
+
43
+
44
+ @dataclass
45
+ class ToolCall:
46
+ """Represents a tool/function call from the LLM."""
47
+
48
+ id: str
49
+ name: str
50
+ arguments: dict
51
+ type: str = "function"
52
+
53
+
54
+ @dataclass
55
+ class LLMToolResponse(LLMResponse):
56
+ """Response containing tool calls."""
57
+
58
+ tool_calls: list[ToolCall] = field(default_factory=list)
59
+
60
+
61
+ class TokenBucketRateLimiter:
62
+ """
63
+ Token bucket rate limiter for controlling request rates.
64
+
65
+ This implementation uses a token bucket algorithm where:
66
+ - Tokens are added at a fixed rate (rate_per_second)
67
+ - Each request consumes one token
68
+ - If no tokens available, caller waits until one becomes available
69
+ """
70
+
71
+ def __init__(self, rate_per_minute: int = 60):
72
+ """
73
+ Initialize the rate limiter.
74
+
75
+ Args:
76
+ rate_per_minute: Maximum requests allowed per minute
77
+ """
78
+ self.rate_per_second = rate_per_minute / 60.0
79
+ self.max_tokens = float(rate_per_minute)
80
+ self.tokens = self.max_tokens
81
+ self.last_refill = time.monotonic()
82
+ self._lock = asyncio.Lock()
83
+ self._wait_count = 0
84
+ self._total_wait_time = 0.0
85
+
86
+ async def acquire(self) -> float:
87
+ """
88
+ Acquire a token, waiting if necessary.
89
+
90
+ Returns:
91
+ Time spent waiting (0.0 if no wait was needed)
92
+ """
93
+ async with self._lock:
94
+ now = time.monotonic()
95
+ elapsed = now - self.last_refill
96
+
97
+ # Refill tokens based on elapsed time
98
+ self.tokens = min(self.max_tokens, self.tokens + elapsed * self.rate_per_second)
99
+ self.last_refill = now
100
+
101
+ wait_time = 0.0
102
+ if self.tokens < 1:
103
+ # Calculate how long to wait for one token
104
+ wait_time = (1 - self.tokens) / self.rate_per_second
105
+ self._wait_count += 1
106
+ self._total_wait_time += wait_time
107
+
108
+ # Release lock during sleep to allow other operations
109
+ self._lock.release()
110
+ try:
111
+ await asyncio.sleep(wait_time)
112
+ finally:
113
+ await self._lock.acquire()
114
+
115
+ # After sleeping, update time and set tokens to 0
116
+ self.last_refill = time.monotonic()
117
+ self.tokens = 0
118
+ else:
119
+ self.tokens -= 1
120
+
121
+ return wait_time
122
+
123
+ @property
124
+ def stats(self) -> dict:
125
+ """Get rate limiter statistics."""
126
+ return {
127
+ "rate_limit_waits": self._wait_count,
128
+ "total_rate_limit_wait_time": self._total_wait_time,
129
+ "current_tokens": self.tokens,
130
+ }
131
+
132
+
133
+ @runtime_checkable
134
+ class LLMClient(Protocol):
135
+ """
136
+ Protocol for LLM clients.
137
+
138
+ This protocol defines the interface that all LLM provider adapters must implement.
139
+ Using Protocol allows for structural subtyping (duck typing) while maintaining
140
+ type safety.
141
+ """
142
+
143
+ async def generate(
144
+ self,
145
+ *,
146
+ messages: list[dict] | None = None,
147
+ prompt: str | None = None,
148
+ temperature: float = 0.7,
149
+ max_tokens: int | None = None,
150
+ tools: list[dict] | None = None,
151
+ stream: bool = False,
152
+ stop: list[str] | None = None,
153
+ **kwargs: Any,
154
+ ) -> LLMResponse | AsyncIterator[str]:
155
+ """
156
+ Generate a response from the LLM.
157
+
158
+ Args:
159
+ messages: List of message dicts in OpenAI format [{"role": "...", "content": "..."}]
160
+ prompt: Simple string prompt (converted to single user message)
161
+ temperature: Sampling temperature (0.0 to 2.0)
162
+ max_tokens: Maximum tokens to generate
163
+ tools: List of tool definitions for function calling
164
+ stream: If True, returns AsyncIterator[str] for streaming
165
+ stop: Stop sequences
166
+ **kwargs: Provider-specific parameters
167
+
168
+ Returns:
169
+ LLMResponse if stream=False, AsyncIterator[str] if stream=True
170
+
171
+ Raises:
172
+ LLMClientError: Base exception for all client errors
173
+ """
174
+ ...
175
+
176
+
177
+ class BaseLLMClient(ABC):
178
+ """
179
+ Abstract base class for LLM clients.
180
+
181
+ Provides common functionality and enforces the interface contract.
182
+ All concrete implementations should inherit from this class.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ api_key: str | None = None,
188
+ model: str = "default",
189
+ base_url: str | None = None,
190
+ timeout: float = 60.0,
191
+ max_retries: int = 3,
192
+ rate_limit_per_minute: int | None = None,
193
+ ):
194
+ """
195
+ Initialize the LLM client.
196
+
197
+ Args:
198
+ api_key: API key for authentication
199
+ model: Model identifier
200
+ base_url: Base URL for API requests
201
+ timeout: Request timeout in seconds
202
+ max_retries: Maximum number of retry attempts
203
+ rate_limit_per_minute: Rate limit (requests per minute), None to disable
204
+ """
205
+ self.api_key = api_key
206
+ self.model = model
207
+ self.base_url = base_url
208
+ self.timeout = timeout
209
+ self.max_retries = max_retries
210
+ self._request_count = 0
211
+ self._total_tokens_used = 0
212
+ self._rate_limited_requests = 0
213
+
214
+ # Initialize rate limiter if configured
215
+ if rate_limit_per_minute is not None and rate_limit_per_minute > 0:
216
+ self._rate_limiter: TokenBucketRateLimiter | None = TokenBucketRateLimiter(
217
+ rate_per_minute=rate_limit_per_minute
218
+ )
219
+ else:
220
+ self._rate_limiter = None
221
+
222
+ @abstractmethod
223
+ async def generate(
224
+ self,
225
+ *,
226
+ messages: list[dict] | None = None,
227
+ prompt: str | None = None,
228
+ temperature: float = 0.7,
229
+ max_tokens: int | None = None,
230
+ tools: list[dict] | None = None,
231
+ stream: bool = False,
232
+ stop: list[str] | None = None,
233
+ **kwargs: Any,
234
+ ) -> LLMResponse | AsyncIterator[str]:
235
+ """Generate a response from the LLM."""
236
+ pass
237
+
238
+ def _build_messages(
239
+ self,
240
+ messages: list[dict] | None = None,
241
+ prompt: str | None = None,
242
+ ) -> list[dict]:
243
+ """
244
+ Build message list from either messages or prompt.
245
+
246
+ Args:
247
+ messages: Pre-formatted message list
248
+ prompt: Simple string prompt
249
+
250
+ Returns:
251
+ List of message dicts
252
+
253
+ Raises:
254
+ ValueError: If neither messages nor prompt provided
255
+ """
256
+ if messages is not None:
257
+ return messages
258
+ elif prompt is not None:
259
+ return [{"role": "user", "content": prompt}]
260
+ else:
261
+ raise ValueError("Either 'messages' or 'prompt' must be provided")
262
+
263
+ def _update_stats(self, response: LLMResponse) -> None:
264
+ """Update internal statistics."""
265
+ self._request_count += 1
266
+ self._total_tokens_used += response.total_tokens
267
+
268
+ async def _apply_rate_limit(self) -> None:
269
+ """
270
+ Apply rate limiting if configured.
271
+
272
+ Waits if necessary to comply with rate limits.
273
+ Tracks rate-limited requests in metrics.
274
+ """
275
+ if self._rate_limiter is not None:
276
+ wait_time = await self._rate_limiter.acquire()
277
+ if wait_time > 0:
278
+ self._rate_limited_requests += 1
279
+
280
+ @property
281
+ def stats(self) -> dict:
282
+ """Get client statistics."""
283
+ base_stats = {
284
+ "request_count": self._request_count,
285
+ "total_tokens_used": self._total_tokens_used,
286
+ "rate_limited_requests": self._rate_limited_requests,
287
+ }
288
+
289
+ # Include rate limiter stats if available
290
+ if self._rate_limiter is not None:
291
+ base_stats.update(self._rate_limiter.stats)
292
+
293
+ return base_stats
294
+
295
+ async def close(self) -> None: # noqa: B027
296
+ """Clean up resources. Override in subclasses if needed."""
297
+ pass
298
+
299
+ async def __aenter__(self):
300
+ """Async context manager entry."""
301
+ return self
302
+
303
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
304
+ """Async context manager exit."""
305
+ await self.close()
src/adapters/llm/exceptions.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom exceptions for LLM client operations.
3
+
4
+ Provides a hierarchy of structured exceptions for better error handling
5
+ and debugging across different LLM providers.
6
+ """
7
+
8
+
9
+ class LLMClientError(Exception):
10
+ """Base exception for all LLM client errors."""
11
+
12
+ def __init__(
13
+ self,
14
+ message: str,
15
+ provider: str = "unknown",
16
+ status_code: int | None = None,
17
+ retry_after: float | None = None,
18
+ ):
19
+ self.message = message
20
+ self.provider = provider
21
+ self.status_code = status_code
22
+ self.retry_after = retry_after
23
+ super().__init__(self.message)
24
+
25
+ def __str__(self) -> str:
26
+ parts = [f"[{self.provider}] {self.message}"]
27
+ if self.status_code:
28
+ parts.append(f"(status: {self.status_code})")
29
+ return " ".join(parts)
30
+
31
+
32
+ class LLMAuthenticationError(LLMClientError):
33
+ """Authentication failed - invalid or missing API key."""
34
+
35
+ def __init__(self, provider: str, message: str = "Authentication failed"):
36
+ super().__init__(
37
+ message=message,
38
+ provider=provider,
39
+ status_code=401,
40
+ )
41
+
42
+
43
+ class LLMRateLimitError(LLMClientError):
44
+ """Rate limit exceeded - too many requests."""
45
+
46
+ def __init__(
47
+ self,
48
+ provider: str,
49
+ retry_after: float | None = None,
50
+ message: str = "Rate limit exceeded",
51
+ ):
52
+ super().__init__(
53
+ message=message,
54
+ provider=provider,
55
+ status_code=429,
56
+ retry_after=retry_after,
57
+ )
58
+
59
+
60
+ class LLMQuotaExceededError(LLMClientError):
61
+ """Quota or credits exhausted."""
62
+
63
+ def __init__(self, provider: str, message: str = "Quota exceeded"):
64
+ super().__init__(
65
+ message=message,
66
+ provider=provider,
67
+ status_code=402,
68
+ )
69
+
70
+
71
+ class LLMModelNotFoundError(LLMClientError):
72
+ """Requested model not available."""
73
+
74
+ def __init__(self, provider: str, model: str):
75
+ super().__init__(
76
+ message=f"Model '{model}' not found or not available",
77
+ provider=provider,
78
+ status_code=404,
79
+ )
80
+
81
+
82
+ class LLMContextLengthError(LLMClientError):
83
+ """Input exceeds model's context window."""
84
+
85
+ def __init__(
86
+ self,
87
+ provider: str,
88
+ token_count: int | None = None,
89
+ max_tokens: int | None = None,
90
+ ):
91
+ message = "Context length exceeded"
92
+ if token_count and max_tokens:
93
+ message = f"Context length exceeded: {token_count} tokens provided, max is {max_tokens}"
94
+ super().__init__(
95
+ message=message,
96
+ provider=provider,
97
+ status_code=400,
98
+ )
99
+
100
+
101
+ class LLMInvalidRequestError(LLMClientError):
102
+ """Invalid request parameters."""
103
+
104
+ def __init__(self, provider: str, message: str = "Invalid request parameters"):
105
+ super().__init__(
106
+ message=message,
107
+ provider=provider,
108
+ status_code=400,
109
+ )
110
+
111
+
112
+ class LLMTimeoutError(LLMClientError):
113
+ """Request timed out."""
114
+
115
+ def __init__(self, provider: str, timeout: float):
116
+ super().__init__(
117
+ message=f"Request timed out after {timeout}s",
118
+ provider=provider,
119
+ status_code=408,
120
+ )
121
+
122
+
123
+ class LLMConnectionError(LLMClientError):
124
+ """Failed to connect to the API endpoint."""
125
+
126
+ def __init__(self, provider: str, url: str | None = None):
127
+ message = "Failed to connect to API"
128
+ if url:
129
+ message = f"Failed to connect to {url}"
130
+ super().__init__(
131
+ message=message,
132
+ provider=provider,
133
+ )
134
+
135
+
136
+ class LLMServerError(LLMClientError):
137
+ """Server-side error from the LLM provider."""
138
+
139
+ def __init__(
140
+ self,
141
+ provider: str,
142
+ status_code: int = 500,
143
+ message: str = "Server error",
144
+ ):
145
+ super().__init__(
146
+ message=message,
147
+ provider=provider,
148
+ status_code=status_code,
149
+ )
150
+
151
+
152
+ class LLMResponseParseError(LLMClientError):
153
+ """Failed to parse response from LLM provider."""
154
+
155
+ def __init__(self, provider: str, raw_response: str | None = None):
156
+ message = "Failed to parse response"
157
+ if raw_response:
158
+ preview = raw_response[:200] + "..." if len(raw_response) > 200 else raw_response
159
+ message = f"Failed to parse response: {preview}"
160
+ super().__init__(
161
+ message=message,
162
+ provider=provider,
163
+ )
164
+
165
+
166
+ class LLMStreamError(LLMClientError):
167
+ """Error during streaming response."""
168
+
169
+ def __init__(self, provider: str, message: str = "Stream interrupted"):
170
+ super().__init__(
171
+ message=message,
172
+ provider=provider,
173
+ )
174
+
175
+
176
+ class LLMContentFilterError(LLMClientError):
177
+ """Content blocked by safety filters."""
178
+
179
+ def __init__(self, provider: str, reason: str | None = None):
180
+ message = "Content blocked by safety filters"
181
+ if reason:
182
+ message = f"Content blocked: {reason}"
183
+ super().__init__(
184
+ message=message,
185
+ provider=provider,
186
+ status_code=400,
187
+ )
188
+
189
+
190
+ class CircuitBreakerOpenError(LLMClientError):
191
+ """Circuit breaker is open, requests are being blocked."""
192
+
193
+ def __init__(
194
+ self,
195
+ provider: str,
196
+ failure_count: int,
197
+ reset_time: float,
198
+ ):
199
+ super().__init__(
200
+ message=f"Circuit breaker open after {failure_count} failures. Resets in {reset_time:.1f}s",
201
+ provider=provider,
202
+ )
203
+ self.failure_count = failure_count
204
+ self.reset_time = reset_time
src/adapters/llm/lmstudio_client.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LM Studio local LLM client adapter.
3
+
4
+ Implements the LLMClient protocol for LM Studio's OpenAI-compatible API.
5
+ Designed for running local models with configurable endpoint.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from collections.abc import AsyncIterator
11
+ from typing import Any
12
+
13
+ import httpx
14
+
15
+ from .base import BaseLLMClient, LLMResponse
16
+ from .exceptions import (
17
+ LLMClientError,
18
+ LLMConnectionError,
19
+ LLMResponseParseError,
20
+ LLMServerError,
21
+ LLMStreamError,
22
+ LLMTimeoutError,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class LMStudioClient(BaseLLMClient):
29
+ """
30
+ LM Studio local server client.
31
+
32
+ LM Studio provides an OpenAI-compatible API for running local models.
33
+ This client is optimized for local deployment with:
34
+ - No authentication required (local)
35
+ - Configurable base URL
36
+ - No circuit breaker (local server expected to be stable)
37
+ - Longer timeouts for large models
38
+ """
39
+
40
+ PROVIDER_NAME = "lmstudio"
41
+ DEFAULT_BASE_URL = "http://localhost:1234/v1"
42
+ DEFAULT_MODEL = "local-model" # LM Studio uses the loaded model
43
+
44
+ def __init__(
45
+ self,
46
+ api_key: str | None = None, # Not required for local
47
+ model: str | None = None,
48
+ base_url: str | None = None,
49
+ timeout: float = 300.0, # Long timeout for local inference
50
+ max_retries: int = 2, # Fewer retries for local
51
+ # Rate limiting
52
+ rate_limit_per_minute: int | None = None,
53
+ ):
54
+ """
55
+ Initialize LM Studio client.
56
+
57
+ Args:
58
+ api_key: Not required for local server (ignored)
59
+ model: Model identifier (often ignored by LM Studio, uses loaded model)
60
+ base_url: Local server URL (default: http://localhost:1234/v1)
61
+ timeout: Request timeout in seconds (default longer for local models)
62
+ max_retries: Max retry attempts (fewer for local)
63
+ rate_limit_per_minute: Rate limit for requests per minute (None to disable)
64
+ """
65
+ import os
66
+
67
+ # Allow overriding via environment variable
68
+ base_url = base_url or os.environ.get("LMSTUDIO_BASE_URL", self.DEFAULT_BASE_URL)
69
+
70
+ super().__init__(
71
+ api_key=api_key or "not-required", # Placeholder
72
+ model=model or self.DEFAULT_MODEL,
73
+ base_url=base_url,
74
+ timeout=timeout,
75
+ max_retries=max_retries,
76
+ rate_limit_per_minute=rate_limit_per_minute,
77
+ )
78
+
79
+ self._client: httpx.AsyncClient | None = None
80
+
81
+ async def _get_client(self) -> httpx.AsyncClient:
82
+ """Get or create the HTTP client."""
83
+ if self._client is None or self._client.is_closed:
84
+ headers = {"Content-Type": "application/json"}
85
+
86
+ # Add auth header if provided (some local servers may require it)
87
+ if self.api_key and self.api_key != "not-required":
88
+ headers["Authorization"] = f"Bearer {self.api_key}"
89
+
90
+ self._client = httpx.AsyncClient(
91
+ base_url=self.base_url,
92
+ headers=headers,
93
+ timeout=httpx.Timeout(self.timeout),
94
+ )
95
+ return self._client
96
+
97
+ async def check_health(self) -> bool:
98
+ """
99
+ Check if LM Studio server is running.
100
+
101
+ Returns:
102
+ True if server is accessible, False otherwise
103
+ """
104
+ try:
105
+ client = await self._get_client()
106
+ response = await client.get("/models")
107
+ return response.status_code == 200
108
+ except Exception:
109
+ return False
110
+
111
+ async def list_models(self) -> list[dict]:
112
+ """
113
+ List available models on the LM Studio server.
114
+
115
+ Returns:
116
+ List of model information dicts
117
+ """
118
+ try:
119
+ client = await self._get_client()
120
+ response = await client.get("/models")
121
+ if response.status_code == 200:
122
+ data = response.json()
123
+ return data.get("data", [])
124
+ return []
125
+ except Exception as e:
126
+ logger.warning(f"Failed to list models: {e}")
127
+ return []
128
+
129
+ def _handle_error_response(self, response: httpx.Response) -> None:
130
+ """Handle error responses from LM Studio server."""
131
+ status_code = response.status_code
132
+
133
+ try:
134
+ error_data = response.json()
135
+ error_message = error_data.get("error", {}).get("message", response.text)
136
+ except Exception:
137
+ error_message = response.text
138
+
139
+ if status_code >= 500:
140
+ raise LLMServerError(self.PROVIDER_NAME, status_code, error_message)
141
+ else:
142
+ raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code)
143
+
144
+ async def generate(
145
+ self,
146
+ *,
147
+ messages: list[dict] | None = None,
148
+ prompt: str | None = None,
149
+ temperature: float = 0.7,
150
+ max_tokens: int | None = None,
151
+ tools: list[dict] | None = None,
152
+ stream: bool = False,
153
+ stop: list[str] | None = None,
154
+ **kwargs: Any,
155
+ ) -> LLMResponse | AsyncIterator[str]:
156
+ """
157
+ Generate a response from LM Studio local model.
158
+
159
+ Args:
160
+ messages: Chat messages in OpenAI format
161
+ prompt: Simple string prompt
162
+ temperature: Sampling temperature
163
+ max_tokens: Maximum tokens to generate
164
+ tools: Tool definitions (limited support in local models)
165
+ stream: If True, returns AsyncIterator
166
+ stop: Stop sequences
167
+ **kwargs: Additional parameters
168
+
169
+ Returns:
170
+ LLMResponse or AsyncIterator[str] for streaming
171
+ """
172
+ # Apply rate limiting before proceeding
173
+ await self._apply_rate_limit()
174
+
175
+ if stream:
176
+ return self._generate_stream(
177
+ messages=messages,
178
+ prompt=prompt,
179
+ temperature=temperature,
180
+ max_tokens=max_tokens,
181
+ tools=tools,
182
+ stop=stop,
183
+ **kwargs,
184
+ )
185
+ else:
186
+ return await self._generate_non_stream(
187
+ messages=messages,
188
+ prompt=prompt,
189
+ temperature=temperature,
190
+ max_tokens=max_tokens,
191
+ tools=tools,
192
+ stop=stop,
193
+ **kwargs,
194
+ )
195
+
196
+ async def _generate_non_stream(
197
+ self,
198
+ *,
199
+ messages: list[dict] | None = None,
200
+ prompt: str | None = None,
201
+ temperature: float = 0.7,
202
+ max_tokens: int | None = None,
203
+ tools: list[dict] | None = None,
204
+ stop: list[str] | None = None,
205
+ **kwargs: Any,
206
+ ) -> LLMResponse:
207
+ """Non-streaming generation."""
208
+ client = await self._get_client()
209
+
210
+ # Build request payload (OpenAI-compatible)
211
+ payload = {
212
+ "model": self.model,
213
+ "messages": self._build_messages(messages, prompt),
214
+ "temperature": temperature,
215
+ }
216
+
217
+ if max_tokens is not None:
218
+ payload["max_tokens"] = max_tokens
219
+ if stop:
220
+ payload["stop"] = stop
221
+
222
+ # Note: most local models don't support tools well
223
+ if tools:
224
+ logger.warning("Tool calling may not be fully supported by local models")
225
+ payload["tools"] = tools
226
+
227
+ # Add additional kwargs (e.g., top_p, repeat_penalty)
228
+ for key in ["top_p", "top_k", "repeat_penalty", "presence_penalty", "frequency_penalty"]:
229
+ if key in kwargs:
230
+ payload[key] = kwargs[key]
231
+
232
+ # Retry logic for local server
233
+ last_error = None
234
+ for attempt in range(self.max_retries):
235
+ try:
236
+ response = await client.post("/chat/completions", json=payload)
237
+
238
+ if response.status_code != 200:
239
+ self._handle_error_response(response)
240
+
241
+ # Parse response
242
+ try:
243
+ data = response.json()
244
+ choice = data["choices"][0]
245
+ message = choice["message"]
246
+
247
+ usage = data.get("usage", {})
248
+ finish_reason = choice.get("finish_reason", "stop")
249
+
250
+ llm_response = LLMResponse(
251
+ text=message.get("content", ""),
252
+ usage=usage,
253
+ model=data.get("model", self.model),
254
+ raw_response=data,
255
+ finish_reason=finish_reason,
256
+ )
257
+
258
+ self._update_stats(llm_response)
259
+ return llm_response
260
+
261
+ except (KeyError, json.JSONDecodeError) as e:
262
+ raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e
263
+
264
+ except httpx.TimeoutException:
265
+ last_error = LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
266
+ logger.warning(f"Attempt {attempt + 1} timed out, retrying...")
267
+ except httpx.ConnectError:
268
+ last_error = LLMConnectionError(self.PROVIDER_NAME, self.base_url)
269
+ logger.warning(f"Attempt {attempt + 1} connection failed, retrying...")
270
+ except LLMClientError:
271
+ raise # Don't retry client errors
272
+
273
+ # All retries exhausted
274
+ if last_error:
275
+ raise last_error
276
+ raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
277
+
278
+ async def _generate_stream(
279
+ self,
280
+ *,
281
+ messages: list[dict] | None = None,
282
+ prompt: str | None = None,
283
+ temperature: float = 0.7,
284
+ max_tokens: int | None = None,
285
+ tools: list[dict] | None = None, # noqa: ARG002
286
+ stop: list[str] | None = None,
287
+ **kwargs: Any,
288
+ ) -> AsyncIterator[str]:
289
+ """Streaming generation."""
290
+ client = await self._get_client()
291
+
292
+ # Build request payload
293
+ payload = {
294
+ "model": self.model,
295
+ "messages": self._build_messages(messages, prompt),
296
+ "temperature": temperature,
297
+ "stream": True,
298
+ }
299
+
300
+ if max_tokens is not None:
301
+ payload["max_tokens"] = max_tokens
302
+ if stop:
303
+ payload["stop"] = stop
304
+
305
+ for key in ["top_p", "top_k", "repeat_penalty"]:
306
+ if key in kwargs:
307
+ payload[key] = kwargs[key]
308
+
309
+ async def stream_generator():
310
+ try:
311
+ async with client.stream("POST", "/chat/completions", json=payload) as response:
312
+ if response.status_code != 200:
313
+ await response.aread()
314
+ self._handle_error_response(response)
315
+
316
+ async for line in response.aiter_lines():
317
+ if line.startswith("data: "):
318
+ data_str = line[6:]
319
+ if data_str.strip() == "[DONE]":
320
+ break
321
+
322
+ try:
323
+ data = json.loads(data_str)
324
+ delta = data["choices"][0].get("delta", {})
325
+ content = delta.get("content", "")
326
+ if content:
327
+ yield content
328
+ except (json.JSONDecodeError, KeyError):
329
+ continue
330
+
331
+ except httpx.TimeoutException:
332
+ raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
333
+ except httpx.ConnectError:
334
+ raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
335
+ except Exception as e:
336
+ if isinstance(e, LLMClientError):
337
+ raise
338
+ raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e
339
+
340
+ return stream_generator()
341
+
342
+ async def close(self) -> None:
343
+ """Close the HTTP client."""
344
+ if self._client and not self._client.is_closed:
345
+ await self._client.aclose()
346
+ self._client = None
src/adapters/llm/openai_client.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI-compatible LLM client adapter.
3
+
4
+ Implements the LLMClient protocol for OpenAI API (and compatible APIs).
5
+ Includes retry logic, circuit breaker pattern, and streaming support.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import time
11
+ from collections.abc import AsyncIterator
12
+ from typing import Any
13
+
14
+ import httpx
15
+ from tenacity import (
16
+ before_sleep_log,
17
+ retry,
18
+ retry_if_exception_type,
19
+ stop_after_attempt,
20
+ wait_exponential,
21
+ )
22
+
23
+ from .base import BaseLLMClient, LLMResponse, LLMToolResponse, ToolCall
24
+ from .exceptions import (
25
+ CircuitBreakerOpenError,
26
+ LLMAuthenticationError,
27
+ LLMClientError,
28
+ LLMConnectionError,
29
+ LLMContextLengthError,
30
+ LLMInvalidRequestError,
31
+ LLMModelNotFoundError,
32
+ LLMQuotaExceededError,
33
+ LLMRateLimitError,
34
+ LLMResponseParseError,
35
+ LLMServerError,
36
+ LLMStreamError,
37
+ LLMTimeoutError,
38
+ )
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class CircuitBreaker:
44
+ """Simple circuit breaker implementation for resilience."""
45
+
46
+ def __init__(
47
+ self,
48
+ failure_threshold: int = 5,
49
+ reset_timeout: float = 60.0,
50
+ half_open_max_calls: int = 1,
51
+ ):
52
+ self.failure_threshold = failure_threshold
53
+ self.reset_timeout = reset_timeout
54
+ self.half_open_max_calls = half_open_max_calls
55
+ self.failure_count = 0
56
+ self.last_failure_time = 0.0
57
+ self.state = "closed" # closed, open, half-open
58
+ self.half_open_calls = 0
59
+
60
+ def can_execute(self) -> bool:
61
+ """Check if request can be executed."""
62
+ if self.state == "closed":
63
+ return True
64
+
65
+ if self.state == "open":
66
+ # Check if reset timeout has passed
67
+ if time.time() - self.last_failure_time >= self.reset_timeout:
68
+ self.state = "half-open"
69
+ self.half_open_calls = 0
70
+ return True
71
+ return False
72
+
73
+ if self.state == "half-open":
74
+ return self.half_open_calls < self.half_open_max_calls
75
+
76
+ return False
77
+
78
+ def record_success(self) -> None:
79
+ """Record successful request."""
80
+ if self.state == "half-open":
81
+ self.state = "closed"
82
+ self.failure_count = 0
83
+ elif self.state == "closed":
84
+ self.failure_count = 0
85
+
86
+ def record_failure(self) -> None:
87
+ """Record failed request."""
88
+ self.failure_count += 1
89
+ self.last_failure_time = time.time()
90
+
91
+ if self.state == "half-open" or self.failure_count >= self.failure_threshold:
92
+ self.state = "open"
93
+
94
+ def get_reset_time(self) -> float:
95
+ """Get time until circuit resets."""
96
+ if self.state != "open":
97
+ return 0.0
98
+ elapsed = time.time() - self.last_failure_time
99
+ return max(0, self.reset_timeout - elapsed)
100
+
101
+
102
+ class OpenAIClient(BaseLLMClient):
103
+ """
104
+ OpenAI API client with retry logic and circuit breaker.
105
+
106
+ Features:
107
+ - Exponential backoff retry for transient errors
108
+ - Circuit breaker to prevent cascading failures
109
+ - Streaming support
110
+ - Structured error handling
111
+ - Tool/function calling support
112
+ """
113
+
114
+ PROVIDER_NAME = "openai"
115
+ DEFAULT_BASE_URL = "https://api.openai.com/v1"
116
+ DEFAULT_MODEL = "gpt-4-turbo-preview"
117
+
118
+ def __init__(
119
+ self,
120
+ api_key: str | None = None,
121
+ model: str | None = None,
122
+ base_url: str | None = None,
123
+ timeout: float = 60.0,
124
+ max_retries: int = 3,
125
+ organization: str | None = None,
126
+ # Circuit breaker settings
127
+ circuit_breaker_threshold: int = 5,
128
+ circuit_breaker_reset: float = 60.0,
129
+ # Rate limiting
130
+ rate_limit_per_minute: int | None = None,
131
+ ):
132
+ """
133
+ Initialize OpenAI client.
134
+
135
+ Args:
136
+ api_key: OpenAI API key (or set OPENAI_API_KEY env var)
137
+ model: Model to use (default: gpt-4-turbo-preview)
138
+ base_url: API base URL (default: https://api.openai.com/v1)
139
+ timeout: Request timeout in seconds
140
+ max_retries: Max retry attempts for transient errors
141
+ organization: Optional organization ID
142
+ circuit_breaker_threshold: Failures before circuit opens
143
+ circuit_breaker_reset: Seconds before circuit resets
144
+ rate_limit_per_minute: Rate limit for requests per minute (None to disable)
145
+ """
146
+ import os
147
+
148
+ api_key = api_key or os.environ.get("OPENAI_API_KEY")
149
+ if not api_key:
150
+ raise LLMAuthenticationError(self.PROVIDER_NAME, "API key not provided and OPENAI_API_KEY not set")
151
+
152
+ super().__init__(
153
+ api_key=api_key,
154
+ model=model or self.DEFAULT_MODEL,
155
+ base_url=base_url or self.DEFAULT_BASE_URL,
156
+ timeout=timeout,
157
+ max_retries=max_retries,
158
+ rate_limit_per_minute=rate_limit_per_minute,
159
+ )
160
+
161
+ self.organization = organization
162
+ self.circuit_breaker = CircuitBreaker(
163
+ failure_threshold=circuit_breaker_threshold,
164
+ reset_timeout=circuit_breaker_reset,
165
+ )
166
+
167
+ # Initialize async HTTP client
168
+ self._client: httpx.AsyncClient | None = None
169
+
170
+ async def _get_client(self) -> httpx.AsyncClient:
171
+ """Get or create the HTTP client."""
172
+ if self._client is None or self._client.is_closed:
173
+ headers = {
174
+ "Authorization": f"Bearer {self.api_key}",
175
+ "Content-Type": "application/json",
176
+ }
177
+ if self.organization:
178
+ headers["OpenAI-Organization"] = self.organization
179
+
180
+ self._client = httpx.AsyncClient(
181
+ base_url=self.base_url,
182
+ headers=headers,
183
+ timeout=httpx.Timeout(self.timeout),
184
+ )
185
+ return self._client
186
+
187
+ def _handle_error_response(self, response: httpx.Response) -> None:
188
+ """Convert HTTP error responses to appropriate exceptions."""
189
+ status_code = response.status_code
190
+
191
+ try:
192
+ error_data = response.json()
193
+ error_message = error_data.get("error", {}).get("message", response.text)
194
+ except Exception:
195
+ error_message = response.text
196
+
197
+ if status_code == 401:
198
+ raise LLMAuthenticationError(self.PROVIDER_NAME, error_message)
199
+ elif status_code == 429:
200
+ retry_after = response.headers.get("Retry-After")
201
+ retry_after_float = float(retry_after) if retry_after else None
202
+ raise LLMRateLimitError(self.PROVIDER_NAME, retry_after=retry_after_float, message=error_message)
203
+ elif status_code == 402:
204
+ raise LLMQuotaExceededError(self.PROVIDER_NAME, error_message)
205
+ elif status_code == 404:
206
+ raise LLMModelNotFoundError(self.PROVIDER_NAME, self.model)
207
+ elif status_code == 400:
208
+ if "context_length" in error_message.lower():
209
+ raise LLMContextLengthError(self.PROVIDER_NAME)
210
+ raise LLMInvalidRequestError(self.PROVIDER_NAME, error_message)
211
+ elif status_code >= 500:
212
+ raise LLMServerError(self.PROVIDER_NAME, status_code, error_message)
213
+ else:
214
+ raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code)
215
+
216
+ def _make_retry_decorator(self):
217
+ """Create retry decorator with exponential backoff."""
218
+ return retry(
219
+ stop=stop_after_attempt(self.max_retries),
220
+ wait=wait_exponential(multiplier=1, min=1, max=60),
221
+ retry=retry_if_exception_type((LLMRateLimitError, LLMServerError, LLMConnectionError)),
222
+ before_sleep=before_sleep_log(logger, logging.WARNING),
223
+ reraise=True,
224
+ )
225
+
226
+ async def generate(
227
+ self,
228
+ *,
229
+ messages: list[dict] | None = None,
230
+ prompt: str | None = None,
231
+ temperature: float = 0.7,
232
+ max_tokens: int | None = None,
233
+ tools: list[dict] | None = None,
234
+ stream: bool = False,
235
+ stop: list[str] | None = None,
236
+ **kwargs: Any,
237
+ ) -> LLMResponse | AsyncIterator[str]:
238
+ """
239
+ Generate a response from OpenAI.
240
+
241
+ Args:
242
+ messages: Chat messages in OpenAI format
243
+ prompt: Simple string prompt
244
+ temperature: Sampling temperature (0.0 to 2.0)
245
+ max_tokens: Maximum tokens to generate
246
+ tools: Tool definitions for function calling
247
+ stream: If True, returns AsyncIterator
248
+ stop: Stop sequences
249
+ **kwargs: Additional OpenAI parameters (top_p, presence_penalty, etc.)
250
+
251
+ Returns:
252
+ LLMResponse or AsyncIterator[str] for streaming
253
+ """
254
+ # Apply rate limiting before proceeding
255
+ await self._apply_rate_limit()
256
+
257
+ # Check circuit breaker
258
+ if not self.circuit_breaker.can_execute():
259
+ raise CircuitBreakerOpenError(
260
+ self.PROVIDER_NAME,
261
+ self.circuit_breaker.failure_count,
262
+ self.circuit_breaker.get_reset_time(),
263
+ )
264
+
265
+ if stream:
266
+ return self._generate_stream(
267
+ messages=messages,
268
+ prompt=prompt,
269
+ temperature=temperature,
270
+ max_tokens=max_tokens,
271
+ tools=tools,
272
+ stop=stop,
273
+ **kwargs,
274
+ )
275
+ else:
276
+ return await self._generate_non_stream(
277
+ messages=messages,
278
+ prompt=prompt,
279
+ temperature=temperature,
280
+ max_tokens=max_tokens,
281
+ tools=tools,
282
+ stop=stop,
283
+ **kwargs,
284
+ )
285
+
286
+ async def _generate_non_stream(
287
+ self,
288
+ *,
289
+ messages: list[dict] | None = None,
290
+ prompt: str | None = None,
291
+ temperature: float = 0.7,
292
+ max_tokens: int | None = None,
293
+ tools: list[dict] | None = None,
294
+ stop: list[str] | None = None,
295
+ **kwargs: Any,
296
+ ) -> LLMResponse:
297
+ """Non-streaming generation with retry logic."""
298
+
299
+ @self._make_retry_decorator()
300
+ async def _request():
301
+ client = await self._get_client()
302
+
303
+ # Build request payload
304
+ payload = {
305
+ "model": self.model,
306
+ "messages": self._build_messages(messages, prompt),
307
+ "temperature": temperature,
308
+ }
309
+
310
+ if max_tokens is not None:
311
+ payload["max_tokens"] = max_tokens
312
+ if stop:
313
+ payload["stop"] = stop
314
+ if tools:
315
+ payload["tools"] = tools
316
+ payload["tool_choice"] = kwargs.pop("tool_choice", "auto")
317
+
318
+ # Add any additional kwargs
319
+ payload.update(kwargs)
320
+
321
+ try:
322
+ response = await client.post("/chat/completions", json=payload)
323
+ except httpx.TimeoutException:
324
+ raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
325
+ except httpx.ConnectError:
326
+ raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
327
+
328
+ if response.status_code != 200:
329
+ self._handle_error_response(response)
330
+
331
+ return response
332
+
333
+ try:
334
+ response = await _request()
335
+ self.circuit_breaker.record_success()
336
+ except Exception:
337
+ self.circuit_breaker.record_failure()
338
+ raise
339
+
340
+ # Parse response
341
+ try:
342
+ data = response.json()
343
+ choice = data["choices"][0]
344
+ message = choice["message"]
345
+
346
+ usage = data.get("usage", {})
347
+ finish_reason = choice.get("finish_reason", "stop")
348
+
349
+ # Check for tool calls
350
+ if "tool_calls" in message:
351
+ tool_calls = [
352
+ ToolCall(
353
+ id=tc["id"],
354
+ name=tc["function"]["name"],
355
+ arguments=json.loads(tc["function"]["arguments"]),
356
+ )
357
+ for tc in message["tool_calls"]
358
+ ]
359
+ llm_response = LLMToolResponse(
360
+ text=message.get("content", ""),
361
+ usage=usage,
362
+ model=data.get("model", self.model),
363
+ raw_response=data,
364
+ finish_reason=finish_reason,
365
+ tool_calls=tool_calls,
366
+ )
367
+ else:
368
+ llm_response = LLMResponse(
369
+ text=message.get("content", ""),
370
+ usage=usage,
371
+ model=data.get("model", self.model),
372
+ raw_response=data,
373
+ finish_reason=finish_reason,
374
+ )
375
+
376
+ self._update_stats(llm_response)
377
+ return llm_response
378
+
379
+ except (KeyError, json.JSONDecodeError) as e:
380
+ raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e
381
+
382
+ async def _generate_stream(
383
+ self,
384
+ *,
385
+ messages: list[dict] | None = None,
386
+ prompt: str | None = None,
387
+ temperature: float = 0.7,
388
+ max_tokens: int | None = None,
389
+ tools: list[dict] | None = None,
390
+ stop: list[str] | None = None,
391
+ **kwargs: Any,
392
+ ) -> AsyncIterator[str]:
393
+ """Streaming generation."""
394
+
395
+ client = await self._get_client()
396
+
397
+ # Build request payload
398
+ payload = {
399
+ "model": self.model,
400
+ "messages": self._build_messages(messages, prompt),
401
+ "temperature": temperature,
402
+ "stream": True,
403
+ }
404
+
405
+ if max_tokens is not None:
406
+ payload["max_tokens"] = max_tokens
407
+ if stop:
408
+ payload["stop"] = stop
409
+ # Note: tools with streaming have limited support
410
+ if tools:
411
+ payload["tools"] = tools
412
+
413
+ payload.update(kwargs)
414
+
415
+ async def stream_generator():
416
+ try:
417
+ async with client.stream("POST", "/chat/completions", json=payload) as response:
418
+ if response.status_code != 200:
419
+ # Read the full response for error handling
420
+ await response.aread()
421
+ self._handle_error_response(response)
422
+
423
+ async for line in response.aiter_lines():
424
+ if line.startswith("data: "):
425
+ data_str = line[6:]
426
+ if data_str.strip() == "[DONE]":
427
+ break
428
+
429
+ try:
430
+ data = json.loads(data_str)
431
+ delta = data["choices"][0].get("delta", {})
432
+ content = delta.get("content", "")
433
+ if content:
434
+ yield content
435
+ except (json.JSONDecodeError, KeyError):
436
+ continue
437
+
438
+ self.circuit_breaker.record_success()
439
+
440
+ except httpx.TimeoutException:
441
+ self.circuit_breaker.record_failure()
442
+ raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
443
+ except httpx.ConnectError:
444
+ self.circuit_breaker.record_failure()
445
+ raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
446
+ except Exception as e:
447
+ self.circuit_breaker.record_failure()
448
+ if isinstance(e, LLMClientError):
449
+ raise
450
+ raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e
451
+
452
+ return stream_generator()
453
+
454
+ async def close(self) -> None:
455
+ """Close the HTTP client."""
456
+ if self._client and not self._client.is_closed:
457
+ await self._client.aclose()
458
+ self._client = None
src/agents/__init__.py ADDED
File without changes
src/agents/hrm_agent.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hierarchical Reasoning Model (HRM) Agent.
3
+
4
+ Implements the HRM architecture with:
5
+ - H-Module: High-level planning and decomposition
6
+ - L-Module: Low-level execution and refinement
7
+ - Adaptive Computation Time (ACT) for dynamic depth
8
+ - Halting mechanism based on confidence thresholds
9
+
10
+ Based on: "Hierarchical Reasoning for Compositional Generalization"
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from ..training.system_config import HRMConfig
22
+
23
+
24
+ @dataclass
25
+ class SubProblem:
26
+ """Represents a decomposed subproblem in the hierarchy."""
27
+
28
+ level: int # Hierarchy level (0 = root, higher = more abstract)
29
+ description: str # Natural language description
30
+ state: torch.Tensor # Latent state representation
31
+ parent_id: int | None = None # Parent subproblem ID
32
+ confidence: float = 0.0 # Confidence in this decomposition
33
+
34
+
35
+ @dataclass
36
+ class HRMOutput:
37
+ """Output from HRM processing."""
38
+
39
+ final_state: torch.Tensor # Final processed state
40
+ subproblems: list[SubProblem] # Hierarchical decomposition
41
+ halt_step: int # Step at which halting occurred
42
+ total_ponder_cost: float # Total computation cost (for training)
43
+ convergence_path: list[float] # Confidence at each step
44
+
45
+
46
+ class AdaptiveComputationTime(nn.Module):
47
+ """
48
+ Adaptive Computation Time (ACT) mechanism for dynamic depth.
49
+
50
+ Allows the model to "ponder" longer on difficult problems by
51
+ dynamically adjusting the number of processing steps.
52
+ """
53
+
54
+ def __init__(self, hidden_dim: int, epsilon: float = 0.01):
55
+ super().__init__()
56
+ self.epsilon = epsilon
57
+
58
+ # Halting unit: predicts probability of halting
59
+ self.halt_fc = nn.Sequential(
60
+ nn.Linear(hidden_dim, hidden_dim // 2),
61
+ nn.ReLU(),
62
+ nn.Linear(hidden_dim // 2, 1),
63
+ nn.Sigmoid(),
64
+ )
65
+
66
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, float]:
67
+ """
68
+ Compute halting probabilities.
69
+
70
+ Args:
71
+ hidden_states: [batch, seq, hidden_dim]
72
+
73
+ Returns:
74
+ halt_probs: [batch, seq] probability of halting
75
+ ponder_cost: Scalar cost for training
76
+ """
77
+ # Compute halting probabilities
78
+ halt_logits = self.halt_fc(hidden_states) # [batch, seq, 1]
79
+ halt_probs = halt_logits.squeeze(-1) # [batch, seq]
80
+
81
+ # Ponder cost is the expected number of steps
82
+ ponder_cost = halt_probs.sum(dim=-1).mean()
83
+
84
+ return halt_probs, ponder_cost
85
+
86
+
87
+ class HModule(nn.Module):
88
+ """
89
+ H-Module: High-level planning and abstract reasoning.
90
+
91
+ Responsible for:
92
+ - Decomposing problems into subproblems
93
+ - Abstract planning and strategy
94
+ - Coordinating L-module executions
95
+ """
96
+
97
+ def __init__(self, config: HRMConfig):
98
+ super().__init__()
99
+ self.config = config
100
+
101
+ # Multi-head self-attention for relational reasoning
102
+ self.attention = nn.MultiheadAttention(
103
+ embed_dim=config.h_dim,
104
+ num_heads=8,
105
+ dropout=config.dropout,
106
+ batch_first=True,
107
+ )
108
+
109
+ # Feed-forward network
110
+ self.ffn = nn.Sequential(
111
+ nn.Linear(config.h_dim, config.h_dim * 4),
112
+ nn.GELU(),
113
+ nn.Dropout(config.dropout),
114
+ nn.Linear(config.h_dim * 4, config.h_dim),
115
+ nn.Dropout(config.dropout),
116
+ )
117
+
118
+ # Layer normalization
119
+ self.norm1 = nn.LayerNorm(config.h_dim)
120
+ self.norm2 = nn.LayerNorm(config.h_dim)
121
+
122
+ # Decomposition head: outputs subproblem structure
123
+ self.decompose_head = nn.Sequential(
124
+ nn.Linear(config.h_dim, config.h_dim),
125
+ nn.ReLU(),
126
+ nn.Linear(config.h_dim, config.h_dim),
127
+ )
128
+
129
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
130
+ """
131
+ Process input through high-level reasoning.
132
+
133
+ Args:
134
+ x: [batch, seq, h_dim] input tensor
135
+
136
+ Returns:
137
+ [batch, seq, h_dim] processed tensor
138
+ """
139
+ # Self-attention for relational reasoning
140
+ attn_out, _ = self.attention(x, x, x)
141
+ x = self.norm1(x + attn_out)
142
+
143
+ # Feed-forward processing
144
+ ffn_out = self.ffn(x)
145
+ x = self.norm2(x + ffn_out)
146
+
147
+ return x
148
+
149
+ def decompose(self, x: torch.Tensor) -> torch.Tensor:
150
+ """Generate subproblem representations."""
151
+ return self.decompose_head(x)
152
+
153
+
154
+ class LModule(nn.Module):
155
+ """
156
+ L-Module: Low-level execution and concrete operations.
157
+
158
+ Responsible for:
159
+ - Executing concrete operations
160
+ - Processing individual subproblems
161
+ - Generating intermediate results
162
+ """
163
+
164
+ def __init__(self, config: HRMConfig):
165
+ super().__init__()
166
+ self.config = config
167
+
168
+ # Projection from H-module to L-module dimension
169
+ self.h_to_l = nn.Linear(config.h_dim, config.l_dim)
170
+
171
+ # GRU for sequential processing
172
+ self.gru = nn.GRU(
173
+ input_size=config.l_dim,
174
+ hidden_size=config.l_dim,
175
+ num_layers=config.num_l_layers,
176
+ dropout=config.dropout if config.num_l_layers > 1 else 0,
177
+ batch_first=True,
178
+ )
179
+
180
+ # Output projection
181
+ self.output_proj = nn.Sequential(
182
+ nn.Linear(config.l_dim, config.l_dim * 2),
183
+ nn.ReLU(),
184
+ nn.Dropout(config.dropout),
185
+ nn.Linear(config.l_dim * 2, config.l_dim),
186
+ )
187
+
188
+ # Back-projection to H-module dimension
189
+ self.l_to_h = nn.Linear(config.l_dim, config.h_dim)
190
+
191
+ def forward(self, x: torch.Tensor, h_context: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
192
+ """
193
+ Execute low-level processing.
194
+
195
+ Args:
196
+ x: [batch, seq, h_dim] input from H-module
197
+ h_context: Optional hidden state
198
+
199
+ Returns:
200
+ output: [batch, seq, l_dim] processed output
201
+ l_to_h: [batch, seq, h_dim] back-projection to H-module
202
+ """
203
+ # Project to L-module dimension
204
+ x_l = self.h_to_l(x)
205
+
206
+ # Sequential processing
207
+ gru_out, _ = self.gru(x_l, h_context)
208
+
209
+ # Output processing
210
+ output = self.output_proj(gru_out)
211
+
212
+ # Back-project to H-module dimension for feedback
213
+ feedback = self.l_to_h(output)
214
+
215
+ return output, feedback
216
+
217
+
218
+ class HRMAgent(nn.Module):
219
+ """
220
+ Complete Hierarchical Reasoning Model agent.
221
+
222
+ Combines H-module and L-module with ACT for adaptive computation.
223
+ """
224
+
225
+ def __init__(self, config: HRMConfig, device: str = "cpu"):
226
+ super().__init__()
227
+ self.config = config
228
+ self.device = device
229
+
230
+ # Input embedding
231
+ self.input_proj = nn.Linear(config.h_dim, config.h_dim)
232
+
233
+ # Core modules
234
+ self.h_module = nn.ModuleList([HModule(config) for _ in range(config.num_h_layers)])
235
+
236
+ self.l_module = LModule(config)
237
+
238
+ # Adaptive computation time
239
+ self.act = AdaptiveComputationTime(config.h_dim, config.ponder_epsilon)
240
+
241
+ # State integration
242
+ self.integrate = nn.Sequential(
243
+ nn.Linear(config.h_dim * 2, config.h_dim),
244
+ nn.LayerNorm(config.h_dim),
245
+ nn.GELU(),
246
+ )
247
+
248
+ self.to(device)
249
+
250
+ def forward(
251
+ self,
252
+ x: torch.Tensor,
253
+ max_steps: int | None = None,
254
+ return_decomposition: bool = False,
255
+ ) -> HRMOutput:
256
+ """
257
+ Process input through hierarchical reasoning.
258
+
259
+ Args:
260
+ x: [batch, seq, h_dim] input tensor
261
+ max_steps: Maximum outer loop steps (defaults to config)
262
+ return_decomposition: Whether to return subproblem decomposition
263
+
264
+ Returns:
265
+ HRMOutput containing final state and optional decomposition
266
+ """
267
+ batch_size, seq_len, _ = x.shape
268
+ max_steps = max_steps or self.config.max_outer_steps
269
+
270
+ # Initial projection
271
+ h_state = self.input_proj(x)
272
+
273
+ # Tracking
274
+ subproblems = []
275
+ convergence_path = []
276
+ total_ponder_cost = 0.0
277
+
278
+ # Outer loop: iterative refinement
279
+ for step in range(max_steps):
280
+ # H-module: high-level planning
281
+ for h_layer in self.h_module:
282
+ h_state = h_layer(h_state)
283
+
284
+ # Check halting condition
285
+ halt_probs, ponder_cost = self.act(h_state)
286
+ total_ponder_cost += ponder_cost
287
+
288
+ # Average halting probability across sequence
289
+ avg_halt_prob = halt_probs.mean().item()
290
+ convergence_path.append(avg_halt_prob)
291
+
292
+ # Generate subproblem decomposition if requested
293
+ if return_decomposition:
294
+ subproblem_repr = self.h_module[0].decompose(h_state)
295
+ # Create subproblem entries (simplified)
296
+ for i in range(min(3, seq_len)): # Top 3 subproblems
297
+ subproblems.append(
298
+ SubProblem(
299
+ level=step,
300
+ description=f"Subproblem at step {step}, position {i}",
301
+ state=subproblem_repr[:, i, :].detach(),
302
+ confidence=halt_probs[:, i].mean().item(),
303
+ )
304
+ )
305
+
306
+ # Halt if confident enough
307
+ if avg_halt_prob >= self.config.halt_threshold:
308
+ break
309
+
310
+ # L-module: low-level execution
311
+ l_output, l_feedback = self.l_module(h_state)
312
+
313
+ # Integrate L-module feedback
314
+ h_state = self.integrate(torch.cat([h_state, l_feedback], dim=-1))
315
+
316
+ return HRMOutput(
317
+ final_state=h_state,
318
+ subproblems=subproblems,
319
+ halt_step=step + 1,
320
+ total_ponder_cost=total_ponder_cost,
321
+ convergence_path=convergence_path,
322
+ )
323
+
324
+ async def decompose_problem(self, query: str, state: torch.Tensor) -> list[SubProblem]:
325
+ """
326
+ Decompose a problem into hierarchical subproblems.
327
+
328
+ Args:
329
+ query: Natural language problem description
330
+ state: Initial state representation
331
+
332
+ Returns:
333
+ List of subproblems in hierarchical order
334
+ """
335
+ # Ensure state is batched
336
+ if state.dim() == 2:
337
+ state = state.unsqueeze(0) # [1, seq, dim]
338
+
339
+ # Forward pass with decomposition
340
+ output = self.forward(state, return_decomposition=True)
341
+
342
+ # Add query context to subproblems
343
+ for i, sp in enumerate(output.subproblems):
344
+ sp.description = f"{query} -> Level {sp.level} Subproblem {i}"
345
+
346
+ return output.subproblems
347
+
348
+ def get_parameter_count(self) -> int:
349
+ """Return total number of trainable parameters."""
350
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
351
+
352
+
353
+ # Training utilities
354
+ class HRMLoss(nn.Module):
355
+ """
356
+ Combined loss for HRM training.
357
+
358
+ Includes:
359
+ - Task loss (e.g., cross-entropy for classification)
360
+ - Ponder cost regularization (encourages efficiency)
361
+ - Consistency loss (encourages stable convergence)
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ task_weight: float = 1.0,
367
+ ponder_weight: float = 0.01,
368
+ consistency_weight: float = 0.1,
369
+ ):
370
+ super().__init__()
371
+ self.task_weight = task_weight
372
+ self.ponder_weight = ponder_weight
373
+ self.consistency_weight = consistency_weight
374
+
375
+ def forward(
376
+ self,
377
+ hrm_output: HRMOutput,
378
+ predictions: torch.Tensor,
379
+ targets: torch.Tensor,
380
+ task_loss_fn: nn.Module,
381
+ ) -> tuple[torch.Tensor, dict]:
382
+ """
383
+ Compute combined loss.
384
+
385
+ Args:
386
+ hrm_output: Output from HRM forward pass
387
+ predictions: Model predictions
388
+ targets: Ground truth targets
389
+ task_loss_fn: Loss function for the task
390
+
391
+ Returns:
392
+ total_loss: Combined loss
393
+ loss_dict: Dictionary of individual loss components
394
+ """
395
+ # Task loss
396
+ task_loss = task_loss_fn(predictions, targets)
397
+
398
+ # Ponder cost (encourages efficiency)
399
+ ponder_loss = hrm_output.total_ponder_cost
400
+
401
+ # Consistency loss (encourages monotonic convergence)
402
+ if len(hrm_output.convergence_path) > 1:
403
+ conv_tensor = torch.tensor(hrm_output.convergence_path)
404
+ # Penalize non-monotonic increases
405
+ diffs = conv_tensor[1:] - conv_tensor[:-1]
406
+ consistency_loss = F.relu(-diffs).mean() # Penalize decreases
407
+ else:
408
+ consistency_loss = torch.tensor(0.0)
409
+
410
+ # Combine losses
411
+ total_loss = (
412
+ self.task_weight * task_loss + self.ponder_weight * ponder_loss + self.consistency_weight * consistency_loss
413
+ )
414
+
415
+ loss_dict = {
416
+ "total": total_loss.item(),
417
+ "task": task_loss.item(),
418
+ "ponder": ponder_loss,
419
+ "consistency": consistency_loss.item(),
420
+ "halt_step": hrm_output.halt_step,
421
+ }
422
+
423
+ return total_loss, loss_dict
424
+
425
+
426
+ def create_hrm_agent(config: HRMConfig, device: str = "cpu") -> HRMAgent:
427
+ """
428
+ Factory function to create and initialize HRM agent.
429
+
430
+ Args:
431
+ config: HRM configuration
432
+ device: Device to place model on
433
+
434
+ Returns:
435
+ Initialized HRMAgent
436
+ """
437
+ agent = HRMAgent(config, device)
438
+
439
+ # Initialize weights
440
+ def init_weights(m):
441
+ if isinstance(m, nn.Linear):
442
+ nn.init.xavier_uniform_(m.weight)
443
+ if m.bias is not None:
444
+ nn.init.zeros_(m.bias)
445
+ elif isinstance(m, nn.GRU):
446
+ for name, param in m.named_parameters():
447
+ if "weight" in name:
448
+ nn.init.orthogonal_(param)
449
+ elif "bias" in name:
450
+ nn.init.zeros_(param)
451
+
452
+ agent.apply(init_weights)
453
+
454
+ return agent
src/agents/meta_controller/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Neural Meta-Controller package for Multi-Agent MCTS Framework.
3
+
4
+ This package provides the base infrastructure for neural network-based
5
+ meta-controllers that dynamically select which agent to route queries to.
6
+ """
7
+
8
+ from src.agents.meta_controller.base import (
9
+ AbstractMetaController,
10
+ MetaControllerFeatures,
11
+ MetaControllerPrediction,
12
+ )
13
+ from src.agents.meta_controller.rnn_controller import (
14
+ RNNMetaController,
15
+ RNNMetaControllerModel,
16
+ )
17
+ from src.agents.meta_controller.utils import (
18
+ features_to_tensor,
19
+ features_to_text,
20
+ normalize_features,
21
+ one_hot_encode_agent,
22
+ )
23
+
24
+ # Import BERT controller (may not be available if transformers/peft not installed)
25
+ try:
26
+ from src.agents.meta_controller.bert_controller import BERTMetaController # noqa: F401
27
+
28
+ _bert_available = True
29
+ except ImportError:
30
+ _bert_available = False
31
+
32
+ __all__ = [
33
+ "AbstractMetaController",
34
+ "MetaControllerFeatures",
35
+ "MetaControllerPrediction",
36
+ "normalize_features",
37
+ "one_hot_encode_agent",
38
+ "features_to_tensor",
39
+ "features_to_text",
40
+ "RNNMetaController",
41
+ "RNNMetaControllerModel",
42
+ ]
43
+
44
+ if _bert_available:
45
+ __all__.append("BERTMetaController")
src/agents/meta_controller/base.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class for Neural Meta-Controllers.
3
+
4
+ Provides the foundation for neural network-based meta-controllers that
5
+ dynamically select which agent (HRM, TRM, or MCTS) should handle a query.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass, field
10
+ from typing import Any
11
+
12
+
13
+ @dataclass
14
+ class MetaControllerFeatures:
15
+ """
16
+ Features extracted from the current agent state for meta-controller prediction.
17
+
18
+ These features capture the current state of the multi-agent system,
19
+ including confidence scores from different agents and contextual information.
20
+ """
21
+
22
+ hrm_confidence: float
23
+ """Confidence score from the HRM (Human Response Model) agent."""
24
+
25
+ trm_confidence: float
26
+ """Confidence score from the TRM (Task Response Model) agent."""
27
+
28
+ mcts_value: float
29
+ """Value estimate from the MCTS (Monte Carlo Tree Search) process."""
30
+
31
+ consensus_score: float
32
+ """Agreement score between different agents."""
33
+
34
+ last_agent: str
35
+ """Name of the last agent used ('hrm', 'trm', 'mcts', or 'none')."""
36
+
37
+ iteration: int
38
+ """Current iteration number in the reasoning process."""
39
+
40
+ query_length: int
41
+ """Length of the input query in characters."""
42
+
43
+ has_rag_context: bool
44
+ """Whether RAG (Retrieval-Augmented Generation) context is available."""
45
+
46
+
47
+ @dataclass
48
+ class MetaControllerPrediction:
49
+ """
50
+ Prediction output from the meta-controller.
51
+
52
+ Contains the selected agent and associated confidence/probability information.
53
+ """
54
+
55
+ agent: str
56
+ """Name of the selected agent ('hrm', 'trm', or 'mcts')."""
57
+
58
+ confidence: float
59
+ """Confidence score for the prediction (0.0 to 1.0)."""
60
+
61
+ probabilities: dict[str, float] = field(default_factory=dict)
62
+ """Probability distribution over all possible agents."""
63
+
64
+
65
+ class AbstractMetaController(ABC):
66
+ """
67
+ Abstract base class for neural meta-controllers.
68
+
69
+ This class defines the interface that all meta-controller implementations
70
+ must follow. Meta-controllers are responsible for deciding which agent
71
+ should handle a given query based on the current system state.
72
+
73
+ Attributes:
74
+ AGENT_NAMES: List of valid agent names that can be selected.
75
+ name: Name of this meta-controller instance.
76
+ seed: Random seed for reproducibility.
77
+ """
78
+
79
+ AGENT_NAMES = ["hrm", "trm", "mcts"]
80
+
81
+ def __init__(self, name: str, seed: int = 42) -> None:
82
+ """
83
+ Initialize the meta-controller.
84
+
85
+ Args:
86
+ name: Name identifier for this meta-controller instance.
87
+ seed: Random seed for reproducibility. Defaults to 42.
88
+ """
89
+ self.name = name
90
+ self.seed = seed
91
+
92
+ @abstractmethod
93
+ def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction:
94
+ """
95
+ Predict which agent should handle the current query.
96
+
97
+ Args:
98
+ features: Features extracted from the current agent state.
99
+
100
+ Returns:
101
+ Prediction containing the selected agent and confidence scores.
102
+ """
103
+ pass
104
+
105
+ @abstractmethod
106
+ def load_model(self, path: str) -> None:
107
+ """
108
+ Load a trained model from disk.
109
+
110
+ Args:
111
+ path: Path to the saved model file or directory.
112
+ """
113
+ pass
114
+
115
+ @abstractmethod
116
+ def save_model(self, path: str) -> None:
117
+ """
118
+ Save the current model to disk.
119
+
120
+ Args:
121
+ path: Path where the model should be saved.
122
+ """
123
+ pass
124
+
125
+ def extract_features(self, state: dict[str, Any]) -> MetaControllerFeatures:
126
+ """
127
+ Extract meta-controller features from an AgentState dictionary.
128
+
129
+ This method converts raw state information into the structured
130
+ MetaControllerFeatures format required for prediction.
131
+
132
+ Args:
133
+ state: Dictionary containing agent state information.
134
+ Expected keys include:
135
+ - 'hrm_confidence' or nested in 'agent_confidences'
136
+ - 'trm_confidence' or nested in 'agent_confidences'
137
+ - 'mcts_value' or nested in 'mcts_state'
138
+ - 'consensus_score'
139
+ - 'last_agent'
140
+ - 'iteration'
141
+ - 'query' or 'query_length'
142
+ - 'rag_context' or 'has_rag_context'
143
+
144
+ Returns:
145
+ MetaControllerFeatures instance with extracted values.
146
+
147
+ Example:
148
+ >>> state = {
149
+ ... 'agent_confidences': {'hrm': 0.8, 'trm': 0.6},
150
+ ... 'mcts_state': {'value': 0.75},
151
+ ... 'consensus_score': 0.7,
152
+ ... 'last_agent': 'hrm',
153
+ ... 'iteration': 2,
154
+ ... 'query': 'What is machine learning?',
155
+ ... 'rag_context': 'ML is a subset of AI...'
156
+ ... }
157
+ >>> features = controller.extract_features(state)
158
+ """
159
+ # Extract HRM confidence
160
+ if "hrm_confidence" in state:
161
+ hrm_confidence = float(state["hrm_confidence"])
162
+ elif "agent_confidences" in state and isinstance(state["agent_confidences"], dict):
163
+ hrm_confidence = float(state["agent_confidences"].get("hrm", 0.0))
164
+ else:
165
+ hrm_confidence = 0.0
166
+
167
+ # Extract TRM confidence
168
+ if "trm_confidence" in state:
169
+ trm_confidence = float(state["trm_confidence"])
170
+ elif "agent_confidences" in state and isinstance(state["agent_confidences"], dict):
171
+ trm_confidence = float(state["agent_confidences"].get("trm", 0.0))
172
+ else:
173
+ trm_confidence = 0.0
174
+
175
+ # Extract MCTS value
176
+ if "mcts_value" in state:
177
+ mcts_value = float(state["mcts_value"])
178
+ elif "mcts_state" in state and isinstance(state["mcts_state"], dict):
179
+ mcts_value = float(state["mcts_state"].get("value", 0.0))
180
+ else:
181
+ mcts_value = 0.0
182
+
183
+ # Extract consensus score
184
+ consensus_score = float(state.get("consensus_score", 0.0))
185
+
186
+ # Extract last agent
187
+ last_agent = str(state.get("last_agent", "none"))
188
+ if last_agent not in self.AGENT_NAMES and last_agent != "none":
189
+ last_agent = "none"
190
+
191
+ # Extract iteration
192
+ iteration = int(state.get("iteration", 0))
193
+
194
+ # Extract query length
195
+ if "query_length" in state:
196
+ query_length = int(state["query_length"])
197
+ elif "query" in state and isinstance(state["query"], str):
198
+ query_length = len(state["query"])
199
+ else:
200
+ query_length = 0
201
+
202
+ # Extract has_rag_context
203
+ if "has_rag_context" in state:
204
+ has_rag_context = bool(state["has_rag_context"])
205
+ elif "rag_context" in state:
206
+ has_rag_context = state["rag_context"] is not None and len(str(state["rag_context"])) > 0
207
+ else:
208
+ has_rag_context = False
209
+
210
+ return MetaControllerFeatures(
211
+ hrm_confidence=hrm_confidence,
212
+ trm_confidence=trm_confidence,
213
+ mcts_value=mcts_value,
214
+ consensus_score=consensus_score,
215
+ last_agent=last_agent,
216
+ iteration=iteration,
217
+ query_length=query_length,
218
+ has_rag_context=has_rag_context,
219
+ )
src/agents/meta_controller/bert_controller.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT-based Meta-Controller with LoRA adapters for efficient fine-tuning.
3
+
4
+ This module provides a BERT-based meta-controller that uses Low-Rank Adaptation (LoRA)
5
+ for parameter-efficient fine-tuning. The controller converts agent state features into
6
+ text and uses a sequence classification model to predict the optimal agent.
7
+ """
8
+
9
+ import warnings
10
+ from typing import Any
11
+
12
+ import torch
13
+
14
+ from src.agents.meta_controller.base import (
15
+ AbstractMetaController,
16
+ MetaControllerFeatures,
17
+ MetaControllerPrediction,
18
+ )
19
+ from src.agents.meta_controller.utils import features_to_text
20
+
21
+ # Handle optional transformers and peft imports gracefully
22
+ _TRANSFORMERS_AVAILABLE = False
23
+ _PEFT_AVAILABLE = False
24
+
25
+ try:
26
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
27
+
28
+ _TRANSFORMERS_AVAILABLE = True
29
+ except ImportError:
30
+ warnings.warn(
31
+ "transformers library not installed. Install it with: pip install transformers",
32
+ ImportWarning,
33
+ stacklevel=2,
34
+ )
35
+ AutoTokenizer = None # type: ignore
36
+ AutoModelForSequenceClassification = None # type: ignore
37
+
38
+ try:
39
+ from peft import LoraConfig, TaskType, get_peft_model
40
+
41
+ _PEFT_AVAILABLE = True
42
+ except ImportError:
43
+ warnings.warn(
44
+ "peft library not installed. Install it with: pip install peft",
45
+ ImportWarning,
46
+ stacklevel=2,
47
+ )
48
+ LoraConfig = None # type: ignore
49
+ TaskType = None # type: ignore
50
+ get_peft_model = None # type: ignore
51
+
52
+
53
+ class BERTMetaController(AbstractMetaController):
54
+ """
55
+ BERT-based meta-controller with optional LoRA adapters for efficient fine-tuning.
56
+
57
+ This controller converts agent state features into structured text and uses
58
+ a pre-trained BERT model (with optional LoRA adapters) to classify which
59
+ agent should handle the current query. LoRA enables parameter-efficient
60
+ fine-tuning by only training low-rank decomposition matrices.
61
+
62
+ Attributes:
63
+ DEFAULT_MODEL_NAME: Default BERT model to use.
64
+ NUM_LABELS: Number of output labels (agents to choose from).
65
+ device: PyTorch device for tensor operations.
66
+ model_name: Name of the pre-trained model.
67
+ lora_r: LoRA rank parameter.
68
+ lora_alpha: LoRA alpha scaling parameter.
69
+ lora_dropout: LoRA dropout rate.
70
+ use_lora: Whether to use LoRA adapters.
71
+ tokenizer: BERT tokenizer for text processing.
72
+ model: BERT sequence classification model (with or without LoRA).
73
+
74
+ Example:
75
+ >>> controller = BERTMetaController(name="BERTController", seed=42)
76
+ >>> features = MetaControllerFeatures(
77
+ ... hrm_confidence=0.8,
78
+ ... trm_confidence=0.6,
79
+ ... mcts_value=0.75,
80
+ ... consensus_score=0.7,
81
+ ... last_agent='hrm',
82
+ ... iteration=2,
83
+ ... query_length=150,
84
+ ... has_rag_context=True
85
+ ... )
86
+ >>> prediction = controller.predict(features)
87
+ >>> prediction.agent in ['hrm', 'trm', 'mcts']
88
+ True
89
+ >>> 0.0 <= prediction.confidence <= 1.0
90
+ True
91
+ """
92
+
93
+ DEFAULT_MODEL_NAME = "prajjwal1/bert-mini"
94
+ NUM_LABELS = 3
95
+
96
+ def __init__(
97
+ self,
98
+ name: str = "BERTMetaController",
99
+ seed: int = 42,
100
+ model_name: str | None = None,
101
+ lora_r: int = 4,
102
+ lora_alpha: int = 16,
103
+ lora_dropout: float = 0.1,
104
+ device: str | None = None,
105
+ use_lora: bool = True,
106
+ ) -> None:
107
+ """
108
+ Initialize the BERT meta-controller with optional LoRA adapters.
109
+
110
+ Args:
111
+ name: Name identifier for this controller. Defaults to "BERTMetaController".
112
+ seed: Random seed for reproducibility. Defaults to 42.
113
+ model_name: Pre-trained model name from HuggingFace. If None, uses DEFAULT_MODEL_NAME.
114
+ lora_r: LoRA rank parameter (lower = more compression). Defaults to 4.
115
+ lora_alpha: LoRA alpha scaling parameter. Defaults to 16.
116
+ lora_dropout: Dropout rate for LoRA layers. Defaults to 0.1.
117
+ device: Device to run model on ('cpu', 'cuda', 'mps', etc.).
118
+ If None, auto-detects best available device.
119
+ use_lora: Whether to apply LoRA adapters to the model. Defaults to True.
120
+
121
+ Raises:
122
+ ImportError: If transformers library is not installed.
123
+ ImportError: If use_lora is True and peft library is not installed.
124
+
125
+ Example:
126
+ >>> controller = BERTMetaController(
127
+ ... name="CustomBERT",
128
+ ... seed=123,
129
+ ... lora_r=8,
130
+ ... lora_alpha=32,
131
+ ... use_lora=True
132
+ ... )
133
+ """
134
+ super().__init__(name=name, seed=seed)
135
+
136
+ # Check for required dependencies
137
+ if not _TRANSFORMERS_AVAILABLE:
138
+ raise ImportError(
139
+ "transformers library is required for BERTMetaController. Install it with: pip install transformers"
140
+ )
141
+
142
+ if use_lora and not _PEFT_AVAILABLE:
143
+ raise ImportError("peft library is required for LoRA support. Install it with: pip install peft")
144
+
145
+ # Set random seed for reproducibility
146
+ torch.manual_seed(seed)
147
+
148
+ # Auto-detect device if not specified
149
+ if device is None:
150
+ if torch.cuda.is_available():
151
+ self.device = torch.device("cuda")
152
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
153
+ self.device = torch.device("mps")
154
+ else:
155
+ self.device = torch.device("cpu")
156
+ else:
157
+ self.device = torch.device(device)
158
+
159
+ # Store configuration parameters
160
+ self.model_name = model_name if model_name is not None else self.DEFAULT_MODEL_NAME
161
+ self.lora_r = lora_r
162
+ self.lora_alpha = lora_alpha
163
+ self.lora_dropout = lora_dropout
164
+ self.use_lora = use_lora
165
+
166
+ # Initialize tokenizer
167
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
168
+
169
+ # Initialize base model for sequence classification
170
+ base_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.NUM_LABELS)
171
+
172
+ # Apply LoRA adapters if requested
173
+ if self.use_lora:
174
+ lora_config = LoraConfig(
175
+ task_type=TaskType.SEQ_CLS,
176
+ r=self.lora_r,
177
+ lora_alpha=self.lora_alpha,
178
+ lora_dropout=self.lora_dropout,
179
+ target_modules=["query", "value"],
180
+ )
181
+ self.model = get_peft_model(base_model, lora_config)
182
+ else:
183
+ self.model = base_model
184
+
185
+ # Move model to device
186
+ self.model = self.model.to(self.device)
187
+
188
+ # Set model to evaluation mode
189
+ self.model.eval()
190
+
191
+ # Initialize tokenization cache for performance optimization
192
+ self._tokenization_cache: dict[str, Any] = {}
193
+
194
+ def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction:
195
+ """
196
+ Predict which agent should handle the current query.
197
+
198
+ Converts features to structured text, tokenizes the text, runs through
199
+ the BERT model, and returns a prediction with confidence scores.
200
+
201
+ Args:
202
+ features: Features extracted from the current agent state.
203
+
204
+ Returns:
205
+ Prediction containing the selected agent, confidence score,
206
+ and probability distribution over all agents.
207
+
208
+ Example:
209
+ >>> controller = BERTMetaController()
210
+ >>> features = MetaControllerFeatures(
211
+ ... hrm_confidence=0.9,
212
+ ... trm_confidence=0.3,
213
+ ... mcts_value=0.5,
214
+ ... consensus_score=0.8,
215
+ ... last_agent='none',
216
+ ... iteration=0,
217
+ ... query_length=100,
218
+ ... has_rag_context=False
219
+ ... )
220
+ >>> pred = controller.predict(features)
221
+ >>> isinstance(pred.agent, str)
222
+ True
223
+ >>> isinstance(pred.confidence, float)
224
+ True
225
+ >>> len(pred.probabilities) == 3
226
+ True
227
+ """
228
+ # Convert features to structured text
229
+ text = features_to_text(features)
230
+
231
+ # Check cache for tokenized text
232
+ if text in self._tokenization_cache:
233
+ inputs = self._tokenization_cache[text]
234
+ else:
235
+ # Tokenize the text
236
+ inputs = self.tokenizer(
237
+ text,
238
+ return_tensors="pt",
239
+ padding=True,
240
+ truncation=True,
241
+ max_length=512,
242
+ )
243
+ # Cache the tokenized result
244
+ self._tokenization_cache[text] = inputs
245
+
246
+ # Move inputs to device
247
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
248
+
249
+ # Perform inference without gradient tracking
250
+ with torch.no_grad():
251
+ # Get logits from model
252
+ outputs = self.model(**inputs)
253
+ logits = outputs.logits
254
+
255
+ # Apply softmax to get probabilities
256
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
257
+
258
+ # Get predicted agent index (argmax)
259
+ predicted_idx = torch.argmax(probabilities, dim=-1).item()
260
+
261
+ # Extract confidence for selected agent
262
+ confidence = probabilities[0, predicted_idx].item()
263
+
264
+ # Create probability dictionary
265
+ prob_dict: dict[str, float] = {}
266
+ for i, agent_name in enumerate(self.AGENT_NAMES):
267
+ prob_dict[agent_name] = probabilities[0, i].item()
268
+
269
+ # Get agent name
270
+ selected_agent = self.AGENT_NAMES[predicted_idx]
271
+
272
+ return MetaControllerPrediction(
273
+ agent=selected_agent,
274
+ confidence=float(confidence),
275
+ probabilities=prob_dict,
276
+ )
277
+
278
+ def load_model(self, path: str) -> None:
279
+ """
280
+ Load a trained model from disk.
281
+
282
+ For LoRA models, loads the PEFT adapter weights. For base models,
283
+ loads the full state dictionary.
284
+
285
+ Args:
286
+ path: Path to the saved model file or directory.
287
+ For LoRA models, this should be a directory containing
288
+ adapter_config.json and adapter_model.bin.
289
+ For base models, this should be a .pt or .pth file.
290
+
291
+ Raises:
292
+ FileNotFoundError: If the model file or directory does not exist.
293
+ RuntimeError: If the state dict is incompatible with the model.
294
+
295
+ Example:
296
+ >>> controller = BERTMetaController(use_lora=True)
297
+ >>> controller.load_model("/path/to/lora_adapter")
298
+ >>> controller = BERTMetaController(use_lora=False)
299
+ >>> controller.load_model("/path/to/model.pt")
300
+ """
301
+ if self.use_lora:
302
+ # Load PEFT adapter weights
303
+ # For PEFT models, the path should be a directory containing adapter files
304
+ from peft import PeftModel
305
+
306
+ # Get the base model from the PEFT wrapper
307
+ base_model = self.model.get_base_model()
308
+
309
+ # Load the PEFT model from the saved path
310
+ self.model = PeftModel.from_pretrained(base_model, path)
311
+ self.model = self.model.to(self.device)
312
+ else:
313
+ # Load base model state dict
314
+ state_dict = torch.load(path, map_location=self.device, weights_only=True)
315
+ self.model.load_state_dict(state_dict)
316
+
317
+ # Ensure model is in evaluation mode
318
+ self.model.eval()
319
+
320
+ def save_model(self, path: str) -> None:
321
+ """
322
+ Save the current model to disk.
323
+
324
+ For LoRA models, saves the PEFT adapter weights. For base models,
325
+ saves the full state dictionary.
326
+
327
+ Args:
328
+ path: Path where the model should be saved.
329
+ For LoRA models, this should be a directory path where
330
+ adapter_config.json and adapter_model.bin will be saved.
331
+ For base models, this should be a .pt or .pth file path.
332
+
333
+ Example:
334
+ >>> controller = BERTMetaController(use_lora=True)
335
+ >>> controller.save_model("/path/to/lora_adapter")
336
+ >>> controller = BERTMetaController(use_lora=False)
337
+ >>> controller.save_model("/path/to/model.pt")
338
+ """
339
+ if self.use_lora:
340
+ # Save PEFT adapter weights
341
+ # This saves only the LoRA adapter weights, not the full model
342
+ self.model.save_pretrained(path)
343
+ else:
344
+ # Save base model state dict
345
+ torch.save(self.model.state_dict(), path)
346
+
347
+ def clear_cache(self) -> None:
348
+ """
349
+ Clear the tokenization cache.
350
+
351
+ This method removes all cached tokenized inputs, freeing memory.
352
+ Useful when processing many different feature combinations or
353
+ when memory usage is a concern.
354
+
355
+ Example:
356
+ >>> controller = BERTMetaController()
357
+ >>> # After many predictions...
358
+ >>> controller.clear_cache()
359
+ >>> info = controller.get_cache_info()
360
+ >>> info['cache_size'] == 0
361
+ True
362
+ """
363
+ self._tokenization_cache.clear()
364
+
365
+ def get_cache_info(self) -> dict[str, Any]:
366
+ """
367
+ Get information about the current tokenization cache.
368
+
369
+ Returns:
370
+ Dictionary containing cache statistics:
371
+ - cache_size: Number of cached tokenizations
372
+ - cache_keys: List of cached text inputs (truncated for display)
373
+
374
+ Example:
375
+ >>> controller = BERTMetaController()
376
+ >>> features = MetaControllerFeatures(
377
+ ... hrm_confidence=0.8,
378
+ ... trm_confidence=0.6,
379
+ ... mcts_value=0.75,
380
+ ... consensus_score=0.7,
381
+ ... last_agent='hrm',
382
+ ... iteration=2,
383
+ ... query_length=150,
384
+ ... has_rag_context=True
385
+ ... )
386
+ >>> _ = controller.predict(features)
387
+ >>> info = controller.get_cache_info()
388
+ >>> 'cache_size' in info
389
+ True
390
+ >>> info['cache_size'] >= 1
391
+ True
392
+ """
393
+ # Truncate keys for display (first 50 chars)
394
+ truncated_keys = [key[:50] + "..." if len(key) > 50 else key for key in self._tokenization_cache]
395
+
396
+ return {
397
+ "cache_size": len(self._tokenization_cache),
398
+ "cache_keys": truncated_keys,
399
+ }
400
+
401
+ def get_trainable_parameters(self) -> dict[str, int]:
402
+ """
403
+ Get the number of trainable and total parameters in the model.
404
+
405
+ This is particularly useful for LoRA models to see the efficiency
406
+ gains from using low-rank adaptation.
407
+
408
+ Returns:
409
+ Dictionary containing:
410
+ - total_params: Total number of parameters in the model
411
+ - trainable_params: Number of trainable parameters
412
+ - trainable_percentage: Percentage of parameters that are trainable
413
+
414
+ Example:
415
+ >>> controller = BERTMetaController(use_lora=True)
416
+ >>> params = controller.get_trainable_parameters()
417
+ >>> params['trainable_percentage'] < 10.0 # LoRA trains <10% of params
418
+ True
419
+ """
420
+ total_params = sum(p.numel() for p in self.model.parameters())
421
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
422
+ trainable_percentage = (trainable_params / total_params) * 100 if total_params > 0 else 0.0
423
+
424
+ return {
425
+ "total_params": total_params,
426
+ "trainable_params": trainable_params,
427
+ "trainable_percentage": round(trainable_percentage, 2),
428
+ }
src/agents/meta_controller/config_loader.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration loader for the Neural Meta-Controller framework.
3
+
4
+ This module provides dataclass-based configuration management for the Meta-Controller,
5
+ supporting both RNN and BERT-based neural network controllers with comprehensive
6
+ validation and serialization capabilities.
7
+ """
8
+
9
+ from dataclasses import asdict, dataclass, field
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import yaml
14
+
15
+
16
+ @dataclass
17
+ class RNNConfig:
18
+ """
19
+ Configuration for RNN-based Meta-Controller.
20
+
21
+ Attributes:
22
+ hidden_dim: Hidden dimension size for RNN layers. Default is 64.
23
+ num_layers: Number of RNN layers. Default is 1.
24
+ dropout: Dropout rate for regularization. Default is 0.1.
25
+ model_path: Optional path to a pre-trained model file. None for untrained model.
26
+ """
27
+
28
+ hidden_dim: int = 64
29
+ num_layers: int = 1
30
+ dropout: float = 0.1
31
+ model_path: str | None = None
32
+
33
+
34
+ @dataclass
35
+ class BERTConfig:
36
+ """
37
+ Configuration for BERT-based Meta-Controller with LoRA fine-tuning.
38
+
39
+ Attributes:
40
+ model_name: Name of the pre-trained BERT model from HuggingFace.
41
+ Default is "prajjwal1/bert-mini" for lightweight deployment.
42
+ use_lora: Whether to use LoRA (Low-Rank Adaptation) for efficient fine-tuning.
43
+ Default is True.
44
+ lora_r: LoRA rank parameter. Controls the rank of the low-rank matrices.
45
+ Default is 4.
46
+ lora_alpha: LoRA alpha parameter. Scaling factor for LoRA weights.
47
+ Default is 16.
48
+ lora_dropout: Dropout rate for LoRA layers. Default is 0.1.
49
+ model_path: Optional path to a trained LoRA adapter. None for base model only.
50
+ """
51
+
52
+ model_name: str = "prajjwal1/bert-mini"
53
+ use_lora: bool = True
54
+ lora_r: int = 4
55
+ lora_alpha: int = 16
56
+ lora_dropout: float = 0.1
57
+ model_path: str | None = None
58
+
59
+
60
+ @dataclass
61
+ class InferenceConfig:
62
+ """
63
+ Configuration for inference settings.
64
+
65
+ Attributes:
66
+ device: Device to use for inference ("cpu", "cuda", "cuda:0", etc.).
67
+ None for auto-detection based on available hardware.
68
+ seed: Random seed for reproducibility. Default is 42.
69
+ """
70
+
71
+ device: str | None = None
72
+ seed: int = 42
73
+
74
+
75
+ @dataclass
76
+ class MetaControllerConfig:
77
+ """
78
+ Main configuration for the Neural Meta-Controller framework.
79
+
80
+ This configuration controls the behavior of the Meta-Controller, including
81
+ which type of neural network to use (RNN or BERT), fallback behavior,
82
+ and specific model parameters.
83
+
84
+ Attributes:
85
+ enabled: Whether the neural Meta-Controller is enabled. Default is False
86
+ for backward compatibility with rule-based systems.
87
+ type: Type of neural network controller ("rnn" or "bert"). Default is "rnn".
88
+ fallback_to_rule_based: Whether to fall back to rule-based selection on errors.
89
+ Default is True for robustness.
90
+ rnn: Configuration for RNN-based controller.
91
+ bert: Configuration for BERT-based controller.
92
+ inference: Configuration for inference settings.
93
+ """
94
+
95
+ enabled: bool = False
96
+ type: str = "rnn" # "rnn" or "bert"
97
+ fallback_to_rule_based: bool = True
98
+ rnn: RNNConfig = field(default_factory=RNNConfig)
99
+ bert: BERTConfig = field(default_factory=BERTConfig)
100
+ inference: InferenceConfig = field(default_factory=InferenceConfig)
101
+
102
+
103
+ class MetaControllerConfigLoader:
104
+ """
105
+ Loader class for Meta-Controller configuration.
106
+
107
+ Provides methods for loading configuration from YAML files or dictionaries,
108
+ converting configuration to dictionaries, and validating configuration values.
109
+
110
+ Example:
111
+ >>> loader = MetaControllerConfigLoader()
112
+ >>> config = loader.load_from_yaml("config/meta_controller.yaml")
113
+ >>> print(config.type)
114
+ 'rnn'
115
+ >>> config.validate()
116
+ """
117
+
118
+ @staticmethod
119
+ def load_from_yaml(path: str) -> MetaControllerConfig:
120
+ """
121
+ Load Meta-Controller configuration from a YAML file.
122
+
123
+ Args:
124
+ path: Path to the YAML configuration file.
125
+
126
+ Returns:
127
+ MetaControllerConfig: Loaded and parsed configuration object.
128
+
129
+ Raises:
130
+ FileNotFoundError: If the specified file does not exist.
131
+ yaml.YAMLError: If the file contains invalid YAML.
132
+ KeyError: If the 'meta_controller' key is missing from the file.
133
+
134
+ Example:
135
+ >>> config = MetaControllerConfigLoader.load_from_yaml("config/meta_controller.yaml")
136
+ >>> print(config.enabled)
137
+ False
138
+ """
139
+ yaml_path = Path(path)
140
+
141
+ if not yaml_path.exists():
142
+ raise FileNotFoundError(f"Configuration file not found: {path}")
143
+
144
+ with open(yaml_path) as f:
145
+ raw_config = yaml.safe_load(f)
146
+
147
+ if "meta_controller" not in raw_config:
148
+ raise KeyError("Configuration file must contain 'meta_controller' key")
149
+
150
+ return MetaControllerConfigLoader.load_from_dict(raw_config["meta_controller"])
151
+
152
+ @staticmethod
153
+ def load_from_dict(config_dict: dict[str, Any]) -> MetaControllerConfig:
154
+ """
155
+ Load Meta-Controller configuration from a dictionary.
156
+
157
+ Args:
158
+ config_dict: Dictionary containing configuration values.
159
+
160
+ Returns:
161
+ MetaControllerConfig: Parsed configuration object with defaults
162
+ applied for missing values.
163
+
164
+ Example:
165
+ >>> config_dict = {
166
+ ... 'enabled': True,
167
+ ... 'type': 'bert',
168
+ ... 'bert': {'model_name': 'bert-base-uncased'}
169
+ ... }
170
+ >>> config = MetaControllerConfigLoader.load_from_dict(config_dict)
171
+ >>> print(config.type)
172
+ 'bert'
173
+ """
174
+ # Parse nested configurations
175
+ rnn_config = RNNConfig(**config_dict.get("rnn", {}))
176
+ bert_config = BERTConfig(**config_dict.get("bert", {}))
177
+ inference_config = InferenceConfig(**config_dict.get("inference", {}))
178
+
179
+ # Create main config with nested configs
180
+ return MetaControllerConfig(
181
+ enabled=config_dict.get("enabled", False),
182
+ type=config_dict.get("type", "rnn"),
183
+ fallback_to_rule_based=config_dict.get("fallback_to_rule_based", True),
184
+ rnn=rnn_config,
185
+ bert=bert_config,
186
+ inference=inference_config,
187
+ )
188
+
189
+ @staticmethod
190
+ def to_dict(config: MetaControllerConfig) -> dict[str, Any]:
191
+ """
192
+ Convert a MetaControllerConfig object to a dictionary.
193
+
194
+ Args:
195
+ config: MetaControllerConfig object to convert.
196
+
197
+ Returns:
198
+ Dict[str, Any]: Dictionary representation of the configuration.
199
+
200
+ Example:
201
+ >>> config = MetaControllerConfig(enabled=True, type='bert')
202
+ >>> config_dict = MetaControllerConfigLoader.to_dict(config)
203
+ >>> print(config_dict['enabled'])
204
+ True
205
+ """
206
+ return asdict(config)
207
+
208
+ @staticmethod
209
+ def validate(config: MetaControllerConfig) -> None:
210
+ """
211
+ Validate the Meta-Controller configuration.
212
+
213
+ Checks that:
214
+ - The controller type is valid ("rnn" or "bert")
215
+ - Model paths exist if specified
216
+ - Numeric parameters are within valid ranges
217
+
218
+ Args:
219
+ config: MetaControllerConfig object to validate.
220
+
221
+ Raises:
222
+ ValueError: If the configuration contains invalid values.
223
+ FileNotFoundError: If specified model paths do not exist.
224
+
225
+ Example:
226
+ >>> config = MetaControllerConfig(type='invalid')
227
+ >>> MetaControllerConfigLoader.validate(config)
228
+ ValueError: Invalid controller type 'invalid'. Must be 'rnn' or 'bert'.
229
+ """
230
+ # Validate controller type
231
+ valid_types = ["rnn", "bert"]
232
+ if config.type not in valid_types:
233
+ raise ValueError(f"Invalid controller type '{config.type}'. Must be one of: {valid_types}")
234
+
235
+ # Validate RNN config
236
+ if config.rnn.hidden_dim <= 0:
237
+ raise ValueError(f"RNN hidden_dim must be positive, got {config.rnn.hidden_dim}")
238
+ if config.rnn.num_layers <= 0:
239
+ raise ValueError(f"RNN num_layers must be positive, got {config.rnn.num_layers}")
240
+ if not 0.0 <= config.rnn.dropout <= 1.0:
241
+ raise ValueError(f"RNN dropout must be between 0 and 1, got {config.rnn.dropout}")
242
+ if config.rnn.model_path is not None:
243
+ rnn_path = Path(config.rnn.model_path)
244
+ if not rnn_path.exists():
245
+ raise FileNotFoundError(f"RNN model path does not exist: {config.rnn.model_path}")
246
+
247
+ # Validate BERT config
248
+ if config.bert.lora_r <= 0:
249
+ raise ValueError(f"BERT lora_r must be positive, got {config.bert.lora_r}")
250
+ if config.bert.lora_alpha <= 0:
251
+ raise ValueError(f"BERT lora_alpha must be positive, got {config.bert.lora_alpha}")
252
+ if not 0.0 <= config.bert.lora_dropout <= 1.0:
253
+ raise ValueError(f"BERT lora_dropout must be between 0 and 1, got {config.bert.lora_dropout}")
254
+ if config.bert.model_path is not None:
255
+ bert_path = Path(config.bert.model_path)
256
+ if not bert_path.exists():
257
+ raise FileNotFoundError(f"BERT model path does not exist: {config.bert.model_path}")
258
+
259
+ # Validate inference config
260
+ if config.inference.device is not None:
261
+ valid_devices = ["cpu", "cuda", "mps"]
262
+ # Check if device starts with a valid prefix (e.g., "cuda:0", "cuda:1")
263
+ device_base = config.inference.device.split(":")[0]
264
+ if device_base not in valid_devices:
265
+ raise ValueError(f"Invalid device '{config.inference.device}'. Must start with one of: {valid_devices}")
266
+
267
+ if not isinstance(config.inference.seed, int) or config.inference.seed < 0:
268
+ raise ValueError(f"Inference seed must be a non-negative integer, got {config.inference.seed}")
269
+
270
+ @staticmethod
271
+ def save_to_yaml(config: MetaControllerConfig, path: str) -> None:
272
+ """
273
+ Save a MetaControllerConfig object to a YAML file.
274
+
275
+ Args:
276
+ config: MetaControllerConfig object to save.
277
+ path: Path where the YAML file will be saved.
278
+
279
+ Example:
280
+ >>> config = MetaControllerConfig(enabled=True)
281
+ >>> MetaControllerConfigLoader.save_to_yaml(config, "my_config.yaml")
282
+ """
283
+ yaml_path = Path(path)
284
+ yaml_path.parent.mkdir(parents=True, exist_ok=True)
285
+
286
+ config_dict = {"meta_controller": MetaControllerConfigLoader.to_dict(config)}
287
+
288
+ with open(yaml_path, "w") as f:
289
+ yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
290
+
291
+ @staticmethod
292
+ def get_default_config() -> MetaControllerConfig:
293
+ """
294
+ Get a default MetaControllerConfig with all default values.
295
+
296
+ Returns:
297
+ MetaControllerConfig: Configuration object with default values.
298
+
299
+ Example:
300
+ >>> config = MetaControllerConfigLoader.get_default_config()
301
+ >>> print(config.enabled)
302
+ False
303
+ """
304
+ return MetaControllerConfig()
src/agents/meta_controller/rnn_controller.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RNN-based Meta-Controller for dynamic agent selection.
3
+
4
+ This module provides a GRU-based recurrent neural network meta-controller
5
+ that learns to select the optimal agent (HRM, TRM, or MCTS) based on
6
+ sequential patterns in the agent state features.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from src.agents.meta_controller.base import (
14
+ AbstractMetaController,
15
+ MetaControllerFeatures,
16
+ MetaControllerPrediction,
17
+ )
18
+ from src.agents.meta_controller.utils import features_to_tensor
19
+
20
+
21
+ class RNNMetaControllerModel(nn.Module):
22
+ """
23
+ GRU-based neural network model for meta-controller predictions.
24
+
25
+ This model uses a Gated Recurrent Unit (GRU) to capture sequential
26
+ patterns in agent state features and predict which agent should be
27
+ selected next.
28
+
29
+ Architecture:
30
+ - GRU layer for sequence processing
31
+ - Dropout for regularization
32
+ - Linear layer for classification
33
+
34
+ Attributes:
35
+ gru: GRU recurrent layer for processing sequences.
36
+ dropout: Dropout layer for regularization.
37
+ fc: Fully connected output layer.
38
+ hidden_dim: Dimension of the hidden state.
39
+ num_layers: Number of GRU layers.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ input_dim: int = 10,
45
+ hidden_dim: int = 64,
46
+ num_layers: int = 1,
47
+ num_agents: int = 3,
48
+ dropout: float = 0.1,
49
+ ) -> None:
50
+ """
51
+ Initialize the RNN meta-controller model.
52
+
53
+ Args:
54
+ input_dim: Dimension of input features. Defaults to 10.
55
+ hidden_dim: Dimension of GRU hidden state. Defaults to 64.
56
+ num_layers: Number of stacked GRU layers. Defaults to 1.
57
+ num_agents: Number of agents to choose from. Defaults to 3.
58
+ dropout: Dropout probability for regularization. Defaults to 0.1.
59
+ """
60
+ super().__init__()
61
+
62
+ self.hidden_dim = hidden_dim
63
+ self.num_layers = num_layers
64
+
65
+ # GRU layer for sequence processing
66
+ self.gru = nn.GRU(
67
+ input_size=input_dim,
68
+ hidden_size=hidden_dim,
69
+ num_layers=num_layers,
70
+ batch_first=True,
71
+ dropout=dropout if num_layers > 1 else 0.0,
72
+ )
73
+
74
+ # Dropout for regularization
75
+ self.dropout = nn.Dropout(p=dropout)
76
+
77
+ # Linear output layer for classification
78
+ self.fc = nn.Linear(hidden_dim, num_agents)
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ Forward pass through the model.
83
+
84
+ Processes input features through GRU and produces agent selection logits.
85
+
86
+ Args:
87
+ x: Input tensor of shape (batch_size, features) or
88
+ (batch_size, seq_len, features).
89
+
90
+ Returns:
91
+ Logits tensor of shape (batch_size, num_agents).
92
+ Note: Returns raw logits, NOT softmax probabilities.
93
+
94
+ Example:
95
+ >>> model = RNNMetaControllerModel()
96
+ >>> x = torch.randn(4, 10) # batch of 4, 10 features
97
+ >>> logits = model(x)
98
+ >>> logits.shape
99
+ torch.Size([4, 3])
100
+ """
101
+ # Handle 2D input by adding sequence dimension
102
+ if x.dim() == 2:
103
+ # Shape: (batch_size, features) -> (batch_size, 1, features)
104
+ x = x.unsqueeze(1)
105
+
106
+ # Pass through GRU
107
+ # output shape: (batch_size, seq_len, hidden_dim)
108
+ # hidden shape: (num_layers, batch_size, hidden_dim)
109
+ output, hidden = self.gru(x)
110
+
111
+ # Take the final hidden state from the last layer
112
+ # Shape: (batch_size, hidden_dim)
113
+ final_hidden = hidden[-1] if self.num_layers > 1 else hidden.squeeze(0)
114
+
115
+ # Apply dropout
116
+ dropped = self.dropout(final_hidden)
117
+
118
+ # Apply linear layer to get logits
119
+ logits = self.fc(dropped)
120
+
121
+ return logits
122
+
123
+
124
+ class RNNMetaController(AbstractMetaController):
125
+ """
126
+ RNN-based meta-controller using GRU for agent selection.
127
+
128
+ This controller uses a recurrent neural network to learn patterns in
129
+ agent state sequences and predict the optimal agent for the current
130
+ situation. It supports both CPU and GPU execution.
131
+
132
+ Attributes:
133
+ device: PyTorch device (CPU or CUDA) for tensor operations.
134
+ hidden_dim: Dimension of GRU hidden state.
135
+ num_layers: Number of GRU layers.
136
+ dropout: Dropout probability.
137
+ model: The underlying RNNMetaControllerModel.
138
+ hidden_state: Optional hidden state for sequence tracking.
139
+
140
+ Example:
141
+ >>> controller = RNNMetaController(name="RNNController", seed=42)
142
+ >>> features = MetaControllerFeatures(
143
+ ... hrm_confidence=0.8,
144
+ ... trm_confidence=0.6,
145
+ ... mcts_value=0.75,
146
+ ... consensus_score=0.7,
147
+ ... last_agent='hrm',
148
+ ... iteration=2,
149
+ ... query_length=150,
150
+ ... has_rag_context=True
151
+ ... )
152
+ >>> prediction = controller.predict(features)
153
+ >>> prediction.agent in ['hrm', 'trm', 'mcts']
154
+ True
155
+ >>> 0.0 <= prediction.confidence <= 1.0
156
+ True
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ name: str = "RNNMetaController",
162
+ seed: int = 42,
163
+ hidden_dim: int = 64,
164
+ num_layers: int = 1,
165
+ dropout: float = 0.1,
166
+ device: str | None = None,
167
+ ) -> None:
168
+ """
169
+ Initialize the RNN meta-controller.
170
+
171
+ Args:
172
+ name: Name identifier for this controller. Defaults to "RNNMetaController".
173
+ seed: Random seed for reproducibility. Defaults to 42.
174
+ hidden_dim: Dimension of GRU hidden state. Defaults to 64.
175
+ num_layers: Number of GRU layers. Defaults to 1.
176
+ dropout: Dropout probability. Defaults to 0.1.
177
+ device: Device to run model on ('cpu', 'cuda', 'mps', etc.).
178
+ If None, auto-detects best available device.
179
+ """
180
+ super().__init__(name=name, seed=seed)
181
+
182
+ # Set random seed for reproducibility
183
+ torch.manual_seed(seed)
184
+
185
+ # Auto-detect device if not specified
186
+ if device is None:
187
+ if torch.cuda.is_available():
188
+ self.device = torch.device("cuda")
189
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
190
+ self.device = torch.device("mps")
191
+ else:
192
+ self.device = torch.device("cpu")
193
+ else:
194
+ self.device = torch.device(device)
195
+
196
+ # Store configuration
197
+ self.hidden_dim = hidden_dim
198
+ self.num_layers = num_layers
199
+ self.dropout = dropout
200
+
201
+ # Initialize model
202
+ self.model = RNNMetaControllerModel(
203
+ input_dim=10, # Fixed based on features_to_tensor output
204
+ hidden_dim=hidden_dim,
205
+ num_layers=num_layers,
206
+ num_agents=len(self.AGENT_NAMES),
207
+ dropout=dropout,
208
+ )
209
+
210
+ # Move model to device
211
+ self.model = self.model.to(self.device)
212
+
213
+ # Set model to evaluation mode
214
+ self.model.eval()
215
+
216
+ # Initialize hidden state for sequence tracking
217
+ self.hidden_state: torch.Tensor | None = None
218
+
219
+ def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction:
220
+ """
221
+ Predict which agent should handle the current query.
222
+
223
+ Converts features to tensor format, runs through the GRU model,
224
+ and returns a prediction with confidence scores.
225
+
226
+ Args:
227
+ features: Features extracted from the current agent state.
228
+
229
+ Returns:
230
+ Prediction containing the selected agent, confidence score,
231
+ and probability distribution over all agents.
232
+
233
+ Example:
234
+ >>> controller = RNNMetaController()
235
+ >>> features = MetaControllerFeatures(
236
+ ... hrm_confidence=0.9,
237
+ ... trm_confidence=0.3,
238
+ ... mcts_value=0.5,
239
+ ... consensus_score=0.8,
240
+ ... last_agent='none',
241
+ ... iteration=0,
242
+ ... query_length=100,
243
+ ... has_rag_context=False
244
+ ... )
245
+ >>> pred = controller.predict(features)
246
+ >>> isinstance(pred.agent, str)
247
+ True
248
+ >>> isinstance(pred.confidence, float)
249
+ True
250
+ >>> len(pred.probabilities) == 3
251
+ True
252
+ """
253
+ # Convert features to tensor
254
+ feature_tensor = features_to_tensor(features)
255
+
256
+ # Add batch dimension: (10,) -> (1, 10)
257
+ feature_tensor = feature_tensor.unsqueeze(0)
258
+
259
+ # Move to device
260
+ feature_tensor = feature_tensor.to(self.device)
261
+
262
+ # Perform inference without gradient tracking
263
+ with torch.no_grad():
264
+ # Get logits from model
265
+ logits = self.model(feature_tensor)
266
+
267
+ # Apply softmax to get probabilities
268
+ probabilities = F.softmax(logits, dim=-1)
269
+
270
+ # Get predicted agent index (argmax)
271
+ predicted_idx = torch.argmax(probabilities, dim=-1).item()
272
+
273
+ # Extract confidence for selected agent
274
+ confidence = probabilities[0, predicted_idx].item()
275
+
276
+ # Create probability dictionary
277
+ prob_dict: dict[str, float] = {}
278
+ for i, agent_name in enumerate(self.AGENT_NAMES):
279
+ prob_dict[agent_name] = probabilities[0, i].item()
280
+
281
+ # Get agent name
282
+ selected_agent = self.AGENT_NAMES[predicted_idx]
283
+
284
+ return MetaControllerPrediction(
285
+ agent=selected_agent,
286
+ confidence=float(confidence),
287
+ probabilities=prob_dict,
288
+ )
289
+
290
+ def load_model(self, path: str) -> None:
291
+ """
292
+ Load a trained model from disk.
293
+
294
+ Loads the model state dictionary from the specified path and
295
+ sets the model to evaluation mode.
296
+
297
+ Args:
298
+ path: Path to the saved model file (.pt or .pth).
299
+
300
+ Raises:
301
+ FileNotFoundError: If the model file does not exist.
302
+ RuntimeError: If the state dict is incompatible with the model.
303
+
304
+ Example:
305
+ >>> controller = RNNMetaController()
306
+ >>> controller.load_model("/path/to/model.pt")
307
+ """
308
+ # Load state dict with appropriate device mapping
309
+ state_dict = torch.load(path, map_location=self.device, weights_only=True)
310
+
311
+ # Load into model
312
+ self.model.load_state_dict(state_dict)
313
+
314
+ # Ensure model is in evaluation mode
315
+ self.model.eval()
316
+
317
+ def save_model(self, path: str) -> None:
318
+ """
319
+ Save the current model to disk.
320
+
321
+ Saves the model state dictionary to the specified path.
322
+
323
+ Args:
324
+ path: Path where the model should be saved (.pt or .pth).
325
+
326
+ Example:
327
+ >>> controller = RNNMetaController()
328
+ >>> controller.save_model("/path/to/model.pt")
329
+ """
330
+ torch.save(self.model.state_dict(), path)
331
+
332
+ def reset_hidden_state(self) -> None:
333
+ """
334
+ Reset the hidden state for sequence tracking.
335
+
336
+ This method clears any accumulated hidden state, useful when
337
+ starting a new conversation or resetting the controller state.
338
+
339
+ Example:
340
+ >>> controller = RNNMetaController()
341
+ >>> controller.reset_hidden_state()
342
+ >>> controller.hidden_state is None
343
+ True
344
+ """
345
+ self.hidden_state = None
src/agents/meta_controller/utils.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for Neural Meta-Controller feature processing.
3
+
4
+ Provides functions for normalizing, encoding, and converting features
5
+ into formats suitable for different neural network architectures.
6
+ """
7
+
8
+ import torch
9
+
10
+ from src.agents.meta_controller.base import MetaControllerFeatures
11
+
12
+
13
+ def normalize_features(features: MetaControllerFeatures) -> list[float]:
14
+ """
15
+ Normalize meta-controller features to a 10-dimensional vector in range [0, 1].
16
+
17
+ The normalization strategy:
18
+ - Confidence scores (hrm, trm, mcts_value, consensus): Already 0-1, clipped
19
+ - last_agent: Encoded as 3 one-hot values (hrm=0, trm=1, mcts=2)
20
+ - iteration: Normalized to 0-1 assuming max 20 iterations
21
+ - query_length: Normalized to 0-1 assuming max 10000 characters
22
+ - has_rag_context: Binary 0 or 1
23
+
24
+ Output vector structure (10 dimensions):
25
+ [hrm_conf, trm_conf, mcts_value, consensus, last_hrm, last_trm, last_mcts,
26
+ iteration_norm, query_length_norm, has_rag]
27
+
28
+ Args:
29
+ features: MetaControllerFeatures instance to normalize.
30
+
31
+ Returns:
32
+ List of 10 floats, each normalized to range [0, 1].
33
+
34
+ Example:
35
+ >>> features = MetaControllerFeatures(
36
+ ... hrm_confidence=0.8,
37
+ ... trm_confidence=0.6,
38
+ ... mcts_value=0.75,
39
+ ... consensus_score=0.7,
40
+ ... last_agent='hrm',
41
+ ... iteration=2,
42
+ ... query_length=150,
43
+ ... has_rag_context=True
44
+ ... )
45
+ >>> normalized = normalize_features(features)
46
+ >>> len(normalized)
47
+ 10
48
+ >>> all(0.0 <= v <= 1.0 for v in normalized)
49
+ True
50
+ """
51
+ # Clip confidence scores to [0, 1]
52
+ hrm_conf = max(0.0, min(1.0, features.hrm_confidence))
53
+ trm_conf = max(0.0, min(1.0, features.trm_confidence))
54
+ mcts_val = max(0.0, min(1.0, features.mcts_value))
55
+ consensus = max(0.0, min(1.0, features.consensus_score))
56
+
57
+ # One-hot encode last_agent (3 dimensions)
58
+ last_agent_onehot = one_hot_encode_agent(features.last_agent)
59
+
60
+ # Normalize iteration (assuming max 20 iterations)
61
+ max_iterations = 20
62
+ iteration_norm = max(0.0, min(1.0, features.iteration / max_iterations))
63
+
64
+ # Normalize query length (assuming max 10000 characters)
65
+ max_query_length = 10000
66
+ query_length_norm = max(0.0, min(1.0, features.query_length / max_query_length))
67
+
68
+ # Binary for has_rag_context
69
+ has_rag = 1.0 if features.has_rag_context else 0.0
70
+
71
+ # Combine into 10-dimensional vector
72
+ return [
73
+ hrm_conf,
74
+ trm_conf,
75
+ mcts_val,
76
+ consensus,
77
+ last_agent_onehot[0], # hrm
78
+ last_agent_onehot[1], # trm
79
+ last_agent_onehot[2], # mcts
80
+ iteration_norm,
81
+ query_length_norm,
82
+ has_rag,
83
+ ]
84
+
85
+
86
+ def one_hot_encode_agent(agent: str) -> list[float]:
87
+ """
88
+ One-hot encode an agent name into a 3-dimensional vector.
89
+
90
+ Encoding:
91
+ - 'hrm' -> [1.0, 0.0, 0.0]
92
+ - 'trm' -> [0.0, 1.0, 0.0]
93
+ - 'mcts' -> [0.0, 0.0, 1.0]
94
+ - 'none' or other -> [0.0, 0.0, 0.0]
95
+
96
+ Args:
97
+ agent: Agent name string ('hrm', 'trm', 'mcts', or 'none').
98
+
99
+ Returns:
100
+ List of 3 floats representing the one-hot encoding.
101
+
102
+ Example:
103
+ >>> one_hot_encode_agent('hrm')
104
+ [1.0, 0.0, 0.0]
105
+ >>> one_hot_encode_agent('trm')
106
+ [0.0, 1.0, 0.0]
107
+ >>> one_hot_encode_agent('mcts')
108
+ [0.0, 0.0, 1.0]
109
+ >>> one_hot_encode_agent('none')
110
+ [0.0, 0.0, 0.0]
111
+ """
112
+ agent_lower = agent.lower()
113
+
114
+ if agent_lower == "hrm": # noqa: SIM116
115
+ return [1.0, 0.0, 0.0]
116
+ elif agent_lower == "trm":
117
+ return [0.0, 1.0, 0.0]
118
+ elif agent_lower == "mcts":
119
+ return [0.0, 0.0, 1.0]
120
+ else:
121
+ # 'none' or unknown agent
122
+ return [0.0, 0.0, 0.0]
123
+
124
+
125
+ def features_to_tensor(features: MetaControllerFeatures) -> torch.Tensor:
126
+ """
127
+ Convert meta-controller features to a PyTorch tensor.
128
+
129
+ Uses normalize_features internally to create a normalized 10-dimensional
130
+ tensor suitable for neural network input.
131
+
132
+ Args:
133
+ features: MetaControllerFeatures instance to convert.
134
+
135
+ Returns:
136
+ PyTorch tensor of shape (10,) with float32 dtype.
137
+
138
+ Example:
139
+ >>> features = MetaControllerFeatures(
140
+ ... hrm_confidence=0.8,
141
+ ... trm_confidence=0.6,
142
+ ... mcts_value=0.75,
143
+ ... consensus_score=0.7,
144
+ ... last_agent='hrm',
145
+ ... iteration=2,
146
+ ... query_length=150,
147
+ ... has_rag_context=True
148
+ ... )
149
+ >>> tensor = features_to_tensor(features)
150
+ >>> tensor.shape
151
+ torch.Size([10])
152
+ >>> tensor.dtype
153
+ torch.float32
154
+ """
155
+ normalized = normalize_features(features)
156
+ return torch.tensor(normalized, dtype=torch.float32)
157
+
158
+
159
+ def features_to_text(features: MetaControllerFeatures) -> str:
160
+ """
161
+ Convert meta-controller features to structured text format.
162
+
163
+ Creates a human-readable text representation suitable for text-based
164
+ models like BERT or other language models.
165
+
166
+ Args:
167
+ features: MetaControllerFeatures instance to convert.
168
+
169
+ Returns:
170
+ Structured text string describing the features.
171
+
172
+ Example:
173
+ >>> features = MetaControllerFeatures(
174
+ ... hrm_confidence=0.8,
175
+ ... trm_confidence=0.6,
176
+ ... mcts_value=0.75,
177
+ ... consensus_score=0.7,
178
+ ... last_agent='hrm',
179
+ ... iteration=2,
180
+ ... query_length=150,
181
+ ... has_rag_context=True
182
+ ... )
183
+ >>> text = features_to_text(features)
184
+ >>> 'HRM confidence: 0.800' in text
185
+ True
186
+ """
187
+ rag_status = "available" if features.has_rag_context else "not available"
188
+
189
+ text = (
190
+ f"Agent State Features:\n"
191
+ f"HRM confidence: {features.hrm_confidence:.3f}\n"
192
+ f"TRM confidence: {features.trm_confidence:.3f}\n"
193
+ f"MCTS value: {features.mcts_value:.3f}\n"
194
+ f"Consensus score: {features.consensus_score:.3f}\n"
195
+ f"Last agent used: {features.last_agent}\n"
196
+ f"Current iteration: {features.iteration}\n"
197
+ f"Query length: {features.query_length} characters\n"
198
+ f"RAG context: {rag_status}"
199
+ )
200
+
201
+ return text
src/agents/trm_agent.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tiny Recursive Model (TRM) Agent.
3
+
4
+ Implements recursive refinement with:
5
+ - Deep supervision at all recursion levels
6
+ - Convergence detection
7
+ - Memory-efficient recursion
8
+ - Iterative improvement mechanism
9
+
10
+ Based on principles from:
11
+ - "Recursive Refinement Networks"
12
+ - "Deep Supervision for Neural Networks"
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ..training.system_config import TRMConfig
23
+
24
+
25
+ @dataclass
26
+ class TRMOutput:
27
+ """Output from TRM recursive processing."""
28
+
29
+ final_prediction: torch.Tensor # Final refined output
30
+ intermediate_predictions: list[torch.Tensor] # Predictions at each recursion
31
+ recursion_depth: int # Actual depth used
32
+ converged: bool # Whether convergence was achieved
33
+ convergence_step: int # Step at which convergence occurred
34
+ residual_norms: list[float] # L2 norms of residuals at each step
35
+
36
+
37
+ class RecursiveBlock(nn.Module):
38
+ """
39
+ Core recursive processing block.
40
+
41
+ Applies the same transformation repeatedly, with residual connections.
42
+ """
43
+
44
+ def __init__(self, config: TRMConfig):
45
+ super().__init__()
46
+ self.config = config
47
+
48
+ # Main processing pathway
49
+ self.transform = nn.Sequential(
50
+ nn.Linear(config.latent_dim, config.hidden_dim),
51
+ nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(),
52
+ nn.GELU(),
53
+ nn.Dropout(config.dropout),
54
+ nn.Linear(config.hidden_dim, config.latent_dim),
55
+ nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(),
56
+ )
57
+
58
+ # Residual scaling (learned)
59
+ self.residual_scale = nn.Parameter(torch.ones(1))
60
+
61
+ def forward(self, x: torch.Tensor, iteration: int = 0) -> torch.Tensor: # noqa: ARG002
62
+ """
63
+ Apply recursive transformation.
64
+
65
+ Args:
66
+ x: Input tensor [batch, ..., latent_dim]
67
+ iteration: Current recursion iteration (reserved for future iteration-dependent behavior)
68
+
69
+ Returns:
70
+ Refined tensor [batch, ..., latent_dim]
71
+ """
72
+ # Residual connection with learned scaling
73
+ residual = self.transform(x)
74
+ return x + self.residual_scale * residual
75
+
76
+
77
+ class DeepSupervisionHead(nn.Module):
78
+ """
79
+ Supervision head for intermediate predictions.
80
+
81
+ Enables training signal at each recursion level.
82
+ """
83
+
84
+ def __init__(self, latent_dim: int, output_dim: int):
85
+ super().__init__()
86
+ self.head = nn.Sequential(
87
+ nn.Linear(latent_dim, latent_dim // 2),
88
+ nn.ReLU(),
89
+ nn.Linear(latent_dim // 2, output_dim),
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ """Generate prediction from latent state."""
94
+ return self.head(x)
95
+
96
+
97
+ class TRMAgent(nn.Module):
98
+ """
99
+ Tiny Recursive Model for iterative refinement.
100
+
101
+ Features:
102
+ - Shared weights across recursions (parameter efficiency)
103
+ - Deep supervision at all levels
104
+ - Automatic convergence detection
105
+ - Residual connections for stable gradients
106
+ """
107
+
108
+ def __init__(self, config: TRMConfig, output_dim: int | None = None, device: str = "cpu"):
109
+ super().__init__()
110
+ self.config = config
111
+ self.device = device
112
+ self.output_dim = output_dim or config.latent_dim
113
+
114
+ # Initial encoding
115
+ self.encoder = nn.Sequential(
116
+ nn.Linear(config.latent_dim, config.hidden_dim),
117
+ nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(),
118
+ nn.GELU(),
119
+ nn.Linear(config.hidden_dim, config.latent_dim),
120
+ nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(),
121
+ )
122
+
123
+ # Shared recursive block
124
+ self.recursive_block = RecursiveBlock(config)
125
+
126
+ # Deep supervision heads (one per recursion level)
127
+ if config.deep_supervision:
128
+ self.supervision_heads = nn.ModuleList(
129
+ [DeepSupervisionHead(config.latent_dim, self.output_dim) for _ in range(config.num_recursions)]
130
+ )
131
+ else:
132
+ # Single output head
133
+ self.output_head = DeepSupervisionHead(config.latent_dim, self.output_dim)
134
+
135
+ self.to(device)
136
+
137
+ def forward(
138
+ self,
139
+ x: torch.Tensor,
140
+ num_recursions: int | None = None,
141
+ check_convergence: bool = True,
142
+ ) -> TRMOutput:
143
+ """
144
+ Process input through recursive refinement.
145
+
146
+ Args:
147
+ x: Input tensor [batch, ..., latent_dim]
148
+ num_recursions: Number of recursions (defaults to config)
149
+ check_convergence: Whether to check for early convergence
150
+
151
+ Returns:
152
+ TRMOutput with final and intermediate predictions
153
+ """
154
+ num_recursions = num_recursions or self.config.num_recursions
155
+
156
+ # Initial encoding
157
+ latent = self.encoder(x)
158
+ previous_latent = latent.clone()
159
+
160
+ # Tracking
161
+ intermediate_predictions = []
162
+ residual_norms = []
163
+ converged = False
164
+ convergence_step = num_recursions
165
+
166
+ # Recursive refinement
167
+ for i in range(num_recursions):
168
+ # Apply recursive transformation
169
+ latent = self.recursive_block(latent, iteration=i)
170
+
171
+ # Generate intermediate prediction
172
+ if self.config.deep_supervision and i < len(self.supervision_heads):
173
+ pred = self.supervision_heads[i](latent)
174
+ else:
175
+ pred = self.output_head(latent)
176
+
177
+ intermediate_predictions.append(pred)
178
+
179
+ # Check convergence
180
+ if check_convergence and i >= self.config.min_recursions:
181
+ residual = latent - previous_latent
182
+ residual_norm = torch.norm(residual, p=2, dim=-1).mean().item()
183
+ residual_norms.append(residual_norm)
184
+
185
+ if residual_norm < self.config.convergence_threshold:
186
+ converged = True
187
+ convergence_step = i + 1
188
+ break
189
+
190
+ previous_latent = latent.clone()
191
+
192
+ # Final prediction
193
+ final_pred = intermediate_predictions[-1]
194
+
195
+ return TRMOutput(
196
+ final_prediction=final_pred,
197
+ intermediate_predictions=intermediate_predictions,
198
+ recursion_depth=len(intermediate_predictions),
199
+ converged=converged,
200
+ convergence_step=convergence_step,
201
+ residual_norms=residual_norms,
202
+ )
203
+
204
+ async def refine_solution(
205
+ self,
206
+ initial_prediction: torch.Tensor,
207
+ num_recursions: int | None = None,
208
+ convergence_threshold: float | None = None,
209
+ ) -> tuple[torch.Tensor, dict]:
210
+ """
211
+ Refine an initial prediction through recursive processing.
212
+
213
+ Args:
214
+ initial_prediction: Initial solution [batch, ..., latent_dim]
215
+ num_recursions: Maximum recursions (optional)
216
+ convergence_threshold: Convergence threshold (optional)
217
+
218
+ Returns:
219
+ refined_solution: Final refined prediction
220
+ info: Dictionary with refinement metadata
221
+ """
222
+ # Temporarily override convergence threshold if provided
223
+ original_threshold = self.config.convergence_threshold
224
+ if convergence_threshold is not None:
225
+ self.config.convergence_threshold = convergence_threshold
226
+
227
+ # Process
228
+ output = self.forward(
229
+ initial_prediction,
230
+ num_recursions=num_recursions,
231
+ check_convergence=True,
232
+ )
233
+
234
+ # Restore original threshold
235
+ self.config.convergence_threshold = original_threshold
236
+
237
+ info = {
238
+ "converged": output.converged,
239
+ "convergence_step": output.convergence_step,
240
+ "total_recursions": output.recursion_depth,
241
+ "final_residual": output.residual_norms[-1] if output.residual_norms else None,
242
+ "refinement_path": output.residual_norms,
243
+ }
244
+
245
+ return output.final_prediction, info
246
+
247
+ def get_parameter_count(self) -> int:
248
+ """Return total number of trainable parameters."""
249
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
250
+
251
+
252
+ class TRMLoss(nn.Module):
253
+ """
254
+ Deep supervision loss for TRM.
255
+
256
+ Applies weighted supervision at all recursion levels,
257
+ with exponential decay for deeper levels.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ task_loss_fn: nn.Module,
263
+ supervision_weight_decay: float = 0.5,
264
+ final_weight: float = 1.0,
265
+ ):
266
+ """
267
+ Initialize TRM loss.
268
+
269
+ Args:
270
+ task_loss_fn: Base loss function (e.g., MSE, CrossEntropy)
271
+ supervision_weight_decay: Decay factor for intermediate losses
272
+ final_weight: Weight for final prediction loss
273
+ """
274
+ super().__init__()
275
+ self.task_loss_fn = task_loss_fn
276
+ self.supervision_weight_decay = supervision_weight_decay
277
+ self.final_weight = final_weight
278
+
279
+ def forward(self, trm_output: TRMOutput, targets: torch.Tensor) -> tuple[torch.Tensor, dict]:
280
+ """
281
+ Compute deep supervision loss.
282
+
283
+ Args:
284
+ trm_output: Output from TRM forward pass
285
+ targets: Ground truth targets
286
+
287
+ Returns:
288
+ total_loss: Combined loss
289
+ loss_dict: Dictionary of loss components
290
+ """
291
+ # Final prediction loss (highest weight)
292
+ final_loss = self.task_loss_fn(trm_output.final_prediction, targets)
293
+ total_loss = self.final_weight * final_loss
294
+
295
+ # Intermediate supervision losses
296
+ intermediate_losses = []
297
+ num_intermediate = len(trm_output.intermediate_predictions) - 1
298
+
299
+ for i, pred in enumerate(trm_output.intermediate_predictions[:-1]):
300
+ # Exponential decay: earlier predictions get lower weight
301
+ weight = self.supervision_weight_decay ** (num_intermediate - i)
302
+ loss = self.task_loss_fn(pred, targets)
303
+ intermediate_losses.append(loss.item())
304
+ total_loss = total_loss + weight * loss
305
+
306
+ loss_dict = {
307
+ "total": total_loss.item(),
308
+ "final": final_loss.item(),
309
+ "intermediate_mean": (sum(intermediate_losses) / len(intermediate_losses) if intermediate_losses else 0.0),
310
+ "recursion_depth": trm_output.recursion_depth,
311
+ "converged": trm_output.converged,
312
+ "convergence_step": trm_output.convergence_step,
313
+ }
314
+
315
+ return total_loss, loss_dict
316
+
317
+
318
+ def create_trm_agent(config: TRMConfig, output_dim: int | None = None, device: str = "cpu") -> TRMAgent:
319
+ """
320
+ Factory function to create and initialize TRM agent.
321
+
322
+ Args:
323
+ config: TRM configuration
324
+ output_dim: Output dimension (defaults to latent_dim)
325
+ device: Device to place model on
326
+
327
+ Returns:
328
+ Initialized TRMAgent
329
+ """
330
+ agent = TRMAgent(config, output_dim, device)
331
+
332
+ # Initialize weights with Xavier/He initialization
333
+ def init_weights(m):
334
+ if isinstance(m, nn.Linear):
335
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
336
+ if m.bias is not None:
337
+ nn.init.zeros_(m.bias)
338
+
339
+ agent.apply(init_weights)
340
+
341
+ return agent
342
+
343
+
344
+ # Utility functions for integration
345
+ class TRMRefinementWrapper:
346
+ """
347
+ Wrapper for using TRM as a refinement step in pipelines.
348
+
349
+ Provides a clean interface for integrating TRM into larger systems.
350
+ """
351
+
352
+ def __init__(self, trm_agent: TRMAgent, device: str = "cpu"):
353
+ self.trm_agent = trm_agent
354
+ self.device = device
355
+ self.trm_agent.eval()
356
+
357
+ @torch.no_grad()
358
+ async def refine(
359
+ self,
360
+ predictions: torch.Tensor,
361
+ num_iterations: int = 10,
362
+ return_path: bool = False,
363
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
364
+ """
365
+ Refine predictions using TRM.
366
+
367
+ Args:
368
+ predictions: Initial predictions to refine
369
+ num_iterations: Number of refinement iterations
370
+ return_path: Whether to return intermediate predictions
371
+
372
+ Returns:
373
+ refined_predictions or (refined_predictions, refinement_path)
374
+ """
375
+ # Ensure predictions are on correct device
376
+ predictions = predictions.to(self.device)
377
+
378
+ # Run TRM
379
+ output = self.trm_agent(predictions, num_recursions=num_iterations, check_convergence=True)
380
+
381
+ if return_path:
382
+ return output.final_prediction, output.intermediate_predictions
383
+ return output.final_prediction
384
+
385
+ def get_refinement_stats(self, predictions: torch.Tensor) -> dict:
386
+ """Get statistics about the refinement process."""
387
+ with torch.no_grad():
388
+ output = self.trm_agent(predictions, check_convergence=True)
389
+
390
+ return {
391
+ "converged": output.converged,
392
+ "steps_to_convergence": output.convergence_step,
393
+ "final_residual": (output.residual_norms[-1] if output.residual_norms else None),
394
+ "total_refinement_iterations": output.recursion_depth,
395
+ }
src/api/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API module for LangGraph Multi-Agent MCTS Framework.
3
+
4
+ Provides:
5
+ - Authentication and authorization
6
+ - Rate limiting
7
+ - Error handling
8
+ - REST API endpoints
9
+ """
10
+
11
+ from src.api.exceptions import (
12
+ AuthenticationError,
13
+ AuthorizationError,
14
+ ConfigurationError,
15
+ FrameworkError,
16
+ LLMError,
17
+ MCTSError,
18
+ RAGError,
19
+ RateLimitError,
20
+ TimeoutError,
21
+ ValidationError,
22
+ )
23
+
24
+ __all__ = [
25
+ "FrameworkError",
26
+ "ValidationError",
27
+ "AuthenticationError",
28
+ "AuthorizationError",
29
+ "RateLimitError",
30
+ "LLMError",
31
+ "MCTSError",
32
+ "RAGError",
33
+ "TimeoutError",
34
+ "ConfigurationError",
35
+ ]
src/api/auth.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication and authorization layer for LangGraph Multi-Agent MCTS Framework.
3
+
4
+ Provides:
5
+ - API key authentication with secure hashing
6
+ - JWT token support (optional)
7
+ - Rate limiting per client
8
+ - Role-based access control
9
+ """
10
+
11
+ import hashlib
12
+ import secrets
13
+ import time
14
+ from collections import defaultdict
15
+ from dataclasses import dataclass, field
16
+ from datetime import datetime, timedelta
17
+
18
+ from src.api.exceptions import (
19
+ AuthenticationError,
20
+ AuthorizationError,
21
+ RateLimitError,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class ClientInfo:
27
+ """Information about an authenticated client."""
28
+
29
+ client_id: str
30
+ roles: set[str] = field(default_factory=lambda: {"user"})
31
+ created_at: datetime = field(default_factory=datetime.utcnow)
32
+ last_access: datetime = field(default_factory=datetime.utcnow)
33
+ request_count: int = 0
34
+
35
+
36
+ @dataclass
37
+ class RateLimitConfig:
38
+ """Rate limiting configuration."""
39
+
40
+ requests_per_minute: int = 60
41
+ requests_per_hour: int = 1000
42
+ requests_per_day: int = 10000
43
+ burst_limit: int = 100 # Max requests in 1 second
44
+
45
+
46
+ class APIKeyAuthenticator:
47
+ """
48
+ API key-based authentication with secure hashing.
49
+
50
+ Keys are stored as SHA-256 hashes to prevent exposure.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ valid_keys: list[str] | None = None,
56
+ rate_limit_config: RateLimitConfig | None = None,
57
+ ):
58
+ """
59
+ Initialize authenticator.
60
+
61
+ Args:
62
+ valid_keys: List of valid API keys (will be hashed)
63
+ rate_limit_config: Rate limiting configuration
64
+ """
65
+ self._key_to_client: dict[str, ClientInfo] = {}
66
+ self._rate_limits: dict[str, list[float]] = defaultdict(list)
67
+ self.rate_limit_config = rate_limit_config or RateLimitConfig()
68
+
69
+ # Hash and store initial keys
70
+ if valid_keys:
71
+ for i, key in enumerate(valid_keys):
72
+ client_id = f"client_{i}"
73
+ self._add_key(key, client_id)
74
+
75
+ def _hash_key(self, api_key: str) -> str:
76
+ """
77
+ Securely hash an API key.
78
+
79
+ Uses SHA-256 with consistent encoding.
80
+ """
81
+ return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
82
+
83
+ def _add_key(self, api_key: str, client_id: str, roles: set[str] | None = None) -> None:
84
+ """
85
+ Add a new API key.
86
+
87
+ Args:
88
+ api_key: Raw API key
89
+ client_id: Client identifier
90
+ roles: Set of roles (defaults to {"user"})
91
+ """
92
+ key_hash = self._hash_key(api_key)
93
+ self._key_to_client[key_hash] = ClientInfo(
94
+ client_id=client_id,
95
+ roles=roles or {"user"},
96
+ )
97
+
98
+ def authenticate(self, api_key: str | None) -> ClientInfo:
99
+ """
100
+ Authenticate an API key.
101
+
102
+ Args:
103
+ api_key: API key to validate
104
+
105
+ Returns:
106
+ ClientInfo for the authenticated client
107
+
108
+ Raises:
109
+ AuthenticationError: If authentication fails
110
+ """
111
+ if not api_key:
112
+ raise AuthenticationError(
113
+ user_message="API key is required",
114
+ internal_details="No API key provided in request",
115
+ )
116
+
117
+ # Constant-time comparison to prevent timing attacks
118
+ key_hash = self._hash_key(api_key)
119
+
120
+ if key_hash not in self._key_to_client:
121
+ raise AuthenticationError(
122
+ user_message="Invalid API key",
123
+ internal_details=f"API key hash not found: {key_hash[:16]}...",
124
+ )
125
+
126
+ client_info = self._key_to_client[key_hash]
127
+ client_info.last_access = datetime.utcnow()
128
+ client_info.request_count += 1
129
+
130
+ # Check rate limits
131
+ self._check_rate_limit(client_info.client_id)
132
+
133
+ return client_info
134
+
135
+ def _check_rate_limit(self, client_id: str) -> None:
136
+ """
137
+ Check if client has exceeded rate limits.
138
+
139
+ Args:
140
+ client_id: Client identifier
141
+
142
+ Raises:
143
+ RateLimitError: If rate limit exceeded
144
+ """
145
+ now = time.time()
146
+ request_times = self._rate_limits[client_id]
147
+
148
+ # Clean old entries
149
+ one_day_ago = now - 86400
150
+ request_times = [t for t in request_times if t > one_day_ago]
151
+ self._rate_limits[client_id] = request_times
152
+
153
+ # Check burst limit (1 second window)
154
+ one_second_ago = now - 1
155
+ burst_count = sum(1 for t in request_times if t > one_second_ago)
156
+ if burst_count >= self.rate_limit_config.burst_limit:
157
+ raise RateLimitError(
158
+ user_message="Too many requests. Please slow down.",
159
+ internal_details=f"Client {client_id} exceeded burst limit: {burst_count}/{self.rate_limit_config.burst_limit}",
160
+ retry_after_seconds=1,
161
+ )
162
+
163
+ # Check per-minute limit
164
+ one_minute_ago = now - 60
165
+ minute_count = sum(1 for t in request_times if t > one_minute_ago)
166
+ if minute_count >= self.rate_limit_config.requests_per_minute:
167
+ raise RateLimitError(
168
+ user_message="Rate limit exceeded. Please wait a minute.",
169
+ internal_details=f"Client {client_id} exceeded minute limit: {minute_count}/{self.rate_limit_config.requests_per_minute}",
170
+ retry_after_seconds=60,
171
+ )
172
+
173
+ # Check per-hour limit
174
+ one_hour_ago = now - 3600
175
+ hour_count = sum(1 for t in request_times if t > one_hour_ago)
176
+ if hour_count >= self.rate_limit_config.requests_per_hour:
177
+ raise RateLimitError(
178
+ user_message="Hourly rate limit exceeded. Please try again later.",
179
+ internal_details=f"Client {client_id} exceeded hour limit: {hour_count}/{self.rate_limit_config.requests_per_hour}",
180
+ retry_after_seconds=3600,
181
+ )
182
+
183
+ # Check per-day limit
184
+ day_count = len(request_times)
185
+ if day_count >= self.rate_limit_config.requests_per_day:
186
+ raise RateLimitError(
187
+ user_message="Daily rate limit exceeded. Please try again tomorrow.",
188
+ internal_details=f"Client {client_id} exceeded day limit: {day_count}/{self.rate_limit_config.requests_per_day}",
189
+ retry_after_seconds=86400,
190
+ )
191
+
192
+ # Record this request
193
+ request_times.append(now)
194
+
195
+ def require_auth(self, api_key: str | None) -> ClientInfo:
196
+ """
197
+ Require authentication for a request.
198
+
199
+ Convenience method that raises on failure.
200
+
201
+ Args:
202
+ api_key: API key to validate
203
+
204
+ Returns:
205
+ ClientInfo for authenticated client
206
+
207
+ Raises:
208
+ AuthenticationError: If authentication fails
209
+ """
210
+ return self.authenticate(api_key)
211
+
212
+ def require_role(self, client_info: ClientInfo, required_role: str) -> None:
213
+ """
214
+ Require a specific role for an operation.
215
+
216
+ Args:
217
+ client_info: Authenticated client info
218
+ required_role: Role that is required
219
+
220
+ Raises:
221
+ AuthorizationError: If client doesn't have required role
222
+ """
223
+ if required_role not in client_info.roles:
224
+ raise AuthorizationError(
225
+ user_message="You do not have permission for this operation",
226
+ internal_details=f"Client {client_info.client_id} missing role: {required_role}",
227
+ required_permission=required_role,
228
+ )
229
+
230
+ def generate_api_key(self) -> str:
231
+ """
232
+ Generate a secure random API key.
233
+
234
+ Returns:
235
+ New API key (32 bytes hex = 64 characters)
236
+ """
237
+ return secrets.token_hex(32)
238
+
239
+ def revoke_key(self, api_key: str) -> bool:
240
+ """
241
+ Revoke an API key.
242
+
243
+ Args:
244
+ api_key: Key to revoke
245
+
246
+ Returns:
247
+ True if key was revoked, False if not found
248
+ """
249
+ key_hash = self._hash_key(api_key)
250
+ if key_hash in self._key_to_client:
251
+ del self._key_to_client[key_hash]
252
+ return True
253
+ return False
254
+
255
+ def add_client(
256
+ self,
257
+ client_id: str,
258
+ roles: set[str] | None = None,
259
+ ) -> str:
260
+ """
261
+ Add a new client and generate their API key.
262
+
263
+ Args:
264
+ client_id: Unique client identifier
265
+ roles: Set of roles for the client
266
+
267
+ Returns:
268
+ Generated API key (save this securely!)
269
+ """
270
+ api_key = self.generate_api_key()
271
+ self._add_key(api_key, client_id, roles)
272
+ return api_key
273
+
274
+ def get_client_stats(self, client_id: str) -> dict:
275
+ """
276
+ Get statistics for a client.
277
+
278
+ Args:
279
+ client_id: Client identifier
280
+
281
+ Returns:
282
+ Dictionary with client statistics
283
+ """
284
+ now = time.time()
285
+ request_times = self._rate_limits.get(client_id, [])
286
+
287
+ return {
288
+ "total_requests_today": len([t for t in request_times if t > now - 86400]),
289
+ "requests_last_hour": len([t for t in request_times if t > now - 3600]),
290
+ "requests_last_minute": len([t for t in request_times if t > now - 60]),
291
+ }
292
+
293
+
294
+ class JWTAuthenticator:
295
+ """
296
+ JWT token-based authentication.
297
+
298
+ Note: Requires PyJWT library for full functionality.
299
+ This is a placeholder for JWT support.
300
+ """
301
+
302
+ def __init__(self, secret_key: str, algorithm: str = "HS256"):
303
+ """
304
+ Initialize JWT authenticator.
305
+
306
+ Args:
307
+ secret_key: Secret key for signing tokens
308
+ algorithm: JWT signing algorithm
309
+ """
310
+ self.secret_key = secret_key
311
+ self.algorithm = algorithm
312
+ self._token_blacklist: set[str] = set()
313
+
314
+ def create_token(
315
+ self,
316
+ client_id: str,
317
+ roles: set[str],
318
+ expires_in_hours: int = 24,
319
+ ) -> str:
320
+ """
321
+ Create a JWT token.
322
+
323
+ Args:
324
+ client_id: Client identifier
325
+ roles: Client roles
326
+ expires_in_hours: Token validity period
327
+
328
+ Returns:
329
+ JWT token string
330
+ """
331
+ try:
332
+ import jwt
333
+ except ImportError:
334
+ raise ImportError("PyJWT library required for JWT authentication. Install with: pip install PyJWT")
335
+
336
+ now = datetime.utcnow()
337
+ payload = {
338
+ "sub": client_id,
339
+ "roles": list(roles),
340
+ "iat": now,
341
+ "exp": now + timedelta(hours=expires_in_hours),
342
+ "jti": secrets.token_hex(16), # Unique token ID
343
+ }
344
+
345
+ return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
346
+
347
+ def verify_token(self, token: str) -> ClientInfo:
348
+ """
349
+ Verify a JWT token.
350
+
351
+ Args:
352
+ token: JWT token string
353
+
354
+ Returns:
355
+ ClientInfo from token claims
356
+
357
+ Raises:
358
+ AuthenticationError: If token is invalid
359
+ """
360
+ try:
361
+ import jwt
362
+ except ImportError:
363
+ raise ImportError("PyJWT library required for JWT authentication")
364
+
365
+ if token in self._token_blacklist:
366
+ raise AuthenticationError(
367
+ user_message="Token has been revoked",
368
+ internal_details="Token found in blacklist",
369
+ )
370
+
371
+ try:
372
+ payload = jwt.decode(
373
+ token,
374
+ self.secret_key,
375
+ algorithms=[self.algorithm],
376
+ )
377
+
378
+ return ClientInfo(
379
+ client_id=payload["sub"],
380
+ roles=set(payload.get("roles", ["user"])),
381
+ )
382
+ except jwt.ExpiredSignatureError:
383
+ raise AuthenticationError(
384
+ user_message="Token has expired",
385
+ internal_details="JWT signature expired",
386
+ )
387
+ except jwt.InvalidTokenError as e:
388
+ raise AuthenticationError(
389
+ user_message="Invalid token",
390
+ internal_details=f"JWT validation failed: {str(e)}",
391
+ )
392
+
393
+ def revoke_token(self, token: str) -> None:
394
+ """
395
+ Revoke a JWT token by adding to blacklist.
396
+
397
+ Args:
398
+ token: Token to revoke
399
+ """
400
+ self._token_blacklist.add(token)
401
+
402
+
403
+ # Default authenticator instance
404
+ _default_authenticator: APIKeyAuthenticator | None = None
405
+
406
+
407
+ def get_authenticator() -> APIKeyAuthenticator:
408
+ """
409
+ Get or create the default authenticator instance.
410
+
411
+ Returns:
412
+ APIKeyAuthenticator instance
413
+ """
414
+ global _default_authenticator
415
+ if _default_authenticator is None:
416
+ _default_authenticator = APIKeyAuthenticator()
417
+ return _default_authenticator
418
+
419
+
420
+ def set_authenticator(authenticator: APIKeyAuthenticator) -> None:
421
+ """
422
+ Set the default authenticator instance.
423
+
424
+ Args:
425
+ authenticator: Authenticator to use
426
+ """
427
+ global _default_authenticator
428
+ _default_authenticator = authenticator
429
+
430
+
431
+ # Exports
432
+ __all__ = [
433
+ "APIKeyAuthenticator",
434
+ "JWTAuthenticator",
435
+ "ClientInfo",
436
+ "RateLimitConfig",
437
+ "get_authenticator",
438
+ "set_authenticator",
439
+ ]
src/api/exceptions.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom exception hierarchy for LangGraph Multi-Agent MCTS Framework.
3
+
4
+ Provides:
5
+ - Sanitized error messages for production
6
+ - Structured error information for logging
7
+ - Clear separation between user-facing and internal errors
8
+ """
9
+
10
+ import re
11
+ from datetime import datetime
12
+ from typing import Any
13
+
14
+
15
+ class FrameworkError(Exception):
16
+ """
17
+ Base exception for all framework errors.
18
+
19
+ Provides sanitized user-facing messages while preserving
20
+ internal details for logging.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ user_message: str,
26
+ internal_details: str | None = None,
27
+ error_code: str | None = None,
28
+ context: dict[str, Any] | None = None,
29
+ ):
30
+ """
31
+ Initialize framework error.
32
+
33
+ Args:
34
+ user_message: Safe message to show to users
35
+ internal_details: Detailed information for logs (may contain sensitive data)
36
+ error_code: Machine-readable error code
37
+ context: Additional context for debugging
38
+ """
39
+ self.user_message = user_message
40
+ self.internal_details = internal_details or user_message
41
+ self.error_code = error_code or self.__class__.__name__.upper()
42
+ self.context = context or {}
43
+ self.timestamp = datetime.utcnow()
44
+
45
+ super().__init__(user_message)
46
+
47
+ def sanitize_details(self) -> str:
48
+ """
49
+ Remove sensitive information from internal details.
50
+
51
+ Sanitizes:
52
+ - File paths
53
+ - API keys
54
+ - Passwords
55
+ - Connection strings
56
+ - IP addresses
57
+ """
58
+ sanitized = self.internal_details
59
+
60
+ # Remove file paths (Unix and Windows)
61
+ sanitized = re.sub(r"/[\w/.-]+", "/***", sanitized)
62
+ sanitized = re.sub(r"[A-Za-z]:\\[\w\\.-]+", "C:\\***", sanitized)
63
+
64
+ # Remove API keys and secrets
65
+ sanitized = re.sub(
66
+ r"(api[_-]?key|secret|password|token|credential)[\s=:]+[\S]+", r"\1=***", sanitized, flags=re.IGNORECASE
67
+ )
68
+
69
+ # Remove connection strings
70
+ sanitized = re.sub(r"(mongodb|postgresql|mysql|redis)://[^\s]+", r"\1://***", sanitized, flags=re.IGNORECASE)
71
+
72
+ # Remove IP addresses
73
+ sanitized = re.sub(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", "***.***.***", sanitized)
74
+
75
+ # Remove email addresses
76
+ sanitized = re.sub(r"\b[\w.-]+@[\w.-]+\.\w+\b", "***@***", sanitized)
77
+
78
+ return sanitized
79
+
80
+ def to_log_dict(self) -> dict[str, Any]:
81
+ """
82
+ Convert exception to dictionary for structured logging.
83
+
84
+ Returns sanitized version safe for logs.
85
+ """
86
+ return {
87
+ "error_type": self.__class__.__name__,
88
+ "error_code": self.error_code,
89
+ "user_message": self.user_message,
90
+ "sanitized_details": self.sanitize_details(),
91
+ "timestamp": self.timestamp.isoformat(),
92
+ "context": {k: str(v) for k, v in self.context.items()},
93
+ }
94
+
95
+ def to_user_response(self) -> dict[str, Any]:
96
+ """
97
+ Convert exception to safe user-facing response.
98
+ """
99
+ return {
100
+ "error": True,
101
+ "error_code": self.error_code,
102
+ "message": self.user_message,
103
+ "timestamp": self.timestamp.isoformat(),
104
+ }
105
+
106
+
107
+ class ValidationError(FrameworkError):
108
+ """Raised when input validation fails."""
109
+
110
+ def __init__(
111
+ self,
112
+ user_message: str = "Invalid input provided",
113
+ internal_details: str | None = None,
114
+ field_name: str | None = None,
115
+ **kwargs,
116
+ ):
117
+ context = kwargs.pop("context", {})
118
+ if field_name:
119
+ context["field_name"] = field_name
120
+ super().__init__(
121
+ user_message=user_message,
122
+ internal_details=internal_details,
123
+ error_code="VALIDATION_ERROR",
124
+ context=context,
125
+ **kwargs,
126
+ )
127
+ self.field_name = field_name
128
+
129
+
130
+ class AuthenticationError(FrameworkError):
131
+ """Raised when authentication fails."""
132
+
133
+ def __init__(self, user_message: str = "Authentication failed", internal_details: str | None = None, **kwargs):
134
+ super().__init__(
135
+ user_message=user_message, internal_details=internal_details, error_code="AUTH_ERROR", **kwargs
136
+ )
137
+
138
+
139
+ class AuthorizationError(FrameworkError):
140
+ """Raised when authorization fails."""
141
+
142
+ def __init__(
143
+ self,
144
+ user_message: str = "Access denied",
145
+ internal_details: str | None = None,
146
+ required_permission: str | None = None,
147
+ **kwargs,
148
+ ):
149
+ context = kwargs.pop("context", {})
150
+ if required_permission:
151
+ context["required_permission"] = required_permission
152
+ super().__init__(
153
+ user_message=user_message,
154
+ internal_details=internal_details,
155
+ error_code="AUTHZ_ERROR",
156
+ context=context,
157
+ **kwargs,
158
+ )
159
+
160
+
161
+ class RateLimitError(FrameworkError):
162
+ """Raised when rate limit is exceeded."""
163
+
164
+ def __init__(
165
+ self,
166
+ user_message: str = "Rate limit exceeded. Please try again later.",
167
+ internal_details: str | None = None,
168
+ retry_after_seconds: int | None = None,
169
+ **kwargs,
170
+ ):
171
+ context = kwargs.pop("context", {})
172
+ if retry_after_seconds:
173
+ context["retry_after_seconds"] = retry_after_seconds
174
+ super().__init__(
175
+ user_message=user_message,
176
+ internal_details=internal_details,
177
+ error_code="RATE_LIMIT",
178
+ context=context,
179
+ **kwargs,
180
+ )
181
+ self.retry_after_seconds = retry_after_seconds
182
+
183
+
184
+ class LLMError(FrameworkError):
185
+ """Raised when LLM operations fail."""
186
+
187
+ def __init__(
188
+ self,
189
+ user_message: str = "Language model service temporarily unavailable",
190
+ internal_details: str | None = None,
191
+ provider: str | None = None,
192
+ **kwargs,
193
+ ):
194
+ context = kwargs.pop("context", {})
195
+ if provider:
196
+ context["provider"] = provider
197
+ super().__init__(
198
+ user_message=user_message,
199
+ internal_details=internal_details,
200
+ error_code="LLM_ERROR",
201
+ context=context,
202
+ **kwargs,
203
+ )
204
+
205
+
206
+ class MCTSError(FrameworkError):
207
+ """Raised when MCTS simulation fails."""
208
+
209
+ def __init__(
210
+ self,
211
+ user_message: str = "Tactical simulation failed",
212
+ internal_details: str | None = None,
213
+ iteration: int | None = None,
214
+ **kwargs,
215
+ ):
216
+ context = kwargs.pop("context", {})
217
+ if iteration is not None:
218
+ context["iteration"] = iteration
219
+ super().__init__(
220
+ user_message=user_message,
221
+ internal_details=internal_details,
222
+ error_code="MCTS_ERROR",
223
+ context=context,
224
+ **kwargs,
225
+ )
226
+
227
+
228
+ class RAGError(FrameworkError):
229
+ """Raised when RAG retrieval fails."""
230
+
231
+ def __init__(self, user_message: str = "Context retrieval failed", internal_details: str | None = None, **kwargs):
232
+ super().__init__(user_message=user_message, internal_details=internal_details, error_code="RAG_ERROR", **kwargs)
233
+
234
+
235
+ class TimeoutError(FrameworkError):
236
+ """Raised when operation times out."""
237
+
238
+ def __init__(
239
+ self,
240
+ user_message: str = "Operation timed out",
241
+ internal_details: str | None = None,
242
+ operation: str | None = None,
243
+ timeout_seconds: float | None = None,
244
+ **kwargs,
245
+ ):
246
+ context = kwargs.pop("context", {})
247
+ if operation:
248
+ context["operation"] = operation
249
+ if timeout_seconds:
250
+ context["timeout_seconds"] = timeout_seconds
251
+ super().__init__(
252
+ user_message=user_message,
253
+ internal_details=internal_details,
254
+ error_code="TIMEOUT",
255
+ context=context,
256
+ **kwargs,
257
+ )
258
+
259
+
260
+ class ConfigurationError(FrameworkError):
261
+ """Raised when configuration is invalid."""
262
+
263
+ def __init__(
264
+ self,
265
+ user_message: str = "System configuration error",
266
+ internal_details: str | None = None,
267
+ config_key: str | None = None,
268
+ **kwargs,
269
+ ):
270
+ context = kwargs.pop("context", {})
271
+ if config_key:
272
+ context["config_key"] = config_key
273
+ super().__init__(
274
+ user_message=user_message,
275
+ internal_details=internal_details,
276
+ error_code="CONFIG_ERROR",
277
+ context=context,
278
+ **kwargs,
279
+ )
280
+
281
+
282
+ # Convenience function for wrapping exceptions
283
+ def wrap_exception(
284
+ exc: Exception, user_message: str = "An unexpected error occurred", error_class: type = FrameworkError, **kwargs
285
+ ) -> FrameworkError:
286
+ """
287
+ Wrap a standard exception in a FrameworkError with sanitized details.
288
+
289
+ Args:
290
+ exc: Original exception
291
+ user_message: Safe user-facing message
292
+ error_class: FrameworkError subclass to use
293
+ **kwargs: Additional context
294
+
295
+ Returns:
296
+ FrameworkError instance with sanitized details
297
+ """
298
+ internal_details = f"{type(exc).__name__}: {str(exc)}"
299
+ return error_class(user_message=user_message, internal_details=internal_details, **kwargs)
src/api/inference_server.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Inference Server for LangGraph Multi-Agent MCTS.
3
+
4
+ Provides REST API for:
5
+ - Problem solving with HRM+MCTS+TRM
6
+ - Policy-value network inference
7
+ - Health checks and monitoring
8
+ """
9
+
10
+ import time
11
+ from typing import Any
12
+
13
+ import torch
14
+ import uvicorn
15
+ from fastapi import FastAPI, HTTPException
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from pydantic import BaseModel, Field
18
+
19
+ from ..framework.mcts.neural_mcts import NeuralMCTS
20
+ from ..training.performance_monitor import PerformanceMonitor
21
+ from ..training.system_config import SystemConfig
22
+
23
+
24
+ # Request/Response Models
25
+ class InferenceRequest(BaseModel):
26
+ """Request for problem inference."""
27
+
28
+ state: list[list[float]] # State representation
29
+ query: str | None = "Solve this problem"
30
+ max_thinking_time: float = Field(default=10.0, ge=0.1, le=60.0)
31
+ use_mcts: bool = True
32
+ num_simulations: int | None = None
33
+ use_hrm_decomposition: bool = False
34
+ use_trm_refinement: bool = False
35
+ temperature: float = Field(default=0.1, ge=0.0, le=2.0)
36
+
37
+
38
+ class PolicyValueRequest(BaseModel):
39
+ """Request for policy-value evaluation."""
40
+
41
+ state: list[list[float]] # State representation
42
+
43
+
44
+ class InferenceResponse(BaseModel):
45
+ """Response with inference results."""
46
+
47
+ success: bool
48
+ action_probabilities: dict[str, float] | None = None
49
+ best_action: str | None = None
50
+ value_estimate: float | None = None
51
+ subproblems: list[dict[str, Any]] | None = None
52
+ refinement_info: dict[str, Any] | None = None
53
+ performance_stats: dict[str, float]
54
+ error: str | None = None
55
+
56
+
57
+ class PolicyValueResponse(BaseModel):
58
+ """Response with policy-value predictions."""
59
+
60
+ policy_probs: list[float]
61
+ value: float
62
+ inference_time_ms: float
63
+
64
+
65
+ class HealthResponse(BaseModel):
66
+ """Health check response."""
67
+
68
+ status: str
69
+ device: str
70
+ model_loaded: bool
71
+ gpu_available: bool
72
+ gpu_memory_gb: float | None = None
73
+ uptime_seconds: float
74
+
75
+
76
+ # Inference Server
77
+ class InferenceServer:
78
+ """
79
+ Production inference server with comprehensive features.
80
+
81
+ Features:
82
+ - FastAPI REST endpoints
83
+ - Performance monitoring
84
+ - Health checks
85
+ - CORS support
86
+ - Error handling
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ checkpoint_path: str,
92
+ config: SystemConfig | None = None,
93
+ host: str = "0.0.0.0",
94
+ port: int = 8000,
95
+ ):
96
+ """
97
+ Initialize inference server.
98
+
99
+ Args:
100
+ checkpoint_path: Path to model checkpoint
101
+ config: System configuration (loaded from checkpoint if None)
102
+ host: Server host
103
+ port: Server port
104
+ """
105
+ self.checkpoint_path = checkpoint_path
106
+ self.host = host
107
+ self.port = port
108
+ self.start_time = time.time()
109
+
110
+ # Load models
111
+ self.config, self.models = self._load_models(checkpoint_path, config)
112
+ self.device = self.config.device
113
+
114
+ # Performance monitoring
115
+ self.monitor = PerformanceMonitor(window_size=100, enable_gpu_monitoring=(self.device != "cpu"))
116
+
117
+ # Setup FastAPI app
118
+ self.app = FastAPI(
119
+ title="LangGraph Multi-Agent MCTS API",
120
+ description="Neural-guided MCTS with HRM and TRM agents",
121
+ version="1.0.0",
122
+ )
123
+
124
+ # CORS middleware
125
+ self.app.add_middleware(
126
+ CORSMiddleware,
127
+ allow_origins=["*"],
128
+ allow_credentials=True,
129
+ allow_methods=["*"],
130
+ allow_headers=["*"],
131
+ )
132
+
133
+ # Setup routes
134
+ self._setup_routes()
135
+
136
+ def _load_models(
137
+ self, checkpoint_path: str, config: SystemConfig | None
138
+ ) -> tuple[SystemConfig, dict[str, torch.nn.Module]]:
139
+ """Load models from checkpoint."""
140
+ print(f"Loading models from {checkpoint_path}...")
141
+
142
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
143
+
144
+ # Load config
145
+ if config is None:
146
+ config_dict = checkpoint.get("config", {})
147
+ config = SystemConfig.from_dict(config_dict)
148
+
149
+ device = config.device
150
+
151
+ # Load models
152
+ models = {}
153
+
154
+ # Policy-Value Network
155
+ from ..models.policy_value_net import create_policy_value_network
156
+
157
+ models["policy_value_net"] = create_policy_value_network(config.neural_net, board_size=19, device=device)
158
+ models["policy_value_net"].load_state_dict(checkpoint["policy_value_net"])
159
+ models["policy_value_net"].eval()
160
+
161
+ # HRM Agent
162
+ from ..agents.hrm_agent import create_hrm_agent
163
+
164
+ models["hrm_agent"] = create_hrm_agent(config.hrm, device)
165
+ models["hrm_agent"].load_state_dict(checkpoint["hrm_agent"])
166
+ models["hrm_agent"].eval()
167
+
168
+ # TRM Agent
169
+ from ..agents.trm_agent import create_trm_agent
170
+
171
+ models["trm_agent"] = create_trm_agent(config.trm, output_dim=config.neural_net.action_size, device=device)
172
+ models["trm_agent"].load_state_dict(checkpoint["trm_agent"])
173
+ models["trm_agent"].eval()
174
+
175
+ # MCTS
176
+ models["mcts"] = NeuralMCTS(
177
+ policy_value_network=models["policy_value_net"],
178
+ config=config.mcts,
179
+ device=device,
180
+ )
181
+
182
+ print(f"✓ Models loaded successfully on {device}")
183
+
184
+ return config, models
185
+
186
+ def _setup_routes(self):
187
+ """Setup API routes."""
188
+
189
+ @self.app.get("/", response_model=dict[str, str])
190
+ async def root():
191
+ """Root endpoint."""
192
+ return {
193
+ "message": "LangGraph Multi-Agent MCTS API",
194
+ "version": "1.0.0",
195
+ "docs": "/docs",
196
+ }
197
+
198
+ @self.app.get("/health", response_model=HealthResponse)
199
+ async def health():
200
+ """Health check endpoint."""
201
+ gpu_memory = None
202
+ if torch.cuda.is_available():
203
+ gpu_memory = torch.cuda.memory_allocated() / (1024**3)
204
+
205
+ return HealthResponse(
206
+ status="healthy",
207
+ device=self.device,
208
+ model_loaded=True,
209
+ gpu_available=torch.cuda.is_available(),
210
+ gpu_memory_gb=gpu_memory,
211
+ uptime_seconds=time.time() - self.start_time,
212
+ )
213
+
214
+ @self.app.post("/inference", response_model=InferenceResponse)
215
+ async def inference(request: InferenceRequest):
216
+ """
217
+ Main inference endpoint.
218
+
219
+ Processes a problem using the full pipeline:
220
+ 1. Optional HRM decomposition
221
+ 2. MCTS search
222
+ 3. Optional TRM refinement
223
+ """
224
+ try:
225
+ start_time = time.perf_counter()
226
+
227
+ # Convert state to tensor
228
+ state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0)
229
+ state_tensor = state_tensor.to(self.device)
230
+
231
+ results = {}
232
+
233
+ # HRM Decomposition (if requested)
234
+ if request.use_hrm_decomposition:
235
+ with torch.no_grad():
236
+ hrm_output = self.models["hrm_agent"](state_tensor)
237
+ results["subproblems"] = [
238
+ {
239
+ "level": sp.level,
240
+ "description": sp.description,
241
+ "confidence": sp.confidence,
242
+ }
243
+ for sp in hrm_output.subproblems
244
+ ]
245
+
246
+ # MCTS Search (if requested)
247
+ if request.use_mcts:
248
+ # Note: This is a simplified version
249
+ # In production, you'd need to convert request.state to GameState
250
+ results["action_probabilities"] = {"action_0": 0.5, "action_1": 0.3, "action_2": 0.2}
251
+ results["best_action"] = "action_0"
252
+ results["value_estimate"] = 0.75
253
+
254
+ # TRM Refinement (if requested)
255
+ if request.use_trm_refinement and results.get("best_action"):
256
+ with torch.no_grad():
257
+ # Simplified: just run TRM on the state
258
+ trm_output = self.models["trm_agent"](state_tensor)
259
+ results["refinement_info"] = {
260
+ "converged": trm_output.converged,
261
+ "convergence_step": trm_output.convergence_step,
262
+ "recursion_depth": trm_output.recursion_depth,
263
+ }
264
+
265
+ # Performance stats
266
+ elapsed_ms = (time.perf_counter() - start_time) * 1000
267
+ self.monitor.log_inference(elapsed_ms)
268
+
269
+ perf_stats = {
270
+ "inference_time_ms": elapsed_ms,
271
+ "device": self.device,
272
+ }
273
+
274
+ return InferenceResponse(
275
+ success=True,
276
+ action_probabilities=results.get("action_probabilities"),
277
+ best_action=results.get("best_action"),
278
+ value_estimate=results.get("value_estimate"),
279
+ subproblems=results.get("subproblems"),
280
+ refinement_info=results.get("refinement_info"),
281
+ performance_stats=perf_stats,
282
+ )
283
+
284
+ except Exception as e:
285
+ raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
286
+
287
+ @self.app.post("/policy-value", response_model=PolicyValueResponse)
288
+ async def policy_value(request: PolicyValueRequest):
289
+ """
290
+ Get policy and value predictions for a state.
291
+
292
+ This is a direct neural network evaluation without MCTS.
293
+ """
294
+ try:
295
+ start_time = time.perf_counter()
296
+
297
+ # Convert state to tensor
298
+ state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0)
299
+ state_tensor = state_tensor.to(self.device)
300
+
301
+ # Get predictions
302
+ with torch.no_grad():
303
+ policy_log_probs, value = self.models["policy_value_net"](state_tensor)
304
+ policy_probs = torch.exp(policy_log_probs).squeeze(0)
305
+
306
+ elapsed_ms = (time.perf_counter() - start_time) * 1000
307
+
308
+ return PolicyValueResponse(
309
+ policy_probs=policy_probs.cpu().tolist(),
310
+ value=value.item(),
311
+ inference_time_ms=elapsed_ms,
312
+ )
313
+
314
+ except Exception as e:
315
+ raise HTTPException(status_code=500, detail=f"Policy-value inference failed: {str(e)}")
316
+
317
+ @self.app.get("/stats")
318
+ async def stats():
319
+ """Get performance statistics."""
320
+ return self.monitor.get_stats()
321
+
322
+ @self.app.post("/reset-stats")
323
+ async def reset_stats():
324
+ """Reset performance statistics."""
325
+ self.monitor.reset()
326
+ return {"message": "Statistics reset successfully"}
327
+
328
+ def run(self):
329
+ """Start the inference server."""
330
+ print(f"\n{'=' * 80}")
331
+ print("Starting LangGraph Multi-Agent MCTS Inference Server")
332
+ print(f"{'=' * 80}")
333
+ print(f"Host: {self.host}:{self.port}")
334
+ print(f"Device: {self.device}")
335
+ print(f"Checkpoint: {self.checkpoint_path}")
336
+ print(f"{'=' * 80}\n")
337
+
338
+ uvicorn.run(self.app, host=self.host, port=self.port)
339
+
340
+
341
+ def main():
342
+ """Main entry point for inference server."""
343
+ import argparse
344
+
345
+ parser = argparse.ArgumentParser(description="LangGraph MCTS Inference Server")
346
+ parser.add_argument(
347
+ "--checkpoint",
348
+ type=str,
349
+ required=True,
350
+ help="Path to model checkpoint",
351
+ )
352
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
353
+ parser.add_argument("--port", type=int, default=8000, help="Server port")
354
+ parser.add_argument(
355
+ "--device",
356
+ type=str,
357
+ default=None,
358
+ help="Device (cpu, cuda, mps)",
359
+ )
360
+
361
+ args = parser.parse_args()
362
+
363
+ # Load config and override device if specified
364
+ config = None
365
+ if args.device:
366
+ config = SystemConfig()
367
+ config.device = args.device
368
+
369
+ server = InferenceServer(
370
+ checkpoint_path=args.checkpoint,
371
+ config=config,
372
+ host=args.host,
373
+ port=args.port,
374
+ )
375
+
376
+ server.run()
377
+
378
+
379
+ if __name__ == "__main__":
380
+ main()
src/api/rest_server.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Production REST API server for LangGraph Multi-Agent MCTS Framework.
3
+
4
+ Provides:
5
+ - OpenAPI/Swagger documentation
6
+ - Authentication via API keys
7
+ - Rate limiting
8
+ - Health and readiness endpoints
9
+ - Request validation with Pydantic
10
+ - Prometheus metrics exposure
11
+ """
12
+
13
+ import asyncio
14
+ import time
15
+ from contextlib import asynccontextmanager
16
+ from datetime import datetime
17
+ from typing import Any
18
+
19
+ from fastapi import Depends, FastAPI, Header, HTTPException, Request, Response
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from fastapi.responses import JSONResponse
22
+ from pydantic import BaseModel, Field
23
+
24
+ # Import framework components
25
+ try:
26
+ from src.adapters.llm import create_client # noqa: F401
27
+ from src.api.auth import (
28
+ APIKeyAuthenticator,
29
+ ClientInfo,
30
+ RateLimitConfig,
31
+ get_authenticator,
32
+ set_authenticator,
33
+ )
34
+ from src.api.exceptions import (
35
+ AuthenticationError,
36
+ AuthorizationError, # noqa: F401
37
+ FrameworkError,
38
+ RateLimitError,
39
+ ValidationError, # noqa: F401
40
+ )
41
+ from src.models.validation import MCTSConfig, QueryInput # noqa: F401
42
+
43
+ IMPORTS_AVAILABLE = True
44
+ except ImportError as e:
45
+ IMPORTS_AVAILABLE = False
46
+ import_error = str(e)
47
+
48
+ # Prometheus metrics (optional)
49
+ try:
50
+ from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest
51
+
52
+ PROMETHEUS_AVAILABLE = True
53
+
54
+ # Define metrics
55
+ REQUEST_COUNT = Counter("mcts_requests_total", "Total number of requests", ["method", "endpoint", "status"])
56
+ REQUEST_LATENCY = Histogram("mcts_request_duration_seconds", "Request latency in seconds", ["method", "endpoint"])
57
+ ACTIVE_REQUESTS = Gauge("mcts_active_requests", "Number of active requests")
58
+ ERROR_COUNT = Counter("mcts_errors_total", "Total number of errors", ["error_type"])
59
+ except ImportError:
60
+ PROMETHEUS_AVAILABLE = False
61
+
62
+
63
+ # Request/Response Models
64
+ class QueryRequest(BaseModel):
65
+ """Request model for query processing."""
66
+
67
+ query: str = Field(
68
+ ...,
69
+ min_length=1,
70
+ max_length=10000,
71
+ description="User query to process",
72
+ json_schema_extra={"example": "Recommend defensive positions for night attack scenario"},
73
+ )
74
+ use_mcts: bool = Field(default=True, description="Enable MCTS tactical simulation")
75
+ use_rag: bool = Field(default=True, description="Enable RAG context retrieval")
76
+ mcts_iterations: int | None = Field(default=None, ge=1, le=10000, description="Override default MCTS iterations")
77
+ thread_id: str | None = Field(
78
+ default=None,
79
+ max_length=100,
80
+ pattern=r"^[a-zA-Z0-9_-]+$",
81
+ description="Conversation thread ID for state persistence",
82
+ )
83
+
84
+ class Config:
85
+ json_schema_extra = {
86
+ "example": {
87
+ "query": "Recommend defensive positions for night attack",
88
+ "use_mcts": True,
89
+ "use_rag": True,
90
+ "mcts_iterations": 200,
91
+ "thread_id": "session_123",
92
+ }
93
+ }
94
+
95
+
96
+ class QueryResponse(BaseModel):
97
+ """Response model for query results."""
98
+
99
+ response: str = Field(..., description="Final synthesized response")
100
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence score")
101
+ agents_used: list[str] = Field(..., description="List of agents that contributed")
102
+ mcts_stats: dict[str, Any] | None = Field(default=None, description="MCTS simulation statistics")
103
+ processing_time_ms: float = Field(..., description="Total processing time in milliseconds")
104
+ metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
105
+
106
+
107
+ class HealthResponse(BaseModel):
108
+ """Health check response."""
109
+
110
+ status: str = Field(..., description="Service status")
111
+ timestamp: str = Field(..., description="Current timestamp")
112
+ version: str = Field(default="1.0.0", description="API version")
113
+ uptime_seconds: float = Field(..., description="Service uptime")
114
+
115
+
116
+ class ReadinessResponse(BaseModel):
117
+ """Readiness check response."""
118
+
119
+ ready: bool = Field(..., description="Whether service is ready")
120
+ checks: dict[str, bool] = Field(..., description="Individual check results")
121
+
122
+
123
+ class ErrorResponse(BaseModel):
124
+ """Error response model."""
125
+
126
+ error: bool = Field(default=True)
127
+ error_code: str = Field(..., description="Machine-readable error code")
128
+ message: str = Field(..., description="Human-readable error message")
129
+ timestamp: str = Field(..., description="Error timestamp")
130
+
131
+
132
+ # Application startup
133
+ start_time = time.time()
134
+ framework_instance = None
135
+
136
+
137
+ @asynccontextmanager
138
+ async def lifespan(app: FastAPI):
139
+ """Application lifespan manager."""
140
+ global framework_instance
141
+
142
+ # Startup
143
+ print("Starting MCTS Framework API server...")
144
+
145
+ # Initialize authenticator with demo key (replace in production)
146
+ authenticator = APIKeyAuthenticator(
147
+ valid_keys=["demo-api-key-replace-in-production"],
148
+ rate_limit_config=RateLimitConfig(
149
+ requests_per_minute=60,
150
+ requests_per_hour=1000,
151
+ requests_per_day=10000,
152
+ ),
153
+ )
154
+ set_authenticator(authenticator)
155
+
156
+ # Initialize framework (lazy loading)
157
+ # framework_instance = create_framework()
158
+
159
+ print("API server started successfully")
160
+
161
+ yield
162
+
163
+ # Shutdown
164
+ print("Shutting down API server...")
165
+
166
+
167
+ # Create FastAPI app
168
+ app = FastAPI(
169
+ title="LangGraph Multi-Agent MCTS API",
170
+ description="""
171
+ ## Multi-Agent Reasoning API with MCTS Tactical Simulation
172
+
173
+ This API provides access to a sophisticated multi-agent reasoning framework that combines:
174
+ - **HRM Agent**: Hierarchical decomposition of complex queries
175
+ - **TRM Agent**: Iterative refinement for response quality
176
+ - **MCTS Engine**: Monte Carlo Tree Search for tactical simulation
177
+ - **RAG Integration**: Context retrieval from vector stores
178
+
179
+ ### Features
180
+ - Secure API key authentication
181
+ - Rate limiting per client
182
+ - Real-time metrics (Prometheus)
183
+ - Distributed tracing (OpenTelemetry)
184
+ - Production-grade error handling
185
+
186
+ ### Quick Start
187
+ 1. Obtain an API key
188
+ 2. Include `X-API-Key` header in requests
189
+ 3. Send queries to `/query` endpoint
190
+ 4. Monitor health via `/health` endpoint
191
+ """,
192
+ version="1.0.0",
193
+ docs_url="/docs",
194
+ redoc_url="/redoc",
195
+ openapi_tags=[
196
+ {"name": "query", "description": "Query processing operations"},
197
+ {"name": "health", "description": "Health and readiness checks"},
198
+ {"name": "metrics", "description": "Observability endpoints"},
199
+ ],
200
+ lifespan=lifespan,
201
+ )
202
+
203
+ # CORS middleware
204
+ app.add_middleware(
205
+ CORSMiddleware,
206
+ allow_origins=["*"], # Configure appropriately for production
207
+ allow_credentials=True,
208
+ allow_methods=["*"],
209
+ allow_headers=["*"],
210
+ )
211
+
212
+
213
+ # Middleware for metrics
214
+ @app.middleware("http")
215
+ async def metrics_middleware(request: Request, call_next):
216
+ """Track request metrics."""
217
+ if PROMETHEUS_AVAILABLE:
218
+ ACTIVE_REQUESTS.inc()
219
+
220
+ start = time.perf_counter()
221
+
222
+ try:
223
+ response = await call_next(request)
224
+ status = response.status_code
225
+ except Exception:
226
+ status = 500
227
+ raise
228
+ finally:
229
+ if PROMETHEUS_AVAILABLE:
230
+ ACTIVE_REQUESTS.dec()
231
+ elapsed = time.perf_counter() - start
232
+ REQUEST_COUNT.labels(method=request.method, endpoint=request.url.path, status=str(status)).inc()
233
+ REQUEST_LATENCY.labels(method=request.method, endpoint=request.url.path).observe(elapsed)
234
+
235
+ return response
236
+
237
+
238
+ # Authentication dependency
239
+ async def verify_api_key(x_api_key: str = Header(..., description="API key for authentication")):
240
+ """Verify API key and return client info."""
241
+ if not IMPORTS_AVAILABLE:
242
+ raise HTTPException(status_code=500, detail="Authentication module not available")
243
+
244
+ try:
245
+ authenticator = get_authenticator()
246
+ client_info = authenticator.require_auth(x_api_key)
247
+ return client_info
248
+ except AuthenticationError as e:
249
+ if PROMETHEUS_AVAILABLE:
250
+ ERROR_COUNT.labels(error_type="authentication").inc()
251
+ raise HTTPException(status_code=401, detail=e.user_message)
252
+ except RateLimitError as e:
253
+ if PROMETHEUS_AVAILABLE:
254
+ ERROR_COUNT.labels(error_type="rate_limit").inc()
255
+ raise HTTPException(
256
+ status_code=429, detail=e.user_message, headers={"Retry-After": str(e.retry_after_seconds or 60)}
257
+ )
258
+
259
+
260
+ # Exception handlers
261
+ @app.exception_handler(FrameworkError)
262
+ async def framework_error_handler(request: Request, exc: FrameworkError):
263
+ """Handle framework-specific errors."""
264
+ if PROMETHEUS_AVAILABLE:
265
+ ERROR_COUNT.labels(error_type=exc.error_code).inc()
266
+
267
+ return JSONResponse(status_code=500, content=exc.to_user_response())
268
+
269
+
270
+ @app.exception_handler(ValidationError)
271
+ async def validation_error_handler(request: Request, exc: ValidationError):
272
+ """Handle validation errors."""
273
+ if PROMETHEUS_AVAILABLE:
274
+ ERROR_COUNT.labels(error_type="validation").inc()
275
+
276
+ return JSONResponse(status_code=400, content=exc.to_user_response())
277
+
278
+
279
+ # Endpoints
280
+ @app.get("/health", response_model=HealthResponse, tags=["health"])
281
+ async def health_check():
282
+ """
283
+ Health check endpoint.
284
+
285
+ Returns basic service health status. Use this for load balancer health checks.
286
+ """
287
+ return HealthResponse(
288
+ status="healthy",
289
+ timestamp=datetime.utcnow().isoformat(),
290
+ version="1.0.0",
291
+ uptime_seconds=time.time() - start_time,
292
+ )
293
+
294
+
295
+ @app.get("/ready", response_model=ReadinessResponse, tags=["health"])
296
+ async def readiness_check():
297
+ """
298
+ Readiness check endpoint.
299
+
300
+ Verifies all dependencies are available. Use this for Kubernetes readiness probes.
301
+ """
302
+ checks = {
303
+ "imports_available": IMPORTS_AVAILABLE,
304
+ "authenticator_configured": True,
305
+ "llm_client_available": True, # Would check actual client
306
+ "prometheus_available": PROMETHEUS_AVAILABLE,
307
+ }
308
+
309
+ # Check if all critical services are available
310
+ all_ready = all(
311
+ [
312
+ checks["imports_available"],
313
+ checks["authenticator_configured"],
314
+ ]
315
+ )
316
+
317
+ if not all_ready:
318
+ raise HTTPException(status_code=503, detail="Service not ready")
319
+
320
+ return ReadinessResponse(ready=all_ready, checks=checks)
321
+
322
+
323
+ @app.get("/metrics", tags=["metrics"])
324
+ async def prometheus_metrics():
325
+ """
326
+ Prometheus metrics endpoint.
327
+
328
+ Returns metrics in Prometheus text format for scraping.
329
+ """
330
+ if not PROMETHEUS_AVAILABLE:
331
+ raise HTTPException(status_code=501, detail="Prometheus metrics not available")
332
+
333
+ return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST)
334
+
335
+
336
+ @app.post(
337
+ "/query",
338
+ response_model=QueryResponse,
339
+ tags=["query"],
340
+ responses={
341
+ 401: {"model": ErrorResponse, "description": "Authentication failed"},
342
+ 429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
343
+ 400: {"model": ErrorResponse, "description": "Invalid input"},
344
+ 500: {"model": ErrorResponse, "description": "Internal server error"},
345
+ },
346
+ )
347
+ async def process_query(request: QueryRequest, client_info: ClientInfo = Depends(verify_api_key)):
348
+ """
349
+ Process a query using the multi-agent MCTS framework.
350
+
351
+ This endpoint:
352
+ 1. Validates the input query
353
+ 2. Optionally retrieves context via RAG
354
+ 3. Processes through HRM and TRM agents
355
+ 4. Optionally runs MCTS simulation
356
+ 5. Synthesizes a final response
357
+
358
+ **Authentication**: Requires valid API key in X-API-Key header.
359
+
360
+ **Rate Limiting**: Subject to rate limits per client.
361
+ """
362
+ start_time = time.perf_counter()
363
+
364
+ # Validate input using validation models
365
+ if IMPORTS_AVAILABLE:
366
+ try:
367
+ QueryInput(
368
+ query=request.query,
369
+ use_rag=request.use_rag,
370
+ use_mcts=request.use_mcts,
371
+ thread_id=request.thread_id,
372
+ )
373
+ except Exception as e:
374
+ if PROMETHEUS_AVAILABLE:
375
+ ERROR_COUNT.labels(error_type="validation").inc()
376
+ raise HTTPException(status_code=400, detail=f"Validation failed: {str(e)}")
377
+
378
+ # Process query (mock implementation for demo)
379
+ # In production, this would call the actual framework
380
+ await asyncio.sleep(0.1) # Simulate processing
381
+
382
+ processing_time = (time.perf_counter() - start_time) * 1000
383
+
384
+ # Mock response
385
+ return QueryResponse(
386
+ response=f"Processed query: {request.query[:100]}...",
387
+ confidence=0.85,
388
+ agents_used=["hrm", "trm"] + (["mcts"] if request.use_mcts else []),
389
+ mcts_stats=(
390
+ {
391
+ "iterations": request.mcts_iterations or 100,
392
+ "best_action": "recommended_action",
393
+ "root_visits": request.mcts_iterations or 100,
394
+ }
395
+ if request.use_mcts
396
+ else None
397
+ ),
398
+ processing_time_ms=processing_time,
399
+ metadata={
400
+ "client_id": client_info.client_id,
401
+ "thread_id": request.thread_id,
402
+ "rag_enabled": request.use_rag,
403
+ },
404
+ )
405
+
406
+
407
+ @app.get("/stats", tags=["metrics"])
408
+ async def get_stats(client_info: ClientInfo = Depends(verify_api_key)):
409
+ """
410
+ Get usage statistics for the authenticated client.
411
+
412
+ Returns request counts and rate limit information.
413
+ """
414
+ authenticator = get_authenticator()
415
+ stats = authenticator.get_client_stats(client_info.client_id)
416
+
417
+ return {
418
+ "client_id": client_info.client_id,
419
+ "roles": list(client_info.roles),
420
+ **stats,
421
+ "rate_limits": {
422
+ "per_minute": authenticator.rate_limit_config.requests_per_minute,
423
+ "per_hour": authenticator.rate_limit_config.requests_per_hour,
424
+ "per_day": authenticator.rate_limit_config.requests_per_day,
425
+ },
426
+ }
427
+
428
+
429
+ # Entry point
430
+ if __name__ == "__main__":
431
+ import uvicorn
432
+
433
+ uvicorn.run(
434
+ "src.api.rest_server:app",
435
+ host="0.0.0.0",
436
+ port=8000,
437
+ reload=False,
438
+ workers=4,
439
+ log_level="info",
440
+ access_log=True,
441
+ )
src/config/__init__.py ADDED
File without changes
src/config/meta_controller.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ meta_controller:
2
+ enabled: false # Disabled by default for backward compatibility
3
+ type: "rnn" # "rnn" or "bert"
4
+ fallback_to_rule_based: true # Fallback on errors
5
+
6
+ rnn:
7
+ hidden_dim: 64
8
+ num_layers: 1
9
+ dropout: 0.1
10
+ model_path: null # Path to trained model (null for untrained)
11
+
12
+ bert:
13
+ model_name: "prajjwal1/bert-mini"
14
+ use_lora: true
15
+ lora_r: 4
16
+ lora_alpha: 16
17
+ lora_dropout: 0.1
18
+ model_path: null # Path to trained LoRA adapter
19
+
20
+ inference:
21
+ device: null # Auto-detect if null
22
+ seed: 42
src/config/settings.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic Settings v2 configuration management for LangGraph Multi-Agent MCTS.
3
+
4
+ Provides:
5
+ - Secure configuration loading from environment variables and .env files
6
+ - Type-safe settings with validation
7
+ - Secrets protection using SecretStr
8
+ - MCTS parameter bounds validation
9
+ - Support for multiple LLM providers
10
+ """
11
+
12
+ from enum import Enum
13
+
14
+ from pydantic import (
15
+ Field,
16
+ SecretStr,
17
+ field_validator,
18
+ model_validator,
19
+ )
20
+ from pydantic_settings import BaseSettings, SettingsConfigDict
21
+
22
+
23
+ class LLMProvider(str, Enum):
24
+ """Supported LLM providers."""
25
+
26
+ OPENAI = "openai"
27
+ ANTHROPIC = "anthropic"
28
+ LMSTUDIO = "lmstudio"
29
+
30
+
31
+ class LogLevel(str, Enum):
32
+ """Supported log levels."""
33
+
34
+ DEBUG = "DEBUG"
35
+ INFO = "INFO"
36
+ WARNING = "WARNING"
37
+ ERROR = "ERROR"
38
+ CRITICAL = "CRITICAL"
39
+
40
+
41
+ class MCTSImplementation(str, Enum):
42
+ """MCTS implementation variants."""
43
+
44
+ BASELINE = "baseline" # Original MCTS core
45
+ NEURAL = "neural" # Neural-guided AlphaZero-style MCTS
46
+
47
+
48
+ class Settings(BaseSettings):
49
+ """
50
+ Application settings with security-first configuration.
51
+
52
+ All sensitive values use SecretStr to prevent accidental exposure in logs.
53
+ Configuration is loaded from environment variables with .env file support.
54
+ """
55
+
56
+ model_config = SettingsConfigDict(
57
+ env_file=".env",
58
+ env_file_encoding="utf-8",
59
+ case_sensitive=True,
60
+ extra="ignore",
61
+ validate_default=True,
62
+ )
63
+
64
+ # LLM Provider Configuration
65
+ LLM_PROVIDER: LLMProvider = Field(
66
+ default=LLMProvider.OPENAI, description="LLM provider to use (openai, anthropic, lmstudio)"
67
+ )
68
+
69
+ # API Keys (Secrets)
70
+ OPENAI_API_KEY: SecretStr | None = Field(
71
+ default=None, description="OpenAI API key (required if using OpenAI provider)"
72
+ )
73
+
74
+ ANTHROPIC_API_KEY: SecretStr | None = Field(
75
+ default=None, description="Anthropic API key (required if using Anthropic provider)"
76
+ )
77
+
78
+ BRAINTRUST_API_KEY: SecretStr | None = Field(
79
+ default=None, description="Braintrust API key for experiment tracking (optional)"
80
+ )
81
+
82
+ PINECONE_API_KEY: SecretStr | None = Field(
83
+ default=None, description="Pinecone API key for vector storage (optional)"
84
+ )
85
+
86
+ PINECONE_HOST: str | None = Field(
87
+ default=None, description="Pinecone host URL (e.g., https://index.svc.environment.pinecone.io)"
88
+ )
89
+
90
+ # Local LLM Configuration
91
+ LMSTUDIO_BASE_URL: str | None = Field(
92
+ default="http://localhost:1234/v1", description="LM Studio API base URL for local inference"
93
+ )
94
+
95
+ LMSTUDIO_MODEL: str | None = Field(default=None, description="LM Studio model identifier (e.g., liquid/lfm2-1.2b)")
96
+
97
+ # MCTS Configuration with bounds validation
98
+ MCTS_ENABLED: bool = Field(default=True, description="Enable MCTS for agent decision-making")
99
+
100
+ MCTS_IMPL: MCTSImplementation = Field(
101
+ default=MCTSImplementation.BASELINE, description="MCTS implementation variant to use"
102
+ )
103
+
104
+ MCTS_ITERATIONS: int = Field(default=100, ge=1, le=10000, description="Number of MCTS iterations (1-10000)")
105
+
106
+ MCTS_C: float = Field(
107
+ default=1.414, ge=0.0, le=10.0, description="MCTS exploration weight (UCB1 constant, 0.0-10.0)"
108
+ )
109
+
110
+ # Random seed for reproducibility
111
+ SEED: int | None = Field(default=None, ge=0, description="Random seed for reproducibility (optional)")
112
+
113
+ # LangSmith Configuration for tracing and evaluation
114
+ LANGSMITH_API_KEY: SecretStr | None = Field(
115
+ default=None, description="LangSmith API key for tracing and evaluation (optional)"
116
+ )
117
+
118
+ LANGSMITH_PROJECT: str = Field(default="langgraph-mcts", description="LangSmith project name")
119
+
120
+ LANGCHAIN_TRACING_V2: bool = Field(default=False, description="Enable LangChain tracing v2")
121
+
122
+ LANGCHAIN_ENDPOINT: str = Field(default="https://api.smith.langchain.com", description="LangChain API endpoint")
123
+
124
+ # Weights & Biases Configuration for experiment tracking
125
+ WANDB_API_KEY: SecretStr | None = Field(
126
+ default=None, description="Weights & Biases API key for experiment tracking (optional)"
127
+ )
128
+
129
+ WANDB_PROJECT: str = Field(default="langgraph-mcts", description="W&B project name")
130
+
131
+ WANDB_ENTITY: str | None = Field(default=None, description="W&B entity (username or team name)")
132
+
133
+ WANDB_MODE: str = Field(default="online", description="W&B mode: online, offline, or disabled")
134
+
135
+ # Logging Configuration
136
+ LOG_LEVEL: LogLevel = Field(default=LogLevel.INFO, description="Application log level")
137
+
138
+ # OpenTelemetry Configuration
139
+ OTEL_EXPORTER_OTLP_ENDPOINT: str | None = Field(
140
+ default=None, description="OpenTelemetry OTLP exporter endpoint URL"
141
+ )
142
+
143
+ # S3 Storage Configuration
144
+ S3_BUCKET: str | None = Field(default=None, description="S3 bucket name for artifact storage")
145
+
146
+ S3_PREFIX: str = Field(default="mcts-artifacts", description="S3 key prefix for stored artifacts")
147
+
148
+ S3_REGION: str = Field(default="us-east-1", description="AWS region for S3 bucket")
149
+
150
+ # Network Configuration (security)
151
+ HTTP_TIMEOUT_SECONDS: int = Field(default=30, ge=1, le=300, description="HTTP request timeout in seconds")
152
+
153
+ HTTP_MAX_RETRIES: int = Field(default=3, ge=0, le=10, description="Maximum HTTP request retries")
154
+
155
+ # Security Settings
156
+ MAX_QUERY_LENGTH: int = Field(
157
+ default=10000, ge=1, le=100000, description="Maximum allowed query length in characters"
158
+ )
159
+
160
+ RATE_LIMIT_REQUESTS_PER_MINUTE: int = Field(
161
+ default=60, ge=1, le=1000, description="Rate limit for API requests per minute"
162
+ )
163
+
164
+ @field_validator("OPENAI_API_KEY")
165
+ @classmethod
166
+ def validate_openai_key_format(cls, v: SecretStr | None) -> SecretStr | None:
167
+ """Validate OpenAI API key format without exposing the value."""
168
+ if v is not None:
169
+ secret_value = v.get_secret_value()
170
+ # Check for obviously invalid patterns
171
+ if secret_value in ("", "your-api-key-here", "sk-xxx", "REPLACE_ME"):
172
+ raise ValueError("OpenAI API key appears to be a placeholder value")
173
+ if not secret_value.startswith("sk-"):
174
+ raise ValueError("OpenAI API key should start with 'sk-'")
175
+ if len(secret_value) < 20:
176
+ raise ValueError("OpenAI API key appears to be too short")
177
+ return v
178
+
179
+ @field_validator("ANTHROPIC_API_KEY")
180
+ @classmethod
181
+ def validate_anthropic_key_format(cls, v: SecretStr | None) -> SecretStr | None:
182
+ """Validate Anthropic API key format without exposing the value."""
183
+ if v is not None:
184
+ secret_value = v.get_secret_value()
185
+ # Check for obviously invalid patterns
186
+ if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
187
+ raise ValueError("Anthropic API key appears to be a placeholder value")
188
+ if len(secret_value) < 20:
189
+ raise ValueError("Anthropic API key appears to be too short")
190
+ return v
191
+
192
+ @field_validator("BRAINTRUST_API_KEY")
193
+ @classmethod
194
+ def validate_braintrust_key_format(cls, v: SecretStr | None) -> SecretStr | None:
195
+ """Validate Braintrust API key format without exposing the value."""
196
+ if v is not None:
197
+ secret_value = v.get_secret_value()
198
+ # Check for obviously invalid patterns
199
+ if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
200
+ raise ValueError("Braintrust API key appears to be a placeholder value")
201
+ if len(secret_value) < 20:
202
+ raise ValueError("Braintrust API key appears to be too short")
203
+ return v
204
+
205
+ @field_validator("PINECONE_API_KEY")
206
+ @classmethod
207
+ def validate_pinecone_key_format(cls, v: SecretStr | None) -> SecretStr | None:
208
+ """Validate Pinecone API key format without exposing the value."""
209
+ if v is not None:
210
+ secret_value = v.get_secret_value()
211
+ # Check for obviously invalid patterns
212
+ if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
213
+ raise ValueError("Pinecone API key appears to be a placeholder value")
214
+ if len(secret_value) < 20:
215
+ raise ValueError("Pinecone API key appears to be too short")
216
+ return v
217
+
218
+ @field_validator("LANGSMITH_API_KEY")
219
+ @classmethod
220
+ def validate_langsmith_key_format(cls, v: SecretStr | None) -> SecretStr | None:
221
+ """Validate LangSmith API key format without exposing the value."""
222
+ if v is not None:
223
+ secret_value = v.get_secret_value()
224
+ if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
225
+ raise ValueError("LangSmith API key appears to be a placeholder value")
226
+ if len(secret_value) < 20:
227
+ raise ValueError("LangSmith API key appears to be too short")
228
+ return v
229
+
230
+ @field_validator("WANDB_API_KEY")
231
+ @classmethod
232
+ def validate_wandb_key_format(cls, v: SecretStr | None) -> SecretStr | None:
233
+ """Validate Weights & Biases API key format without exposing the value."""
234
+ if v is not None:
235
+ secret_value = v.get_secret_value()
236
+ if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
237
+ raise ValueError("W&B API key appears to be a placeholder value")
238
+ if len(secret_value) < 20:
239
+ raise ValueError("W&B API key appears to be too short")
240
+ return v
241
+
242
+ @field_validator("PINECONE_HOST")
243
+ @classmethod
244
+ def validate_pinecone_host(cls, v: str | None) -> str | None:
245
+ """Validate Pinecone host URL format."""
246
+ if v is not None and v != "":
247
+ if not v.startswith("https://"):
248
+ raise ValueError("Pinecone host must start with https://")
249
+ if "pinecone.io" not in v:
250
+ raise ValueError("Pinecone host should be a valid pinecone.io URL")
251
+ return v
252
+
253
+ @field_validator("LMSTUDIO_BASE_URL")
254
+ @classmethod
255
+ def validate_lmstudio_url(cls, v: str | None) -> str | None:
256
+ """Validate LM Studio base URL format."""
257
+ if v is not None:
258
+ if not v.startswith(("http://", "https://")):
259
+ raise ValueError("LM Studio base URL must start with http:// or https://")
260
+ # Warn if not localhost (potential security concern)
261
+ if not any(host in v for host in ("localhost", "127.0.0.1", "::1")):
262
+ import warnings
263
+
264
+ warnings.warn(
265
+ "LM Studio URL points to non-localhost address. Ensure this is intentional and secure.",
266
+ UserWarning,
267
+ stacklevel=2,
268
+ )
269
+ return v
270
+
271
+ @field_validator("OTEL_EXPORTER_OTLP_ENDPOINT")
272
+ @classmethod
273
+ def validate_otel_endpoint(cls, v: str | None) -> str | None:
274
+ """Validate OpenTelemetry endpoint URL."""
275
+ if v is not None and v != "" and not v.startswith(("http://", "https://", "grpc://")):
276
+ raise ValueError("OpenTelemetry endpoint must start with http://, https://, or grpc://")
277
+ return v
278
+
279
+ @field_validator("S3_BUCKET")
280
+ @classmethod
281
+ def validate_s3_bucket_name(cls, v: str | None) -> str | None:
282
+ """Validate S3 bucket name format."""
283
+ if v is not None:
284
+ # S3 bucket naming rules
285
+ if len(v) < 3 or len(v) > 63:
286
+ raise ValueError("S3 bucket name must be 3-63 characters long")
287
+ if not v.replace("-", "").replace(".", "").isalnum():
288
+ raise ValueError("S3 bucket name can only contain lowercase letters, numbers, hyphens, and periods")
289
+ if v.startswith("-") or v.endswith("-"):
290
+ raise ValueError("S3 bucket name cannot start or end with a hyphen")
291
+ return v
292
+
293
+ @model_validator(mode="after")
294
+ def validate_provider_credentials(self) -> "Settings":
295
+ """Ensure required API keys are provided for the selected provider."""
296
+ if self.LLM_PROVIDER == LLMProvider.OPENAI:
297
+ if self.OPENAI_API_KEY is None:
298
+ raise ValueError(
299
+ "OPENAI_API_KEY is required when using OpenAI provider. "
300
+ "Set the OPENAI_API_KEY environment variable."
301
+ )
302
+ elif self.LLM_PROVIDER == LLMProvider.ANTHROPIC:
303
+ if self.ANTHROPIC_API_KEY is None:
304
+ raise ValueError(
305
+ "ANTHROPIC_API_KEY is required when using Anthropic provider. "
306
+ "Set the ANTHROPIC_API_KEY environment variable."
307
+ )
308
+ elif self.LLM_PROVIDER == LLMProvider.LMSTUDIO and self.LMSTUDIO_BASE_URL is None:
309
+ raise ValueError("LMSTUDIO_BASE_URL is required when using LM Studio provider.")
310
+ return self
311
+
312
+ def get_api_key(self) -> str | None:
313
+ """
314
+ Get the API key for the current provider.
315
+
316
+ Returns the secret value - use with caution to avoid logging.
317
+ """
318
+ if self.LLM_PROVIDER == LLMProvider.OPENAI and self.OPENAI_API_KEY:
319
+ return self.OPENAI_API_KEY.get_secret_value()
320
+ elif self.LLM_PROVIDER == LLMProvider.ANTHROPIC and self.ANTHROPIC_API_KEY:
321
+ return self.ANTHROPIC_API_KEY.get_secret_value()
322
+ return None
323
+
324
+ def safe_dict(self) -> dict:
325
+ """
326
+ Return settings as dictionary with secrets masked.
327
+
328
+ Safe for logging and display purposes.
329
+ """
330
+ data = self.model_dump()
331
+ # Mask all sensitive fields
332
+ secret_fields = [
333
+ "OPENAI_API_KEY",
334
+ "ANTHROPIC_API_KEY",
335
+ "BRAINTRUST_API_KEY",
336
+ "PINECONE_API_KEY",
337
+ "LANGSMITH_API_KEY",
338
+ "WANDB_API_KEY",
339
+ ]
340
+ for field in secret_fields:
341
+ if field in data and data[field]:
342
+ data[field] = "***MASKED***"
343
+ return data
344
+
345
+ def get_braintrust_api_key(self) -> str | None:
346
+ """
347
+ Get the Braintrust API key if configured.
348
+
349
+ Returns the secret value - use with caution to avoid logging.
350
+ """
351
+ if self.BRAINTRUST_API_KEY:
352
+ return self.BRAINTRUST_API_KEY.get_secret_value()
353
+ return None
354
+
355
+ def get_pinecone_api_key(self) -> str | None:
356
+ """
357
+ Get the Pinecone API key if configured.
358
+
359
+ Returns the secret value - use with caution to avoid logging.
360
+ """
361
+ if self.PINECONE_API_KEY:
362
+ return self.PINECONE_API_KEY.get_secret_value()
363
+ return None
364
+
365
+ def get_langsmith_api_key(self) -> str | None:
366
+ """
367
+ Get the LangSmith API key if configured.
368
+
369
+ Returns the secret value - use with caution to avoid logging.
370
+ """
371
+ if self.LANGSMITH_API_KEY:
372
+ return self.LANGSMITH_API_KEY.get_secret_value()
373
+ return None
374
+
375
+ def get_wandb_api_key(self) -> str | None:
376
+ """
377
+ Get the Weights & Biases API key if configured.
378
+
379
+ Returns the secret value - use with caution to avoid logging.
380
+ """
381
+ if self.WANDB_API_KEY:
382
+ return self.WANDB_API_KEY.get_secret_value()
383
+ return None
384
+
385
+ def __repr__(self) -> str:
386
+ """Safe string representation that doesn't expose secrets."""
387
+ return f"Settings(LLM_PROVIDER={self.LLM_PROVIDER}, MCTS_ENABLED={self.MCTS_ENABLED}, MCTS_IMPL={self.MCTS_IMPL}, LOG_LEVEL={self.LOG_LEVEL})"
388
+
389
+
390
+ # Global settings instance (lazily loaded)
391
+ _settings: Settings | None = None
392
+
393
+
394
+ def get_settings() -> Settings:
395
+ """
396
+ Get the global settings instance.
397
+
398
+ Settings are loaded once and cached. To reload, call reset_settings() first.
399
+
400
+ Returns:
401
+ Settings: Application configuration instance
402
+
403
+ Raises:
404
+ ValidationError: If configuration is invalid
405
+ """
406
+ global _settings
407
+ if _settings is None:
408
+ _settings = Settings()
409
+ return _settings
410
+
411
+
412
+ def reset_settings() -> None:
413
+ """
414
+ Reset the global settings instance.
415
+
416
+ Forces settings to be reloaded from environment on next get_settings() call.
417
+ Useful for testing.
418
+ """
419
+ global _settings
420
+ _settings = None
421
+
422
+
423
+ # Type exports for external use
424
+ __all__ = [
425
+ "Settings",
426
+ "LLMProvider",
427
+ "LogLevel",
428
+ "MCTSImplementation",
429
+ "get_settings",
430
+ "reset_settings",
431
+ ]
src/data/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Integration Module for Multi-Agent MCTS Training.
3
+
4
+ This module provides utilities for loading, preprocessing, and managing
5
+ open-source datasets for training HRM/TRM agents and neural meta-controllers.
6
+
7
+ Supported Datasets:
8
+ - DABStep: Multi-step reasoning tasks (CC-BY-4.0)
9
+ - PRIMUS-Seed: Cybersecurity domain knowledge (ODC-BY)
10
+ - PRIMUS-Instruct: Instruction fine-tuning data (ODC-BY)
11
+ """
12
+
13
+ from .dataset_loader import DABStepLoader, DatasetLoader, PRIMUSLoader
14
+ from .preprocessing import TextPreprocessor, TokenizerWrapper
15
+ from .tactical_augmentation import TacticalAugmenter
16
+ from .train_test_split import DataSplitter, StratifiedSplitter
17
+
18
+ __all__ = [
19
+ "DatasetLoader",
20
+ "DABStepLoader",
21
+ "PRIMUSLoader",
22
+ "TextPreprocessor",
23
+ "TokenizerWrapper",
24
+ "TacticalAugmenter",
25
+ "DataSplitter",
26
+ "StratifiedSplitter",
27
+ ]
28
+
29
+ __version__ = "1.0.0"
src/data/dataset_loader.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Loading Module for Open-Source Training Data.
3
+
4
+ Provides unified loading interfaces for:
5
+ - DABStep: Multi-step data analysis reasoning
6
+ - PRIMUS: Cybersecurity domain knowledge
7
+ - Custom tactical datasets
8
+
9
+ License Attribution:
10
+ - DABStep: CC-BY-4.0 (Creative Commons Attribution)
11
+ - PRIMUS: ODC-BY (Open Data Commons Attribution)
12
+ """
13
+
14
+ import logging
15
+ from abc import ABC, abstractmethod
16
+ from collections.abc import Iterator
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class DatasetSample:
26
+ """Standardized representation of a dataset sample."""
27
+
28
+ id: str
29
+ text: str
30
+ metadata: dict[str, Any] = field(default_factory=dict)
31
+ labels: list[str] | None = None
32
+ difficulty: str | None = None
33
+ domain: str | None = None
34
+ reasoning_steps: list[str] | None = None
35
+
36
+
37
+ @dataclass
38
+ class DatasetStatistics:
39
+ """Statistics about a loaded dataset."""
40
+
41
+ total_samples: int
42
+ domains: dict[str, int]
43
+ avg_text_length: float
44
+ difficulty_distribution: dict[str, int]
45
+ total_tokens: int = 0
46
+
47
+
48
+ class DatasetLoader(ABC):
49
+ """Abstract base class for dataset loaders."""
50
+
51
+ def __init__(self, cache_dir: str | None = None):
52
+ """
53
+ Initialize dataset loader.
54
+
55
+ Args:
56
+ cache_dir: Directory to cache downloaded datasets
57
+ """
58
+ self.cache_dir = cache_dir or str(Path.home() / ".cache" / "mcts_datasets")
59
+ self._dataset = None
60
+ self._statistics = None
61
+
62
+ @abstractmethod
63
+ def load(self, split: str = "train") -> list[DatasetSample]:
64
+ """Load dataset split."""
65
+ pass
66
+
67
+ @abstractmethod
68
+ def get_statistics(self) -> DatasetStatistics:
69
+ """Get dataset statistics."""
70
+ pass
71
+
72
+ @abstractmethod
73
+ def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]:
74
+ """Iterate over samples in batches."""
75
+ pass
76
+
77
+
78
+ class DABStepLoader(DatasetLoader):
79
+ """
80
+ Loader for DABStep Multi-Step Reasoning Dataset.
81
+
82
+ DABStep contains 450+ data analysis tasks requiring sequential,
83
+ iterative problem-solving. Perfect for training HRM/TRM agents.
84
+
85
+ License: CC-BY-4.0 (Attribution required)
86
+ Source: huggingface.co/datasets/adyen/DABstep
87
+ """
88
+
89
+ DATASET_NAME = "adyen/DABstep"
90
+ DIFFICULTIES = ["easy", "medium", "hard"]
91
+
92
+ def __init__(self, cache_dir: str | None = None):
93
+ """Initialize DABStep loader."""
94
+ super().__init__(cache_dir)
95
+ self._loaded_samples: list[DatasetSample] = []
96
+
97
+ def load(self, split: str = "train", difficulty: str | None = None) -> list[DatasetSample]:
98
+ """
99
+ Load DABStep dataset.
100
+
101
+ Args:
102
+ split: Dataset split ('train', 'validation', 'test')
103
+ difficulty: Filter by difficulty ('easy', 'medium', 'hard')
104
+
105
+ Returns:
106
+ List of DatasetSample objects
107
+ """
108
+ try:
109
+ from datasets import load_dataset
110
+
111
+ logger.info(f"Loading DABStep dataset (split={split})")
112
+
113
+ dataset = load_dataset(
114
+ self.DATASET_NAME,
115
+ cache_dir=self.cache_dir,
116
+ )
117
+
118
+ if split not in dataset:
119
+ available_splits = list(dataset.keys())
120
+ logger.warning(f"Split '{split}' not found. Available: {available_splits}")
121
+ split = available_splits[0] if available_splits else "train"
122
+
123
+ samples = []
124
+ for idx, item in enumerate(dataset[split]):
125
+ sample = DatasetSample(
126
+ id=f"dabstep_{split}_{idx}",
127
+ text=str(item.get("question", item.get("text", ""))),
128
+ metadata={
129
+ "source": "DABStep",
130
+ "license": "CC-BY-4.0",
131
+ "split": split,
132
+ "original_data": item,
133
+ },
134
+ difficulty=item.get("difficulty", "medium"),
135
+ domain="data_analysis",
136
+ reasoning_steps=item.get("steps", []),
137
+ )
138
+
139
+ if difficulty and sample.difficulty != difficulty:
140
+ continue
141
+
142
+ samples.append(sample)
143
+
144
+ self._loaded_samples = samples
145
+ logger.info(f"Loaded {len(samples)} DABStep samples")
146
+ return samples
147
+
148
+ except ImportError:
149
+ logger.error("datasets library not installed. Run: pip install datasets")
150
+ raise
151
+ except Exception as e:
152
+ logger.error(f"Failed to load DABStep: {e}")
153
+ raise
154
+
155
+ def get_statistics(self) -> DatasetStatistics:
156
+ """Get statistics about loaded DABStep data."""
157
+ if not self._loaded_samples:
158
+ raise ValueError("No samples loaded. Call load() first.")
159
+
160
+ difficulty_dist = {}
161
+ total_length = 0
162
+
163
+ for sample in self._loaded_samples:
164
+ diff = sample.difficulty or "unknown"
165
+ difficulty_dist[diff] = difficulty_dist.get(diff, 0) + 1
166
+ total_length += len(sample.text)
167
+
168
+ return DatasetStatistics(
169
+ total_samples=len(self._loaded_samples),
170
+ domains={"data_analysis": len(self._loaded_samples)},
171
+ avg_text_length=total_length / len(self._loaded_samples),
172
+ difficulty_distribution=difficulty_dist,
173
+ )
174
+
175
+ def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]:
176
+ """Iterate over samples in batches."""
177
+ if not self._loaded_samples:
178
+ raise ValueError("No samples loaded. Call load() first.")
179
+
180
+ for i in range(0, len(self._loaded_samples), batch_size):
181
+ yield self._loaded_samples[i : i + batch_size]
182
+
183
+ def get_reasoning_tasks(self) -> list[DatasetSample]:
184
+ """Get only samples with explicit reasoning steps."""
185
+ return [s for s in self._loaded_samples if s.reasoning_steps]
186
+
187
+
188
+ class PRIMUSLoader(DatasetLoader):
189
+ """
190
+ Loader for PRIMUS Cybersecurity Dataset Suite.
191
+
192
+ PRIMUS contains:
193
+ - Seed: 674,848 cybersecurity documents (190M tokens)
194
+ - Instruct: 835 instruction-tuning samples
195
+ - Reasoning: Self-reflection data for reasoning
196
+
197
+ License: ODC-BY (Open Data Commons Attribution)
198
+ Source: huggingface.co/datasets/trendmicro-ailab/Primus-Seed
199
+ """
200
+
201
+ SEED_DATASET = "trendmicro-ailab/Primus-Seed"
202
+ INSTRUCT_DATASET = "trendmicro-ailab/Primus-Instruct"
203
+
204
+ DOMAINS = [
205
+ "mitre_attack",
206
+ "wikipedia",
207
+ "company_sites",
208
+ "threat_intelligence",
209
+ "vulnerability_db",
210
+ ]
211
+
212
+ def __init__(self, cache_dir: str | None = None):
213
+ """Initialize PRIMUS loader."""
214
+ super().__init__(cache_dir)
215
+ self._seed_samples: list[DatasetSample] = []
216
+ self._instruct_samples: list[DatasetSample] = []
217
+
218
+ def load(
219
+ self,
220
+ split: str = "train",
221
+ dataset_type: str = "seed",
222
+ domains: list[str] | None = None,
223
+ max_samples: int | None = None,
224
+ streaming: bool = True,
225
+ ) -> list[DatasetSample]:
226
+ """
227
+ Load PRIMUS dataset.
228
+
229
+ Args:
230
+ split: Dataset split ('train', 'validation', 'test')
231
+ dataset_type: 'seed' for knowledge base, 'instruct' for fine-tuning
232
+ domains: Filter by specific domains
233
+ max_samples: Limit number of samples (useful for large datasets)
234
+ streaming: Use streaming mode for large datasets (default True)
235
+
236
+ Returns:
237
+ List of DatasetSample objects
238
+ """
239
+ try:
240
+ from datasets import load_dataset
241
+
242
+ dataset_name = self.SEED_DATASET if dataset_type == "seed" else self.INSTRUCT_DATASET
243
+
244
+ logger.info(f"Loading PRIMUS {dataset_type} dataset")
245
+
246
+ # Use streaming for large seed dataset to avoid download issues
247
+ use_streaming = streaming and dataset_type == "seed" and max_samples is not None
248
+
249
+ if use_streaming:
250
+ logger.info(f"Using streaming mode (max_samples={max_samples})")
251
+ dataset = load_dataset(
252
+ dataset_name,
253
+ "default",
254
+ streaming=True,
255
+ cache_dir=self.cache_dir,
256
+ )
257
+ # For streaming, iterate the first available split
258
+ data_iter = iter(dataset["train"]) if "train" in dataset else iter(dataset[list(dataset.keys())[0]])
259
+ else:
260
+ dataset = load_dataset(
261
+ dataset_name,
262
+ cache_dir=self.cache_dir,
263
+ )
264
+
265
+ if split not in dataset:
266
+ available_splits = list(dataset.keys())
267
+ logger.warning(f"Split '{split}' not found. Using: {available_splits[0]}")
268
+ split = available_splits[0]
269
+
270
+ data_iter = iter(dataset[split])
271
+
272
+ samples = []
273
+ count = 0
274
+
275
+ for idx, item in enumerate(data_iter):
276
+ if max_samples and count >= max_samples:
277
+ break
278
+
279
+ domain = item.get("domain", item.get("source", "unknown"))
280
+
281
+ if domains and domain not in domains:
282
+ continue
283
+
284
+ if dataset_type == "instruct":
285
+ text = f"Instruction: {item.get('instruction', '')}\nResponse: {item.get('response', '')}"
286
+ else:
287
+ text = str(item.get("text", item.get("content", "")))
288
+
289
+ sample = DatasetSample(
290
+ id=f"primus_{dataset_type}_{split}_{idx}",
291
+ text=text,
292
+ metadata={
293
+ "source": f"PRIMUS-{dataset_type.capitalize()}",
294
+ "license": "ODC-BY",
295
+ "split": split,
296
+ "original_domain": domain,
297
+ },
298
+ domain=domain,
299
+ labels=item.get("labels", item.get("tags", [])),
300
+ )
301
+
302
+ samples.append(sample)
303
+ count += 1
304
+
305
+ if dataset_type == "seed":
306
+ self._seed_samples = samples
307
+ else:
308
+ self._instruct_samples = samples
309
+
310
+ logger.info(f"Loaded {len(samples)} PRIMUS {dataset_type} samples")
311
+ return samples
312
+
313
+ except ImportError:
314
+ logger.error("datasets library not installed. Run: pip install datasets")
315
+ raise
316
+ except Exception as e:
317
+ if "gated dataset" in str(e):
318
+ logger.error(
319
+ f"PRIMUS is a gated dataset. Please authenticate with HuggingFace:\n"
320
+ f"1. Create account at https://huggingface.co/\n"
321
+ f"2. Accept dataset terms at https://huggingface.co/datasets/{dataset_name}\n"
322
+ f"3. Create token at https://huggingface.co/settings/tokens\n"
323
+ f"4. Run: huggingface-cli login"
324
+ )
325
+ else:
326
+ logger.error(f"Failed to load PRIMUS: {e}")
327
+ raise
328
+
329
+ def load_seed(self, max_samples: int | None = None) -> list[DatasetSample]:
330
+ """Load PRIMUS-Seed knowledge base."""
331
+ return self.load(dataset_type="seed", max_samples=max_samples)
332
+
333
+ def load_instruct(self) -> list[DatasetSample]:
334
+ """Load PRIMUS-Instruct fine-tuning data."""
335
+ return self.load(dataset_type="instruct", streaming=False)
336
+
337
+ def get_statistics(self) -> DatasetStatistics:
338
+ """Get statistics about loaded PRIMUS data."""
339
+ all_samples = self._seed_samples + self._instruct_samples
340
+
341
+ if not all_samples:
342
+ raise ValueError("No samples loaded. Call load() first.")
343
+
344
+ domain_dist = {}
345
+ total_length = 0
346
+
347
+ for sample in all_samples:
348
+ domain = sample.domain or "unknown"
349
+ domain_dist[domain] = domain_dist.get(domain, 0) + 1
350
+ total_length += len(sample.text)
351
+
352
+ return DatasetStatistics(
353
+ total_samples=len(all_samples),
354
+ domains=domain_dist,
355
+ avg_text_length=total_length / len(all_samples),
356
+ difficulty_distribution={"cybersecurity": len(all_samples)},
357
+ )
358
+
359
+ def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]:
360
+ """Iterate over all loaded samples in batches."""
361
+ all_samples = self._seed_samples + self._instruct_samples
362
+
363
+ if not all_samples:
364
+ raise ValueError("No samples loaded. Call load() first.")
365
+
366
+ for i in range(0, len(all_samples), batch_size):
367
+ yield all_samples[i : i + batch_size]
368
+
369
+ def get_mitre_attack_samples(self) -> list[DatasetSample]:
370
+ """Get samples specifically from MITRE ATT&CK."""
371
+ return [s for s in self._seed_samples if "mitre" in (s.domain or "").lower()]
372
+
373
+ def get_threat_intelligence_samples(self) -> list[DatasetSample]:
374
+ """Get threat intelligence related samples."""
375
+ return [
376
+ s
377
+ for s in self._seed_samples
378
+ if any(kw in (s.domain or "").lower() for kw in ["threat", "cti", "intelligence"])
379
+ ]
380
+
381
+
382
+ class CombinedDatasetLoader:
383
+ """
384
+ Unified loader for combining multiple datasets.
385
+
386
+ Provides a single interface for loading and managing:
387
+ - DABStep (multi-step reasoning)
388
+ - PRIMUS (cybersecurity knowledge)
389
+ - Custom tactical datasets
390
+ """
391
+
392
+ def __init__(self, cache_dir: str | None = None):
393
+ """Initialize combined loader."""
394
+ self.cache_dir = cache_dir
395
+ self.dabstep_loader = DABStepLoader(cache_dir)
396
+ self.primus_loader = PRIMUSLoader(cache_dir)
397
+ self._all_samples: list[DatasetSample] = []
398
+
399
+ def load_all(
400
+ self,
401
+ dabstep_split: str = "train",
402
+ primus_max_samples: int | None = 10000,
403
+ include_instruct: bool = True,
404
+ ) -> list[DatasetSample]:
405
+ """
406
+ Load all datasets.
407
+
408
+ Args:
409
+ dabstep_split: Split for DABStep
410
+ primus_max_samples: Max samples from PRIMUS-Seed (None for all)
411
+ include_instruct: Whether to include PRIMUS-Instruct
412
+
413
+ Returns:
414
+ Combined list of all samples
415
+ """
416
+ logger.info("Loading combined datasets")
417
+
418
+ # Load DABStep
419
+ dabstep_samples = self.dabstep_loader.load(split=dabstep_split)
420
+ logger.info(f"DABStep: {len(dabstep_samples)} samples")
421
+
422
+ # Load PRIMUS-Seed
423
+ primus_seed = self.primus_loader.load_seed(max_samples=primus_max_samples)
424
+ logger.info(f"PRIMUS-Seed: {len(primus_seed)} samples")
425
+
426
+ # Load PRIMUS-Instruct
427
+ primus_instruct = []
428
+ if include_instruct:
429
+ primus_instruct = self.primus_loader.load_instruct()
430
+ logger.info(f"PRIMUS-Instruct: {len(primus_instruct)} samples")
431
+
432
+ self._all_samples = dabstep_samples + primus_seed + primus_instruct
433
+ logger.info(f"Total combined samples: {len(self._all_samples)}")
434
+
435
+ return self._all_samples
436
+
437
+ def get_domain_distribution(self) -> dict[str, int]:
438
+ """Get distribution of samples across domains."""
439
+ dist = {}
440
+ for sample in self._all_samples:
441
+ domain = sample.domain or "unknown"
442
+ dist[domain] = dist.get(domain, 0) + 1
443
+ return dist
444
+
445
+ def filter_by_domain(self, domain: str) -> list[DatasetSample]:
446
+ """Filter samples by domain."""
447
+ return [s for s in self._all_samples if s.domain == domain]
448
+
449
+ def get_multi_step_reasoning_samples(self) -> list[DatasetSample]:
450
+ """Get samples suitable for multi-step reasoning training."""
451
+ return [
452
+ s
453
+ for s in self._all_samples
454
+ if s.reasoning_steps or s.domain == "data_analysis" or "instruct" in s.metadata.get("source", "").lower()
455
+ ]
456
+
457
+ def export_for_training(self, output_path: str, format: str = "jsonl") -> str:
458
+ """
459
+ Export dataset for training.
460
+
461
+ Args:
462
+ output_path: Path to save exported data
463
+ format: Export format ('jsonl', 'csv', 'parquet')
464
+
465
+ Returns:
466
+ Path to exported file
467
+ """
468
+ import json
469
+
470
+ output_file = Path(output_path)
471
+ output_file.parent.mkdir(parents=True, exist_ok=True)
472
+
473
+ if format == "jsonl":
474
+ with open(output_file, "w", encoding="utf-8") as f:
475
+ for sample in self._all_samples:
476
+ record = {
477
+ "id": sample.id,
478
+ "text": sample.text,
479
+ "domain": sample.domain,
480
+ "difficulty": sample.difficulty,
481
+ "labels": sample.labels,
482
+ "metadata": sample.metadata,
483
+ }
484
+ f.write(json.dumps(record) + "\n")
485
+ else:
486
+ raise NotImplementedError(f"Format {format} not yet supported")
487
+
488
+ logger.info(f"Exported {len(self._all_samples)} samples to {output_file}")
489
+ return str(output_file)
490
+
491
+
492
+ def load_dataset(
493
+ dataset_name: str,
494
+ split: str = "train",
495
+ cache_dir: str | None = None,
496
+ **kwargs,
497
+ ) -> Any:
498
+ """
499
+ Unified interface for loading datasets from HuggingFace.
500
+
501
+ This function provides compatibility with the standard HuggingFace datasets API.
502
+ It wraps the underlying load_dataset function from the datasets library.
503
+
504
+ Args:
505
+ dataset_name: HuggingFace dataset identifier (e.g., "adyen/DABstep")
506
+ split: Dataset split to load ("train", "validation", "test")
507
+ cache_dir: Optional directory for caching downloaded datasets
508
+ **kwargs: Additional arguments passed to datasets.load_dataset
509
+
510
+ Returns:
511
+ HuggingFace Dataset object or dict of Dataset objects
512
+
513
+ Raises:
514
+ ImportError: If datasets library is not installed
515
+ Exception: If dataset loading fails
516
+
517
+ Examples:
518
+ >>> # Load DABStep dataset
519
+ >>> dataset = load_dataset("adyen/DABstep")
520
+ >>> samples = dataset["train"]
521
+
522
+ >>> # Load PRIMUS-Seed with custom cache
523
+ >>> dataset = load_dataset("trendmicro-ailab/Primus-Seed", cache_dir="/tmp/cache")
524
+
525
+ License Attribution:
526
+ - DABStep: CC-BY-4.0 (Creative Commons Attribution 4.0)
527
+ - PRIMUS: ODC-BY (Open Data Commons Attribution)
528
+ """
529
+ try:
530
+ from datasets import load_dataset as hf_load_dataset
531
+
532
+ logger.info(f"Loading dataset: {dataset_name} (split={split})")
533
+
534
+ load_kwargs = {
535
+ **kwargs,
536
+ }
537
+
538
+ if cache_dir:
539
+ load_kwargs["cache_dir"] = cache_dir
540
+
541
+ dataset = hf_load_dataset(dataset_name, **load_kwargs)
542
+
543
+ logger.info(f"Successfully loaded dataset: {dataset_name}")
544
+ return dataset
545
+
546
+ except ImportError:
547
+ logger.error("datasets library not installed. Run: pip install datasets")
548
+ raise ImportError("The datasets library is required but not installed. Install it with: pip install datasets")
549
+ except Exception as e:
550
+ logger.error(f"Failed to load dataset {dataset_name}: {e}")
551
+ raise
src/data/preprocessing.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text Preprocessing Module for Training Data.
3
+
4
+ Provides utilities for:
5
+ - Text cleaning and normalization
6
+ - Tokenization with various backends
7
+ - Feature extraction for meta-controller training
8
+ """
9
+
10
+ import logging
11
+ import re
12
+ from dataclasses import dataclass
13
+ from typing import Any
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class PreprocessedText:
20
+ """Preprocessed text with metadata."""
21
+
22
+ original: str
23
+ cleaned: str
24
+ tokens: list[str]
25
+ token_ids: list[int] | None = None
26
+ features: dict[str, Any] | None = None
27
+
28
+
29
+ class TextPreprocessor:
30
+ """
31
+ Text preprocessing pipeline for multi-agent training data.
32
+
33
+ Handles:
34
+ - HTML/XML tag removal
35
+ - Special character normalization
36
+ - Whitespace cleanup
37
+ - Domain-specific preprocessing (cyber, military, etc.)
38
+ """
39
+
40
+ # Patterns for cleaning
41
+ HTML_TAG_PATTERN = re.compile(r"<[^>]+>")
42
+ URL_PATTERN = re.compile(r"https?://\S+|www\.\S+")
43
+ MULTIPLE_SPACES = re.compile(r"\s+")
44
+ SPECIAL_CHARS = re.compile(r"[^\w\s\-.,!?;:()[\]{}\"'/]")
45
+
46
+ # Domain-specific patterns
47
+ IP_ADDRESS_PATTERN = re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b")
48
+ CVE_PATTERN = re.compile(r"CVE-\d{4}-\d{4,}")
49
+ MITRE_TECHNIQUE_PATTERN = re.compile(r"T\d{4}(?:\.\d{3})?")
50
+
51
+ def __init__(
52
+ self,
53
+ remove_html: bool = True,
54
+ normalize_urls: bool = True,
55
+ lowercase: bool = False,
56
+ preserve_domain_patterns: bool = True,
57
+ ):
58
+ """
59
+ Initialize preprocessor.
60
+
61
+ Args:
62
+ remove_html: Remove HTML/XML tags
63
+ normalize_urls: Replace URLs with placeholder
64
+ lowercase: Convert to lowercase
65
+ preserve_domain_patterns: Keep domain-specific patterns (IPs, CVEs, etc.)
66
+ """
67
+ self.remove_html = remove_html
68
+ self.normalize_urls = normalize_urls
69
+ self.lowercase = lowercase
70
+ self.preserve_domain_patterns = preserve_domain_patterns
71
+
72
+ def clean(self, text: str) -> str:
73
+ """
74
+ Clean and normalize text.
75
+
76
+ Args:
77
+ text: Raw input text
78
+
79
+ Returns:
80
+ Cleaned text
81
+ """
82
+ if not text:
83
+ return ""
84
+
85
+ result = text
86
+
87
+ # Remove HTML tags
88
+ if self.remove_html:
89
+ result = self.HTML_TAG_PATTERN.sub(" ", result)
90
+
91
+ # Preserve or normalize URLs
92
+ if self.normalize_urls:
93
+ if self.preserve_domain_patterns:
94
+ result = self.URL_PATTERN.sub("[URL]", result)
95
+ else:
96
+ result = self.URL_PATTERN.sub("", result)
97
+
98
+ # Normalize whitespace
99
+ result = self.MULTIPLE_SPACES.sub(" ", result)
100
+
101
+ # Lowercase if requested
102
+ if self.lowercase:
103
+ result = result.lower()
104
+
105
+ # Strip leading/trailing whitespace
106
+ result = result.strip()
107
+
108
+ return result
109
+
110
+ def extract_domain_features(self, text: str) -> dict[str, Any]:
111
+ """
112
+ Extract domain-specific features from text.
113
+
114
+ Args:
115
+ text: Input text
116
+
117
+ Returns:
118
+ Dictionary of extracted features
119
+ """
120
+ features = {
121
+ "has_ip_addresses": bool(self.IP_ADDRESS_PATTERN.search(text)),
122
+ "ip_count": len(self.IP_ADDRESS_PATTERN.findall(text)),
123
+ "has_cve": bool(self.CVE_PATTERN.search(text)),
124
+ "cve_ids": self.CVE_PATTERN.findall(text),
125
+ "has_mitre_techniques": bool(self.MITRE_TECHNIQUE_PATTERN.search(text)),
126
+ "mitre_techniques": self.MITRE_TECHNIQUE_PATTERN.findall(text),
127
+ "text_length": len(text),
128
+ "word_count": len(text.split()),
129
+ "sentence_count": len(re.findall(r"[.!?]+", text)),
130
+ }
131
+
132
+ # Detect domain indicators
133
+ domain_keywords = {
134
+ "cybersecurity": ["attack", "vulnerability", "exploit", "malware", "threat"],
135
+ "military": ["tactical", "reconnaissance", "deployment", "terrain", "objective"],
136
+ "data_analysis": ["dataset", "analysis", "correlation", "statistics", "visualization"],
137
+ }
138
+
139
+ for domain, keywords in domain_keywords.items():
140
+ features[f"is_{domain}"] = any(kw in text.lower() for kw in keywords)
141
+
142
+ return features
143
+
144
+ def preprocess(self, text: str) -> PreprocessedText:
145
+ """
146
+ Full preprocessing pipeline.
147
+
148
+ Args:
149
+ text: Raw input text
150
+
151
+ Returns:
152
+ PreprocessedText object with all preprocessing results
153
+ """
154
+ cleaned = self.clean(text)
155
+ tokens = cleaned.split() # Simple whitespace tokenization
156
+ features = self.extract_domain_features(text)
157
+
158
+ return PreprocessedText(
159
+ original=text,
160
+ cleaned=cleaned,
161
+ tokens=tokens,
162
+ features=features,
163
+ )
164
+
165
+ def batch_preprocess(self, texts: list[str]) -> list[PreprocessedText]:
166
+ """
167
+ Preprocess multiple texts.
168
+
169
+ Args:
170
+ texts: List of raw texts
171
+
172
+ Returns:
173
+ List of PreprocessedText objects
174
+ """
175
+ return [self.preprocess(text) for text in texts]
176
+
177
+
178
+ class TokenizerWrapper:
179
+ """
180
+ Wrapper for various tokenization backends.
181
+
182
+ Supports:
183
+ - Simple whitespace tokenization
184
+ - HuggingFace tokenizers
185
+ - Custom vocabularies
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ backend: str = "simple",
191
+ model_name: str | None = None,
192
+ max_length: int = 512,
193
+ ):
194
+ """
195
+ Initialize tokenizer.
196
+
197
+ Args:
198
+ backend: Tokenizer backend ('simple', 'huggingface', 'custom')
199
+ model_name: Model name for HuggingFace tokenizer
200
+ max_length: Maximum sequence length
201
+ """
202
+ self.backend = backend
203
+ self.model_name = model_name
204
+ self.max_length = max_length
205
+ self._tokenizer = None
206
+
207
+ if backend == "huggingface" and model_name:
208
+ self._load_huggingface_tokenizer()
209
+
210
+ def _load_huggingface_tokenizer(self):
211
+ """Load HuggingFace tokenizer."""
212
+ try:
213
+ from transformers import AutoTokenizer
214
+
215
+ self._tokenizer = AutoTokenizer.from_pretrained(
216
+ self.model_name,
217
+ model_max_length=self.max_length,
218
+ )
219
+ logger.info(f"Loaded HuggingFace tokenizer: {self.model_name}")
220
+ except ImportError:
221
+ logger.error("transformers library not installed. Run: pip install transformers")
222
+ raise
223
+
224
+ def tokenize(self, text: str) -> tuple[list[str], list[int] | None]:
225
+ """
226
+ Tokenize text.
227
+
228
+ Args:
229
+ text: Input text
230
+
231
+ Returns:
232
+ Tuple of (tokens, token_ids)
233
+ """
234
+ if self.backend == "simple":
235
+ tokens = text.split()[: self.max_length]
236
+ return tokens, None
237
+
238
+ elif self.backend == "huggingface" and self._tokenizer:
239
+ encoded = self._tokenizer(
240
+ text,
241
+ truncation=True,
242
+ max_length=self.max_length,
243
+ return_tensors=None,
244
+ )
245
+ tokens = self._tokenizer.convert_ids_to_tokens(encoded["input_ids"])
246
+ token_ids = encoded["input_ids"]
247
+ return tokens, token_ids
248
+
249
+ else:
250
+ raise ValueError(f"Unsupported backend: {self.backend}")
251
+
252
+ def batch_tokenize(self, texts: list[str]) -> list[tuple[list[str], list[int] | None]]:
253
+ """
254
+ Tokenize multiple texts.
255
+
256
+ Args:
257
+ texts: List of input texts
258
+
259
+ Returns:
260
+ List of (tokens, token_ids) tuples
261
+ """
262
+ return [self.tokenize(text) for text in texts]
263
+
264
+ def encode_for_training(self, texts: list[str]) -> dict[str, Any]:
265
+ """
266
+ Encode texts for model training.
267
+
268
+ Args:
269
+ texts: List of input texts
270
+
271
+ Returns:
272
+ Dictionary with encoded data ready for training
273
+ """
274
+ if self.backend != "huggingface" or not self._tokenizer:
275
+ raise ValueError("encode_for_training requires HuggingFace backend")
276
+
277
+ encoded = self._tokenizer(
278
+ texts,
279
+ truncation=True,
280
+ padding=True,
281
+ max_length=self.max_length,
282
+ return_tensors="pt",
283
+ )
284
+
285
+ return encoded
286
+
287
+
288
+ class MetaControllerFeatureExtractor:
289
+ """
290
+ Extract features for meta-controller training.
291
+
292
+ Converts text and agent state information into numerical features
293
+ suitable for RNN/BERT routing decisions.
294
+ """
295
+
296
+ def __init__(self):
297
+ """Initialize feature extractor."""
298
+ self.preprocessor = TextPreprocessor()
299
+
300
+ def extract_query_features(self, query: str) -> dict[str, float]:
301
+ """
302
+ Extract numerical features from query text.
303
+
304
+ Args:
305
+ query: User query text
306
+
307
+ Returns:
308
+ Dictionary of numerical features
309
+ """
310
+ domain_features = self.preprocessor.extract_domain_features(query)
311
+
312
+ features = {
313
+ "query_length": domain_features["text_length"] / 10000, # Normalize
314
+ "word_count": domain_features["word_count"] / 500,
315
+ "sentence_count": domain_features["sentence_count"] / 50,
316
+ "has_technical_terms": float(
317
+ domain_features["has_ip_addresses"]
318
+ or domain_features["has_cve"]
319
+ or domain_features["has_mitre_techniques"]
320
+ ),
321
+ "is_cybersecurity": float(domain_features["is_cybersecurity"]),
322
+ "is_military": float(domain_features["is_military"]),
323
+ "is_data_analysis": float(domain_features["is_data_analysis"]),
324
+ "complexity_score": self._estimate_complexity(query),
325
+ }
326
+
327
+ return features
328
+
329
+ def _estimate_complexity(self, text: str) -> float:
330
+ """
331
+ Estimate query complexity (0-1 scale).
332
+
333
+ Args:
334
+ text: Input text
335
+
336
+ Returns:
337
+ Complexity score
338
+ """
339
+ # Simple heuristic based on length, technical terms, etc.
340
+ score = 0.0
341
+
342
+ # Length factor
343
+ word_count = len(text.split())
344
+ if word_count > 50:
345
+ score += 0.3
346
+ elif word_count > 20:
347
+ score += 0.1
348
+
349
+ # Technical term factor
350
+ technical_indicators = [
351
+ "analyze",
352
+ "compare",
353
+ "evaluate",
354
+ "synthesize",
355
+ "strategic",
356
+ "tactical",
357
+ "multi-step",
358
+ "consider",
359
+ ]
360
+ for term in technical_indicators:
361
+ if term in text.lower():
362
+ score += 0.1
363
+
364
+ # Question complexity
365
+ if "?" in text:
366
+ if any(kw in text.lower() for kw in ["why", "how", "what if"]):
367
+ score += 0.2
368
+ else:
369
+ score += 0.1
370
+
371
+ return min(score, 1.0)
372
+
373
+ def extract_agent_state_features(
374
+ self,
375
+ hrm_confidence: float = 0.0,
376
+ trm_confidence: float = 0.0,
377
+ mcts_iterations: int = 0,
378
+ consensus_score: float = 0.0,
379
+ rag_retrieved: int = 0,
380
+ ) -> list[float]:
381
+ """
382
+ Extract features from current agent state.
383
+
384
+ Args:
385
+ hrm_confidence: HRM agent confidence
386
+ trm_confidence: TRM agent confidence
387
+ mcts_iterations: MCTS iterations completed
388
+ consensus_score: Inter-agent consensus
389
+ rag_retrieved: Number of RAG documents retrieved
390
+
391
+ Returns:
392
+ List of normalized features (10-dimensional)
393
+ """
394
+ return [
395
+ hrm_confidence,
396
+ trm_confidence,
397
+ min(mcts_iterations / 1000, 1.0),
398
+ consensus_score,
399
+ min(rag_retrieved / 20, 1.0),
400
+ # Derived features
401
+ abs(hrm_confidence - trm_confidence), # Disagreement
402
+ (hrm_confidence + trm_confidence) / 2, # Average confidence
403
+ float(mcts_iterations > 0), # MCTS active
404
+ float(consensus_score > 0.7), # High consensus
405
+ float(rag_retrieved > 0), # RAG used
406
+ ]
src/data/tactical_augmentation.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tactical Data Augmentation Module.
3
+
4
+ Provides domain-specific data augmentation techniques for:
5
+ - Cybersecurity threat scenarios
6
+ - Military tactical situations
7
+ - Multi-step reasoning problems
8
+
9
+ These augmentations help increase training data diversity and improve
10
+ model robustness for tactical analysis tasks.
11
+ """
12
+
13
+ import logging
14
+ import random
15
+ from dataclasses import dataclass
16
+
17
+ from .dataset_loader import DatasetSample
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class AugmentationResult:
24
+ """Result of data augmentation."""
25
+
26
+ original: DatasetSample
27
+ augmented: list[DatasetSample]
28
+ augmentation_types: list[str]
29
+
30
+
31
+ class TacticalAugmenter:
32
+ """
33
+ Domain-specific data augmentation for tactical analysis.
34
+
35
+ Augmentation techniques:
36
+ - Paraphrasing tactical scenarios
37
+ - Varying urgency levels
38
+ - Adding/removing constraints
39
+ - Scenario parameter variation
40
+ - Threat actor substitution
41
+ - Temporal shifting
42
+ """
43
+
44
+ # Tactical scenario templates
45
+ URGENCY_MODIFIERS = {
46
+ "high": ["IMMEDIATE", "CRITICAL", "URGENT", "TIME-SENSITIVE"],
47
+ "medium": ["PRIORITY", "IMPORTANT", "ATTENTION REQUIRED"],
48
+ "low": ["ROUTINE", "STANDARD", "WHEN POSSIBLE"],
49
+ }
50
+
51
+ THREAT_ACTORS = [
52
+ "APT28",
53
+ "APT29",
54
+ "Lazarus Group",
55
+ "Cozy Bear",
56
+ "Fancy Bear",
57
+ "Unknown Actor",
58
+ "Nation-State Actor",
59
+ "Criminal Organization",
60
+ ]
61
+
62
+ ATTACK_VECTORS = [
63
+ "phishing",
64
+ "spear-phishing",
65
+ "watering hole",
66
+ "supply chain compromise",
67
+ "zero-day exploit",
68
+ "credential stuffing",
69
+ "brute force",
70
+ "social engineering",
71
+ ]
72
+
73
+ MILITARY_OBJECTIVES = [
74
+ "secure perimeter",
75
+ "establish forward position",
76
+ "conduct reconnaissance",
77
+ "neutralize threat",
78
+ "protect assets",
79
+ "maintain operational security",
80
+ "coordinate with allied forces",
81
+ "execute tactical withdrawal",
82
+ ]
83
+
84
+ ENVIRONMENTAL_CONDITIONS = [
85
+ "night operations",
86
+ "adverse weather",
87
+ "limited visibility",
88
+ "urban terrain",
89
+ "mountainous region",
90
+ "coastal area",
91
+ "contested airspace",
92
+ "electronic warfare environment",
93
+ ]
94
+
95
+ def __init__(self, seed: int = 42):
96
+ """
97
+ Initialize augmenter.
98
+
99
+ Args:
100
+ seed: Random seed for reproducibility
101
+ """
102
+ self.rng = random.Random(seed)
103
+ self._augmentation_count = 0
104
+
105
+ def augment_sample(
106
+ self,
107
+ sample: DatasetSample,
108
+ num_augmentations: int = 3,
109
+ techniques: list[str] | None = None,
110
+ ) -> AugmentationResult:
111
+ """
112
+ Augment a single sample.
113
+
114
+ Args:
115
+ sample: Original dataset sample
116
+ num_augmentations: Number of augmented versions to create
117
+ techniques: Specific techniques to use (None for random selection)
118
+
119
+ Returns:
120
+ AugmentationResult with augmented samples
121
+ """
122
+ available_techniques = [
123
+ "urgency_variation",
124
+ "parameter_substitution",
125
+ "constraint_addition",
126
+ "temporal_shift",
127
+ "perspective_change",
128
+ ]
129
+
130
+ if techniques:
131
+ available_techniques = [t for t in techniques if t in available_techniques]
132
+
133
+ augmented_samples = []
134
+ used_techniques = []
135
+
136
+ for _i in range(num_augmentations):
137
+ technique = self.rng.choice(available_techniques)
138
+ used_techniques.append(technique)
139
+
140
+ augmented_text = self._apply_technique(sample.text, sample.domain, technique)
141
+
142
+ aug_sample = DatasetSample(
143
+ id=f"{sample.id}_aug_{self._augmentation_count}",
144
+ text=augmented_text,
145
+ metadata={
146
+ **sample.metadata,
147
+ "augmentation": technique,
148
+ "original_id": sample.id,
149
+ },
150
+ labels=sample.labels,
151
+ difficulty=sample.difficulty,
152
+ domain=sample.domain,
153
+ reasoning_steps=sample.reasoning_steps,
154
+ )
155
+
156
+ augmented_samples.append(aug_sample)
157
+ self._augmentation_count += 1
158
+
159
+ return AugmentationResult(
160
+ original=sample,
161
+ augmented=augmented_samples,
162
+ augmentation_types=used_techniques,
163
+ )
164
+
165
+ def _apply_technique(self, text: str, domain: str | None, technique: str) -> str:
166
+ """Apply specific augmentation technique."""
167
+ if technique == "urgency_variation":
168
+ return self._augment_urgency(text)
169
+ elif technique == "parameter_substitution":
170
+ return self._augment_parameters(text, domain)
171
+ elif technique == "constraint_addition":
172
+ return self._augment_constraints(text, domain)
173
+ elif technique == "temporal_shift":
174
+ return self._augment_temporal(text)
175
+ elif technique == "perspective_change":
176
+ return self._augment_perspective(text, domain)
177
+ else:
178
+ return text
179
+
180
+ def _augment_urgency(self, text: str) -> str:
181
+ """Vary urgency level in the text."""
182
+ urgency_level = self.rng.choice(list(self.URGENCY_MODIFIERS.keys()))
183
+ modifier = self.rng.choice(self.URGENCY_MODIFIERS[urgency_level])
184
+
185
+ # Add urgency prefix
186
+ if urgency_level == "high":
187
+ return f"[{modifier}] {text}"
188
+ elif urgency_level == "medium":
189
+ return f"{modifier}: {text}"
190
+ else:
191
+ return f"({modifier}) {text}"
192
+
193
+ def _augment_parameters(self, text: str, domain: str | None) -> str:
194
+ """Substitute domain-specific parameters."""
195
+ if domain == "cybersecurity" or "cyber" in text.lower():
196
+ # Substitute threat actors
197
+ for actor in self.THREAT_ACTORS:
198
+ if actor in text:
199
+ new_actor = self.rng.choice([a for a in self.THREAT_ACTORS if a != actor])
200
+ text = text.replace(actor, new_actor)
201
+ break
202
+
203
+ # Substitute attack vectors
204
+ for vector in self.ATTACK_VECTORS:
205
+ if vector in text.lower():
206
+ new_vector = self.rng.choice([v for v in self.ATTACK_VECTORS if v != vector])
207
+ text = text.replace(vector, new_vector)
208
+ break
209
+
210
+ elif domain == "military" or any(kw in text.lower() for kw in ["tactical", "military", "reconnaissance"]):
211
+ # Substitute objectives
212
+ for obj in self.MILITARY_OBJECTIVES:
213
+ if obj in text.lower():
214
+ new_obj = self.rng.choice([o for o in self.MILITARY_OBJECTIVES if o != obj])
215
+ text = text.replace(obj, new_obj)
216
+ break
217
+
218
+ return text
219
+
220
+ def _augment_constraints(self, text: str, domain: str | None) -> str:
221
+ """Add additional constraints to the scenario."""
222
+ constraints = []
223
+
224
+ if domain == "cybersecurity":
225
+ constraints = [
226
+ "with limited network visibility",
227
+ "under active attack",
228
+ "with compromised credentials",
229
+ "during maintenance window",
230
+ "with restricted access to logs",
231
+ ]
232
+ elif domain == "military":
233
+ constraints = [
234
+ "with limited ammunition",
235
+ "under communication blackout",
236
+ "with reduced personnel",
237
+ "in contested environment",
238
+ "with time constraint of 2 hours",
239
+ ]
240
+ else:
241
+ constraints = [
242
+ "with incomplete information",
243
+ "under time pressure",
244
+ "with resource constraints",
245
+ "considering multiple stakeholders",
246
+ "with conflicting objectives",
247
+ ]
248
+
249
+ if constraints:
250
+ constraint = self.rng.choice(constraints)
251
+ return f"{text} [{constraint}]"
252
+
253
+ return text
254
+
255
+ def _augment_temporal(self, text: str) -> str:
256
+ """Shift temporal context."""
257
+ temporal_contexts = [
258
+ "In the past 24 hours, ",
259
+ "Over the next week, ",
260
+ "Immediately, ",
261
+ "During the upcoming operation, ",
262
+ "Following initial assessment, ",
263
+ ]
264
+
265
+ context = self.rng.choice(temporal_contexts)
266
+ return f"{context}{text.lower()}" if text else text
267
+
268
+ def _augment_perspective(self, text: str, domain: str | None) -> str:
269
+ """Change analytical perspective."""
270
+ perspectives = {
271
+ "cybersecurity": [
272
+ "From a threat hunter's perspective: ",
273
+ "Considering the attacker's viewpoint: ",
274
+ "For incident response purposes: ",
275
+ "From a risk management standpoint: ",
276
+ ],
277
+ "military": [
278
+ "From the commander's perspective: ",
279
+ "Considering enemy capabilities: ",
280
+ "For tactical planning purposes: ",
281
+ "From a logistics standpoint: ",
282
+ ],
283
+ "default": [
284
+ "From an analytical perspective: ",
285
+ "Considering all factors: ",
286
+ "For decision-making purposes: ",
287
+ "From a strategic viewpoint: ",
288
+ ],
289
+ }
290
+
291
+ domain_perspectives = perspectives.get(domain or "default", perspectives["default"])
292
+ perspective = self.rng.choice(domain_perspectives)
293
+
294
+ return f"{perspective}{text}"
295
+
296
+ def augment_batch(
297
+ self,
298
+ samples: list[DatasetSample],
299
+ augmentations_per_sample: int = 2,
300
+ ) -> list[DatasetSample]:
301
+ """
302
+ Augment a batch of samples.
303
+
304
+ Args:
305
+ samples: List of original samples
306
+ augmentations_per_sample: Number of augmentations per sample
307
+
308
+ Returns:
309
+ List of all samples (original + augmented)
310
+ """
311
+ all_samples = list(samples) # Keep originals
312
+
313
+ for sample in samples:
314
+ result = self.augment_sample(sample, num_augmentations=augmentations_per_sample)
315
+ all_samples.extend(result.augmented)
316
+
317
+ logger.info(
318
+ f"Augmented {len(samples)} samples to {len(all_samples)} (+{len(all_samples) - len(samples)} augmented)"
319
+ )
320
+
321
+ return all_samples
322
+
323
+ def create_tactical_scenarios(self, base_samples: list[DatasetSample]) -> list[DatasetSample]:
324
+ """
325
+ Create tactical scenario variations from base samples.
326
+
327
+ Combines multiple augmentation techniques to create
328
+ diverse tactical scenarios for training.
329
+
330
+ Args:
331
+ base_samples: Base dataset samples
332
+
333
+ Returns:
334
+ Extended list with tactical scenario variations
335
+ """
336
+ scenarios = list(base_samples)
337
+
338
+ for sample in base_samples:
339
+ # Create high-stakes variant
340
+ high_stakes = self._augment_urgency(sample.text)
341
+ high_stakes = self._augment_constraints(high_stakes, sample.domain)
342
+ scenarios.append(
343
+ DatasetSample(
344
+ id=f"{sample.id}_highstakes_{self._augmentation_count}",
345
+ text=high_stakes,
346
+ metadata={
347
+ **sample.metadata,
348
+ "scenario_type": "high_stakes",
349
+ "original_id": sample.id,
350
+ },
351
+ labels=sample.labels,
352
+ difficulty="hard", # High stakes scenarios are harder
353
+ domain=sample.domain,
354
+ reasoning_steps=sample.reasoning_steps,
355
+ )
356
+ )
357
+ self._augmentation_count += 1
358
+
359
+ # Create multi-perspective variant
360
+ if self.rng.random() > 0.5:
361
+ multi_perspective = self._augment_perspective(sample.text, sample.domain)
362
+ scenarios.append(
363
+ DatasetSample(
364
+ id=f"{sample.id}_multiperspective_{self._augmentation_count}",
365
+ text=multi_perspective,
366
+ metadata={
367
+ **sample.metadata,
368
+ "scenario_type": "multi_perspective",
369
+ "original_id": sample.id,
370
+ },
371
+ labels=sample.labels,
372
+ difficulty=sample.difficulty,
373
+ domain=sample.domain,
374
+ reasoning_steps=sample.reasoning_steps,
375
+ )
376
+ )
377
+ self._augmentation_count += 1
378
+
379
+ logger.info(f"Created {len(scenarios) - len(base_samples)} tactical scenarios")
380
+ return scenarios
381
+
382
+
383
+ class CyberSecurityAugmenter(TacticalAugmenter):
384
+ """
385
+ Specialized augmenter for cybersecurity scenarios.
386
+
387
+ Focuses on:
388
+ - MITRE ATT&CK technique variations
389
+ - Threat intelligence context
390
+ - Incident response scenarios
391
+ """
392
+
393
+ MITRE_TACTICS = [
394
+ "Initial Access",
395
+ "Execution",
396
+ "Persistence",
397
+ "Privilege Escalation",
398
+ "Defense Evasion",
399
+ "Credential Access",
400
+ "Discovery",
401
+ "Lateral Movement",
402
+ "Collection",
403
+ "Exfiltration",
404
+ "Impact",
405
+ ]
406
+
407
+ SEVERITY_LEVELS = ["LOW", "MEDIUM", "HIGH", "CRITICAL"]
408
+
409
+ def augment_with_mitre_context(self, sample: DatasetSample) -> DatasetSample:
410
+ """
411
+ Add MITRE ATT&CK context to sample.
412
+
413
+ Args:
414
+ sample: Original sample
415
+
416
+ Returns:
417
+ Augmented sample with MITRE context
418
+ """
419
+ tactic = self.rng.choice(self.MITRE_TACTICS)
420
+ severity = self.rng.choice(self.SEVERITY_LEVELS)
421
+
422
+ augmented_text = f"[MITRE ATT&CK: {tactic}] [Severity: {severity}] {sample.text}"
423
+
424
+ return DatasetSample(
425
+ id=f"{sample.id}_mitre_{self._augmentation_count}",
426
+ text=augmented_text,
427
+ metadata={
428
+ **sample.metadata,
429
+ "mitre_tactic": tactic,
430
+ "severity": severity,
431
+ },
432
+ labels=sample.labels,
433
+ difficulty=sample.difficulty,
434
+ domain="cybersecurity",
435
+ reasoning_steps=sample.reasoning_steps,
436
+ )
437
+
438
+
439
+ class MilitaryTacticalAugmenter(TacticalAugmenter):
440
+ """
441
+ Specialized augmenter for military tactical scenarios.
442
+
443
+ Focuses on:
444
+ - Environmental condition variations
445
+ - Force composition changes
446
+ - Mission objective variations
447
+ """
448
+
449
+ FORCE_COMPOSITIONS = [
450
+ "infantry platoon",
451
+ "mechanized company",
452
+ "special operations team",
453
+ "combined arms battalion",
454
+ "air assault element",
455
+ ]
456
+
457
+ def augment_with_force_composition(self, sample: DatasetSample) -> DatasetSample:
458
+ """
459
+ Add force composition context to sample.
460
+
461
+ Args:
462
+ sample: Original sample
463
+
464
+ Returns:
465
+ Augmented sample with force composition
466
+ """
467
+ force = self.rng.choice(self.FORCE_COMPOSITIONS)
468
+ condition = self.rng.choice(self.ENVIRONMENTAL_CONDITIONS)
469
+
470
+ augmented_text = f"[Force: {force}] [Conditions: {condition}] {sample.text}"
471
+
472
+ return DatasetSample(
473
+ id=f"{sample.id}_tactical_{self._augmentation_count}",
474
+ text=augmented_text,
475
+ metadata={
476
+ **sample.metadata,
477
+ "force_composition": force,
478
+ "environmental_conditions": condition,
479
+ },
480
+ labels=sample.labels,
481
+ difficulty=sample.difficulty,
482
+ domain="military",
483
+ reasoning_steps=sample.reasoning_steps,
484
+ )
src/data/train_test_split.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Splitting Module for Training Pipeline.
3
+
4
+ Provides utilities for:
5
+ - Train/validation/test splitting
6
+ - Stratified sampling by domain or difficulty
7
+ - Cross-validation fold creation
8
+ - Reproducible splits with seeding
9
+ """
10
+
11
+ import logging
12
+ from collections import defaultdict
13
+ from dataclasses import dataclass
14
+ from typing import Any
15
+
16
+ from .dataset_loader import DatasetSample
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class DataSplit:
23
+ """Result of dataset splitting."""
24
+
25
+ train: list[DatasetSample]
26
+ validation: list[DatasetSample]
27
+ test: list[DatasetSample]
28
+ split_info: dict[str, Any]
29
+
30
+
31
+ @dataclass
32
+ class CrossValidationFold:
33
+ """Single fold for cross-validation."""
34
+
35
+ fold_id: int
36
+ train: list[DatasetSample]
37
+ validation: list[DatasetSample]
38
+
39
+
40
+ class DataSplitter:
41
+ """
42
+ Basic dataset splitter with random sampling.
43
+
44
+ Provides reproducible train/validation/test splits
45
+ with configurable ratios.
46
+ """
47
+
48
+ def __init__(self, seed: int = 42):
49
+ """
50
+ Initialize splitter.
51
+
52
+ Args:
53
+ seed: Random seed for reproducibility
54
+ """
55
+ self.seed = seed
56
+ import random
57
+
58
+ self.rng = random.Random(seed)
59
+
60
+ def split(
61
+ self,
62
+ samples: list[DatasetSample],
63
+ train_ratio: float = 0.7,
64
+ val_ratio: float = 0.15,
65
+ test_ratio: float = 0.15,
66
+ shuffle: bool = True,
67
+ ) -> DataSplit:
68
+ """
69
+ Split dataset into train/validation/test sets.
70
+
71
+ Args:
72
+ samples: List of all samples
73
+ train_ratio: Proportion for training (default 0.7)
74
+ val_ratio: Proportion for validation (default 0.15)
75
+ test_ratio: Proportion for testing (default 0.15)
76
+ shuffle: Whether to shuffle before splitting
77
+
78
+ Returns:
79
+ DataSplit with train, validation, and test sets
80
+ """
81
+ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001:
82
+ raise ValueError("Ratios must sum to 1.0")
83
+
84
+ if not samples:
85
+ raise ValueError("Cannot split empty sample list")
86
+
87
+ # Copy and optionally shuffle
88
+ all_samples = list(samples)
89
+ if shuffle:
90
+ self.rng.shuffle(all_samples)
91
+
92
+ n = len(all_samples)
93
+ train_end = int(n * train_ratio)
94
+ val_end = train_end + int(n * val_ratio)
95
+
96
+ train_samples = all_samples[:train_end]
97
+ val_samples = all_samples[train_end:val_end]
98
+ test_samples = all_samples[val_end:]
99
+
100
+ split_info = {
101
+ "total_samples": n,
102
+ "train_samples": len(train_samples),
103
+ "val_samples": len(val_samples),
104
+ "test_samples": len(test_samples),
105
+ "train_ratio": len(train_samples) / n,
106
+ "val_ratio": len(val_samples) / n,
107
+ "test_ratio": len(test_samples) / n,
108
+ "seed": self.seed,
109
+ "shuffled": shuffle,
110
+ }
111
+
112
+ logger.info(f"Split {n} samples: train={len(train_samples)}, val={len(val_samples)}, test={len(test_samples)}")
113
+
114
+ return DataSplit(
115
+ train=train_samples,
116
+ validation=val_samples,
117
+ test=test_samples,
118
+ split_info=split_info,
119
+ )
120
+
121
+ def create_k_folds(
122
+ self,
123
+ samples: list[DatasetSample],
124
+ k: int = 5,
125
+ shuffle: bool = True,
126
+ ) -> list[CrossValidationFold]:
127
+ """
128
+ Create k-fold cross-validation splits.
129
+
130
+ Args:
131
+ samples: List of all samples
132
+ k: Number of folds
133
+ shuffle: Whether to shuffle before splitting
134
+
135
+ Returns:
136
+ List of CrossValidationFold objects
137
+ """
138
+ if k < 2:
139
+ raise ValueError("k must be at least 2")
140
+
141
+ if len(samples) < k:
142
+ raise ValueError(f"Need at least {k} samples for {k}-fold CV")
143
+
144
+ # Copy and optionally shuffle
145
+ all_samples = list(samples)
146
+ if shuffle:
147
+ self.rng.shuffle(all_samples)
148
+
149
+ # Calculate fold sizes
150
+ fold_size = len(all_samples) // k
151
+ folds = []
152
+
153
+ for fold_id in range(k):
154
+ # Validation is the current fold
155
+ val_start = fold_id * fold_size
156
+ val_end = len(all_samples) if fold_id == k - 1 else val_start + fold_size # noqa: SIM108
157
+
158
+ val_samples = all_samples[val_start:val_end]
159
+ train_samples = all_samples[:val_start] + all_samples[val_end:]
160
+
161
+ folds.append(
162
+ CrossValidationFold(
163
+ fold_id=fold_id,
164
+ train=train_samples,
165
+ validation=val_samples,
166
+ )
167
+ )
168
+
169
+ logger.info(f"Created {k}-fold cross-validation splits")
170
+ return folds
171
+
172
+
173
+ class StratifiedSplitter(DataSplitter):
174
+ """
175
+ Stratified dataset splitter.
176
+
177
+ Ensures proportional representation of categories
178
+ (domain, difficulty, etc.) across splits.
179
+ """
180
+
181
+ def __init__(self, seed: int = 42, stratify_by: str = "domain"):
182
+ """
183
+ Initialize stratified splitter.
184
+
185
+ Args:
186
+ seed: Random seed for reproducibility
187
+ stratify_by: Attribute to stratify on ('domain', 'difficulty', 'labels')
188
+ """
189
+ super().__init__(seed)
190
+ self.stratify_by = stratify_by
191
+
192
+ def split(
193
+ self,
194
+ samples: list[DatasetSample],
195
+ train_ratio: float = 0.7,
196
+ val_ratio: float = 0.15,
197
+ test_ratio: float = 0.15,
198
+ shuffle: bool = True,
199
+ ) -> DataSplit:
200
+ """
201
+ Stratified split maintaining category proportions.
202
+
203
+ Args:
204
+ samples: List of all samples
205
+ train_ratio: Proportion for training
206
+ val_ratio: Proportion for validation
207
+ test_ratio: Proportion for testing
208
+ shuffle: Whether to shuffle before splitting
209
+
210
+ Returns:
211
+ DataSplit with stratified train, validation, and test sets
212
+ """
213
+ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001:
214
+ raise ValueError("Ratios must sum to 1.0")
215
+
216
+ if not samples:
217
+ raise ValueError("Cannot split empty sample list")
218
+
219
+ # Group samples by stratification key
220
+ groups = defaultdict(list)
221
+ for sample in samples:
222
+ key = self._get_stratify_key(sample)
223
+ groups[key].append(sample)
224
+
225
+ # Split each group proportionally
226
+ train_samples = []
227
+ val_samples = []
228
+ test_samples = []
229
+
230
+ for _key, group_samples in groups.items():
231
+ if shuffle:
232
+ self.rng.shuffle(group_samples)
233
+
234
+ n = len(group_samples)
235
+ train_end = int(n * train_ratio)
236
+ val_end = train_end + int(n * val_ratio)
237
+
238
+ train_samples.extend(group_samples[:train_end])
239
+ val_samples.extend(group_samples[train_end:val_end])
240
+ test_samples.extend(group_samples[val_end:])
241
+
242
+ # Final shuffle of combined sets
243
+ if shuffle:
244
+ self.rng.shuffle(train_samples)
245
+ self.rng.shuffle(val_samples)
246
+ self.rng.shuffle(test_samples)
247
+
248
+ # Verify stratification
249
+ stratify_info = self._verify_stratification(train_samples, val_samples, test_samples)
250
+
251
+ split_info = {
252
+ "total_samples": len(samples),
253
+ "train_samples": len(train_samples),
254
+ "val_samples": len(val_samples),
255
+ "test_samples": len(test_samples),
256
+ "train_ratio": len(train_samples) / len(samples),
257
+ "val_ratio": len(val_samples) / len(samples),
258
+ "test_ratio": len(test_samples) / len(samples),
259
+ "stratify_by": self.stratify_by,
260
+ "stratification_info": stratify_info,
261
+ "seed": self.seed,
262
+ "shuffled": shuffle,
263
+ }
264
+
265
+ logger.info(
266
+ f"Stratified split ({self.stratify_by}): "
267
+ f"train={len(train_samples)}, val={len(val_samples)}, "
268
+ f"test={len(test_samples)}"
269
+ )
270
+
271
+ return DataSplit(
272
+ train=train_samples,
273
+ validation=val_samples,
274
+ test=test_samples,
275
+ split_info=split_info,
276
+ )
277
+
278
+ def _get_stratify_key(self, sample: DatasetSample) -> str:
279
+ """Get stratification key for a sample."""
280
+ if self.stratify_by == "domain":
281
+ return sample.domain or "unknown"
282
+ elif self.stratify_by == "difficulty":
283
+ return sample.difficulty or "unknown"
284
+ elif self.stratify_by == "labels":
285
+ return ",".join(sorted(sample.labels)) if sample.labels else "unknown"
286
+ else:
287
+ return str(getattr(sample, self.stratify_by, "unknown"))
288
+
289
+ def _verify_stratification(
290
+ self,
291
+ train: list[DatasetSample],
292
+ val: list[DatasetSample],
293
+ test: list[DatasetSample],
294
+ ) -> dict[str, dict[str, float]]:
295
+ """
296
+ Verify that stratification was successful.
297
+
298
+ Returns dictionary showing distribution of stratification key
299
+ across train/val/test splits.
300
+ """
301
+
302
+ def get_distribution(samples: list[DatasetSample]) -> dict[str, float]:
303
+ if not samples:
304
+ return {}
305
+ counts = defaultdict(int)
306
+ for sample in samples:
307
+ key = self._get_stratify_key(sample)
308
+ counts[key] += 1
309
+ total = len(samples)
310
+ return {k: v / total for k, v in counts.items()}
311
+
312
+ return {
313
+ "train": get_distribution(train),
314
+ "validation": get_distribution(val),
315
+ "test": get_distribution(test),
316
+ }
317
+
318
+ def create_stratified_k_folds(
319
+ self,
320
+ samples: list[DatasetSample],
321
+ k: int = 5,
322
+ shuffle: bool = True,
323
+ ) -> list[CrossValidationFold]:
324
+ """
325
+ Create stratified k-fold cross-validation splits.
326
+
327
+ Args:
328
+ samples: List of all samples
329
+ k: Number of folds
330
+ shuffle: Whether to shuffle before splitting
331
+
332
+ Returns:
333
+ List of CrossValidationFold objects with stratification
334
+ """
335
+ if k < 2:
336
+ raise ValueError("k must be at least 2")
337
+
338
+ # Group samples by stratification key
339
+ groups = defaultdict(list)
340
+ for sample in samples:
341
+ key = self._get_stratify_key(sample)
342
+ groups[key].append(sample)
343
+
344
+ # Initialize folds
345
+ folds_data = [{"train": [], "val": []} for _ in range(k)]
346
+
347
+ # Distribute each group across folds
348
+ for _key, group_samples in groups.items():
349
+ if shuffle:
350
+ self.rng.shuffle(group_samples)
351
+
352
+ # Assign samples to folds
353
+ fold_size = len(group_samples) // k
354
+ for fold_id in range(k):
355
+ val_start = fold_id * fold_size
356
+ val_end = len(group_samples) if fold_id == k - 1 else val_start + fold_size
357
+
358
+ for i, sample in enumerate(group_samples):
359
+ if val_start <= i < val_end:
360
+ folds_data[fold_id]["val"].append(sample)
361
+ else:
362
+ folds_data[fold_id]["train"].append(sample)
363
+
364
+ # Create fold objects
365
+ folds = [
366
+ CrossValidationFold(
367
+ fold_id=i,
368
+ train=data["train"],
369
+ validation=data["val"],
370
+ )
371
+ for i, data in enumerate(folds_data)
372
+ ]
373
+
374
+ logger.info(f"Created stratified {k}-fold cross-validation splits")
375
+ return folds
376
+
377
+
378
+ class BalancedSampler:
379
+ """
380
+ Balanced sampling for imbalanced datasets.
381
+
382
+ Provides utilities for:
383
+ - Oversampling minority classes
384
+ - Undersampling majority classes
385
+ - SMOTE-like synthetic sampling (for numerical features)
386
+ """
387
+
388
+ def __init__(self, seed: int = 42):
389
+ """Initialize balanced sampler."""
390
+ self.seed = seed
391
+ import random
392
+
393
+ self.rng = random.Random(seed)
394
+
395
+ def oversample_minority(
396
+ self,
397
+ samples: list[DatasetSample],
398
+ target_key: str = "domain",
399
+ target_ratio: float = 1.0,
400
+ ) -> list[DatasetSample]:
401
+ """
402
+ Oversample minority classes to balance dataset.
403
+
404
+ Args:
405
+ samples: Original samples
406
+ target_key: Attribute to balance on
407
+ target_ratio: Target ratio relative to majority (1.0 = equal)
408
+
409
+ Returns:
410
+ Balanced sample list (originals + oversampled)
411
+ """
412
+ # Group by target key
413
+ groups = defaultdict(list)
414
+ for sample in samples:
415
+ key = getattr(sample, target_key, "unknown") or "unknown"
416
+ groups[key].append(sample)
417
+
418
+ # Find majority class size
419
+ max_count = max(len(g) for g in groups.values())
420
+ target_count = int(max_count * target_ratio)
421
+
422
+ # Oversample minority classes
423
+ balanced = []
424
+ for _key, group in groups.items():
425
+ balanced.extend(group)
426
+
427
+ # Oversample if needed
428
+ if len(group) < target_count:
429
+ num_to_add = target_count - len(group)
430
+ for _ in range(num_to_add):
431
+ # Randomly duplicate from group
432
+ original = self.rng.choice(group)
433
+ duplicate = DatasetSample(
434
+ id=f"{original.id}_oversample_{self.rng.randint(0, 999999)}",
435
+ text=original.text,
436
+ metadata={**original.metadata, "oversampled": True},
437
+ labels=original.labels,
438
+ difficulty=original.difficulty,
439
+ domain=original.domain,
440
+ reasoning_steps=original.reasoning_steps,
441
+ )
442
+ balanced.append(duplicate)
443
+
444
+ logger.info(f"Oversampled from {len(samples)} to {len(balanced)} samples")
445
+ return balanced
446
+
447
+ def undersample_majority(
448
+ self,
449
+ samples: list[DatasetSample],
450
+ target_key: str = "domain",
451
+ target_ratio: float = 1.0,
452
+ ) -> list[DatasetSample]:
453
+ """
454
+ Undersample majority classes to balance dataset.
455
+
456
+ Args:
457
+ samples: Original samples
458
+ target_key: Attribute to balance on
459
+ target_ratio: Target ratio relative to minority (1.0 = equal)
460
+
461
+ Returns:
462
+ Balanced sample list (subset of originals)
463
+ """
464
+ # Group by target key
465
+ groups = defaultdict(list)
466
+ for sample in samples:
467
+ key = getattr(sample, target_key, "unknown") or "unknown"
468
+ groups[key].append(sample)
469
+
470
+ # Find minority class size
471
+ min_count = min(len(g) for g in groups.values())
472
+ target_count = int(min_count * target_ratio)
473
+
474
+ # Undersample majority classes
475
+ balanced = []
476
+ for _key, group in groups.items():
477
+ if len(group) > target_count:
478
+ # Randomly select target_count samples
479
+ balanced.extend(self.rng.sample(group, target_count))
480
+ else:
481
+ balanced.extend(group)
482
+
483
+ logger.info(f"Undersampled from {len(samples)} to {len(balanced)} samples")
484
+ return balanced
485
+
486
+ def get_class_distribution(
487
+ self,
488
+ samples: list[DatasetSample],
489
+ target_key: str = "domain",
490
+ ) -> dict[str, int]:
491
+ """
492
+ Get distribution of classes.
493
+
494
+ Args:
495
+ samples: Sample list
496
+ target_key: Attribute to analyze
497
+
498
+ Returns:
499
+ Dictionary of class counts
500
+ """
501
+ distribution = defaultdict(int)
502
+ for sample in samples:
503
+ key = getattr(sample, target_key, "unknown") or "unknown"
504
+ distribution[key] += 1
505
+ return dict(distribution)
src/framework/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Framework module
src/framework/agents/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Agents module for async agent implementations
2
+ from .base import (
3
+ AgentContext,
4
+ AgentResult,
5
+ AsyncAgentBase,
6
+ CompositeAgent,
7
+ MetricsCollector,
8
+ NoOpMetricsCollector,
9
+ ParallelAgent,
10
+ SequentialAgent,
11
+ )
12
+
13
+ __all__ = [
14
+ "AsyncAgentBase",
15
+ "AgentContext",
16
+ "AgentResult",
17
+ "MetricsCollector",
18
+ "NoOpMetricsCollector",
19
+ "CompositeAgent",
20
+ "ParallelAgent",
21
+ "SequentialAgent",
22
+ ]