Spaces:
Sleeping
Sleeping
ianshank
Claude
commited on
Commit
·
40ee6b4
0
Parent(s):
feat: add personality output and bug fixes
Browse files- Added PersonalityResponseGenerator for conversational advisor responses
- Updated app.py with personality-infused output section
- Added sentence-transformers dependency
- Includes all trained model files and bug fixes
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +27 -0
- DEPLOYMENT_GUIDE.md +306 -0
- README.md +225 -0
- app.py +553 -0
- app_mock.py +590 -0
- demo_src/__init__.py +1 -0
- demo_src/agents_demo.py +234 -0
- demo_src/llm_mock.py +182 -0
- demo_src/mcts_demo.py +436 -0
- demo_src/wandb_tracker.py +349 -0
- models/bert_lora/final_model/README.md +206 -0
- models/bert_lora/final_model/adapter_config.json +40 -0
- models/bert_lora/final_model/adapter_model.safetensors +0 -0
- models/bert_lora/generated_dataset.json +0 -0
- models/bert_lora/training_results.json +48 -0
- models/rnn_meta_controller.history.json +128 -0
- models/rnn_meta_controller.pt +0 -0
- requirements.txt +28 -0
- src/__init__.py +0 -0
- src/adapters/__init__.py +7 -0
- src/adapters/llm/__init__.py +257 -0
- src/adapters/llm/anthropic_client.py +521 -0
- src/adapters/llm/base.py +305 -0
- src/adapters/llm/exceptions.py +204 -0
- src/adapters/llm/lmstudio_client.py +346 -0
- src/adapters/llm/openai_client.py +458 -0
- src/agents/__init__.py +0 -0
- src/agents/hrm_agent.py +454 -0
- src/agents/meta_controller/__init__.py +45 -0
- src/agents/meta_controller/base.py +219 -0
- src/agents/meta_controller/bert_controller.py +428 -0
- src/agents/meta_controller/config_loader.py +304 -0
- src/agents/meta_controller/rnn_controller.py +345 -0
- src/agents/meta_controller/utils.py +201 -0
- src/agents/trm_agent.py +395 -0
- src/api/__init__.py +35 -0
- src/api/auth.py +439 -0
- src/api/exceptions.py +299 -0
- src/api/inference_server.py +380 -0
- src/api/rest_server.py +441 -0
- src/config/__init__.py +0 -0
- src/config/meta_controller.yaml +22 -0
- src/config/settings.py +431 -0
- src/data/__init__.py +29 -0
- src/data/dataset_loader.py +551 -0
- src/data/preprocessing.py +406 -0
- src/data/tactical_augmentation.py +484 -0
- src/data/train_test_split.py +505 -0
- src/framework/__init__.py +1 -0
- src/framework/agents/__init__.py +22 -0
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
|
| 7 |
+
# Virtual environment
|
| 8 |
+
venv/
|
| 9 |
+
env/
|
| 10 |
+
.env
|
| 11 |
+
|
| 12 |
+
# IDE
|
| 13 |
+
.vscode/
|
| 14 |
+
.idea/
|
| 15 |
+
*.swp
|
| 16 |
+
*.swo
|
| 17 |
+
|
| 18 |
+
# OS
|
| 19 |
+
.DS_Store
|
| 20 |
+
Thumbs.db
|
| 21 |
+
|
| 22 |
+
# Gradio
|
| 23 |
+
flagged/
|
| 24 |
+
gradio_cached_examples/
|
| 25 |
+
|
| 26 |
+
# Logs
|
| 27 |
+
*.log
|
DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces Deployment Guide
|
| 2 |
+
|
| 3 |
+
This guide walks you through deploying the LangGraph Multi-Agent MCTS demo to Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- [Hugging Face Account](https://huggingface.co/join)
|
| 8 |
+
- Git installed locally
|
| 9 |
+
- Python 3.10+ (for local testing)
|
| 10 |
+
|
| 11 |
+
## Step 1: Create a New Space
|
| 12 |
+
|
| 13 |
+
1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 14 |
+
2. Click **"Create new Space"**
|
| 15 |
+
3. Fill in the form:
|
| 16 |
+
- **Owner**: Your username or organization
|
| 17 |
+
- **Space name**: `langgraph-mcts-demo` (or your choice)
|
| 18 |
+
- **License**: MIT
|
| 19 |
+
- **SDK**: Gradio
|
| 20 |
+
- **Hardware**: CPU Basic (Free tier - sufficient for demo)
|
| 21 |
+
- **Visibility**: Public (or Private)
|
| 22 |
+
4. Click **"Create Space"**
|
| 23 |
+
|
| 24 |
+
## Step 2: Clone and Deploy
|
| 25 |
+
|
| 26 |
+
### Option A: Git-based Deployment (Recommended)
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
# 1. Clone your new empty Space
|
| 30 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo
|
| 31 |
+
cd langgraph-mcts-demo
|
| 32 |
+
|
| 33 |
+
# 2. Copy demo files from this directory
|
| 34 |
+
cp -r /path/to/huggingface_space/* .
|
| 35 |
+
cp -r /path/to/huggingface_space/.gitignore .
|
| 36 |
+
|
| 37 |
+
# 3. Verify structure
|
| 38 |
+
ls -la
|
| 39 |
+
# Should show:
|
| 40 |
+
# - app.py
|
| 41 |
+
# - requirements.txt
|
| 42 |
+
# - README.md
|
| 43 |
+
# - .gitignore
|
| 44 |
+
# - demo_src/
|
| 45 |
+
# - __init__.py
|
| 46 |
+
# - agents_demo.py
|
| 47 |
+
# - llm_mock.py
|
| 48 |
+
# - mcts_demo.py
|
| 49 |
+
|
| 50 |
+
# 4. Commit and push
|
| 51 |
+
git add -A
|
| 52 |
+
git commit -m "Initial deployment of LangGraph Multi-Agent MCTS demo"
|
| 53 |
+
git push
|
| 54 |
+
|
| 55 |
+
# 5. Space will automatically build and deploy (takes 2-5 minutes)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Option B: Direct Upload via Web UI
|
| 59 |
+
|
| 60 |
+
1. Navigate to your Space on Hugging Face
|
| 61 |
+
2. Click **"Files"** tab
|
| 62 |
+
3. Click **"Add file"** → **"Upload files"**
|
| 63 |
+
4. Upload all files maintaining the directory structure:
|
| 64 |
+
- `app.py`
|
| 65 |
+
- `requirements.txt`
|
| 66 |
+
- `README.md`
|
| 67 |
+
- `.gitignore`
|
| 68 |
+
- `demo_src/__init__.py`
|
| 69 |
+
- `demo_src/agents_demo.py`
|
| 70 |
+
- `demo_src/llm_mock.py`
|
| 71 |
+
- `demo_src/mcts_demo.py`
|
| 72 |
+
5. Commit changes
|
| 73 |
+
|
| 74 |
+
## Step 3: Monitor Deployment
|
| 75 |
+
|
| 76 |
+
1. Go to your Space URL: `https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo`
|
| 77 |
+
2. Click **"Logs"** tab to monitor build progress
|
| 78 |
+
3. Wait for "Running on" message
|
| 79 |
+
4. Your demo is now live!
|
| 80 |
+
|
| 81 |
+
## Step 4: Test the Demo
|
| 82 |
+
|
| 83 |
+
1. Enter a query or select an example
|
| 84 |
+
2. Enable/disable different agents
|
| 85 |
+
3. Adjust MCTS parameters
|
| 86 |
+
4. Click "Process Query"
|
| 87 |
+
5. Review results and consensus scores
|
| 88 |
+
|
| 89 |
+
## Optional: Enable Real LLM Responses
|
| 90 |
+
|
| 91 |
+
To use Hugging Face Inference API instead of mock responses:
|
| 92 |
+
|
| 93 |
+
### 1. Update requirements.txt
|
| 94 |
+
|
| 95 |
+
```txt
|
| 96 |
+
gradio>=4.0.0,<5.0.0
|
| 97 |
+
numpy>=1.24.0,<2.0.0
|
| 98 |
+
huggingface_hub>=0.20.0
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### 2. Add Secret Token
|
| 102 |
+
|
| 103 |
+
1. Go to Space Settings → **Repository secrets**
|
| 104 |
+
2. Add new secret:
|
| 105 |
+
- Name: `HF_TOKEN`
|
| 106 |
+
- Value: Your Hugging Face token (from [Settings → Access Tokens](https://huggingface.co/settings/tokens))
|
| 107 |
+
|
| 108 |
+
### 3. Update app.py Initialization
|
| 109 |
+
|
| 110 |
+
Change line ~290 in `app.py`:
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
# From:
|
| 114 |
+
framework = MultiAgentFrameworkDemo(use_hf_inference=False)
|
| 115 |
+
|
| 116 |
+
# To:
|
| 117 |
+
import os
|
| 118 |
+
framework = MultiAgentFrameworkDemo(
|
| 119 |
+
use_hf_inference=True,
|
| 120 |
+
hf_model="mistralai/Mistral-7B-Instruct-v0.2"
|
| 121 |
+
)
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### 4. Commit and Push
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
git add -A
|
| 128 |
+
git commit -m "Enable Hugging Face Inference API"
|
| 129 |
+
git push
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## Optional: Enable Weights & Biases Tracking
|
| 133 |
+
|
| 134 |
+
Track experiments and visualize metrics with W&B integration.
|
| 135 |
+
|
| 136 |
+
### 1. Get W&B API Key
|
| 137 |
+
|
| 138 |
+
1. Sign up at [wandb.ai](https://wandb.ai)
|
| 139 |
+
2. Go to Settings → API Keys
|
| 140 |
+
3. Copy your API key
|
| 141 |
+
|
| 142 |
+
### 2. Add W&B Secret to Space
|
| 143 |
+
|
| 144 |
+
1. Go to Space Settings → **Repository secrets**
|
| 145 |
+
2. Add new secret:
|
| 146 |
+
- Name: `WANDB_API_KEY`
|
| 147 |
+
- Value: Your W&B API key
|
| 148 |
+
|
| 149 |
+
### 3. Use W&B in the Demo
|
| 150 |
+
|
| 151 |
+
1. Expand "Weights & Biases Tracking" accordion in the UI
|
| 152 |
+
2. Check "Enable W&B Tracking"
|
| 153 |
+
3. Optionally set:
|
| 154 |
+
- **Project Name**: Your W&B project (default: `langgraph-mcts-demo`)
|
| 155 |
+
- **Run Name**: Custom name for this run (auto-generated if empty)
|
| 156 |
+
4. Process your query
|
| 157 |
+
5. View the W&B run URL in the results
|
| 158 |
+
|
| 159 |
+
### 4. What Gets Logged
|
| 160 |
+
|
| 161 |
+
- **Agent Metrics**: Confidence scores, execution times, response lengths
|
| 162 |
+
- **MCTS Metrics**: Best value, visits, tree depth, exploration paths
|
| 163 |
+
- **Consensus Metrics**: Agreement scores, agent combinations
|
| 164 |
+
- **Performance**: Total processing time
|
| 165 |
+
- **Artifacts**: Full JSON results as artifacts
|
| 166 |
+
|
| 167 |
+
### 5. View Your Dashboard
|
| 168 |
+
|
| 169 |
+
After runs, visit your W&B project dashboard to:
|
| 170 |
+
- Compare different agent configurations
|
| 171 |
+
- Visualize consensus patterns
|
| 172 |
+
- Analyze MCTS exploration strategies
|
| 173 |
+
- Track performance over time
|
| 174 |
+
|
| 175 |
+
## Customization Options
|
| 176 |
+
|
| 177 |
+
### Change Gradio Theme
|
| 178 |
+
|
| 179 |
+
In `app.py`, modify:
|
| 180 |
+
|
| 181 |
+
```python
|
| 182 |
+
with gr.Blocks(
|
| 183 |
+
theme=gr.themes.Soft(), # Try: Default(), Monochrome(), Glass()
|
| 184 |
+
...
|
| 185 |
+
) as demo:
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### Add Custom Examples
|
| 189 |
+
|
| 190 |
+
Update `EXAMPLE_QUERIES` list in `app.py`:
|
| 191 |
+
|
| 192 |
+
```python
|
| 193 |
+
EXAMPLE_QUERIES = [
|
| 194 |
+
"Your custom query 1",
|
| 195 |
+
"Your custom query 2",
|
| 196 |
+
...
|
| 197 |
+
]
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
### Adjust MCTS Parameters
|
| 201 |
+
|
| 202 |
+
Modify sliders in `app.py`:
|
| 203 |
+
|
| 204 |
+
```python
|
| 205 |
+
mcts_iterations = gr.Slider(
|
| 206 |
+
minimum=10,
|
| 207 |
+
maximum=200, # Increase for more thorough search
|
| 208 |
+
value=50, # Change default
|
| 209 |
+
...
|
| 210 |
+
)
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
### Add More Agent Types
|
| 214 |
+
|
| 215 |
+
1. Create new agent in `demo_src/agents_demo.py`
|
| 216 |
+
2. Add to `MultiAgentFrameworkDemo` in `app.py`
|
| 217 |
+
3. Add UI controls in Gradio interface
|
| 218 |
+
|
| 219 |
+
## Troubleshooting
|
| 220 |
+
|
| 221 |
+
### Build Fails
|
| 222 |
+
|
| 223 |
+
- Check **Logs** tab for error details
|
| 224 |
+
- Verify `requirements.txt` has compatible versions
|
| 225 |
+
- Ensure all imports in `app.py` are satisfied
|
| 226 |
+
|
| 227 |
+
### Slow Performance
|
| 228 |
+
|
| 229 |
+
- Reduce default MCTS iterations
|
| 230 |
+
- Use mock LLM (no API calls)
|
| 231 |
+
- Simplify tree visualization
|
| 232 |
+
|
| 233 |
+
### Memory Issues (Free Tier)
|
| 234 |
+
|
| 235 |
+
- Limit max MCTS iterations to 100
|
| 236 |
+
- Reduce tree depth in `demo_src/mcts_demo.py`
|
| 237 |
+
- Simplify response generation
|
| 238 |
+
|
| 239 |
+
### Missing Files
|
| 240 |
+
|
| 241 |
+
Ensure directory structure:
|
| 242 |
+
```
|
| 243 |
+
your-space/
|
| 244 |
+
├── app.py
|
| 245 |
+
├── requirements.txt
|
| 246 |
+
├── README.md
|
| 247 |
+
├── .gitignore
|
| 248 |
+
└── demo_src/
|
| 249 |
+
├── __init__.py
|
| 250 |
+
├── agents_demo.py
|
| 251 |
+
├── llm_mock.py
|
| 252 |
+
├── mcts_demo.py
|
| 253 |
+
└── wandb_tracker.py
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
## Upgrading Hardware
|
| 257 |
+
|
| 258 |
+
For better performance:
|
| 259 |
+
|
| 260 |
+
1. Go to Space Settings
|
| 261 |
+
2. Under **Hardware**, select:
|
| 262 |
+
- **CPU Upgrade** ($0.03/hr) - Faster processing
|
| 263 |
+
- **T4 Small** ($0.60/hr) - GPU for neural models
|
| 264 |
+
3. Save changes
|
| 265 |
+
|
| 266 |
+
## Sharing Your Space
|
| 267 |
+
|
| 268 |
+
### Embed in Website
|
| 269 |
+
|
| 270 |
+
```html
|
| 271 |
+
<iframe
|
| 272 |
+
src="https://YOUR_USERNAME-langgraph-mcts-demo.hf.space"
|
| 273 |
+
frameborder="0"
|
| 274 |
+
width="100%"
|
| 275 |
+
height="600"
|
| 276 |
+
></iframe>
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
### Direct Link
|
| 280 |
+
|
| 281 |
+
Share: `https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo`
|
| 282 |
+
|
| 283 |
+
### API Access
|
| 284 |
+
|
| 285 |
+
Gradio automatically provides API endpoint:
|
| 286 |
+
```
|
| 287 |
+
https://YOUR_USERNAME-langgraph-mcts-demo.hf.space/api/predict
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
## Next Steps
|
| 291 |
+
|
| 292 |
+
1. **Collect Feedback**: Enable flagging for user feedback
|
| 293 |
+
2. **Add Analytics**: Track usage patterns
|
| 294 |
+
3. **Extend Agents**: Add domain-specific reasoning modules
|
| 295 |
+
4. **Integrate RAG**: Connect to vector databases for real context
|
| 296 |
+
5. **Add Visualization**: Enhanced tree and consensus displays
|
| 297 |
+
|
| 298 |
+
## Support
|
| 299 |
+
|
| 300 |
+
- **Hugging Face Docs**: https://huggingface.co/docs/hub/spaces
|
| 301 |
+
- **Gradio Docs**: https://www.gradio.app/docs
|
| 302 |
+
- **Full Framework**: https://github.com/ianshank/langgraph_multi_agent_mcts
|
| 303 |
+
|
| 304 |
+
---
|
| 305 |
+
|
| 306 |
+
**Happy Deploying!** 🚀
|
README.md
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: LangGraph Multi-Agent MCTS Demo
|
| 3 |
+
emoji: 🌳
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
tags:
|
| 12 |
+
- multi-agent
|
| 13 |
+
- mcts
|
| 14 |
+
- reasoning
|
| 15 |
+
- langgraph
|
| 16 |
+
- ai-agents
|
| 17 |
+
- wandb
|
| 18 |
+
- experiment-tracking
|
| 19 |
+
short_description: Multi-agent reasoning framework with Monte Carlo Tree Search
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# LangGraph Multi-Agent MCTS Framework
|
| 23 |
+
|
| 24 |
+
**Production Demo with Trained Neural Models** - Experience real trained meta-controllers for intelligent agent routing
|
| 25 |
+
|
| 26 |
+
## What This Demo Shows
|
| 27 |
+
|
| 28 |
+
This interactive demo showcases trained neural meta-controllers that dynamically route queries to specialized agents:
|
| 29 |
+
|
| 30 |
+
### 🤖 Trained Meta-Controllers
|
| 31 |
+
|
| 32 |
+
1. **RNN Meta-Controller**
|
| 33 |
+
- GRU-based recurrent neural network
|
| 34 |
+
- Learns sequential patterns in agent performance
|
| 35 |
+
- Fast inference (~2ms latency)
|
| 36 |
+
- Trained on 1000+ synthetic routing examples
|
| 37 |
+
|
| 38 |
+
2. **BERT Meta-Controller with LoRA**
|
| 39 |
+
- Transformer-based text understanding
|
| 40 |
+
- Parameter-efficient fine-tuning with LoRA adapters
|
| 41 |
+
- Context-aware routing decisions
|
| 42 |
+
- Better generalization to unseen query patterns
|
| 43 |
+
|
| 44 |
+
### 🧠 Three Specialized Agents
|
| 45 |
+
|
| 46 |
+
1. **HRM (Hierarchical Reasoning Module)**
|
| 47 |
+
- Best for: Complex decomposition, multi-level problems
|
| 48 |
+
- Technique: Hierarchical planning with adaptive computation
|
| 49 |
+
|
| 50 |
+
2. **TRM (Tree Reasoning Module)**
|
| 51 |
+
- Best for: Iterative refinement, comparison tasks
|
| 52 |
+
- Technique: Recursive refinement with convergence detection
|
| 53 |
+
|
| 54 |
+
3. **MCTS (Monte Carlo Tree Search)**
|
| 55 |
+
- Best for: Optimization, strategic planning
|
| 56 |
+
- Technique: UCB1 exploration with value backpropagation
|
| 57 |
+
|
| 58 |
+
### 📊 Key Features
|
| 59 |
+
|
| 60 |
+
- **Real Trained Models**: Production-ready neural meta-controllers
|
| 61 |
+
- **Intelligent Routing**: Models learn optimal agent selection patterns
|
| 62 |
+
- **Routing Visualization**: See confidence scores and probability distributions
|
| 63 |
+
- **Feature Engineering**: Demonstrates query → features → routing pipeline
|
| 64 |
+
- **Performance Metrics**: Track execution time and routing accuracy
|
| 65 |
+
|
| 66 |
+
## How to Use
|
| 67 |
+
|
| 68 |
+
1. **Enter a Query**: Type your question or select an example
|
| 69 |
+
2. **Select Controller**: Choose RNN (fast) or BERT (context-aware)
|
| 70 |
+
3. **Process Query**: Click "🚀 Process Query"
|
| 71 |
+
4. **Review Results**:
|
| 72 |
+
- See which agent the controller selected
|
| 73 |
+
- View routing confidence and probabilities
|
| 74 |
+
- Examine features used for decision-making
|
| 75 |
+
- Check agent execution details
|
| 76 |
+
|
| 77 |
+
## Weights & Biases Integration
|
| 78 |
+
|
| 79 |
+
Track your experiments with **Weights & Biases** for:
|
| 80 |
+
- 📈 **Metrics Dashboard**: Visualize consensus scores, execution times, agent performance
|
| 81 |
+
- 🔄 **Run Comparison**: Compare different configurations side-by-side
|
| 82 |
+
- 📊 **Experiment History**: Track all your queries and results
|
| 83 |
+
- 🌳 **MCTS Visualization**: Log tree exploration patterns
|
| 84 |
+
|
| 85 |
+
### Setting Up W&B
|
| 86 |
+
|
| 87 |
+
1. **Get API Key**: Sign up at [wandb.ai](https://wandb.ai) and get your API key
|
| 88 |
+
2. **Configure Space Secret** (if deploying your own):
|
| 89 |
+
- Go to Space Settings → Repository secrets
|
| 90 |
+
- Add: `WANDB_API_KEY` = your API key
|
| 91 |
+
3. **Enable in UI**:
|
| 92 |
+
- Expand "Weights & Biases Tracking" accordion
|
| 93 |
+
- Check "Enable W&B Tracking"
|
| 94 |
+
- Set project name (optional)
|
| 95 |
+
- Set run name (optional, auto-generated if empty)
|
| 96 |
+
4. **View Results**: After processing, click the W&B run URL to see your dashboard
|
| 97 |
+
|
| 98 |
+
### Logged Metrics
|
| 99 |
+
|
| 100 |
+
- **Per Agent**: Confidence, execution time, response length, reasoning steps
|
| 101 |
+
- **MCTS**: Best value, visits, tree depth, top actions with UCB1 scores
|
| 102 |
+
- **Consensus**: Score, level (high/medium/low), number of agents
|
| 103 |
+
- **Performance**: Total processing time
|
| 104 |
+
- **Artifacts**: Full JSON results, tree visualizations
|
| 105 |
+
|
| 106 |
+
## Example Queries
|
| 107 |
+
|
| 108 |
+
- "What are the key factors to consider when choosing between microservices and monolithic architecture?"
|
| 109 |
+
- "How can we optimize a Python application that processes 10GB of log files daily?"
|
| 110 |
+
- "Should we use SQL or NoSQL database for a social media application with 1M users?"
|
| 111 |
+
- "How to design a fault-tolerant message queue system?"
|
| 112 |
+
|
| 113 |
+
## Technical Details
|
| 114 |
+
|
| 115 |
+
### Architecture
|
| 116 |
+
|
| 117 |
+
```
|
| 118 |
+
Query Input
|
| 119 |
+
│
|
| 120 |
+
├─→ HRM Agent (Hierarchical Decomposition)
|
| 121 |
+
│ ├─ Component Analysis
|
| 122 |
+
│ └─ Structured Synthesis
|
| 123 |
+
│
|
| 124 |
+
├─→ TRM Agent (Iterative Refinement)
|
| 125 |
+
│ ├─ Initial Response
|
| 126 |
+
│ ├─ Clarity Enhancement
|
| 127 |
+
│ └─ Validation Check
|
| 128 |
+
│
|
| 129 |
+
└─→ MCTS Engine (Strategic Search)
|
| 130 |
+
├─ Selection (UCB1)
|
| 131 |
+
├─ Expansion
|
| 132 |
+
├─ Simulation
|
| 133 |
+
└─ Backpropagation
|
| 134 |
+
│
|
| 135 |
+
▼
|
| 136 |
+
Consensus Scoring
|
| 137 |
+
│
|
| 138 |
+
▼
|
| 139 |
+
Final Synthesized Response
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### MCTS Algorithm
|
| 143 |
+
|
| 144 |
+
The Monte Carlo Tree Search implementation uses:
|
| 145 |
+
|
| 146 |
+
- **UCB1 Selection**: `Q(s,a) + C * sqrt(ln(N(s)) / N(s,a))`
|
| 147 |
+
- **Progressive Widening**: Controls branching factor
|
| 148 |
+
- **Domain-Aware Actions**: Contextual decision options
|
| 149 |
+
- **Value Backpropagation**: Updates entire path statistics
|
| 150 |
+
|
| 151 |
+
### Consensus Calculation
|
| 152 |
+
|
| 153 |
+
```
|
| 154 |
+
consensus = average_confidence * agreement_factor
|
| 155 |
+
agreement_factor = max(0, 1 - std_deviation * 2)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
High consensus (>70%) indicates agents agree on approach.
|
| 159 |
+
Low consensus (<40%) suggests uncertainty or conflicting strategies.
|
| 160 |
+
|
| 161 |
+
## Demo Scope
|
| 162 |
+
|
| 163 |
+
This demonstration focuses on **meta-controller training and routing**:
|
| 164 |
+
|
| 165 |
+
- ✅ **Real Trained Models**: Production RNN and BERT controllers
|
| 166 |
+
- ✅ **Actual Model Loading**: PyTorch and HuggingFace Transformers
|
| 167 |
+
- ✅ **Feature Engineering**: Query analysis → feature vectors
|
| 168 |
+
- ✅ **Routing Visualization**: See controller decision-making
|
| 169 |
+
- ⚠️ **Simplified Agents**: Agent responses are mocked for demo purposes
|
| 170 |
+
- ⚠️ **No Live LLM Calls**: Agents don't call actual LLMs (to reduce latency/cost)
|
| 171 |
+
|
| 172 |
+
## Full Production Framework
|
| 173 |
+
|
| 174 |
+
The complete repository includes all production features:
|
| 175 |
+
|
| 176 |
+
- ✅ **Neural Meta-Controllers**: RNN and BERT with LoRA (deployed here!)
|
| 177 |
+
- ✅ **Agent Implementations**: Full HRM, TRM, and MCTS with PyTorch
|
| 178 |
+
- ✅ **Training Pipeline**: Data generation, training, evaluation
|
| 179 |
+
- ✅ **LLM Integration**: OpenAI, Anthropic, LM Studio support
|
| 180 |
+
- ✅ **RAG Systems**: ChromaDB, FAISS, Pinecone vector stores
|
| 181 |
+
- ✅ **Observability**: OpenTelemetry tracing, Prometheus metrics
|
| 182 |
+
- ✅ **Storage**: S3 artifact storage, experiment tracking
|
| 183 |
+
- ✅ **CI/CD**: Automated testing, security scanning, deployment
|
| 184 |
+
|
| 185 |
+
**GitHub Repository**: [ianshank/langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts)
|
| 186 |
+
|
| 187 |
+
## Technical Stack
|
| 188 |
+
|
| 189 |
+
- **Python**: 3.11+
|
| 190 |
+
- **UI**: Gradio 4.x
|
| 191 |
+
- **ML Frameworks**: PyTorch 2.1+, Transformers, PEFT (LoRA)
|
| 192 |
+
- **Models**: GRU-based RNN, BERT-mini with LoRA adapters
|
| 193 |
+
- **Architecture**: Neural meta-controller + multi-agent system
|
| 194 |
+
- **Experiment Tracking**: Weights & Biases (optional)
|
| 195 |
+
- **Numerical**: NumPy
|
| 196 |
+
|
| 197 |
+
## Research Applications
|
| 198 |
+
|
| 199 |
+
This framework demonstrates concepts applicable to:
|
| 200 |
+
|
| 201 |
+
- Complex decision-making systems
|
| 202 |
+
- AI-assisted software architecture decisions
|
| 203 |
+
- Multi-perspective problem analysis
|
| 204 |
+
- Strategic planning with uncertainty
|
| 205 |
+
|
| 206 |
+
## Citation
|
| 207 |
+
|
| 208 |
+
If you use this framework in research, please cite:
|
| 209 |
+
|
| 210 |
+
```bibtex
|
| 211 |
+
@software{langgraph_mcts_2024,
|
| 212 |
+
title={LangGraph Multi-Agent MCTS Framework},
|
| 213 |
+
author={Your Name},
|
| 214 |
+
year={2024},
|
| 215 |
+
url={https://github.com/ianshank/langgraph_multi_agent_mcts}
|
| 216 |
+
}
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
## License
|
| 220 |
+
|
| 221 |
+
MIT License - See repository for details.
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
**Built with** LangGraph, Gradio, and Python | **Demo Version**: 1.0.0
|
app.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LangGraph Multi-Agent MCTS Framework - Integrated Demo with Trained Models
|
| 3 |
+
|
| 4 |
+
Demonstrates the actual trained neural meta-controllers:
|
| 5 |
+
- RNN Meta-Controller for sequential pattern recognition
|
| 6 |
+
- BERT with LoRA adapters for text-based routing
|
| 7 |
+
|
| 8 |
+
This is a production demonstration using real trained models.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import asyncio
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# Fail fast if critical dependencies are missing or broken
|
| 18 |
+
try:
|
| 19 |
+
import peft
|
| 20 |
+
|
| 21 |
+
print(f"[OK] PEFT library imported successfully (version: {peft.__version__})")
|
| 22 |
+
except ImportError as e:
|
| 23 |
+
print(f"CRITICAL ERROR: Could not import peft library: {e}")
|
| 24 |
+
# We don't exit here to allow the app to crash naturally later with full stack trace,
|
| 25 |
+
# but this print ensures it's visible in the logs immediately.
|
| 26 |
+
|
| 27 |
+
import gradio as gr
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
# Import the trained controllers
|
| 31 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 32 |
+
|
| 33 |
+
from src.agents.meta_controller.base import MetaControllerFeatures
|
| 34 |
+
from src.agents.meta_controller.bert_controller import BERTMetaController
|
| 35 |
+
from src.agents.meta_controller.rnn_controller import RNNMetaController
|
| 36 |
+
from src.agents.meta_controller.feature_extractor import (
|
| 37 |
+
FeatureExtractor,
|
| 38 |
+
FeatureExtractorConfig,
|
| 39 |
+
)
|
| 40 |
+
from src.utils.personality_response import PersonalityResponseGenerator
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class AgentResult:
|
| 45 |
+
"""Result from a single agent."""
|
| 46 |
+
|
| 47 |
+
agent_name: str
|
| 48 |
+
response: str
|
| 49 |
+
confidence: float
|
| 50 |
+
reasoning_steps: list[str]
|
| 51 |
+
execution_time_ms: float
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class ControllerDecision:
|
| 56 |
+
"""Decision made by the meta-controller."""
|
| 57 |
+
|
| 58 |
+
selected_agent: str
|
| 59 |
+
confidence: float
|
| 60 |
+
routing_probabilities: dict[str, float]
|
| 61 |
+
features_used: dict
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def create_features_from_query(
|
| 65 |
+
query: str,
|
| 66 |
+
iteration: int = 0,
|
| 67 |
+
last_agent: str = "none",
|
| 68 |
+
feature_extractor: FeatureExtractor | None = None,
|
| 69 |
+
) -> MetaControllerFeatures:
|
| 70 |
+
"""
|
| 71 |
+
Convert a text query into features for the meta-controller.
|
| 72 |
+
|
| 73 |
+
Uses semantic embeddings for robust feature extraction. Falls back to
|
| 74 |
+
heuristic-based extraction if embeddings are not available.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
query: The input query text
|
| 78 |
+
iteration: Current iteration number
|
| 79 |
+
last_agent: Name of the last agent used
|
| 80 |
+
feature_extractor: Optional FeatureExtractor instance (created if None)
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
MetaControllerFeatures instance
|
| 84 |
+
"""
|
| 85 |
+
# Use provided feature extractor or create a new one
|
| 86 |
+
if feature_extractor is None:
|
| 87 |
+
try:
|
| 88 |
+
config = FeatureExtractorConfig.from_env()
|
| 89 |
+
feature_extractor = FeatureExtractor(config)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Warning: Failed to initialize FeatureExtractor: {e}")
|
| 92 |
+
print("Falling back to heuristic-based feature extraction")
|
| 93 |
+
# Will use heuristic fallback below
|
| 94 |
+
|
| 95 |
+
# Extract features using the feature extractor
|
| 96 |
+
try:
|
| 97 |
+
if feature_extractor is not None:
|
| 98 |
+
return feature_extractor.extract_features(query, iteration, last_agent)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"Warning: Feature extraction failed: {e}")
|
| 101 |
+
print("Falling back to heuristic-based feature extraction")
|
| 102 |
+
|
| 103 |
+
# Fallback to original heuristic-based extraction
|
| 104 |
+
# (This code is kept as a safety net but should rarely be used)
|
| 105 |
+
query_length = len(query)
|
| 106 |
+
|
| 107 |
+
# Estimate complexity based on query characteristics
|
| 108 |
+
has_multiple_questions = "?" in query and query.count("?") > 1
|
| 109 |
+
has_comparison = any(word in query.lower() for word in ["vs", "versus", "compare", "difference", "better"])
|
| 110 |
+
has_optimization = any(word in query.lower() for word in ["optimize", "best", "improve", "maximize", "minimize"])
|
| 111 |
+
has_technical = any(word in query.lower() for word in ["algorithm", "code", "implement", "technical", "system"])
|
| 112 |
+
|
| 113 |
+
# Create mock confidence scores based on query characteristics
|
| 114 |
+
hrm_confidence = 0.5 + (0.3 if has_multiple_questions else 0) + (0.1 if has_technical else 0)
|
| 115 |
+
trm_confidence = 0.5 + (0.3 if has_comparison else 0) + (0.1 if query_length > 100 else 0)
|
| 116 |
+
mcts_confidence = 0.5 + (0.3 if has_optimization else 0) + (0.1 if has_technical else 0)
|
| 117 |
+
|
| 118 |
+
# Normalize
|
| 119 |
+
total = hrm_confidence + trm_confidence + mcts_confidence
|
| 120 |
+
if total == 0:
|
| 121 |
+
hrm_confidence = 1.0 / 3.0
|
| 122 |
+
trm_confidence = 1.0 / 3.0
|
| 123 |
+
mcts_confidence = 1.0 / 3.0
|
| 124 |
+
else:
|
| 125 |
+
hrm_confidence /= total
|
| 126 |
+
trm_confidence /= total
|
| 127 |
+
mcts_confidence /= total
|
| 128 |
+
|
| 129 |
+
# Calculate consensus score
|
| 130 |
+
max_confidence = max(hrm_confidence, trm_confidence, mcts_confidence)
|
| 131 |
+
if max_confidence == 0:
|
| 132 |
+
consensus_score = 0.0
|
| 133 |
+
else:
|
| 134 |
+
consensus_score = min(hrm_confidence, trm_confidence, mcts_confidence) / max_confidence
|
| 135 |
+
|
| 136 |
+
features = MetaControllerFeatures(
|
| 137 |
+
hrm_confidence=hrm_confidence,
|
| 138 |
+
trm_confidence=trm_confidence,
|
| 139 |
+
mcts_value=mcts_confidence,
|
| 140 |
+
consensus_score=consensus_score,
|
| 141 |
+
last_agent=last_agent,
|
| 142 |
+
iteration=iteration,
|
| 143 |
+
query_length=query_length,
|
| 144 |
+
has_rag_context=query_length > 50,
|
| 145 |
+
rag_relevance_score=0.7 if query_length > 50 else 0.0,
|
| 146 |
+
is_technical_query=has_technical,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return features
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class IntegratedFramework:
|
| 153 |
+
"""
|
| 154 |
+
Integrated multi-agent framework using trained meta-controllers.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self):
|
| 158 |
+
"""Initialize the framework with trained models."""
|
| 159 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 160 |
+
print(f"Using device: {self.device}")
|
| 161 |
+
|
| 162 |
+
# Initialize feature extractor with semantic embeddings
|
| 163 |
+
print("Initializing Feature Extractor...")
|
| 164 |
+
try:
|
| 165 |
+
config = FeatureExtractorConfig.from_env()
|
| 166 |
+
# Set device to match the framework device
|
| 167 |
+
config.device = self.device
|
| 168 |
+
self.feature_extractor = FeatureExtractor(config)
|
| 169 |
+
print(f"[OK] Feature Extractor initialized: {self.feature_extractor}")
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"[WARN] Failed to initialize Feature Extractor: {e}")
|
| 172 |
+
print("[WARN] Will fall back to heuristic-based feature extraction")
|
| 173 |
+
self.feature_extractor = None
|
| 174 |
+
|
| 175 |
+
# Load trained RNN Meta-Controller
|
| 176 |
+
print("Loading RNN Meta-Controller...")
|
| 177 |
+
self.rnn_controller = RNNMetaController(name="RNNController", seed=42, device=self.device)
|
| 178 |
+
|
| 179 |
+
# Load the trained weights
|
| 180 |
+
rnn_model_path = Path(__file__).parent / "models" / "rnn_meta_controller.pt"
|
| 181 |
+
if rnn_model_path.exists():
|
| 182 |
+
checkpoint = torch.load(rnn_model_path, map_location=self.device, weights_only=True)
|
| 183 |
+
self.rnn_controller.model.load_state_dict(checkpoint)
|
| 184 |
+
self.rnn_controller.model.eval()
|
| 185 |
+
print(f"[OK] Loaded RNN model from {rnn_model_path}")
|
| 186 |
+
else:
|
| 187 |
+
print(f"[WARN] RNN model not found at {rnn_model_path}, using untrained model")
|
| 188 |
+
|
| 189 |
+
# Load trained BERT Meta-Controller with LoRA
|
| 190 |
+
print("Loading BERT Meta-Controller with LoRA...")
|
| 191 |
+
self.bert_controller = BERTMetaController(name="BERTController", seed=42, device=self.device, use_lora=True)
|
| 192 |
+
|
| 193 |
+
bert_model_path = Path(__file__).parent / "models" / "bert_lora" / "final_model"
|
| 194 |
+
if bert_model_path.exists():
|
| 195 |
+
try:
|
| 196 |
+
self.bert_controller.load_model(str(bert_model_path))
|
| 197 |
+
print(f"[OK] Loaded BERT LoRA model from {bert_model_path}")
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"[WARN] Error loading BERT model: {e}")
|
| 200 |
+
print("Using untrained BERT model")
|
| 201 |
+
else:
|
| 202 |
+
print(f"[WARN] BERT model not found at {bert_model_path}, using untrained model")
|
| 203 |
+
|
| 204 |
+
# Agent routing map
|
| 205 |
+
self.agent_handlers = {
|
| 206 |
+
"hrm": self._handle_hrm,
|
| 207 |
+
"trm": self._handle_trm,
|
| 208 |
+
"mcts": self._handle_mcts,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
print("Framework initialized successfully!")
|
| 212 |
+
|
| 213 |
+
async def process_query(
|
| 214 |
+
self,
|
| 215 |
+
query: str,
|
| 216 |
+
controller_type: str = "rnn",
|
| 217 |
+
) -> tuple[AgentResult, ControllerDecision]:
|
| 218 |
+
"""
|
| 219 |
+
Process a query using the trained meta-controller.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
query: The input query
|
| 223 |
+
controller_type: Which controller to use ("rnn" or "bert")
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
(agent_result, controller_decision) tuple
|
| 227 |
+
"""
|
| 228 |
+
start_time = time.perf_counter()
|
| 229 |
+
|
| 230 |
+
# Step 1: Convert query to features using semantic embeddings
|
| 231 |
+
features = create_features_from_query(query, feature_extractor=self.feature_extractor)
|
| 232 |
+
|
| 233 |
+
# Step 2: Get controller decision
|
| 234 |
+
if controller_type == "rnn":
|
| 235 |
+
prediction = self.rnn_controller.predict(features)
|
| 236 |
+
else: # bert
|
| 237 |
+
prediction = self.bert_controller.predict(features)
|
| 238 |
+
|
| 239 |
+
selected_agent = prediction.agent
|
| 240 |
+
confidence = prediction.confidence
|
| 241 |
+
|
| 242 |
+
# Get routing probabilities (prediction.probabilities is already a dict)
|
| 243 |
+
routing_probs = prediction.probabilities
|
| 244 |
+
|
| 245 |
+
# Step 3: Route to selected agent
|
| 246 |
+
handler = self.agent_handlers.get(selected_agent, self._handle_hrm)
|
| 247 |
+
agent_result = await handler(query)
|
| 248 |
+
|
| 249 |
+
# Create controller decision summary
|
| 250 |
+
controller_decision = ControllerDecision(
|
| 251 |
+
selected_agent=selected_agent,
|
| 252 |
+
confidence=confidence,
|
| 253 |
+
routing_probabilities=routing_probs,
|
| 254 |
+
features_used={
|
| 255 |
+
"hrm_confidence": features.hrm_confidence,
|
| 256 |
+
"trm_confidence": features.trm_confidence,
|
| 257 |
+
"mcts_value": features.mcts_value,
|
| 258 |
+
"consensus_score": features.consensus_score,
|
| 259 |
+
"query_length": features.query_length,
|
| 260 |
+
"is_technical": features.is_technical_query,
|
| 261 |
+
},
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
total_time = (time.perf_counter() - start_time) * 1000
|
| 265 |
+
agent_result.execution_time_ms = round(total_time, 2)
|
| 266 |
+
|
| 267 |
+
return agent_result, controller_decision
|
| 268 |
+
|
| 269 |
+
async def _handle_hrm(self, query: str) -> AgentResult:
|
| 270 |
+
"""Handle query with Hierarchical Reasoning Module."""
|
| 271 |
+
# Simulate HRM processing
|
| 272 |
+
await asyncio.sleep(0.1)
|
| 273 |
+
|
| 274 |
+
steps = [
|
| 275 |
+
"Decompose query into hierarchical subproblems",
|
| 276 |
+
"Apply high-level reasoning (H-Module)",
|
| 277 |
+
"Execute low-level refinement (L-Module)",
|
| 278 |
+
"Synthesize hierarchical solution",
|
| 279 |
+
]
|
| 280 |
+
|
| 281 |
+
response = f"[HRM Analysis] Breaking down the problem hierarchically: {query[:100]}..."
|
| 282 |
+
|
| 283 |
+
return AgentResult(
|
| 284 |
+
agent_name="HRM (Hierarchical Reasoning)",
|
| 285 |
+
response=response,
|
| 286 |
+
confidence=0.85,
|
| 287 |
+
reasoning_steps=steps,
|
| 288 |
+
execution_time_ms=0.0,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
async def _handle_trm(self, query: str) -> AgentResult:
|
| 292 |
+
"""Handle query with Tree Reasoning Module."""
|
| 293 |
+
# Simulate TRM processing
|
| 294 |
+
await asyncio.sleep(0.1)
|
| 295 |
+
|
| 296 |
+
steps = [
|
| 297 |
+
"Initialize solution state",
|
| 298 |
+
"Recursive refinement iteration 1",
|
| 299 |
+
"Recursive refinement iteration 2",
|
| 300 |
+
"Convergence achieved - finalize",
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
response = f"[TRM Analysis] Applying iterative refinement: {query[:100]}..."
|
| 304 |
+
|
| 305 |
+
return AgentResult(
|
| 306 |
+
agent_name="TRM (Iterative Refinement)",
|
| 307 |
+
response=response,
|
| 308 |
+
confidence=0.80,
|
| 309 |
+
reasoning_steps=steps,
|
| 310 |
+
execution_time_ms=0.0,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
async def _handle_mcts(self, query: str) -> AgentResult:
|
| 314 |
+
"""Handle query with MCTS."""
|
| 315 |
+
# Simulate MCTS processing
|
| 316 |
+
await asyncio.sleep(0.15)
|
| 317 |
+
|
| 318 |
+
steps = [
|
| 319 |
+
"Build search tree",
|
| 320 |
+
"Selection: UCB1 exploration",
|
| 321 |
+
"Expansion: Add promising nodes",
|
| 322 |
+
"Simulation: Rollout evaluation",
|
| 323 |
+
"Backpropagation: Update values",
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
response = f"[MCTS Analysis] Strategic exploration via tree search: {query[:100]}..."
|
| 327 |
+
|
| 328 |
+
return AgentResult(
|
| 329 |
+
agent_name="MCTS (Monte Carlo Tree Search)",
|
| 330 |
+
response=response,
|
| 331 |
+
confidence=0.88,
|
| 332 |
+
reasoning_steps=steps,
|
| 333 |
+
execution_time_ms=0.0,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# Global framework instance
|
| 338 |
+
framework = None
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def initialize_framework():
|
| 342 |
+
"""Initialize or reinitialize the framework."""
|
| 343 |
+
global framework
|
| 344 |
+
try:
|
| 345 |
+
framework = IntegratedFramework()
|
| 346 |
+
return "[OK] Framework initialized with trained models!"
|
| 347 |
+
except Exception as e:
|
| 348 |
+
return f"[ERROR] Error initializing framework: {str(e)}"
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def process_query_sync(
|
| 352 |
+
query: str,
|
| 353 |
+
controller_type: str,
|
| 354 |
+
):
|
| 355 |
+
"""Synchronous wrapper for async processing."""
|
| 356 |
+
global framework
|
| 357 |
+
|
| 358 |
+
if framework is None:
|
| 359 |
+
framework = IntegratedFramework()
|
| 360 |
+
|
| 361 |
+
if not query.strip():
|
| 362 |
+
return ("Please enter a query.", {}, "", {}, "", "")
|
| 363 |
+
|
| 364 |
+
# Run async function
|
| 365 |
+
agent_result, controller_decision = asyncio.run(
|
| 366 |
+
framework.process_query(query=query, controller_type=controller_type.lower())
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Format outputs
|
| 370 |
+
final_response = agent_result.response
|
| 371 |
+
|
| 372 |
+
# Generate personality-infused response
|
| 373 |
+
personality_gen = PersonalityResponseGenerator()
|
| 374 |
+
try:
|
| 375 |
+
personality_response = personality_gen.generate_response(
|
| 376 |
+
agent_response=final_response,
|
| 377 |
+
query=query
|
| 378 |
+
)
|
| 379 |
+
except Exception as e:
|
| 380 |
+
# Fallback to a simple wrapper if personality generation fails
|
| 381 |
+
personality_response = f"Here's what I found:\n\n{final_response}"
|
| 382 |
+
print(f"Warning: Personality generation failed: {e}")
|
| 383 |
+
|
| 384 |
+
# Controller decision visualization
|
| 385 |
+
routing_viz = "### 🧠 Meta-Controller Decision\n\n"
|
| 386 |
+
routing_viz += f"**Selected Agent:** `{controller_decision.selected_agent.upper()}`\n\n"
|
| 387 |
+
routing_viz += f"**Confidence:** {controller_decision.confidence:.1%}\n\n"
|
| 388 |
+
routing_viz += "**Routing Probabilities:**\n"
|
| 389 |
+
for agent, prob in controller_decision.routing_probabilities.items():
|
| 390 |
+
bar = "█" * int(prob * 50)
|
| 391 |
+
routing_viz += f"- **{agent.upper()}**: {prob:.1%} {bar}\n"
|
| 392 |
+
|
| 393 |
+
# Agent details
|
| 394 |
+
agent_details = {
|
| 395 |
+
"agent": agent_result.agent_name,
|
| 396 |
+
"confidence": f"{agent_result.confidence:.1%}",
|
| 397 |
+
"reasoning_steps": agent_result.reasoning_steps,
|
| 398 |
+
"execution_time_ms": agent_result.execution_time_ms,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
# Features used
|
| 402 |
+
features_viz = "### 📊 Features Used for Routing\n\n"
|
| 403 |
+
for feature, value in controller_decision.features_used.items():
|
| 404 |
+
if isinstance(value, float):
|
| 405 |
+
features_viz += f"- **{feature}**: {value:.3f}\n"
|
| 406 |
+
elif isinstance(value, bool):
|
| 407 |
+
features_viz += f"- **{feature}**: {'Yes' if value else 'No'}\n"
|
| 408 |
+
else:
|
| 409 |
+
features_viz += f"- **{feature}**: {value}\n"
|
| 410 |
+
|
| 411 |
+
# Metrics
|
| 412 |
+
metrics = f"""
|
| 413 |
+
**Controller:** {controller_type}
|
| 414 |
+
**Execution Time:** {agent_result.execution_time_ms:.2f} ms
|
| 415 |
+
**Agent Confidence:** {agent_result.confidence:.1%}
|
| 416 |
+
"""
|
| 417 |
+
|
| 418 |
+
return final_response, agent_details, routing_viz, features_viz, metrics, personality_response
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# Example queries
|
| 422 |
+
EXAMPLE_QUERIES = [
|
| 423 |
+
"What are the key factors to consider when choosing between microservices and monolithic architecture?",
|
| 424 |
+
"How can we optimize a Python application that processes 10GB of log files daily?",
|
| 425 |
+
"Compare the performance characteristics of B-trees vs LSM-trees for write-heavy workloads",
|
| 426 |
+
"Design a distributed rate limiting system that handles 100k requests per second",
|
| 427 |
+
"Explain the difference between supervised and unsupervised learning with examples",
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# Gradio Interface
|
| 432 |
+
with gr.Blocks(
|
| 433 |
+
title="LangGraph Multi-Agent MCTS - Trained Models Demo",
|
| 434 |
+
theme=gr.themes.Soft(),
|
| 435 |
+
css="""
|
| 436 |
+
.agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; }
|
| 437 |
+
.highlight { background-color: #e3f2fd; padding: 10px; border-radius: 5px; margin: 10px 0; }
|
| 438 |
+
""",
|
| 439 |
+
) as demo:
|
| 440 |
+
gr.Markdown(
|
| 441 |
+
"""
|
| 442 |
+
# 🎯 LangGraph Multi-Agent MCTS Framework
|
| 443 |
+
## Production Demo with Trained Neural Meta-Controllers
|
| 444 |
+
|
| 445 |
+
This demo uses **REAL trained models**:
|
| 446 |
+
- 🧠 **RNN Meta-Controller**: GRU-based sequential pattern recognition
|
| 447 |
+
- 🤖 **BERT with LoRA**: Transformer-based text understanding for routing
|
| 448 |
+
|
| 449 |
+
The meta-controllers learn to route queries to the optimal agent:
|
| 450 |
+
- **HRM**: Hierarchical reasoning for complex decomposition
|
| 451 |
+
- **TRM**: Iterative refinement for progressive improvement
|
| 452 |
+
- **MCTS**: Strategic exploration for optimization problems
|
| 453 |
+
|
| 454 |
+
---
|
| 455 |
+
"""
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
with gr.Row():
|
| 459 |
+
with gr.Column(scale=2):
|
| 460 |
+
query_input = gr.Textbox(
|
| 461 |
+
label="Query", placeholder="Enter your question or reasoning task...", lines=4, max_lines=10
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
gr.Markdown("**Example Queries:**")
|
| 465 |
+
example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True)
|
| 466 |
+
|
| 467 |
+
def load_example(example):
|
| 468 |
+
return example
|
| 469 |
+
|
| 470 |
+
example_dropdown.change(load_example, example_dropdown, query_input)
|
| 471 |
+
|
| 472 |
+
with gr.Column(scale=1):
|
| 473 |
+
gr.Markdown("**Meta-Controller Selection**")
|
| 474 |
+
controller_type = gr.Radio(
|
| 475 |
+
choices=["RNN", "BERT"],
|
| 476 |
+
value="RNN",
|
| 477 |
+
label="Controller Type",
|
| 478 |
+
info="Choose which trained controller to use",
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
gr.Markdown(
|
| 482 |
+
"""
|
| 483 |
+
**Controller Comparison:**
|
| 484 |
+
- **RNN**: Fast, captures sequential patterns
|
| 485 |
+
- **BERT**: More context-aware, text understanding
|
| 486 |
+
"""
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
process_btn = gr.Button("🚀 Process Query", variant="primary", size="lg")
|
| 490 |
+
|
| 491 |
+
gr.Markdown("---")
|
| 492 |
+
|
| 493 |
+
with gr.Row():
|
| 494 |
+
with gr.Column():
|
| 495 |
+
gr.Markdown("### 🎯 Agent Response")
|
| 496 |
+
final_response_output = gr.Textbox(label="Response", lines=4, interactive=False)
|
| 497 |
+
|
| 498 |
+
gr.Markdown("### 🤝 Personality-Infused Response")
|
| 499 |
+
gr.Markdown("*A conversational, balanced advisor interpretation*")
|
| 500 |
+
personality_output = gr.Textbox(label="Balanced Advisor Response", lines=8, interactive=False)
|
| 501 |
+
|
| 502 |
+
gr.Markdown("### 📈 Performance Metrics")
|
| 503 |
+
metrics_output = gr.Markdown()
|
| 504 |
+
|
| 505 |
+
with gr.Column():
|
| 506 |
+
routing_viz = gr.Markdown(label="Controller Decision")
|
| 507 |
+
features_viz = gr.Markdown(label="Features")
|
| 508 |
+
|
| 509 |
+
with gr.Accordion("🔍 Detailed Agent Information", open=False):
|
| 510 |
+
agent_details_output = gr.JSON(label="Agent Execution Details")
|
| 511 |
+
|
| 512 |
+
# Wire up the processing
|
| 513 |
+
process_btn.click(
|
| 514 |
+
fn=process_query_sync,
|
| 515 |
+
inputs=[
|
| 516 |
+
query_input,
|
| 517 |
+
controller_type,
|
| 518 |
+
],
|
| 519 |
+
outputs=[final_response_output, agent_details_output, routing_viz, features_viz, metrics_output, personality_output],
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
gr.Markdown(
|
| 523 |
+
"""
|
| 524 |
+
---
|
| 525 |
+
|
| 526 |
+
### 📚 About This Demo
|
| 527 |
+
|
| 528 |
+
This is a **production demonstration** of trained neural meta-controllers for multi-agent routing.
|
| 529 |
+
|
| 530 |
+
**Models:**
|
| 531 |
+
- RNN Meta-Controller: 10-dimensional feature vector → 3-class routing (HRM/TRM/MCTS)
|
| 532 |
+
- BERT with LoRA: Text features → routing decision with adapters
|
| 533 |
+
|
| 534 |
+
**Training:**
|
| 535 |
+
- Synthetic dataset: 1000+ samples with balanced routing decisions
|
| 536 |
+
- Optimization: Adam optimizer, cross-entropy loss
|
| 537 |
+
- Validation: 80/20 train/val split with early stopping
|
| 538 |
+
|
| 539 |
+
**Repository:** [GitHub - langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts)
|
| 540 |
+
|
| 541 |
+
---
|
| 542 |
+
*Built with PyTorch, Transformers, PEFT, and Gradio*
|
| 543 |
+
"""
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
if __name__ == "__main__":
|
| 548 |
+
# Initialize framework
|
| 549 |
+
print("Initializing framework with trained models...")
|
| 550 |
+
framework = IntegratedFramework()
|
| 551 |
+
|
| 552 |
+
# Launch the demo
|
| 553 |
+
demo.launch(server_name="0.0.0.0", share=False, show_error=True)
|
app_mock.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LangGraph Multi-Agent MCTS Framework - Hugging Face Spaces Demo
|
| 3 |
+
|
| 4 |
+
A proof-of-concept demonstration of multi-agent reasoning with Monte Carlo Tree Search.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
import time
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
# Demo-specific simplified implementations
|
| 15 |
+
from demo_src.agents_demo import HRMAgent, TRMAgent
|
| 16 |
+
from demo_src.llm_mock import HuggingFaceClient, MockLLMClient
|
| 17 |
+
from demo_src.mcts_demo import MCTSDemo
|
| 18 |
+
from demo_src.wandb_tracker import WandBTracker, is_wandb_available
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AgentResult:
|
| 23 |
+
"""Result from a single agent."""
|
| 24 |
+
|
| 25 |
+
agent_name: str
|
| 26 |
+
response: str
|
| 27 |
+
confidence: float
|
| 28 |
+
reasoning_steps: list[str]
|
| 29 |
+
execution_time_ms: float
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class FrameworkResult:
|
| 34 |
+
"""Combined result from all agents."""
|
| 35 |
+
|
| 36 |
+
query: str
|
| 37 |
+
hrm_result: AgentResult | None
|
| 38 |
+
trm_result: AgentResult | None
|
| 39 |
+
mcts_result: dict | None
|
| 40 |
+
consensus_score: float
|
| 41 |
+
final_response: str
|
| 42 |
+
total_time_ms: float
|
| 43 |
+
metadata: dict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MultiAgentFrameworkDemo:
|
| 47 |
+
"""Simplified multi-agent framework for Hugging Face Spaces demo."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, use_hf_inference: bool = False, hf_model: str = ""):
|
| 50 |
+
"""Initialize the demo framework.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
use_hf_inference: Use Hugging Face Inference API instead of mock
|
| 54 |
+
hf_model: Hugging Face model ID for inference
|
| 55 |
+
"""
|
| 56 |
+
self.use_hf_inference = use_hf_inference
|
| 57 |
+
self.hf_model = hf_model
|
| 58 |
+
|
| 59 |
+
# Initialize components
|
| 60 |
+
if use_hf_inference and hf_model:
|
| 61 |
+
self.llm_client = HuggingFaceClient(model_id=hf_model)
|
| 62 |
+
else:
|
| 63 |
+
self.llm_client = MockLLMClient()
|
| 64 |
+
|
| 65 |
+
self.hrm_agent = HRMAgent(self.llm_client)
|
| 66 |
+
self.trm_agent = TRMAgent(self.llm_client)
|
| 67 |
+
self.mcts = MCTSDemo()
|
| 68 |
+
|
| 69 |
+
async def process_query(
|
| 70 |
+
self,
|
| 71 |
+
query: str,
|
| 72 |
+
use_hrm: bool = True,
|
| 73 |
+
use_trm: bool = True,
|
| 74 |
+
use_mcts: bool = False,
|
| 75 |
+
mcts_iterations: int = 25,
|
| 76 |
+
exploration_weight: float = 1.414,
|
| 77 |
+
seed: int | None = None,
|
| 78 |
+
) -> FrameworkResult:
|
| 79 |
+
"""Process a query through the multi-agent framework.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
query: The input query to process
|
| 83 |
+
use_hrm: Enable Hierarchical Reasoning Module
|
| 84 |
+
use_trm: Enable Tree Reasoning Module
|
| 85 |
+
use_mcts: Enable Monte Carlo Tree Search
|
| 86 |
+
mcts_iterations: Number of MCTS iterations
|
| 87 |
+
exploration_weight: UCB1 exploration parameter
|
| 88 |
+
seed: Random seed for reproducibility
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
FrameworkResult with all agent outputs and consensus
|
| 92 |
+
"""
|
| 93 |
+
start_time = time.perf_counter()
|
| 94 |
+
|
| 95 |
+
hrm_result = None
|
| 96 |
+
trm_result = None
|
| 97 |
+
mcts_result = None
|
| 98 |
+
|
| 99 |
+
# Run enabled agents
|
| 100 |
+
tasks = []
|
| 101 |
+
agent_names = []
|
| 102 |
+
|
| 103 |
+
if use_hrm:
|
| 104 |
+
tasks.append(self._run_hrm(query))
|
| 105 |
+
agent_names.append("hrm")
|
| 106 |
+
|
| 107 |
+
if use_trm:
|
| 108 |
+
tasks.append(self._run_trm(query))
|
| 109 |
+
agent_names.append("trm")
|
| 110 |
+
|
| 111 |
+
if use_mcts:
|
| 112 |
+
tasks.append(self._run_mcts(query, mcts_iterations, exploration_weight, seed))
|
| 113 |
+
agent_names.append("mcts")
|
| 114 |
+
|
| 115 |
+
# Execute agents concurrently
|
| 116 |
+
if tasks:
|
| 117 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 118 |
+
|
| 119 |
+
for name, result in zip(agent_names, results, strict=False):
|
| 120 |
+
if isinstance(result, Exception):
|
| 121 |
+
continue
|
| 122 |
+
if name == "hrm":
|
| 123 |
+
hrm_result = result
|
| 124 |
+
elif name == "trm":
|
| 125 |
+
trm_result = result
|
| 126 |
+
elif name == "mcts":
|
| 127 |
+
mcts_result = result
|
| 128 |
+
|
| 129 |
+
# Calculate consensus score
|
| 130 |
+
consensus_score = self._calculate_consensus(hrm_result, trm_result, mcts_result)
|
| 131 |
+
|
| 132 |
+
# Generate final synthesized response
|
| 133 |
+
final_response = self._synthesize_response(query, hrm_result, trm_result, mcts_result, consensus_score)
|
| 134 |
+
|
| 135 |
+
total_time = (time.perf_counter() - start_time) * 1000
|
| 136 |
+
|
| 137 |
+
return FrameworkResult(
|
| 138 |
+
query=query,
|
| 139 |
+
hrm_result=hrm_result,
|
| 140 |
+
trm_result=trm_result,
|
| 141 |
+
mcts_result=mcts_result,
|
| 142 |
+
consensus_score=consensus_score,
|
| 143 |
+
final_response=final_response,
|
| 144 |
+
total_time_ms=round(total_time, 2),
|
| 145 |
+
metadata={
|
| 146 |
+
"agents_used": agent_names,
|
| 147 |
+
"mcts_config": (
|
| 148 |
+
{"iterations": mcts_iterations, "exploration_weight": exploration_weight, "seed": seed}
|
| 149 |
+
if use_mcts
|
| 150 |
+
else None
|
| 151 |
+
),
|
| 152 |
+
},
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
async def _run_hrm(self, query: str) -> AgentResult:
|
| 156 |
+
"""Run Hierarchical Reasoning Module."""
|
| 157 |
+
start = time.perf_counter()
|
| 158 |
+
result = await self.hrm_agent.process(query)
|
| 159 |
+
elapsed = (time.perf_counter() - start) * 1000
|
| 160 |
+
|
| 161 |
+
return AgentResult(
|
| 162 |
+
agent_name="HRM (Hierarchical Reasoning)",
|
| 163 |
+
response=result["response"],
|
| 164 |
+
confidence=result["confidence"],
|
| 165 |
+
reasoning_steps=result["steps"],
|
| 166 |
+
execution_time_ms=round(elapsed, 2),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
async def _run_trm(self, query: str) -> AgentResult:
|
| 170 |
+
"""Run Tree Reasoning Module."""
|
| 171 |
+
start = time.perf_counter()
|
| 172 |
+
result = await self.trm_agent.process(query)
|
| 173 |
+
elapsed = (time.perf_counter() - start) * 1000
|
| 174 |
+
|
| 175 |
+
return AgentResult(
|
| 176 |
+
agent_name="TRM (Iterative Refinement)",
|
| 177 |
+
response=result["response"],
|
| 178 |
+
confidence=result["confidence"],
|
| 179 |
+
reasoning_steps=result["steps"],
|
| 180 |
+
execution_time_ms=round(elapsed, 2),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
async def _run_mcts(self, query: str, iterations: int, exploration_weight: float, seed: int | None) -> dict:
|
| 184 |
+
"""Run Monte Carlo Tree Search."""
|
| 185 |
+
start = time.perf_counter()
|
| 186 |
+
|
| 187 |
+
# MCTSDemo.search is now async and uses the production framework
|
| 188 |
+
result = await self.mcts.search(query=query, iterations=iterations, exploration_weight=exploration_weight, seed=seed)
|
| 189 |
+
|
| 190 |
+
elapsed = (time.perf_counter() - start) * 1000
|
| 191 |
+
result["execution_time_ms"] = round(elapsed, 2)
|
| 192 |
+
|
| 193 |
+
return result
|
| 194 |
+
|
| 195 |
+
def _calculate_consensus(
|
| 196 |
+
self, hrm_result: AgentResult | None, trm_result: AgentResult | None, mcts_result: dict | None
|
| 197 |
+
) -> float:
|
| 198 |
+
"""Calculate agreement score between agents."""
|
| 199 |
+
confidences = []
|
| 200 |
+
|
| 201 |
+
if hrm_result:
|
| 202 |
+
confidences.append(hrm_result.confidence)
|
| 203 |
+
if trm_result:
|
| 204 |
+
confidences.append(trm_result.confidence)
|
| 205 |
+
if mcts_result:
|
| 206 |
+
confidences.append(mcts_result.get("best_value", 0.5))
|
| 207 |
+
|
| 208 |
+
if not confidences:
|
| 209 |
+
return 0.0
|
| 210 |
+
|
| 211 |
+
# Consensus is based on confidence alignment and average
|
| 212 |
+
if len(confidences) == 1:
|
| 213 |
+
return confidences[0]
|
| 214 |
+
|
| 215 |
+
avg_confidence = np.mean(confidences)
|
| 216 |
+
std_confidence = np.std(confidences)
|
| 217 |
+
|
| 218 |
+
# Higher consensus when agents agree (low std) and are confident (high avg)
|
| 219 |
+
agreement_factor = max(0, 1 - std_confidence * 2)
|
| 220 |
+
consensus = avg_confidence * agreement_factor
|
| 221 |
+
|
| 222 |
+
return round(min(1.0, consensus), 3)
|
| 223 |
+
|
| 224 |
+
def _synthesize_response(
|
| 225 |
+
self,
|
| 226 |
+
query: str,
|
| 227 |
+
hrm_result: AgentResult | None,
|
| 228 |
+
trm_result: AgentResult | None,
|
| 229 |
+
mcts_result: dict | None,
|
| 230 |
+
consensus_score: float,
|
| 231 |
+
) -> str:
|
| 232 |
+
"""Synthesize final response from all agent outputs."""
|
| 233 |
+
parts = []
|
| 234 |
+
|
| 235 |
+
if hrm_result and hrm_result.confidence > 0.5:
|
| 236 |
+
parts.append(f"[HRM] {hrm_result.response}")
|
| 237 |
+
|
| 238 |
+
if trm_result and trm_result.confidence > 0.5:
|
| 239 |
+
parts.append(f"[TRM] {trm_result.response}")
|
| 240 |
+
|
| 241 |
+
if mcts_result and mcts_result.get("best_value", 0) > 0.5:
|
| 242 |
+
parts.append(f"[MCTS] Best path: {mcts_result.get('best_action', 'N/A')}")
|
| 243 |
+
|
| 244 |
+
if not parts:
|
| 245 |
+
truncated_query = f"{query[:80]}..." if len(query) > 80 else query
|
| 246 |
+
return f"Insufficient confidence to answer query: '{truncated_query}'."
|
| 247 |
+
|
| 248 |
+
synthesis = " | ".join(parts)
|
| 249 |
+
|
| 250 |
+
if consensus_score > 0.7:
|
| 251 |
+
return f"HIGH CONSENSUS ({consensus_score:.1%}): {synthesis}"
|
| 252 |
+
elif consensus_score > 0.4:
|
| 253 |
+
return f"MODERATE CONSENSUS ({consensus_score:.1%}): {synthesis}"
|
| 254 |
+
else:
|
| 255 |
+
return f"LOW CONSENSUS ({consensus_score:.1%}): {synthesis}"
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# Global framework instance
|
| 259 |
+
framework = None
|
| 260 |
+
wandb_tracker = None
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def initialize_framework(use_hf: bool, model_id: str):
|
| 264 |
+
"""Initialize or reinitialize the framework."""
|
| 265 |
+
global framework
|
| 266 |
+
framework = MultiAgentFrameworkDemo(use_hf_inference=use_hf, hf_model=model_id)
|
| 267 |
+
return "Framework initialized successfully!"
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def process_query_sync(
|
| 271 |
+
query: str,
|
| 272 |
+
use_hrm: bool,
|
| 273 |
+
use_trm: bool,
|
| 274 |
+
use_mcts: bool,
|
| 275 |
+
mcts_iterations: int,
|
| 276 |
+
exploration_weight: float,
|
| 277 |
+
seed: int,
|
| 278 |
+
enable_wandb: bool = False,
|
| 279 |
+
wandb_project: str = "langgraph-mcts-demo",
|
| 280 |
+
wandb_run_name: str = "",
|
| 281 |
+
):
|
| 282 |
+
"""Synchronous wrapper for async processing."""
|
| 283 |
+
global framework, wandb_tracker
|
| 284 |
+
|
| 285 |
+
if framework is None:
|
| 286 |
+
framework = MultiAgentFrameworkDemo()
|
| 287 |
+
|
| 288 |
+
if not query.strip():
|
| 289 |
+
return "Please enter a query.", {}, "", {}, ""
|
| 290 |
+
|
| 291 |
+
# Handle seed
|
| 292 |
+
seed_value = seed if seed > 0 else None
|
| 293 |
+
|
| 294 |
+
# Initialize W&B tracking if enabled
|
| 295 |
+
wandb_url = ""
|
| 296 |
+
if enable_wandb and is_wandb_available():
|
| 297 |
+
if wandb_tracker is None:
|
| 298 |
+
wandb_tracker = WandBTracker(project_name=wandb_project, enabled=True)
|
| 299 |
+
|
| 300 |
+
# Start a new run
|
| 301 |
+
run_name = wandb_run_name if wandb_run_name.strip() else None
|
| 302 |
+
config = {
|
| 303 |
+
"query": query[:200], # Truncate for config
|
| 304 |
+
"use_hrm": use_hrm,
|
| 305 |
+
"use_trm": use_trm,
|
| 306 |
+
"use_mcts": use_mcts,
|
| 307 |
+
"mcts_iterations": mcts_iterations,
|
| 308 |
+
"exploration_weight": exploration_weight,
|
| 309 |
+
"seed": seed_value,
|
| 310 |
+
}
|
| 311 |
+
wandb_tracker.init_run(run_name=run_name, config=config)
|
| 312 |
+
|
| 313 |
+
# Run async function
|
| 314 |
+
result = asyncio.run(
|
| 315 |
+
framework.process_query(
|
| 316 |
+
query=query,
|
| 317 |
+
use_hrm=use_hrm,
|
| 318 |
+
use_trm=use_trm,
|
| 319 |
+
use_mcts=use_mcts,
|
| 320 |
+
mcts_iterations=int(mcts_iterations),
|
| 321 |
+
exploration_weight=exploration_weight,
|
| 322 |
+
seed=seed_value,
|
| 323 |
+
)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Format outputs
|
| 327 |
+
final_response = result.final_response
|
| 328 |
+
|
| 329 |
+
# Agent details
|
| 330 |
+
agent_details = {}
|
| 331 |
+
if result.hrm_result:
|
| 332 |
+
agent_details["HRM"] = {
|
| 333 |
+
"response": result.hrm_result.response,
|
| 334 |
+
"confidence": f"{result.hrm_result.confidence:.1%}",
|
| 335 |
+
"reasoning_steps": result.hrm_result.reasoning_steps,
|
| 336 |
+
"time_ms": result.hrm_result.execution_time_ms,
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Log to W&B
|
| 340 |
+
if enable_wandb and wandb_tracker:
|
| 341 |
+
wandb_tracker.log_agent_result(
|
| 342 |
+
"HRM",
|
| 343 |
+
result.hrm_result.response,
|
| 344 |
+
result.hrm_result.confidence,
|
| 345 |
+
result.hrm_result.execution_time_ms,
|
| 346 |
+
result.hrm_result.reasoning_steps,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if result.trm_result:
|
| 350 |
+
agent_details["TRM"] = {
|
| 351 |
+
"response": result.trm_result.response,
|
| 352 |
+
"confidence": f"{result.trm_result.confidence:.1%}",
|
| 353 |
+
"reasoning_steps": result.trm_result.reasoning_steps,
|
| 354 |
+
"time_ms": result.trm_result.execution_time_ms,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
# Log to W&B
|
| 358 |
+
if enable_wandb and wandb_tracker:
|
| 359 |
+
wandb_tracker.log_agent_result(
|
| 360 |
+
"TRM",
|
| 361 |
+
result.trm_result.response,
|
| 362 |
+
result.trm_result.confidence,
|
| 363 |
+
result.trm_result.execution_time_ms,
|
| 364 |
+
result.trm_result.reasoning_steps,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if result.mcts_result:
|
| 368 |
+
agent_details["MCTS"] = result.mcts_result
|
| 369 |
+
|
| 370 |
+
# Log to W&B
|
| 371 |
+
if enable_wandb and wandb_tracker:
|
| 372 |
+
wandb_tracker.log_mcts_result(result.mcts_result)
|
| 373 |
+
|
| 374 |
+
# Log consensus and performance to W&B
|
| 375 |
+
if enable_wandb and wandb_tracker:
|
| 376 |
+
wandb_tracker.log_consensus(result.consensus_score, result.metadata["agents_used"], result.final_response)
|
| 377 |
+
wandb_tracker.log_performance(result.total_time_ms)
|
| 378 |
+
wandb_tracker.log_query_summary(query, use_hrm, use_trm, use_mcts, result.consensus_score, result.total_time_ms)
|
| 379 |
+
|
| 380 |
+
# Get run URL
|
| 381 |
+
wandb_url = wandb_tracker.get_run_url() or ""
|
| 382 |
+
|
| 383 |
+
# Finish the run
|
| 384 |
+
wandb_tracker.finish_run()
|
| 385 |
+
|
| 386 |
+
# Metrics
|
| 387 |
+
metrics = f"""
|
| 388 |
+
**Consensus Score:** {result.consensus_score:.1%}
|
| 389 |
+
**Total Processing Time:** {result.total_time_ms:.2f} ms
|
| 390 |
+
**Agents Used:** {", ".join(result.metadata["agents_used"])}
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
if wandb_url:
|
| 394 |
+
metrics += f"\n**W&B Run:** [{wandb_url}]({wandb_url})"
|
| 395 |
+
|
| 396 |
+
# Full JSON result
|
| 397 |
+
full_result = {
|
| 398 |
+
"query": result.query,
|
| 399 |
+
"final_response": result.final_response,
|
| 400 |
+
"consensus_score": result.consensus_score,
|
| 401 |
+
"total_time_ms": result.total_time_ms,
|
| 402 |
+
"metadata": result.metadata,
|
| 403 |
+
"agent_details": agent_details,
|
| 404 |
+
"wandb_url": wandb_url,
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
return final_response, agent_details, metrics, full_result, wandb_url
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def visualize_mcts_tree(mcts_result: dict) -> str:
|
| 411 |
+
"""Create ASCII visualization of MCTS tree."""
|
| 412 |
+
if not mcts_result or "tree_visualization" not in mcts_result:
|
| 413 |
+
return "No MCTS tree data available"
|
| 414 |
+
|
| 415 |
+
return mcts_result["tree_visualization"]
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Example queries for demonstration
|
| 419 |
+
EXAMPLE_QUERIES = [
|
| 420 |
+
"What are the key factors to consider when choosing between microservices and monolithic architecture?",
|
| 421 |
+
"How can we optimize a Python application that processes 10GB of log files daily?",
|
| 422 |
+
"What is the best approach to implement rate limiting in a distributed system?",
|
| 423 |
+
"Should we use SQL or NoSQL database for a social media application with 1M users?",
|
| 424 |
+
"How to design a fault-tolerant message queue system?",
|
| 425 |
+
]
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# Gradio Interface
|
| 429 |
+
with gr.Blocks(
|
| 430 |
+
title="LangGraph Multi-Agent MCTS Demo",
|
| 431 |
+
theme=gr.themes.Soft(),
|
| 432 |
+
css="""
|
| 433 |
+
.agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; }
|
| 434 |
+
.consensus-high { color: #28a745; font-weight: bold; }
|
| 435 |
+
.consensus-medium { color: #ffc107; font-weight: bold; }
|
| 436 |
+
.consensus-low { color: #dc3545; font-weight: bold; }
|
| 437 |
+
""",
|
| 438 |
+
) as demo:
|
| 439 |
+
gr.Markdown(
|
| 440 |
+
"""
|
| 441 |
+
# LangGraph Multi-Agent MCTS Framework
|
| 442 |
+
|
| 443 |
+
**Proof-of-Concept Demo** - Multi-agent reasoning with Monte Carlo Tree Search
|
| 444 |
+
|
| 445 |
+
This demo showcases:
|
| 446 |
+
- **HRM**: Hierarchical Reasoning Module - breaks down complex queries
|
| 447 |
+
- **TRM**: Tree Reasoning Module - iterative refinement of responses
|
| 448 |
+
- **MCTS**: Monte Carlo Tree Search - strategic exploration of solution space
|
| 449 |
+
- **Consensus**: Agreement scoring between agents
|
| 450 |
+
|
| 451 |
+
---
|
| 452 |
+
"""
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
with gr.Row():
|
| 456 |
+
with gr.Column(scale=2):
|
| 457 |
+
query_input = gr.Textbox(
|
| 458 |
+
label="Query", placeholder="Enter your reasoning task or question...", lines=3, max_lines=10
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
gr.Markdown("**Example Queries:**")
|
| 462 |
+
example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True)
|
| 463 |
+
|
| 464 |
+
def load_example(example):
|
| 465 |
+
return example
|
| 466 |
+
|
| 467 |
+
example_dropdown.change(load_example, example_dropdown, query_input)
|
| 468 |
+
|
| 469 |
+
with gr.Column(scale=1):
|
| 470 |
+
gr.Markdown("**Agent Configuration**")
|
| 471 |
+
use_hrm = gr.Checkbox(label="Enable HRM (Hierarchical)", value=True)
|
| 472 |
+
use_trm = gr.Checkbox(label="Enable TRM (Iterative)", value=True)
|
| 473 |
+
use_mcts = gr.Checkbox(label="Enable MCTS", value=False)
|
| 474 |
+
|
| 475 |
+
gr.Markdown("**MCTS Parameters**")
|
| 476 |
+
mcts_iterations = gr.Slider(
|
| 477 |
+
minimum=10,
|
| 478 |
+
maximum=100,
|
| 479 |
+
value=25,
|
| 480 |
+
step=5,
|
| 481 |
+
label="Iterations",
|
| 482 |
+
info="More iterations = better search, but slower",
|
| 483 |
+
)
|
| 484 |
+
exploration_weight = gr.Slider(
|
| 485 |
+
minimum=0.1,
|
| 486 |
+
maximum=3.0,
|
| 487 |
+
value=1.414,
|
| 488 |
+
step=0.1,
|
| 489 |
+
label="Exploration Weight (C)",
|
| 490 |
+
info="Higher = more exploration, Lower = more exploitation",
|
| 491 |
+
)
|
| 492 |
+
seed_input = gr.Number(label="Random Seed (0 for random)", value=0, precision=0)
|
| 493 |
+
|
| 494 |
+
with gr.Accordion("Weights & Biases Tracking", open=False):
|
| 495 |
+
gr.Markdown(
|
| 496 |
+
"""
|
| 497 |
+
**Experiment Tracking with W&B**
|
| 498 |
+
|
| 499 |
+
Track your experiments, visualize metrics, and compare runs.
|
| 500 |
+
Requires W&B API key set in Space secrets as `WANDB_API_KEY`.
|
| 501 |
+
"""
|
| 502 |
+
)
|
| 503 |
+
with gr.Row():
|
| 504 |
+
enable_wandb = gr.Checkbox(
|
| 505 |
+
label="Enable W&B Tracking", value=False, info="Log metrics and results to Weights & Biases"
|
| 506 |
+
)
|
| 507 |
+
wandb_project = gr.Textbox(
|
| 508 |
+
label="Project Name", value="langgraph-mcts-demo", placeholder="Your W&B project name"
|
| 509 |
+
)
|
| 510 |
+
wandb_run_name = gr.Textbox(label="Run Name (optional)", value="", placeholder="Auto-generated if empty")
|
| 511 |
+
|
| 512 |
+
wandb_status = gr.Markdown(f"**W&B Status:** {'Available' if is_wandb_available() else 'Not installed'}")
|
| 513 |
+
|
| 514 |
+
process_btn = gr.Button("Process Query", variant="primary", size="lg")
|
| 515 |
+
|
| 516 |
+
gr.Markdown("---")
|
| 517 |
+
|
| 518 |
+
with gr.Row():
|
| 519 |
+
with gr.Column():
|
| 520 |
+
gr.Markdown("### Final Response")
|
| 521 |
+
final_response_output = gr.Textbox(label="Synthesized Response", lines=4, interactive=False)
|
| 522 |
+
|
| 523 |
+
gr.Markdown("### Performance Metrics")
|
| 524 |
+
metrics_output = gr.Markdown()
|
| 525 |
+
|
| 526 |
+
with gr.Column():
|
| 527 |
+
gr.Markdown("### Agent Details")
|
| 528 |
+
agent_details_output = gr.JSON(label="Individual Agent Results")
|
| 529 |
+
|
| 530 |
+
with gr.Accordion("Full JSON Result", open=False):
|
| 531 |
+
full_result_output = gr.JSON(label="Complete Framework Output")
|
| 532 |
+
|
| 533 |
+
with gr.Accordion("W&B Run Details", open=False, visible=True):
|
| 534 |
+
wandb_url_output = gr.Textbox(
|
| 535 |
+
label="W&B Run URL", interactive=False, placeholder="Enable W&B tracking to see run URL here"
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Wire up the processing
|
| 539 |
+
process_btn.click(
|
| 540 |
+
fn=process_query_sync,
|
| 541 |
+
inputs=[
|
| 542 |
+
query_input,
|
| 543 |
+
use_hrm,
|
| 544 |
+
use_trm,
|
| 545 |
+
use_mcts,
|
| 546 |
+
mcts_iterations,
|
| 547 |
+
exploration_weight,
|
| 548 |
+
seed_input,
|
| 549 |
+
enable_wandb,
|
| 550 |
+
wandb_project,
|
| 551 |
+
wandb_run_name,
|
| 552 |
+
],
|
| 553 |
+
outputs=[final_response_output, agent_details_output, metrics_output, full_result_output, wandb_url_output],
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
gr.Markdown(
|
| 557 |
+
"""
|
| 558 |
+
---
|
| 559 |
+
|
| 560 |
+
### About This Demo
|
| 561 |
+
|
| 562 |
+
This is a **proof-of-concept** demonstration of the LangGraph Multi-Agent MCTS Framework.
|
| 563 |
+
|
| 564 |
+
**Features:**
|
| 565 |
+
- Multi-agent orchestration with consensus scoring
|
| 566 |
+
- Monte Carlo Tree Search for strategic reasoning
|
| 567 |
+
- Configurable exploration vs exploitation trade-offs
|
| 568 |
+
- Deterministic results with seeded randomness
|
| 569 |
+
- **Weights & Biases integration** for experiment tracking
|
| 570 |
+
|
| 571 |
+
**Limitations (POC):**
|
| 572 |
+
- Uses mock/simplified LLM responses (not production LLM)
|
| 573 |
+
- Limited to demonstration scenarios
|
| 574 |
+
- No persistent storage or RAG
|
| 575 |
+
- Simplified MCTS implementation
|
| 576 |
+
|
| 577 |
+
**Full Framework:** [GitHub Repository](https://github.com/ianshank/langgraph_multi_agent_mcts)
|
| 578 |
+
|
| 579 |
+
---
|
| 580 |
+
*Built with LangGraph, Gradio, Weights & Biases, and Python*
|
| 581 |
+
"""
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
if __name__ == "__main__":
|
| 586 |
+
# Initialize with mock client for demo
|
| 587 |
+
framework = MultiAgentFrameworkDemo(use_hf_inference=False)
|
| 588 |
+
|
| 589 |
+
# Launch the demo
|
| 590 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
|
demo_src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Demo source modules for Hugging Face Spaces
|
demo_src/agents_demo.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simplified agent implementations for Hugging Face Spaces demo.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class HRMAgent:
|
| 10 |
+
"""Hierarchical Reasoning Module - breaks down complex queries."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, llm_client):
|
| 13 |
+
"""Initialize with an LLM client.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
llm_client: LLM client (MockLLMClient or HuggingFaceClient)
|
| 17 |
+
"""
|
| 18 |
+
self.llm_client = llm_client
|
| 19 |
+
self.name = "HRM (Hierarchical Reasoning)"
|
| 20 |
+
|
| 21 |
+
async def process(self, query: str) -> dict[str, Any]:
|
| 22 |
+
"""Process query using hierarchical decomposition.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
query: Input query to process
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Dictionary with response, confidence, and reasoning steps
|
| 29 |
+
"""
|
| 30 |
+
# Step 1: Decompose the query
|
| 31 |
+
decomposition_steps = await self._decompose_query(query)
|
| 32 |
+
|
| 33 |
+
# Step 2: Analyze each component
|
| 34 |
+
analysis_results = await self._analyze_components(decomposition_steps)
|
| 35 |
+
|
| 36 |
+
# Step 3: Synthesize hierarchical response
|
| 37 |
+
llm_result = await self.llm_client.generate(
|
| 38 |
+
prompt=f"Hierarchical analysis of: {query}", context=f"Components: {', '.join(decomposition_steps)}"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Compile reasoning steps
|
| 42 |
+
reasoning_steps = [
|
| 43 |
+
f"1. Query decomposition: Identified {len(decomposition_steps)} key components",
|
| 44 |
+
f"2. Component analysis: {analysis_results}",
|
| 45 |
+
"3. Hierarchical synthesis: Combined insights from all levels",
|
| 46 |
+
f"4. Confidence assessment: {llm_result['confidence']:.1%} based on component clarity",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
"response": llm_result["response"],
|
| 51 |
+
"confidence": llm_result["confidence"],
|
| 52 |
+
"steps": reasoning_steps,
|
| 53 |
+
"components": decomposition_steps,
|
| 54 |
+
"tokens_used": llm_result.get("tokens_used", 0),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
async def _decompose_query(self, query: str) -> list[str]:
|
| 58 |
+
"""Decompose query into hierarchical components."""
|
| 59 |
+
# Simulate decomposition based on query structure
|
| 60 |
+
await asyncio.sleep(0.05) # Simulate processing
|
| 61 |
+
|
| 62 |
+
# Simple heuristic decomposition
|
| 63 |
+
components = []
|
| 64 |
+
|
| 65 |
+
# Extract key phrases
|
| 66 |
+
query_lower = query.lower()
|
| 67 |
+
|
| 68 |
+
if "?" in query:
|
| 69 |
+
components.append("Question type: Analytical")
|
| 70 |
+
else:
|
| 71 |
+
components.append("Question type: Directive")
|
| 72 |
+
|
| 73 |
+
if "how" in query_lower:
|
| 74 |
+
components.append("Focus: Methodology/Process")
|
| 75 |
+
elif "what" in query_lower:
|
| 76 |
+
components.append("Focus: Definition/Identification")
|
| 77 |
+
elif "why" in query_lower:
|
| 78 |
+
components.append("Focus: Causation/Reasoning")
|
| 79 |
+
elif "should" in query_lower or "best" in query_lower:
|
| 80 |
+
components.append("Focus: Decision/Recommendation")
|
| 81 |
+
else:
|
| 82 |
+
components.append("Focus: General inquiry")
|
| 83 |
+
|
| 84 |
+
# Domain detection
|
| 85 |
+
if any(term in query_lower for term in ["database", "sql", "nosql", "storage"]):
|
| 86 |
+
components.append("Domain: Data Management")
|
| 87 |
+
elif any(term in query_lower for term in ["architecture", "design", "pattern"]):
|
| 88 |
+
components.append("Domain: System Architecture")
|
| 89 |
+
elif any(term in query_lower for term in ["performance", "optimization", "speed"]):
|
| 90 |
+
components.append("Domain: Performance Engineering")
|
| 91 |
+
elif any(term in query_lower for term in ["scale", "distributed", "cluster"]):
|
| 92 |
+
components.append("Domain: Distributed Systems")
|
| 93 |
+
else:
|
| 94 |
+
components.append("Domain: Software Engineering")
|
| 95 |
+
|
| 96 |
+
# Complexity assessment
|
| 97 |
+
word_count = len(query.split())
|
| 98 |
+
if word_count > 20:
|
| 99 |
+
components.append("Complexity: High (detailed query)")
|
| 100 |
+
elif word_count > 10:
|
| 101 |
+
components.append("Complexity: Medium")
|
| 102 |
+
else:
|
| 103 |
+
components.append("Complexity: Low (concise query)")
|
| 104 |
+
|
| 105 |
+
return components
|
| 106 |
+
|
| 107 |
+
async def _analyze_components(self, components: list[str]) -> str:
|
| 108 |
+
"""Analyze the decomposed components."""
|
| 109 |
+
await asyncio.sleep(0.03) # Simulate processing
|
| 110 |
+
|
| 111 |
+
# Generate analysis summary
|
| 112 |
+
analysis_parts = []
|
| 113 |
+
|
| 114 |
+
for component in components:
|
| 115 |
+
if "Focus:" in component:
|
| 116 |
+
focus = component.split(":")[1].strip()
|
| 117 |
+
analysis_parts.append(f"requires {focus.lower()} approach")
|
| 118 |
+
elif "Domain:" in component:
|
| 119 |
+
domain = component.split(":")[1].strip()
|
| 120 |
+
analysis_parts.append(f"applies to {domain}")
|
| 121 |
+
elif "Complexity:" in component:
|
| 122 |
+
complexity = component.split(":")[1].strip().split()[0]
|
| 123 |
+
analysis_parts.append(f"{complexity.lower()} complexity level")
|
| 124 |
+
|
| 125 |
+
return "; ".join(analysis_parts) if analysis_parts else "Standard analysis"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TRMAgent:
|
| 129 |
+
"""Tree Reasoning Module - iterative refinement of responses."""
|
| 130 |
+
|
| 131 |
+
def __init__(self, llm_client):
|
| 132 |
+
"""Initialize with an LLM client.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
llm_client: LLM client (MockLLMClient or HuggingFaceClient)
|
| 136 |
+
"""
|
| 137 |
+
self.llm_client = llm_client
|
| 138 |
+
self.name = "TRM (Iterative Refinement)"
|
| 139 |
+
self.max_iterations = 3
|
| 140 |
+
|
| 141 |
+
async def process(self, query: str) -> dict[str, Any]:
|
| 142 |
+
"""Process query using iterative refinement.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
query: Input query to process
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Dictionary with response, confidence, and reasoning steps
|
| 149 |
+
"""
|
| 150 |
+
reasoning_steps = []
|
| 151 |
+
current_response = ""
|
| 152 |
+
current_confidence = 0.0
|
| 153 |
+
|
| 154 |
+
# Iterative refinement loop
|
| 155 |
+
for iteration in range(self.max_iterations):
|
| 156 |
+
step_num = iteration + 1
|
| 157 |
+
|
| 158 |
+
# Generate or refine response
|
| 159 |
+
if iteration == 0:
|
| 160 |
+
# Initial response
|
| 161 |
+
result = await self.llm_client.generate(prompt=query, context="")
|
| 162 |
+
current_response = result["response"]
|
| 163 |
+
current_confidence = result["confidence"]
|
| 164 |
+
reasoning_steps.append(
|
| 165 |
+
f"Iteration {step_num}: Initial response generated (confidence: {current_confidence:.1%})"
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
# Refinement iteration
|
| 169 |
+
refinement_result = await self._refine_response(query, current_response, iteration)
|
| 170 |
+
current_response = refinement_result["response"]
|
| 171 |
+
|
| 172 |
+
# Confidence typically improves with refinement
|
| 173 |
+
confidence_improvement = min(0.1, (1 - current_confidence) * 0.3)
|
| 174 |
+
current_confidence = min(0.95, current_confidence + confidence_improvement)
|
| 175 |
+
|
| 176 |
+
reasoning_steps.append(
|
| 177 |
+
f"Iteration {step_num}: {refinement_result['refinement_type']} "
|
| 178 |
+
f"(confidence: {current_confidence:.1%})"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Check if confidence is high enough to stop
|
| 182 |
+
if current_confidence > 0.85:
|
| 183 |
+
reasoning_steps.append(f"Early termination: High confidence ({current_confidence:.1%}) achieved")
|
| 184 |
+
break
|
| 185 |
+
|
| 186 |
+
# Final reasoning step
|
| 187 |
+
reasoning_steps.append(f"Final: Response refined through {len(reasoning_steps)} iterations")
|
| 188 |
+
|
| 189 |
+
return {
|
| 190 |
+
"response": current_response,
|
| 191 |
+
"confidence": round(current_confidence, 3),
|
| 192 |
+
"steps": reasoning_steps,
|
| 193 |
+
"iterations_used": min(iteration + 1, self.max_iterations),
|
| 194 |
+
"refinement_history": reasoning_steps,
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
async def _refine_response(self, query: str, current_response: str, iteration: int) -> dict[str, Any]:
|
| 198 |
+
"""Refine the current response."""
|
| 199 |
+
await asyncio.sleep(0.05) # Simulate refinement processing
|
| 200 |
+
|
| 201 |
+
# Different refinement strategies based on iteration
|
| 202 |
+
refinement_strategies = [
|
| 203 |
+
("Clarity enhancement", "improve clarity and precision"),
|
| 204 |
+
("Detail expansion", "add technical depth and specifics"),
|
| 205 |
+
("Validation check", "verify accuracy and completeness"),
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
strategy_name, strategy_action = refinement_strategies[iteration % len(refinement_strategies)]
|
| 209 |
+
|
| 210 |
+
# Generate refined response
|
| 211 |
+
refinement_prompt = f"""
|
| 212 |
+
Original query: {query}
|
| 213 |
+
Current response: {current_response}
|
| 214 |
+
Refinement task: {strategy_action}
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
result = await self.llm_client.generate(
|
| 218 |
+
prompt=refinement_prompt, context=f"Refinement iteration {iteration + 1}"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Enhance the response based on strategy
|
| 222 |
+
enhanced_response = current_response
|
| 223 |
+
if strategy_name == "Clarity enhancement":
|
| 224 |
+
enhanced_response = f"{current_response}. {result['response']}"
|
| 225 |
+
elif strategy_name == "Detail expansion":
|
| 226 |
+
enhanced_response = f"{current_response}. Furthermore, {result['response']}"
|
| 227 |
+
else: # Validation
|
| 228 |
+
enhanced_response = f"{current_response}. Validated: {result['response']}"
|
| 229 |
+
|
| 230 |
+
# Truncate if too long
|
| 231 |
+
if len(enhanced_response) > 300:
|
| 232 |
+
enhanced_response = enhanced_response[:297] + "..."
|
| 233 |
+
|
| 234 |
+
return {"response": enhanced_response, "refinement_type": strategy_name, "strategy_action": strategy_action}
|
demo_src/llm_mock.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mock and lightweight LLM clients for demo purposes.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import random
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MockLLMClient:
|
| 11 |
+
"""Mock LLM client that generates plausible demo responses."""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.response_templates = {
|
| 15 |
+
"architecture": [
|
| 16 |
+
"Consider scalability requirements and team expertise",
|
| 17 |
+
"Evaluate coupling, deployment complexity, and operational overhead",
|
| 18 |
+
"Balance between development speed and long-term maintainability",
|
| 19 |
+
],
|
| 20 |
+
"optimization": [
|
| 21 |
+
"Profile first to identify actual bottlenecks",
|
| 22 |
+
"Consider memory-mapped files and streaming processing",
|
| 23 |
+
"Implement parallel processing with appropriate chunk sizes",
|
| 24 |
+
],
|
| 25 |
+
"database": [
|
| 26 |
+
"Consider data consistency requirements and query patterns",
|
| 27 |
+
"Evaluate write-heavy vs read-heavy workload characteristics",
|
| 28 |
+
"Plan for horizontal scaling and data distribution strategies",
|
| 29 |
+
],
|
| 30 |
+
"distributed": [
|
| 31 |
+
"Implement proper failure detection and recovery mechanisms",
|
| 32 |
+
"Use circuit breakers and bulkhead patterns for resilience",
|
| 33 |
+
"Consider eventual consistency vs strong consistency trade-offs",
|
| 34 |
+
],
|
| 35 |
+
"default": [
|
| 36 |
+
"Break down the problem into smaller components",
|
| 37 |
+
"Consider trade-offs between different approaches",
|
| 38 |
+
"Evaluate based on specific use case requirements",
|
| 39 |
+
],
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
async def generate(self, prompt: str, context: str = "") -> dict[str, Any]:
|
| 43 |
+
"""Generate a mock response based on the prompt and optional context."""
|
| 44 |
+
# Simulate processing time
|
| 45 |
+
await asyncio.sleep(random.uniform(0.1, 0.3))
|
| 46 |
+
|
| 47 |
+
# Determine response category
|
| 48 |
+
prompt_lower = prompt.lower()
|
| 49 |
+
if "architecture" in prompt_lower or "microservice" in prompt_lower or "monolith" in prompt_lower:
|
| 50 |
+
category = "architecture"
|
| 51 |
+
elif "optim" in prompt_lower or "performance" in prompt_lower or "process" in prompt_lower:
|
| 52 |
+
category = "optimization"
|
| 53 |
+
elif "database" in prompt_lower or "sql" in prompt_lower or "nosql" in prompt_lower:
|
| 54 |
+
category = "database"
|
| 55 |
+
elif "distribut" in prompt_lower or "fault" in prompt_lower or "rate limit" in prompt_lower:
|
| 56 |
+
category = "distributed"
|
| 57 |
+
else:
|
| 58 |
+
category = "default"
|
| 59 |
+
|
| 60 |
+
templates = self.response_templates[category]
|
| 61 |
+
|
| 62 |
+
# Generate response with some randomness
|
| 63 |
+
response = random.choice(templates)
|
| 64 |
+
confidence = random.uniform(0.6, 0.95)
|
| 65 |
+
|
| 66 |
+
# Add more detail based on prompt length (simulating "understanding")
|
| 67 |
+
if len(prompt) > 100:
|
| 68 |
+
confidence = min(0.95, confidence + 0.1)
|
| 69 |
+
response += f". Additionally, {random.choice(self.response_templates['default'])}"
|
| 70 |
+
|
| 71 |
+
# Lightly incorporate context to simulate conditioning
|
| 72 |
+
context_snippet = context.strip()
|
| 73 |
+
if context_snippet:
|
| 74 |
+
confidence = min(0.99, confidence + 0.05)
|
| 75 |
+
response += f" (context signal: {context_snippet[:60]}{'...' if len(context_snippet) > 60 else ''})"
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
"response": response,
|
| 79 |
+
"confidence": round(confidence, 3),
|
| 80 |
+
"tokens_used": len(prompt.split()) * 2 + len(response.split()),
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
async def generate_reasoning_steps(self, query: str, num_steps: int = 3) -> list[str]:
|
| 84 |
+
"""Generate mock reasoning steps."""
|
| 85 |
+
await asyncio.sleep(random.uniform(0.05, 0.15))
|
| 86 |
+
|
| 87 |
+
base_steps = [
|
| 88 |
+
f"Analyzing query: '{query[:50]}...'",
|
| 89 |
+
"Identifying key requirements and constraints",
|
| 90 |
+
"Evaluating potential approaches",
|
| 91 |
+
"Considering trade-offs and implications",
|
| 92 |
+
"Synthesizing recommendations based on analysis",
|
| 93 |
+
"Validating conclusions against requirements",
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
return random.sample(base_steps, min(num_steps, len(base_steps)))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class HuggingFaceClient:
|
| 100 |
+
"""Lightweight Hugging Face Inference API client."""
|
| 101 |
+
|
| 102 |
+
def __init__(self, model_id: str = "mistralai/Mistral-7B-Instruct-v0.2"):
|
| 103 |
+
"""Initialize with a Hugging Face model.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
model_id: The model ID on Hugging Face Hub
|
| 107 |
+
"""
|
| 108 |
+
self.model_id = model_id
|
| 109 |
+
self._client = None
|
| 110 |
+
|
| 111 |
+
def _get_client(self):
|
| 112 |
+
"""Lazy load the HF client."""
|
| 113 |
+
if self._client is None:
|
| 114 |
+
try:
|
| 115 |
+
from huggingface_hub import InferenceClient
|
| 116 |
+
|
| 117 |
+
self._client = InferenceClient(model=self.model_id)
|
| 118 |
+
except ImportError:
|
| 119 |
+
raise ImportError("huggingface_hub not installed. Install with: pip install huggingface_hub")
|
| 120 |
+
return self._client
|
| 121 |
+
|
| 122 |
+
async def generate(self, prompt: str, context: str = "") -> dict[str, Any]:
|
| 123 |
+
"""Generate response using Hugging Face Inference API."""
|
| 124 |
+
try:
|
| 125 |
+
client = self._get_client()
|
| 126 |
+
|
| 127 |
+
# Format prompt
|
| 128 |
+
if context:
|
| 129 |
+
full_prompt = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:"
|
| 130 |
+
else:
|
| 131 |
+
full_prompt = f"Question: {prompt}\n\nProvide a concise, technical answer:\n\nAnswer:"
|
| 132 |
+
|
| 133 |
+
# Call HF Inference API (sync call wrapped in async)
|
| 134 |
+
response_text = await asyncio.to_thread(
|
| 135 |
+
client.text_generation, full_prompt, max_new_tokens=150, temperature=0.7, do_sample=True
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Estimate confidence based on response characteristics
|
| 139 |
+
confidence = min(0.95, 0.6 + len(response_text) / 500)
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"response": response_text.strip(),
|
| 143 |
+
"confidence": round(confidence, 3),
|
| 144 |
+
"tokens_used": len(full_prompt.split()) + len(response_text.split()),
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
# Fallback to mock on error
|
| 149 |
+
print(f"HF Inference error: {e}. Falling back to mock.")
|
| 150 |
+
mock = MockLLMClient()
|
| 151 |
+
return await mock.generate(prompt, context)
|
| 152 |
+
|
| 153 |
+
async def generate_reasoning_steps(self, query: str, num_steps: int = 3) -> list[str]:
|
| 154 |
+
"""Generate reasoning steps using HF model."""
|
| 155 |
+
try:
|
| 156 |
+
client = self._get_client()
|
| 157 |
+
|
| 158 |
+
prompt = f"""Break down this question into {num_steps} reasoning steps:
|
| 159 |
+
Question: {query}
|
| 160 |
+
|
| 161 |
+
Reasoning steps (one per line):
|
| 162 |
+
1."""
|
| 163 |
+
|
| 164 |
+
response = await asyncio.to_thread(client.text_generation, prompt, max_new_tokens=200, temperature=0.5)
|
| 165 |
+
|
| 166 |
+
# Parse steps from response
|
| 167 |
+
lines = response.strip().split("\n")
|
| 168 |
+
steps = []
|
| 169 |
+
for line in lines:
|
| 170 |
+
line = line.strip()
|
| 171 |
+
if line and not line.startswith("#"):
|
| 172 |
+
# Remove numbering
|
| 173 |
+
if line[0].isdigit() and "." in line[:3]:
|
| 174 |
+
line = line.split(".", 1)[1].strip()
|
| 175 |
+
steps.append(line)
|
| 176 |
+
|
| 177 |
+
return steps[:num_steps] if steps else ["Analysis in progress"]
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"HF reasoning error: {e}. Falling back to mock.")
|
| 181 |
+
mock = MockLLMClient()
|
| 182 |
+
return await mock.generate_reasoning_steps(query, num_steps)
|
demo_src/mcts_demo.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Educational MCTS demonstration using the production framework.
|
| 3 |
+
|
| 4 |
+
This demo uses the real MCTSEngine from src.framework.mcts.core to provide
|
| 5 |
+
an authentic learning experience while remaining accessible for demonstrations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from src.framework.mcts.core import MCTSEngine, MCTSNode, MCTSState
|
| 14 |
+
from src.framework.mcts.policies import RolloutPolicy, SelectionPolicy
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DemoRolloutPolicy(RolloutPolicy):
|
| 18 |
+
"""
|
| 19 |
+
Educational rollout policy for demo purposes.
|
| 20 |
+
|
| 21 |
+
Evaluates states based on:
|
| 22 |
+
- Depth of exploration (deeper = more thorough)
|
| 23 |
+
- Action quality (domain-specific heuristics)
|
| 24 |
+
- Exploration randomness
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, category: str, action_templates: dict[str, list[str]]):
|
| 28 |
+
"""
|
| 29 |
+
Initialize demo rollout policy.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
category: Query category for heuristic evaluation
|
| 33 |
+
action_templates: Available action templates for scoring
|
| 34 |
+
"""
|
| 35 |
+
self.category = category
|
| 36 |
+
self.action_templates = action_templates
|
| 37 |
+
|
| 38 |
+
# Define key terms that indicate quality actions per category
|
| 39 |
+
self.quality_indicators = {
|
| 40 |
+
"architecture": ["scalability", "consistency", "requirements"],
|
| 41 |
+
"optimization": ["profile", "caching", "parallel"],
|
| 42 |
+
"database": ["patterns", "relationships", "scaling"],
|
| 43 |
+
"distributed": ["circuit", "retry", "bulkhead"],
|
| 44 |
+
"default": ["decompose", "constraints", "trade-offs"],
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
async def evaluate(
|
| 48 |
+
self,
|
| 49 |
+
state: MCTSState,
|
| 50 |
+
rng,
|
| 51 |
+
max_depth: int = 10,
|
| 52 |
+
) -> float:
|
| 53 |
+
"""
|
| 54 |
+
Evaluate a state through heuristic analysis.
|
| 55 |
+
|
| 56 |
+
This combines:
|
| 57 |
+
- Depth bonus: rewards thorough exploration
|
| 58 |
+
- Action quality: rewards domain-appropriate actions
|
| 59 |
+
- Noise: adds exploration randomness
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
state: State to evaluate
|
| 63 |
+
rng: Random number generator
|
| 64 |
+
max_depth: Maximum depth (unused in heuristic)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Estimated value in [0, 1] range
|
| 68 |
+
"""
|
| 69 |
+
# Base value
|
| 70 |
+
base_value = 0.5
|
| 71 |
+
|
| 72 |
+
# Depth bonus: deeper exploration = more value (up to 0.3)
|
| 73 |
+
depth = state.features.get("depth", 0)
|
| 74 |
+
depth_bonus = min(depth * 0.1, 0.3)
|
| 75 |
+
|
| 76 |
+
# Action quality bonus
|
| 77 |
+
action_bonus = 0.0
|
| 78 |
+
last_action = state.features.get("last_action", "")
|
| 79 |
+
|
| 80 |
+
if last_action:
|
| 81 |
+
# Check if action contains quality indicators for this category
|
| 82 |
+
indicators = self.quality_indicators.get(self.category, self.quality_indicators["default"])
|
| 83 |
+
for term in indicators:
|
| 84 |
+
if term in last_action.lower():
|
| 85 |
+
action_bonus = 0.15
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
# Add exploration noise
|
| 89 |
+
noise = rng.uniform(-0.1, 0.1)
|
| 90 |
+
|
| 91 |
+
# Combine components
|
| 92 |
+
value = base_value + depth_bonus + action_bonus + noise
|
| 93 |
+
|
| 94 |
+
# Clamp to [0, 1]
|
| 95 |
+
return max(0.0, min(1.0, value))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MCTSDemo:
|
| 99 |
+
"""
|
| 100 |
+
Educational MCTS demonstration using the production framework.
|
| 101 |
+
|
| 102 |
+
This class wraps the production MCTSEngine to provide:
|
| 103 |
+
- Simple, educational interface for demos
|
| 104 |
+
- Category-based action selection
|
| 105 |
+
- Tree visualization for learning
|
| 106 |
+
- Deterministic behavior with seeds
|
| 107 |
+
|
| 108 |
+
Unlike the old mock implementation, this uses the real MCTS algorithm
|
| 109 |
+
with all its features: UCB1 selection, progressive widening, caching, etc.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, max_depth: int = 5):
|
| 113 |
+
"""
|
| 114 |
+
Initialize MCTS demo.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
max_depth: Maximum tree depth for exploration
|
| 118 |
+
"""
|
| 119 |
+
self.max_depth = max_depth
|
| 120 |
+
|
| 121 |
+
# Action templates for different query types
|
| 122 |
+
# These provide domain-specific reasoning paths
|
| 123 |
+
self.action_templates = {
|
| 124 |
+
"architecture": [
|
| 125 |
+
"Consider microservices for scalability",
|
| 126 |
+
"Evaluate monolith for simplicity",
|
| 127 |
+
"Analyze team capabilities",
|
| 128 |
+
"Assess deployment requirements",
|
| 129 |
+
"Review data consistency needs",
|
| 130 |
+
],
|
| 131 |
+
"optimization": [
|
| 132 |
+
"Profile application hotspots",
|
| 133 |
+
"Implement caching layer",
|
| 134 |
+
"Use parallel processing",
|
| 135 |
+
"Optimize database queries",
|
| 136 |
+
"Reduce memory allocations",
|
| 137 |
+
],
|
| 138 |
+
"database": [
|
| 139 |
+
"Analyze query patterns",
|
| 140 |
+
"Consider data relationships",
|
| 141 |
+
"Evaluate consistency requirements",
|
| 142 |
+
"Plan for horizontal scaling",
|
| 143 |
+
"Assess read/write ratios",
|
| 144 |
+
],
|
| 145 |
+
"distributed": [
|
| 146 |
+
"Implement circuit breakers",
|
| 147 |
+
"Add retry mechanisms",
|
| 148 |
+
"Use message queues",
|
| 149 |
+
"Apply bulkhead pattern",
|
| 150 |
+
"Design for eventual consistency",
|
| 151 |
+
],
|
| 152 |
+
"default": [
|
| 153 |
+
"Decompose the problem",
|
| 154 |
+
"Identify constraints",
|
| 155 |
+
"Evaluate trade-offs",
|
| 156 |
+
"Consider alternatives",
|
| 157 |
+
"Validate assumptions",
|
| 158 |
+
],
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def _categorize_query(self, query: str) -> str:
|
| 162 |
+
"""
|
| 163 |
+
Categorize query to select appropriate action templates.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
query: User's input query
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Category name for action selection
|
| 170 |
+
"""
|
| 171 |
+
query_lower = query.lower()
|
| 172 |
+
if "architecture" in query_lower or "microservice" in query_lower:
|
| 173 |
+
return "architecture"
|
| 174 |
+
elif "optim" in query_lower or "performance" in query_lower:
|
| 175 |
+
return "optimization"
|
| 176 |
+
elif "database" in query_lower or "sql" in query_lower:
|
| 177 |
+
return "database"
|
| 178 |
+
elif "distribut" in query_lower or "fault" in query_lower:
|
| 179 |
+
return "distributed"
|
| 180 |
+
return "default"
|
| 181 |
+
|
| 182 |
+
def _create_action_generator(self, category: str):
|
| 183 |
+
"""
|
| 184 |
+
Create action generator function for this query category.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
category: Query category
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Function that generates actions for a given state
|
| 191 |
+
"""
|
| 192 |
+
def action_generator(state: MCTSState) -> list[str]:
|
| 193 |
+
"""Generate available actions from current state."""
|
| 194 |
+
# Get category-specific actions
|
| 195 |
+
actions = self.action_templates.get(category, self.action_templates["default"])
|
| 196 |
+
|
| 197 |
+
# Filter out already-used actions (track via state features)
|
| 198 |
+
used_actions = state.features.get("used_actions", set())
|
| 199 |
+
available = [a for a in actions if a not in used_actions]
|
| 200 |
+
|
| 201 |
+
# If all actions used, allow re-exploring top 2
|
| 202 |
+
if not available:
|
| 203 |
+
return actions[:2]
|
| 204 |
+
|
| 205 |
+
return available
|
| 206 |
+
|
| 207 |
+
return action_generator
|
| 208 |
+
|
| 209 |
+
def _create_state_transition(self, category: str):
|
| 210 |
+
"""
|
| 211 |
+
Create state transition function for this query category.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
category: Query category
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Function that computes next state from current state + action
|
| 218 |
+
"""
|
| 219 |
+
def state_transition(state: MCTSState, action: str) -> MCTSState:
|
| 220 |
+
"""Compute next state by applying action."""
|
| 221 |
+
# Track action history
|
| 222 |
+
action_history = list(state.features.get("action_history", []))
|
| 223 |
+
action_history.append(action)
|
| 224 |
+
|
| 225 |
+
# Track used actions
|
| 226 |
+
used_actions = set(state.features.get("used_actions", set()))
|
| 227 |
+
used_actions.add(action)
|
| 228 |
+
|
| 229 |
+
# Increment depth
|
| 230 |
+
depth = state.features.get("depth", 0) + 1
|
| 231 |
+
|
| 232 |
+
# Create new state ID from action history
|
| 233 |
+
state_id = " -> ".join(action_history)
|
| 234 |
+
|
| 235 |
+
# Build new state
|
| 236 |
+
new_state = MCTSState(
|
| 237 |
+
state_id=state_id,
|
| 238 |
+
features={
|
| 239 |
+
"action_history": action_history,
|
| 240 |
+
"used_actions": used_actions,
|
| 241 |
+
"depth": depth,
|
| 242 |
+
"last_action": action,
|
| 243 |
+
"category": category,
|
| 244 |
+
},
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return new_state
|
| 248 |
+
|
| 249 |
+
return state_transition
|
| 250 |
+
|
| 251 |
+
def _generate_tree_visualization(self, root: MCTSNode, max_nodes: int = 20) -> str:
|
| 252 |
+
"""
|
| 253 |
+
Generate ASCII visualization of the MCTS tree.
|
| 254 |
+
|
| 255 |
+
This provides educational insight into the search process.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
root: Root node of the tree
|
| 259 |
+
max_nodes: Maximum nodes to display
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
ASCII art representation of the tree
|
| 263 |
+
"""
|
| 264 |
+
max_nodes = max(1, max_nodes)
|
| 265 |
+
lines = []
|
| 266 |
+
lines.append("MCTS Tree Visualization")
|
| 267 |
+
lines.append("=" * 50)
|
| 268 |
+
|
| 269 |
+
nodes_rendered = 0
|
| 270 |
+
|
| 271 |
+
def format_node(node: MCTSNode, prefix: str = "", is_last: bool = True) -> list[str]:
|
| 272 |
+
nonlocal nodes_rendered
|
| 273 |
+
result = []
|
| 274 |
+
|
| 275 |
+
# Node representation
|
| 276 |
+
connector = "└── " if is_last else "├── "
|
| 277 |
+
|
| 278 |
+
if nodes_rendered >= max_nodes:
|
| 279 |
+
result.append(f"{prefix}{connector}... (truncated)")
|
| 280 |
+
return result
|
| 281 |
+
|
| 282 |
+
nodes_rendered += 1
|
| 283 |
+
|
| 284 |
+
# Display action or state
|
| 285 |
+
node_str = f"{node.state.state_id[:30]}..."
|
| 286 |
+
if node.action:
|
| 287 |
+
node_str = f"{node.action[:25]}..."
|
| 288 |
+
|
| 289 |
+
stats = f"[V:{node.visits}, Q:{node.value:.3f}]"
|
| 290 |
+
|
| 291 |
+
result.append(f"{prefix}{connector}{node_str} {stats}")
|
| 292 |
+
|
| 293 |
+
# Recursively add children
|
| 294 |
+
new_prefix = prefix + (" " if is_last else "│ ")
|
| 295 |
+
|
| 296 |
+
# Limit children shown
|
| 297 |
+
children_to_show = node.children[:3]
|
| 298 |
+
for i, child in enumerate(children_to_show):
|
| 299 |
+
is_child_last = i == len(children_to_show) - 1
|
| 300 |
+
result.extend(format_node(child, new_prefix, is_child_last))
|
| 301 |
+
|
| 302 |
+
if len(node.children) > 3:
|
| 303 |
+
result.append(f"{new_prefix} ... and {len(node.children) - 3} more")
|
| 304 |
+
|
| 305 |
+
return result
|
| 306 |
+
|
| 307 |
+
# Start with root
|
| 308 |
+
lines.append(f"Root: {root.state.state_id[:40]}... [V:{root.visits}, Q:{root.value:.3f}]")
|
| 309 |
+
nodes_rendered += 1
|
| 310 |
+
|
| 311 |
+
for i, child in enumerate(root.children[:5]):
|
| 312 |
+
is_last = i == len(root.children[:5]) - 1
|
| 313 |
+
lines.extend(format_node(child, "", is_last))
|
| 314 |
+
|
| 315 |
+
if len(root.children) > 5:
|
| 316 |
+
lines.append(f"... and {len(root.children) - 5} more branches")
|
| 317 |
+
|
| 318 |
+
return "\n".join(lines)
|
| 319 |
+
|
| 320 |
+
async def search(
|
| 321 |
+
self,
|
| 322 |
+
query: str,
|
| 323 |
+
iterations: int = 25,
|
| 324 |
+
exploration_weight: float = 1.414,
|
| 325 |
+
seed: int | None = None,
|
| 326 |
+
) -> dict[str, Any]:
|
| 327 |
+
"""
|
| 328 |
+
Run MCTS search on the query using the production framework.
|
| 329 |
+
|
| 330 |
+
This method demonstrates the full MCTS algorithm:
|
| 331 |
+
1. Selection: UCB1-based tree traversal
|
| 332 |
+
2. Expansion: Progressive widening of nodes
|
| 333 |
+
3. Simulation: Heuristic evaluation (rollout)
|
| 334 |
+
4. Backpropagation: Value updates up the tree
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
query: The input query to analyze
|
| 338 |
+
iterations: Number of MCTS iterations (more = better but slower)
|
| 339 |
+
exploration_weight: UCB1 exploration constant (higher = more exploration)
|
| 340 |
+
seed: Random seed for deterministic results
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
Dictionary with:
|
| 344 |
+
- best_action: Recommended next step
|
| 345 |
+
- best_value: Confidence in recommendation
|
| 346 |
+
- statistics: Search metrics and performance data
|
| 347 |
+
- tree_visualization: ASCII art of search tree
|
| 348 |
+
"""
|
| 349 |
+
# Determine query category
|
| 350 |
+
category = self._categorize_query(query)
|
| 351 |
+
|
| 352 |
+
# Initialize MCTS engine with production features
|
| 353 |
+
engine = MCTSEngine(
|
| 354 |
+
seed=seed if seed is not None else 42,
|
| 355 |
+
exploration_weight=exploration_weight,
|
| 356 |
+
progressive_widening_k=1.0, # Moderate expansion
|
| 357 |
+
progressive_widening_alpha=0.5,
|
| 358 |
+
max_parallel_rollouts=4,
|
| 359 |
+
cache_size_limit=10000,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Create root state
|
| 363 |
+
root_state = MCTSState(
|
| 364 |
+
state_id=f"Query: {query[:50]}",
|
| 365 |
+
features={
|
| 366 |
+
"query": query,
|
| 367 |
+
"category": category,
|
| 368 |
+
"action_history": [],
|
| 369 |
+
"used_actions": set(),
|
| 370 |
+
"depth": 0,
|
| 371 |
+
"last_action": "",
|
| 372 |
+
},
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Create root node
|
| 376 |
+
root = MCTSNode(state=root_state, rng=engine.rng)
|
| 377 |
+
|
| 378 |
+
# Create domain-specific functions
|
| 379 |
+
action_generator = self._create_action_generator(category)
|
| 380 |
+
state_transition = self._create_state_transition(category)
|
| 381 |
+
rollout_policy = DemoRolloutPolicy(category, self.action_templates)
|
| 382 |
+
|
| 383 |
+
# Run MCTS search with production engine
|
| 384 |
+
best_action, stats = await engine.search(
|
| 385 |
+
root=root,
|
| 386 |
+
num_iterations=iterations,
|
| 387 |
+
action_generator=action_generator,
|
| 388 |
+
state_transition=state_transition,
|
| 389 |
+
rollout_policy=rollout_policy,
|
| 390 |
+
max_rollout_depth=self.max_depth,
|
| 391 |
+
selection_policy=SelectionPolicy.MAX_VISITS, # Most robust
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Extract best child info
|
| 395 |
+
best_child = None
|
| 396 |
+
if root.children:
|
| 397 |
+
best_child = max(root.children, key=lambda c: c.visits)
|
| 398 |
+
|
| 399 |
+
# Compile results for demo interface
|
| 400 |
+
result = {
|
| 401 |
+
"best_action": best_action or "No action found",
|
| 402 |
+
"best_value": round(best_child.value, 4) if best_child else 0.0,
|
| 403 |
+
"root_visits": root.visits,
|
| 404 |
+
"total_nodes": engine.get_cached_node_count(),
|
| 405 |
+
"max_depth_reached": engine.get_cached_tree_depth(),
|
| 406 |
+
"iterations_completed": iterations,
|
| 407 |
+
"exploration_weight": exploration_weight,
|
| 408 |
+
"seed": seed,
|
| 409 |
+
"category": category,
|
| 410 |
+
|
| 411 |
+
# Top actions sorted by visits
|
| 412 |
+
"top_actions": [
|
| 413 |
+
{
|
| 414 |
+
"action": child.action,
|
| 415 |
+
"visits": child.visits,
|
| 416 |
+
"value": round(child.value, 4),
|
| 417 |
+
"ucb1": round(
|
| 418 |
+
child.visits / root.visits if root.visits > 0 else 0.0, 4
|
| 419 |
+
), # Simplified UCB display
|
| 420 |
+
}
|
| 421 |
+
for child in sorted(root.children, key=lambda c: -c.visits)[:5]
|
| 422 |
+
],
|
| 423 |
+
|
| 424 |
+
# Framework statistics
|
| 425 |
+
"framework_stats": {
|
| 426 |
+
"cache_hits": stats.get("cache_hits", 0),
|
| 427 |
+
"cache_misses": stats.get("cache_misses", 0),
|
| 428 |
+
"cache_hit_rate": round(stats.get("cache_hit_rate", 0.0), 4),
|
| 429 |
+
"total_simulations": stats.get("total_simulations", 0),
|
| 430 |
+
},
|
| 431 |
+
|
| 432 |
+
# Educational visualization
|
| 433 |
+
"tree_visualization": self._generate_tree_visualization(root),
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
return result
|
demo_src/wandb_tracker.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Weights & Biases integration for experiment tracking.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import wandb
|
| 11 |
+
|
| 12 |
+
WANDB_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
WANDB_AVAILABLE = False
|
| 15 |
+
wandb = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class WandBTracker:
|
| 19 |
+
"""Weights & Biases experiment tracker for multi-agent MCTS demo."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True):
|
| 22 |
+
"""Initialize W&B tracker.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
project_name: W&B project name
|
| 26 |
+
entity: W&B entity (username or team)
|
| 27 |
+
enabled: Whether tracking is enabled
|
| 28 |
+
"""
|
| 29 |
+
self.project_name = project_name
|
| 30 |
+
self.entity = entity
|
| 31 |
+
self.enabled = enabled and WANDB_AVAILABLE
|
| 32 |
+
self.run = None
|
| 33 |
+
self.run_id = None
|
| 34 |
+
|
| 35 |
+
def is_available(self) -> bool:
|
| 36 |
+
"""Check if W&B is available."""
|
| 37 |
+
return WANDB_AVAILABLE
|
| 38 |
+
|
| 39 |
+
def init_run(
|
| 40 |
+
self, run_name: str | None = None, config: dict[str, Any] | None = None, tags: list[str] | None = None
|
| 41 |
+
) -> bool:
|
| 42 |
+
"""Initialize a new W&B run.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
run_name: Optional name for the run
|
| 46 |
+
config: Configuration dictionary to log
|
| 47 |
+
tags: Tags for the run
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
True if run initialized successfully, False otherwise
|
| 51 |
+
"""
|
| 52 |
+
if not self.enabled:
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
# Generate run name if not provided
|
| 57 |
+
if run_name is None:
|
| 58 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 59 |
+
run_name = f"mcts_query_{timestamp}"
|
| 60 |
+
|
| 61 |
+
# Default tags
|
| 62 |
+
if tags is None:
|
| 63 |
+
tags = ["demo", "multi-agent", "mcts"]
|
| 64 |
+
|
| 65 |
+
# Initialize run
|
| 66 |
+
self.run = wandb.init(
|
| 67 |
+
project=self.project_name,
|
| 68 |
+
entity=self.entity,
|
| 69 |
+
name=run_name,
|
| 70 |
+
config=config or {},
|
| 71 |
+
tags=tags,
|
| 72 |
+
reinit=True,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.run_id = self.run.id
|
| 76 |
+
return True
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"W&B init error: {e}")
|
| 80 |
+
self.enabled = False
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
def log_query_config(self, config: dict[str, Any]):
|
| 84 |
+
"""Log query configuration.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
config: Configuration dictionary with agent settings, MCTS params, etc.
|
| 88 |
+
"""
|
| 89 |
+
if not self.enabled or not self.run:
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
wandb.config.update(config)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"W&B config log error: {e}")
|
| 96 |
+
|
| 97 |
+
def log_agent_result(
|
| 98 |
+
self,
|
| 99 |
+
agent_name: str,
|
| 100 |
+
response: str,
|
| 101 |
+
confidence: float,
|
| 102 |
+
execution_time_ms: float,
|
| 103 |
+
reasoning_steps: list[str] | None = None,
|
| 104 |
+
):
|
| 105 |
+
"""Log individual agent results.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
agent_name: Name of the agent (HRM, TRM, MCTS)
|
| 109 |
+
response: Agent's response text
|
| 110 |
+
confidence: Confidence score (0-1)
|
| 111 |
+
execution_time_ms: Execution time in milliseconds
|
| 112 |
+
reasoning_steps: Optional list of reasoning steps
|
| 113 |
+
"""
|
| 114 |
+
if not self.enabled or not self.run:
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
metrics = {
|
| 119 |
+
f"{agent_name}/confidence": confidence,
|
| 120 |
+
f"{agent_name}/execution_time_ms": execution_time_ms,
|
| 121 |
+
f"{agent_name}/response_length": len(response),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if reasoning_steps:
|
| 125 |
+
metrics[f"{agent_name}/num_reasoning_steps"] = len(reasoning_steps)
|
| 126 |
+
|
| 127 |
+
wandb.log(metrics)
|
| 128 |
+
|
| 129 |
+
# Log response as text
|
| 130 |
+
wandb.log({f"{agent_name}/response": wandb.Html(f"<pre>{response}</pre>")})
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f"W&B agent result log error: {e}")
|
| 134 |
+
|
| 135 |
+
def log_mcts_result(self, mcts_result: dict[str, Any]):
|
| 136 |
+
"""Log MCTS-specific metrics.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
mcts_result: Dictionary containing MCTS search results
|
| 140 |
+
"""
|
| 141 |
+
if not self.enabled or not self.run:
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
# Extract key metrics
|
| 146 |
+
metrics = {
|
| 147 |
+
"mcts/best_value": mcts_result.get("best_value", 0),
|
| 148 |
+
"mcts/root_visits": mcts_result.get("root_visits", 0),
|
| 149 |
+
"mcts/total_nodes": mcts_result.get("total_nodes", 0),
|
| 150 |
+
"mcts/max_depth": mcts_result.get("max_depth_reached", 0),
|
| 151 |
+
"mcts/iterations": mcts_result.get("iterations_completed", 0),
|
| 152 |
+
"mcts/exploration_weight": mcts_result.get("exploration_weight", 1.414),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
wandb.log(metrics)
|
| 156 |
+
|
| 157 |
+
# Log top actions as table
|
| 158 |
+
if "top_actions" in mcts_result:
|
| 159 |
+
top_actions_data = []
|
| 160 |
+
for action in mcts_result["top_actions"]:
|
| 161 |
+
top_actions_data.append(
|
| 162 |
+
[
|
| 163 |
+
action.get("action", ""),
|
| 164 |
+
action.get("visits", 0),
|
| 165 |
+
action.get("value", 0),
|
| 166 |
+
action.get("ucb1", 0),
|
| 167 |
+
]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if top_actions_data:
|
| 171 |
+
table = wandb.Table(data=top_actions_data, columns=["Action", "Visits", "Value", "UCB1"])
|
| 172 |
+
wandb.log({"mcts/top_actions_table": table})
|
| 173 |
+
|
| 174 |
+
# Log tree visualization as text artifact
|
| 175 |
+
if "tree_visualization" in mcts_result:
|
| 176 |
+
wandb.log({"mcts/tree_visualization": wandb.Html(f"<pre>{mcts_result['tree_visualization']}</pre>")})
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"W&B MCTS result log error: {e}")
|
| 180 |
+
|
| 181 |
+
def log_consensus(self, consensus_score: float, agents_used: list[str], final_response: str):
|
| 182 |
+
"""Log consensus metrics.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
consensus_score: Agreement score between agents (0-1)
|
| 186 |
+
agents_used: List of agent names that were used
|
| 187 |
+
final_response: Final synthesized response
|
| 188 |
+
"""
|
| 189 |
+
if not self.enabled or not self.run:
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
wandb.log(
|
| 194 |
+
{
|
| 195 |
+
"consensus/score": consensus_score,
|
| 196 |
+
"consensus/num_agents": len(agents_used),
|
| 197 |
+
"consensus/agents": ", ".join(agents_used),
|
| 198 |
+
"consensus/final_response_length": len(final_response),
|
| 199 |
+
}
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Categorize consensus level
|
| 203 |
+
if consensus_score > 0.7:
|
| 204 |
+
consensus_level = "high"
|
| 205 |
+
elif consensus_score > 0.4:
|
| 206 |
+
consensus_level = "medium"
|
| 207 |
+
else:
|
| 208 |
+
consensus_level = "low"
|
| 209 |
+
|
| 210 |
+
wandb.log({"consensus/level": consensus_level})
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"W&B consensus log error: {e}")
|
| 214 |
+
|
| 215 |
+
def log_performance(self, total_time_ms: float):
|
| 216 |
+
"""Log overall performance metrics.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
total_time_ms: Total execution time in milliseconds
|
| 220 |
+
"""
|
| 221 |
+
if not self.enabled or not self.run:
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
wandb.log({"performance/total_time_ms": total_time_ms, "performance/total_time_s": total_time_ms / 1000})
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"W&B performance log error: {e}")
|
| 228 |
+
|
| 229 |
+
def log_full_result(self, result: dict[str, Any]):
|
| 230 |
+
"""Log the complete result as an artifact.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
result: Full framework result dictionary
|
| 234 |
+
"""
|
| 235 |
+
if not self.enabled or not self.run:
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
# Create artifact
|
| 240 |
+
artifact = wandb.Artifact(name=f"query_result_{self.run_id}", type="result")
|
| 241 |
+
|
| 242 |
+
# Add result as JSON
|
| 243 |
+
import json
|
| 244 |
+
import tempfile
|
| 245 |
+
|
| 246 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
| 247 |
+
json.dump(result, f, indent=2, default=str)
|
| 248 |
+
temp_path = f.name
|
| 249 |
+
|
| 250 |
+
artifact.add_file(temp_path, name="result.json")
|
| 251 |
+
wandb.log_artifact(artifact)
|
| 252 |
+
|
| 253 |
+
# Clean up temp file
|
| 254 |
+
os.unlink(temp_path)
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(f"W&B full result log error: {e}")
|
| 258 |
+
|
| 259 |
+
def log_query_summary(
|
| 260 |
+
self, query: str, use_hrm: bool, use_trm: bool, use_mcts: bool, consensus_score: float, total_time_ms: float
|
| 261 |
+
):
|
| 262 |
+
"""Log a summary row for the query.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
query: The input query
|
| 266 |
+
use_hrm: Whether HRM was enabled
|
| 267 |
+
use_trm: Whether TRM was enabled
|
| 268 |
+
use_mcts: Whether MCTS was enabled
|
| 269 |
+
consensus_score: Final consensus score
|
| 270 |
+
total_time_ms: Total execution time
|
| 271 |
+
"""
|
| 272 |
+
if not self.enabled or not self.run:
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
# Create summary table entry
|
| 277 |
+
summary_data = [
|
| 278 |
+
[
|
| 279 |
+
query[:100] + "..." if len(query) > 100 else query,
|
| 280 |
+
"✓" if use_hrm else "✗",
|
| 281 |
+
"✓" if use_trm else "✗",
|
| 282 |
+
"✓" if use_mcts else "✗",
|
| 283 |
+
f"{consensus_score:.1%}",
|
| 284 |
+
f"{total_time_ms:.2f}",
|
| 285 |
+
]
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
table = wandb.Table(data=summary_data, columns=["Query", "HRM", "TRM", "MCTS", "Consensus", "Time (ms)"])
|
| 289 |
+
|
| 290 |
+
wandb.log({"query_summary": table})
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
print(f"W&B summary log error: {e}")
|
| 294 |
+
|
| 295 |
+
def finish_run(self):
|
| 296 |
+
"""Finish the current W&B run."""
|
| 297 |
+
if not self.enabled or not self.run:
|
| 298 |
+
return
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
wandb.finish()
|
| 302 |
+
self.run = None
|
| 303 |
+
self.run_id = None
|
| 304 |
+
except Exception as e:
|
| 305 |
+
print(f"W&B finish error: {e}")
|
| 306 |
+
|
| 307 |
+
def get_run_url(self) -> str | None:
|
| 308 |
+
"""Get the URL for the current run.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
URL string or None if no active run
|
| 312 |
+
"""
|
| 313 |
+
if not self.enabled or not self.run:
|
| 314 |
+
return None
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
return self.run.get_url()
|
| 318 |
+
except Exception:
|
| 319 |
+
return None
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# Global tracker instance
|
| 323 |
+
_global_tracker: WandBTracker | None = None
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def get_tracker(
|
| 327 |
+
project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True
|
| 328 |
+
) -> WandBTracker:
|
| 329 |
+
"""Get or create the global W&B tracker.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
project_name: W&B project name
|
| 333 |
+
entity: W&B entity
|
| 334 |
+
enabled: Whether tracking is enabled
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
WandBTracker instance
|
| 338 |
+
"""
|
| 339 |
+
global _global_tracker
|
| 340 |
+
|
| 341 |
+
if _global_tracker is None:
|
| 342 |
+
_global_tracker = WandBTracker(project_name=project_name, entity=entity, enabled=enabled)
|
| 343 |
+
|
| 344 |
+
return _global_tracker
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def is_wandb_available() -> bool:
|
| 348 |
+
"""Check if W&B is available."""
|
| 349 |
+
return WANDB_AVAILABLE
|
models/bert_lora/final_model/README.md
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: prajjwal1/bert-mini
|
| 3 |
+
library_name: peft
|
| 4 |
+
tags:
|
| 5 |
+
- base_model:adapter:prajjwal1/bert-mini
|
| 6 |
+
- lora
|
| 7 |
+
- transformers
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Model Card for Model ID
|
| 11 |
+
|
| 12 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Model Details
|
| 17 |
+
|
| 18 |
+
### Model Description
|
| 19 |
+
|
| 20 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
- **Developed by:** [More Information Needed]
|
| 25 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 26 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 27 |
+
- **Model type:** [More Information Needed]
|
| 28 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 29 |
+
- **License:** [More Information Needed]
|
| 30 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 31 |
+
|
| 32 |
+
### Model Sources [optional]
|
| 33 |
+
|
| 34 |
+
<!-- Provide the basic links for the model. -->
|
| 35 |
+
|
| 36 |
+
- **Repository:** [More Information Needed]
|
| 37 |
+
- **Paper [optional]:** [More Information Needed]
|
| 38 |
+
- **Demo [optional]:** [More Information Needed]
|
| 39 |
+
|
| 40 |
+
## Uses
|
| 41 |
+
|
| 42 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 43 |
+
|
| 44 |
+
### Direct Use
|
| 45 |
+
|
| 46 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 47 |
+
|
| 48 |
+
[More Information Needed]
|
| 49 |
+
|
| 50 |
+
### Downstream Use [optional]
|
| 51 |
+
|
| 52 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 53 |
+
|
| 54 |
+
[More Information Needed]
|
| 55 |
+
|
| 56 |
+
### Out-of-Scope Use
|
| 57 |
+
|
| 58 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 59 |
+
|
| 60 |
+
[More Information Needed]
|
| 61 |
+
|
| 62 |
+
## Bias, Risks, and Limitations
|
| 63 |
+
|
| 64 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 65 |
+
|
| 66 |
+
[More Information Needed]
|
| 67 |
+
|
| 68 |
+
### Recommendations
|
| 69 |
+
|
| 70 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 71 |
+
|
| 72 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 73 |
+
|
| 74 |
+
## How to Get Started with the Model
|
| 75 |
+
|
| 76 |
+
Use the code below to get started with the model.
|
| 77 |
+
|
| 78 |
+
[More Information Needed]
|
| 79 |
+
|
| 80 |
+
## Training Details
|
| 81 |
+
|
| 82 |
+
### Training Data
|
| 83 |
+
|
| 84 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 85 |
+
|
| 86 |
+
[More Information Needed]
|
| 87 |
+
|
| 88 |
+
### Training Procedure
|
| 89 |
+
|
| 90 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 91 |
+
|
| 92 |
+
#### Preprocessing [optional]
|
| 93 |
+
|
| 94 |
+
[More Information Needed]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
#### Training Hyperparameters
|
| 98 |
+
|
| 99 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 100 |
+
|
| 101 |
+
#### Speeds, Sizes, Times [optional]
|
| 102 |
+
|
| 103 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 104 |
+
|
| 105 |
+
[More Information Needed]
|
| 106 |
+
|
| 107 |
+
## Evaluation
|
| 108 |
+
|
| 109 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 110 |
+
|
| 111 |
+
### Testing Data, Factors & Metrics
|
| 112 |
+
|
| 113 |
+
#### Testing Data
|
| 114 |
+
|
| 115 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 116 |
+
|
| 117 |
+
[More Information Needed]
|
| 118 |
+
|
| 119 |
+
#### Factors
|
| 120 |
+
|
| 121 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 122 |
+
|
| 123 |
+
[More Information Needed]
|
| 124 |
+
|
| 125 |
+
#### Metrics
|
| 126 |
+
|
| 127 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
### Results
|
| 132 |
+
|
| 133 |
+
[More Information Needed]
|
| 134 |
+
|
| 135 |
+
#### Summary
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
## Model Examination [optional]
|
| 140 |
+
|
| 141 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 142 |
+
|
| 143 |
+
[More Information Needed]
|
| 144 |
+
|
| 145 |
+
## Environmental Impact
|
| 146 |
+
|
| 147 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 148 |
+
|
| 149 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 150 |
+
|
| 151 |
+
- **Hardware Type:** [More Information Needed]
|
| 152 |
+
- **Hours used:** [More Information Needed]
|
| 153 |
+
- **Cloud Provider:** [More Information Needed]
|
| 154 |
+
- **Compute Region:** [More Information Needed]
|
| 155 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 156 |
+
|
| 157 |
+
## Technical Specifications [optional]
|
| 158 |
+
|
| 159 |
+
### Model Architecture and Objective
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
### Compute Infrastructure
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Hardware
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
#### Software
|
| 172 |
+
|
| 173 |
+
[More Information Needed]
|
| 174 |
+
|
| 175 |
+
## Citation [optional]
|
| 176 |
+
|
| 177 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 178 |
+
|
| 179 |
+
**BibTeX:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
**APA:**
|
| 184 |
+
|
| 185 |
+
[More Information Needed]
|
| 186 |
+
|
| 187 |
+
## Glossary [optional]
|
| 188 |
+
|
| 189 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## More Information [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Authors [optional]
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
|
| 201 |
+
## Model Card Contact
|
| 202 |
+
|
| 203 |
+
[More Information Needed]
|
| 204 |
+
### Framework versions
|
| 205 |
+
|
| 206 |
+
- PEFT 0.17.1
|
models/bert_lora/final_model/adapter_config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "prajjwal1/bert-mini",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 16,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.1,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": [
|
| 22 |
+
"classifier",
|
| 23 |
+
"score"
|
| 24 |
+
],
|
| 25 |
+
"peft_type": "LORA",
|
| 26 |
+
"qalora_group_size": 16,
|
| 27 |
+
"r": 4,
|
| 28 |
+
"rank_pattern": {},
|
| 29 |
+
"revision": null,
|
| 30 |
+
"target_modules": [
|
| 31 |
+
"query",
|
| 32 |
+
"value"
|
| 33 |
+
],
|
| 34 |
+
"target_parameters": null,
|
| 35 |
+
"task_type": "SEQ_CLS",
|
| 36 |
+
"trainable_token_indices": null,
|
| 37 |
+
"use_dora": false,
|
| 38 |
+
"use_qalora": false,
|
| 39 |
+
"use_rslora": false
|
| 40 |
+
}
|
models/bert_lora/final_model/adapter_model.safetensors
ADDED
|
Binary file (71 kB). View file
|
|
|
models/bert_lora/generated_dataset.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/bert_lora/training_results.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config": {
|
| 3 |
+
"model_name": "prajjwal1/bert-mini",
|
| 4 |
+
"lora_r": 4,
|
| 5 |
+
"lora_alpha": 16,
|
| 6 |
+
"lora_dropout": 0.1,
|
| 7 |
+
"lr": 0.001,
|
| 8 |
+
"batch_size": 16,
|
| 9 |
+
"epochs": 5,
|
| 10 |
+
"warmup_steps": 100,
|
| 11 |
+
"seed": 42,
|
| 12 |
+
"num_samples": 1000,
|
| 13 |
+
"data_path": null,
|
| 14 |
+
"balanced": true,
|
| 15 |
+
"output_dir": "models/bert_lora"
|
| 16 |
+
},
|
| 17 |
+
"train_history": {
|
| 18 |
+
"train_loss": 1.1033503922549162,
|
| 19 |
+
"train_runtime": 11.0946,
|
| 20 |
+
"train_samples_per_second": 315.018,
|
| 21 |
+
"epochs": 5,
|
| 22 |
+
"final_metrics": {
|
| 23 |
+
"train_runtime": 11.0946,
|
| 24 |
+
"train_samples_per_second": 315.018,
|
| 25 |
+
"train_steps_per_second": 19.829,
|
| 26 |
+
"total_flos": 34821822412800.0,
|
| 27 |
+
"train_loss": 1.1033503922549162,
|
| 28 |
+
"epoch": 5.0
|
| 29 |
+
},
|
| 30 |
+
"eval_results": {
|
| 31 |
+
"eval_loss": 1.0453400611877441,
|
| 32 |
+
"eval_accuracy": 0.47651006711409394,
|
| 33 |
+
"eval_runtime": 0.1251,
|
| 34 |
+
"eval_samples_per_second": 1191.171,
|
| 35 |
+
"eval_steps_per_second": 79.944,
|
| 36 |
+
"epoch": 5.0
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"test_results": {
|
| 40 |
+
"loss": 1.0559743153338401,
|
| 41 |
+
"accuracy": 0.4768211920529801
|
| 42 |
+
},
|
| 43 |
+
"model_params": {
|
| 44 |
+
"total_params": 11188486,
|
| 45 |
+
"trainable_params": 17155,
|
| 46 |
+
"trainable_percentage": 0.15
|
| 47 |
+
}
|
| 48 |
+
}
|
models/rnn_meta_controller.history.json
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config": {
|
| 3 |
+
"hidden_dim": 64,
|
| 4 |
+
"num_layers": 1,
|
| 5 |
+
"dropout": 0.1,
|
| 6 |
+
"lr": 0.001,
|
| 7 |
+
"batch_size": 32,
|
| 8 |
+
"epochs": 20,
|
| 9 |
+
"patience": 5,
|
| 10 |
+
"seed": 42,
|
| 11 |
+
"num_samples": 1000
|
| 12 |
+
},
|
| 13 |
+
"training_history": {
|
| 14 |
+
"train_losses": [
|
| 15 |
+
1.060307163180727,
|
| 16 |
+
0.9014069383794611,
|
| 17 |
+
0.6105747597687172,
|
| 18 |
+
0.35656250968123926,
|
| 19 |
+
0.22574858390020602,
|
| 20 |
+
0.16157509059165465,
|
| 21 |
+
0.12456387586214325,
|
| 22 |
+
0.10158110240643675,
|
| 23 |
+
0.08592396827809738,
|
| 24 |
+
0.07474524908783761,
|
| 25 |
+
0.06479036057311477,
|
| 26 |
+
0.057878461638183304,
|
| 27 |
+
0.052609961931452606,
|
| 28 |
+
0.04809149278497154,
|
| 29 |
+
0.043710527828697,
|
| 30 |
+
0.041286276738074695,
|
| 31 |
+
0.03756282673302022,
|
| 32 |
+
0.03491098284156936,
|
| 33 |
+
0.031911260236731985,
|
| 34 |
+
0.030496817025722878
|
| 35 |
+
],
|
| 36 |
+
"val_losses": [
|
| 37 |
+
1.0059996803601583,
|
| 38 |
+
0.7808501919110616,
|
| 39 |
+
0.47826388080914817,
|
| 40 |
+
0.29279296696186063,
|
| 41 |
+
0.2008462185660998,
|
| 42 |
+
0.1529717780649662,
|
| 43 |
+
0.12299496456980705,
|
| 44 |
+
0.10291122049093246,
|
| 45 |
+
0.08860023791591326,
|
| 46 |
+
0.07790809428940217,
|
| 47 |
+
0.06982718824098508,
|
| 48 |
+
0.06387854401643077,
|
| 49 |
+
0.05984275036801894,
|
| 50 |
+
0.05463591649507483,
|
| 51 |
+
0.04938021237030625,
|
| 52 |
+
0.0452831008626769,
|
| 53 |
+
0.04252756762628754,
|
| 54 |
+
0.039516554485696055,
|
| 55 |
+
0.038632405494960644,
|
| 56 |
+
0.035608950459087886
|
| 57 |
+
],
|
| 58 |
+
"val_accuracies": [
|
| 59 |
+
0.8466666666666667,
|
| 60 |
+
0.92,
|
| 61 |
+
0.9822222222222222,
|
| 62 |
+
0.9933333333333333,
|
| 63 |
+
0.9911111111111112,
|
| 64 |
+
0.9933333333333333,
|
| 65 |
+
0.9955555555555555,
|
| 66 |
+
0.9955555555555555,
|
| 67 |
+
0.9955555555555555,
|
| 68 |
+
0.9955555555555555,
|
| 69 |
+
0.9955555555555555,
|
| 70 |
+
0.9977777777777778,
|
| 71 |
+
0.9933333333333333,
|
| 72 |
+
0.9933333333333333,
|
| 73 |
+
0.9977777777777778,
|
| 74 |
+
0.9977777777777778,
|
| 75 |
+
0.9977777777777778,
|
| 76 |
+
0.9977777777777778,
|
| 77 |
+
0.9955555555555555,
|
| 78 |
+
0.9977777777777778
|
| 79 |
+
],
|
| 80 |
+
"best_epoch": 20,
|
| 81 |
+
"best_val_loss": 0.035608950459087886,
|
| 82 |
+
"best_val_accuracy": 0.9977777777777778,
|
| 83 |
+
"stopped_early": false,
|
| 84 |
+
"total_epochs": 20
|
| 85 |
+
},
|
| 86 |
+
"test_results": {
|
| 87 |
+
"loss": 0.022989434589787076,
|
| 88 |
+
"accuracy": 0.9977777777777778,
|
| 89 |
+
"per_class_metrics": {
|
| 90 |
+
"hrm": {
|
| 91 |
+
"precision": 1.0,
|
| 92 |
+
"recall": 1.0,
|
| 93 |
+
"f1_score": 1.0,
|
| 94 |
+
"support": 153
|
| 95 |
+
},
|
| 96 |
+
"trm": {
|
| 97 |
+
"precision": 0.9933774834437086,
|
| 98 |
+
"recall": 1.0,
|
| 99 |
+
"f1_score": 0.9966777408637874,
|
| 100 |
+
"support": 150
|
| 101 |
+
},
|
| 102 |
+
"mcts": {
|
| 103 |
+
"precision": 1.0,
|
| 104 |
+
"recall": 0.9931972789115646,
|
| 105 |
+
"f1_score": 0.9965870307167235,
|
| 106 |
+
"support": 147
|
| 107 |
+
}
|
| 108 |
+
},
|
| 109 |
+
"confusion_matrix": [
|
| 110 |
+
[
|
| 111 |
+
153,
|
| 112 |
+
0,
|
| 113 |
+
0
|
| 114 |
+
],
|
| 115 |
+
[
|
| 116 |
+
0,
|
| 117 |
+
150,
|
| 118 |
+
0
|
| 119 |
+
],
|
| 120 |
+
[
|
| 121 |
+
0,
|
| 122 |
+
1,
|
| 123 |
+
146
|
| 124 |
+
]
|
| 125 |
+
],
|
| 126 |
+
"total_samples": 450
|
| 127 |
+
}
|
| 128 |
+
}
|
models/rnn_meta_controller.pt
ADDED
|
Binary file (61.6 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LangGraph Multi-Agent MCTS Demo - Dependencies
|
| 2 |
+
# Optimized for Hugging Face Spaces deployment with trained models
|
| 3 |
+
|
| 4 |
+
# Core UI Framework
|
| 5 |
+
gradio>=4.0.0,<5.0.0
|
| 6 |
+
|
| 7 |
+
# Numerical computation
|
| 8 |
+
numpy>=1.24.0,<2.0.0
|
| 9 |
+
|
| 10 |
+
# Machine Learning - Neural Models
|
| 11 |
+
torch>=2.1.0
|
| 12 |
+
transformers>=4.40.0
|
| 13 |
+
peft>=0.7.0
|
| 14 |
+
sentence-transformers>=2.2.0
|
| 15 |
+
|
| 16 |
+
# Configuration
|
| 17 |
+
pyyaml>=6.0
|
| 18 |
+
|
| 19 |
+
# Experiment Tracking
|
| 20 |
+
wandb>=0.16.0
|
| 21 |
+
|
| 22 |
+
# Required for Gradio OAuth and model loading
|
| 23 |
+
huggingface_hub>=0.20.0,<0.30.0
|
| 24 |
+
|
| 25 |
+
# Note: This demo now uses REAL trained models:
|
| 26 |
+
# - RNN Meta-Controller (models/rnn_meta_controller.pt)
|
| 27 |
+
# - BERT with LoRA adapters (models/bert_lora/final_model/)
|
| 28 |
+
# - Actual HRM and TRM agent implementations
|
src/__init__.py
ADDED
|
File without changes
|
src/adapters/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapters package for external service integrations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .llm import BaseLLMClient, LLMResponse, create_client
|
| 6 |
+
|
| 7 |
+
__all__ = ["create_client", "BaseLLMClient", "LLMResponse"]
|
src/adapters/llm/__init__.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Client Factory and Provider Registry.
|
| 3 |
+
|
| 4 |
+
This module provides a factory function to instantiate the correct LLM client
|
| 5 |
+
based on provider settings, with lazy loading of adapters.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import importlib
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
from .base import BaseLLMClient, LLMClient, LLMResponse, LLMToolResponse, ToolCall
|
| 13 |
+
from .exceptions import (
|
| 14 |
+
CircuitBreakerOpenError,
|
| 15 |
+
LLMAuthenticationError,
|
| 16 |
+
LLMClientError,
|
| 17 |
+
LLMConnectionError,
|
| 18 |
+
LLMContentFilterError,
|
| 19 |
+
LLMContextLengthError,
|
| 20 |
+
LLMInvalidRequestError,
|
| 21 |
+
LLMModelNotFoundError,
|
| 22 |
+
LLMQuotaExceededError,
|
| 23 |
+
LLMRateLimitError,
|
| 24 |
+
LLMResponseParseError,
|
| 25 |
+
LLMServerError,
|
| 26 |
+
LLMStreamError,
|
| 27 |
+
LLMTimeoutError,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Provider registry with lazy loading
|
| 33 |
+
# Maps provider name to (module_path, class_name)
|
| 34 |
+
_PROVIDER_REGISTRY: dict[str, tuple[str, str]] = {
|
| 35 |
+
"openai": ("src.adapters.llm.openai_client", "OpenAIClient"),
|
| 36 |
+
"anthropic": ("src.adapters.llm.anthropic_client", "AnthropicClient"),
|
| 37 |
+
"lmstudio": ("src.adapters.llm.lmstudio_client", "LMStudioClient"),
|
| 38 |
+
"local": ("src.adapters.llm.lmstudio_client", "LMStudioClient"), # Alias
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Cache for loaded client classes
|
| 42 |
+
_CLIENT_CACHE: dict[str, type[BaseLLMClient]] = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def register_provider(name: str, module_path: str, class_name: str, override: bool = False) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Register a new LLM provider.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
name: Provider identifier (e.g., "azure", "bedrock")
|
| 51 |
+
module_path: Full module path (e.g., "src.adapters.llm.azure_client")
|
| 52 |
+
class_name: Class name in the module (e.g., "AzureOpenAIClient")
|
| 53 |
+
override: If True, allow overriding existing provider
|
| 54 |
+
"""
|
| 55 |
+
if name in _PROVIDER_REGISTRY and not override:
|
| 56 |
+
raise ValueError(f"Provider '{name}' already registered. Use override=True to replace.")
|
| 57 |
+
|
| 58 |
+
_PROVIDER_REGISTRY[name] = (module_path, class_name)
|
| 59 |
+
# Clear cache if overriding
|
| 60 |
+
if name in _CLIENT_CACHE:
|
| 61 |
+
del _CLIENT_CACHE[name]
|
| 62 |
+
|
| 63 |
+
logger.info(f"Registered LLM provider: {name} -> {module_path}.{class_name}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def list_providers() -> list[str]:
|
| 67 |
+
"""
|
| 68 |
+
List all registered provider names.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
List of provider identifiers
|
| 72 |
+
"""
|
| 73 |
+
return list(_PROVIDER_REGISTRY.keys())
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_provider_class(provider: str) -> type[BaseLLMClient]:
|
| 77 |
+
"""
|
| 78 |
+
Get the client class for a provider (with lazy loading).
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
provider: Provider identifier
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Client class (not instantiated)
|
| 85 |
+
|
| 86 |
+
Raises:
|
| 87 |
+
ValueError: If provider not registered
|
| 88 |
+
ImportError: If module cannot be loaded
|
| 89 |
+
"""
|
| 90 |
+
if provider not in _PROVIDER_REGISTRY:
|
| 91 |
+
available = ", ".join(list_providers())
|
| 92 |
+
raise ValueError(f"Unknown provider '{provider}'. Available: {available}")
|
| 93 |
+
|
| 94 |
+
# Check cache first
|
| 95 |
+
if provider in _CLIENT_CACHE:
|
| 96 |
+
return _CLIENT_CACHE[provider]
|
| 97 |
+
|
| 98 |
+
# Lazy load the module
|
| 99 |
+
module_path, class_name = _PROVIDER_REGISTRY[provider]
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
module = importlib.import_module(module_path)
|
| 103 |
+
client_class = getattr(module, class_name)
|
| 104 |
+
except ImportError as e:
|
| 105 |
+
raise ImportError(f"Failed to load provider '{provider}': {e}") from e
|
| 106 |
+
except AttributeError as e:
|
| 107 |
+
raise ImportError(f"Class '{class_name}' not found in module '{module_path}'") from e
|
| 108 |
+
|
| 109 |
+
# Cache for future use
|
| 110 |
+
_CLIENT_CACHE[provider] = client_class
|
| 111 |
+
return client_class
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def create_client(
|
| 115 |
+
provider: str = "openai",
|
| 116 |
+
*,
|
| 117 |
+
api_key: str | None = None,
|
| 118 |
+
model: str | None = None,
|
| 119 |
+
base_url: str | None = None,
|
| 120 |
+
timeout: float | None = None,
|
| 121 |
+
max_retries: int | None = None,
|
| 122 |
+
**kwargs: Any,
|
| 123 |
+
) -> BaseLLMClient:
|
| 124 |
+
"""
|
| 125 |
+
Create an LLM client instance.
|
| 126 |
+
|
| 127 |
+
This is the main factory function for creating provider clients.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
provider: Provider name ("openai", "anthropic", "lmstudio", etc.)
|
| 131 |
+
api_key: API key (may be optional for some providers)
|
| 132 |
+
model: Model identifier
|
| 133 |
+
base_url: Base URL for API
|
| 134 |
+
timeout: Request timeout in seconds
|
| 135 |
+
max_retries: Maximum retry attempts
|
| 136 |
+
**kwargs: Provider-specific parameters
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Configured LLMClient instance
|
| 140 |
+
|
| 141 |
+
Examples:
|
| 142 |
+
# OpenAI client
|
| 143 |
+
client = create_client("openai", model="gpt-4-turbo-preview")
|
| 144 |
+
|
| 145 |
+
# Anthropic client
|
| 146 |
+
client = create_client("anthropic", model="sonnet")
|
| 147 |
+
|
| 148 |
+
# Local LM Studio
|
| 149 |
+
client = create_client("lmstudio", base_url="http://localhost:1234/v1")
|
| 150 |
+
|
| 151 |
+
# With custom settings
|
| 152 |
+
client = create_client(
|
| 153 |
+
"openai",
|
| 154 |
+
api_key="sk-...",
|
| 155 |
+
timeout=120.0,
|
| 156 |
+
max_retries=5,
|
| 157 |
+
organization="org-..."
|
| 158 |
+
)
|
| 159 |
+
"""
|
| 160 |
+
client_class = get_provider_class(provider)
|
| 161 |
+
|
| 162 |
+
# Build kwargs for client initialization
|
| 163 |
+
init_kwargs = {**kwargs}
|
| 164 |
+
|
| 165 |
+
if api_key is not None:
|
| 166 |
+
init_kwargs["api_key"] = api_key
|
| 167 |
+
if model is not None:
|
| 168 |
+
init_kwargs["model"] = model
|
| 169 |
+
if base_url is not None:
|
| 170 |
+
init_kwargs["base_url"] = base_url
|
| 171 |
+
if timeout is not None:
|
| 172 |
+
init_kwargs["timeout"] = timeout
|
| 173 |
+
if max_retries is not None:
|
| 174 |
+
init_kwargs["max_retries"] = max_retries
|
| 175 |
+
|
| 176 |
+
logger.info(f"Creating {provider} client with model={model or 'default'}")
|
| 177 |
+
|
| 178 |
+
return client_class(**init_kwargs)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def create_client_from_config(config: dict) -> BaseLLMClient:
|
| 182 |
+
"""
|
| 183 |
+
Create an LLM client from a configuration dictionary.
|
| 184 |
+
|
| 185 |
+
Useful for loading settings from YAML/JSON config files.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
config: Configuration dictionary with keys:
|
| 189 |
+
- provider: Required provider name
|
| 190 |
+
- Other keys passed to create_client
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Configured LLMClient instance
|
| 194 |
+
|
| 195 |
+
Example:
|
| 196 |
+
config = {
|
| 197 |
+
"provider": "openai",
|
| 198 |
+
"model": "gpt-4-turbo-preview",
|
| 199 |
+
"timeout": 60.0,
|
| 200 |
+
"max_retries": 3
|
| 201 |
+
}
|
| 202 |
+
client = create_client_from_config(config)
|
| 203 |
+
"""
|
| 204 |
+
config = config.copy()
|
| 205 |
+
provider = config.pop("provider", "openai")
|
| 206 |
+
return create_client(provider, **config)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Convenience aliases for common use cases
|
| 210 |
+
def create_openai_client(**kwargs) -> BaseLLMClient:
|
| 211 |
+
"""Create an OpenAI client."""
|
| 212 |
+
return create_client("openai", **kwargs)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def create_anthropic_client(**kwargs) -> BaseLLMClient:
|
| 216 |
+
"""Create an Anthropic Claude client."""
|
| 217 |
+
return create_client("anthropic", **kwargs)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def create_local_client(**kwargs) -> BaseLLMClient:
|
| 221 |
+
"""Create a local LM Studio client."""
|
| 222 |
+
return create_client("lmstudio", **kwargs)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
__all__ = [
|
| 226 |
+
# Base types
|
| 227 |
+
"LLMClient",
|
| 228 |
+
"LLMResponse",
|
| 229 |
+
"LLMToolResponse",
|
| 230 |
+
"ToolCall",
|
| 231 |
+
"BaseLLMClient",
|
| 232 |
+
# Exceptions
|
| 233 |
+
"LLMClientError",
|
| 234 |
+
"LLMAuthenticationError",
|
| 235 |
+
"LLMRateLimitError",
|
| 236 |
+
"LLMQuotaExceededError",
|
| 237 |
+
"LLMModelNotFoundError",
|
| 238 |
+
"LLMContextLengthError",
|
| 239 |
+
"LLMInvalidRequestError",
|
| 240 |
+
"LLMTimeoutError",
|
| 241 |
+
"LLMConnectionError",
|
| 242 |
+
"LLMServerError",
|
| 243 |
+
"LLMResponseParseError",
|
| 244 |
+
"LLMStreamError",
|
| 245 |
+
"LLMContentFilterError",
|
| 246 |
+
"CircuitBreakerOpenError",
|
| 247 |
+
# Factory functions
|
| 248 |
+
"create_client",
|
| 249 |
+
"create_client_from_config",
|
| 250 |
+
"create_openai_client",
|
| 251 |
+
"create_anthropic_client",
|
| 252 |
+
"create_local_client",
|
| 253 |
+
# Registry functions
|
| 254 |
+
"register_provider",
|
| 255 |
+
"list_providers",
|
| 256 |
+
"get_provider_class",
|
| 257 |
+
]
|
src/adapters/llm/anthropic_client.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anthropic Claude LLM client adapter.
|
| 3 |
+
|
| 4 |
+
Implements the LLMClient protocol for Anthropic's Messages API.
|
| 5 |
+
Supports Claude 3 models with proper content block handling.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from collections.abc import AsyncIterator
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import httpx
|
| 14 |
+
from tenacity import (
|
| 15 |
+
before_sleep_log,
|
| 16 |
+
retry,
|
| 17 |
+
retry_if_exception_type,
|
| 18 |
+
stop_after_attempt,
|
| 19 |
+
wait_exponential,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from .base import BaseLLMClient, LLMResponse, LLMToolResponse, ToolCall
|
| 23 |
+
from .exceptions import (
|
| 24 |
+
CircuitBreakerOpenError,
|
| 25 |
+
LLMAuthenticationError,
|
| 26 |
+
LLMClientError,
|
| 27 |
+
LLMConnectionError,
|
| 28 |
+
LLMContentFilterError,
|
| 29 |
+
LLMContextLengthError,
|
| 30 |
+
LLMInvalidRequestError,
|
| 31 |
+
LLMModelNotFoundError,
|
| 32 |
+
LLMQuotaExceededError,
|
| 33 |
+
LLMRateLimitError,
|
| 34 |
+
LLMResponseParseError,
|
| 35 |
+
LLMServerError,
|
| 36 |
+
LLMStreamError,
|
| 37 |
+
LLMTimeoutError,
|
| 38 |
+
)
|
| 39 |
+
from .openai_client import CircuitBreaker
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Model mappings for convenience
|
| 45 |
+
ANTHROPIC_MODELS = {
|
| 46 |
+
"claude-3-opus": "claude-3-opus-20240229",
|
| 47 |
+
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
| 48 |
+
"claude-3-haiku": "claude-3-haiku-20240307",
|
| 49 |
+
"claude-3.5-sonnet": "claude-3-5-sonnet-20240620",
|
| 50 |
+
"claude-3.5-sonnet-v2": "claude-3-5-sonnet-20241022",
|
| 51 |
+
"claude-sonnet-4": "claude-sonnet-4-20250514",
|
| 52 |
+
# Add latest models
|
| 53 |
+
"opus": "claude-3-opus-20240229",
|
| 54 |
+
"sonnet": "claude-3-5-sonnet-20241022",
|
| 55 |
+
"haiku": "claude-3-haiku-20240307",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class AnthropicClient(BaseLLMClient):
|
| 60 |
+
"""
|
| 61 |
+
Anthropic Claude API client.
|
| 62 |
+
|
| 63 |
+
Features:
|
| 64 |
+
- Messages API support (not legacy completion API)
|
| 65 |
+
- Content block handling (text, tool_use)
|
| 66 |
+
- Streaming with proper SSE parsing
|
| 67 |
+
- Model alias mapping
|
| 68 |
+
- System prompt support
|
| 69 |
+
- Tool/function calling (beta)
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
PROVIDER_NAME = "anthropic"
|
| 73 |
+
DEFAULT_BASE_URL = "https://api.anthropic.com"
|
| 74 |
+
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
| 75 |
+
API_VERSION = "2023-06-01"
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
api_key: str | None = None,
|
| 80 |
+
model: str | None = None,
|
| 81 |
+
base_url: str | None = None,
|
| 82 |
+
timeout: float = 120.0, # Claude can be slower
|
| 83 |
+
max_retries: int = 3,
|
| 84 |
+
# Circuit breaker settings
|
| 85 |
+
circuit_breaker_threshold: int = 5,
|
| 86 |
+
circuit_breaker_reset: float = 60.0,
|
| 87 |
+
# Rate limiting
|
| 88 |
+
rate_limit_per_minute: int | None = None,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Initialize Anthropic client.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
api_key: Anthropic API key (or set ANTHROPIC_API_KEY env var)
|
| 95 |
+
model: Model to use (supports aliases like 'sonnet', 'opus')
|
| 96 |
+
base_url: API base URL
|
| 97 |
+
timeout: Request timeout in seconds (default longer for Claude)
|
| 98 |
+
max_retries: Max retry attempts
|
| 99 |
+
circuit_breaker_threshold: Failures before circuit opens
|
| 100 |
+
circuit_breaker_reset: Seconds before circuit resets
|
| 101 |
+
rate_limit_per_minute: Rate limit for requests per minute (None to disable)
|
| 102 |
+
"""
|
| 103 |
+
import os
|
| 104 |
+
|
| 105 |
+
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
| 106 |
+
if not api_key:
|
| 107 |
+
raise LLMAuthenticationError(self.PROVIDER_NAME, "API key not provided and ANTHROPIC_API_KEY not set")
|
| 108 |
+
|
| 109 |
+
# Resolve model alias
|
| 110 |
+
model_name = model or self.DEFAULT_MODEL
|
| 111 |
+
resolved_model = ANTHROPIC_MODELS.get(model_name, model_name)
|
| 112 |
+
|
| 113 |
+
super().__init__(
|
| 114 |
+
api_key=api_key,
|
| 115 |
+
model=resolved_model,
|
| 116 |
+
base_url=base_url or self.DEFAULT_BASE_URL,
|
| 117 |
+
timeout=timeout,
|
| 118 |
+
max_retries=max_retries,
|
| 119 |
+
rate_limit_per_minute=rate_limit_per_minute,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.circuit_breaker = CircuitBreaker(
|
| 123 |
+
failure_threshold=circuit_breaker_threshold,
|
| 124 |
+
reset_timeout=circuit_breaker_reset,
|
| 125 |
+
)
|
| 126 |
+
self._client: httpx.AsyncClient | None = None
|
| 127 |
+
|
| 128 |
+
async def _get_client(self) -> httpx.AsyncClient:
|
| 129 |
+
"""Get or create the HTTP client."""
|
| 130 |
+
if self._client is None or self._client.is_closed:
|
| 131 |
+
headers = {
|
| 132 |
+
"x-api-key": self.api_key,
|
| 133 |
+
"anthropic-version": self.API_VERSION,
|
| 134 |
+
"Content-Type": "application/json",
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
self._client = httpx.AsyncClient(
|
| 138 |
+
base_url=self.base_url,
|
| 139 |
+
headers=headers,
|
| 140 |
+
timeout=httpx.Timeout(self.timeout),
|
| 141 |
+
)
|
| 142 |
+
return self._client
|
| 143 |
+
|
| 144 |
+
def _convert_messages_to_anthropic(self, messages: list[dict]) -> tuple[str | None, list[dict]]:
|
| 145 |
+
"""
|
| 146 |
+
Convert OpenAI-style messages to Anthropic format.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Tuple of (system_prompt, messages)
|
| 150 |
+
"""
|
| 151 |
+
system_prompt = None
|
| 152 |
+
anthropic_messages = []
|
| 153 |
+
|
| 154 |
+
for msg in messages:
|
| 155 |
+
role = msg.get("role", "user")
|
| 156 |
+
content = msg.get("content", "")
|
| 157 |
+
|
| 158 |
+
if role == "system":
|
| 159 |
+
# Anthropic uses separate system parameter
|
| 160 |
+
system_prompt = content
|
| 161 |
+
elif role == "assistant":
|
| 162 |
+
anthropic_messages.append({"role": "assistant", "content": content})
|
| 163 |
+
elif role == "user":
|
| 164 |
+
anthropic_messages.append({"role": "user", "content": content})
|
| 165 |
+
elif role == "tool":
|
| 166 |
+
# Tool result message
|
| 167 |
+
anthropic_messages.append(
|
| 168 |
+
{
|
| 169 |
+
"role": "user",
|
| 170 |
+
"content": [
|
| 171 |
+
{
|
| 172 |
+
"type": "tool_result",
|
| 173 |
+
"tool_use_id": msg.get("tool_call_id", ""),
|
| 174 |
+
"content": content,
|
| 175 |
+
}
|
| 176 |
+
],
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return system_prompt, anthropic_messages
|
| 181 |
+
|
| 182 |
+
def _convert_tools_to_anthropic(self, tools: list[dict]) -> list[dict]:
|
| 183 |
+
"""Convert OpenAI-style tool definitions to Anthropic format."""
|
| 184 |
+
anthropic_tools = []
|
| 185 |
+
|
| 186 |
+
for tool in tools:
|
| 187 |
+
if tool.get("type") == "function":
|
| 188 |
+
func = tool["function"]
|
| 189 |
+
anthropic_tools.append(
|
| 190 |
+
{
|
| 191 |
+
"name": func["name"],
|
| 192 |
+
"description": func.get("description", ""),
|
| 193 |
+
"input_schema": func.get("parameters", {"type": "object"}),
|
| 194 |
+
}
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
# Already in Anthropic format
|
| 198 |
+
anthropic_tools.append(tool)
|
| 199 |
+
|
| 200 |
+
return anthropic_tools
|
| 201 |
+
|
| 202 |
+
def _handle_error_response(self, response: httpx.Response) -> None:
|
| 203 |
+
"""Convert HTTP error responses to appropriate exceptions."""
|
| 204 |
+
status_code = response.status_code
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
error_data = response.json()
|
| 208 |
+
error_type = error_data.get("error", {}).get("type", "")
|
| 209 |
+
error_message = error_data.get("error", {}).get("message", response.text)
|
| 210 |
+
except Exception:
|
| 211 |
+
error_type = ""
|
| 212 |
+
error_message = response.text
|
| 213 |
+
|
| 214 |
+
if status_code == 401:
|
| 215 |
+
raise LLMAuthenticationError(self.PROVIDER_NAME, error_message)
|
| 216 |
+
elif status_code == 429:
|
| 217 |
+
retry_after = response.headers.get("retry-after")
|
| 218 |
+
retry_after_float = float(retry_after) if retry_after else None
|
| 219 |
+
raise LLMRateLimitError(self.PROVIDER_NAME, retry_after=retry_after_float, message=error_message)
|
| 220 |
+
elif status_code == 402 or "billing" in error_type.lower():
|
| 221 |
+
raise LLMQuotaExceededError(self.PROVIDER_NAME, error_message)
|
| 222 |
+
elif status_code == 404 or error_type == "not_found_error":
|
| 223 |
+
raise LLMModelNotFoundError(self.PROVIDER_NAME, self.model)
|
| 224 |
+
elif status_code == 400:
|
| 225 |
+
if "context" in error_message.lower() or "token" in error_message.lower():
|
| 226 |
+
raise LLMContextLengthError(self.PROVIDER_NAME)
|
| 227 |
+
if "content_policy" in error_type or "safety" in error_message.lower():
|
| 228 |
+
raise LLMContentFilterError(self.PROVIDER_NAME, error_message)
|
| 229 |
+
raise LLMInvalidRequestError(self.PROVIDER_NAME, error_message)
|
| 230 |
+
elif status_code >= 500:
|
| 231 |
+
raise LLMServerError(self.PROVIDER_NAME, status_code, error_message)
|
| 232 |
+
else:
|
| 233 |
+
raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code)
|
| 234 |
+
|
| 235 |
+
def _make_retry_decorator(self):
|
| 236 |
+
"""Create retry decorator with exponential backoff."""
|
| 237 |
+
return retry(
|
| 238 |
+
stop=stop_after_attempt(self.max_retries),
|
| 239 |
+
wait=wait_exponential(multiplier=1, min=2, max=120),
|
| 240 |
+
retry=retry_if_exception_type((LLMRateLimitError, LLMServerError, LLMConnectionError)),
|
| 241 |
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
| 242 |
+
reraise=True,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
async def generate(
|
| 246 |
+
self,
|
| 247 |
+
*,
|
| 248 |
+
messages: list[dict] | None = None,
|
| 249 |
+
prompt: str | None = None,
|
| 250 |
+
temperature: float = 0.7,
|
| 251 |
+
max_tokens: int | None = None,
|
| 252 |
+
tools: list[dict] | None = None,
|
| 253 |
+
stream: bool = False,
|
| 254 |
+
stop: list[str] | None = None,
|
| 255 |
+
**kwargs: Any,
|
| 256 |
+
) -> LLMResponse | AsyncIterator[str]:
|
| 257 |
+
"""
|
| 258 |
+
Generate a response from Anthropic Claude.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
messages: Chat messages (will be converted to Anthropic format)
|
| 262 |
+
prompt: Simple string prompt
|
| 263 |
+
temperature: Sampling temperature (0.0 to 1.0 for Claude)
|
| 264 |
+
max_tokens: Maximum tokens to generate (required for Anthropic)
|
| 265 |
+
tools: Tool definitions (will be converted to Anthropic format)
|
| 266 |
+
stream: If True, returns AsyncIterator
|
| 267 |
+
stop: Stop sequences
|
| 268 |
+
**kwargs: Additional parameters (top_p, top_k, etc.)
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
LLMResponse or AsyncIterator[str] for streaming
|
| 272 |
+
"""
|
| 273 |
+
# Apply rate limiting before proceeding
|
| 274 |
+
await self._apply_rate_limit()
|
| 275 |
+
|
| 276 |
+
# Check circuit breaker
|
| 277 |
+
if not self.circuit_breaker.can_execute():
|
| 278 |
+
raise CircuitBreakerOpenError(
|
| 279 |
+
self.PROVIDER_NAME,
|
| 280 |
+
self.circuit_breaker.failure_count,
|
| 281 |
+
self.circuit_breaker.get_reset_time(),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Anthropic requires max_tokens
|
| 285 |
+
if max_tokens is None:
|
| 286 |
+
max_tokens = 4096 # Sensible default
|
| 287 |
+
|
| 288 |
+
if stream:
|
| 289 |
+
return self._generate_stream(
|
| 290 |
+
messages=messages,
|
| 291 |
+
prompt=prompt,
|
| 292 |
+
temperature=temperature,
|
| 293 |
+
max_tokens=max_tokens,
|
| 294 |
+
tools=tools,
|
| 295 |
+
stop=stop,
|
| 296 |
+
**kwargs,
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
return await self._generate_non_stream(
|
| 300 |
+
messages=messages,
|
| 301 |
+
prompt=prompt,
|
| 302 |
+
temperature=temperature,
|
| 303 |
+
max_tokens=max_tokens,
|
| 304 |
+
tools=tools,
|
| 305 |
+
stop=stop,
|
| 306 |
+
**kwargs,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
async def _generate_non_stream(
|
| 310 |
+
self,
|
| 311 |
+
*,
|
| 312 |
+
messages: list[dict] | None = None,
|
| 313 |
+
prompt: str | None = None,
|
| 314 |
+
temperature: float = 0.7,
|
| 315 |
+
max_tokens: int = 4096,
|
| 316 |
+
tools: list[dict] | None = None,
|
| 317 |
+
stop: list[str] | None = None,
|
| 318 |
+
**kwargs: Any,
|
| 319 |
+
) -> LLMResponse:
|
| 320 |
+
"""Non-streaming generation with retry logic."""
|
| 321 |
+
|
| 322 |
+
@self._make_retry_decorator()
|
| 323 |
+
async def _request():
|
| 324 |
+
client = await self._get_client()
|
| 325 |
+
|
| 326 |
+
# Convert messages
|
| 327 |
+
built_messages = self._build_messages(messages, prompt)
|
| 328 |
+
system_prompt, anthropic_messages = self._convert_messages_to_anthropic(built_messages)
|
| 329 |
+
|
| 330 |
+
# Build request payload
|
| 331 |
+
payload = {
|
| 332 |
+
"model": self.model,
|
| 333 |
+
"messages": anthropic_messages,
|
| 334 |
+
"max_tokens": max_tokens,
|
| 335 |
+
"temperature": min(temperature, 1.0), # Anthropic max is 1.0
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
if system_prompt:
|
| 339 |
+
payload["system"] = system_prompt
|
| 340 |
+
if stop:
|
| 341 |
+
payload["stop_sequences"] = stop
|
| 342 |
+
if tools:
|
| 343 |
+
payload["tools"] = self._convert_tools_to_anthropic(tools)
|
| 344 |
+
|
| 345 |
+
# Add any additional kwargs (top_p, top_k, etc.)
|
| 346 |
+
for key in ["top_p", "top_k", "metadata"]:
|
| 347 |
+
if key in kwargs:
|
| 348 |
+
payload[key] = kwargs[key]
|
| 349 |
+
|
| 350 |
+
try:
|
| 351 |
+
response = await client.post("/v1/messages", json=payload)
|
| 352 |
+
except httpx.TimeoutException:
|
| 353 |
+
raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
|
| 354 |
+
except httpx.ConnectError:
|
| 355 |
+
raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
|
| 356 |
+
|
| 357 |
+
if response.status_code != 200:
|
| 358 |
+
self._handle_error_response(response)
|
| 359 |
+
|
| 360 |
+
return response
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
response = await _request()
|
| 364 |
+
self.circuit_breaker.record_success()
|
| 365 |
+
except Exception:
|
| 366 |
+
self.circuit_breaker.record_failure()
|
| 367 |
+
raise
|
| 368 |
+
|
| 369 |
+
# Parse response
|
| 370 |
+
try:
|
| 371 |
+
data = response.json()
|
| 372 |
+
|
| 373 |
+
# Extract text from content blocks
|
| 374 |
+
text_parts = []
|
| 375 |
+
tool_calls = []
|
| 376 |
+
|
| 377 |
+
for block in data.get("content", []):
|
| 378 |
+
if block.get("type") == "text":
|
| 379 |
+
text_parts.append(block.get("text", ""))
|
| 380 |
+
elif block.get("type") == "tool_use":
|
| 381 |
+
tool_calls.append(
|
| 382 |
+
ToolCall(
|
| 383 |
+
id=block.get("id", ""),
|
| 384 |
+
name=block.get("name", ""),
|
| 385 |
+
arguments=block.get("input", {}),
|
| 386 |
+
type="tool_use",
|
| 387 |
+
)
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
text = "\n".join(text_parts)
|
| 391 |
+
|
| 392 |
+
# Build usage dict
|
| 393 |
+
usage = {
|
| 394 |
+
"prompt_tokens": data.get("usage", {}).get("input_tokens", 0),
|
| 395 |
+
"completion_tokens": data.get("usage", {}).get("output_tokens", 0),
|
| 396 |
+
}
|
| 397 |
+
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
|
| 398 |
+
|
| 399 |
+
finish_reason = data.get("stop_reason", "stop")
|
| 400 |
+
|
| 401 |
+
if tool_calls:
|
| 402 |
+
llm_response = LLMToolResponse(
|
| 403 |
+
text=text,
|
| 404 |
+
usage=usage,
|
| 405 |
+
model=data.get("model", self.model),
|
| 406 |
+
raw_response=data,
|
| 407 |
+
finish_reason=finish_reason,
|
| 408 |
+
tool_calls=tool_calls,
|
| 409 |
+
)
|
| 410 |
+
else:
|
| 411 |
+
llm_response = LLMResponse(
|
| 412 |
+
text=text,
|
| 413 |
+
usage=usage,
|
| 414 |
+
model=data.get("model", self.model),
|
| 415 |
+
raw_response=data,
|
| 416 |
+
finish_reason=finish_reason,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
self._update_stats(llm_response)
|
| 420 |
+
return llm_response
|
| 421 |
+
|
| 422 |
+
except (KeyError, json.JSONDecodeError) as e:
|
| 423 |
+
raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e
|
| 424 |
+
|
| 425 |
+
async def _generate_stream(
|
| 426 |
+
self,
|
| 427 |
+
*,
|
| 428 |
+
messages: list[dict] | None = None,
|
| 429 |
+
prompt: str | None = None,
|
| 430 |
+
temperature: float = 0.7,
|
| 431 |
+
max_tokens: int = 4096,
|
| 432 |
+
tools: list[dict] | None = None,
|
| 433 |
+
stop: list[str] | None = None,
|
| 434 |
+
**kwargs: Any,
|
| 435 |
+
) -> AsyncIterator[str]:
|
| 436 |
+
"""Streaming generation with Server-Sent Events."""
|
| 437 |
+
|
| 438 |
+
client = await self._get_client()
|
| 439 |
+
|
| 440 |
+
# Convert messages
|
| 441 |
+
built_messages = self._build_messages(messages, prompt)
|
| 442 |
+
system_prompt, anthropic_messages = self._convert_messages_to_anthropic(built_messages)
|
| 443 |
+
|
| 444 |
+
# Build request payload
|
| 445 |
+
payload = {
|
| 446 |
+
"model": self.model,
|
| 447 |
+
"messages": anthropic_messages,
|
| 448 |
+
"max_tokens": max_tokens,
|
| 449 |
+
"temperature": min(temperature, 1.0),
|
| 450 |
+
"stream": True,
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
if system_prompt:
|
| 454 |
+
payload["system"] = system_prompt
|
| 455 |
+
if stop:
|
| 456 |
+
payload["stop_sequences"] = stop
|
| 457 |
+
if tools:
|
| 458 |
+
payload["tools"] = self._convert_tools_to_anthropic(tools)
|
| 459 |
+
|
| 460 |
+
for key in ["top_p", "top_k"]:
|
| 461 |
+
if key in kwargs:
|
| 462 |
+
payload[key] = kwargs[key]
|
| 463 |
+
|
| 464 |
+
async def stream_generator():
|
| 465 |
+
try:
|
| 466 |
+
async with client.stream("POST", "/v1/messages", json=payload) as response:
|
| 467 |
+
if response.status_code != 200:
|
| 468 |
+
await response.aread()
|
| 469 |
+
self._handle_error_response(response)
|
| 470 |
+
|
| 471 |
+
async for line in response.aiter_lines():
|
| 472 |
+
if not line.strip():
|
| 473 |
+
continue
|
| 474 |
+
|
| 475 |
+
if line.startswith("event:"):
|
| 476 |
+
event_type = line[6:].strip()
|
| 477 |
+
continue
|
| 478 |
+
|
| 479 |
+
if line.startswith("data:"):
|
| 480 |
+
data_str = line[5:].strip()
|
| 481 |
+
if not data_str:
|
| 482 |
+
continue
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
data = json.loads(data_str)
|
| 486 |
+
event_type = data.get("type", "")
|
| 487 |
+
|
| 488 |
+
if event_type == "content_block_delta":
|
| 489 |
+
delta = data.get("delta", {})
|
| 490 |
+
if delta.get("type") == "text_delta":
|
| 491 |
+
text = delta.get("text", "")
|
| 492 |
+
if text:
|
| 493 |
+
yield text
|
| 494 |
+
|
| 495 |
+
elif event_type == "message_stop":
|
| 496 |
+
break
|
| 497 |
+
|
| 498 |
+
except json.JSONDecodeError:
|
| 499 |
+
continue
|
| 500 |
+
|
| 501 |
+
self.circuit_breaker.record_success()
|
| 502 |
+
|
| 503 |
+
except httpx.TimeoutException:
|
| 504 |
+
self.circuit_breaker.record_failure()
|
| 505 |
+
raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
|
| 506 |
+
except httpx.ConnectError:
|
| 507 |
+
self.circuit_breaker.record_failure()
|
| 508 |
+
raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
|
| 509 |
+
except Exception as e:
|
| 510 |
+
self.circuit_breaker.record_failure()
|
| 511 |
+
if isinstance(e, LLMClientError):
|
| 512 |
+
raise
|
| 513 |
+
raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e
|
| 514 |
+
|
| 515 |
+
return stream_generator()
|
| 516 |
+
|
| 517 |
+
async def close(self) -> None:
|
| 518 |
+
"""Close the HTTP client."""
|
| 519 |
+
if self._client and not self._client.is_closed:
|
| 520 |
+
await self._client.aclose()
|
| 521 |
+
self._client = None
|
src/adapters/llm/base.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base LLM client interface for provider-agnostic model access.
|
| 3 |
+
|
| 4 |
+
This module defines the protocol and data structures for LLM clients,
|
| 5 |
+
enabling seamless switching between providers (OpenAI, Anthropic, LM Studio, etc.)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import time
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from collections.abc import AsyncIterator
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from typing import Any, Protocol, runtime_checkable
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class LLMResponse:
|
| 19 |
+
"""Standardized response from any LLM provider."""
|
| 20 |
+
|
| 21 |
+
text: str
|
| 22 |
+
usage: dict = field(default_factory=dict)
|
| 23 |
+
model: str = ""
|
| 24 |
+
raw_response: Any = None
|
| 25 |
+
finish_reason: str = "stop"
|
| 26 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def total_tokens(self) -> int:
|
| 30 |
+
"""Total tokens used in request/response."""
|
| 31 |
+
return self.usage.get("total_tokens", 0)
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def prompt_tokens(self) -> int:
|
| 35 |
+
"""Tokens used in prompt."""
|
| 36 |
+
return self.usage.get("prompt_tokens", 0)
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def completion_tokens(self) -> int:
|
| 40 |
+
"""Tokens used in completion."""
|
| 41 |
+
return self.usage.get("completion_tokens", 0)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class ToolCall:
|
| 46 |
+
"""Represents a tool/function call from the LLM."""
|
| 47 |
+
|
| 48 |
+
id: str
|
| 49 |
+
name: str
|
| 50 |
+
arguments: dict
|
| 51 |
+
type: str = "function"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class LLMToolResponse(LLMResponse):
|
| 56 |
+
"""Response containing tool calls."""
|
| 57 |
+
|
| 58 |
+
tool_calls: list[ToolCall] = field(default_factory=list)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TokenBucketRateLimiter:
|
| 62 |
+
"""
|
| 63 |
+
Token bucket rate limiter for controlling request rates.
|
| 64 |
+
|
| 65 |
+
This implementation uses a token bucket algorithm where:
|
| 66 |
+
- Tokens are added at a fixed rate (rate_per_second)
|
| 67 |
+
- Each request consumes one token
|
| 68 |
+
- If no tokens available, caller waits until one becomes available
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, rate_per_minute: int = 60):
|
| 72 |
+
"""
|
| 73 |
+
Initialize the rate limiter.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
rate_per_minute: Maximum requests allowed per minute
|
| 77 |
+
"""
|
| 78 |
+
self.rate_per_second = rate_per_minute / 60.0
|
| 79 |
+
self.max_tokens = float(rate_per_minute)
|
| 80 |
+
self.tokens = self.max_tokens
|
| 81 |
+
self.last_refill = time.monotonic()
|
| 82 |
+
self._lock = asyncio.Lock()
|
| 83 |
+
self._wait_count = 0
|
| 84 |
+
self._total_wait_time = 0.0
|
| 85 |
+
|
| 86 |
+
async def acquire(self) -> float:
|
| 87 |
+
"""
|
| 88 |
+
Acquire a token, waiting if necessary.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Time spent waiting (0.0 if no wait was needed)
|
| 92 |
+
"""
|
| 93 |
+
async with self._lock:
|
| 94 |
+
now = time.monotonic()
|
| 95 |
+
elapsed = now - self.last_refill
|
| 96 |
+
|
| 97 |
+
# Refill tokens based on elapsed time
|
| 98 |
+
self.tokens = min(self.max_tokens, self.tokens + elapsed * self.rate_per_second)
|
| 99 |
+
self.last_refill = now
|
| 100 |
+
|
| 101 |
+
wait_time = 0.0
|
| 102 |
+
if self.tokens < 1:
|
| 103 |
+
# Calculate how long to wait for one token
|
| 104 |
+
wait_time = (1 - self.tokens) / self.rate_per_second
|
| 105 |
+
self._wait_count += 1
|
| 106 |
+
self._total_wait_time += wait_time
|
| 107 |
+
|
| 108 |
+
# Release lock during sleep to allow other operations
|
| 109 |
+
self._lock.release()
|
| 110 |
+
try:
|
| 111 |
+
await asyncio.sleep(wait_time)
|
| 112 |
+
finally:
|
| 113 |
+
await self._lock.acquire()
|
| 114 |
+
|
| 115 |
+
# After sleeping, update time and set tokens to 0
|
| 116 |
+
self.last_refill = time.monotonic()
|
| 117 |
+
self.tokens = 0
|
| 118 |
+
else:
|
| 119 |
+
self.tokens -= 1
|
| 120 |
+
|
| 121 |
+
return wait_time
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def stats(self) -> dict:
|
| 125 |
+
"""Get rate limiter statistics."""
|
| 126 |
+
return {
|
| 127 |
+
"rate_limit_waits": self._wait_count,
|
| 128 |
+
"total_rate_limit_wait_time": self._total_wait_time,
|
| 129 |
+
"current_tokens": self.tokens,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@runtime_checkable
|
| 134 |
+
class LLMClient(Protocol):
|
| 135 |
+
"""
|
| 136 |
+
Protocol for LLM clients.
|
| 137 |
+
|
| 138 |
+
This protocol defines the interface that all LLM provider adapters must implement.
|
| 139 |
+
Using Protocol allows for structural subtyping (duck typing) while maintaining
|
| 140 |
+
type safety.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
async def generate(
|
| 144 |
+
self,
|
| 145 |
+
*,
|
| 146 |
+
messages: list[dict] | None = None,
|
| 147 |
+
prompt: str | None = None,
|
| 148 |
+
temperature: float = 0.7,
|
| 149 |
+
max_tokens: int | None = None,
|
| 150 |
+
tools: list[dict] | None = None,
|
| 151 |
+
stream: bool = False,
|
| 152 |
+
stop: list[str] | None = None,
|
| 153 |
+
**kwargs: Any,
|
| 154 |
+
) -> LLMResponse | AsyncIterator[str]:
|
| 155 |
+
"""
|
| 156 |
+
Generate a response from the LLM.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
messages: List of message dicts in OpenAI format [{"role": "...", "content": "..."}]
|
| 160 |
+
prompt: Simple string prompt (converted to single user message)
|
| 161 |
+
temperature: Sampling temperature (0.0 to 2.0)
|
| 162 |
+
max_tokens: Maximum tokens to generate
|
| 163 |
+
tools: List of tool definitions for function calling
|
| 164 |
+
stream: If True, returns AsyncIterator[str] for streaming
|
| 165 |
+
stop: Stop sequences
|
| 166 |
+
**kwargs: Provider-specific parameters
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
LLMResponse if stream=False, AsyncIterator[str] if stream=True
|
| 170 |
+
|
| 171 |
+
Raises:
|
| 172 |
+
LLMClientError: Base exception for all client errors
|
| 173 |
+
"""
|
| 174 |
+
...
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class BaseLLMClient(ABC):
|
| 178 |
+
"""
|
| 179 |
+
Abstract base class for LLM clients.
|
| 180 |
+
|
| 181 |
+
Provides common functionality and enforces the interface contract.
|
| 182 |
+
All concrete implementations should inherit from this class.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
api_key: str | None = None,
|
| 188 |
+
model: str = "default",
|
| 189 |
+
base_url: str | None = None,
|
| 190 |
+
timeout: float = 60.0,
|
| 191 |
+
max_retries: int = 3,
|
| 192 |
+
rate_limit_per_minute: int | None = None,
|
| 193 |
+
):
|
| 194 |
+
"""
|
| 195 |
+
Initialize the LLM client.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
api_key: API key for authentication
|
| 199 |
+
model: Model identifier
|
| 200 |
+
base_url: Base URL for API requests
|
| 201 |
+
timeout: Request timeout in seconds
|
| 202 |
+
max_retries: Maximum number of retry attempts
|
| 203 |
+
rate_limit_per_minute: Rate limit (requests per minute), None to disable
|
| 204 |
+
"""
|
| 205 |
+
self.api_key = api_key
|
| 206 |
+
self.model = model
|
| 207 |
+
self.base_url = base_url
|
| 208 |
+
self.timeout = timeout
|
| 209 |
+
self.max_retries = max_retries
|
| 210 |
+
self._request_count = 0
|
| 211 |
+
self._total_tokens_used = 0
|
| 212 |
+
self._rate_limited_requests = 0
|
| 213 |
+
|
| 214 |
+
# Initialize rate limiter if configured
|
| 215 |
+
if rate_limit_per_minute is not None and rate_limit_per_minute > 0:
|
| 216 |
+
self._rate_limiter: TokenBucketRateLimiter | None = TokenBucketRateLimiter(
|
| 217 |
+
rate_per_minute=rate_limit_per_minute
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
self._rate_limiter = None
|
| 221 |
+
|
| 222 |
+
@abstractmethod
|
| 223 |
+
async def generate(
|
| 224 |
+
self,
|
| 225 |
+
*,
|
| 226 |
+
messages: list[dict] | None = None,
|
| 227 |
+
prompt: str | None = None,
|
| 228 |
+
temperature: float = 0.7,
|
| 229 |
+
max_tokens: int | None = None,
|
| 230 |
+
tools: list[dict] | None = None,
|
| 231 |
+
stream: bool = False,
|
| 232 |
+
stop: list[str] | None = None,
|
| 233 |
+
**kwargs: Any,
|
| 234 |
+
) -> LLMResponse | AsyncIterator[str]:
|
| 235 |
+
"""Generate a response from the LLM."""
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
def _build_messages(
|
| 239 |
+
self,
|
| 240 |
+
messages: list[dict] | None = None,
|
| 241 |
+
prompt: str | None = None,
|
| 242 |
+
) -> list[dict]:
|
| 243 |
+
"""
|
| 244 |
+
Build message list from either messages or prompt.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
messages: Pre-formatted message list
|
| 248 |
+
prompt: Simple string prompt
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
List of message dicts
|
| 252 |
+
|
| 253 |
+
Raises:
|
| 254 |
+
ValueError: If neither messages nor prompt provided
|
| 255 |
+
"""
|
| 256 |
+
if messages is not None:
|
| 257 |
+
return messages
|
| 258 |
+
elif prompt is not None:
|
| 259 |
+
return [{"role": "user", "content": prompt}]
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError("Either 'messages' or 'prompt' must be provided")
|
| 262 |
+
|
| 263 |
+
def _update_stats(self, response: LLMResponse) -> None:
|
| 264 |
+
"""Update internal statistics."""
|
| 265 |
+
self._request_count += 1
|
| 266 |
+
self._total_tokens_used += response.total_tokens
|
| 267 |
+
|
| 268 |
+
async def _apply_rate_limit(self) -> None:
|
| 269 |
+
"""
|
| 270 |
+
Apply rate limiting if configured.
|
| 271 |
+
|
| 272 |
+
Waits if necessary to comply with rate limits.
|
| 273 |
+
Tracks rate-limited requests in metrics.
|
| 274 |
+
"""
|
| 275 |
+
if self._rate_limiter is not None:
|
| 276 |
+
wait_time = await self._rate_limiter.acquire()
|
| 277 |
+
if wait_time > 0:
|
| 278 |
+
self._rate_limited_requests += 1
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def stats(self) -> dict:
|
| 282 |
+
"""Get client statistics."""
|
| 283 |
+
base_stats = {
|
| 284 |
+
"request_count": self._request_count,
|
| 285 |
+
"total_tokens_used": self._total_tokens_used,
|
| 286 |
+
"rate_limited_requests": self._rate_limited_requests,
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
# Include rate limiter stats if available
|
| 290 |
+
if self._rate_limiter is not None:
|
| 291 |
+
base_stats.update(self._rate_limiter.stats)
|
| 292 |
+
|
| 293 |
+
return base_stats
|
| 294 |
+
|
| 295 |
+
async def close(self) -> None: # noqa: B027
|
| 296 |
+
"""Clean up resources. Override in subclasses if needed."""
|
| 297 |
+
pass
|
| 298 |
+
|
| 299 |
+
async def __aenter__(self):
|
| 300 |
+
"""Async context manager entry."""
|
| 301 |
+
return self
|
| 302 |
+
|
| 303 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 304 |
+
"""Async context manager exit."""
|
| 305 |
+
await self.close()
|
src/adapters/llm/exceptions.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom exceptions for LLM client operations.
|
| 3 |
+
|
| 4 |
+
Provides a hierarchy of structured exceptions for better error handling
|
| 5 |
+
and debugging across different LLM providers.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LLMClientError(Exception):
|
| 10 |
+
"""Base exception for all LLM client errors."""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
message: str,
|
| 15 |
+
provider: str = "unknown",
|
| 16 |
+
status_code: int | None = None,
|
| 17 |
+
retry_after: float | None = None,
|
| 18 |
+
):
|
| 19 |
+
self.message = message
|
| 20 |
+
self.provider = provider
|
| 21 |
+
self.status_code = status_code
|
| 22 |
+
self.retry_after = retry_after
|
| 23 |
+
super().__init__(self.message)
|
| 24 |
+
|
| 25 |
+
def __str__(self) -> str:
|
| 26 |
+
parts = [f"[{self.provider}] {self.message}"]
|
| 27 |
+
if self.status_code:
|
| 28 |
+
parts.append(f"(status: {self.status_code})")
|
| 29 |
+
return " ".join(parts)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LLMAuthenticationError(LLMClientError):
|
| 33 |
+
"""Authentication failed - invalid or missing API key."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, provider: str, message: str = "Authentication failed"):
|
| 36 |
+
super().__init__(
|
| 37 |
+
message=message,
|
| 38 |
+
provider=provider,
|
| 39 |
+
status_code=401,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LLMRateLimitError(LLMClientError):
|
| 44 |
+
"""Rate limit exceeded - too many requests."""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
provider: str,
|
| 49 |
+
retry_after: float | None = None,
|
| 50 |
+
message: str = "Rate limit exceeded",
|
| 51 |
+
):
|
| 52 |
+
super().__init__(
|
| 53 |
+
message=message,
|
| 54 |
+
provider=provider,
|
| 55 |
+
status_code=429,
|
| 56 |
+
retry_after=retry_after,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LLMQuotaExceededError(LLMClientError):
|
| 61 |
+
"""Quota or credits exhausted."""
|
| 62 |
+
|
| 63 |
+
def __init__(self, provider: str, message: str = "Quota exceeded"):
|
| 64 |
+
super().__init__(
|
| 65 |
+
message=message,
|
| 66 |
+
provider=provider,
|
| 67 |
+
status_code=402,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class LLMModelNotFoundError(LLMClientError):
|
| 72 |
+
"""Requested model not available."""
|
| 73 |
+
|
| 74 |
+
def __init__(self, provider: str, model: str):
|
| 75 |
+
super().__init__(
|
| 76 |
+
message=f"Model '{model}' not found or not available",
|
| 77 |
+
provider=provider,
|
| 78 |
+
status_code=404,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LLMContextLengthError(LLMClientError):
|
| 83 |
+
"""Input exceeds model's context window."""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
provider: str,
|
| 88 |
+
token_count: int | None = None,
|
| 89 |
+
max_tokens: int | None = None,
|
| 90 |
+
):
|
| 91 |
+
message = "Context length exceeded"
|
| 92 |
+
if token_count and max_tokens:
|
| 93 |
+
message = f"Context length exceeded: {token_count} tokens provided, max is {max_tokens}"
|
| 94 |
+
super().__init__(
|
| 95 |
+
message=message,
|
| 96 |
+
provider=provider,
|
| 97 |
+
status_code=400,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class LLMInvalidRequestError(LLMClientError):
|
| 102 |
+
"""Invalid request parameters."""
|
| 103 |
+
|
| 104 |
+
def __init__(self, provider: str, message: str = "Invalid request parameters"):
|
| 105 |
+
super().__init__(
|
| 106 |
+
message=message,
|
| 107 |
+
provider=provider,
|
| 108 |
+
status_code=400,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class LLMTimeoutError(LLMClientError):
|
| 113 |
+
"""Request timed out."""
|
| 114 |
+
|
| 115 |
+
def __init__(self, provider: str, timeout: float):
|
| 116 |
+
super().__init__(
|
| 117 |
+
message=f"Request timed out after {timeout}s",
|
| 118 |
+
provider=provider,
|
| 119 |
+
status_code=408,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class LLMConnectionError(LLMClientError):
|
| 124 |
+
"""Failed to connect to the API endpoint."""
|
| 125 |
+
|
| 126 |
+
def __init__(self, provider: str, url: str | None = None):
|
| 127 |
+
message = "Failed to connect to API"
|
| 128 |
+
if url:
|
| 129 |
+
message = f"Failed to connect to {url}"
|
| 130 |
+
super().__init__(
|
| 131 |
+
message=message,
|
| 132 |
+
provider=provider,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class LLMServerError(LLMClientError):
|
| 137 |
+
"""Server-side error from the LLM provider."""
|
| 138 |
+
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
provider: str,
|
| 142 |
+
status_code: int = 500,
|
| 143 |
+
message: str = "Server error",
|
| 144 |
+
):
|
| 145 |
+
super().__init__(
|
| 146 |
+
message=message,
|
| 147 |
+
provider=provider,
|
| 148 |
+
status_code=status_code,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class LLMResponseParseError(LLMClientError):
|
| 153 |
+
"""Failed to parse response from LLM provider."""
|
| 154 |
+
|
| 155 |
+
def __init__(self, provider: str, raw_response: str | None = None):
|
| 156 |
+
message = "Failed to parse response"
|
| 157 |
+
if raw_response:
|
| 158 |
+
preview = raw_response[:200] + "..." if len(raw_response) > 200 else raw_response
|
| 159 |
+
message = f"Failed to parse response: {preview}"
|
| 160 |
+
super().__init__(
|
| 161 |
+
message=message,
|
| 162 |
+
provider=provider,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class LLMStreamError(LLMClientError):
|
| 167 |
+
"""Error during streaming response."""
|
| 168 |
+
|
| 169 |
+
def __init__(self, provider: str, message: str = "Stream interrupted"):
|
| 170 |
+
super().__init__(
|
| 171 |
+
message=message,
|
| 172 |
+
provider=provider,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class LLMContentFilterError(LLMClientError):
|
| 177 |
+
"""Content blocked by safety filters."""
|
| 178 |
+
|
| 179 |
+
def __init__(self, provider: str, reason: str | None = None):
|
| 180 |
+
message = "Content blocked by safety filters"
|
| 181 |
+
if reason:
|
| 182 |
+
message = f"Content blocked: {reason}"
|
| 183 |
+
super().__init__(
|
| 184 |
+
message=message,
|
| 185 |
+
provider=provider,
|
| 186 |
+
status_code=400,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class CircuitBreakerOpenError(LLMClientError):
|
| 191 |
+
"""Circuit breaker is open, requests are being blocked."""
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
provider: str,
|
| 196 |
+
failure_count: int,
|
| 197 |
+
reset_time: float,
|
| 198 |
+
):
|
| 199 |
+
super().__init__(
|
| 200 |
+
message=f"Circuit breaker open after {failure_count} failures. Resets in {reset_time:.1f}s",
|
| 201 |
+
provider=provider,
|
| 202 |
+
)
|
| 203 |
+
self.failure_count = failure_count
|
| 204 |
+
self.reset_time = reset_time
|
src/adapters/llm/lmstudio_client.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LM Studio local LLM client adapter.
|
| 3 |
+
|
| 4 |
+
Implements the LLMClient protocol for LM Studio's OpenAI-compatible API.
|
| 5 |
+
Designed for running local models with configurable endpoint.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from collections.abc import AsyncIterator
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import httpx
|
| 14 |
+
|
| 15 |
+
from .base import BaseLLMClient, LLMResponse
|
| 16 |
+
from .exceptions import (
|
| 17 |
+
LLMClientError,
|
| 18 |
+
LLMConnectionError,
|
| 19 |
+
LLMResponseParseError,
|
| 20 |
+
LLMServerError,
|
| 21 |
+
LLMStreamError,
|
| 22 |
+
LLMTimeoutError,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LMStudioClient(BaseLLMClient):
|
| 29 |
+
"""
|
| 30 |
+
LM Studio local server client.
|
| 31 |
+
|
| 32 |
+
LM Studio provides an OpenAI-compatible API for running local models.
|
| 33 |
+
This client is optimized for local deployment with:
|
| 34 |
+
- No authentication required (local)
|
| 35 |
+
- Configurable base URL
|
| 36 |
+
- No circuit breaker (local server expected to be stable)
|
| 37 |
+
- Longer timeouts for large models
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
PROVIDER_NAME = "lmstudio"
|
| 41 |
+
DEFAULT_BASE_URL = "http://localhost:1234/v1"
|
| 42 |
+
DEFAULT_MODEL = "local-model" # LM Studio uses the loaded model
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
api_key: str | None = None, # Not required for local
|
| 47 |
+
model: str | None = None,
|
| 48 |
+
base_url: str | None = None,
|
| 49 |
+
timeout: float = 300.0, # Long timeout for local inference
|
| 50 |
+
max_retries: int = 2, # Fewer retries for local
|
| 51 |
+
# Rate limiting
|
| 52 |
+
rate_limit_per_minute: int | None = None,
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Initialize LM Studio client.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
api_key: Not required for local server (ignored)
|
| 59 |
+
model: Model identifier (often ignored by LM Studio, uses loaded model)
|
| 60 |
+
base_url: Local server URL (default: http://localhost:1234/v1)
|
| 61 |
+
timeout: Request timeout in seconds (default longer for local models)
|
| 62 |
+
max_retries: Max retry attempts (fewer for local)
|
| 63 |
+
rate_limit_per_minute: Rate limit for requests per minute (None to disable)
|
| 64 |
+
"""
|
| 65 |
+
import os
|
| 66 |
+
|
| 67 |
+
# Allow overriding via environment variable
|
| 68 |
+
base_url = base_url or os.environ.get("LMSTUDIO_BASE_URL", self.DEFAULT_BASE_URL)
|
| 69 |
+
|
| 70 |
+
super().__init__(
|
| 71 |
+
api_key=api_key or "not-required", # Placeholder
|
| 72 |
+
model=model or self.DEFAULT_MODEL,
|
| 73 |
+
base_url=base_url,
|
| 74 |
+
timeout=timeout,
|
| 75 |
+
max_retries=max_retries,
|
| 76 |
+
rate_limit_per_minute=rate_limit_per_minute,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self._client: httpx.AsyncClient | None = None
|
| 80 |
+
|
| 81 |
+
async def _get_client(self) -> httpx.AsyncClient:
|
| 82 |
+
"""Get or create the HTTP client."""
|
| 83 |
+
if self._client is None or self._client.is_closed:
|
| 84 |
+
headers = {"Content-Type": "application/json"}
|
| 85 |
+
|
| 86 |
+
# Add auth header if provided (some local servers may require it)
|
| 87 |
+
if self.api_key and self.api_key != "not-required":
|
| 88 |
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
| 89 |
+
|
| 90 |
+
self._client = httpx.AsyncClient(
|
| 91 |
+
base_url=self.base_url,
|
| 92 |
+
headers=headers,
|
| 93 |
+
timeout=httpx.Timeout(self.timeout),
|
| 94 |
+
)
|
| 95 |
+
return self._client
|
| 96 |
+
|
| 97 |
+
async def check_health(self) -> bool:
|
| 98 |
+
"""
|
| 99 |
+
Check if LM Studio server is running.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
True if server is accessible, False otherwise
|
| 103 |
+
"""
|
| 104 |
+
try:
|
| 105 |
+
client = await self._get_client()
|
| 106 |
+
response = await client.get("/models")
|
| 107 |
+
return response.status_code == 200
|
| 108 |
+
except Exception:
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
async def list_models(self) -> list[dict]:
|
| 112 |
+
"""
|
| 113 |
+
List available models on the LM Studio server.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
List of model information dicts
|
| 117 |
+
"""
|
| 118 |
+
try:
|
| 119 |
+
client = await self._get_client()
|
| 120 |
+
response = await client.get("/models")
|
| 121 |
+
if response.status_code == 200:
|
| 122 |
+
data = response.json()
|
| 123 |
+
return data.get("data", [])
|
| 124 |
+
return []
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.warning(f"Failed to list models: {e}")
|
| 127 |
+
return []
|
| 128 |
+
|
| 129 |
+
def _handle_error_response(self, response: httpx.Response) -> None:
|
| 130 |
+
"""Handle error responses from LM Studio server."""
|
| 131 |
+
status_code = response.status_code
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
error_data = response.json()
|
| 135 |
+
error_message = error_data.get("error", {}).get("message", response.text)
|
| 136 |
+
except Exception:
|
| 137 |
+
error_message = response.text
|
| 138 |
+
|
| 139 |
+
if status_code >= 500:
|
| 140 |
+
raise LLMServerError(self.PROVIDER_NAME, status_code, error_message)
|
| 141 |
+
else:
|
| 142 |
+
raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code)
|
| 143 |
+
|
| 144 |
+
async def generate(
|
| 145 |
+
self,
|
| 146 |
+
*,
|
| 147 |
+
messages: list[dict] | None = None,
|
| 148 |
+
prompt: str | None = None,
|
| 149 |
+
temperature: float = 0.7,
|
| 150 |
+
max_tokens: int | None = None,
|
| 151 |
+
tools: list[dict] | None = None,
|
| 152 |
+
stream: bool = False,
|
| 153 |
+
stop: list[str] | None = None,
|
| 154 |
+
**kwargs: Any,
|
| 155 |
+
) -> LLMResponse | AsyncIterator[str]:
|
| 156 |
+
"""
|
| 157 |
+
Generate a response from LM Studio local model.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
messages: Chat messages in OpenAI format
|
| 161 |
+
prompt: Simple string prompt
|
| 162 |
+
temperature: Sampling temperature
|
| 163 |
+
max_tokens: Maximum tokens to generate
|
| 164 |
+
tools: Tool definitions (limited support in local models)
|
| 165 |
+
stream: If True, returns AsyncIterator
|
| 166 |
+
stop: Stop sequences
|
| 167 |
+
**kwargs: Additional parameters
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
LLMResponse or AsyncIterator[str] for streaming
|
| 171 |
+
"""
|
| 172 |
+
# Apply rate limiting before proceeding
|
| 173 |
+
await self._apply_rate_limit()
|
| 174 |
+
|
| 175 |
+
if stream:
|
| 176 |
+
return self._generate_stream(
|
| 177 |
+
messages=messages,
|
| 178 |
+
prompt=prompt,
|
| 179 |
+
temperature=temperature,
|
| 180 |
+
max_tokens=max_tokens,
|
| 181 |
+
tools=tools,
|
| 182 |
+
stop=stop,
|
| 183 |
+
**kwargs,
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
return await self._generate_non_stream(
|
| 187 |
+
messages=messages,
|
| 188 |
+
prompt=prompt,
|
| 189 |
+
temperature=temperature,
|
| 190 |
+
max_tokens=max_tokens,
|
| 191 |
+
tools=tools,
|
| 192 |
+
stop=stop,
|
| 193 |
+
**kwargs,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
async def _generate_non_stream(
|
| 197 |
+
self,
|
| 198 |
+
*,
|
| 199 |
+
messages: list[dict] | None = None,
|
| 200 |
+
prompt: str | None = None,
|
| 201 |
+
temperature: float = 0.7,
|
| 202 |
+
max_tokens: int | None = None,
|
| 203 |
+
tools: list[dict] | None = None,
|
| 204 |
+
stop: list[str] | None = None,
|
| 205 |
+
**kwargs: Any,
|
| 206 |
+
) -> LLMResponse:
|
| 207 |
+
"""Non-streaming generation."""
|
| 208 |
+
client = await self._get_client()
|
| 209 |
+
|
| 210 |
+
# Build request payload (OpenAI-compatible)
|
| 211 |
+
payload = {
|
| 212 |
+
"model": self.model,
|
| 213 |
+
"messages": self._build_messages(messages, prompt),
|
| 214 |
+
"temperature": temperature,
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
if max_tokens is not None:
|
| 218 |
+
payload["max_tokens"] = max_tokens
|
| 219 |
+
if stop:
|
| 220 |
+
payload["stop"] = stop
|
| 221 |
+
|
| 222 |
+
# Note: most local models don't support tools well
|
| 223 |
+
if tools:
|
| 224 |
+
logger.warning("Tool calling may not be fully supported by local models")
|
| 225 |
+
payload["tools"] = tools
|
| 226 |
+
|
| 227 |
+
# Add additional kwargs (e.g., top_p, repeat_penalty)
|
| 228 |
+
for key in ["top_p", "top_k", "repeat_penalty", "presence_penalty", "frequency_penalty"]:
|
| 229 |
+
if key in kwargs:
|
| 230 |
+
payload[key] = kwargs[key]
|
| 231 |
+
|
| 232 |
+
# Retry logic for local server
|
| 233 |
+
last_error = None
|
| 234 |
+
for attempt in range(self.max_retries):
|
| 235 |
+
try:
|
| 236 |
+
response = await client.post("/chat/completions", json=payload)
|
| 237 |
+
|
| 238 |
+
if response.status_code != 200:
|
| 239 |
+
self._handle_error_response(response)
|
| 240 |
+
|
| 241 |
+
# Parse response
|
| 242 |
+
try:
|
| 243 |
+
data = response.json()
|
| 244 |
+
choice = data["choices"][0]
|
| 245 |
+
message = choice["message"]
|
| 246 |
+
|
| 247 |
+
usage = data.get("usage", {})
|
| 248 |
+
finish_reason = choice.get("finish_reason", "stop")
|
| 249 |
+
|
| 250 |
+
llm_response = LLMResponse(
|
| 251 |
+
text=message.get("content", ""),
|
| 252 |
+
usage=usage,
|
| 253 |
+
model=data.get("model", self.model),
|
| 254 |
+
raw_response=data,
|
| 255 |
+
finish_reason=finish_reason,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self._update_stats(llm_response)
|
| 259 |
+
return llm_response
|
| 260 |
+
|
| 261 |
+
except (KeyError, json.JSONDecodeError) as e:
|
| 262 |
+
raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e
|
| 263 |
+
|
| 264 |
+
except httpx.TimeoutException:
|
| 265 |
+
last_error = LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
|
| 266 |
+
logger.warning(f"Attempt {attempt + 1} timed out, retrying...")
|
| 267 |
+
except httpx.ConnectError:
|
| 268 |
+
last_error = LLMConnectionError(self.PROVIDER_NAME, self.base_url)
|
| 269 |
+
logger.warning(f"Attempt {attempt + 1} connection failed, retrying...")
|
| 270 |
+
except LLMClientError:
|
| 271 |
+
raise # Don't retry client errors
|
| 272 |
+
|
| 273 |
+
# All retries exhausted
|
| 274 |
+
if last_error:
|
| 275 |
+
raise last_error
|
| 276 |
+
raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
|
| 277 |
+
|
| 278 |
+
async def _generate_stream(
|
| 279 |
+
self,
|
| 280 |
+
*,
|
| 281 |
+
messages: list[dict] | None = None,
|
| 282 |
+
prompt: str | None = None,
|
| 283 |
+
temperature: float = 0.7,
|
| 284 |
+
max_tokens: int | None = None,
|
| 285 |
+
tools: list[dict] | None = None, # noqa: ARG002
|
| 286 |
+
stop: list[str] | None = None,
|
| 287 |
+
**kwargs: Any,
|
| 288 |
+
) -> AsyncIterator[str]:
|
| 289 |
+
"""Streaming generation."""
|
| 290 |
+
client = await self._get_client()
|
| 291 |
+
|
| 292 |
+
# Build request payload
|
| 293 |
+
payload = {
|
| 294 |
+
"model": self.model,
|
| 295 |
+
"messages": self._build_messages(messages, prompt),
|
| 296 |
+
"temperature": temperature,
|
| 297 |
+
"stream": True,
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
if max_tokens is not None:
|
| 301 |
+
payload["max_tokens"] = max_tokens
|
| 302 |
+
if stop:
|
| 303 |
+
payload["stop"] = stop
|
| 304 |
+
|
| 305 |
+
for key in ["top_p", "top_k", "repeat_penalty"]:
|
| 306 |
+
if key in kwargs:
|
| 307 |
+
payload[key] = kwargs[key]
|
| 308 |
+
|
| 309 |
+
async def stream_generator():
|
| 310 |
+
try:
|
| 311 |
+
async with client.stream("POST", "/chat/completions", json=payload) as response:
|
| 312 |
+
if response.status_code != 200:
|
| 313 |
+
await response.aread()
|
| 314 |
+
self._handle_error_response(response)
|
| 315 |
+
|
| 316 |
+
async for line in response.aiter_lines():
|
| 317 |
+
if line.startswith("data: "):
|
| 318 |
+
data_str = line[6:]
|
| 319 |
+
if data_str.strip() == "[DONE]":
|
| 320 |
+
break
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
data = json.loads(data_str)
|
| 324 |
+
delta = data["choices"][0].get("delta", {})
|
| 325 |
+
content = delta.get("content", "")
|
| 326 |
+
if content:
|
| 327 |
+
yield content
|
| 328 |
+
except (json.JSONDecodeError, KeyError):
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
except httpx.TimeoutException:
|
| 332 |
+
raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
|
| 333 |
+
except httpx.ConnectError:
|
| 334 |
+
raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
|
| 335 |
+
except Exception as e:
|
| 336 |
+
if isinstance(e, LLMClientError):
|
| 337 |
+
raise
|
| 338 |
+
raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e
|
| 339 |
+
|
| 340 |
+
return stream_generator()
|
| 341 |
+
|
| 342 |
+
async def close(self) -> None:
|
| 343 |
+
"""Close the HTTP client."""
|
| 344 |
+
if self._client and not self._client.is_closed:
|
| 345 |
+
await self._client.aclose()
|
| 346 |
+
self._client = None
|
src/adapters/llm/openai_client.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI-compatible LLM client adapter.
|
| 3 |
+
|
| 4 |
+
Implements the LLMClient protocol for OpenAI API (and compatible APIs).
|
| 5 |
+
Includes retry logic, circuit breaker pattern, and streaming support.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
from collections.abc import AsyncIterator
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
import httpx
|
| 15 |
+
from tenacity import (
|
| 16 |
+
before_sleep_log,
|
| 17 |
+
retry,
|
| 18 |
+
retry_if_exception_type,
|
| 19 |
+
stop_after_attempt,
|
| 20 |
+
wait_exponential,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from .base import BaseLLMClient, LLMResponse, LLMToolResponse, ToolCall
|
| 24 |
+
from .exceptions import (
|
| 25 |
+
CircuitBreakerOpenError,
|
| 26 |
+
LLMAuthenticationError,
|
| 27 |
+
LLMClientError,
|
| 28 |
+
LLMConnectionError,
|
| 29 |
+
LLMContextLengthError,
|
| 30 |
+
LLMInvalidRequestError,
|
| 31 |
+
LLMModelNotFoundError,
|
| 32 |
+
LLMQuotaExceededError,
|
| 33 |
+
LLMRateLimitError,
|
| 34 |
+
LLMResponseParseError,
|
| 35 |
+
LLMServerError,
|
| 36 |
+
LLMStreamError,
|
| 37 |
+
LLMTimeoutError,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CircuitBreaker:
|
| 44 |
+
"""Simple circuit breaker implementation for resilience."""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
failure_threshold: int = 5,
|
| 49 |
+
reset_timeout: float = 60.0,
|
| 50 |
+
half_open_max_calls: int = 1,
|
| 51 |
+
):
|
| 52 |
+
self.failure_threshold = failure_threshold
|
| 53 |
+
self.reset_timeout = reset_timeout
|
| 54 |
+
self.half_open_max_calls = half_open_max_calls
|
| 55 |
+
self.failure_count = 0
|
| 56 |
+
self.last_failure_time = 0.0
|
| 57 |
+
self.state = "closed" # closed, open, half-open
|
| 58 |
+
self.half_open_calls = 0
|
| 59 |
+
|
| 60 |
+
def can_execute(self) -> bool:
|
| 61 |
+
"""Check if request can be executed."""
|
| 62 |
+
if self.state == "closed":
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
if self.state == "open":
|
| 66 |
+
# Check if reset timeout has passed
|
| 67 |
+
if time.time() - self.last_failure_time >= self.reset_timeout:
|
| 68 |
+
self.state = "half-open"
|
| 69 |
+
self.half_open_calls = 0
|
| 70 |
+
return True
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
if self.state == "half-open":
|
| 74 |
+
return self.half_open_calls < self.half_open_max_calls
|
| 75 |
+
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
def record_success(self) -> None:
|
| 79 |
+
"""Record successful request."""
|
| 80 |
+
if self.state == "half-open":
|
| 81 |
+
self.state = "closed"
|
| 82 |
+
self.failure_count = 0
|
| 83 |
+
elif self.state == "closed":
|
| 84 |
+
self.failure_count = 0
|
| 85 |
+
|
| 86 |
+
def record_failure(self) -> None:
|
| 87 |
+
"""Record failed request."""
|
| 88 |
+
self.failure_count += 1
|
| 89 |
+
self.last_failure_time = time.time()
|
| 90 |
+
|
| 91 |
+
if self.state == "half-open" or self.failure_count >= self.failure_threshold:
|
| 92 |
+
self.state = "open"
|
| 93 |
+
|
| 94 |
+
def get_reset_time(self) -> float:
|
| 95 |
+
"""Get time until circuit resets."""
|
| 96 |
+
if self.state != "open":
|
| 97 |
+
return 0.0
|
| 98 |
+
elapsed = time.time() - self.last_failure_time
|
| 99 |
+
return max(0, self.reset_timeout - elapsed)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class OpenAIClient(BaseLLMClient):
|
| 103 |
+
"""
|
| 104 |
+
OpenAI API client with retry logic and circuit breaker.
|
| 105 |
+
|
| 106 |
+
Features:
|
| 107 |
+
- Exponential backoff retry for transient errors
|
| 108 |
+
- Circuit breaker to prevent cascading failures
|
| 109 |
+
- Streaming support
|
| 110 |
+
- Structured error handling
|
| 111 |
+
- Tool/function calling support
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
PROVIDER_NAME = "openai"
|
| 115 |
+
DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
| 116 |
+
DEFAULT_MODEL = "gpt-4-turbo-preview"
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
api_key: str | None = None,
|
| 121 |
+
model: str | None = None,
|
| 122 |
+
base_url: str | None = None,
|
| 123 |
+
timeout: float = 60.0,
|
| 124 |
+
max_retries: int = 3,
|
| 125 |
+
organization: str | None = None,
|
| 126 |
+
# Circuit breaker settings
|
| 127 |
+
circuit_breaker_threshold: int = 5,
|
| 128 |
+
circuit_breaker_reset: float = 60.0,
|
| 129 |
+
# Rate limiting
|
| 130 |
+
rate_limit_per_minute: int | None = None,
|
| 131 |
+
):
|
| 132 |
+
"""
|
| 133 |
+
Initialize OpenAI client.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
api_key: OpenAI API key (or set OPENAI_API_KEY env var)
|
| 137 |
+
model: Model to use (default: gpt-4-turbo-preview)
|
| 138 |
+
base_url: API base URL (default: https://api.openai.com/v1)
|
| 139 |
+
timeout: Request timeout in seconds
|
| 140 |
+
max_retries: Max retry attempts for transient errors
|
| 141 |
+
organization: Optional organization ID
|
| 142 |
+
circuit_breaker_threshold: Failures before circuit opens
|
| 143 |
+
circuit_breaker_reset: Seconds before circuit resets
|
| 144 |
+
rate_limit_per_minute: Rate limit for requests per minute (None to disable)
|
| 145 |
+
"""
|
| 146 |
+
import os
|
| 147 |
+
|
| 148 |
+
api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
| 149 |
+
if not api_key:
|
| 150 |
+
raise LLMAuthenticationError(self.PROVIDER_NAME, "API key not provided and OPENAI_API_KEY not set")
|
| 151 |
+
|
| 152 |
+
super().__init__(
|
| 153 |
+
api_key=api_key,
|
| 154 |
+
model=model or self.DEFAULT_MODEL,
|
| 155 |
+
base_url=base_url or self.DEFAULT_BASE_URL,
|
| 156 |
+
timeout=timeout,
|
| 157 |
+
max_retries=max_retries,
|
| 158 |
+
rate_limit_per_minute=rate_limit_per_minute,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.organization = organization
|
| 162 |
+
self.circuit_breaker = CircuitBreaker(
|
| 163 |
+
failure_threshold=circuit_breaker_threshold,
|
| 164 |
+
reset_timeout=circuit_breaker_reset,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Initialize async HTTP client
|
| 168 |
+
self._client: httpx.AsyncClient | None = None
|
| 169 |
+
|
| 170 |
+
async def _get_client(self) -> httpx.AsyncClient:
|
| 171 |
+
"""Get or create the HTTP client."""
|
| 172 |
+
if self._client is None or self._client.is_closed:
|
| 173 |
+
headers = {
|
| 174 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 175 |
+
"Content-Type": "application/json",
|
| 176 |
+
}
|
| 177 |
+
if self.organization:
|
| 178 |
+
headers["OpenAI-Organization"] = self.organization
|
| 179 |
+
|
| 180 |
+
self._client = httpx.AsyncClient(
|
| 181 |
+
base_url=self.base_url,
|
| 182 |
+
headers=headers,
|
| 183 |
+
timeout=httpx.Timeout(self.timeout),
|
| 184 |
+
)
|
| 185 |
+
return self._client
|
| 186 |
+
|
| 187 |
+
def _handle_error_response(self, response: httpx.Response) -> None:
|
| 188 |
+
"""Convert HTTP error responses to appropriate exceptions."""
|
| 189 |
+
status_code = response.status_code
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
error_data = response.json()
|
| 193 |
+
error_message = error_data.get("error", {}).get("message", response.text)
|
| 194 |
+
except Exception:
|
| 195 |
+
error_message = response.text
|
| 196 |
+
|
| 197 |
+
if status_code == 401:
|
| 198 |
+
raise LLMAuthenticationError(self.PROVIDER_NAME, error_message)
|
| 199 |
+
elif status_code == 429:
|
| 200 |
+
retry_after = response.headers.get("Retry-After")
|
| 201 |
+
retry_after_float = float(retry_after) if retry_after else None
|
| 202 |
+
raise LLMRateLimitError(self.PROVIDER_NAME, retry_after=retry_after_float, message=error_message)
|
| 203 |
+
elif status_code == 402:
|
| 204 |
+
raise LLMQuotaExceededError(self.PROVIDER_NAME, error_message)
|
| 205 |
+
elif status_code == 404:
|
| 206 |
+
raise LLMModelNotFoundError(self.PROVIDER_NAME, self.model)
|
| 207 |
+
elif status_code == 400:
|
| 208 |
+
if "context_length" in error_message.lower():
|
| 209 |
+
raise LLMContextLengthError(self.PROVIDER_NAME)
|
| 210 |
+
raise LLMInvalidRequestError(self.PROVIDER_NAME, error_message)
|
| 211 |
+
elif status_code >= 500:
|
| 212 |
+
raise LLMServerError(self.PROVIDER_NAME, status_code, error_message)
|
| 213 |
+
else:
|
| 214 |
+
raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code)
|
| 215 |
+
|
| 216 |
+
def _make_retry_decorator(self):
|
| 217 |
+
"""Create retry decorator with exponential backoff."""
|
| 218 |
+
return retry(
|
| 219 |
+
stop=stop_after_attempt(self.max_retries),
|
| 220 |
+
wait=wait_exponential(multiplier=1, min=1, max=60),
|
| 221 |
+
retry=retry_if_exception_type((LLMRateLimitError, LLMServerError, LLMConnectionError)),
|
| 222 |
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
| 223 |
+
reraise=True,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
async def generate(
|
| 227 |
+
self,
|
| 228 |
+
*,
|
| 229 |
+
messages: list[dict] | None = None,
|
| 230 |
+
prompt: str | None = None,
|
| 231 |
+
temperature: float = 0.7,
|
| 232 |
+
max_tokens: int | None = None,
|
| 233 |
+
tools: list[dict] | None = None,
|
| 234 |
+
stream: bool = False,
|
| 235 |
+
stop: list[str] | None = None,
|
| 236 |
+
**kwargs: Any,
|
| 237 |
+
) -> LLMResponse | AsyncIterator[str]:
|
| 238 |
+
"""
|
| 239 |
+
Generate a response from OpenAI.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
messages: Chat messages in OpenAI format
|
| 243 |
+
prompt: Simple string prompt
|
| 244 |
+
temperature: Sampling temperature (0.0 to 2.0)
|
| 245 |
+
max_tokens: Maximum tokens to generate
|
| 246 |
+
tools: Tool definitions for function calling
|
| 247 |
+
stream: If True, returns AsyncIterator
|
| 248 |
+
stop: Stop sequences
|
| 249 |
+
**kwargs: Additional OpenAI parameters (top_p, presence_penalty, etc.)
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
LLMResponse or AsyncIterator[str] for streaming
|
| 253 |
+
"""
|
| 254 |
+
# Apply rate limiting before proceeding
|
| 255 |
+
await self._apply_rate_limit()
|
| 256 |
+
|
| 257 |
+
# Check circuit breaker
|
| 258 |
+
if not self.circuit_breaker.can_execute():
|
| 259 |
+
raise CircuitBreakerOpenError(
|
| 260 |
+
self.PROVIDER_NAME,
|
| 261 |
+
self.circuit_breaker.failure_count,
|
| 262 |
+
self.circuit_breaker.get_reset_time(),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if stream:
|
| 266 |
+
return self._generate_stream(
|
| 267 |
+
messages=messages,
|
| 268 |
+
prompt=prompt,
|
| 269 |
+
temperature=temperature,
|
| 270 |
+
max_tokens=max_tokens,
|
| 271 |
+
tools=tools,
|
| 272 |
+
stop=stop,
|
| 273 |
+
**kwargs,
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
return await self._generate_non_stream(
|
| 277 |
+
messages=messages,
|
| 278 |
+
prompt=prompt,
|
| 279 |
+
temperature=temperature,
|
| 280 |
+
max_tokens=max_tokens,
|
| 281 |
+
tools=tools,
|
| 282 |
+
stop=stop,
|
| 283 |
+
**kwargs,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
async def _generate_non_stream(
|
| 287 |
+
self,
|
| 288 |
+
*,
|
| 289 |
+
messages: list[dict] | None = None,
|
| 290 |
+
prompt: str | None = None,
|
| 291 |
+
temperature: float = 0.7,
|
| 292 |
+
max_tokens: int | None = None,
|
| 293 |
+
tools: list[dict] | None = None,
|
| 294 |
+
stop: list[str] | None = None,
|
| 295 |
+
**kwargs: Any,
|
| 296 |
+
) -> LLMResponse:
|
| 297 |
+
"""Non-streaming generation with retry logic."""
|
| 298 |
+
|
| 299 |
+
@self._make_retry_decorator()
|
| 300 |
+
async def _request():
|
| 301 |
+
client = await self._get_client()
|
| 302 |
+
|
| 303 |
+
# Build request payload
|
| 304 |
+
payload = {
|
| 305 |
+
"model": self.model,
|
| 306 |
+
"messages": self._build_messages(messages, prompt),
|
| 307 |
+
"temperature": temperature,
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
if max_tokens is not None:
|
| 311 |
+
payload["max_tokens"] = max_tokens
|
| 312 |
+
if stop:
|
| 313 |
+
payload["stop"] = stop
|
| 314 |
+
if tools:
|
| 315 |
+
payload["tools"] = tools
|
| 316 |
+
payload["tool_choice"] = kwargs.pop("tool_choice", "auto")
|
| 317 |
+
|
| 318 |
+
# Add any additional kwargs
|
| 319 |
+
payload.update(kwargs)
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
response = await client.post("/chat/completions", json=payload)
|
| 323 |
+
except httpx.TimeoutException:
|
| 324 |
+
raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
|
| 325 |
+
except httpx.ConnectError:
|
| 326 |
+
raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
|
| 327 |
+
|
| 328 |
+
if response.status_code != 200:
|
| 329 |
+
self._handle_error_response(response)
|
| 330 |
+
|
| 331 |
+
return response
|
| 332 |
+
|
| 333 |
+
try:
|
| 334 |
+
response = await _request()
|
| 335 |
+
self.circuit_breaker.record_success()
|
| 336 |
+
except Exception:
|
| 337 |
+
self.circuit_breaker.record_failure()
|
| 338 |
+
raise
|
| 339 |
+
|
| 340 |
+
# Parse response
|
| 341 |
+
try:
|
| 342 |
+
data = response.json()
|
| 343 |
+
choice = data["choices"][0]
|
| 344 |
+
message = choice["message"]
|
| 345 |
+
|
| 346 |
+
usage = data.get("usage", {})
|
| 347 |
+
finish_reason = choice.get("finish_reason", "stop")
|
| 348 |
+
|
| 349 |
+
# Check for tool calls
|
| 350 |
+
if "tool_calls" in message:
|
| 351 |
+
tool_calls = [
|
| 352 |
+
ToolCall(
|
| 353 |
+
id=tc["id"],
|
| 354 |
+
name=tc["function"]["name"],
|
| 355 |
+
arguments=json.loads(tc["function"]["arguments"]),
|
| 356 |
+
)
|
| 357 |
+
for tc in message["tool_calls"]
|
| 358 |
+
]
|
| 359 |
+
llm_response = LLMToolResponse(
|
| 360 |
+
text=message.get("content", ""),
|
| 361 |
+
usage=usage,
|
| 362 |
+
model=data.get("model", self.model),
|
| 363 |
+
raw_response=data,
|
| 364 |
+
finish_reason=finish_reason,
|
| 365 |
+
tool_calls=tool_calls,
|
| 366 |
+
)
|
| 367 |
+
else:
|
| 368 |
+
llm_response = LLMResponse(
|
| 369 |
+
text=message.get("content", ""),
|
| 370 |
+
usage=usage,
|
| 371 |
+
model=data.get("model", self.model),
|
| 372 |
+
raw_response=data,
|
| 373 |
+
finish_reason=finish_reason,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self._update_stats(llm_response)
|
| 377 |
+
return llm_response
|
| 378 |
+
|
| 379 |
+
except (KeyError, json.JSONDecodeError) as e:
|
| 380 |
+
raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e
|
| 381 |
+
|
| 382 |
+
async def _generate_stream(
|
| 383 |
+
self,
|
| 384 |
+
*,
|
| 385 |
+
messages: list[dict] | None = None,
|
| 386 |
+
prompt: str | None = None,
|
| 387 |
+
temperature: float = 0.7,
|
| 388 |
+
max_tokens: int | None = None,
|
| 389 |
+
tools: list[dict] | None = None,
|
| 390 |
+
stop: list[str] | None = None,
|
| 391 |
+
**kwargs: Any,
|
| 392 |
+
) -> AsyncIterator[str]:
|
| 393 |
+
"""Streaming generation."""
|
| 394 |
+
|
| 395 |
+
client = await self._get_client()
|
| 396 |
+
|
| 397 |
+
# Build request payload
|
| 398 |
+
payload = {
|
| 399 |
+
"model": self.model,
|
| 400 |
+
"messages": self._build_messages(messages, prompt),
|
| 401 |
+
"temperature": temperature,
|
| 402 |
+
"stream": True,
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
if max_tokens is not None:
|
| 406 |
+
payload["max_tokens"] = max_tokens
|
| 407 |
+
if stop:
|
| 408 |
+
payload["stop"] = stop
|
| 409 |
+
# Note: tools with streaming have limited support
|
| 410 |
+
if tools:
|
| 411 |
+
payload["tools"] = tools
|
| 412 |
+
|
| 413 |
+
payload.update(kwargs)
|
| 414 |
+
|
| 415 |
+
async def stream_generator():
|
| 416 |
+
try:
|
| 417 |
+
async with client.stream("POST", "/chat/completions", json=payload) as response:
|
| 418 |
+
if response.status_code != 200:
|
| 419 |
+
# Read the full response for error handling
|
| 420 |
+
await response.aread()
|
| 421 |
+
self._handle_error_response(response)
|
| 422 |
+
|
| 423 |
+
async for line in response.aiter_lines():
|
| 424 |
+
if line.startswith("data: "):
|
| 425 |
+
data_str = line[6:]
|
| 426 |
+
if data_str.strip() == "[DONE]":
|
| 427 |
+
break
|
| 428 |
+
|
| 429 |
+
try:
|
| 430 |
+
data = json.loads(data_str)
|
| 431 |
+
delta = data["choices"][0].get("delta", {})
|
| 432 |
+
content = delta.get("content", "")
|
| 433 |
+
if content:
|
| 434 |
+
yield content
|
| 435 |
+
except (json.JSONDecodeError, KeyError):
|
| 436 |
+
continue
|
| 437 |
+
|
| 438 |
+
self.circuit_breaker.record_success()
|
| 439 |
+
|
| 440 |
+
except httpx.TimeoutException:
|
| 441 |
+
self.circuit_breaker.record_failure()
|
| 442 |
+
raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout)
|
| 443 |
+
except httpx.ConnectError:
|
| 444 |
+
self.circuit_breaker.record_failure()
|
| 445 |
+
raise LLMConnectionError(self.PROVIDER_NAME, self.base_url)
|
| 446 |
+
except Exception as e:
|
| 447 |
+
self.circuit_breaker.record_failure()
|
| 448 |
+
if isinstance(e, LLMClientError):
|
| 449 |
+
raise
|
| 450 |
+
raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e
|
| 451 |
+
|
| 452 |
+
return stream_generator()
|
| 453 |
+
|
| 454 |
+
async def close(self) -> None:
|
| 455 |
+
"""Close the HTTP client."""
|
| 456 |
+
if self._client and not self._client.is_closed:
|
| 457 |
+
await self._client.aclose()
|
| 458 |
+
self._client = None
|
src/agents/__init__.py
ADDED
|
File without changes
|
src/agents/hrm_agent.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hierarchical Reasoning Model (HRM) Agent.
|
| 3 |
+
|
| 4 |
+
Implements the HRM architecture with:
|
| 5 |
+
- H-Module: High-level planning and decomposition
|
| 6 |
+
- L-Module: Low-level execution and refinement
|
| 7 |
+
- Adaptive Computation Time (ACT) for dynamic depth
|
| 8 |
+
- Halting mechanism based on confidence thresholds
|
| 9 |
+
|
| 10 |
+
Based on: "Hierarchical Reasoning for Compositional Generalization"
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
from ..training.system_config import HRMConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class SubProblem:
|
| 26 |
+
"""Represents a decomposed subproblem in the hierarchy."""
|
| 27 |
+
|
| 28 |
+
level: int # Hierarchy level (0 = root, higher = more abstract)
|
| 29 |
+
description: str # Natural language description
|
| 30 |
+
state: torch.Tensor # Latent state representation
|
| 31 |
+
parent_id: int | None = None # Parent subproblem ID
|
| 32 |
+
confidence: float = 0.0 # Confidence in this decomposition
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class HRMOutput:
|
| 37 |
+
"""Output from HRM processing."""
|
| 38 |
+
|
| 39 |
+
final_state: torch.Tensor # Final processed state
|
| 40 |
+
subproblems: list[SubProblem] # Hierarchical decomposition
|
| 41 |
+
halt_step: int # Step at which halting occurred
|
| 42 |
+
total_ponder_cost: float # Total computation cost (for training)
|
| 43 |
+
convergence_path: list[float] # Confidence at each step
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AdaptiveComputationTime(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Adaptive Computation Time (ACT) mechanism for dynamic depth.
|
| 49 |
+
|
| 50 |
+
Allows the model to "ponder" longer on difficult problems by
|
| 51 |
+
dynamically adjusting the number of processing steps.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, hidden_dim: int, epsilon: float = 0.01):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.epsilon = epsilon
|
| 57 |
+
|
| 58 |
+
# Halting unit: predicts probability of halting
|
| 59 |
+
self.halt_fc = nn.Sequential(
|
| 60 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 61 |
+
nn.ReLU(),
|
| 62 |
+
nn.Linear(hidden_dim // 2, 1),
|
| 63 |
+
nn.Sigmoid(),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, float]:
|
| 67 |
+
"""
|
| 68 |
+
Compute halting probabilities.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
hidden_states: [batch, seq, hidden_dim]
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
halt_probs: [batch, seq] probability of halting
|
| 75 |
+
ponder_cost: Scalar cost for training
|
| 76 |
+
"""
|
| 77 |
+
# Compute halting probabilities
|
| 78 |
+
halt_logits = self.halt_fc(hidden_states) # [batch, seq, 1]
|
| 79 |
+
halt_probs = halt_logits.squeeze(-1) # [batch, seq]
|
| 80 |
+
|
| 81 |
+
# Ponder cost is the expected number of steps
|
| 82 |
+
ponder_cost = halt_probs.sum(dim=-1).mean()
|
| 83 |
+
|
| 84 |
+
return halt_probs, ponder_cost
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class HModule(nn.Module):
|
| 88 |
+
"""
|
| 89 |
+
H-Module: High-level planning and abstract reasoning.
|
| 90 |
+
|
| 91 |
+
Responsible for:
|
| 92 |
+
- Decomposing problems into subproblems
|
| 93 |
+
- Abstract planning and strategy
|
| 94 |
+
- Coordinating L-module executions
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, config: HRMConfig):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.config = config
|
| 100 |
+
|
| 101 |
+
# Multi-head self-attention for relational reasoning
|
| 102 |
+
self.attention = nn.MultiheadAttention(
|
| 103 |
+
embed_dim=config.h_dim,
|
| 104 |
+
num_heads=8,
|
| 105 |
+
dropout=config.dropout,
|
| 106 |
+
batch_first=True,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Feed-forward network
|
| 110 |
+
self.ffn = nn.Sequential(
|
| 111 |
+
nn.Linear(config.h_dim, config.h_dim * 4),
|
| 112 |
+
nn.GELU(),
|
| 113 |
+
nn.Dropout(config.dropout),
|
| 114 |
+
nn.Linear(config.h_dim * 4, config.h_dim),
|
| 115 |
+
nn.Dropout(config.dropout),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Layer normalization
|
| 119 |
+
self.norm1 = nn.LayerNorm(config.h_dim)
|
| 120 |
+
self.norm2 = nn.LayerNorm(config.h_dim)
|
| 121 |
+
|
| 122 |
+
# Decomposition head: outputs subproblem structure
|
| 123 |
+
self.decompose_head = nn.Sequential(
|
| 124 |
+
nn.Linear(config.h_dim, config.h_dim),
|
| 125 |
+
nn.ReLU(),
|
| 126 |
+
nn.Linear(config.h_dim, config.h_dim),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
"""
|
| 131 |
+
Process input through high-level reasoning.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
x: [batch, seq, h_dim] input tensor
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
[batch, seq, h_dim] processed tensor
|
| 138 |
+
"""
|
| 139 |
+
# Self-attention for relational reasoning
|
| 140 |
+
attn_out, _ = self.attention(x, x, x)
|
| 141 |
+
x = self.norm1(x + attn_out)
|
| 142 |
+
|
| 143 |
+
# Feed-forward processing
|
| 144 |
+
ffn_out = self.ffn(x)
|
| 145 |
+
x = self.norm2(x + ffn_out)
|
| 146 |
+
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
def decompose(self, x: torch.Tensor) -> torch.Tensor:
|
| 150 |
+
"""Generate subproblem representations."""
|
| 151 |
+
return self.decompose_head(x)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class LModule(nn.Module):
|
| 155 |
+
"""
|
| 156 |
+
L-Module: Low-level execution and concrete operations.
|
| 157 |
+
|
| 158 |
+
Responsible for:
|
| 159 |
+
- Executing concrete operations
|
| 160 |
+
- Processing individual subproblems
|
| 161 |
+
- Generating intermediate results
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(self, config: HRMConfig):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.config = config
|
| 167 |
+
|
| 168 |
+
# Projection from H-module to L-module dimension
|
| 169 |
+
self.h_to_l = nn.Linear(config.h_dim, config.l_dim)
|
| 170 |
+
|
| 171 |
+
# GRU for sequential processing
|
| 172 |
+
self.gru = nn.GRU(
|
| 173 |
+
input_size=config.l_dim,
|
| 174 |
+
hidden_size=config.l_dim,
|
| 175 |
+
num_layers=config.num_l_layers,
|
| 176 |
+
dropout=config.dropout if config.num_l_layers > 1 else 0,
|
| 177 |
+
batch_first=True,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Output projection
|
| 181 |
+
self.output_proj = nn.Sequential(
|
| 182 |
+
nn.Linear(config.l_dim, config.l_dim * 2),
|
| 183 |
+
nn.ReLU(),
|
| 184 |
+
nn.Dropout(config.dropout),
|
| 185 |
+
nn.Linear(config.l_dim * 2, config.l_dim),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Back-projection to H-module dimension
|
| 189 |
+
self.l_to_h = nn.Linear(config.l_dim, config.h_dim)
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor, h_context: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
|
| 192 |
+
"""
|
| 193 |
+
Execute low-level processing.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
x: [batch, seq, h_dim] input from H-module
|
| 197 |
+
h_context: Optional hidden state
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
output: [batch, seq, l_dim] processed output
|
| 201 |
+
l_to_h: [batch, seq, h_dim] back-projection to H-module
|
| 202 |
+
"""
|
| 203 |
+
# Project to L-module dimension
|
| 204 |
+
x_l = self.h_to_l(x)
|
| 205 |
+
|
| 206 |
+
# Sequential processing
|
| 207 |
+
gru_out, _ = self.gru(x_l, h_context)
|
| 208 |
+
|
| 209 |
+
# Output processing
|
| 210 |
+
output = self.output_proj(gru_out)
|
| 211 |
+
|
| 212 |
+
# Back-project to H-module dimension for feedback
|
| 213 |
+
feedback = self.l_to_h(output)
|
| 214 |
+
|
| 215 |
+
return output, feedback
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class HRMAgent(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
Complete Hierarchical Reasoning Model agent.
|
| 221 |
+
|
| 222 |
+
Combines H-module and L-module with ACT for adaptive computation.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, config: HRMConfig, device: str = "cpu"):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.config = config
|
| 228 |
+
self.device = device
|
| 229 |
+
|
| 230 |
+
# Input embedding
|
| 231 |
+
self.input_proj = nn.Linear(config.h_dim, config.h_dim)
|
| 232 |
+
|
| 233 |
+
# Core modules
|
| 234 |
+
self.h_module = nn.ModuleList([HModule(config) for _ in range(config.num_h_layers)])
|
| 235 |
+
|
| 236 |
+
self.l_module = LModule(config)
|
| 237 |
+
|
| 238 |
+
# Adaptive computation time
|
| 239 |
+
self.act = AdaptiveComputationTime(config.h_dim, config.ponder_epsilon)
|
| 240 |
+
|
| 241 |
+
# State integration
|
| 242 |
+
self.integrate = nn.Sequential(
|
| 243 |
+
nn.Linear(config.h_dim * 2, config.h_dim),
|
| 244 |
+
nn.LayerNorm(config.h_dim),
|
| 245 |
+
nn.GELU(),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
self.to(device)
|
| 249 |
+
|
| 250 |
+
def forward(
|
| 251 |
+
self,
|
| 252 |
+
x: torch.Tensor,
|
| 253 |
+
max_steps: int | None = None,
|
| 254 |
+
return_decomposition: bool = False,
|
| 255 |
+
) -> HRMOutput:
|
| 256 |
+
"""
|
| 257 |
+
Process input through hierarchical reasoning.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
x: [batch, seq, h_dim] input tensor
|
| 261 |
+
max_steps: Maximum outer loop steps (defaults to config)
|
| 262 |
+
return_decomposition: Whether to return subproblem decomposition
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
HRMOutput containing final state and optional decomposition
|
| 266 |
+
"""
|
| 267 |
+
batch_size, seq_len, _ = x.shape
|
| 268 |
+
max_steps = max_steps or self.config.max_outer_steps
|
| 269 |
+
|
| 270 |
+
# Initial projection
|
| 271 |
+
h_state = self.input_proj(x)
|
| 272 |
+
|
| 273 |
+
# Tracking
|
| 274 |
+
subproblems = []
|
| 275 |
+
convergence_path = []
|
| 276 |
+
total_ponder_cost = 0.0
|
| 277 |
+
|
| 278 |
+
# Outer loop: iterative refinement
|
| 279 |
+
for step in range(max_steps):
|
| 280 |
+
# H-module: high-level planning
|
| 281 |
+
for h_layer in self.h_module:
|
| 282 |
+
h_state = h_layer(h_state)
|
| 283 |
+
|
| 284 |
+
# Check halting condition
|
| 285 |
+
halt_probs, ponder_cost = self.act(h_state)
|
| 286 |
+
total_ponder_cost += ponder_cost
|
| 287 |
+
|
| 288 |
+
# Average halting probability across sequence
|
| 289 |
+
avg_halt_prob = halt_probs.mean().item()
|
| 290 |
+
convergence_path.append(avg_halt_prob)
|
| 291 |
+
|
| 292 |
+
# Generate subproblem decomposition if requested
|
| 293 |
+
if return_decomposition:
|
| 294 |
+
subproblem_repr = self.h_module[0].decompose(h_state)
|
| 295 |
+
# Create subproblem entries (simplified)
|
| 296 |
+
for i in range(min(3, seq_len)): # Top 3 subproblems
|
| 297 |
+
subproblems.append(
|
| 298 |
+
SubProblem(
|
| 299 |
+
level=step,
|
| 300 |
+
description=f"Subproblem at step {step}, position {i}",
|
| 301 |
+
state=subproblem_repr[:, i, :].detach(),
|
| 302 |
+
confidence=halt_probs[:, i].mean().item(),
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Halt if confident enough
|
| 307 |
+
if avg_halt_prob >= self.config.halt_threshold:
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
# L-module: low-level execution
|
| 311 |
+
l_output, l_feedback = self.l_module(h_state)
|
| 312 |
+
|
| 313 |
+
# Integrate L-module feedback
|
| 314 |
+
h_state = self.integrate(torch.cat([h_state, l_feedback], dim=-1))
|
| 315 |
+
|
| 316 |
+
return HRMOutput(
|
| 317 |
+
final_state=h_state,
|
| 318 |
+
subproblems=subproblems,
|
| 319 |
+
halt_step=step + 1,
|
| 320 |
+
total_ponder_cost=total_ponder_cost,
|
| 321 |
+
convergence_path=convergence_path,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
async def decompose_problem(self, query: str, state: torch.Tensor) -> list[SubProblem]:
|
| 325 |
+
"""
|
| 326 |
+
Decompose a problem into hierarchical subproblems.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
query: Natural language problem description
|
| 330 |
+
state: Initial state representation
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
List of subproblems in hierarchical order
|
| 334 |
+
"""
|
| 335 |
+
# Ensure state is batched
|
| 336 |
+
if state.dim() == 2:
|
| 337 |
+
state = state.unsqueeze(0) # [1, seq, dim]
|
| 338 |
+
|
| 339 |
+
# Forward pass with decomposition
|
| 340 |
+
output = self.forward(state, return_decomposition=True)
|
| 341 |
+
|
| 342 |
+
# Add query context to subproblems
|
| 343 |
+
for i, sp in enumerate(output.subproblems):
|
| 344 |
+
sp.description = f"{query} -> Level {sp.level} Subproblem {i}"
|
| 345 |
+
|
| 346 |
+
return output.subproblems
|
| 347 |
+
|
| 348 |
+
def get_parameter_count(self) -> int:
|
| 349 |
+
"""Return total number of trainable parameters."""
|
| 350 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# Training utilities
|
| 354 |
+
class HRMLoss(nn.Module):
|
| 355 |
+
"""
|
| 356 |
+
Combined loss for HRM training.
|
| 357 |
+
|
| 358 |
+
Includes:
|
| 359 |
+
- Task loss (e.g., cross-entropy for classification)
|
| 360 |
+
- Ponder cost regularization (encourages efficiency)
|
| 361 |
+
- Consistency loss (encourages stable convergence)
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
def __init__(
|
| 365 |
+
self,
|
| 366 |
+
task_weight: float = 1.0,
|
| 367 |
+
ponder_weight: float = 0.01,
|
| 368 |
+
consistency_weight: float = 0.1,
|
| 369 |
+
):
|
| 370 |
+
super().__init__()
|
| 371 |
+
self.task_weight = task_weight
|
| 372 |
+
self.ponder_weight = ponder_weight
|
| 373 |
+
self.consistency_weight = consistency_weight
|
| 374 |
+
|
| 375 |
+
def forward(
|
| 376 |
+
self,
|
| 377 |
+
hrm_output: HRMOutput,
|
| 378 |
+
predictions: torch.Tensor,
|
| 379 |
+
targets: torch.Tensor,
|
| 380 |
+
task_loss_fn: nn.Module,
|
| 381 |
+
) -> tuple[torch.Tensor, dict]:
|
| 382 |
+
"""
|
| 383 |
+
Compute combined loss.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
hrm_output: Output from HRM forward pass
|
| 387 |
+
predictions: Model predictions
|
| 388 |
+
targets: Ground truth targets
|
| 389 |
+
task_loss_fn: Loss function for the task
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
total_loss: Combined loss
|
| 393 |
+
loss_dict: Dictionary of individual loss components
|
| 394 |
+
"""
|
| 395 |
+
# Task loss
|
| 396 |
+
task_loss = task_loss_fn(predictions, targets)
|
| 397 |
+
|
| 398 |
+
# Ponder cost (encourages efficiency)
|
| 399 |
+
ponder_loss = hrm_output.total_ponder_cost
|
| 400 |
+
|
| 401 |
+
# Consistency loss (encourages monotonic convergence)
|
| 402 |
+
if len(hrm_output.convergence_path) > 1:
|
| 403 |
+
conv_tensor = torch.tensor(hrm_output.convergence_path)
|
| 404 |
+
# Penalize non-monotonic increases
|
| 405 |
+
diffs = conv_tensor[1:] - conv_tensor[:-1]
|
| 406 |
+
consistency_loss = F.relu(-diffs).mean() # Penalize decreases
|
| 407 |
+
else:
|
| 408 |
+
consistency_loss = torch.tensor(0.0)
|
| 409 |
+
|
| 410 |
+
# Combine losses
|
| 411 |
+
total_loss = (
|
| 412 |
+
self.task_weight * task_loss + self.ponder_weight * ponder_loss + self.consistency_weight * consistency_loss
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
loss_dict = {
|
| 416 |
+
"total": total_loss.item(),
|
| 417 |
+
"task": task_loss.item(),
|
| 418 |
+
"ponder": ponder_loss,
|
| 419 |
+
"consistency": consistency_loss.item(),
|
| 420 |
+
"halt_step": hrm_output.halt_step,
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
return total_loss, loss_dict
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def create_hrm_agent(config: HRMConfig, device: str = "cpu") -> HRMAgent:
|
| 427 |
+
"""
|
| 428 |
+
Factory function to create and initialize HRM agent.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
config: HRM configuration
|
| 432 |
+
device: Device to place model on
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
Initialized HRMAgent
|
| 436 |
+
"""
|
| 437 |
+
agent = HRMAgent(config, device)
|
| 438 |
+
|
| 439 |
+
# Initialize weights
|
| 440 |
+
def init_weights(m):
|
| 441 |
+
if isinstance(m, nn.Linear):
|
| 442 |
+
nn.init.xavier_uniform_(m.weight)
|
| 443 |
+
if m.bias is not None:
|
| 444 |
+
nn.init.zeros_(m.bias)
|
| 445 |
+
elif isinstance(m, nn.GRU):
|
| 446 |
+
for name, param in m.named_parameters():
|
| 447 |
+
if "weight" in name:
|
| 448 |
+
nn.init.orthogonal_(param)
|
| 449 |
+
elif "bias" in name:
|
| 450 |
+
nn.init.zeros_(param)
|
| 451 |
+
|
| 452 |
+
agent.apply(init_weights)
|
| 453 |
+
|
| 454 |
+
return agent
|
src/agents/meta_controller/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Neural Meta-Controller package for Multi-Agent MCTS Framework.
|
| 3 |
+
|
| 4 |
+
This package provides the base infrastructure for neural network-based
|
| 5 |
+
meta-controllers that dynamically select which agent to route queries to.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from src.agents.meta_controller.base import (
|
| 9 |
+
AbstractMetaController,
|
| 10 |
+
MetaControllerFeatures,
|
| 11 |
+
MetaControllerPrediction,
|
| 12 |
+
)
|
| 13 |
+
from src.agents.meta_controller.rnn_controller import (
|
| 14 |
+
RNNMetaController,
|
| 15 |
+
RNNMetaControllerModel,
|
| 16 |
+
)
|
| 17 |
+
from src.agents.meta_controller.utils import (
|
| 18 |
+
features_to_tensor,
|
| 19 |
+
features_to_text,
|
| 20 |
+
normalize_features,
|
| 21 |
+
one_hot_encode_agent,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Import BERT controller (may not be available if transformers/peft not installed)
|
| 25 |
+
try:
|
| 26 |
+
from src.agents.meta_controller.bert_controller import BERTMetaController # noqa: F401
|
| 27 |
+
|
| 28 |
+
_bert_available = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
_bert_available = False
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
"AbstractMetaController",
|
| 34 |
+
"MetaControllerFeatures",
|
| 35 |
+
"MetaControllerPrediction",
|
| 36 |
+
"normalize_features",
|
| 37 |
+
"one_hot_encode_agent",
|
| 38 |
+
"features_to_tensor",
|
| 39 |
+
"features_to_text",
|
| 40 |
+
"RNNMetaController",
|
| 41 |
+
"RNNMetaControllerModel",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
if _bert_available:
|
| 45 |
+
__all__.append("BERTMetaController")
|
src/agents/meta_controller/base.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Abstract base class for Neural Meta-Controllers.
|
| 3 |
+
|
| 4 |
+
Provides the foundation for neural network-based meta-controllers that
|
| 5 |
+
dynamically select which agent (HRM, TRM, or MCTS) should handle a query.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class MetaControllerFeatures:
|
| 15 |
+
"""
|
| 16 |
+
Features extracted from the current agent state for meta-controller prediction.
|
| 17 |
+
|
| 18 |
+
These features capture the current state of the multi-agent system,
|
| 19 |
+
including confidence scores from different agents and contextual information.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
hrm_confidence: float
|
| 23 |
+
"""Confidence score from the HRM (Human Response Model) agent."""
|
| 24 |
+
|
| 25 |
+
trm_confidence: float
|
| 26 |
+
"""Confidence score from the TRM (Task Response Model) agent."""
|
| 27 |
+
|
| 28 |
+
mcts_value: float
|
| 29 |
+
"""Value estimate from the MCTS (Monte Carlo Tree Search) process."""
|
| 30 |
+
|
| 31 |
+
consensus_score: float
|
| 32 |
+
"""Agreement score between different agents."""
|
| 33 |
+
|
| 34 |
+
last_agent: str
|
| 35 |
+
"""Name of the last agent used ('hrm', 'trm', 'mcts', or 'none')."""
|
| 36 |
+
|
| 37 |
+
iteration: int
|
| 38 |
+
"""Current iteration number in the reasoning process."""
|
| 39 |
+
|
| 40 |
+
query_length: int
|
| 41 |
+
"""Length of the input query in characters."""
|
| 42 |
+
|
| 43 |
+
has_rag_context: bool
|
| 44 |
+
"""Whether RAG (Retrieval-Augmented Generation) context is available."""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class MetaControllerPrediction:
|
| 49 |
+
"""
|
| 50 |
+
Prediction output from the meta-controller.
|
| 51 |
+
|
| 52 |
+
Contains the selected agent and associated confidence/probability information.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
agent: str
|
| 56 |
+
"""Name of the selected agent ('hrm', 'trm', or 'mcts')."""
|
| 57 |
+
|
| 58 |
+
confidence: float
|
| 59 |
+
"""Confidence score for the prediction (0.0 to 1.0)."""
|
| 60 |
+
|
| 61 |
+
probabilities: dict[str, float] = field(default_factory=dict)
|
| 62 |
+
"""Probability distribution over all possible agents."""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class AbstractMetaController(ABC):
|
| 66 |
+
"""
|
| 67 |
+
Abstract base class for neural meta-controllers.
|
| 68 |
+
|
| 69 |
+
This class defines the interface that all meta-controller implementations
|
| 70 |
+
must follow. Meta-controllers are responsible for deciding which agent
|
| 71 |
+
should handle a given query based on the current system state.
|
| 72 |
+
|
| 73 |
+
Attributes:
|
| 74 |
+
AGENT_NAMES: List of valid agent names that can be selected.
|
| 75 |
+
name: Name of this meta-controller instance.
|
| 76 |
+
seed: Random seed for reproducibility.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
AGENT_NAMES = ["hrm", "trm", "mcts"]
|
| 80 |
+
|
| 81 |
+
def __init__(self, name: str, seed: int = 42) -> None:
|
| 82 |
+
"""
|
| 83 |
+
Initialize the meta-controller.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
name: Name identifier for this meta-controller instance.
|
| 87 |
+
seed: Random seed for reproducibility. Defaults to 42.
|
| 88 |
+
"""
|
| 89 |
+
self.name = name
|
| 90 |
+
self.seed = seed
|
| 91 |
+
|
| 92 |
+
@abstractmethod
|
| 93 |
+
def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction:
|
| 94 |
+
"""
|
| 95 |
+
Predict which agent should handle the current query.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
features: Features extracted from the current agent state.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Prediction containing the selected agent and confidence scores.
|
| 102 |
+
"""
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
@abstractmethod
|
| 106 |
+
def load_model(self, path: str) -> None:
|
| 107 |
+
"""
|
| 108 |
+
Load a trained model from disk.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
path: Path to the saved model file or directory.
|
| 112 |
+
"""
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
@abstractmethod
|
| 116 |
+
def save_model(self, path: str) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Save the current model to disk.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
path: Path where the model should be saved.
|
| 122 |
+
"""
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
def extract_features(self, state: dict[str, Any]) -> MetaControllerFeatures:
|
| 126 |
+
"""
|
| 127 |
+
Extract meta-controller features from an AgentState dictionary.
|
| 128 |
+
|
| 129 |
+
This method converts raw state information into the structured
|
| 130 |
+
MetaControllerFeatures format required for prediction.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
state: Dictionary containing agent state information.
|
| 134 |
+
Expected keys include:
|
| 135 |
+
- 'hrm_confidence' or nested in 'agent_confidences'
|
| 136 |
+
- 'trm_confidence' or nested in 'agent_confidences'
|
| 137 |
+
- 'mcts_value' or nested in 'mcts_state'
|
| 138 |
+
- 'consensus_score'
|
| 139 |
+
- 'last_agent'
|
| 140 |
+
- 'iteration'
|
| 141 |
+
- 'query' or 'query_length'
|
| 142 |
+
- 'rag_context' or 'has_rag_context'
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
MetaControllerFeatures instance with extracted values.
|
| 146 |
+
|
| 147 |
+
Example:
|
| 148 |
+
>>> state = {
|
| 149 |
+
... 'agent_confidences': {'hrm': 0.8, 'trm': 0.6},
|
| 150 |
+
... 'mcts_state': {'value': 0.75},
|
| 151 |
+
... 'consensus_score': 0.7,
|
| 152 |
+
... 'last_agent': 'hrm',
|
| 153 |
+
... 'iteration': 2,
|
| 154 |
+
... 'query': 'What is machine learning?',
|
| 155 |
+
... 'rag_context': 'ML is a subset of AI...'
|
| 156 |
+
... }
|
| 157 |
+
>>> features = controller.extract_features(state)
|
| 158 |
+
"""
|
| 159 |
+
# Extract HRM confidence
|
| 160 |
+
if "hrm_confidence" in state:
|
| 161 |
+
hrm_confidence = float(state["hrm_confidence"])
|
| 162 |
+
elif "agent_confidences" in state and isinstance(state["agent_confidences"], dict):
|
| 163 |
+
hrm_confidence = float(state["agent_confidences"].get("hrm", 0.0))
|
| 164 |
+
else:
|
| 165 |
+
hrm_confidence = 0.0
|
| 166 |
+
|
| 167 |
+
# Extract TRM confidence
|
| 168 |
+
if "trm_confidence" in state:
|
| 169 |
+
trm_confidence = float(state["trm_confidence"])
|
| 170 |
+
elif "agent_confidences" in state and isinstance(state["agent_confidences"], dict):
|
| 171 |
+
trm_confidence = float(state["agent_confidences"].get("trm", 0.0))
|
| 172 |
+
else:
|
| 173 |
+
trm_confidence = 0.0
|
| 174 |
+
|
| 175 |
+
# Extract MCTS value
|
| 176 |
+
if "mcts_value" in state:
|
| 177 |
+
mcts_value = float(state["mcts_value"])
|
| 178 |
+
elif "mcts_state" in state and isinstance(state["mcts_state"], dict):
|
| 179 |
+
mcts_value = float(state["mcts_state"].get("value", 0.0))
|
| 180 |
+
else:
|
| 181 |
+
mcts_value = 0.0
|
| 182 |
+
|
| 183 |
+
# Extract consensus score
|
| 184 |
+
consensus_score = float(state.get("consensus_score", 0.0))
|
| 185 |
+
|
| 186 |
+
# Extract last agent
|
| 187 |
+
last_agent = str(state.get("last_agent", "none"))
|
| 188 |
+
if last_agent not in self.AGENT_NAMES and last_agent != "none":
|
| 189 |
+
last_agent = "none"
|
| 190 |
+
|
| 191 |
+
# Extract iteration
|
| 192 |
+
iteration = int(state.get("iteration", 0))
|
| 193 |
+
|
| 194 |
+
# Extract query length
|
| 195 |
+
if "query_length" in state:
|
| 196 |
+
query_length = int(state["query_length"])
|
| 197 |
+
elif "query" in state and isinstance(state["query"], str):
|
| 198 |
+
query_length = len(state["query"])
|
| 199 |
+
else:
|
| 200 |
+
query_length = 0
|
| 201 |
+
|
| 202 |
+
# Extract has_rag_context
|
| 203 |
+
if "has_rag_context" in state:
|
| 204 |
+
has_rag_context = bool(state["has_rag_context"])
|
| 205 |
+
elif "rag_context" in state:
|
| 206 |
+
has_rag_context = state["rag_context"] is not None and len(str(state["rag_context"])) > 0
|
| 207 |
+
else:
|
| 208 |
+
has_rag_context = False
|
| 209 |
+
|
| 210 |
+
return MetaControllerFeatures(
|
| 211 |
+
hrm_confidence=hrm_confidence,
|
| 212 |
+
trm_confidence=trm_confidence,
|
| 213 |
+
mcts_value=mcts_value,
|
| 214 |
+
consensus_score=consensus_score,
|
| 215 |
+
last_agent=last_agent,
|
| 216 |
+
iteration=iteration,
|
| 217 |
+
query_length=query_length,
|
| 218 |
+
has_rag_context=has_rag_context,
|
| 219 |
+
)
|
src/agents/meta_controller/bert_controller.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BERT-based Meta-Controller with LoRA adapters for efficient fine-tuning.
|
| 3 |
+
|
| 4 |
+
This module provides a BERT-based meta-controller that uses Low-Rank Adaptation (LoRA)
|
| 5 |
+
for parameter-efficient fine-tuning. The controller converts agent state features into
|
| 6 |
+
text and uses a sequence classification model to predict the optimal agent.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import warnings
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from src.agents.meta_controller.base import (
|
| 15 |
+
AbstractMetaController,
|
| 16 |
+
MetaControllerFeatures,
|
| 17 |
+
MetaControllerPrediction,
|
| 18 |
+
)
|
| 19 |
+
from src.agents.meta_controller.utils import features_to_text
|
| 20 |
+
|
| 21 |
+
# Handle optional transformers and peft imports gracefully
|
| 22 |
+
_TRANSFORMERS_AVAILABLE = False
|
| 23 |
+
_PEFT_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 27 |
+
|
| 28 |
+
_TRANSFORMERS_AVAILABLE = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
warnings.warn(
|
| 31 |
+
"transformers library not installed. Install it with: pip install transformers",
|
| 32 |
+
ImportWarning,
|
| 33 |
+
stacklevel=2,
|
| 34 |
+
)
|
| 35 |
+
AutoTokenizer = None # type: ignore
|
| 36 |
+
AutoModelForSequenceClassification = None # type: ignore
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
| 40 |
+
|
| 41 |
+
_PEFT_AVAILABLE = True
|
| 42 |
+
except ImportError:
|
| 43 |
+
warnings.warn(
|
| 44 |
+
"peft library not installed. Install it with: pip install peft",
|
| 45 |
+
ImportWarning,
|
| 46 |
+
stacklevel=2,
|
| 47 |
+
)
|
| 48 |
+
LoraConfig = None # type: ignore
|
| 49 |
+
TaskType = None # type: ignore
|
| 50 |
+
get_peft_model = None # type: ignore
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class BERTMetaController(AbstractMetaController):
|
| 54 |
+
"""
|
| 55 |
+
BERT-based meta-controller with optional LoRA adapters for efficient fine-tuning.
|
| 56 |
+
|
| 57 |
+
This controller converts agent state features into structured text and uses
|
| 58 |
+
a pre-trained BERT model (with optional LoRA adapters) to classify which
|
| 59 |
+
agent should handle the current query. LoRA enables parameter-efficient
|
| 60 |
+
fine-tuning by only training low-rank decomposition matrices.
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
DEFAULT_MODEL_NAME: Default BERT model to use.
|
| 64 |
+
NUM_LABELS: Number of output labels (agents to choose from).
|
| 65 |
+
device: PyTorch device for tensor operations.
|
| 66 |
+
model_name: Name of the pre-trained model.
|
| 67 |
+
lora_r: LoRA rank parameter.
|
| 68 |
+
lora_alpha: LoRA alpha scaling parameter.
|
| 69 |
+
lora_dropout: LoRA dropout rate.
|
| 70 |
+
use_lora: Whether to use LoRA adapters.
|
| 71 |
+
tokenizer: BERT tokenizer for text processing.
|
| 72 |
+
model: BERT sequence classification model (with or without LoRA).
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
>>> controller = BERTMetaController(name="BERTController", seed=42)
|
| 76 |
+
>>> features = MetaControllerFeatures(
|
| 77 |
+
... hrm_confidence=0.8,
|
| 78 |
+
... trm_confidence=0.6,
|
| 79 |
+
... mcts_value=0.75,
|
| 80 |
+
... consensus_score=0.7,
|
| 81 |
+
... last_agent='hrm',
|
| 82 |
+
... iteration=2,
|
| 83 |
+
... query_length=150,
|
| 84 |
+
... has_rag_context=True
|
| 85 |
+
... )
|
| 86 |
+
>>> prediction = controller.predict(features)
|
| 87 |
+
>>> prediction.agent in ['hrm', 'trm', 'mcts']
|
| 88 |
+
True
|
| 89 |
+
>>> 0.0 <= prediction.confidence <= 1.0
|
| 90 |
+
True
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
DEFAULT_MODEL_NAME = "prajjwal1/bert-mini"
|
| 94 |
+
NUM_LABELS = 3
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
name: str = "BERTMetaController",
|
| 99 |
+
seed: int = 42,
|
| 100 |
+
model_name: str | None = None,
|
| 101 |
+
lora_r: int = 4,
|
| 102 |
+
lora_alpha: int = 16,
|
| 103 |
+
lora_dropout: float = 0.1,
|
| 104 |
+
device: str | None = None,
|
| 105 |
+
use_lora: bool = True,
|
| 106 |
+
) -> None:
|
| 107 |
+
"""
|
| 108 |
+
Initialize the BERT meta-controller with optional LoRA adapters.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
name: Name identifier for this controller. Defaults to "BERTMetaController".
|
| 112 |
+
seed: Random seed for reproducibility. Defaults to 42.
|
| 113 |
+
model_name: Pre-trained model name from HuggingFace. If None, uses DEFAULT_MODEL_NAME.
|
| 114 |
+
lora_r: LoRA rank parameter (lower = more compression). Defaults to 4.
|
| 115 |
+
lora_alpha: LoRA alpha scaling parameter. Defaults to 16.
|
| 116 |
+
lora_dropout: Dropout rate for LoRA layers. Defaults to 0.1.
|
| 117 |
+
device: Device to run model on ('cpu', 'cuda', 'mps', etc.).
|
| 118 |
+
If None, auto-detects best available device.
|
| 119 |
+
use_lora: Whether to apply LoRA adapters to the model. Defaults to True.
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
ImportError: If transformers library is not installed.
|
| 123 |
+
ImportError: If use_lora is True and peft library is not installed.
|
| 124 |
+
|
| 125 |
+
Example:
|
| 126 |
+
>>> controller = BERTMetaController(
|
| 127 |
+
... name="CustomBERT",
|
| 128 |
+
... seed=123,
|
| 129 |
+
... lora_r=8,
|
| 130 |
+
... lora_alpha=32,
|
| 131 |
+
... use_lora=True
|
| 132 |
+
... )
|
| 133 |
+
"""
|
| 134 |
+
super().__init__(name=name, seed=seed)
|
| 135 |
+
|
| 136 |
+
# Check for required dependencies
|
| 137 |
+
if not _TRANSFORMERS_AVAILABLE:
|
| 138 |
+
raise ImportError(
|
| 139 |
+
"transformers library is required for BERTMetaController. Install it with: pip install transformers"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if use_lora and not _PEFT_AVAILABLE:
|
| 143 |
+
raise ImportError("peft library is required for LoRA support. Install it with: pip install peft")
|
| 144 |
+
|
| 145 |
+
# Set random seed for reproducibility
|
| 146 |
+
torch.manual_seed(seed)
|
| 147 |
+
|
| 148 |
+
# Auto-detect device if not specified
|
| 149 |
+
if device is None:
|
| 150 |
+
if torch.cuda.is_available():
|
| 151 |
+
self.device = torch.device("cuda")
|
| 152 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 153 |
+
self.device = torch.device("mps")
|
| 154 |
+
else:
|
| 155 |
+
self.device = torch.device("cpu")
|
| 156 |
+
else:
|
| 157 |
+
self.device = torch.device(device)
|
| 158 |
+
|
| 159 |
+
# Store configuration parameters
|
| 160 |
+
self.model_name = model_name if model_name is not None else self.DEFAULT_MODEL_NAME
|
| 161 |
+
self.lora_r = lora_r
|
| 162 |
+
self.lora_alpha = lora_alpha
|
| 163 |
+
self.lora_dropout = lora_dropout
|
| 164 |
+
self.use_lora = use_lora
|
| 165 |
+
|
| 166 |
+
# Initialize tokenizer
|
| 167 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 168 |
+
|
| 169 |
+
# Initialize base model for sequence classification
|
| 170 |
+
base_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.NUM_LABELS)
|
| 171 |
+
|
| 172 |
+
# Apply LoRA adapters if requested
|
| 173 |
+
if self.use_lora:
|
| 174 |
+
lora_config = LoraConfig(
|
| 175 |
+
task_type=TaskType.SEQ_CLS,
|
| 176 |
+
r=self.lora_r,
|
| 177 |
+
lora_alpha=self.lora_alpha,
|
| 178 |
+
lora_dropout=self.lora_dropout,
|
| 179 |
+
target_modules=["query", "value"],
|
| 180 |
+
)
|
| 181 |
+
self.model = get_peft_model(base_model, lora_config)
|
| 182 |
+
else:
|
| 183 |
+
self.model = base_model
|
| 184 |
+
|
| 185 |
+
# Move model to device
|
| 186 |
+
self.model = self.model.to(self.device)
|
| 187 |
+
|
| 188 |
+
# Set model to evaluation mode
|
| 189 |
+
self.model.eval()
|
| 190 |
+
|
| 191 |
+
# Initialize tokenization cache for performance optimization
|
| 192 |
+
self._tokenization_cache: dict[str, Any] = {}
|
| 193 |
+
|
| 194 |
+
def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction:
|
| 195 |
+
"""
|
| 196 |
+
Predict which agent should handle the current query.
|
| 197 |
+
|
| 198 |
+
Converts features to structured text, tokenizes the text, runs through
|
| 199 |
+
the BERT model, and returns a prediction with confidence scores.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
features: Features extracted from the current agent state.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Prediction containing the selected agent, confidence score,
|
| 206 |
+
and probability distribution over all agents.
|
| 207 |
+
|
| 208 |
+
Example:
|
| 209 |
+
>>> controller = BERTMetaController()
|
| 210 |
+
>>> features = MetaControllerFeatures(
|
| 211 |
+
... hrm_confidence=0.9,
|
| 212 |
+
... trm_confidence=0.3,
|
| 213 |
+
... mcts_value=0.5,
|
| 214 |
+
... consensus_score=0.8,
|
| 215 |
+
... last_agent='none',
|
| 216 |
+
... iteration=0,
|
| 217 |
+
... query_length=100,
|
| 218 |
+
... has_rag_context=False
|
| 219 |
+
... )
|
| 220 |
+
>>> pred = controller.predict(features)
|
| 221 |
+
>>> isinstance(pred.agent, str)
|
| 222 |
+
True
|
| 223 |
+
>>> isinstance(pred.confidence, float)
|
| 224 |
+
True
|
| 225 |
+
>>> len(pred.probabilities) == 3
|
| 226 |
+
True
|
| 227 |
+
"""
|
| 228 |
+
# Convert features to structured text
|
| 229 |
+
text = features_to_text(features)
|
| 230 |
+
|
| 231 |
+
# Check cache for tokenized text
|
| 232 |
+
if text in self._tokenization_cache:
|
| 233 |
+
inputs = self._tokenization_cache[text]
|
| 234 |
+
else:
|
| 235 |
+
# Tokenize the text
|
| 236 |
+
inputs = self.tokenizer(
|
| 237 |
+
text,
|
| 238 |
+
return_tensors="pt",
|
| 239 |
+
padding=True,
|
| 240 |
+
truncation=True,
|
| 241 |
+
max_length=512,
|
| 242 |
+
)
|
| 243 |
+
# Cache the tokenized result
|
| 244 |
+
self._tokenization_cache[text] = inputs
|
| 245 |
+
|
| 246 |
+
# Move inputs to device
|
| 247 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
| 248 |
+
|
| 249 |
+
# Perform inference without gradient tracking
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
# Get logits from model
|
| 252 |
+
outputs = self.model(**inputs)
|
| 253 |
+
logits = outputs.logits
|
| 254 |
+
|
| 255 |
+
# Apply softmax to get probabilities
|
| 256 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
| 257 |
+
|
| 258 |
+
# Get predicted agent index (argmax)
|
| 259 |
+
predicted_idx = torch.argmax(probabilities, dim=-1).item()
|
| 260 |
+
|
| 261 |
+
# Extract confidence for selected agent
|
| 262 |
+
confidence = probabilities[0, predicted_idx].item()
|
| 263 |
+
|
| 264 |
+
# Create probability dictionary
|
| 265 |
+
prob_dict: dict[str, float] = {}
|
| 266 |
+
for i, agent_name in enumerate(self.AGENT_NAMES):
|
| 267 |
+
prob_dict[agent_name] = probabilities[0, i].item()
|
| 268 |
+
|
| 269 |
+
# Get agent name
|
| 270 |
+
selected_agent = self.AGENT_NAMES[predicted_idx]
|
| 271 |
+
|
| 272 |
+
return MetaControllerPrediction(
|
| 273 |
+
agent=selected_agent,
|
| 274 |
+
confidence=float(confidence),
|
| 275 |
+
probabilities=prob_dict,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def load_model(self, path: str) -> None:
|
| 279 |
+
"""
|
| 280 |
+
Load a trained model from disk.
|
| 281 |
+
|
| 282 |
+
For LoRA models, loads the PEFT adapter weights. For base models,
|
| 283 |
+
loads the full state dictionary.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
path: Path to the saved model file or directory.
|
| 287 |
+
For LoRA models, this should be a directory containing
|
| 288 |
+
adapter_config.json and adapter_model.bin.
|
| 289 |
+
For base models, this should be a .pt or .pth file.
|
| 290 |
+
|
| 291 |
+
Raises:
|
| 292 |
+
FileNotFoundError: If the model file or directory does not exist.
|
| 293 |
+
RuntimeError: If the state dict is incompatible with the model.
|
| 294 |
+
|
| 295 |
+
Example:
|
| 296 |
+
>>> controller = BERTMetaController(use_lora=True)
|
| 297 |
+
>>> controller.load_model("/path/to/lora_adapter")
|
| 298 |
+
>>> controller = BERTMetaController(use_lora=False)
|
| 299 |
+
>>> controller.load_model("/path/to/model.pt")
|
| 300 |
+
"""
|
| 301 |
+
if self.use_lora:
|
| 302 |
+
# Load PEFT adapter weights
|
| 303 |
+
# For PEFT models, the path should be a directory containing adapter files
|
| 304 |
+
from peft import PeftModel
|
| 305 |
+
|
| 306 |
+
# Get the base model from the PEFT wrapper
|
| 307 |
+
base_model = self.model.get_base_model()
|
| 308 |
+
|
| 309 |
+
# Load the PEFT model from the saved path
|
| 310 |
+
self.model = PeftModel.from_pretrained(base_model, path)
|
| 311 |
+
self.model = self.model.to(self.device)
|
| 312 |
+
else:
|
| 313 |
+
# Load base model state dict
|
| 314 |
+
state_dict = torch.load(path, map_location=self.device, weights_only=True)
|
| 315 |
+
self.model.load_state_dict(state_dict)
|
| 316 |
+
|
| 317 |
+
# Ensure model is in evaluation mode
|
| 318 |
+
self.model.eval()
|
| 319 |
+
|
| 320 |
+
def save_model(self, path: str) -> None:
|
| 321 |
+
"""
|
| 322 |
+
Save the current model to disk.
|
| 323 |
+
|
| 324 |
+
For LoRA models, saves the PEFT adapter weights. For base models,
|
| 325 |
+
saves the full state dictionary.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
path: Path where the model should be saved.
|
| 329 |
+
For LoRA models, this should be a directory path where
|
| 330 |
+
adapter_config.json and adapter_model.bin will be saved.
|
| 331 |
+
For base models, this should be a .pt or .pth file path.
|
| 332 |
+
|
| 333 |
+
Example:
|
| 334 |
+
>>> controller = BERTMetaController(use_lora=True)
|
| 335 |
+
>>> controller.save_model("/path/to/lora_adapter")
|
| 336 |
+
>>> controller = BERTMetaController(use_lora=False)
|
| 337 |
+
>>> controller.save_model("/path/to/model.pt")
|
| 338 |
+
"""
|
| 339 |
+
if self.use_lora:
|
| 340 |
+
# Save PEFT adapter weights
|
| 341 |
+
# This saves only the LoRA adapter weights, not the full model
|
| 342 |
+
self.model.save_pretrained(path)
|
| 343 |
+
else:
|
| 344 |
+
# Save base model state dict
|
| 345 |
+
torch.save(self.model.state_dict(), path)
|
| 346 |
+
|
| 347 |
+
def clear_cache(self) -> None:
|
| 348 |
+
"""
|
| 349 |
+
Clear the tokenization cache.
|
| 350 |
+
|
| 351 |
+
This method removes all cached tokenized inputs, freeing memory.
|
| 352 |
+
Useful when processing many different feature combinations or
|
| 353 |
+
when memory usage is a concern.
|
| 354 |
+
|
| 355 |
+
Example:
|
| 356 |
+
>>> controller = BERTMetaController()
|
| 357 |
+
>>> # After many predictions...
|
| 358 |
+
>>> controller.clear_cache()
|
| 359 |
+
>>> info = controller.get_cache_info()
|
| 360 |
+
>>> info['cache_size'] == 0
|
| 361 |
+
True
|
| 362 |
+
"""
|
| 363 |
+
self._tokenization_cache.clear()
|
| 364 |
+
|
| 365 |
+
def get_cache_info(self) -> dict[str, Any]:
|
| 366 |
+
"""
|
| 367 |
+
Get information about the current tokenization cache.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
Dictionary containing cache statistics:
|
| 371 |
+
- cache_size: Number of cached tokenizations
|
| 372 |
+
- cache_keys: List of cached text inputs (truncated for display)
|
| 373 |
+
|
| 374 |
+
Example:
|
| 375 |
+
>>> controller = BERTMetaController()
|
| 376 |
+
>>> features = MetaControllerFeatures(
|
| 377 |
+
... hrm_confidence=0.8,
|
| 378 |
+
... trm_confidence=0.6,
|
| 379 |
+
... mcts_value=0.75,
|
| 380 |
+
... consensus_score=0.7,
|
| 381 |
+
... last_agent='hrm',
|
| 382 |
+
... iteration=2,
|
| 383 |
+
... query_length=150,
|
| 384 |
+
... has_rag_context=True
|
| 385 |
+
... )
|
| 386 |
+
>>> _ = controller.predict(features)
|
| 387 |
+
>>> info = controller.get_cache_info()
|
| 388 |
+
>>> 'cache_size' in info
|
| 389 |
+
True
|
| 390 |
+
>>> info['cache_size'] >= 1
|
| 391 |
+
True
|
| 392 |
+
"""
|
| 393 |
+
# Truncate keys for display (first 50 chars)
|
| 394 |
+
truncated_keys = [key[:50] + "..." if len(key) > 50 else key for key in self._tokenization_cache]
|
| 395 |
+
|
| 396 |
+
return {
|
| 397 |
+
"cache_size": len(self._tokenization_cache),
|
| 398 |
+
"cache_keys": truncated_keys,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
def get_trainable_parameters(self) -> dict[str, int]:
|
| 402 |
+
"""
|
| 403 |
+
Get the number of trainable and total parameters in the model.
|
| 404 |
+
|
| 405 |
+
This is particularly useful for LoRA models to see the efficiency
|
| 406 |
+
gains from using low-rank adaptation.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
Dictionary containing:
|
| 410 |
+
- total_params: Total number of parameters in the model
|
| 411 |
+
- trainable_params: Number of trainable parameters
|
| 412 |
+
- trainable_percentage: Percentage of parameters that are trainable
|
| 413 |
+
|
| 414 |
+
Example:
|
| 415 |
+
>>> controller = BERTMetaController(use_lora=True)
|
| 416 |
+
>>> params = controller.get_trainable_parameters()
|
| 417 |
+
>>> params['trainable_percentage'] < 10.0 # LoRA trains <10% of params
|
| 418 |
+
True
|
| 419 |
+
"""
|
| 420 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 421 |
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 422 |
+
trainable_percentage = (trainable_params / total_params) * 100 if total_params > 0 else 0.0
|
| 423 |
+
|
| 424 |
+
return {
|
| 425 |
+
"total_params": total_params,
|
| 426 |
+
"trainable_params": trainable_params,
|
| 427 |
+
"trainable_percentage": round(trainable_percentage, 2),
|
| 428 |
+
}
|
src/agents/meta_controller/config_loader.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration loader for the Neural Meta-Controller framework.
|
| 3 |
+
|
| 4 |
+
This module provides dataclass-based configuration management for the Meta-Controller,
|
| 5 |
+
supporting both RNN and BERT-based neural network controllers with comprehensive
|
| 6 |
+
validation and serialization capabilities.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataclasses import asdict, dataclass, field
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class RNNConfig:
|
| 18 |
+
"""
|
| 19 |
+
Configuration for RNN-based Meta-Controller.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
hidden_dim: Hidden dimension size for RNN layers. Default is 64.
|
| 23 |
+
num_layers: Number of RNN layers. Default is 1.
|
| 24 |
+
dropout: Dropout rate for regularization. Default is 0.1.
|
| 25 |
+
model_path: Optional path to a pre-trained model file. None for untrained model.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
hidden_dim: int = 64
|
| 29 |
+
num_layers: int = 1
|
| 30 |
+
dropout: float = 0.1
|
| 31 |
+
model_path: str | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class BERTConfig:
|
| 36 |
+
"""
|
| 37 |
+
Configuration for BERT-based Meta-Controller with LoRA fine-tuning.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
model_name: Name of the pre-trained BERT model from HuggingFace.
|
| 41 |
+
Default is "prajjwal1/bert-mini" for lightweight deployment.
|
| 42 |
+
use_lora: Whether to use LoRA (Low-Rank Adaptation) for efficient fine-tuning.
|
| 43 |
+
Default is True.
|
| 44 |
+
lora_r: LoRA rank parameter. Controls the rank of the low-rank matrices.
|
| 45 |
+
Default is 4.
|
| 46 |
+
lora_alpha: LoRA alpha parameter. Scaling factor for LoRA weights.
|
| 47 |
+
Default is 16.
|
| 48 |
+
lora_dropout: Dropout rate for LoRA layers. Default is 0.1.
|
| 49 |
+
model_path: Optional path to a trained LoRA adapter. None for base model only.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
model_name: str = "prajjwal1/bert-mini"
|
| 53 |
+
use_lora: bool = True
|
| 54 |
+
lora_r: int = 4
|
| 55 |
+
lora_alpha: int = 16
|
| 56 |
+
lora_dropout: float = 0.1
|
| 57 |
+
model_path: str | None = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class InferenceConfig:
|
| 62 |
+
"""
|
| 63 |
+
Configuration for inference settings.
|
| 64 |
+
|
| 65 |
+
Attributes:
|
| 66 |
+
device: Device to use for inference ("cpu", "cuda", "cuda:0", etc.).
|
| 67 |
+
None for auto-detection based on available hardware.
|
| 68 |
+
seed: Random seed for reproducibility. Default is 42.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
device: str | None = None
|
| 72 |
+
seed: int = 42
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class MetaControllerConfig:
|
| 77 |
+
"""
|
| 78 |
+
Main configuration for the Neural Meta-Controller framework.
|
| 79 |
+
|
| 80 |
+
This configuration controls the behavior of the Meta-Controller, including
|
| 81 |
+
which type of neural network to use (RNN or BERT), fallback behavior,
|
| 82 |
+
and specific model parameters.
|
| 83 |
+
|
| 84 |
+
Attributes:
|
| 85 |
+
enabled: Whether the neural Meta-Controller is enabled. Default is False
|
| 86 |
+
for backward compatibility with rule-based systems.
|
| 87 |
+
type: Type of neural network controller ("rnn" or "bert"). Default is "rnn".
|
| 88 |
+
fallback_to_rule_based: Whether to fall back to rule-based selection on errors.
|
| 89 |
+
Default is True for robustness.
|
| 90 |
+
rnn: Configuration for RNN-based controller.
|
| 91 |
+
bert: Configuration for BERT-based controller.
|
| 92 |
+
inference: Configuration for inference settings.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
enabled: bool = False
|
| 96 |
+
type: str = "rnn" # "rnn" or "bert"
|
| 97 |
+
fallback_to_rule_based: bool = True
|
| 98 |
+
rnn: RNNConfig = field(default_factory=RNNConfig)
|
| 99 |
+
bert: BERTConfig = field(default_factory=BERTConfig)
|
| 100 |
+
inference: InferenceConfig = field(default_factory=InferenceConfig)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MetaControllerConfigLoader:
|
| 104 |
+
"""
|
| 105 |
+
Loader class for Meta-Controller configuration.
|
| 106 |
+
|
| 107 |
+
Provides methods for loading configuration from YAML files or dictionaries,
|
| 108 |
+
converting configuration to dictionaries, and validating configuration values.
|
| 109 |
+
|
| 110 |
+
Example:
|
| 111 |
+
>>> loader = MetaControllerConfigLoader()
|
| 112 |
+
>>> config = loader.load_from_yaml("config/meta_controller.yaml")
|
| 113 |
+
>>> print(config.type)
|
| 114 |
+
'rnn'
|
| 115 |
+
>>> config.validate()
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def load_from_yaml(path: str) -> MetaControllerConfig:
|
| 120 |
+
"""
|
| 121 |
+
Load Meta-Controller configuration from a YAML file.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
path: Path to the YAML configuration file.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
MetaControllerConfig: Loaded and parsed configuration object.
|
| 128 |
+
|
| 129 |
+
Raises:
|
| 130 |
+
FileNotFoundError: If the specified file does not exist.
|
| 131 |
+
yaml.YAMLError: If the file contains invalid YAML.
|
| 132 |
+
KeyError: If the 'meta_controller' key is missing from the file.
|
| 133 |
+
|
| 134 |
+
Example:
|
| 135 |
+
>>> config = MetaControllerConfigLoader.load_from_yaml("config/meta_controller.yaml")
|
| 136 |
+
>>> print(config.enabled)
|
| 137 |
+
False
|
| 138 |
+
"""
|
| 139 |
+
yaml_path = Path(path)
|
| 140 |
+
|
| 141 |
+
if not yaml_path.exists():
|
| 142 |
+
raise FileNotFoundError(f"Configuration file not found: {path}")
|
| 143 |
+
|
| 144 |
+
with open(yaml_path) as f:
|
| 145 |
+
raw_config = yaml.safe_load(f)
|
| 146 |
+
|
| 147 |
+
if "meta_controller" not in raw_config:
|
| 148 |
+
raise KeyError("Configuration file must contain 'meta_controller' key")
|
| 149 |
+
|
| 150 |
+
return MetaControllerConfigLoader.load_from_dict(raw_config["meta_controller"])
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def load_from_dict(config_dict: dict[str, Any]) -> MetaControllerConfig:
|
| 154 |
+
"""
|
| 155 |
+
Load Meta-Controller configuration from a dictionary.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
config_dict: Dictionary containing configuration values.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
MetaControllerConfig: Parsed configuration object with defaults
|
| 162 |
+
applied for missing values.
|
| 163 |
+
|
| 164 |
+
Example:
|
| 165 |
+
>>> config_dict = {
|
| 166 |
+
... 'enabled': True,
|
| 167 |
+
... 'type': 'bert',
|
| 168 |
+
... 'bert': {'model_name': 'bert-base-uncased'}
|
| 169 |
+
... }
|
| 170 |
+
>>> config = MetaControllerConfigLoader.load_from_dict(config_dict)
|
| 171 |
+
>>> print(config.type)
|
| 172 |
+
'bert'
|
| 173 |
+
"""
|
| 174 |
+
# Parse nested configurations
|
| 175 |
+
rnn_config = RNNConfig(**config_dict.get("rnn", {}))
|
| 176 |
+
bert_config = BERTConfig(**config_dict.get("bert", {}))
|
| 177 |
+
inference_config = InferenceConfig(**config_dict.get("inference", {}))
|
| 178 |
+
|
| 179 |
+
# Create main config with nested configs
|
| 180 |
+
return MetaControllerConfig(
|
| 181 |
+
enabled=config_dict.get("enabled", False),
|
| 182 |
+
type=config_dict.get("type", "rnn"),
|
| 183 |
+
fallback_to_rule_based=config_dict.get("fallback_to_rule_based", True),
|
| 184 |
+
rnn=rnn_config,
|
| 185 |
+
bert=bert_config,
|
| 186 |
+
inference=inference_config,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def to_dict(config: MetaControllerConfig) -> dict[str, Any]:
|
| 191 |
+
"""
|
| 192 |
+
Convert a MetaControllerConfig object to a dictionary.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
config: MetaControllerConfig object to convert.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Dict[str, Any]: Dictionary representation of the configuration.
|
| 199 |
+
|
| 200 |
+
Example:
|
| 201 |
+
>>> config = MetaControllerConfig(enabled=True, type='bert')
|
| 202 |
+
>>> config_dict = MetaControllerConfigLoader.to_dict(config)
|
| 203 |
+
>>> print(config_dict['enabled'])
|
| 204 |
+
True
|
| 205 |
+
"""
|
| 206 |
+
return asdict(config)
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
def validate(config: MetaControllerConfig) -> None:
|
| 210 |
+
"""
|
| 211 |
+
Validate the Meta-Controller configuration.
|
| 212 |
+
|
| 213 |
+
Checks that:
|
| 214 |
+
- The controller type is valid ("rnn" or "bert")
|
| 215 |
+
- Model paths exist if specified
|
| 216 |
+
- Numeric parameters are within valid ranges
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
config: MetaControllerConfig object to validate.
|
| 220 |
+
|
| 221 |
+
Raises:
|
| 222 |
+
ValueError: If the configuration contains invalid values.
|
| 223 |
+
FileNotFoundError: If specified model paths do not exist.
|
| 224 |
+
|
| 225 |
+
Example:
|
| 226 |
+
>>> config = MetaControllerConfig(type='invalid')
|
| 227 |
+
>>> MetaControllerConfigLoader.validate(config)
|
| 228 |
+
ValueError: Invalid controller type 'invalid'. Must be 'rnn' or 'bert'.
|
| 229 |
+
"""
|
| 230 |
+
# Validate controller type
|
| 231 |
+
valid_types = ["rnn", "bert"]
|
| 232 |
+
if config.type not in valid_types:
|
| 233 |
+
raise ValueError(f"Invalid controller type '{config.type}'. Must be one of: {valid_types}")
|
| 234 |
+
|
| 235 |
+
# Validate RNN config
|
| 236 |
+
if config.rnn.hidden_dim <= 0:
|
| 237 |
+
raise ValueError(f"RNN hidden_dim must be positive, got {config.rnn.hidden_dim}")
|
| 238 |
+
if config.rnn.num_layers <= 0:
|
| 239 |
+
raise ValueError(f"RNN num_layers must be positive, got {config.rnn.num_layers}")
|
| 240 |
+
if not 0.0 <= config.rnn.dropout <= 1.0:
|
| 241 |
+
raise ValueError(f"RNN dropout must be between 0 and 1, got {config.rnn.dropout}")
|
| 242 |
+
if config.rnn.model_path is not None:
|
| 243 |
+
rnn_path = Path(config.rnn.model_path)
|
| 244 |
+
if not rnn_path.exists():
|
| 245 |
+
raise FileNotFoundError(f"RNN model path does not exist: {config.rnn.model_path}")
|
| 246 |
+
|
| 247 |
+
# Validate BERT config
|
| 248 |
+
if config.bert.lora_r <= 0:
|
| 249 |
+
raise ValueError(f"BERT lora_r must be positive, got {config.bert.lora_r}")
|
| 250 |
+
if config.bert.lora_alpha <= 0:
|
| 251 |
+
raise ValueError(f"BERT lora_alpha must be positive, got {config.bert.lora_alpha}")
|
| 252 |
+
if not 0.0 <= config.bert.lora_dropout <= 1.0:
|
| 253 |
+
raise ValueError(f"BERT lora_dropout must be between 0 and 1, got {config.bert.lora_dropout}")
|
| 254 |
+
if config.bert.model_path is not None:
|
| 255 |
+
bert_path = Path(config.bert.model_path)
|
| 256 |
+
if not bert_path.exists():
|
| 257 |
+
raise FileNotFoundError(f"BERT model path does not exist: {config.bert.model_path}")
|
| 258 |
+
|
| 259 |
+
# Validate inference config
|
| 260 |
+
if config.inference.device is not None:
|
| 261 |
+
valid_devices = ["cpu", "cuda", "mps"]
|
| 262 |
+
# Check if device starts with a valid prefix (e.g., "cuda:0", "cuda:1")
|
| 263 |
+
device_base = config.inference.device.split(":")[0]
|
| 264 |
+
if device_base not in valid_devices:
|
| 265 |
+
raise ValueError(f"Invalid device '{config.inference.device}'. Must start with one of: {valid_devices}")
|
| 266 |
+
|
| 267 |
+
if not isinstance(config.inference.seed, int) or config.inference.seed < 0:
|
| 268 |
+
raise ValueError(f"Inference seed must be a non-negative integer, got {config.inference.seed}")
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def save_to_yaml(config: MetaControllerConfig, path: str) -> None:
|
| 272 |
+
"""
|
| 273 |
+
Save a MetaControllerConfig object to a YAML file.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
config: MetaControllerConfig object to save.
|
| 277 |
+
path: Path where the YAML file will be saved.
|
| 278 |
+
|
| 279 |
+
Example:
|
| 280 |
+
>>> config = MetaControllerConfig(enabled=True)
|
| 281 |
+
>>> MetaControllerConfigLoader.save_to_yaml(config, "my_config.yaml")
|
| 282 |
+
"""
|
| 283 |
+
yaml_path = Path(path)
|
| 284 |
+
yaml_path.parent.mkdir(parents=True, exist_ok=True)
|
| 285 |
+
|
| 286 |
+
config_dict = {"meta_controller": MetaControllerConfigLoader.to_dict(config)}
|
| 287 |
+
|
| 288 |
+
with open(yaml_path, "w") as f:
|
| 289 |
+
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
| 290 |
+
|
| 291 |
+
@staticmethod
|
| 292 |
+
def get_default_config() -> MetaControllerConfig:
|
| 293 |
+
"""
|
| 294 |
+
Get a default MetaControllerConfig with all default values.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
MetaControllerConfig: Configuration object with default values.
|
| 298 |
+
|
| 299 |
+
Example:
|
| 300 |
+
>>> config = MetaControllerConfigLoader.get_default_config()
|
| 301 |
+
>>> print(config.enabled)
|
| 302 |
+
False
|
| 303 |
+
"""
|
| 304 |
+
return MetaControllerConfig()
|
src/agents/meta_controller/rnn_controller.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RNN-based Meta-Controller for dynamic agent selection.
|
| 3 |
+
|
| 4 |
+
This module provides a GRU-based recurrent neural network meta-controller
|
| 5 |
+
that learns to select the optimal agent (HRM, TRM, or MCTS) based on
|
| 6 |
+
sequential patterns in the agent state features.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from src.agents.meta_controller.base import (
|
| 14 |
+
AbstractMetaController,
|
| 15 |
+
MetaControllerFeatures,
|
| 16 |
+
MetaControllerPrediction,
|
| 17 |
+
)
|
| 18 |
+
from src.agents.meta_controller.utils import features_to_tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RNNMetaControllerModel(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
GRU-based neural network model for meta-controller predictions.
|
| 24 |
+
|
| 25 |
+
This model uses a Gated Recurrent Unit (GRU) to capture sequential
|
| 26 |
+
patterns in agent state features and predict which agent should be
|
| 27 |
+
selected next.
|
| 28 |
+
|
| 29 |
+
Architecture:
|
| 30 |
+
- GRU layer for sequence processing
|
| 31 |
+
- Dropout for regularization
|
| 32 |
+
- Linear layer for classification
|
| 33 |
+
|
| 34 |
+
Attributes:
|
| 35 |
+
gru: GRU recurrent layer for processing sequences.
|
| 36 |
+
dropout: Dropout layer for regularization.
|
| 37 |
+
fc: Fully connected output layer.
|
| 38 |
+
hidden_dim: Dimension of the hidden state.
|
| 39 |
+
num_layers: Number of GRU layers.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
input_dim: int = 10,
|
| 45 |
+
hidden_dim: int = 64,
|
| 46 |
+
num_layers: int = 1,
|
| 47 |
+
num_agents: int = 3,
|
| 48 |
+
dropout: float = 0.1,
|
| 49 |
+
) -> None:
|
| 50 |
+
"""
|
| 51 |
+
Initialize the RNN meta-controller model.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
input_dim: Dimension of input features. Defaults to 10.
|
| 55 |
+
hidden_dim: Dimension of GRU hidden state. Defaults to 64.
|
| 56 |
+
num_layers: Number of stacked GRU layers. Defaults to 1.
|
| 57 |
+
num_agents: Number of agents to choose from. Defaults to 3.
|
| 58 |
+
dropout: Dropout probability for regularization. Defaults to 0.1.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
self.hidden_dim = hidden_dim
|
| 63 |
+
self.num_layers = num_layers
|
| 64 |
+
|
| 65 |
+
# GRU layer for sequence processing
|
| 66 |
+
self.gru = nn.GRU(
|
| 67 |
+
input_size=input_dim,
|
| 68 |
+
hidden_size=hidden_dim,
|
| 69 |
+
num_layers=num_layers,
|
| 70 |
+
batch_first=True,
|
| 71 |
+
dropout=dropout if num_layers > 1 else 0.0,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Dropout for regularization
|
| 75 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 76 |
+
|
| 77 |
+
# Linear output layer for classification
|
| 78 |
+
self.fc = nn.Linear(hidden_dim, num_agents)
|
| 79 |
+
|
| 80 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Forward pass through the model.
|
| 83 |
+
|
| 84 |
+
Processes input features through GRU and produces agent selection logits.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
x: Input tensor of shape (batch_size, features) or
|
| 88 |
+
(batch_size, seq_len, features).
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Logits tensor of shape (batch_size, num_agents).
|
| 92 |
+
Note: Returns raw logits, NOT softmax probabilities.
|
| 93 |
+
|
| 94 |
+
Example:
|
| 95 |
+
>>> model = RNNMetaControllerModel()
|
| 96 |
+
>>> x = torch.randn(4, 10) # batch of 4, 10 features
|
| 97 |
+
>>> logits = model(x)
|
| 98 |
+
>>> logits.shape
|
| 99 |
+
torch.Size([4, 3])
|
| 100 |
+
"""
|
| 101 |
+
# Handle 2D input by adding sequence dimension
|
| 102 |
+
if x.dim() == 2:
|
| 103 |
+
# Shape: (batch_size, features) -> (batch_size, 1, features)
|
| 104 |
+
x = x.unsqueeze(1)
|
| 105 |
+
|
| 106 |
+
# Pass through GRU
|
| 107 |
+
# output shape: (batch_size, seq_len, hidden_dim)
|
| 108 |
+
# hidden shape: (num_layers, batch_size, hidden_dim)
|
| 109 |
+
output, hidden = self.gru(x)
|
| 110 |
+
|
| 111 |
+
# Take the final hidden state from the last layer
|
| 112 |
+
# Shape: (batch_size, hidden_dim)
|
| 113 |
+
final_hidden = hidden[-1] if self.num_layers > 1 else hidden.squeeze(0)
|
| 114 |
+
|
| 115 |
+
# Apply dropout
|
| 116 |
+
dropped = self.dropout(final_hidden)
|
| 117 |
+
|
| 118 |
+
# Apply linear layer to get logits
|
| 119 |
+
logits = self.fc(dropped)
|
| 120 |
+
|
| 121 |
+
return logits
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class RNNMetaController(AbstractMetaController):
|
| 125 |
+
"""
|
| 126 |
+
RNN-based meta-controller using GRU for agent selection.
|
| 127 |
+
|
| 128 |
+
This controller uses a recurrent neural network to learn patterns in
|
| 129 |
+
agent state sequences and predict the optimal agent for the current
|
| 130 |
+
situation. It supports both CPU and GPU execution.
|
| 131 |
+
|
| 132 |
+
Attributes:
|
| 133 |
+
device: PyTorch device (CPU or CUDA) for tensor operations.
|
| 134 |
+
hidden_dim: Dimension of GRU hidden state.
|
| 135 |
+
num_layers: Number of GRU layers.
|
| 136 |
+
dropout: Dropout probability.
|
| 137 |
+
model: The underlying RNNMetaControllerModel.
|
| 138 |
+
hidden_state: Optional hidden state for sequence tracking.
|
| 139 |
+
|
| 140 |
+
Example:
|
| 141 |
+
>>> controller = RNNMetaController(name="RNNController", seed=42)
|
| 142 |
+
>>> features = MetaControllerFeatures(
|
| 143 |
+
... hrm_confidence=0.8,
|
| 144 |
+
... trm_confidence=0.6,
|
| 145 |
+
... mcts_value=0.75,
|
| 146 |
+
... consensus_score=0.7,
|
| 147 |
+
... last_agent='hrm',
|
| 148 |
+
... iteration=2,
|
| 149 |
+
... query_length=150,
|
| 150 |
+
... has_rag_context=True
|
| 151 |
+
... )
|
| 152 |
+
>>> prediction = controller.predict(features)
|
| 153 |
+
>>> prediction.agent in ['hrm', 'trm', 'mcts']
|
| 154 |
+
True
|
| 155 |
+
>>> 0.0 <= prediction.confidence <= 1.0
|
| 156 |
+
True
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
name: str = "RNNMetaController",
|
| 162 |
+
seed: int = 42,
|
| 163 |
+
hidden_dim: int = 64,
|
| 164 |
+
num_layers: int = 1,
|
| 165 |
+
dropout: float = 0.1,
|
| 166 |
+
device: str | None = None,
|
| 167 |
+
) -> None:
|
| 168 |
+
"""
|
| 169 |
+
Initialize the RNN meta-controller.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
name: Name identifier for this controller. Defaults to "RNNMetaController".
|
| 173 |
+
seed: Random seed for reproducibility. Defaults to 42.
|
| 174 |
+
hidden_dim: Dimension of GRU hidden state. Defaults to 64.
|
| 175 |
+
num_layers: Number of GRU layers. Defaults to 1.
|
| 176 |
+
dropout: Dropout probability. Defaults to 0.1.
|
| 177 |
+
device: Device to run model on ('cpu', 'cuda', 'mps', etc.).
|
| 178 |
+
If None, auto-detects best available device.
|
| 179 |
+
"""
|
| 180 |
+
super().__init__(name=name, seed=seed)
|
| 181 |
+
|
| 182 |
+
# Set random seed for reproducibility
|
| 183 |
+
torch.manual_seed(seed)
|
| 184 |
+
|
| 185 |
+
# Auto-detect device if not specified
|
| 186 |
+
if device is None:
|
| 187 |
+
if torch.cuda.is_available():
|
| 188 |
+
self.device = torch.device("cuda")
|
| 189 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 190 |
+
self.device = torch.device("mps")
|
| 191 |
+
else:
|
| 192 |
+
self.device = torch.device("cpu")
|
| 193 |
+
else:
|
| 194 |
+
self.device = torch.device(device)
|
| 195 |
+
|
| 196 |
+
# Store configuration
|
| 197 |
+
self.hidden_dim = hidden_dim
|
| 198 |
+
self.num_layers = num_layers
|
| 199 |
+
self.dropout = dropout
|
| 200 |
+
|
| 201 |
+
# Initialize model
|
| 202 |
+
self.model = RNNMetaControllerModel(
|
| 203 |
+
input_dim=10, # Fixed based on features_to_tensor output
|
| 204 |
+
hidden_dim=hidden_dim,
|
| 205 |
+
num_layers=num_layers,
|
| 206 |
+
num_agents=len(self.AGENT_NAMES),
|
| 207 |
+
dropout=dropout,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Move model to device
|
| 211 |
+
self.model = self.model.to(self.device)
|
| 212 |
+
|
| 213 |
+
# Set model to evaluation mode
|
| 214 |
+
self.model.eval()
|
| 215 |
+
|
| 216 |
+
# Initialize hidden state for sequence tracking
|
| 217 |
+
self.hidden_state: torch.Tensor | None = None
|
| 218 |
+
|
| 219 |
+
def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction:
|
| 220 |
+
"""
|
| 221 |
+
Predict which agent should handle the current query.
|
| 222 |
+
|
| 223 |
+
Converts features to tensor format, runs through the GRU model,
|
| 224 |
+
and returns a prediction with confidence scores.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
features: Features extracted from the current agent state.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Prediction containing the selected agent, confidence score,
|
| 231 |
+
and probability distribution over all agents.
|
| 232 |
+
|
| 233 |
+
Example:
|
| 234 |
+
>>> controller = RNNMetaController()
|
| 235 |
+
>>> features = MetaControllerFeatures(
|
| 236 |
+
... hrm_confidence=0.9,
|
| 237 |
+
... trm_confidence=0.3,
|
| 238 |
+
... mcts_value=0.5,
|
| 239 |
+
... consensus_score=0.8,
|
| 240 |
+
... last_agent='none',
|
| 241 |
+
... iteration=0,
|
| 242 |
+
... query_length=100,
|
| 243 |
+
... has_rag_context=False
|
| 244 |
+
... )
|
| 245 |
+
>>> pred = controller.predict(features)
|
| 246 |
+
>>> isinstance(pred.agent, str)
|
| 247 |
+
True
|
| 248 |
+
>>> isinstance(pred.confidence, float)
|
| 249 |
+
True
|
| 250 |
+
>>> len(pred.probabilities) == 3
|
| 251 |
+
True
|
| 252 |
+
"""
|
| 253 |
+
# Convert features to tensor
|
| 254 |
+
feature_tensor = features_to_tensor(features)
|
| 255 |
+
|
| 256 |
+
# Add batch dimension: (10,) -> (1, 10)
|
| 257 |
+
feature_tensor = feature_tensor.unsqueeze(0)
|
| 258 |
+
|
| 259 |
+
# Move to device
|
| 260 |
+
feature_tensor = feature_tensor.to(self.device)
|
| 261 |
+
|
| 262 |
+
# Perform inference without gradient tracking
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
# Get logits from model
|
| 265 |
+
logits = self.model(feature_tensor)
|
| 266 |
+
|
| 267 |
+
# Apply softmax to get probabilities
|
| 268 |
+
probabilities = F.softmax(logits, dim=-1)
|
| 269 |
+
|
| 270 |
+
# Get predicted agent index (argmax)
|
| 271 |
+
predicted_idx = torch.argmax(probabilities, dim=-1).item()
|
| 272 |
+
|
| 273 |
+
# Extract confidence for selected agent
|
| 274 |
+
confidence = probabilities[0, predicted_idx].item()
|
| 275 |
+
|
| 276 |
+
# Create probability dictionary
|
| 277 |
+
prob_dict: dict[str, float] = {}
|
| 278 |
+
for i, agent_name in enumerate(self.AGENT_NAMES):
|
| 279 |
+
prob_dict[agent_name] = probabilities[0, i].item()
|
| 280 |
+
|
| 281 |
+
# Get agent name
|
| 282 |
+
selected_agent = self.AGENT_NAMES[predicted_idx]
|
| 283 |
+
|
| 284 |
+
return MetaControllerPrediction(
|
| 285 |
+
agent=selected_agent,
|
| 286 |
+
confidence=float(confidence),
|
| 287 |
+
probabilities=prob_dict,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def load_model(self, path: str) -> None:
|
| 291 |
+
"""
|
| 292 |
+
Load a trained model from disk.
|
| 293 |
+
|
| 294 |
+
Loads the model state dictionary from the specified path and
|
| 295 |
+
sets the model to evaluation mode.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
path: Path to the saved model file (.pt or .pth).
|
| 299 |
+
|
| 300 |
+
Raises:
|
| 301 |
+
FileNotFoundError: If the model file does not exist.
|
| 302 |
+
RuntimeError: If the state dict is incompatible with the model.
|
| 303 |
+
|
| 304 |
+
Example:
|
| 305 |
+
>>> controller = RNNMetaController()
|
| 306 |
+
>>> controller.load_model("/path/to/model.pt")
|
| 307 |
+
"""
|
| 308 |
+
# Load state dict with appropriate device mapping
|
| 309 |
+
state_dict = torch.load(path, map_location=self.device, weights_only=True)
|
| 310 |
+
|
| 311 |
+
# Load into model
|
| 312 |
+
self.model.load_state_dict(state_dict)
|
| 313 |
+
|
| 314 |
+
# Ensure model is in evaluation mode
|
| 315 |
+
self.model.eval()
|
| 316 |
+
|
| 317 |
+
def save_model(self, path: str) -> None:
|
| 318 |
+
"""
|
| 319 |
+
Save the current model to disk.
|
| 320 |
+
|
| 321 |
+
Saves the model state dictionary to the specified path.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
path: Path where the model should be saved (.pt or .pth).
|
| 325 |
+
|
| 326 |
+
Example:
|
| 327 |
+
>>> controller = RNNMetaController()
|
| 328 |
+
>>> controller.save_model("/path/to/model.pt")
|
| 329 |
+
"""
|
| 330 |
+
torch.save(self.model.state_dict(), path)
|
| 331 |
+
|
| 332 |
+
def reset_hidden_state(self) -> None:
|
| 333 |
+
"""
|
| 334 |
+
Reset the hidden state for sequence tracking.
|
| 335 |
+
|
| 336 |
+
This method clears any accumulated hidden state, useful when
|
| 337 |
+
starting a new conversation or resetting the controller state.
|
| 338 |
+
|
| 339 |
+
Example:
|
| 340 |
+
>>> controller = RNNMetaController()
|
| 341 |
+
>>> controller.reset_hidden_state()
|
| 342 |
+
>>> controller.hidden_state is None
|
| 343 |
+
True
|
| 344 |
+
"""
|
| 345 |
+
self.hidden_state = None
|
src/agents/meta_controller/utils.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for Neural Meta-Controller feature processing.
|
| 3 |
+
|
| 4 |
+
Provides functions for normalizing, encoding, and converting features
|
| 5 |
+
into formats suitable for different neural network architectures.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from src.agents.meta_controller.base import MetaControllerFeatures
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def normalize_features(features: MetaControllerFeatures) -> list[float]:
|
| 14 |
+
"""
|
| 15 |
+
Normalize meta-controller features to a 10-dimensional vector in range [0, 1].
|
| 16 |
+
|
| 17 |
+
The normalization strategy:
|
| 18 |
+
- Confidence scores (hrm, trm, mcts_value, consensus): Already 0-1, clipped
|
| 19 |
+
- last_agent: Encoded as 3 one-hot values (hrm=0, trm=1, mcts=2)
|
| 20 |
+
- iteration: Normalized to 0-1 assuming max 20 iterations
|
| 21 |
+
- query_length: Normalized to 0-1 assuming max 10000 characters
|
| 22 |
+
- has_rag_context: Binary 0 or 1
|
| 23 |
+
|
| 24 |
+
Output vector structure (10 dimensions):
|
| 25 |
+
[hrm_conf, trm_conf, mcts_value, consensus, last_hrm, last_trm, last_mcts,
|
| 26 |
+
iteration_norm, query_length_norm, has_rag]
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
features: MetaControllerFeatures instance to normalize.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List of 10 floats, each normalized to range [0, 1].
|
| 33 |
+
|
| 34 |
+
Example:
|
| 35 |
+
>>> features = MetaControllerFeatures(
|
| 36 |
+
... hrm_confidence=0.8,
|
| 37 |
+
... trm_confidence=0.6,
|
| 38 |
+
... mcts_value=0.75,
|
| 39 |
+
... consensus_score=0.7,
|
| 40 |
+
... last_agent='hrm',
|
| 41 |
+
... iteration=2,
|
| 42 |
+
... query_length=150,
|
| 43 |
+
... has_rag_context=True
|
| 44 |
+
... )
|
| 45 |
+
>>> normalized = normalize_features(features)
|
| 46 |
+
>>> len(normalized)
|
| 47 |
+
10
|
| 48 |
+
>>> all(0.0 <= v <= 1.0 for v in normalized)
|
| 49 |
+
True
|
| 50 |
+
"""
|
| 51 |
+
# Clip confidence scores to [0, 1]
|
| 52 |
+
hrm_conf = max(0.0, min(1.0, features.hrm_confidence))
|
| 53 |
+
trm_conf = max(0.0, min(1.0, features.trm_confidence))
|
| 54 |
+
mcts_val = max(0.0, min(1.0, features.mcts_value))
|
| 55 |
+
consensus = max(0.0, min(1.0, features.consensus_score))
|
| 56 |
+
|
| 57 |
+
# One-hot encode last_agent (3 dimensions)
|
| 58 |
+
last_agent_onehot = one_hot_encode_agent(features.last_agent)
|
| 59 |
+
|
| 60 |
+
# Normalize iteration (assuming max 20 iterations)
|
| 61 |
+
max_iterations = 20
|
| 62 |
+
iteration_norm = max(0.0, min(1.0, features.iteration / max_iterations))
|
| 63 |
+
|
| 64 |
+
# Normalize query length (assuming max 10000 characters)
|
| 65 |
+
max_query_length = 10000
|
| 66 |
+
query_length_norm = max(0.0, min(1.0, features.query_length / max_query_length))
|
| 67 |
+
|
| 68 |
+
# Binary for has_rag_context
|
| 69 |
+
has_rag = 1.0 if features.has_rag_context else 0.0
|
| 70 |
+
|
| 71 |
+
# Combine into 10-dimensional vector
|
| 72 |
+
return [
|
| 73 |
+
hrm_conf,
|
| 74 |
+
trm_conf,
|
| 75 |
+
mcts_val,
|
| 76 |
+
consensus,
|
| 77 |
+
last_agent_onehot[0], # hrm
|
| 78 |
+
last_agent_onehot[1], # trm
|
| 79 |
+
last_agent_onehot[2], # mcts
|
| 80 |
+
iteration_norm,
|
| 81 |
+
query_length_norm,
|
| 82 |
+
has_rag,
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def one_hot_encode_agent(agent: str) -> list[float]:
|
| 87 |
+
"""
|
| 88 |
+
One-hot encode an agent name into a 3-dimensional vector.
|
| 89 |
+
|
| 90 |
+
Encoding:
|
| 91 |
+
- 'hrm' -> [1.0, 0.0, 0.0]
|
| 92 |
+
- 'trm' -> [0.0, 1.0, 0.0]
|
| 93 |
+
- 'mcts' -> [0.0, 0.0, 1.0]
|
| 94 |
+
- 'none' or other -> [0.0, 0.0, 0.0]
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
agent: Agent name string ('hrm', 'trm', 'mcts', or 'none').
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
List of 3 floats representing the one-hot encoding.
|
| 101 |
+
|
| 102 |
+
Example:
|
| 103 |
+
>>> one_hot_encode_agent('hrm')
|
| 104 |
+
[1.0, 0.0, 0.0]
|
| 105 |
+
>>> one_hot_encode_agent('trm')
|
| 106 |
+
[0.0, 1.0, 0.0]
|
| 107 |
+
>>> one_hot_encode_agent('mcts')
|
| 108 |
+
[0.0, 0.0, 1.0]
|
| 109 |
+
>>> one_hot_encode_agent('none')
|
| 110 |
+
[0.0, 0.0, 0.0]
|
| 111 |
+
"""
|
| 112 |
+
agent_lower = agent.lower()
|
| 113 |
+
|
| 114 |
+
if agent_lower == "hrm": # noqa: SIM116
|
| 115 |
+
return [1.0, 0.0, 0.0]
|
| 116 |
+
elif agent_lower == "trm":
|
| 117 |
+
return [0.0, 1.0, 0.0]
|
| 118 |
+
elif agent_lower == "mcts":
|
| 119 |
+
return [0.0, 0.0, 1.0]
|
| 120 |
+
else:
|
| 121 |
+
# 'none' or unknown agent
|
| 122 |
+
return [0.0, 0.0, 0.0]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def features_to_tensor(features: MetaControllerFeatures) -> torch.Tensor:
|
| 126 |
+
"""
|
| 127 |
+
Convert meta-controller features to a PyTorch tensor.
|
| 128 |
+
|
| 129 |
+
Uses normalize_features internally to create a normalized 10-dimensional
|
| 130 |
+
tensor suitable for neural network input.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
features: MetaControllerFeatures instance to convert.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
PyTorch tensor of shape (10,) with float32 dtype.
|
| 137 |
+
|
| 138 |
+
Example:
|
| 139 |
+
>>> features = MetaControllerFeatures(
|
| 140 |
+
... hrm_confidence=0.8,
|
| 141 |
+
... trm_confidence=0.6,
|
| 142 |
+
... mcts_value=0.75,
|
| 143 |
+
... consensus_score=0.7,
|
| 144 |
+
... last_agent='hrm',
|
| 145 |
+
... iteration=2,
|
| 146 |
+
... query_length=150,
|
| 147 |
+
... has_rag_context=True
|
| 148 |
+
... )
|
| 149 |
+
>>> tensor = features_to_tensor(features)
|
| 150 |
+
>>> tensor.shape
|
| 151 |
+
torch.Size([10])
|
| 152 |
+
>>> tensor.dtype
|
| 153 |
+
torch.float32
|
| 154 |
+
"""
|
| 155 |
+
normalized = normalize_features(features)
|
| 156 |
+
return torch.tensor(normalized, dtype=torch.float32)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def features_to_text(features: MetaControllerFeatures) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Convert meta-controller features to structured text format.
|
| 162 |
+
|
| 163 |
+
Creates a human-readable text representation suitable for text-based
|
| 164 |
+
models like BERT or other language models.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
features: MetaControllerFeatures instance to convert.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Structured text string describing the features.
|
| 171 |
+
|
| 172 |
+
Example:
|
| 173 |
+
>>> features = MetaControllerFeatures(
|
| 174 |
+
... hrm_confidence=0.8,
|
| 175 |
+
... trm_confidence=0.6,
|
| 176 |
+
... mcts_value=0.75,
|
| 177 |
+
... consensus_score=0.7,
|
| 178 |
+
... last_agent='hrm',
|
| 179 |
+
... iteration=2,
|
| 180 |
+
... query_length=150,
|
| 181 |
+
... has_rag_context=True
|
| 182 |
+
... )
|
| 183 |
+
>>> text = features_to_text(features)
|
| 184 |
+
>>> 'HRM confidence: 0.800' in text
|
| 185 |
+
True
|
| 186 |
+
"""
|
| 187 |
+
rag_status = "available" if features.has_rag_context else "not available"
|
| 188 |
+
|
| 189 |
+
text = (
|
| 190 |
+
f"Agent State Features:\n"
|
| 191 |
+
f"HRM confidence: {features.hrm_confidence:.3f}\n"
|
| 192 |
+
f"TRM confidence: {features.trm_confidence:.3f}\n"
|
| 193 |
+
f"MCTS value: {features.mcts_value:.3f}\n"
|
| 194 |
+
f"Consensus score: {features.consensus_score:.3f}\n"
|
| 195 |
+
f"Last agent used: {features.last_agent}\n"
|
| 196 |
+
f"Current iteration: {features.iteration}\n"
|
| 197 |
+
f"Query length: {features.query_length} characters\n"
|
| 198 |
+
f"RAG context: {rag_status}"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
return text
|
src/agents/trm_agent.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tiny Recursive Model (TRM) Agent.
|
| 3 |
+
|
| 4 |
+
Implements recursive refinement with:
|
| 5 |
+
- Deep supervision at all recursion levels
|
| 6 |
+
- Convergence detection
|
| 7 |
+
- Memory-efficient recursion
|
| 8 |
+
- Iterative improvement mechanism
|
| 9 |
+
|
| 10 |
+
Based on principles from:
|
| 11 |
+
- "Recursive Refinement Networks"
|
| 12 |
+
- "Deep Supervision for Neural Networks"
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from ..training.system_config import TRMConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TRMOutput:
|
| 27 |
+
"""Output from TRM recursive processing."""
|
| 28 |
+
|
| 29 |
+
final_prediction: torch.Tensor # Final refined output
|
| 30 |
+
intermediate_predictions: list[torch.Tensor] # Predictions at each recursion
|
| 31 |
+
recursion_depth: int # Actual depth used
|
| 32 |
+
converged: bool # Whether convergence was achieved
|
| 33 |
+
convergence_step: int # Step at which convergence occurred
|
| 34 |
+
residual_norms: list[float] # L2 norms of residuals at each step
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RecursiveBlock(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Core recursive processing block.
|
| 40 |
+
|
| 41 |
+
Applies the same transformation repeatedly, with residual connections.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, config: TRMConfig):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.config = config
|
| 47 |
+
|
| 48 |
+
# Main processing pathway
|
| 49 |
+
self.transform = nn.Sequential(
|
| 50 |
+
nn.Linear(config.latent_dim, config.hidden_dim),
|
| 51 |
+
nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(),
|
| 52 |
+
nn.GELU(),
|
| 53 |
+
nn.Dropout(config.dropout),
|
| 54 |
+
nn.Linear(config.hidden_dim, config.latent_dim),
|
| 55 |
+
nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Residual scaling (learned)
|
| 59 |
+
self.residual_scale = nn.Parameter(torch.ones(1))
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor, iteration: int = 0) -> torch.Tensor: # noqa: ARG002
|
| 62 |
+
"""
|
| 63 |
+
Apply recursive transformation.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
x: Input tensor [batch, ..., latent_dim]
|
| 67 |
+
iteration: Current recursion iteration (reserved for future iteration-dependent behavior)
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Refined tensor [batch, ..., latent_dim]
|
| 71 |
+
"""
|
| 72 |
+
# Residual connection with learned scaling
|
| 73 |
+
residual = self.transform(x)
|
| 74 |
+
return x + self.residual_scale * residual
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DeepSupervisionHead(nn.Module):
|
| 78 |
+
"""
|
| 79 |
+
Supervision head for intermediate predictions.
|
| 80 |
+
|
| 81 |
+
Enables training signal at each recursion level.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, latent_dim: int, output_dim: int):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.head = nn.Sequential(
|
| 87 |
+
nn.Linear(latent_dim, latent_dim // 2),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.Linear(latent_dim // 2, output_dim),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""Generate prediction from latent state."""
|
| 94 |
+
return self.head(x)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class TRMAgent(nn.Module):
|
| 98 |
+
"""
|
| 99 |
+
Tiny Recursive Model for iterative refinement.
|
| 100 |
+
|
| 101 |
+
Features:
|
| 102 |
+
- Shared weights across recursions (parameter efficiency)
|
| 103 |
+
- Deep supervision at all levels
|
| 104 |
+
- Automatic convergence detection
|
| 105 |
+
- Residual connections for stable gradients
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, config: TRMConfig, output_dim: int | None = None, device: str = "cpu"):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.config = config
|
| 111 |
+
self.device = device
|
| 112 |
+
self.output_dim = output_dim or config.latent_dim
|
| 113 |
+
|
| 114 |
+
# Initial encoding
|
| 115 |
+
self.encoder = nn.Sequential(
|
| 116 |
+
nn.Linear(config.latent_dim, config.hidden_dim),
|
| 117 |
+
nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(),
|
| 118 |
+
nn.GELU(),
|
| 119 |
+
nn.Linear(config.hidden_dim, config.latent_dim),
|
| 120 |
+
nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Shared recursive block
|
| 124 |
+
self.recursive_block = RecursiveBlock(config)
|
| 125 |
+
|
| 126 |
+
# Deep supervision heads (one per recursion level)
|
| 127 |
+
if config.deep_supervision:
|
| 128 |
+
self.supervision_heads = nn.ModuleList(
|
| 129 |
+
[DeepSupervisionHead(config.latent_dim, self.output_dim) for _ in range(config.num_recursions)]
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
# Single output head
|
| 133 |
+
self.output_head = DeepSupervisionHead(config.latent_dim, self.output_dim)
|
| 134 |
+
|
| 135 |
+
self.to(device)
|
| 136 |
+
|
| 137 |
+
def forward(
|
| 138 |
+
self,
|
| 139 |
+
x: torch.Tensor,
|
| 140 |
+
num_recursions: int | None = None,
|
| 141 |
+
check_convergence: bool = True,
|
| 142 |
+
) -> TRMOutput:
|
| 143 |
+
"""
|
| 144 |
+
Process input through recursive refinement.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
x: Input tensor [batch, ..., latent_dim]
|
| 148 |
+
num_recursions: Number of recursions (defaults to config)
|
| 149 |
+
check_convergence: Whether to check for early convergence
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
TRMOutput with final and intermediate predictions
|
| 153 |
+
"""
|
| 154 |
+
num_recursions = num_recursions or self.config.num_recursions
|
| 155 |
+
|
| 156 |
+
# Initial encoding
|
| 157 |
+
latent = self.encoder(x)
|
| 158 |
+
previous_latent = latent.clone()
|
| 159 |
+
|
| 160 |
+
# Tracking
|
| 161 |
+
intermediate_predictions = []
|
| 162 |
+
residual_norms = []
|
| 163 |
+
converged = False
|
| 164 |
+
convergence_step = num_recursions
|
| 165 |
+
|
| 166 |
+
# Recursive refinement
|
| 167 |
+
for i in range(num_recursions):
|
| 168 |
+
# Apply recursive transformation
|
| 169 |
+
latent = self.recursive_block(latent, iteration=i)
|
| 170 |
+
|
| 171 |
+
# Generate intermediate prediction
|
| 172 |
+
if self.config.deep_supervision and i < len(self.supervision_heads):
|
| 173 |
+
pred = self.supervision_heads[i](latent)
|
| 174 |
+
else:
|
| 175 |
+
pred = self.output_head(latent)
|
| 176 |
+
|
| 177 |
+
intermediate_predictions.append(pred)
|
| 178 |
+
|
| 179 |
+
# Check convergence
|
| 180 |
+
if check_convergence and i >= self.config.min_recursions:
|
| 181 |
+
residual = latent - previous_latent
|
| 182 |
+
residual_norm = torch.norm(residual, p=2, dim=-1).mean().item()
|
| 183 |
+
residual_norms.append(residual_norm)
|
| 184 |
+
|
| 185 |
+
if residual_norm < self.config.convergence_threshold:
|
| 186 |
+
converged = True
|
| 187 |
+
convergence_step = i + 1
|
| 188 |
+
break
|
| 189 |
+
|
| 190 |
+
previous_latent = latent.clone()
|
| 191 |
+
|
| 192 |
+
# Final prediction
|
| 193 |
+
final_pred = intermediate_predictions[-1]
|
| 194 |
+
|
| 195 |
+
return TRMOutput(
|
| 196 |
+
final_prediction=final_pred,
|
| 197 |
+
intermediate_predictions=intermediate_predictions,
|
| 198 |
+
recursion_depth=len(intermediate_predictions),
|
| 199 |
+
converged=converged,
|
| 200 |
+
convergence_step=convergence_step,
|
| 201 |
+
residual_norms=residual_norms,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
async def refine_solution(
|
| 205 |
+
self,
|
| 206 |
+
initial_prediction: torch.Tensor,
|
| 207 |
+
num_recursions: int | None = None,
|
| 208 |
+
convergence_threshold: float | None = None,
|
| 209 |
+
) -> tuple[torch.Tensor, dict]:
|
| 210 |
+
"""
|
| 211 |
+
Refine an initial prediction through recursive processing.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
initial_prediction: Initial solution [batch, ..., latent_dim]
|
| 215 |
+
num_recursions: Maximum recursions (optional)
|
| 216 |
+
convergence_threshold: Convergence threshold (optional)
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
refined_solution: Final refined prediction
|
| 220 |
+
info: Dictionary with refinement metadata
|
| 221 |
+
"""
|
| 222 |
+
# Temporarily override convergence threshold if provided
|
| 223 |
+
original_threshold = self.config.convergence_threshold
|
| 224 |
+
if convergence_threshold is not None:
|
| 225 |
+
self.config.convergence_threshold = convergence_threshold
|
| 226 |
+
|
| 227 |
+
# Process
|
| 228 |
+
output = self.forward(
|
| 229 |
+
initial_prediction,
|
| 230 |
+
num_recursions=num_recursions,
|
| 231 |
+
check_convergence=True,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Restore original threshold
|
| 235 |
+
self.config.convergence_threshold = original_threshold
|
| 236 |
+
|
| 237 |
+
info = {
|
| 238 |
+
"converged": output.converged,
|
| 239 |
+
"convergence_step": output.convergence_step,
|
| 240 |
+
"total_recursions": output.recursion_depth,
|
| 241 |
+
"final_residual": output.residual_norms[-1] if output.residual_norms else None,
|
| 242 |
+
"refinement_path": output.residual_norms,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
return output.final_prediction, info
|
| 246 |
+
|
| 247 |
+
def get_parameter_count(self) -> int:
|
| 248 |
+
"""Return total number of trainable parameters."""
|
| 249 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class TRMLoss(nn.Module):
|
| 253 |
+
"""
|
| 254 |
+
Deep supervision loss for TRM.
|
| 255 |
+
|
| 256 |
+
Applies weighted supervision at all recursion levels,
|
| 257 |
+
with exponential decay for deeper levels.
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
task_loss_fn: nn.Module,
|
| 263 |
+
supervision_weight_decay: float = 0.5,
|
| 264 |
+
final_weight: float = 1.0,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Initialize TRM loss.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
task_loss_fn: Base loss function (e.g., MSE, CrossEntropy)
|
| 271 |
+
supervision_weight_decay: Decay factor for intermediate losses
|
| 272 |
+
final_weight: Weight for final prediction loss
|
| 273 |
+
"""
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.task_loss_fn = task_loss_fn
|
| 276 |
+
self.supervision_weight_decay = supervision_weight_decay
|
| 277 |
+
self.final_weight = final_weight
|
| 278 |
+
|
| 279 |
+
def forward(self, trm_output: TRMOutput, targets: torch.Tensor) -> tuple[torch.Tensor, dict]:
|
| 280 |
+
"""
|
| 281 |
+
Compute deep supervision loss.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
trm_output: Output from TRM forward pass
|
| 285 |
+
targets: Ground truth targets
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
total_loss: Combined loss
|
| 289 |
+
loss_dict: Dictionary of loss components
|
| 290 |
+
"""
|
| 291 |
+
# Final prediction loss (highest weight)
|
| 292 |
+
final_loss = self.task_loss_fn(trm_output.final_prediction, targets)
|
| 293 |
+
total_loss = self.final_weight * final_loss
|
| 294 |
+
|
| 295 |
+
# Intermediate supervision losses
|
| 296 |
+
intermediate_losses = []
|
| 297 |
+
num_intermediate = len(trm_output.intermediate_predictions) - 1
|
| 298 |
+
|
| 299 |
+
for i, pred in enumerate(trm_output.intermediate_predictions[:-1]):
|
| 300 |
+
# Exponential decay: earlier predictions get lower weight
|
| 301 |
+
weight = self.supervision_weight_decay ** (num_intermediate - i)
|
| 302 |
+
loss = self.task_loss_fn(pred, targets)
|
| 303 |
+
intermediate_losses.append(loss.item())
|
| 304 |
+
total_loss = total_loss + weight * loss
|
| 305 |
+
|
| 306 |
+
loss_dict = {
|
| 307 |
+
"total": total_loss.item(),
|
| 308 |
+
"final": final_loss.item(),
|
| 309 |
+
"intermediate_mean": (sum(intermediate_losses) / len(intermediate_losses) if intermediate_losses else 0.0),
|
| 310 |
+
"recursion_depth": trm_output.recursion_depth,
|
| 311 |
+
"converged": trm_output.converged,
|
| 312 |
+
"convergence_step": trm_output.convergence_step,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
return total_loss, loss_dict
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def create_trm_agent(config: TRMConfig, output_dim: int | None = None, device: str = "cpu") -> TRMAgent:
|
| 319 |
+
"""
|
| 320 |
+
Factory function to create and initialize TRM agent.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
config: TRM configuration
|
| 324 |
+
output_dim: Output dimension (defaults to latent_dim)
|
| 325 |
+
device: Device to place model on
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
Initialized TRMAgent
|
| 329 |
+
"""
|
| 330 |
+
agent = TRMAgent(config, output_dim, device)
|
| 331 |
+
|
| 332 |
+
# Initialize weights with Xavier/He initialization
|
| 333 |
+
def init_weights(m):
|
| 334 |
+
if isinstance(m, nn.Linear):
|
| 335 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 336 |
+
if m.bias is not None:
|
| 337 |
+
nn.init.zeros_(m.bias)
|
| 338 |
+
|
| 339 |
+
agent.apply(init_weights)
|
| 340 |
+
|
| 341 |
+
return agent
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# Utility functions for integration
|
| 345 |
+
class TRMRefinementWrapper:
|
| 346 |
+
"""
|
| 347 |
+
Wrapper for using TRM as a refinement step in pipelines.
|
| 348 |
+
|
| 349 |
+
Provides a clean interface for integrating TRM into larger systems.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
def __init__(self, trm_agent: TRMAgent, device: str = "cpu"):
|
| 353 |
+
self.trm_agent = trm_agent
|
| 354 |
+
self.device = device
|
| 355 |
+
self.trm_agent.eval()
|
| 356 |
+
|
| 357 |
+
@torch.no_grad()
|
| 358 |
+
async def refine(
|
| 359 |
+
self,
|
| 360 |
+
predictions: torch.Tensor,
|
| 361 |
+
num_iterations: int = 10,
|
| 362 |
+
return_path: bool = False,
|
| 363 |
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
| 364 |
+
"""
|
| 365 |
+
Refine predictions using TRM.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
predictions: Initial predictions to refine
|
| 369 |
+
num_iterations: Number of refinement iterations
|
| 370 |
+
return_path: Whether to return intermediate predictions
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
refined_predictions or (refined_predictions, refinement_path)
|
| 374 |
+
"""
|
| 375 |
+
# Ensure predictions are on correct device
|
| 376 |
+
predictions = predictions.to(self.device)
|
| 377 |
+
|
| 378 |
+
# Run TRM
|
| 379 |
+
output = self.trm_agent(predictions, num_recursions=num_iterations, check_convergence=True)
|
| 380 |
+
|
| 381 |
+
if return_path:
|
| 382 |
+
return output.final_prediction, output.intermediate_predictions
|
| 383 |
+
return output.final_prediction
|
| 384 |
+
|
| 385 |
+
def get_refinement_stats(self, predictions: torch.Tensor) -> dict:
|
| 386 |
+
"""Get statistics about the refinement process."""
|
| 387 |
+
with torch.no_grad():
|
| 388 |
+
output = self.trm_agent(predictions, check_convergence=True)
|
| 389 |
+
|
| 390 |
+
return {
|
| 391 |
+
"converged": output.converged,
|
| 392 |
+
"steps_to_convergence": output.convergence_step,
|
| 393 |
+
"final_residual": (output.residual_norms[-1] if output.residual_norms else None),
|
| 394 |
+
"total_refinement_iterations": output.recursion_depth,
|
| 395 |
+
}
|
src/api/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API module for LangGraph Multi-Agent MCTS Framework.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- Authentication and authorization
|
| 6 |
+
- Rate limiting
|
| 7 |
+
- Error handling
|
| 8 |
+
- REST API endpoints
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from src.api.exceptions import (
|
| 12 |
+
AuthenticationError,
|
| 13 |
+
AuthorizationError,
|
| 14 |
+
ConfigurationError,
|
| 15 |
+
FrameworkError,
|
| 16 |
+
LLMError,
|
| 17 |
+
MCTSError,
|
| 18 |
+
RAGError,
|
| 19 |
+
RateLimitError,
|
| 20 |
+
TimeoutError,
|
| 21 |
+
ValidationError,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"FrameworkError",
|
| 26 |
+
"ValidationError",
|
| 27 |
+
"AuthenticationError",
|
| 28 |
+
"AuthorizationError",
|
| 29 |
+
"RateLimitError",
|
| 30 |
+
"LLMError",
|
| 31 |
+
"MCTSError",
|
| 32 |
+
"RAGError",
|
| 33 |
+
"TimeoutError",
|
| 34 |
+
"ConfigurationError",
|
| 35 |
+
]
|
src/api/auth.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication and authorization layer for LangGraph Multi-Agent MCTS Framework.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- API key authentication with secure hashing
|
| 6 |
+
- JWT token support (optional)
|
| 7 |
+
- Rate limiting per client
|
| 8 |
+
- Role-based access control
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import hashlib
|
| 12 |
+
import secrets
|
| 13 |
+
import time
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from datetime import datetime, timedelta
|
| 17 |
+
|
| 18 |
+
from src.api.exceptions import (
|
| 19 |
+
AuthenticationError,
|
| 20 |
+
AuthorizationError,
|
| 21 |
+
RateLimitError,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ClientInfo:
|
| 27 |
+
"""Information about an authenticated client."""
|
| 28 |
+
|
| 29 |
+
client_id: str
|
| 30 |
+
roles: set[str] = field(default_factory=lambda: {"user"})
|
| 31 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
| 32 |
+
last_access: datetime = field(default_factory=datetime.utcnow)
|
| 33 |
+
request_count: int = 0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class RateLimitConfig:
|
| 38 |
+
"""Rate limiting configuration."""
|
| 39 |
+
|
| 40 |
+
requests_per_minute: int = 60
|
| 41 |
+
requests_per_hour: int = 1000
|
| 42 |
+
requests_per_day: int = 10000
|
| 43 |
+
burst_limit: int = 100 # Max requests in 1 second
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class APIKeyAuthenticator:
|
| 47 |
+
"""
|
| 48 |
+
API key-based authentication with secure hashing.
|
| 49 |
+
|
| 50 |
+
Keys are stored as SHA-256 hashes to prevent exposure.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
valid_keys: list[str] | None = None,
|
| 56 |
+
rate_limit_config: RateLimitConfig | None = None,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Initialize authenticator.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
valid_keys: List of valid API keys (will be hashed)
|
| 63 |
+
rate_limit_config: Rate limiting configuration
|
| 64 |
+
"""
|
| 65 |
+
self._key_to_client: dict[str, ClientInfo] = {}
|
| 66 |
+
self._rate_limits: dict[str, list[float]] = defaultdict(list)
|
| 67 |
+
self.rate_limit_config = rate_limit_config or RateLimitConfig()
|
| 68 |
+
|
| 69 |
+
# Hash and store initial keys
|
| 70 |
+
if valid_keys:
|
| 71 |
+
for i, key in enumerate(valid_keys):
|
| 72 |
+
client_id = f"client_{i}"
|
| 73 |
+
self._add_key(key, client_id)
|
| 74 |
+
|
| 75 |
+
def _hash_key(self, api_key: str) -> str:
|
| 76 |
+
"""
|
| 77 |
+
Securely hash an API key.
|
| 78 |
+
|
| 79 |
+
Uses SHA-256 with consistent encoding.
|
| 80 |
+
"""
|
| 81 |
+
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
| 82 |
+
|
| 83 |
+
def _add_key(self, api_key: str, client_id: str, roles: set[str] | None = None) -> None:
|
| 84 |
+
"""
|
| 85 |
+
Add a new API key.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
api_key: Raw API key
|
| 89 |
+
client_id: Client identifier
|
| 90 |
+
roles: Set of roles (defaults to {"user"})
|
| 91 |
+
"""
|
| 92 |
+
key_hash = self._hash_key(api_key)
|
| 93 |
+
self._key_to_client[key_hash] = ClientInfo(
|
| 94 |
+
client_id=client_id,
|
| 95 |
+
roles=roles or {"user"},
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def authenticate(self, api_key: str | None) -> ClientInfo:
|
| 99 |
+
"""
|
| 100 |
+
Authenticate an API key.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
api_key: API key to validate
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
ClientInfo for the authenticated client
|
| 107 |
+
|
| 108 |
+
Raises:
|
| 109 |
+
AuthenticationError: If authentication fails
|
| 110 |
+
"""
|
| 111 |
+
if not api_key:
|
| 112 |
+
raise AuthenticationError(
|
| 113 |
+
user_message="API key is required",
|
| 114 |
+
internal_details="No API key provided in request",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Constant-time comparison to prevent timing attacks
|
| 118 |
+
key_hash = self._hash_key(api_key)
|
| 119 |
+
|
| 120 |
+
if key_hash not in self._key_to_client:
|
| 121 |
+
raise AuthenticationError(
|
| 122 |
+
user_message="Invalid API key",
|
| 123 |
+
internal_details=f"API key hash not found: {key_hash[:16]}...",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
client_info = self._key_to_client[key_hash]
|
| 127 |
+
client_info.last_access = datetime.utcnow()
|
| 128 |
+
client_info.request_count += 1
|
| 129 |
+
|
| 130 |
+
# Check rate limits
|
| 131 |
+
self._check_rate_limit(client_info.client_id)
|
| 132 |
+
|
| 133 |
+
return client_info
|
| 134 |
+
|
| 135 |
+
def _check_rate_limit(self, client_id: str) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Check if client has exceeded rate limits.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
client_id: Client identifier
|
| 141 |
+
|
| 142 |
+
Raises:
|
| 143 |
+
RateLimitError: If rate limit exceeded
|
| 144 |
+
"""
|
| 145 |
+
now = time.time()
|
| 146 |
+
request_times = self._rate_limits[client_id]
|
| 147 |
+
|
| 148 |
+
# Clean old entries
|
| 149 |
+
one_day_ago = now - 86400
|
| 150 |
+
request_times = [t for t in request_times if t > one_day_ago]
|
| 151 |
+
self._rate_limits[client_id] = request_times
|
| 152 |
+
|
| 153 |
+
# Check burst limit (1 second window)
|
| 154 |
+
one_second_ago = now - 1
|
| 155 |
+
burst_count = sum(1 for t in request_times if t > one_second_ago)
|
| 156 |
+
if burst_count >= self.rate_limit_config.burst_limit:
|
| 157 |
+
raise RateLimitError(
|
| 158 |
+
user_message="Too many requests. Please slow down.",
|
| 159 |
+
internal_details=f"Client {client_id} exceeded burst limit: {burst_count}/{self.rate_limit_config.burst_limit}",
|
| 160 |
+
retry_after_seconds=1,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Check per-minute limit
|
| 164 |
+
one_minute_ago = now - 60
|
| 165 |
+
minute_count = sum(1 for t in request_times if t > one_minute_ago)
|
| 166 |
+
if minute_count >= self.rate_limit_config.requests_per_minute:
|
| 167 |
+
raise RateLimitError(
|
| 168 |
+
user_message="Rate limit exceeded. Please wait a minute.",
|
| 169 |
+
internal_details=f"Client {client_id} exceeded minute limit: {minute_count}/{self.rate_limit_config.requests_per_minute}",
|
| 170 |
+
retry_after_seconds=60,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Check per-hour limit
|
| 174 |
+
one_hour_ago = now - 3600
|
| 175 |
+
hour_count = sum(1 for t in request_times if t > one_hour_ago)
|
| 176 |
+
if hour_count >= self.rate_limit_config.requests_per_hour:
|
| 177 |
+
raise RateLimitError(
|
| 178 |
+
user_message="Hourly rate limit exceeded. Please try again later.",
|
| 179 |
+
internal_details=f"Client {client_id} exceeded hour limit: {hour_count}/{self.rate_limit_config.requests_per_hour}",
|
| 180 |
+
retry_after_seconds=3600,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Check per-day limit
|
| 184 |
+
day_count = len(request_times)
|
| 185 |
+
if day_count >= self.rate_limit_config.requests_per_day:
|
| 186 |
+
raise RateLimitError(
|
| 187 |
+
user_message="Daily rate limit exceeded. Please try again tomorrow.",
|
| 188 |
+
internal_details=f"Client {client_id} exceeded day limit: {day_count}/{self.rate_limit_config.requests_per_day}",
|
| 189 |
+
retry_after_seconds=86400,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Record this request
|
| 193 |
+
request_times.append(now)
|
| 194 |
+
|
| 195 |
+
def require_auth(self, api_key: str | None) -> ClientInfo:
|
| 196 |
+
"""
|
| 197 |
+
Require authentication for a request.
|
| 198 |
+
|
| 199 |
+
Convenience method that raises on failure.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
api_key: API key to validate
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
ClientInfo for authenticated client
|
| 206 |
+
|
| 207 |
+
Raises:
|
| 208 |
+
AuthenticationError: If authentication fails
|
| 209 |
+
"""
|
| 210 |
+
return self.authenticate(api_key)
|
| 211 |
+
|
| 212 |
+
def require_role(self, client_info: ClientInfo, required_role: str) -> None:
|
| 213 |
+
"""
|
| 214 |
+
Require a specific role for an operation.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
client_info: Authenticated client info
|
| 218 |
+
required_role: Role that is required
|
| 219 |
+
|
| 220 |
+
Raises:
|
| 221 |
+
AuthorizationError: If client doesn't have required role
|
| 222 |
+
"""
|
| 223 |
+
if required_role not in client_info.roles:
|
| 224 |
+
raise AuthorizationError(
|
| 225 |
+
user_message="You do not have permission for this operation",
|
| 226 |
+
internal_details=f"Client {client_info.client_id} missing role: {required_role}",
|
| 227 |
+
required_permission=required_role,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def generate_api_key(self) -> str:
|
| 231 |
+
"""
|
| 232 |
+
Generate a secure random API key.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
New API key (32 bytes hex = 64 characters)
|
| 236 |
+
"""
|
| 237 |
+
return secrets.token_hex(32)
|
| 238 |
+
|
| 239 |
+
def revoke_key(self, api_key: str) -> bool:
|
| 240 |
+
"""
|
| 241 |
+
Revoke an API key.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
api_key: Key to revoke
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
True if key was revoked, False if not found
|
| 248 |
+
"""
|
| 249 |
+
key_hash = self._hash_key(api_key)
|
| 250 |
+
if key_hash in self._key_to_client:
|
| 251 |
+
del self._key_to_client[key_hash]
|
| 252 |
+
return True
|
| 253 |
+
return False
|
| 254 |
+
|
| 255 |
+
def add_client(
|
| 256 |
+
self,
|
| 257 |
+
client_id: str,
|
| 258 |
+
roles: set[str] | None = None,
|
| 259 |
+
) -> str:
|
| 260 |
+
"""
|
| 261 |
+
Add a new client and generate their API key.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
client_id: Unique client identifier
|
| 265 |
+
roles: Set of roles for the client
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Generated API key (save this securely!)
|
| 269 |
+
"""
|
| 270 |
+
api_key = self.generate_api_key()
|
| 271 |
+
self._add_key(api_key, client_id, roles)
|
| 272 |
+
return api_key
|
| 273 |
+
|
| 274 |
+
def get_client_stats(self, client_id: str) -> dict:
|
| 275 |
+
"""
|
| 276 |
+
Get statistics for a client.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
client_id: Client identifier
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Dictionary with client statistics
|
| 283 |
+
"""
|
| 284 |
+
now = time.time()
|
| 285 |
+
request_times = self._rate_limits.get(client_id, [])
|
| 286 |
+
|
| 287 |
+
return {
|
| 288 |
+
"total_requests_today": len([t for t in request_times if t > now - 86400]),
|
| 289 |
+
"requests_last_hour": len([t for t in request_times if t > now - 3600]),
|
| 290 |
+
"requests_last_minute": len([t for t in request_times if t > now - 60]),
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class JWTAuthenticator:
|
| 295 |
+
"""
|
| 296 |
+
JWT token-based authentication.
|
| 297 |
+
|
| 298 |
+
Note: Requires PyJWT library for full functionality.
|
| 299 |
+
This is a placeholder for JWT support.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(self, secret_key: str, algorithm: str = "HS256"):
|
| 303 |
+
"""
|
| 304 |
+
Initialize JWT authenticator.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
secret_key: Secret key for signing tokens
|
| 308 |
+
algorithm: JWT signing algorithm
|
| 309 |
+
"""
|
| 310 |
+
self.secret_key = secret_key
|
| 311 |
+
self.algorithm = algorithm
|
| 312 |
+
self._token_blacklist: set[str] = set()
|
| 313 |
+
|
| 314 |
+
def create_token(
|
| 315 |
+
self,
|
| 316 |
+
client_id: str,
|
| 317 |
+
roles: set[str],
|
| 318 |
+
expires_in_hours: int = 24,
|
| 319 |
+
) -> str:
|
| 320 |
+
"""
|
| 321 |
+
Create a JWT token.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
client_id: Client identifier
|
| 325 |
+
roles: Client roles
|
| 326 |
+
expires_in_hours: Token validity period
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
JWT token string
|
| 330 |
+
"""
|
| 331 |
+
try:
|
| 332 |
+
import jwt
|
| 333 |
+
except ImportError:
|
| 334 |
+
raise ImportError("PyJWT library required for JWT authentication. Install with: pip install PyJWT")
|
| 335 |
+
|
| 336 |
+
now = datetime.utcnow()
|
| 337 |
+
payload = {
|
| 338 |
+
"sub": client_id,
|
| 339 |
+
"roles": list(roles),
|
| 340 |
+
"iat": now,
|
| 341 |
+
"exp": now + timedelta(hours=expires_in_hours),
|
| 342 |
+
"jti": secrets.token_hex(16), # Unique token ID
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 346 |
+
|
| 347 |
+
def verify_token(self, token: str) -> ClientInfo:
|
| 348 |
+
"""
|
| 349 |
+
Verify a JWT token.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
token: JWT token string
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
ClientInfo from token claims
|
| 356 |
+
|
| 357 |
+
Raises:
|
| 358 |
+
AuthenticationError: If token is invalid
|
| 359 |
+
"""
|
| 360 |
+
try:
|
| 361 |
+
import jwt
|
| 362 |
+
except ImportError:
|
| 363 |
+
raise ImportError("PyJWT library required for JWT authentication")
|
| 364 |
+
|
| 365 |
+
if token in self._token_blacklist:
|
| 366 |
+
raise AuthenticationError(
|
| 367 |
+
user_message="Token has been revoked",
|
| 368 |
+
internal_details="Token found in blacklist",
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
payload = jwt.decode(
|
| 373 |
+
token,
|
| 374 |
+
self.secret_key,
|
| 375 |
+
algorithms=[self.algorithm],
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return ClientInfo(
|
| 379 |
+
client_id=payload["sub"],
|
| 380 |
+
roles=set(payload.get("roles", ["user"])),
|
| 381 |
+
)
|
| 382 |
+
except jwt.ExpiredSignatureError:
|
| 383 |
+
raise AuthenticationError(
|
| 384 |
+
user_message="Token has expired",
|
| 385 |
+
internal_details="JWT signature expired",
|
| 386 |
+
)
|
| 387 |
+
except jwt.InvalidTokenError as e:
|
| 388 |
+
raise AuthenticationError(
|
| 389 |
+
user_message="Invalid token",
|
| 390 |
+
internal_details=f"JWT validation failed: {str(e)}",
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
def revoke_token(self, token: str) -> None:
|
| 394 |
+
"""
|
| 395 |
+
Revoke a JWT token by adding to blacklist.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
token: Token to revoke
|
| 399 |
+
"""
|
| 400 |
+
self._token_blacklist.add(token)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# Default authenticator instance
|
| 404 |
+
_default_authenticator: APIKeyAuthenticator | None = None
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def get_authenticator() -> APIKeyAuthenticator:
|
| 408 |
+
"""
|
| 409 |
+
Get or create the default authenticator instance.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
APIKeyAuthenticator instance
|
| 413 |
+
"""
|
| 414 |
+
global _default_authenticator
|
| 415 |
+
if _default_authenticator is None:
|
| 416 |
+
_default_authenticator = APIKeyAuthenticator()
|
| 417 |
+
return _default_authenticator
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def set_authenticator(authenticator: APIKeyAuthenticator) -> None:
|
| 421 |
+
"""
|
| 422 |
+
Set the default authenticator instance.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
authenticator: Authenticator to use
|
| 426 |
+
"""
|
| 427 |
+
global _default_authenticator
|
| 428 |
+
_default_authenticator = authenticator
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# Exports
|
| 432 |
+
__all__ = [
|
| 433 |
+
"APIKeyAuthenticator",
|
| 434 |
+
"JWTAuthenticator",
|
| 435 |
+
"ClientInfo",
|
| 436 |
+
"RateLimitConfig",
|
| 437 |
+
"get_authenticator",
|
| 438 |
+
"set_authenticator",
|
| 439 |
+
]
|
src/api/exceptions.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom exception hierarchy for LangGraph Multi-Agent MCTS Framework.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- Sanitized error messages for production
|
| 6 |
+
- Structured error information for logging
|
| 7 |
+
- Clear separation between user-facing and internal errors
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FrameworkError(Exception):
|
| 16 |
+
"""
|
| 17 |
+
Base exception for all framework errors.
|
| 18 |
+
|
| 19 |
+
Provides sanitized user-facing messages while preserving
|
| 20 |
+
internal details for logging.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
user_message: str,
|
| 26 |
+
internal_details: str | None = None,
|
| 27 |
+
error_code: str | None = None,
|
| 28 |
+
context: dict[str, Any] | None = None,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize framework error.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
user_message: Safe message to show to users
|
| 35 |
+
internal_details: Detailed information for logs (may contain sensitive data)
|
| 36 |
+
error_code: Machine-readable error code
|
| 37 |
+
context: Additional context for debugging
|
| 38 |
+
"""
|
| 39 |
+
self.user_message = user_message
|
| 40 |
+
self.internal_details = internal_details or user_message
|
| 41 |
+
self.error_code = error_code or self.__class__.__name__.upper()
|
| 42 |
+
self.context = context or {}
|
| 43 |
+
self.timestamp = datetime.utcnow()
|
| 44 |
+
|
| 45 |
+
super().__init__(user_message)
|
| 46 |
+
|
| 47 |
+
def sanitize_details(self) -> str:
|
| 48 |
+
"""
|
| 49 |
+
Remove sensitive information from internal details.
|
| 50 |
+
|
| 51 |
+
Sanitizes:
|
| 52 |
+
- File paths
|
| 53 |
+
- API keys
|
| 54 |
+
- Passwords
|
| 55 |
+
- Connection strings
|
| 56 |
+
- IP addresses
|
| 57 |
+
"""
|
| 58 |
+
sanitized = self.internal_details
|
| 59 |
+
|
| 60 |
+
# Remove file paths (Unix and Windows)
|
| 61 |
+
sanitized = re.sub(r"/[\w/.-]+", "/***", sanitized)
|
| 62 |
+
sanitized = re.sub(r"[A-Za-z]:\\[\w\\.-]+", "C:\\***", sanitized)
|
| 63 |
+
|
| 64 |
+
# Remove API keys and secrets
|
| 65 |
+
sanitized = re.sub(
|
| 66 |
+
r"(api[_-]?key|secret|password|token|credential)[\s=:]+[\S]+", r"\1=***", sanitized, flags=re.IGNORECASE
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Remove connection strings
|
| 70 |
+
sanitized = re.sub(r"(mongodb|postgresql|mysql|redis)://[^\s]+", r"\1://***", sanitized, flags=re.IGNORECASE)
|
| 71 |
+
|
| 72 |
+
# Remove IP addresses
|
| 73 |
+
sanitized = re.sub(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", "***.***.***", sanitized)
|
| 74 |
+
|
| 75 |
+
# Remove email addresses
|
| 76 |
+
sanitized = re.sub(r"\b[\w.-]+@[\w.-]+\.\w+\b", "***@***", sanitized)
|
| 77 |
+
|
| 78 |
+
return sanitized
|
| 79 |
+
|
| 80 |
+
def to_log_dict(self) -> dict[str, Any]:
|
| 81 |
+
"""
|
| 82 |
+
Convert exception to dictionary for structured logging.
|
| 83 |
+
|
| 84 |
+
Returns sanitized version safe for logs.
|
| 85 |
+
"""
|
| 86 |
+
return {
|
| 87 |
+
"error_type": self.__class__.__name__,
|
| 88 |
+
"error_code": self.error_code,
|
| 89 |
+
"user_message": self.user_message,
|
| 90 |
+
"sanitized_details": self.sanitize_details(),
|
| 91 |
+
"timestamp": self.timestamp.isoformat(),
|
| 92 |
+
"context": {k: str(v) for k, v in self.context.items()},
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def to_user_response(self) -> dict[str, Any]:
|
| 96 |
+
"""
|
| 97 |
+
Convert exception to safe user-facing response.
|
| 98 |
+
"""
|
| 99 |
+
return {
|
| 100 |
+
"error": True,
|
| 101 |
+
"error_code": self.error_code,
|
| 102 |
+
"message": self.user_message,
|
| 103 |
+
"timestamp": self.timestamp.isoformat(),
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ValidationError(FrameworkError):
|
| 108 |
+
"""Raised when input validation fails."""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
user_message: str = "Invalid input provided",
|
| 113 |
+
internal_details: str | None = None,
|
| 114 |
+
field_name: str | None = None,
|
| 115 |
+
**kwargs,
|
| 116 |
+
):
|
| 117 |
+
context = kwargs.pop("context", {})
|
| 118 |
+
if field_name:
|
| 119 |
+
context["field_name"] = field_name
|
| 120 |
+
super().__init__(
|
| 121 |
+
user_message=user_message,
|
| 122 |
+
internal_details=internal_details,
|
| 123 |
+
error_code="VALIDATION_ERROR",
|
| 124 |
+
context=context,
|
| 125 |
+
**kwargs,
|
| 126 |
+
)
|
| 127 |
+
self.field_name = field_name
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class AuthenticationError(FrameworkError):
|
| 131 |
+
"""Raised when authentication fails."""
|
| 132 |
+
|
| 133 |
+
def __init__(self, user_message: str = "Authentication failed", internal_details: str | None = None, **kwargs):
|
| 134 |
+
super().__init__(
|
| 135 |
+
user_message=user_message, internal_details=internal_details, error_code="AUTH_ERROR", **kwargs
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class AuthorizationError(FrameworkError):
|
| 140 |
+
"""Raised when authorization fails."""
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
user_message: str = "Access denied",
|
| 145 |
+
internal_details: str | None = None,
|
| 146 |
+
required_permission: str | None = None,
|
| 147 |
+
**kwargs,
|
| 148 |
+
):
|
| 149 |
+
context = kwargs.pop("context", {})
|
| 150 |
+
if required_permission:
|
| 151 |
+
context["required_permission"] = required_permission
|
| 152 |
+
super().__init__(
|
| 153 |
+
user_message=user_message,
|
| 154 |
+
internal_details=internal_details,
|
| 155 |
+
error_code="AUTHZ_ERROR",
|
| 156 |
+
context=context,
|
| 157 |
+
**kwargs,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class RateLimitError(FrameworkError):
|
| 162 |
+
"""Raised when rate limit is exceeded."""
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
user_message: str = "Rate limit exceeded. Please try again later.",
|
| 167 |
+
internal_details: str | None = None,
|
| 168 |
+
retry_after_seconds: int | None = None,
|
| 169 |
+
**kwargs,
|
| 170 |
+
):
|
| 171 |
+
context = kwargs.pop("context", {})
|
| 172 |
+
if retry_after_seconds:
|
| 173 |
+
context["retry_after_seconds"] = retry_after_seconds
|
| 174 |
+
super().__init__(
|
| 175 |
+
user_message=user_message,
|
| 176 |
+
internal_details=internal_details,
|
| 177 |
+
error_code="RATE_LIMIT",
|
| 178 |
+
context=context,
|
| 179 |
+
**kwargs,
|
| 180 |
+
)
|
| 181 |
+
self.retry_after_seconds = retry_after_seconds
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class LLMError(FrameworkError):
|
| 185 |
+
"""Raised when LLM operations fail."""
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
user_message: str = "Language model service temporarily unavailable",
|
| 190 |
+
internal_details: str | None = None,
|
| 191 |
+
provider: str | None = None,
|
| 192 |
+
**kwargs,
|
| 193 |
+
):
|
| 194 |
+
context = kwargs.pop("context", {})
|
| 195 |
+
if provider:
|
| 196 |
+
context["provider"] = provider
|
| 197 |
+
super().__init__(
|
| 198 |
+
user_message=user_message,
|
| 199 |
+
internal_details=internal_details,
|
| 200 |
+
error_code="LLM_ERROR",
|
| 201 |
+
context=context,
|
| 202 |
+
**kwargs,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class MCTSError(FrameworkError):
|
| 207 |
+
"""Raised when MCTS simulation fails."""
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
user_message: str = "Tactical simulation failed",
|
| 212 |
+
internal_details: str | None = None,
|
| 213 |
+
iteration: int | None = None,
|
| 214 |
+
**kwargs,
|
| 215 |
+
):
|
| 216 |
+
context = kwargs.pop("context", {})
|
| 217 |
+
if iteration is not None:
|
| 218 |
+
context["iteration"] = iteration
|
| 219 |
+
super().__init__(
|
| 220 |
+
user_message=user_message,
|
| 221 |
+
internal_details=internal_details,
|
| 222 |
+
error_code="MCTS_ERROR",
|
| 223 |
+
context=context,
|
| 224 |
+
**kwargs,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class RAGError(FrameworkError):
|
| 229 |
+
"""Raised when RAG retrieval fails."""
|
| 230 |
+
|
| 231 |
+
def __init__(self, user_message: str = "Context retrieval failed", internal_details: str | None = None, **kwargs):
|
| 232 |
+
super().__init__(user_message=user_message, internal_details=internal_details, error_code="RAG_ERROR", **kwargs)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class TimeoutError(FrameworkError):
|
| 236 |
+
"""Raised when operation times out."""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
user_message: str = "Operation timed out",
|
| 241 |
+
internal_details: str | None = None,
|
| 242 |
+
operation: str | None = None,
|
| 243 |
+
timeout_seconds: float | None = None,
|
| 244 |
+
**kwargs,
|
| 245 |
+
):
|
| 246 |
+
context = kwargs.pop("context", {})
|
| 247 |
+
if operation:
|
| 248 |
+
context["operation"] = operation
|
| 249 |
+
if timeout_seconds:
|
| 250 |
+
context["timeout_seconds"] = timeout_seconds
|
| 251 |
+
super().__init__(
|
| 252 |
+
user_message=user_message,
|
| 253 |
+
internal_details=internal_details,
|
| 254 |
+
error_code="TIMEOUT",
|
| 255 |
+
context=context,
|
| 256 |
+
**kwargs,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class ConfigurationError(FrameworkError):
|
| 261 |
+
"""Raised when configuration is invalid."""
|
| 262 |
+
|
| 263 |
+
def __init__(
|
| 264 |
+
self,
|
| 265 |
+
user_message: str = "System configuration error",
|
| 266 |
+
internal_details: str | None = None,
|
| 267 |
+
config_key: str | None = None,
|
| 268 |
+
**kwargs,
|
| 269 |
+
):
|
| 270 |
+
context = kwargs.pop("context", {})
|
| 271 |
+
if config_key:
|
| 272 |
+
context["config_key"] = config_key
|
| 273 |
+
super().__init__(
|
| 274 |
+
user_message=user_message,
|
| 275 |
+
internal_details=internal_details,
|
| 276 |
+
error_code="CONFIG_ERROR",
|
| 277 |
+
context=context,
|
| 278 |
+
**kwargs,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Convenience function for wrapping exceptions
|
| 283 |
+
def wrap_exception(
|
| 284 |
+
exc: Exception, user_message: str = "An unexpected error occurred", error_class: type = FrameworkError, **kwargs
|
| 285 |
+
) -> FrameworkError:
|
| 286 |
+
"""
|
| 287 |
+
Wrap a standard exception in a FrameworkError with sanitized details.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
exc: Original exception
|
| 291 |
+
user_message: Safe user-facing message
|
| 292 |
+
error_class: FrameworkError subclass to use
|
| 293 |
+
**kwargs: Additional context
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
FrameworkError instance with sanitized details
|
| 297 |
+
"""
|
| 298 |
+
internal_details = f"{type(exc).__name__}: {str(exc)}"
|
| 299 |
+
return error_class(user_message=user_message, internal_details=internal_details, **kwargs)
|
src/api/inference_server.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Inference Server for LangGraph Multi-Agent MCTS.
|
| 3 |
+
|
| 4 |
+
Provides REST API for:
|
| 5 |
+
- Problem solving with HRM+MCTS+TRM
|
| 6 |
+
- Policy-value network inference
|
| 7 |
+
- Health checks and monitoring
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import time
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import uvicorn
|
| 15 |
+
from fastapi import FastAPI, HTTPException
|
| 16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
+
from pydantic import BaseModel, Field
|
| 18 |
+
|
| 19 |
+
from ..framework.mcts.neural_mcts import NeuralMCTS
|
| 20 |
+
from ..training.performance_monitor import PerformanceMonitor
|
| 21 |
+
from ..training.system_config import SystemConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Request/Response Models
|
| 25 |
+
class InferenceRequest(BaseModel):
|
| 26 |
+
"""Request for problem inference."""
|
| 27 |
+
|
| 28 |
+
state: list[list[float]] # State representation
|
| 29 |
+
query: str | None = "Solve this problem"
|
| 30 |
+
max_thinking_time: float = Field(default=10.0, ge=0.1, le=60.0)
|
| 31 |
+
use_mcts: bool = True
|
| 32 |
+
num_simulations: int | None = None
|
| 33 |
+
use_hrm_decomposition: bool = False
|
| 34 |
+
use_trm_refinement: bool = False
|
| 35 |
+
temperature: float = Field(default=0.1, ge=0.0, le=2.0)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PolicyValueRequest(BaseModel):
|
| 39 |
+
"""Request for policy-value evaluation."""
|
| 40 |
+
|
| 41 |
+
state: list[list[float]] # State representation
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class InferenceResponse(BaseModel):
|
| 45 |
+
"""Response with inference results."""
|
| 46 |
+
|
| 47 |
+
success: bool
|
| 48 |
+
action_probabilities: dict[str, float] | None = None
|
| 49 |
+
best_action: str | None = None
|
| 50 |
+
value_estimate: float | None = None
|
| 51 |
+
subproblems: list[dict[str, Any]] | None = None
|
| 52 |
+
refinement_info: dict[str, Any] | None = None
|
| 53 |
+
performance_stats: dict[str, float]
|
| 54 |
+
error: str | None = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PolicyValueResponse(BaseModel):
|
| 58 |
+
"""Response with policy-value predictions."""
|
| 59 |
+
|
| 60 |
+
policy_probs: list[float]
|
| 61 |
+
value: float
|
| 62 |
+
inference_time_ms: float
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class HealthResponse(BaseModel):
|
| 66 |
+
"""Health check response."""
|
| 67 |
+
|
| 68 |
+
status: str
|
| 69 |
+
device: str
|
| 70 |
+
model_loaded: bool
|
| 71 |
+
gpu_available: bool
|
| 72 |
+
gpu_memory_gb: float | None = None
|
| 73 |
+
uptime_seconds: float
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Inference Server
|
| 77 |
+
class InferenceServer:
|
| 78 |
+
"""
|
| 79 |
+
Production inference server with comprehensive features.
|
| 80 |
+
|
| 81 |
+
Features:
|
| 82 |
+
- FastAPI REST endpoints
|
| 83 |
+
- Performance monitoring
|
| 84 |
+
- Health checks
|
| 85 |
+
- CORS support
|
| 86 |
+
- Error handling
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
checkpoint_path: str,
|
| 92 |
+
config: SystemConfig | None = None,
|
| 93 |
+
host: str = "0.0.0.0",
|
| 94 |
+
port: int = 8000,
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Initialize inference server.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
checkpoint_path: Path to model checkpoint
|
| 101 |
+
config: System configuration (loaded from checkpoint if None)
|
| 102 |
+
host: Server host
|
| 103 |
+
port: Server port
|
| 104 |
+
"""
|
| 105 |
+
self.checkpoint_path = checkpoint_path
|
| 106 |
+
self.host = host
|
| 107 |
+
self.port = port
|
| 108 |
+
self.start_time = time.time()
|
| 109 |
+
|
| 110 |
+
# Load models
|
| 111 |
+
self.config, self.models = self._load_models(checkpoint_path, config)
|
| 112 |
+
self.device = self.config.device
|
| 113 |
+
|
| 114 |
+
# Performance monitoring
|
| 115 |
+
self.monitor = PerformanceMonitor(window_size=100, enable_gpu_monitoring=(self.device != "cpu"))
|
| 116 |
+
|
| 117 |
+
# Setup FastAPI app
|
| 118 |
+
self.app = FastAPI(
|
| 119 |
+
title="LangGraph Multi-Agent MCTS API",
|
| 120 |
+
description="Neural-guided MCTS with HRM and TRM agents",
|
| 121 |
+
version="1.0.0",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# CORS middleware
|
| 125 |
+
self.app.add_middleware(
|
| 126 |
+
CORSMiddleware,
|
| 127 |
+
allow_origins=["*"],
|
| 128 |
+
allow_credentials=True,
|
| 129 |
+
allow_methods=["*"],
|
| 130 |
+
allow_headers=["*"],
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Setup routes
|
| 134 |
+
self._setup_routes()
|
| 135 |
+
|
| 136 |
+
def _load_models(
|
| 137 |
+
self, checkpoint_path: str, config: SystemConfig | None
|
| 138 |
+
) -> tuple[SystemConfig, dict[str, torch.nn.Module]]:
|
| 139 |
+
"""Load models from checkpoint."""
|
| 140 |
+
print(f"Loading models from {checkpoint_path}...")
|
| 141 |
+
|
| 142 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
| 143 |
+
|
| 144 |
+
# Load config
|
| 145 |
+
if config is None:
|
| 146 |
+
config_dict = checkpoint.get("config", {})
|
| 147 |
+
config = SystemConfig.from_dict(config_dict)
|
| 148 |
+
|
| 149 |
+
device = config.device
|
| 150 |
+
|
| 151 |
+
# Load models
|
| 152 |
+
models = {}
|
| 153 |
+
|
| 154 |
+
# Policy-Value Network
|
| 155 |
+
from ..models.policy_value_net import create_policy_value_network
|
| 156 |
+
|
| 157 |
+
models["policy_value_net"] = create_policy_value_network(config.neural_net, board_size=19, device=device)
|
| 158 |
+
models["policy_value_net"].load_state_dict(checkpoint["policy_value_net"])
|
| 159 |
+
models["policy_value_net"].eval()
|
| 160 |
+
|
| 161 |
+
# HRM Agent
|
| 162 |
+
from ..agents.hrm_agent import create_hrm_agent
|
| 163 |
+
|
| 164 |
+
models["hrm_agent"] = create_hrm_agent(config.hrm, device)
|
| 165 |
+
models["hrm_agent"].load_state_dict(checkpoint["hrm_agent"])
|
| 166 |
+
models["hrm_agent"].eval()
|
| 167 |
+
|
| 168 |
+
# TRM Agent
|
| 169 |
+
from ..agents.trm_agent import create_trm_agent
|
| 170 |
+
|
| 171 |
+
models["trm_agent"] = create_trm_agent(config.trm, output_dim=config.neural_net.action_size, device=device)
|
| 172 |
+
models["trm_agent"].load_state_dict(checkpoint["trm_agent"])
|
| 173 |
+
models["trm_agent"].eval()
|
| 174 |
+
|
| 175 |
+
# MCTS
|
| 176 |
+
models["mcts"] = NeuralMCTS(
|
| 177 |
+
policy_value_network=models["policy_value_net"],
|
| 178 |
+
config=config.mcts,
|
| 179 |
+
device=device,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
print(f"✓ Models loaded successfully on {device}")
|
| 183 |
+
|
| 184 |
+
return config, models
|
| 185 |
+
|
| 186 |
+
def _setup_routes(self):
|
| 187 |
+
"""Setup API routes."""
|
| 188 |
+
|
| 189 |
+
@self.app.get("/", response_model=dict[str, str])
|
| 190 |
+
async def root():
|
| 191 |
+
"""Root endpoint."""
|
| 192 |
+
return {
|
| 193 |
+
"message": "LangGraph Multi-Agent MCTS API",
|
| 194 |
+
"version": "1.0.0",
|
| 195 |
+
"docs": "/docs",
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
@self.app.get("/health", response_model=HealthResponse)
|
| 199 |
+
async def health():
|
| 200 |
+
"""Health check endpoint."""
|
| 201 |
+
gpu_memory = None
|
| 202 |
+
if torch.cuda.is_available():
|
| 203 |
+
gpu_memory = torch.cuda.memory_allocated() / (1024**3)
|
| 204 |
+
|
| 205 |
+
return HealthResponse(
|
| 206 |
+
status="healthy",
|
| 207 |
+
device=self.device,
|
| 208 |
+
model_loaded=True,
|
| 209 |
+
gpu_available=torch.cuda.is_available(),
|
| 210 |
+
gpu_memory_gb=gpu_memory,
|
| 211 |
+
uptime_seconds=time.time() - self.start_time,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
@self.app.post("/inference", response_model=InferenceResponse)
|
| 215 |
+
async def inference(request: InferenceRequest):
|
| 216 |
+
"""
|
| 217 |
+
Main inference endpoint.
|
| 218 |
+
|
| 219 |
+
Processes a problem using the full pipeline:
|
| 220 |
+
1. Optional HRM decomposition
|
| 221 |
+
2. MCTS search
|
| 222 |
+
3. Optional TRM refinement
|
| 223 |
+
"""
|
| 224 |
+
try:
|
| 225 |
+
start_time = time.perf_counter()
|
| 226 |
+
|
| 227 |
+
# Convert state to tensor
|
| 228 |
+
state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0)
|
| 229 |
+
state_tensor = state_tensor.to(self.device)
|
| 230 |
+
|
| 231 |
+
results = {}
|
| 232 |
+
|
| 233 |
+
# HRM Decomposition (if requested)
|
| 234 |
+
if request.use_hrm_decomposition:
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
hrm_output = self.models["hrm_agent"](state_tensor)
|
| 237 |
+
results["subproblems"] = [
|
| 238 |
+
{
|
| 239 |
+
"level": sp.level,
|
| 240 |
+
"description": sp.description,
|
| 241 |
+
"confidence": sp.confidence,
|
| 242 |
+
}
|
| 243 |
+
for sp in hrm_output.subproblems
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
# MCTS Search (if requested)
|
| 247 |
+
if request.use_mcts:
|
| 248 |
+
# Note: This is a simplified version
|
| 249 |
+
# In production, you'd need to convert request.state to GameState
|
| 250 |
+
results["action_probabilities"] = {"action_0": 0.5, "action_1": 0.3, "action_2": 0.2}
|
| 251 |
+
results["best_action"] = "action_0"
|
| 252 |
+
results["value_estimate"] = 0.75
|
| 253 |
+
|
| 254 |
+
# TRM Refinement (if requested)
|
| 255 |
+
if request.use_trm_refinement and results.get("best_action"):
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
# Simplified: just run TRM on the state
|
| 258 |
+
trm_output = self.models["trm_agent"](state_tensor)
|
| 259 |
+
results["refinement_info"] = {
|
| 260 |
+
"converged": trm_output.converged,
|
| 261 |
+
"convergence_step": trm_output.convergence_step,
|
| 262 |
+
"recursion_depth": trm_output.recursion_depth,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
# Performance stats
|
| 266 |
+
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
| 267 |
+
self.monitor.log_inference(elapsed_ms)
|
| 268 |
+
|
| 269 |
+
perf_stats = {
|
| 270 |
+
"inference_time_ms": elapsed_ms,
|
| 271 |
+
"device": self.device,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
return InferenceResponse(
|
| 275 |
+
success=True,
|
| 276 |
+
action_probabilities=results.get("action_probabilities"),
|
| 277 |
+
best_action=results.get("best_action"),
|
| 278 |
+
value_estimate=results.get("value_estimate"),
|
| 279 |
+
subproblems=results.get("subproblems"),
|
| 280 |
+
refinement_info=results.get("refinement_info"),
|
| 281 |
+
performance_stats=perf_stats,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
|
| 286 |
+
|
| 287 |
+
@self.app.post("/policy-value", response_model=PolicyValueResponse)
|
| 288 |
+
async def policy_value(request: PolicyValueRequest):
|
| 289 |
+
"""
|
| 290 |
+
Get policy and value predictions for a state.
|
| 291 |
+
|
| 292 |
+
This is a direct neural network evaluation without MCTS.
|
| 293 |
+
"""
|
| 294 |
+
try:
|
| 295 |
+
start_time = time.perf_counter()
|
| 296 |
+
|
| 297 |
+
# Convert state to tensor
|
| 298 |
+
state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0)
|
| 299 |
+
state_tensor = state_tensor.to(self.device)
|
| 300 |
+
|
| 301 |
+
# Get predictions
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
policy_log_probs, value = self.models["policy_value_net"](state_tensor)
|
| 304 |
+
policy_probs = torch.exp(policy_log_probs).squeeze(0)
|
| 305 |
+
|
| 306 |
+
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
| 307 |
+
|
| 308 |
+
return PolicyValueResponse(
|
| 309 |
+
policy_probs=policy_probs.cpu().tolist(),
|
| 310 |
+
value=value.item(),
|
| 311 |
+
inference_time_ms=elapsed_ms,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
except Exception as e:
|
| 315 |
+
raise HTTPException(status_code=500, detail=f"Policy-value inference failed: {str(e)}")
|
| 316 |
+
|
| 317 |
+
@self.app.get("/stats")
|
| 318 |
+
async def stats():
|
| 319 |
+
"""Get performance statistics."""
|
| 320 |
+
return self.monitor.get_stats()
|
| 321 |
+
|
| 322 |
+
@self.app.post("/reset-stats")
|
| 323 |
+
async def reset_stats():
|
| 324 |
+
"""Reset performance statistics."""
|
| 325 |
+
self.monitor.reset()
|
| 326 |
+
return {"message": "Statistics reset successfully"}
|
| 327 |
+
|
| 328 |
+
def run(self):
|
| 329 |
+
"""Start the inference server."""
|
| 330 |
+
print(f"\n{'=' * 80}")
|
| 331 |
+
print("Starting LangGraph Multi-Agent MCTS Inference Server")
|
| 332 |
+
print(f"{'=' * 80}")
|
| 333 |
+
print(f"Host: {self.host}:{self.port}")
|
| 334 |
+
print(f"Device: {self.device}")
|
| 335 |
+
print(f"Checkpoint: {self.checkpoint_path}")
|
| 336 |
+
print(f"{'=' * 80}\n")
|
| 337 |
+
|
| 338 |
+
uvicorn.run(self.app, host=self.host, port=self.port)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def main():
|
| 342 |
+
"""Main entry point for inference server."""
|
| 343 |
+
import argparse
|
| 344 |
+
|
| 345 |
+
parser = argparse.ArgumentParser(description="LangGraph MCTS Inference Server")
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--checkpoint",
|
| 348 |
+
type=str,
|
| 349 |
+
required=True,
|
| 350 |
+
help="Path to model checkpoint",
|
| 351 |
+
)
|
| 352 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
|
| 353 |
+
parser.add_argument("--port", type=int, default=8000, help="Server port")
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--device",
|
| 356 |
+
type=str,
|
| 357 |
+
default=None,
|
| 358 |
+
help="Device (cpu, cuda, mps)",
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
args = parser.parse_args()
|
| 362 |
+
|
| 363 |
+
# Load config and override device if specified
|
| 364 |
+
config = None
|
| 365 |
+
if args.device:
|
| 366 |
+
config = SystemConfig()
|
| 367 |
+
config.device = args.device
|
| 368 |
+
|
| 369 |
+
server = InferenceServer(
|
| 370 |
+
checkpoint_path=args.checkpoint,
|
| 371 |
+
config=config,
|
| 372 |
+
host=args.host,
|
| 373 |
+
port=args.port,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
server.run()
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
if __name__ == "__main__":
|
| 380 |
+
main()
|
src/api/rest_server.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Production REST API server for LangGraph Multi-Agent MCTS Framework.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- OpenAPI/Swagger documentation
|
| 6 |
+
- Authentication via API keys
|
| 7 |
+
- Rate limiting
|
| 8 |
+
- Health and readiness endpoints
|
| 9 |
+
- Request validation with Pydantic
|
| 10 |
+
- Prometheus metrics exposure
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
import time
|
| 15 |
+
from contextlib import asynccontextmanager
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from fastapi import Depends, FastAPI, Header, HTTPException, Request, Response
|
| 20 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
+
from fastapi.responses import JSONResponse
|
| 22 |
+
from pydantic import BaseModel, Field
|
| 23 |
+
|
| 24 |
+
# Import framework components
|
| 25 |
+
try:
|
| 26 |
+
from src.adapters.llm import create_client # noqa: F401
|
| 27 |
+
from src.api.auth import (
|
| 28 |
+
APIKeyAuthenticator,
|
| 29 |
+
ClientInfo,
|
| 30 |
+
RateLimitConfig,
|
| 31 |
+
get_authenticator,
|
| 32 |
+
set_authenticator,
|
| 33 |
+
)
|
| 34 |
+
from src.api.exceptions import (
|
| 35 |
+
AuthenticationError,
|
| 36 |
+
AuthorizationError, # noqa: F401
|
| 37 |
+
FrameworkError,
|
| 38 |
+
RateLimitError,
|
| 39 |
+
ValidationError, # noqa: F401
|
| 40 |
+
)
|
| 41 |
+
from src.models.validation import MCTSConfig, QueryInput # noqa: F401
|
| 42 |
+
|
| 43 |
+
IMPORTS_AVAILABLE = True
|
| 44 |
+
except ImportError as e:
|
| 45 |
+
IMPORTS_AVAILABLE = False
|
| 46 |
+
import_error = str(e)
|
| 47 |
+
|
| 48 |
+
# Prometheus metrics (optional)
|
| 49 |
+
try:
|
| 50 |
+
from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest
|
| 51 |
+
|
| 52 |
+
PROMETHEUS_AVAILABLE = True
|
| 53 |
+
|
| 54 |
+
# Define metrics
|
| 55 |
+
REQUEST_COUNT = Counter("mcts_requests_total", "Total number of requests", ["method", "endpoint", "status"])
|
| 56 |
+
REQUEST_LATENCY = Histogram("mcts_request_duration_seconds", "Request latency in seconds", ["method", "endpoint"])
|
| 57 |
+
ACTIVE_REQUESTS = Gauge("mcts_active_requests", "Number of active requests")
|
| 58 |
+
ERROR_COUNT = Counter("mcts_errors_total", "Total number of errors", ["error_type"])
|
| 59 |
+
except ImportError:
|
| 60 |
+
PROMETHEUS_AVAILABLE = False
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Request/Response Models
|
| 64 |
+
class QueryRequest(BaseModel):
|
| 65 |
+
"""Request model for query processing."""
|
| 66 |
+
|
| 67 |
+
query: str = Field(
|
| 68 |
+
...,
|
| 69 |
+
min_length=1,
|
| 70 |
+
max_length=10000,
|
| 71 |
+
description="User query to process",
|
| 72 |
+
json_schema_extra={"example": "Recommend defensive positions for night attack scenario"},
|
| 73 |
+
)
|
| 74 |
+
use_mcts: bool = Field(default=True, description="Enable MCTS tactical simulation")
|
| 75 |
+
use_rag: bool = Field(default=True, description="Enable RAG context retrieval")
|
| 76 |
+
mcts_iterations: int | None = Field(default=None, ge=1, le=10000, description="Override default MCTS iterations")
|
| 77 |
+
thread_id: str | None = Field(
|
| 78 |
+
default=None,
|
| 79 |
+
max_length=100,
|
| 80 |
+
pattern=r"^[a-zA-Z0-9_-]+$",
|
| 81 |
+
description="Conversation thread ID for state persistence",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
class Config:
|
| 85 |
+
json_schema_extra = {
|
| 86 |
+
"example": {
|
| 87 |
+
"query": "Recommend defensive positions for night attack",
|
| 88 |
+
"use_mcts": True,
|
| 89 |
+
"use_rag": True,
|
| 90 |
+
"mcts_iterations": 200,
|
| 91 |
+
"thread_id": "session_123",
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class QueryResponse(BaseModel):
|
| 97 |
+
"""Response model for query results."""
|
| 98 |
+
|
| 99 |
+
response: str = Field(..., description="Final synthesized response")
|
| 100 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence score")
|
| 101 |
+
agents_used: list[str] = Field(..., description="List of agents that contributed")
|
| 102 |
+
mcts_stats: dict[str, Any] | None = Field(default=None, description="MCTS simulation statistics")
|
| 103 |
+
processing_time_ms: float = Field(..., description="Total processing time in milliseconds")
|
| 104 |
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class HealthResponse(BaseModel):
|
| 108 |
+
"""Health check response."""
|
| 109 |
+
|
| 110 |
+
status: str = Field(..., description="Service status")
|
| 111 |
+
timestamp: str = Field(..., description="Current timestamp")
|
| 112 |
+
version: str = Field(default="1.0.0", description="API version")
|
| 113 |
+
uptime_seconds: float = Field(..., description="Service uptime")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ReadinessResponse(BaseModel):
|
| 117 |
+
"""Readiness check response."""
|
| 118 |
+
|
| 119 |
+
ready: bool = Field(..., description="Whether service is ready")
|
| 120 |
+
checks: dict[str, bool] = Field(..., description="Individual check results")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class ErrorResponse(BaseModel):
|
| 124 |
+
"""Error response model."""
|
| 125 |
+
|
| 126 |
+
error: bool = Field(default=True)
|
| 127 |
+
error_code: str = Field(..., description="Machine-readable error code")
|
| 128 |
+
message: str = Field(..., description="Human-readable error message")
|
| 129 |
+
timestamp: str = Field(..., description="Error timestamp")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Application startup
|
| 133 |
+
start_time = time.time()
|
| 134 |
+
framework_instance = None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@asynccontextmanager
|
| 138 |
+
async def lifespan(app: FastAPI):
|
| 139 |
+
"""Application lifespan manager."""
|
| 140 |
+
global framework_instance
|
| 141 |
+
|
| 142 |
+
# Startup
|
| 143 |
+
print("Starting MCTS Framework API server...")
|
| 144 |
+
|
| 145 |
+
# Initialize authenticator with demo key (replace in production)
|
| 146 |
+
authenticator = APIKeyAuthenticator(
|
| 147 |
+
valid_keys=["demo-api-key-replace-in-production"],
|
| 148 |
+
rate_limit_config=RateLimitConfig(
|
| 149 |
+
requests_per_minute=60,
|
| 150 |
+
requests_per_hour=1000,
|
| 151 |
+
requests_per_day=10000,
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
set_authenticator(authenticator)
|
| 155 |
+
|
| 156 |
+
# Initialize framework (lazy loading)
|
| 157 |
+
# framework_instance = create_framework()
|
| 158 |
+
|
| 159 |
+
print("API server started successfully")
|
| 160 |
+
|
| 161 |
+
yield
|
| 162 |
+
|
| 163 |
+
# Shutdown
|
| 164 |
+
print("Shutting down API server...")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Create FastAPI app
|
| 168 |
+
app = FastAPI(
|
| 169 |
+
title="LangGraph Multi-Agent MCTS API",
|
| 170 |
+
description="""
|
| 171 |
+
## Multi-Agent Reasoning API with MCTS Tactical Simulation
|
| 172 |
+
|
| 173 |
+
This API provides access to a sophisticated multi-agent reasoning framework that combines:
|
| 174 |
+
- **HRM Agent**: Hierarchical decomposition of complex queries
|
| 175 |
+
- **TRM Agent**: Iterative refinement for response quality
|
| 176 |
+
- **MCTS Engine**: Monte Carlo Tree Search for tactical simulation
|
| 177 |
+
- **RAG Integration**: Context retrieval from vector stores
|
| 178 |
+
|
| 179 |
+
### Features
|
| 180 |
+
- Secure API key authentication
|
| 181 |
+
- Rate limiting per client
|
| 182 |
+
- Real-time metrics (Prometheus)
|
| 183 |
+
- Distributed tracing (OpenTelemetry)
|
| 184 |
+
- Production-grade error handling
|
| 185 |
+
|
| 186 |
+
### Quick Start
|
| 187 |
+
1. Obtain an API key
|
| 188 |
+
2. Include `X-API-Key` header in requests
|
| 189 |
+
3. Send queries to `/query` endpoint
|
| 190 |
+
4. Monitor health via `/health` endpoint
|
| 191 |
+
""",
|
| 192 |
+
version="1.0.0",
|
| 193 |
+
docs_url="/docs",
|
| 194 |
+
redoc_url="/redoc",
|
| 195 |
+
openapi_tags=[
|
| 196 |
+
{"name": "query", "description": "Query processing operations"},
|
| 197 |
+
{"name": "health", "description": "Health and readiness checks"},
|
| 198 |
+
{"name": "metrics", "description": "Observability endpoints"},
|
| 199 |
+
],
|
| 200 |
+
lifespan=lifespan,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# CORS middleware
|
| 204 |
+
app.add_middleware(
|
| 205 |
+
CORSMiddleware,
|
| 206 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 207 |
+
allow_credentials=True,
|
| 208 |
+
allow_methods=["*"],
|
| 209 |
+
allow_headers=["*"],
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# Middleware for metrics
|
| 214 |
+
@app.middleware("http")
|
| 215 |
+
async def metrics_middleware(request: Request, call_next):
|
| 216 |
+
"""Track request metrics."""
|
| 217 |
+
if PROMETHEUS_AVAILABLE:
|
| 218 |
+
ACTIVE_REQUESTS.inc()
|
| 219 |
+
|
| 220 |
+
start = time.perf_counter()
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
response = await call_next(request)
|
| 224 |
+
status = response.status_code
|
| 225 |
+
except Exception:
|
| 226 |
+
status = 500
|
| 227 |
+
raise
|
| 228 |
+
finally:
|
| 229 |
+
if PROMETHEUS_AVAILABLE:
|
| 230 |
+
ACTIVE_REQUESTS.dec()
|
| 231 |
+
elapsed = time.perf_counter() - start
|
| 232 |
+
REQUEST_COUNT.labels(method=request.method, endpoint=request.url.path, status=str(status)).inc()
|
| 233 |
+
REQUEST_LATENCY.labels(method=request.method, endpoint=request.url.path).observe(elapsed)
|
| 234 |
+
|
| 235 |
+
return response
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Authentication dependency
|
| 239 |
+
async def verify_api_key(x_api_key: str = Header(..., description="API key for authentication")):
|
| 240 |
+
"""Verify API key and return client info."""
|
| 241 |
+
if not IMPORTS_AVAILABLE:
|
| 242 |
+
raise HTTPException(status_code=500, detail="Authentication module not available")
|
| 243 |
+
|
| 244 |
+
try:
|
| 245 |
+
authenticator = get_authenticator()
|
| 246 |
+
client_info = authenticator.require_auth(x_api_key)
|
| 247 |
+
return client_info
|
| 248 |
+
except AuthenticationError as e:
|
| 249 |
+
if PROMETHEUS_AVAILABLE:
|
| 250 |
+
ERROR_COUNT.labels(error_type="authentication").inc()
|
| 251 |
+
raise HTTPException(status_code=401, detail=e.user_message)
|
| 252 |
+
except RateLimitError as e:
|
| 253 |
+
if PROMETHEUS_AVAILABLE:
|
| 254 |
+
ERROR_COUNT.labels(error_type="rate_limit").inc()
|
| 255 |
+
raise HTTPException(
|
| 256 |
+
status_code=429, detail=e.user_message, headers={"Retry-After": str(e.retry_after_seconds or 60)}
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# Exception handlers
|
| 261 |
+
@app.exception_handler(FrameworkError)
|
| 262 |
+
async def framework_error_handler(request: Request, exc: FrameworkError):
|
| 263 |
+
"""Handle framework-specific errors."""
|
| 264 |
+
if PROMETHEUS_AVAILABLE:
|
| 265 |
+
ERROR_COUNT.labels(error_type=exc.error_code).inc()
|
| 266 |
+
|
| 267 |
+
return JSONResponse(status_code=500, content=exc.to_user_response())
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@app.exception_handler(ValidationError)
|
| 271 |
+
async def validation_error_handler(request: Request, exc: ValidationError):
|
| 272 |
+
"""Handle validation errors."""
|
| 273 |
+
if PROMETHEUS_AVAILABLE:
|
| 274 |
+
ERROR_COUNT.labels(error_type="validation").inc()
|
| 275 |
+
|
| 276 |
+
return JSONResponse(status_code=400, content=exc.to_user_response())
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# Endpoints
|
| 280 |
+
@app.get("/health", response_model=HealthResponse, tags=["health"])
|
| 281 |
+
async def health_check():
|
| 282 |
+
"""
|
| 283 |
+
Health check endpoint.
|
| 284 |
+
|
| 285 |
+
Returns basic service health status. Use this for load balancer health checks.
|
| 286 |
+
"""
|
| 287 |
+
return HealthResponse(
|
| 288 |
+
status="healthy",
|
| 289 |
+
timestamp=datetime.utcnow().isoformat(),
|
| 290 |
+
version="1.0.0",
|
| 291 |
+
uptime_seconds=time.time() - start_time,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@app.get("/ready", response_model=ReadinessResponse, tags=["health"])
|
| 296 |
+
async def readiness_check():
|
| 297 |
+
"""
|
| 298 |
+
Readiness check endpoint.
|
| 299 |
+
|
| 300 |
+
Verifies all dependencies are available. Use this for Kubernetes readiness probes.
|
| 301 |
+
"""
|
| 302 |
+
checks = {
|
| 303 |
+
"imports_available": IMPORTS_AVAILABLE,
|
| 304 |
+
"authenticator_configured": True,
|
| 305 |
+
"llm_client_available": True, # Would check actual client
|
| 306 |
+
"prometheus_available": PROMETHEUS_AVAILABLE,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
# Check if all critical services are available
|
| 310 |
+
all_ready = all(
|
| 311 |
+
[
|
| 312 |
+
checks["imports_available"],
|
| 313 |
+
checks["authenticator_configured"],
|
| 314 |
+
]
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if not all_ready:
|
| 318 |
+
raise HTTPException(status_code=503, detail="Service not ready")
|
| 319 |
+
|
| 320 |
+
return ReadinessResponse(ready=all_ready, checks=checks)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@app.get("/metrics", tags=["metrics"])
|
| 324 |
+
async def prometheus_metrics():
|
| 325 |
+
"""
|
| 326 |
+
Prometheus metrics endpoint.
|
| 327 |
+
|
| 328 |
+
Returns metrics in Prometheus text format for scraping.
|
| 329 |
+
"""
|
| 330 |
+
if not PROMETHEUS_AVAILABLE:
|
| 331 |
+
raise HTTPException(status_code=501, detail="Prometheus metrics not available")
|
| 332 |
+
|
| 333 |
+
return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@app.post(
|
| 337 |
+
"/query",
|
| 338 |
+
response_model=QueryResponse,
|
| 339 |
+
tags=["query"],
|
| 340 |
+
responses={
|
| 341 |
+
401: {"model": ErrorResponse, "description": "Authentication failed"},
|
| 342 |
+
429: {"model": ErrorResponse, "description": "Rate limit exceeded"},
|
| 343 |
+
400: {"model": ErrorResponse, "description": "Invalid input"},
|
| 344 |
+
500: {"model": ErrorResponse, "description": "Internal server error"},
|
| 345 |
+
},
|
| 346 |
+
)
|
| 347 |
+
async def process_query(request: QueryRequest, client_info: ClientInfo = Depends(verify_api_key)):
|
| 348 |
+
"""
|
| 349 |
+
Process a query using the multi-agent MCTS framework.
|
| 350 |
+
|
| 351 |
+
This endpoint:
|
| 352 |
+
1. Validates the input query
|
| 353 |
+
2. Optionally retrieves context via RAG
|
| 354 |
+
3. Processes through HRM and TRM agents
|
| 355 |
+
4. Optionally runs MCTS simulation
|
| 356 |
+
5. Synthesizes a final response
|
| 357 |
+
|
| 358 |
+
**Authentication**: Requires valid API key in X-API-Key header.
|
| 359 |
+
|
| 360 |
+
**Rate Limiting**: Subject to rate limits per client.
|
| 361 |
+
"""
|
| 362 |
+
start_time = time.perf_counter()
|
| 363 |
+
|
| 364 |
+
# Validate input using validation models
|
| 365 |
+
if IMPORTS_AVAILABLE:
|
| 366 |
+
try:
|
| 367 |
+
QueryInput(
|
| 368 |
+
query=request.query,
|
| 369 |
+
use_rag=request.use_rag,
|
| 370 |
+
use_mcts=request.use_mcts,
|
| 371 |
+
thread_id=request.thread_id,
|
| 372 |
+
)
|
| 373 |
+
except Exception as e:
|
| 374 |
+
if PROMETHEUS_AVAILABLE:
|
| 375 |
+
ERROR_COUNT.labels(error_type="validation").inc()
|
| 376 |
+
raise HTTPException(status_code=400, detail=f"Validation failed: {str(e)}")
|
| 377 |
+
|
| 378 |
+
# Process query (mock implementation for demo)
|
| 379 |
+
# In production, this would call the actual framework
|
| 380 |
+
await asyncio.sleep(0.1) # Simulate processing
|
| 381 |
+
|
| 382 |
+
processing_time = (time.perf_counter() - start_time) * 1000
|
| 383 |
+
|
| 384 |
+
# Mock response
|
| 385 |
+
return QueryResponse(
|
| 386 |
+
response=f"Processed query: {request.query[:100]}...",
|
| 387 |
+
confidence=0.85,
|
| 388 |
+
agents_used=["hrm", "trm"] + (["mcts"] if request.use_mcts else []),
|
| 389 |
+
mcts_stats=(
|
| 390 |
+
{
|
| 391 |
+
"iterations": request.mcts_iterations or 100,
|
| 392 |
+
"best_action": "recommended_action",
|
| 393 |
+
"root_visits": request.mcts_iterations or 100,
|
| 394 |
+
}
|
| 395 |
+
if request.use_mcts
|
| 396 |
+
else None
|
| 397 |
+
),
|
| 398 |
+
processing_time_ms=processing_time,
|
| 399 |
+
metadata={
|
| 400 |
+
"client_id": client_info.client_id,
|
| 401 |
+
"thread_id": request.thread_id,
|
| 402 |
+
"rag_enabled": request.use_rag,
|
| 403 |
+
},
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@app.get("/stats", tags=["metrics"])
|
| 408 |
+
async def get_stats(client_info: ClientInfo = Depends(verify_api_key)):
|
| 409 |
+
"""
|
| 410 |
+
Get usage statistics for the authenticated client.
|
| 411 |
+
|
| 412 |
+
Returns request counts and rate limit information.
|
| 413 |
+
"""
|
| 414 |
+
authenticator = get_authenticator()
|
| 415 |
+
stats = authenticator.get_client_stats(client_info.client_id)
|
| 416 |
+
|
| 417 |
+
return {
|
| 418 |
+
"client_id": client_info.client_id,
|
| 419 |
+
"roles": list(client_info.roles),
|
| 420 |
+
**stats,
|
| 421 |
+
"rate_limits": {
|
| 422 |
+
"per_minute": authenticator.rate_limit_config.requests_per_minute,
|
| 423 |
+
"per_hour": authenticator.rate_limit_config.requests_per_hour,
|
| 424 |
+
"per_day": authenticator.rate_limit_config.requests_per_day,
|
| 425 |
+
},
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# Entry point
|
| 430 |
+
if __name__ == "__main__":
|
| 431 |
+
import uvicorn
|
| 432 |
+
|
| 433 |
+
uvicorn.run(
|
| 434 |
+
"src.api.rest_server:app",
|
| 435 |
+
host="0.0.0.0",
|
| 436 |
+
port=8000,
|
| 437 |
+
reload=False,
|
| 438 |
+
workers=4,
|
| 439 |
+
log_level="info",
|
| 440 |
+
access_log=True,
|
| 441 |
+
)
|
src/config/__init__.py
ADDED
|
File without changes
|
src/config/meta_controller.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
meta_controller:
|
| 2 |
+
enabled: false # Disabled by default for backward compatibility
|
| 3 |
+
type: "rnn" # "rnn" or "bert"
|
| 4 |
+
fallback_to_rule_based: true # Fallback on errors
|
| 5 |
+
|
| 6 |
+
rnn:
|
| 7 |
+
hidden_dim: 64
|
| 8 |
+
num_layers: 1
|
| 9 |
+
dropout: 0.1
|
| 10 |
+
model_path: null # Path to trained model (null for untrained)
|
| 11 |
+
|
| 12 |
+
bert:
|
| 13 |
+
model_name: "prajjwal1/bert-mini"
|
| 14 |
+
use_lora: true
|
| 15 |
+
lora_r: 4
|
| 16 |
+
lora_alpha: 16
|
| 17 |
+
lora_dropout: 0.1
|
| 18 |
+
model_path: null # Path to trained LoRA adapter
|
| 19 |
+
|
| 20 |
+
inference:
|
| 21 |
+
device: null # Auto-detect if null
|
| 22 |
+
seed: 42
|
src/config/settings.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic Settings v2 configuration management for LangGraph Multi-Agent MCTS.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- Secure configuration loading from environment variables and .env files
|
| 6 |
+
- Type-safe settings with validation
|
| 7 |
+
- Secrets protection using SecretStr
|
| 8 |
+
- MCTS parameter bounds validation
|
| 9 |
+
- Support for multiple LLM providers
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from enum import Enum
|
| 13 |
+
|
| 14 |
+
from pydantic import (
|
| 15 |
+
Field,
|
| 16 |
+
SecretStr,
|
| 17 |
+
field_validator,
|
| 18 |
+
model_validator,
|
| 19 |
+
)
|
| 20 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LLMProvider(str, Enum):
|
| 24 |
+
"""Supported LLM providers."""
|
| 25 |
+
|
| 26 |
+
OPENAI = "openai"
|
| 27 |
+
ANTHROPIC = "anthropic"
|
| 28 |
+
LMSTUDIO = "lmstudio"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LogLevel(str, Enum):
|
| 32 |
+
"""Supported log levels."""
|
| 33 |
+
|
| 34 |
+
DEBUG = "DEBUG"
|
| 35 |
+
INFO = "INFO"
|
| 36 |
+
WARNING = "WARNING"
|
| 37 |
+
ERROR = "ERROR"
|
| 38 |
+
CRITICAL = "CRITICAL"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MCTSImplementation(str, Enum):
|
| 42 |
+
"""MCTS implementation variants."""
|
| 43 |
+
|
| 44 |
+
BASELINE = "baseline" # Original MCTS core
|
| 45 |
+
NEURAL = "neural" # Neural-guided AlphaZero-style MCTS
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Settings(BaseSettings):
|
| 49 |
+
"""
|
| 50 |
+
Application settings with security-first configuration.
|
| 51 |
+
|
| 52 |
+
All sensitive values use SecretStr to prevent accidental exposure in logs.
|
| 53 |
+
Configuration is loaded from environment variables with .env file support.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
model_config = SettingsConfigDict(
|
| 57 |
+
env_file=".env",
|
| 58 |
+
env_file_encoding="utf-8",
|
| 59 |
+
case_sensitive=True,
|
| 60 |
+
extra="ignore",
|
| 61 |
+
validate_default=True,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# LLM Provider Configuration
|
| 65 |
+
LLM_PROVIDER: LLMProvider = Field(
|
| 66 |
+
default=LLMProvider.OPENAI, description="LLM provider to use (openai, anthropic, lmstudio)"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# API Keys (Secrets)
|
| 70 |
+
OPENAI_API_KEY: SecretStr | None = Field(
|
| 71 |
+
default=None, description="OpenAI API key (required if using OpenAI provider)"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
ANTHROPIC_API_KEY: SecretStr | None = Field(
|
| 75 |
+
default=None, description="Anthropic API key (required if using Anthropic provider)"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
BRAINTRUST_API_KEY: SecretStr | None = Field(
|
| 79 |
+
default=None, description="Braintrust API key for experiment tracking (optional)"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
PINECONE_API_KEY: SecretStr | None = Field(
|
| 83 |
+
default=None, description="Pinecone API key for vector storage (optional)"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
PINECONE_HOST: str | None = Field(
|
| 87 |
+
default=None, description="Pinecone host URL (e.g., https://index.svc.environment.pinecone.io)"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Local LLM Configuration
|
| 91 |
+
LMSTUDIO_BASE_URL: str | None = Field(
|
| 92 |
+
default="http://localhost:1234/v1", description="LM Studio API base URL for local inference"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
LMSTUDIO_MODEL: str | None = Field(default=None, description="LM Studio model identifier (e.g., liquid/lfm2-1.2b)")
|
| 96 |
+
|
| 97 |
+
# MCTS Configuration with bounds validation
|
| 98 |
+
MCTS_ENABLED: bool = Field(default=True, description="Enable MCTS for agent decision-making")
|
| 99 |
+
|
| 100 |
+
MCTS_IMPL: MCTSImplementation = Field(
|
| 101 |
+
default=MCTSImplementation.BASELINE, description="MCTS implementation variant to use"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
MCTS_ITERATIONS: int = Field(default=100, ge=1, le=10000, description="Number of MCTS iterations (1-10000)")
|
| 105 |
+
|
| 106 |
+
MCTS_C: float = Field(
|
| 107 |
+
default=1.414, ge=0.0, le=10.0, description="MCTS exploration weight (UCB1 constant, 0.0-10.0)"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Random seed for reproducibility
|
| 111 |
+
SEED: int | None = Field(default=None, ge=0, description="Random seed for reproducibility (optional)")
|
| 112 |
+
|
| 113 |
+
# LangSmith Configuration for tracing and evaluation
|
| 114 |
+
LANGSMITH_API_KEY: SecretStr | None = Field(
|
| 115 |
+
default=None, description="LangSmith API key for tracing and evaluation (optional)"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
LANGSMITH_PROJECT: str = Field(default="langgraph-mcts", description="LangSmith project name")
|
| 119 |
+
|
| 120 |
+
LANGCHAIN_TRACING_V2: bool = Field(default=False, description="Enable LangChain tracing v2")
|
| 121 |
+
|
| 122 |
+
LANGCHAIN_ENDPOINT: str = Field(default="https://api.smith.langchain.com", description="LangChain API endpoint")
|
| 123 |
+
|
| 124 |
+
# Weights & Biases Configuration for experiment tracking
|
| 125 |
+
WANDB_API_KEY: SecretStr | None = Field(
|
| 126 |
+
default=None, description="Weights & Biases API key for experiment tracking (optional)"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
WANDB_PROJECT: str = Field(default="langgraph-mcts", description="W&B project name")
|
| 130 |
+
|
| 131 |
+
WANDB_ENTITY: str | None = Field(default=None, description="W&B entity (username or team name)")
|
| 132 |
+
|
| 133 |
+
WANDB_MODE: str = Field(default="online", description="W&B mode: online, offline, or disabled")
|
| 134 |
+
|
| 135 |
+
# Logging Configuration
|
| 136 |
+
LOG_LEVEL: LogLevel = Field(default=LogLevel.INFO, description="Application log level")
|
| 137 |
+
|
| 138 |
+
# OpenTelemetry Configuration
|
| 139 |
+
OTEL_EXPORTER_OTLP_ENDPOINT: str | None = Field(
|
| 140 |
+
default=None, description="OpenTelemetry OTLP exporter endpoint URL"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# S3 Storage Configuration
|
| 144 |
+
S3_BUCKET: str | None = Field(default=None, description="S3 bucket name for artifact storage")
|
| 145 |
+
|
| 146 |
+
S3_PREFIX: str = Field(default="mcts-artifacts", description="S3 key prefix for stored artifacts")
|
| 147 |
+
|
| 148 |
+
S3_REGION: str = Field(default="us-east-1", description="AWS region for S3 bucket")
|
| 149 |
+
|
| 150 |
+
# Network Configuration (security)
|
| 151 |
+
HTTP_TIMEOUT_SECONDS: int = Field(default=30, ge=1, le=300, description="HTTP request timeout in seconds")
|
| 152 |
+
|
| 153 |
+
HTTP_MAX_RETRIES: int = Field(default=3, ge=0, le=10, description="Maximum HTTP request retries")
|
| 154 |
+
|
| 155 |
+
# Security Settings
|
| 156 |
+
MAX_QUERY_LENGTH: int = Field(
|
| 157 |
+
default=10000, ge=1, le=100000, description="Maximum allowed query length in characters"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
RATE_LIMIT_REQUESTS_PER_MINUTE: int = Field(
|
| 161 |
+
default=60, ge=1, le=1000, description="Rate limit for API requests per minute"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
@field_validator("OPENAI_API_KEY")
|
| 165 |
+
@classmethod
|
| 166 |
+
def validate_openai_key_format(cls, v: SecretStr | None) -> SecretStr | None:
|
| 167 |
+
"""Validate OpenAI API key format without exposing the value."""
|
| 168 |
+
if v is not None:
|
| 169 |
+
secret_value = v.get_secret_value()
|
| 170 |
+
# Check for obviously invalid patterns
|
| 171 |
+
if secret_value in ("", "your-api-key-here", "sk-xxx", "REPLACE_ME"):
|
| 172 |
+
raise ValueError("OpenAI API key appears to be a placeholder value")
|
| 173 |
+
if not secret_value.startswith("sk-"):
|
| 174 |
+
raise ValueError("OpenAI API key should start with 'sk-'")
|
| 175 |
+
if len(secret_value) < 20:
|
| 176 |
+
raise ValueError("OpenAI API key appears to be too short")
|
| 177 |
+
return v
|
| 178 |
+
|
| 179 |
+
@field_validator("ANTHROPIC_API_KEY")
|
| 180 |
+
@classmethod
|
| 181 |
+
def validate_anthropic_key_format(cls, v: SecretStr | None) -> SecretStr | None:
|
| 182 |
+
"""Validate Anthropic API key format without exposing the value."""
|
| 183 |
+
if v is not None:
|
| 184 |
+
secret_value = v.get_secret_value()
|
| 185 |
+
# Check for obviously invalid patterns
|
| 186 |
+
if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
|
| 187 |
+
raise ValueError("Anthropic API key appears to be a placeholder value")
|
| 188 |
+
if len(secret_value) < 20:
|
| 189 |
+
raise ValueError("Anthropic API key appears to be too short")
|
| 190 |
+
return v
|
| 191 |
+
|
| 192 |
+
@field_validator("BRAINTRUST_API_KEY")
|
| 193 |
+
@classmethod
|
| 194 |
+
def validate_braintrust_key_format(cls, v: SecretStr | None) -> SecretStr | None:
|
| 195 |
+
"""Validate Braintrust API key format without exposing the value."""
|
| 196 |
+
if v is not None:
|
| 197 |
+
secret_value = v.get_secret_value()
|
| 198 |
+
# Check for obviously invalid patterns
|
| 199 |
+
if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
|
| 200 |
+
raise ValueError("Braintrust API key appears to be a placeholder value")
|
| 201 |
+
if len(secret_value) < 20:
|
| 202 |
+
raise ValueError("Braintrust API key appears to be too short")
|
| 203 |
+
return v
|
| 204 |
+
|
| 205 |
+
@field_validator("PINECONE_API_KEY")
|
| 206 |
+
@classmethod
|
| 207 |
+
def validate_pinecone_key_format(cls, v: SecretStr | None) -> SecretStr | None:
|
| 208 |
+
"""Validate Pinecone API key format without exposing the value."""
|
| 209 |
+
if v is not None:
|
| 210 |
+
secret_value = v.get_secret_value()
|
| 211 |
+
# Check for obviously invalid patterns
|
| 212 |
+
if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
|
| 213 |
+
raise ValueError("Pinecone API key appears to be a placeholder value")
|
| 214 |
+
if len(secret_value) < 20:
|
| 215 |
+
raise ValueError("Pinecone API key appears to be too short")
|
| 216 |
+
return v
|
| 217 |
+
|
| 218 |
+
@field_validator("LANGSMITH_API_KEY")
|
| 219 |
+
@classmethod
|
| 220 |
+
def validate_langsmith_key_format(cls, v: SecretStr | None) -> SecretStr | None:
|
| 221 |
+
"""Validate LangSmith API key format without exposing the value."""
|
| 222 |
+
if v is not None:
|
| 223 |
+
secret_value = v.get_secret_value()
|
| 224 |
+
if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
|
| 225 |
+
raise ValueError("LangSmith API key appears to be a placeholder value")
|
| 226 |
+
if len(secret_value) < 20:
|
| 227 |
+
raise ValueError("LangSmith API key appears to be too short")
|
| 228 |
+
return v
|
| 229 |
+
|
| 230 |
+
@field_validator("WANDB_API_KEY")
|
| 231 |
+
@classmethod
|
| 232 |
+
def validate_wandb_key_format(cls, v: SecretStr | None) -> SecretStr | None:
|
| 233 |
+
"""Validate Weights & Biases API key format without exposing the value."""
|
| 234 |
+
if v is not None:
|
| 235 |
+
secret_value = v.get_secret_value()
|
| 236 |
+
if secret_value in ("", "your-api-key-here", "REPLACE_ME"):
|
| 237 |
+
raise ValueError("W&B API key appears to be a placeholder value")
|
| 238 |
+
if len(secret_value) < 20:
|
| 239 |
+
raise ValueError("W&B API key appears to be too short")
|
| 240 |
+
return v
|
| 241 |
+
|
| 242 |
+
@field_validator("PINECONE_HOST")
|
| 243 |
+
@classmethod
|
| 244 |
+
def validate_pinecone_host(cls, v: str | None) -> str | None:
|
| 245 |
+
"""Validate Pinecone host URL format."""
|
| 246 |
+
if v is not None and v != "":
|
| 247 |
+
if not v.startswith("https://"):
|
| 248 |
+
raise ValueError("Pinecone host must start with https://")
|
| 249 |
+
if "pinecone.io" not in v:
|
| 250 |
+
raise ValueError("Pinecone host should be a valid pinecone.io URL")
|
| 251 |
+
return v
|
| 252 |
+
|
| 253 |
+
@field_validator("LMSTUDIO_BASE_URL")
|
| 254 |
+
@classmethod
|
| 255 |
+
def validate_lmstudio_url(cls, v: str | None) -> str | None:
|
| 256 |
+
"""Validate LM Studio base URL format."""
|
| 257 |
+
if v is not None:
|
| 258 |
+
if not v.startswith(("http://", "https://")):
|
| 259 |
+
raise ValueError("LM Studio base URL must start with http:// or https://")
|
| 260 |
+
# Warn if not localhost (potential security concern)
|
| 261 |
+
if not any(host in v for host in ("localhost", "127.0.0.1", "::1")):
|
| 262 |
+
import warnings
|
| 263 |
+
|
| 264 |
+
warnings.warn(
|
| 265 |
+
"LM Studio URL points to non-localhost address. Ensure this is intentional and secure.",
|
| 266 |
+
UserWarning,
|
| 267 |
+
stacklevel=2,
|
| 268 |
+
)
|
| 269 |
+
return v
|
| 270 |
+
|
| 271 |
+
@field_validator("OTEL_EXPORTER_OTLP_ENDPOINT")
|
| 272 |
+
@classmethod
|
| 273 |
+
def validate_otel_endpoint(cls, v: str | None) -> str | None:
|
| 274 |
+
"""Validate OpenTelemetry endpoint URL."""
|
| 275 |
+
if v is not None and v != "" and not v.startswith(("http://", "https://", "grpc://")):
|
| 276 |
+
raise ValueError("OpenTelemetry endpoint must start with http://, https://, or grpc://")
|
| 277 |
+
return v
|
| 278 |
+
|
| 279 |
+
@field_validator("S3_BUCKET")
|
| 280 |
+
@classmethod
|
| 281 |
+
def validate_s3_bucket_name(cls, v: str | None) -> str | None:
|
| 282 |
+
"""Validate S3 bucket name format."""
|
| 283 |
+
if v is not None:
|
| 284 |
+
# S3 bucket naming rules
|
| 285 |
+
if len(v) < 3 or len(v) > 63:
|
| 286 |
+
raise ValueError("S3 bucket name must be 3-63 characters long")
|
| 287 |
+
if not v.replace("-", "").replace(".", "").isalnum():
|
| 288 |
+
raise ValueError("S3 bucket name can only contain lowercase letters, numbers, hyphens, and periods")
|
| 289 |
+
if v.startswith("-") or v.endswith("-"):
|
| 290 |
+
raise ValueError("S3 bucket name cannot start or end with a hyphen")
|
| 291 |
+
return v
|
| 292 |
+
|
| 293 |
+
@model_validator(mode="after")
|
| 294 |
+
def validate_provider_credentials(self) -> "Settings":
|
| 295 |
+
"""Ensure required API keys are provided for the selected provider."""
|
| 296 |
+
if self.LLM_PROVIDER == LLMProvider.OPENAI:
|
| 297 |
+
if self.OPENAI_API_KEY is None:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
"OPENAI_API_KEY is required when using OpenAI provider. "
|
| 300 |
+
"Set the OPENAI_API_KEY environment variable."
|
| 301 |
+
)
|
| 302 |
+
elif self.LLM_PROVIDER == LLMProvider.ANTHROPIC:
|
| 303 |
+
if self.ANTHROPIC_API_KEY is None:
|
| 304 |
+
raise ValueError(
|
| 305 |
+
"ANTHROPIC_API_KEY is required when using Anthropic provider. "
|
| 306 |
+
"Set the ANTHROPIC_API_KEY environment variable."
|
| 307 |
+
)
|
| 308 |
+
elif self.LLM_PROVIDER == LLMProvider.LMSTUDIO and self.LMSTUDIO_BASE_URL is None:
|
| 309 |
+
raise ValueError("LMSTUDIO_BASE_URL is required when using LM Studio provider.")
|
| 310 |
+
return self
|
| 311 |
+
|
| 312 |
+
def get_api_key(self) -> str | None:
|
| 313 |
+
"""
|
| 314 |
+
Get the API key for the current provider.
|
| 315 |
+
|
| 316 |
+
Returns the secret value - use with caution to avoid logging.
|
| 317 |
+
"""
|
| 318 |
+
if self.LLM_PROVIDER == LLMProvider.OPENAI and self.OPENAI_API_KEY:
|
| 319 |
+
return self.OPENAI_API_KEY.get_secret_value()
|
| 320 |
+
elif self.LLM_PROVIDER == LLMProvider.ANTHROPIC and self.ANTHROPIC_API_KEY:
|
| 321 |
+
return self.ANTHROPIC_API_KEY.get_secret_value()
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
def safe_dict(self) -> dict:
|
| 325 |
+
"""
|
| 326 |
+
Return settings as dictionary with secrets masked.
|
| 327 |
+
|
| 328 |
+
Safe for logging and display purposes.
|
| 329 |
+
"""
|
| 330 |
+
data = self.model_dump()
|
| 331 |
+
# Mask all sensitive fields
|
| 332 |
+
secret_fields = [
|
| 333 |
+
"OPENAI_API_KEY",
|
| 334 |
+
"ANTHROPIC_API_KEY",
|
| 335 |
+
"BRAINTRUST_API_KEY",
|
| 336 |
+
"PINECONE_API_KEY",
|
| 337 |
+
"LANGSMITH_API_KEY",
|
| 338 |
+
"WANDB_API_KEY",
|
| 339 |
+
]
|
| 340 |
+
for field in secret_fields:
|
| 341 |
+
if field in data and data[field]:
|
| 342 |
+
data[field] = "***MASKED***"
|
| 343 |
+
return data
|
| 344 |
+
|
| 345 |
+
def get_braintrust_api_key(self) -> str | None:
|
| 346 |
+
"""
|
| 347 |
+
Get the Braintrust API key if configured.
|
| 348 |
+
|
| 349 |
+
Returns the secret value - use with caution to avoid logging.
|
| 350 |
+
"""
|
| 351 |
+
if self.BRAINTRUST_API_KEY:
|
| 352 |
+
return self.BRAINTRUST_API_KEY.get_secret_value()
|
| 353 |
+
return None
|
| 354 |
+
|
| 355 |
+
def get_pinecone_api_key(self) -> str | None:
|
| 356 |
+
"""
|
| 357 |
+
Get the Pinecone API key if configured.
|
| 358 |
+
|
| 359 |
+
Returns the secret value - use with caution to avoid logging.
|
| 360 |
+
"""
|
| 361 |
+
if self.PINECONE_API_KEY:
|
| 362 |
+
return self.PINECONE_API_KEY.get_secret_value()
|
| 363 |
+
return None
|
| 364 |
+
|
| 365 |
+
def get_langsmith_api_key(self) -> str | None:
|
| 366 |
+
"""
|
| 367 |
+
Get the LangSmith API key if configured.
|
| 368 |
+
|
| 369 |
+
Returns the secret value - use with caution to avoid logging.
|
| 370 |
+
"""
|
| 371 |
+
if self.LANGSMITH_API_KEY:
|
| 372 |
+
return self.LANGSMITH_API_KEY.get_secret_value()
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
def get_wandb_api_key(self) -> str | None:
|
| 376 |
+
"""
|
| 377 |
+
Get the Weights & Biases API key if configured.
|
| 378 |
+
|
| 379 |
+
Returns the secret value - use with caution to avoid logging.
|
| 380 |
+
"""
|
| 381 |
+
if self.WANDB_API_KEY:
|
| 382 |
+
return self.WANDB_API_KEY.get_secret_value()
|
| 383 |
+
return None
|
| 384 |
+
|
| 385 |
+
def __repr__(self) -> str:
|
| 386 |
+
"""Safe string representation that doesn't expose secrets."""
|
| 387 |
+
return f"Settings(LLM_PROVIDER={self.LLM_PROVIDER}, MCTS_ENABLED={self.MCTS_ENABLED}, MCTS_IMPL={self.MCTS_IMPL}, LOG_LEVEL={self.LOG_LEVEL})"
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# Global settings instance (lazily loaded)
|
| 391 |
+
_settings: Settings | None = None
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def get_settings() -> Settings:
|
| 395 |
+
"""
|
| 396 |
+
Get the global settings instance.
|
| 397 |
+
|
| 398 |
+
Settings are loaded once and cached. To reload, call reset_settings() first.
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
Settings: Application configuration instance
|
| 402 |
+
|
| 403 |
+
Raises:
|
| 404 |
+
ValidationError: If configuration is invalid
|
| 405 |
+
"""
|
| 406 |
+
global _settings
|
| 407 |
+
if _settings is None:
|
| 408 |
+
_settings = Settings()
|
| 409 |
+
return _settings
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def reset_settings() -> None:
|
| 413 |
+
"""
|
| 414 |
+
Reset the global settings instance.
|
| 415 |
+
|
| 416 |
+
Forces settings to be reloaded from environment on next get_settings() call.
|
| 417 |
+
Useful for testing.
|
| 418 |
+
"""
|
| 419 |
+
global _settings
|
| 420 |
+
_settings = None
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
# Type exports for external use
|
| 424 |
+
__all__ = [
|
| 425 |
+
"Settings",
|
| 426 |
+
"LLMProvider",
|
| 427 |
+
"LogLevel",
|
| 428 |
+
"MCTSImplementation",
|
| 429 |
+
"get_settings",
|
| 430 |
+
"reset_settings",
|
| 431 |
+
]
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Integration Module for Multi-Agent MCTS Training.
|
| 3 |
+
|
| 4 |
+
This module provides utilities for loading, preprocessing, and managing
|
| 5 |
+
open-source datasets for training HRM/TRM agents and neural meta-controllers.
|
| 6 |
+
|
| 7 |
+
Supported Datasets:
|
| 8 |
+
- DABStep: Multi-step reasoning tasks (CC-BY-4.0)
|
| 9 |
+
- PRIMUS-Seed: Cybersecurity domain knowledge (ODC-BY)
|
| 10 |
+
- PRIMUS-Instruct: Instruction fine-tuning data (ODC-BY)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from .dataset_loader import DABStepLoader, DatasetLoader, PRIMUSLoader
|
| 14 |
+
from .preprocessing import TextPreprocessor, TokenizerWrapper
|
| 15 |
+
from .tactical_augmentation import TacticalAugmenter
|
| 16 |
+
from .train_test_split import DataSplitter, StratifiedSplitter
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"DatasetLoader",
|
| 20 |
+
"DABStepLoader",
|
| 21 |
+
"PRIMUSLoader",
|
| 22 |
+
"TextPreprocessor",
|
| 23 |
+
"TokenizerWrapper",
|
| 24 |
+
"TacticalAugmenter",
|
| 25 |
+
"DataSplitter",
|
| 26 |
+
"StratifiedSplitter",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
__version__ = "1.0.0"
|
src/data/dataset_loader.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Loading Module for Open-Source Training Data.
|
| 3 |
+
|
| 4 |
+
Provides unified loading interfaces for:
|
| 5 |
+
- DABStep: Multi-step data analysis reasoning
|
| 6 |
+
- PRIMUS: Cybersecurity domain knowledge
|
| 7 |
+
- Custom tactical datasets
|
| 8 |
+
|
| 9 |
+
License Attribution:
|
| 10 |
+
- DABStep: CC-BY-4.0 (Creative Commons Attribution)
|
| 11 |
+
- PRIMUS: ODC-BY (Open Data Commons Attribution)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from collections.abc import Iterator
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class DatasetSample:
|
| 26 |
+
"""Standardized representation of a dataset sample."""
|
| 27 |
+
|
| 28 |
+
id: str
|
| 29 |
+
text: str
|
| 30 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 31 |
+
labels: list[str] | None = None
|
| 32 |
+
difficulty: str | None = None
|
| 33 |
+
domain: str | None = None
|
| 34 |
+
reasoning_steps: list[str] | None = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class DatasetStatistics:
|
| 39 |
+
"""Statistics about a loaded dataset."""
|
| 40 |
+
|
| 41 |
+
total_samples: int
|
| 42 |
+
domains: dict[str, int]
|
| 43 |
+
avg_text_length: float
|
| 44 |
+
difficulty_distribution: dict[str, int]
|
| 45 |
+
total_tokens: int = 0
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DatasetLoader(ABC):
|
| 49 |
+
"""Abstract base class for dataset loaders."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, cache_dir: str | None = None):
|
| 52 |
+
"""
|
| 53 |
+
Initialize dataset loader.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
cache_dir: Directory to cache downloaded datasets
|
| 57 |
+
"""
|
| 58 |
+
self.cache_dir = cache_dir or str(Path.home() / ".cache" / "mcts_datasets")
|
| 59 |
+
self._dataset = None
|
| 60 |
+
self._statistics = None
|
| 61 |
+
|
| 62 |
+
@abstractmethod
|
| 63 |
+
def load(self, split: str = "train") -> list[DatasetSample]:
|
| 64 |
+
"""Load dataset split."""
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
@abstractmethod
|
| 68 |
+
def get_statistics(self) -> DatasetStatistics:
|
| 69 |
+
"""Get dataset statistics."""
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
@abstractmethod
|
| 73 |
+
def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]:
|
| 74 |
+
"""Iterate over samples in batches."""
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DABStepLoader(DatasetLoader):
|
| 79 |
+
"""
|
| 80 |
+
Loader for DABStep Multi-Step Reasoning Dataset.
|
| 81 |
+
|
| 82 |
+
DABStep contains 450+ data analysis tasks requiring sequential,
|
| 83 |
+
iterative problem-solving. Perfect for training HRM/TRM agents.
|
| 84 |
+
|
| 85 |
+
License: CC-BY-4.0 (Attribution required)
|
| 86 |
+
Source: huggingface.co/datasets/adyen/DABstep
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
DATASET_NAME = "adyen/DABstep"
|
| 90 |
+
DIFFICULTIES = ["easy", "medium", "hard"]
|
| 91 |
+
|
| 92 |
+
def __init__(self, cache_dir: str | None = None):
|
| 93 |
+
"""Initialize DABStep loader."""
|
| 94 |
+
super().__init__(cache_dir)
|
| 95 |
+
self._loaded_samples: list[DatasetSample] = []
|
| 96 |
+
|
| 97 |
+
def load(self, split: str = "train", difficulty: str | None = None) -> list[DatasetSample]:
|
| 98 |
+
"""
|
| 99 |
+
Load DABStep dataset.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
split: Dataset split ('train', 'validation', 'test')
|
| 103 |
+
difficulty: Filter by difficulty ('easy', 'medium', 'hard')
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
List of DatasetSample objects
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
from datasets import load_dataset
|
| 110 |
+
|
| 111 |
+
logger.info(f"Loading DABStep dataset (split={split})")
|
| 112 |
+
|
| 113 |
+
dataset = load_dataset(
|
| 114 |
+
self.DATASET_NAME,
|
| 115 |
+
cache_dir=self.cache_dir,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if split not in dataset:
|
| 119 |
+
available_splits = list(dataset.keys())
|
| 120 |
+
logger.warning(f"Split '{split}' not found. Available: {available_splits}")
|
| 121 |
+
split = available_splits[0] if available_splits else "train"
|
| 122 |
+
|
| 123 |
+
samples = []
|
| 124 |
+
for idx, item in enumerate(dataset[split]):
|
| 125 |
+
sample = DatasetSample(
|
| 126 |
+
id=f"dabstep_{split}_{idx}",
|
| 127 |
+
text=str(item.get("question", item.get("text", ""))),
|
| 128 |
+
metadata={
|
| 129 |
+
"source": "DABStep",
|
| 130 |
+
"license": "CC-BY-4.0",
|
| 131 |
+
"split": split,
|
| 132 |
+
"original_data": item,
|
| 133 |
+
},
|
| 134 |
+
difficulty=item.get("difficulty", "medium"),
|
| 135 |
+
domain="data_analysis",
|
| 136 |
+
reasoning_steps=item.get("steps", []),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if difficulty and sample.difficulty != difficulty:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
samples.append(sample)
|
| 143 |
+
|
| 144 |
+
self._loaded_samples = samples
|
| 145 |
+
logger.info(f"Loaded {len(samples)} DABStep samples")
|
| 146 |
+
return samples
|
| 147 |
+
|
| 148 |
+
except ImportError:
|
| 149 |
+
logger.error("datasets library not installed. Run: pip install datasets")
|
| 150 |
+
raise
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"Failed to load DABStep: {e}")
|
| 153 |
+
raise
|
| 154 |
+
|
| 155 |
+
def get_statistics(self) -> DatasetStatistics:
|
| 156 |
+
"""Get statistics about loaded DABStep data."""
|
| 157 |
+
if not self._loaded_samples:
|
| 158 |
+
raise ValueError("No samples loaded. Call load() first.")
|
| 159 |
+
|
| 160 |
+
difficulty_dist = {}
|
| 161 |
+
total_length = 0
|
| 162 |
+
|
| 163 |
+
for sample in self._loaded_samples:
|
| 164 |
+
diff = sample.difficulty or "unknown"
|
| 165 |
+
difficulty_dist[diff] = difficulty_dist.get(diff, 0) + 1
|
| 166 |
+
total_length += len(sample.text)
|
| 167 |
+
|
| 168 |
+
return DatasetStatistics(
|
| 169 |
+
total_samples=len(self._loaded_samples),
|
| 170 |
+
domains={"data_analysis": len(self._loaded_samples)},
|
| 171 |
+
avg_text_length=total_length / len(self._loaded_samples),
|
| 172 |
+
difficulty_distribution=difficulty_dist,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]:
|
| 176 |
+
"""Iterate over samples in batches."""
|
| 177 |
+
if not self._loaded_samples:
|
| 178 |
+
raise ValueError("No samples loaded. Call load() first.")
|
| 179 |
+
|
| 180 |
+
for i in range(0, len(self._loaded_samples), batch_size):
|
| 181 |
+
yield self._loaded_samples[i : i + batch_size]
|
| 182 |
+
|
| 183 |
+
def get_reasoning_tasks(self) -> list[DatasetSample]:
|
| 184 |
+
"""Get only samples with explicit reasoning steps."""
|
| 185 |
+
return [s for s in self._loaded_samples if s.reasoning_steps]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class PRIMUSLoader(DatasetLoader):
|
| 189 |
+
"""
|
| 190 |
+
Loader for PRIMUS Cybersecurity Dataset Suite.
|
| 191 |
+
|
| 192 |
+
PRIMUS contains:
|
| 193 |
+
- Seed: 674,848 cybersecurity documents (190M tokens)
|
| 194 |
+
- Instruct: 835 instruction-tuning samples
|
| 195 |
+
- Reasoning: Self-reflection data for reasoning
|
| 196 |
+
|
| 197 |
+
License: ODC-BY (Open Data Commons Attribution)
|
| 198 |
+
Source: huggingface.co/datasets/trendmicro-ailab/Primus-Seed
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
SEED_DATASET = "trendmicro-ailab/Primus-Seed"
|
| 202 |
+
INSTRUCT_DATASET = "trendmicro-ailab/Primus-Instruct"
|
| 203 |
+
|
| 204 |
+
DOMAINS = [
|
| 205 |
+
"mitre_attack",
|
| 206 |
+
"wikipedia",
|
| 207 |
+
"company_sites",
|
| 208 |
+
"threat_intelligence",
|
| 209 |
+
"vulnerability_db",
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
def __init__(self, cache_dir: str | None = None):
|
| 213 |
+
"""Initialize PRIMUS loader."""
|
| 214 |
+
super().__init__(cache_dir)
|
| 215 |
+
self._seed_samples: list[DatasetSample] = []
|
| 216 |
+
self._instruct_samples: list[DatasetSample] = []
|
| 217 |
+
|
| 218 |
+
def load(
|
| 219 |
+
self,
|
| 220 |
+
split: str = "train",
|
| 221 |
+
dataset_type: str = "seed",
|
| 222 |
+
domains: list[str] | None = None,
|
| 223 |
+
max_samples: int | None = None,
|
| 224 |
+
streaming: bool = True,
|
| 225 |
+
) -> list[DatasetSample]:
|
| 226 |
+
"""
|
| 227 |
+
Load PRIMUS dataset.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
split: Dataset split ('train', 'validation', 'test')
|
| 231 |
+
dataset_type: 'seed' for knowledge base, 'instruct' for fine-tuning
|
| 232 |
+
domains: Filter by specific domains
|
| 233 |
+
max_samples: Limit number of samples (useful for large datasets)
|
| 234 |
+
streaming: Use streaming mode for large datasets (default True)
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
List of DatasetSample objects
|
| 238 |
+
"""
|
| 239 |
+
try:
|
| 240 |
+
from datasets import load_dataset
|
| 241 |
+
|
| 242 |
+
dataset_name = self.SEED_DATASET if dataset_type == "seed" else self.INSTRUCT_DATASET
|
| 243 |
+
|
| 244 |
+
logger.info(f"Loading PRIMUS {dataset_type} dataset")
|
| 245 |
+
|
| 246 |
+
# Use streaming for large seed dataset to avoid download issues
|
| 247 |
+
use_streaming = streaming and dataset_type == "seed" and max_samples is not None
|
| 248 |
+
|
| 249 |
+
if use_streaming:
|
| 250 |
+
logger.info(f"Using streaming mode (max_samples={max_samples})")
|
| 251 |
+
dataset = load_dataset(
|
| 252 |
+
dataset_name,
|
| 253 |
+
"default",
|
| 254 |
+
streaming=True,
|
| 255 |
+
cache_dir=self.cache_dir,
|
| 256 |
+
)
|
| 257 |
+
# For streaming, iterate the first available split
|
| 258 |
+
data_iter = iter(dataset["train"]) if "train" in dataset else iter(dataset[list(dataset.keys())[0]])
|
| 259 |
+
else:
|
| 260 |
+
dataset = load_dataset(
|
| 261 |
+
dataset_name,
|
| 262 |
+
cache_dir=self.cache_dir,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if split not in dataset:
|
| 266 |
+
available_splits = list(dataset.keys())
|
| 267 |
+
logger.warning(f"Split '{split}' not found. Using: {available_splits[0]}")
|
| 268 |
+
split = available_splits[0]
|
| 269 |
+
|
| 270 |
+
data_iter = iter(dataset[split])
|
| 271 |
+
|
| 272 |
+
samples = []
|
| 273 |
+
count = 0
|
| 274 |
+
|
| 275 |
+
for idx, item in enumerate(data_iter):
|
| 276 |
+
if max_samples and count >= max_samples:
|
| 277 |
+
break
|
| 278 |
+
|
| 279 |
+
domain = item.get("domain", item.get("source", "unknown"))
|
| 280 |
+
|
| 281 |
+
if domains and domain not in domains:
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
if dataset_type == "instruct":
|
| 285 |
+
text = f"Instruction: {item.get('instruction', '')}\nResponse: {item.get('response', '')}"
|
| 286 |
+
else:
|
| 287 |
+
text = str(item.get("text", item.get("content", "")))
|
| 288 |
+
|
| 289 |
+
sample = DatasetSample(
|
| 290 |
+
id=f"primus_{dataset_type}_{split}_{idx}",
|
| 291 |
+
text=text,
|
| 292 |
+
metadata={
|
| 293 |
+
"source": f"PRIMUS-{dataset_type.capitalize()}",
|
| 294 |
+
"license": "ODC-BY",
|
| 295 |
+
"split": split,
|
| 296 |
+
"original_domain": domain,
|
| 297 |
+
},
|
| 298 |
+
domain=domain,
|
| 299 |
+
labels=item.get("labels", item.get("tags", [])),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
samples.append(sample)
|
| 303 |
+
count += 1
|
| 304 |
+
|
| 305 |
+
if dataset_type == "seed":
|
| 306 |
+
self._seed_samples = samples
|
| 307 |
+
else:
|
| 308 |
+
self._instruct_samples = samples
|
| 309 |
+
|
| 310 |
+
logger.info(f"Loaded {len(samples)} PRIMUS {dataset_type} samples")
|
| 311 |
+
return samples
|
| 312 |
+
|
| 313 |
+
except ImportError:
|
| 314 |
+
logger.error("datasets library not installed. Run: pip install datasets")
|
| 315 |
+
raise
|
| 316 |
+
except Exception as e:
|
| 317 |
+
if "gated dataset" in str(e):
|
| 318 |
+
logger.error(
|
| 319 |
+
f"PRIMUS is a gated dataset. Please authenticate with HuggingFace:\n"
|
| 320 |
+
f"1. Create account at https://huggingface.co/\n"
|
| 321 |
+
f"2. Accept dataset terms at https://huggingface.co/datasets/{dataset_name}\n"
|
| 322 |
+
f"3. Create token at https://huggingface.co/settings/tokens\n"
|
| 323 |
+
f"4. Run: huggingface-cli login"
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
logger.error(f"Failed to load PRIMUS: {e}")
|
| 327 |
+
raise
|
| 328 |
+
|
| 329 |
+
def load_seed(self, max_samples: int | None = None) -> list[DatasetSample]:
|
| 330 |
+
"""Load PRIMUS-Seed knowledge base."""
|
| 331 |
+
return self.load(dataset_type="seed", max_samples=max_samples)
|
| 332 |
+
|
| 333 |
+
def load_instruct(self) -> list[DatasetSample]:
|
| 334 |
+
"""Load PRIMUS-Instruct fine-tuning data."""
|
| 335 |
+
return self.load(dataset_type="instruct", streaming=False)
|
| 336 |
+
|
| 337 |
+
def get_statistics(self) -> DatasetStatistics:
|
| 338 |
+
"""Get statistics about loaded PRIMUS data."""
|
| 339 |
+
all_samples = self._seed_samples + self._instruct_samples
|
| 340 |
+
|
| 341 |
+
if not all_samples:
|
| 342 |
+
raise ValueError("No samples loaded. Call load() first.")
|
| 343 |
+
|
| 344 |
+
domain_dist = {}
|
| 345 |
+
total_length = 0
|
| 346 |
+
|
| 347 |
+
for sample in all_samples:
|
| 348 |
+
domain = sample.domain or "unknown"
|
| 349 |
+
domain_dist[domain] = domain_dist.get(domain, 0) + 1
|
| 350 |
+
total_length += len(sample.text)
|
| 351 |
+
|
| 352 |
+
return DatasetStatistics(
|
| 353 |
+
total_samples=len(all_samples),
|
| 354 |
+
domains=domain_dist,
|
| 355 |
+
avg_text_length=total_length / len(all_samples),
|
| 356 |
+
difficulty_distribution={"cybersecurity": len(all_samples)},
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]:
|
| 360 |
+
"""Iterate over all loaded samples in batches."""
|
| 361 |
+
all_samples = self._seed_samples + self._instruct_samples
|
| 362 |
+
|
| 363 |
+
if not all_samples:
|
| 364 |
+
raise ValueError("No samples loaded. Call load() first.")
|
| 365 |
+
|
| 366 |
+
for i in range(0, len(all_samples), batch_size):
|
| 367 |
+
yield all_samples[i : i + batch_size]
|
| 368 |
+
|
| 369 |
+
def get_mitre_attack_samples(self) -> list[DatasetSample]:
|
| 370 |
+
"""Get samples specifically from MITRE ATT&CK."""
|
| 371 |
+
return [s for s in self._seed_samples if "mitre" in (s.domain or "").lower()]
|
| 372 |
+
|
| 373 |
+
def get_threat_intelligence_samples(self) -> list[DatasetSample]:
|
| 374 |
+
"""Get threat intelligence related samples."""
|
| 375 |
+
return [
|
| 376 |
+
s
|
| 377 |
+
for s in self._seed_samples
|
| 378 |
+
if any(kw in (s.domain or "").lower() for kw in ["threat", "cti", "intelligence"])
|
| 379 |
+
]
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class CombinedDatasetLoader:
|
| 383 |
+
"""
|
| 384 |
+
Unified loader for combining multiple datasets.
|
| 385 |
+
|
| 386 |
+
Provides a single interface for loading and managing:
|
| 387 |
+
- DABStep (multi-step reasoning)
|
| 388 |
+
- PRIMUS (cybersecurity knowledge)
|
| 389 |
+
- Custom tactical datasets
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
def __init__(self, cache_dir: str | None = None):
|
| 393 |
+
"""Initialize combined loader."""
|
| 394 |
+
self.cache_dir = cache_dir
|
| 395 |
+
self.dabstep_loader = DABStepLoader(cache_dir)
|
| 396 |
+
self.primus_loader = PRIMUSLoader(cache_dir)
|
| 397 |
+
self._all_samples: list[DatasetSample] = []
|
| 398 |
+
|
| 399 |
+
def load_all(
|
| 400 |
+
self,
|
| 401 |
+
dabstep_split: str = "train",
|
| 402 |
+
primus_max_samples: int | None = 10000,
|
| 403 |
+
include_instruct: bool = True,
|
| 404 |
+
) -> list[DatasetSample]:
|
| 405 |
+
"""
|
| 406 |
+
Load all datasets.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
dabstep_split: Split for DABStep
|
| 410 |
+
primus_max_samples: Max samples from PRIMUS-Seed (None for all)
|
| 411 |
+
include_instruct: Whether to include PRIMUS-Instruct
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
Combined list of all samples
|
| 415 |
+
"""
|
| 416 |
+
logger.info("Loading combined datasets")
|
| 417 |
+
|
| 418 |
+
# Load DABStep
|
| 419 |
+
dabstep_samples = self.dabstep_loader.load(split=dabstep_split)
|
| 420 |
+
logger.info(f"DABStep: {len(dabstep_samples)} samples")
|
| 421 |
+
|
| 422 |
+
# Load PRIMUS-Seed
|
| 423 |
+
primus_seed = self.primus_loader.load_seed(max_samples=primus_max_samples)
|
| 424 |
+
logger.info(f"PRIMUS-Seed: {len(primus_seed)} samples")
|
| 425 |
+
|
| 426 |
+
# Load PRIMUS-Instruct
|
| 427 |
+
primus_instruct = []
|
| 428 |
+
if include_instruct:
|
| 429 |
+
primus_instruct = self.primus_loader.load_instruct()
|
| 430 |
+
logger.info(f"PRIMUS-Instruct: {len(primus_instruct)} samples")
|
| 431 |
+
|
| 432 |
+
self._all_samples = dabstep_samples + primus_seed + primus_instruct
|
| 433 |
+
logger.info(f"Total combined samples: {len(self._all_samples)}")
|
| 434 |
+
|
| 435 |
+
return self._all_samples
|
| 436 |
+
|
| 437 |
+
def get_domain_distribution(self) -> dict[str, int]:
|
| 438 |
+
"""Get distribution of samples across domains."""
|
| 439 |
+
dist = {}
|
| 440 |
+
for sample in self._all_samples:
|
| 441 |
+
domain = sample.domain or "unknown"
|
| 442 |
+
dist[domain] = dist.get(domain, 0) + 1
|
| 443 |
+
return dist
|
| 444 |
+
|
| 445 |
+
def filter_by_domain(self, domain: str) -> list[DatasetSample]:
|
| 446 |
+
"""Filter samples by domain."""
|
| 447 |
+
return [s for s in self._all_samples if s.domain == domain]
|
| 448 |
+
|
| 449 |
+
def get_multi_step_reasoning_samples(self) -> list[DatasetSample]:
|
| 450 |
+
"""Get samples suitable for multi-step reasoning training."""
|
| 451 |
+
return [
|
| 452 |
+
s
|
| 453 |
+
for s in self._all_samples
|
| 454 |
+
if s.reasoning_steps or s.domain == "data_analysis" or "instruct" in s.metadata.get("source", "").lower()
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
def export_for_training(self, output_path: str, format: str = "jsonl") -> str:
|
| 458 |
+
"""
|
| 459 |
+
Export dataset for training.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
output_path: Path to save exported data
|
| 463 |
+
format: Export format ('jsonl', 'csv', 'parquet')
|
| 464 |
+
|
| 465 |
+
Returns:
|
| 466 |
+
Path to exported file
|
| 467 |
+
"""
|
| 468 |
+
import json
|
| 469 |
+
|
| 470 |
+
output_file = Path(output_path)
|
| 471 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 472 |
+
|
| 473 |
+
if format == "jsonl":
|
| 474 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 475 |
+
for sample in self._all_samples:
|
| 476 |
+
record = {
|
| 477 |
+
"id": sample.id,
|
| 478 |
+
"text": sample.text,
|
| 479 |
+
"domain": sample.domain,
|
| 480 |
+
"difficulty": sample.difficulty,
|
| 481 |
+
"labels": sample.labels,
|
| 482 |
+
"metadata": sample.metadata,
|
| 483 |
+
}
|
| 484 |
+
f.write(json.dumps(record) + "\n")
|
| 485 |
+
else:
|
| 486 |
+
raise NotImplementedError(f"Format {format} not yet supported")
|
| 487 |
+
|
| 488 |
+
logger.info(f"Exported {len(self._all_samples)} samples to {output_file}")
|
| 489 |
+
return str(output_file)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def load_dataset(
|
| 493 |
+
dataset_name: str,
|
| 494 |
+
split: str = "train",
|
| 495 |
+
cache_dir: str | None = None,
|
| 496 |
+
**kwargs,
|
| 497 |
+
) -> Any:
|
| 498 |
+
"""
|
| 499 |
+
Unified interface for loading datasets from HuggingFace.
|
| 500 |
+
|
| 501 |
+
This function provides compatibility with the standard HuggingFace datasets API.
|
| 502 |
+
It wraps the underlying load_dataset function from the datasets library.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
dataset_name: HuggingFace dataset identifier (e.g., "adyen/DABstep")
|
| 506 |
+
split: Dataset split to load ("train", "validation", "test")
|
| 507 |
+
cache_dir: Optional directory for caching downloaded datasets
|
| 508 |
+
**kwargs: Additional arguments passed to datasets.load_dataset
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
HuggingFace Dataset object or dict of Dataset objects
|
| 512 |
+
|
| 513 |
+
Raises:
|
| 514 |
+
ImportError: If datasets library is not installed
|
| 515 |
+
Exception: If dataset loading fails
|
| 516 |
+
|
| 517 |
+
Examples:
|
| 518 |
+
>>> # Load DABStep dataset
|
| 519 |
+
>>> dataset = load_dataset("adyen/DABstep")
|
| 520 |
+
>>> samples = dataset["train"]
|
| 521 |
+
|
| 522 |
+
>>> # Load PRIMUS-Seed with custom cache
|
| 523 |
+
>>> dataset = load_dataset("trendmicro-ailab/Primus-Seed", cache_dir="/tmp/cache")
|
| 524 |
+
|
| 525 |
+
License Attribution:
|
| 526 |
+
- DABStep: CC-BY-4.0 (Creative Commons Attribution 4.0)
|
| 527 |
+
- PRIMUS: ODC-BY (Open Data Commons Attribution)
|
| 528 |
+
"""
|
| 529 |
+
try:
|
| 530 |
+
from datasets import load_dataset as hf_load_dataset
|
| 531 |
+
|
| 532 |
+
logger.info(f"Loading dataset: {dataset_name} (split={split})")
|
| 533 |
+
|
| 534 |
+
load_kwargs = {
|
| 535 |
+
**kwargs,
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
if cache_dir:
|
| 539 |
+
load_kwargs["cache_dir"] = cache_dir
|
| 540 |
+
|
| 541 |
+
dataset = hf_load_dataset(dataset_name, **load_kwargs)
|
| 542 |
+
|
| 543 |
+
logger.info(f"Successfully loaded dataset: {dataset_name}")
|
| 544 |
+
return dataset
|
| 545 |
+
|
| 546 |
+
except ImportError:
|
| 547 |
+
logger.error("datasets library not installed. Run: pip install datasets")
|
| 548 |
+
raise ImportError("The datasets library is required but not installed. Install it with: pip install datasets")
|
| 549 |
+
except Exception as e:
|
| 550 |
+
logger.error(f"Failed to load dataset {dataset_name}: {e}")
|
| 551 |
+
raise
|
src/data/preprocessing.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text Preprocessing Module for Training Data.
|
| 3 |
+
|
| 4 |
+
Provides utilities for:
|
| 5 |
+
- Text cleaning and normalization
|
| 6 |
+
- Tokenization with various backends
|
| 7 |
+
- Feature extraction for meta-controller training
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import re
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class PreprocessedText:
|
| 20 |
+
"""Preprocessed text with metadata."""
|
| 21 |
+
|
| 22 |
+
original: str
|
| 23 |
+
cleaned: str
|
| 24 |
+
tokens: list[str]
|
| 25 |
+
token_ids: list[int] | None = None
|
| 26 |
+
features: dict[str, Any] | None = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TextPreprocessor:
|
| 30 |
+
"""
|
| 31 |
+
Text preprocessing pipeline for multi-agent training data.
|
| 32 |
+
|
| 33 |
+
Handles:
|
| 34 |
+
- HTML/XML tag removal
|
| 35 |
+
- Special character normalization
|
| 36 |
+
- Whitespace cleanup
|
| 37 |
+
- Domain-specific preprocessing (cyber, military, etc.)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Patterns for cleaning
|
| 41 |
+
HTML_TAG_PATTERN = re.compile(r"<[^>]+>")
|
| 42 |
+
URL_PATTERN = re.compile(r"https?://\S+|www\.\S+")
|
| 43 |
+
MULTIPLE_SPACES = re.compile(r"\s+")
|
| 44 |
+
SPECIAL_CHARS = re.compile(r"[^\w\s\-.,!?;:()[\]{}\"'/]")
|
| 45 |
+
|
| 46 |
+
# Domain-specific patterns
|
| 47 |
+
IP_ADDRESS_PATTERN = re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b")
|
| 48 |
+
CVE_PATTERN = re.compile(r"CVE-\d{4}-\d{4,}")
|
| 49 |
+
MITRE_TECHNIQUE_PATTERN = re.compile(r"T\d{4}(?:\.\d{3})?")
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
remove_html: bool = True,
|
| 54 |
+
normalize_urls: bool = True,
|
| 55 |
+
lowercase: bool = False,
|
| 56 |
+
preserve_domain_patterns: bool = True,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Initialize preprocessor.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
remove_html: Remove HTML/XML tags
|
| 63 |
+
normalize_urls: Replace URLs with placeholder
|
| 64 |
+
lowercase: Convert to lowercase
|
| 65 |
+
preserve_domain_patterns: Keep domain-specific patterns (IPs, CVEs, etc.)
|
| 66 |
+
"""
|
| 67 |
+
self.remove_html = remove_html
|
| 68 |
+
self.normalize_urls = normalize_urls
|
| 69 |
+
self.lowercase = lowercase
|
| 70 |
+
self.preserve_domain_patterns = preserve_domain_patterns
|
| 71 |
+
|
| 72 |
+
def clean(self, text: str) -> str:
|
| 73 |
+
"""
|
| 74 |
+
Clean and normalize text.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
text: Raw input text
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Cleaned text
|
| 81 |
+
"""
|
| 82 |
+
if not text:
|
| 83 |
+
return ""
|
| 84 |
+
|
| 85 |
+
result = text
|
| 86 |
+
|
| 87 |
+
# Remove HTML tags
|
| 88 |
+
if self.remove_html:
|
| 89 |
+
result = self.HTML_TAG_PATTERN.sub(" ", result)
|
| 90 |
+
|
| 91 |
+
# Preserve or normalize URLs
|
| 92 |
+
if self.normalize_urls:
|
| 93 |
+
if self.preserve_domain_patterns:
|
| 94 |
+
result = self.URL_PATTERN.sub("[URL]", result)
|
| 95 |
+
else:
|
| 96 |
+
result = self.URL_PATTERN.sub("", result)
|
| 97 |
+
|
| 98 |
+
# Normalize whitespace
|
| 99 |
+
result = self.MULTIPLE_SPACES.sub(" ", result)
|
| 100 |
+
|
| 101 |
+
# Lowercase if requested
|
| 102 |
+
if self.lowercase:
|
| 103 |
+
result = result.lower()
|
| 104 |
+
|
| 105 |
+
# Strip leading/trailing whitespace
|
| 106 |
+
result = result.strip()
|
| 107 |
+
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
def extract_domain_features(self, text: str) -> dict[str, Any]:
|
| 111 |
+
"""
|
| 112 |
+
Extract domain-specific features from text.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
text: Input text
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Dictionary of extracted features
|
| 119 |
+
"""
|
| 120 |
+
features = {
|
| 121 |
+
"has_ip_addresses": bool(self.IP_ADDRESS_PATTERN.search(text)),
|
| 122 |
+
"ip_count": len(self.IP_ADDRESS_PATTERN.findall(text)),
|
| 123 |
+
"has_cve": bool(self.CVE_PATTERN.search(text)),
|
| 124 |
+
"cve_ids": self.CVE_PATTERN.findall(text),
|
| 125 |
+
"has_mitre_techniques": bool(self.MITRE_TECHNIQUE_PATTERN.search(text)),
|
| 126 |
+
"mitre_techniques": self.MITRE_TECHNIQUE_PATTERN.findall(text),
|
| 127 |
+
"text_length": len(text),
|
| 128 |
+
"word_count": len(text.split()),
|
| 129 |
+
"sentence_count": len(re.findall(r"[.!?]+", text)),
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Detect domain indicators
|
| 133 |
+
domain_keywords = {
|
| 134 |
+
"cybersecurity": ["attack", "vulnerability", "exploit", "malware", "threat"],
|
| 135 |
+
"military": ["tactical", "reconnaissance", "deployment", "terrain", "objective"],
|
| 136 |
+
"data_analysis": ["dataset", "analysis", "correlation", "statistics", "visualization"],
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
for domain, keywords in domain_keywords.items():
|
| 140 |
+
features[f"is_{domain}"] = any(kw in text.lower() for kw in keywords)
|
| 141 |
+
|
| 142 |
+
return features
|
| 143 |
+
|
| 144 |
+
def preprocess(self, text: str) -> PreprocessedText:
|
| 145 |
+
"""
|
| 146 |
+
Full preprocessing pipeline.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
text: Raw input text
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
PreprocessedText object with all preprocessing results
|
| 153 |
+
"""
|
| 154 |
+
cleaned = self.clean(text)
|
| 155 |
+
tokens = cleaned.split() # Simple whitespace tokenization
|
| 156 |
+
features = self.extract_domain_features(text)
|
| 157 |
+
|
| 158 |
+
return PreprocessedText(
|
| 159 |
+
original=text,
|
| 160 |
+
cleaned=cleaned,
|
| 161 |
+
tokens=tokens,
|
| 162 |
+
features=features,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def batch_preprocess(self, texts: list[str]) -> list[PreprocessedText]:
|
| 166 |
+
"""
|
| 167 |
+
Preprocess multiple texts.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
texts: List of raw texts
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
List of PreprocessedText objects
|
| 174 |
+
"""
|
| 175 |
+
return [self.preprocess(text) for text in texts]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class TokenizerWrapper:
|
| 179 |
+
"""
|
| 180 |
+
Wrapper for various tokenization backends.
|
| 181 |
+
|
| 182 |
+
Supports:
|
| 183 |
+
- Simple whitespace tokenization
|
| 184 |
+
- HuggingFace tokenizers
|
| 185 |
+
- Custom vocabularies
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
backend: str = "simple",
|
| 191 |
+
model_name: str | None = None,
|
| 192 |
+
max_length: int = 512,
|
| 193 |
+
):
|
| 194 |
+
"""
|
| 195 |
+
Initialize tokenizer.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
backend: Tokenizer backend ('simple', 'huggingface', 'custom')
|
| 199 |
+
model_name: Model name for HuggingFace tokenizer
|
| 200 |
+
max_length: Maximum sequence length
|
| 201 |
+
"""
|
| 202 |
+
self.backend = backend
|
| 203 |
+
self.model_name = model_name
|
| 204 |
+
self.max_length = max_length
|
| 205 |
+
self._tokenizer = None
|
| 206 |
+
|
| 207 |
+
if backend == "huggingface" and model_name:
|
| 208 |
+
self._load_huggingface_tokenizer()
|
| 209 |
+
|
| 210 |
+
def _load_huggingface_tokenizer(self):
|
| 211 |
+
"""Load HuggingFace tokenizer."""
|
| 212 |
+
try:
|
| 213 |
+
from transformers import AutoTokenizer
|
| 214 |
+
|
| 215 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 216 |
+
self.model_name,
|
| 217 |
+
model_max_length=self.max_length,
|
| 218 |
+
)
|
| 219 |
+
logger.info(f"Loaded HuggingFace tokenizer: {self.model_name}")
|
| 220 |
+
except ImportError:
|
| 221 |
+
logger.error("transformers library not installed. Run: pip install transformers")
|
| 222 |
+
raise
|
| 223 |
+
|
| 224 |
+
def tokenize(self, text: str) -> tuple[list[str], list[int] | None]:
|
| 225 |
+
"""
|
| 226 |
+
Tokenize text.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
text: Input text
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Tuple of (tokens, token_ids)
|
| 233 |
+
"""
|
| 234 |
+
if self.backend == "simple":
|
| 235 |
+
tokens = text.split()[: self.max_length]
|
| 236 |
+
return tokens, None
|
| 237 |
+
|
| 238 |
+
elif self.backend == "huggingface" and self._tokenizer:
|
| 239 |
+
encoded = self._tokenizer(
|
| 240 |
+
text,
|
| 241 |
+
truncation=True,
|
| 242 |
+
max_length=self.max_length,
|
| 243 |
+
return_tensors=None,
|
| 244 |
+
)
|
| 245 |
+
tokens = self._tokenizer.convert_ids_to_tokens(encoded["input_ids"])
|
| 246 |
+
token_ids = encoded["input_ids"]
|
| 247 |
+
return tokens, token_ids
|
| 248 |
+
|
| 249 |
+
else:
|
| 250 |
+
raise ValueError(f"Unsupported backend: {self.backend}")
|
| 251 |
+
|
| 252 |
+
def batch_tokenize(self, texts: list[str]) -> list[tuple[list[str], list[int] | None]]:
|
| 253 |
+
"""
|
| 254 |
+
Tokenize multiple texts.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
texts: List of input texts
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
List of (tokens, token_ids) tuples
|
| 261 |
+
"""
|
| 262 |
+
return [self.tokenize(text) for text in texts]
|
| 263 |
+
|
| 264 |
+
def encode_for_training(self, texts: list[str]) -> dict[str, Any]:
|
| 265 |
+
"""
|
| 266 |
+
Encode texts for model training.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
texts: List of input texts
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Dictionary with encoded data ready for training
|
| 273 |
+
"""
|
| 274 |
+
if self.backend != "huggingface" or not self._tokenizer:
|
| 275 |
+
raise ValueError("encode_for_training requires HuggingFace backend")
|
| 276 |
+
|
| 277 |
+
encoded = self._tokenizer(
|
| 278 |
+
texts,
|
| 279 |
+
truncation=True,
|
| 280 |
+
padding=True,
|
| 281 |
+
max_length=self.max_length,
|
| 282 |
+
return_tensors="pt",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return encoded
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class MetaControllerFeatureExtractor:
|
| 289 |
+
"""
|
| 290 |
+
Extract features for meta-controller training.
|
| 291 |
+
|
| 292 |
+
Converts text and agent state information into numerical features
|
| 293 |
+
suitable for RNN/BERT routing decisions.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self):
|
| 297 |
+
"""Initialize feature extractor."""
|
| 298 |
+
self.preprocessor = TextPreprocessor()
|
| 299 |
+
|
| 300 |
+
def extract_query_features(self, query: str) -> dict[str, float]:
|
| 301 |
+
"""
|
| 302 |
+
Extract numerical features from query text.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
query: User query text
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Dictionary of numerical features
|
| 309 |
+
"""
|
| 310 |
+
domain_features = self.preprocessor.extract_domain_features(query)
|
| 311 |
+
|
| 312 |
+
features = {
|
| 313 |
+
"query_length": domain_features["text_length"] / 10000, # Normalize
|
| 314 |
+
"word_count": domain_features["word_count"] / 500,
|
| 315 |
+
"sentence_count": domain_features["sentence_count"] / 50,
|
| 316 |
+
"has_technical_terms": float(
|
| 317 |
+
domain_features["has_ip_addresses"]
|
| 318 |
+
or domain_features["has_cve"]
|
| 319 |
+
or domain_features["has_mitre_techniques"]
|
| 320 |
+
),
|
| 321 |
+
"is_cybersecurity": float(domain_features["is_cybersecurity"]),
|
| 322 |
+
"is_military": float(domain_features["is_military"]),
|
| 323 |
+
"is_data_analysis": float(domain_features["is_data_analysis"]),
|
| 324 |
+
"complexity_score": self._estimate_complexity(query),
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
return features
|
| 328 |
+
|
| 329 |
+
def _estimate_complexity(self, text: str) -> float:
|
| 330 |
+
"""
|
| 331 |
+
Estimate query complexity (0-1 scale).
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
text: Input text
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Complexity score
|
| 338 |
+
"""
|
| 339 |
+
# Simple heuristic based on length, technical terms, etc.
|
| 340 |
+
score = 0.0
|
| 341 |
+
|
| 342 |
+
# Length factor
|
| 343 |
+
word_count = len(text.split())
|
| 344 |
+
if word_count > 50:
|
| 345 |
+
score += 0.3
|
| 346 |
+
elif word_count > 20:
|
| 347 |
+
score += 0.1
|
| 348 |
+
|
| 349 |
+
# Technical term factor
|
| 350 |
+
technical_indicators = [
|
| 351 |
+
"analyze",
|
| 352 |
+
"compare",
|
| 353 |
+
"evaluate",
|
| 354 |
+
"synthesize",
|
| 355 |
+
"strategic",
|
| 356 |
+
"tactical",
|
| 357 |
+
"multi-step",
|
| 358 |
+
"consider",
|
| 359 |
+
]
|
| 360 |
+
for term in technical_indicators:
|
| 361 |
+
if term in text.lower():
|
| 362 |
+
score += 0.1
|
| 363 |
+
|
| 364 |
+
# Question complexity
|
| 365 |
+
if "?" in text:
|
| 366 |
+
if any(kw in text.lower() for kw in ["why", "how", "what if"]):
|
| 367 |
+
score += 0.2
|
| 368 |
+
else:
|
| 369 |
+
score += 0.1
|
| 370 |
+
|
| 371 |
+
return min(score, 1.0)
|
| 372 |
+
|
| 373 |
+
def extract_agent_state_features(
|
| 374 |
+
self,
|
| 375 |
+
hrm_confidence: float = 0.0,
|
| 376 |
+
trm_confidence: float = 0.0,
|
| 377 |
+
mcts_iterations: int = 0,
|
| 378 |
+
consensus_score: float = 0.0,
|
| 379 |
+
rag_retrieved: int = 0,
|
| 380 |
+
) -> list[float]:
|
| 381 |
+
"""
|
| 382 |
+
Extract features from current agent state.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
hrm_confidence: HRM agent confidence
|
| 386 |
+
trm_confidence: TRM agent confidence
|
| 387 |
+
mcts_iterations: MCTS iterations completed
|
| 388 |
+
consensus_score: Inter-agent consensus
|
| 389 |
+
rag_retrieved: Number of RAG documents retrieved
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List of normalized features (10-dimensional)
|
| 393 |
+
"""
|
| 394 |
+
return [
|
| 395 |
+
hrm_confidence,
|
| 396 |
+
trm_confidence,
|
| 397 |
+
min(mcts_iterations / 1000, 1.0),
|
| 398 |
+
consensus_score,
|
| 399 |
+
min(rag_retrieved / 20, 1.0),
|
| 400 |
+
# Derived features
|
| 401 |
+
abs(hrm_confidence - trm_confidence), # Disagreement
|
| 402 |
+
(hrm_confidence + trm_confidence) / 2, # Average confidence
|
| 403 |
+
float(mcts_iterations > 0), # MCTS active
|
| 404 |
+
float(consensus_score > 0.7), # High consensus
|
| 405 |
+
float(rag_retrieved > 0), # RAG used
|
| 406 |
+
]
|
src/data/tactical_augmentation.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tactical Data Augmentation Module.
|
| 3 |
+
|
| 4 |
+
Provides domain-specific data augmentation techniques for:
|
| 5 |
+
- Cybersecurity threat scenarios
|
| 6 |
+
- Military tactical situations
|
| 7 |
+
- Multi-step reasoning problems
|
| 8 |
+
|
| 9 |
+
These augmentations help increase training data diversity and improve
|
| 10 |
+
model robustness for tactical analysis tasks.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import random
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
|
| 17 |
+
from .dataset_loader import DatasetSample
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class AugmentationResult:
|
| 24 |
+
"""Result of data augmentation."""
|
| 25 |
+
|
| 26 |
+
original: DatasetSample
|
| 27 |
+
augmented: list[DatasetSample]
|
| 28 |
+
augmentation_types: list[str]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TacticalAugmenter:
|
| 32 |
+
"""
|
| 33 |
+
Domain-specific data augmentation for tactical analysis.
|
| 34 |
+
|
| 35 |
+
Augmentation techniques:
|
| 36 |
+
- Paraphrasing tactical scenarios
|
| 37 |
+
- Varying urgency levels
|
| 38 |
+
- Adding/removing constraints
|
| 39 |
+
- Scenario parameter variation
|
| 40 |
+
- Threat actor substitution
|
| 41 |
+
- Temporal shifting
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# Tactical scenario templates
|
| 45 |
+
URGENCY_MODIFIERS = {
|
| 46 |
+
"high": ["IMMEDIATE", "CRITICAL", "URGENT", "TIME-SENSITIVE"],
|
| 47 |
+
"medium": ["PRIORITY", "IMPORTANT", "ATTENTION REQUIRED"],
|
| 48 |
+
"low": ["ROUTINE", "STANDARD", "WHEN POSSIBLE"],
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
THREAT_ACTORS = [
|
| 52 |
+
"APT28",
|
| 53 |
+
"APT29",
|
| 54 |
+
"Lazarus Group",
|
| 55 |
+
"Cozy Bear",
|
| 56 |
+
"Fancy Bear",
|
| 57 |
+
"Unknown Actor",
|
| 58 |
+
"Nation-State Actor",
|
| 59 |
+
"Criminal Organization",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
ATTACK_VECTORS = [
|
| 63 |
+
"phishing",
|
| 64 |
+
"spear-phishing",
|
| 65 |
+
"watering hole",
|
| 66 |
+
"supply chain compromise",
|
| 67 |
+
"zero-day exploit",
|
| 68 |
+
"credential stuffing",
|
| 69 |
+
"brute force",
|
| 70 |
+
"social engineering",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
MILITARY_OBJECTIVES = [
|
| 74 |
+
"secure perimeter",
|
| 75 |
+
"establish forward position",
|
| 76 |
+
"conduct reconnaissance",
|
| 77 |
+
"neutralize threat",
|
| 78 |
+
"protect assets",
|
| 79 |
+
"maintain operational security",
|
| 80 |
+
"coordinate with allied forces",
|
| 81 |
+
"execute tactical withdrawal",
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
ENVIRONMENTAL_CONDITIONS = [
|
| 85 |
+
"night operations",
|
| 86 |
+
"adverse weather",
|
| 87 |
+
"limited visibility",
|
| 88 |
+
"urban terrain",
|
| 89 |
+
"mountainous region",
|
| 90 |
+
"coastal area",
|
| 91 |
+
"contested airspace",
|
| 92 |
+
"electronic warfare environment",
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
def __init__(self, seed: int = 42):
|
| 96 |
+
"""
|
| 97 |
+
Initialize augmenter.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
seed: Random seed for reproducibility
|
| 101 |
+
"""
|
| 102 |
+
self.rng = random.Random(seed)
|
| 103 |
+
self._augmentation_count = 0
|
| 104 |
+
|
| 105 |
+
def augment_sample(
|
| 106 |
+
self,
|
| 107 |
+
sample: DatasetSample,
|
| 108 |
+
num_augmentations: int = 3,
|
| 109 |
+
techniques: list[str] | None = None,
|
| 110 |
+
) -> AugmentationResult:
|
| 111 |
+
"""
|
| 112 |
+
Augment a single sample.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
sample: Original dataset sample
|
| 116 |
+
num_augmentations: Number of augmented versions to create
|
| 117 |
+
techniques: Specific techniques to use (None for random selection)
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
AugmentationResult with augmented samples
|
| 121 |
+
"""
|
| 122 |
+
available_techniques = [
|
| 123 |
+
"urgency_variation",
|
| 124 |
+
"parameter_substitution",
|
| 125 |
+
"constraint_addition",
|
| 126 |
+
"temporal_shift",
|
| 127 |
+
"perspective_change",
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
if techniques:
|
| 131 |
+
available_techniques = [t for t in techniques if t in available_techniques]
|
| 132 |
+
|
| 133 |
+
augmented_samples = []
|
| 134 |
+
used_techniques = []
|
| 135 |
+
|
| 136 |
+
for _i in range(num_augmentations):
|
| 137 |
+
technique = self.rng.choice(available_techniques)
|
| 138 |
+
used_techniques.append(technique)
|
| 139 |
+
|
| 140 |
+
augmented_text = self._apply_technique(sample.text, sample.domain, technique)
|
| 141 |
+
|
| 142 |
+
aug_sample = DatasetSample(
|
| 143 |
+
id=f"{sample.id}_aug_{self._augmentation_count}",
|
| 144 |
+
text=augmented_text,
|
| 145 |
+
metadata={
|
| 146 |
+
**sample.metadata,
|
| 147 |
+
"augmentation": technique,
|
| 148 |
+
"original_id": sample.id,
|
| 149 |
+
},
|
| 150 |
+
labels=sample.labels,
|
| 151 |
+
difficulty=sample.difficulty,
|
| 152 |
+
domain=sample.domain,
|
| 153 |
+
reasoning_steps=sample.reasoning_steps,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
augmented_samples.append(aug_sample)
|
| 157 |
+
self._augmentation_count += 1
|
| 158 |
+
|
| 159 |
+
return AugmentationResult(
|
| 160 |
+
original=sample,
|
| 161 |
+
augmented=augmented_samples,
|
| 162 |
+
augmentation_types=used_techniques,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def _apply_technique(self, text: str, domain: str | None, technique: str) -> str:
|
| 166 |
+
"""Apply specific augmentation technique."""
|
| 167 |
+
if technique == "urgency_variation":
|
| 168 |
+
return self._augment_urgency(text)
|
| 169 |
+
elif technique == "parameter_substitution":
|
| 170 |
+
return self._augment_parameters(text, domain)
|
| 171 |
+
elif technique == "constraint_addition":
|
| 172 |
+
return self._augment_constraints(text, domain)
|
| 173 |
+
elif technique == "temporal_shift":
|
| 174 |
+
return self._augment_temporal(text)
|
| 175 |
+
elif technique == "perspective_change":
|
| 176 |
+
return self._augment_perspective(text, domain)
|
| 177 |
+
else:
|
| 178 |
+
return text
|
| 179 |
+
|
| 180 |
+
def _augment_urgency(self, text: str) -> str:
|
| 181 |
+
"""Vary urgency level in the text."""
|
| 182 |
+
urgency_level = self.rng.choice(list(self.URGENCY_MODIFIERS.keys()))
|
| 183 |
+
modifier = self.rng.choice(self.URGENCY_MODIFIERS[urgency_level])
|
| 184 |
+
|
| 185 |
+
# Add urgency prefix
|
| 186 |
+
if urgency_level == "high":
|
| 187 |
+
return f"[{modifier}] {text}"
|
| 188 |
+
elif urgency_level == "medium":
|
| 189 |
+
return f"{modifier}: {text}"
|
| 190 |
+
else:
|
| 191 |
+
return f"({modifier}) {text}"
|
| 192 |
+
|
| 193 |
+
def _augment_parameters(self, text: str, domain: str | None) -> str:
|
| 194 |
+
"""Substitute domain-specific parameters."""
|
| 195 |
+
if domain == "cybersecurity" or "cyber" in text.lower():
|
| 196 |
+
# Substitute threat actors
|
| 197 |
+
for actor in self.THREAT_ACTORS:
|
| 198 |
+
if actor in text:
|
| 199 |
+
new_actor = self.rng.choice([a for a in self.THREAT_ACTORS if a != actor])
|
| 200 |
+
text = text.replace(actor, new_actor)
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
# Substitute attack vectors
|
| 204 |
+
for vector in self.ATTACK_VECTORS:
|
| 205 |
+
if vector in text.lower():
|
| 206 |
+
new_vector = self.rng.choice([v for v in self.ATTACK_VECTORS if v != vector])
|
| 207 |
+
text = text.replace(vector, new_vector)
|
| 208 |
+
break
|
| 209 |
+
|
| 210 |
+
elif domain == "military" or any(kw in text.lower() for kw in ["tactical", "military", "reconnaissance"]):
|
| 211 |
+
# Substitute objectives
|
| 212 |
+
for obj in self.MILITARY_OBJECTIVES:
|
| 213 |
+
if obj in text.lower():
|
| 214 |
+
new_obj = self.rng.choice([o for o in self.MILITARY_OBJECTIVES if o != obj])
|
| 215 |
+
text = text.replace(obj, new_obj)
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
return text
|
| 219 |
+
|
| 220 |
+
def _augment_constraints(self, text: str, domain: str | None) -> str:
|
| 221 |
+
"""Add additional constraints to the scenario."""
|
| 222 |
+
constraints = []
|
| 223 |
+
|
| 224 |
+
if domain == "cybersecurity":
|
| 225 |
+
constraints = [
|
| 226 |
+
"with limited network visibility",
|
| 227 |
+
"under active attack",
|
| 228 |
+
"with compromised credentials",
|
| 229 |
+
"during maintenance window",
|
| 230 |
+
"with restricted access to logs",
|
| 231 |
+
]
|
| 232 |
+
elif domain == "military":
|
| 233 |
+
constraints = [
|
| 234 |
+
"with limited ammunition",
|
| 235 |
+
"under communication blackout",
|
| 236 |
+
"with reduced personnel",
|
| 237 |
+
"in contested environment",
|
| 238 |
+
"with time constraint of 2 hours",
|
| 239 |
+
]
|
| 240 |
+
else:
|
| 241 |
+
constraints = [
|
| 242 |
+
"with incomplete information",
|
| 243 |
+
"under time pressure",
|
| 244 |
+
"with resource constraints",
|
| 245 |
+
"considering multiple stakeholders",
|
| 246 |
+
"with conflicting objectives",
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
if constraints:
|
| 250 |
+
constraint = self.rng.choice(constraints)
|
| 251 |
+
return f"{text} [{constraint}]"
|
| 252 |
+
|
| 253 |
+
return text
|
| 254 |
+
|
| 255 |
+
def _augment_temporal(self, text: str) -> str:
|
| 256 |
+
"""Shift temporal context."""
|
| 257 |
+
temporal_contexts = [
|
| 258 |
+
"In the past 24 hours, ",
|
| 259 |
+
"Over the next week, ",
|
| 260 |
+
"Immediately, ",
|
| 261 |
+
"During the upcoming operation, ",
|
| 262 |
+
"Following initial assessment, ",
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
context = self.rng.choice(temporal_contexts)
|
| 266 |
+
return f"{context}{text.lower()}" if text else text
|
| 267 |
+
|
| 268 |
+
def _augment_perspective(self, text: str, domain: str | None) -> str:
|
| 269 |
+
"""Change analytical perspective."""
|
| 270 |
+
perspectives = {
|
| 271 |
+
"cybersecurity": [
|
| 272 |
+
"From a threat hunter's perspective: ",
|
| 273 |
+
"Considering the attacker's viewpoint: ",
|
| 274 |
+
"For incident response purposes: ",
|
| 275 |
+
"From a risk management standpoint: ",
|
| 276 |
+
],
|
| 277 |
+
"military": [
|
| 278 |
+
"From the commander's perspective: ",
|
| 279 |
+
"Considering enemy capabilities: ",
|
| 280 |
+
"For tactical planning purposes: ",
|
| 281 |
+
"From a logistics standpoint: ",
|
| 282 |
+
],
|
| 283 |
+
"default": [
|
| 284 |
+
"From an analytical perspective: ",
|
| 285 |
+
"Considering all factors: ",
|
| 286 |
+
"For decision-making purposes: ",
|
| 287 |
+
"From a strategic viewpoint: ",
|
| 288 |
+
],
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
domain_perspectives = perspectives.get(domain or "default", perspectives["default"])
|
| 292 |
+
perspective = self.rng.choice(domain_perspectives)
|
| 293 |
+
|
| 294 |
+
return f"{perspective}{text}"
|
| 295 |
+
|
| 296 |
+
def augment_batch(
|
| 297 |
+
self,
|
| 298 |
+
samples: list[DatasetSample],
|
| 299 |
+
augmentations_per_sample: int = 2,
|
| 300 |
+
) -> list[DatasetSample]:
|
| 301 |
+
"""
|
| 302 |
+
Augment a batch of samples.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
samples: List of original samples
|
| 306 |
+
augmentations_per_sample: Number of augmentations per sample
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
List of all samples (original + augmented)
|
| 310 |
+
"""
|
| 311 |
+
all_samples = list(samples) # Keep originals
|
| 312 |
+
|
| 313 |
+
for sample in samples:
|
| 314 |
+
result = self.augment_sample(sample, num_augmentations=augmentations_per_sample)
|
| 315 |
+
all_samples.extend(result.augmented)
|
| 316 |
+
|
| 317 |
+
logger.info(
|
| 318 |
+
f"Augmented {len(samples)} samples to {len(all_samples)} (+{len(all_samples) - len(samples)} augmented)"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
return all_samples
|
| 322 |
+
|
| 323 |
+
def create_tactical_scenarios(self, base_samples: list[DatasetSample]) -> list[DatasetSample]:
|
| 324 |
+
"""
|
| 325 |
+
Create tactical scenario variations from base samples.
|
| 326 |
+
|
| 327 |
+
Combines multiple augmentation techniques to create
|
| 328 |
+
diverse tactical scenarios for training.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
base_samples: Base dataset samples
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Extended list with tactical scenario variations
|
| 335 |
+
"""
|
| 336 |
+
scenarios = list(base_samples)
|
| 337 |
+
|
| 338 |
+
for sample in base_samples:
|
| 339 |
+
# Create high-stakes variant
|
| 340 |
+
high_stakes = self._augment_urgency(sample.text)
|
| 341 |
+
high_stakes = self._augment_constraints(high_stakes, sample.domain)
|
| 342 |
+
scenarios.append(
|
| 343 |
+
DatasetSample(
|
| 344 |
+
id=f"{sample.id}_highstakes_{self._augmentation_count}",
|
| 345 |
+
text=high_stakes,
|
| 346 |
+
metadata={
|
| 347 |
+
**sample.metadata,
|
| 348 |
+
"scenario_type": "high_stakes",
|
| 349 |
+
"original_id": sample.id,
|
| 350 |
+
},
|
| 351 |
+
labels=sample.labels,
|
| 352 |
+
difficulty="hard", # High stakes scenarios are harder
|
| 353 |
+
domain=sample.domain,
|
| 354 |
+
reasoning_steps=sample.reasoning_steps,
|
| 355 |
+
)
|
| 356 |
+
)
|
| 357 |
+
self._augmentation_count += 1
|
| 358 |
+
|
| 359 |
+
# Create multi-perspective variant
|
| 360 |
+
if self.rng.random() > 0.5:
|
| 361 |
+
multi_perspective = self._augment_perspective(sample.text, sample.domain)
|
| 362 |
+
scenarios.append(
|
| 363 |
+
DatasetSample(
|
| 364 |
+
id=f"{sample.id}_multiperspective_{self._augmentation_count}",
|
| 365 |
+
text=multi_perspective,
|
| 366 |
+
metadata={
|
| 367 |
+
**sample.metadata,
|
| 368 |
+
"scenario_type": "multi_perspective",
|
| 369 |
+
"original_id": sample.id,
|
| 370 |
+
},
|
| 371 |
+
labels=sample.labels,
|
| 372 |
+
difficulty=sample.difficulty,
|
| 373 |
+
domain=sample.domain,
|
| 374 |
+
reasoning_steps=sample.reasoning_steps,
|
| 375 |
+
)
|
| 376 |
+
)
|
| 377 |
+
self._augmentation_count += 1
|
| 378 |
+
|
| 379 |
+
logger.info(f"Created {len(scenarios) - len(base_samples)} tactical scenarios")
|
| 380 |
+
return scenarios
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class CyberSecurityAugmenter(TacticalAugmenter):
|
| 384 |
+
"""
|
| 385 |
+
Specialized augmenter for cybersecurity scenarios.
|
| 386 |
+
|
| 387 |
+
Focuses on:
|
| 388 |
+
- MITRE ATT&CK technique variations
|
| 389 |
+
- Threat intelligence context
|
| 390 |
+
- Incident response scenarios
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
MITRE_TACTICS = [
|
| 394 |
+
"Initial Access",
|
| 395 |
+
"Execution",
|
| 396 |
+
"Persistence",
|
| 397 |
+
"Privilege Escalation",
|
| 398 |
+
"Defense Evasion",
|
| 399 |
+
"Credential Access",
|
| 400 |
+
"Discovery",
|
| 401 |
+
"Lateral Movement",
|
| 402 |
+
"Collection",
|
| 403 |
+
"Exfiltration",
|
| 404 |
+
"Impact",
|
| 405 |
+
]
|
| 406 |
+
|
| 407 |
+
SEVERITY_LEVELS = ["LOW", "MEDIUM", "HIGH", "CRITICAL"]
|
| 408 |
+
|
| 409 |
+
def augment_with_mitre_context(self, sample: DatasetSample) -> DatasetSample:
|
| 410 |
+
"""
|
| 411 |
+
Add MITRE ATT&CK context to sample.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
sample: Original sample
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
Augmented sample with MITRE context
|
| 418 |
+
"""
|
| 419 |
+
tactic = self.rng.choice(self.MITRE_TACTICS)
|
| 420 |
+
severity = self.rng.choice(self.SEVERITY_LEVELS)
|
| 421 |
+
|
| 422 |
+
augmented_text = f"[MITRE ATT&CK: {tactic}] [Severity: {severity}] {sample.text}"
|
| 423 |
+
|
| 424 |
+
return DatasetSample(
|
| 425 |
+
id=f"{sample.id}_mitre_{self._augmentation_count}",
|
| 426 |
+
text=augmented_text,
|
| 427 |
+
metadata={
|
| 428 |
+
**sample.metadata,
|
| 429 |
+
"mitre_tactic": tactic,
|
| 430 |
+
"severity": severity,
|
| 431 |
+
},
|
| 432 |
+
labels=sample.labels,
|
| 433 |
+
difficulty=sample.difficulty,
|
| 434 |
+
domain="cybersecurity",
|
| 435 |
+
reasoning_steps=sample.reasoning_steps,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class MilitaryTacticalAugmenter(TacticalAugmenter):
|
| 440 |
+
"""
|
| 441 |
+
Specialized augmenter for military tactical scenarios.
|
| 442 |
+
|
| 443 |
+
Focuses on:
|
| 444 |
+
- Environmental condition variations
|
| 445 |
+
- Force composition changes
|
| 446 |
+
- Mission objective variations
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
FORCE_COMPOSITIONS = [
|
| 450 |
+
"infantry platoon",
|
| 451 |
+
"mechanized company",
|
| 452 |
+
"special operations team",
|
| 453 |
+
"combined arms battalion",
|
| 454 |
+
"air assault element",
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
def augment_with_force_composition(self, sample: DatasetSample) -> DatasetSample:
|
| 458 |
+
"""
|
| 459 |
+
Add force composition context to sample.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
sample: Original sample
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
Augmented sample with force composition
|
| 466 |
+
"""
|
| 467 |
+
force = self.rng.choice(self.FORCE_COMPOSITIONS)
|
| 468 |
+
condition = self.rng.choice(self.ENVIRONMENTAL_CONDITIONS)
|
| 469 |
+
|
| 470 |
+
augmented_text = f"[Force: {force}] [Conditions: {condition}] {sample.text}"
|
| 471 |
+
|
| 472 |
+
return DatasetSample(
|
| 473 |
+
id=f"{sample.id}_tactical_{self._augmentation_count}",
|
| 474 |
+
text=augmented_text,
|
| 475 |
+
metadata={
|
| 476 |
+
**sample.metadata,
|
| 477 |
+
"force_composition": force,
|
| 478 |
+
"environmental_conditions": condition,
|
| 479 |
+
},
|
| 480 |
+
labels=sample.labels,
|
| 481 |
+
difficulty=sample.difficulty,
|
| 482 |
+
domain="military",
|
| 483 |
+
reasoning_steps=sample.reasoning_steps,
|
| 484 |
+
)
|
src/data/train_test_split.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Splitting Module for Training Pipeline.
|
| 3 |
+
|
| 4 |
+
Provides utilities for:
|
| 5 |
+
- Train/validation/test splitting
|
| 6 |
+
- Stratified sampling by domain or difficulty
|
| 7 |
+
- Cross-validation fold creation
|
| 8 |
+
- Reproducible splits with seeding
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
from .dataset_loader import DatasetSample
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DataSplit:
|
| 23 |
+
"""Result of dataset splitting."""
|
| 24 |
+
|
| 25 |
+
train: list[DatasetSample]
|
| 26 |
+
validation: list[DatasetSample]
|
| 27 |
+
test: list[DatasetSample]
|
| 28 |
+
split_info: dict[str, Any]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class CrossValidationFold:
|
| 33 |
+
"""Single fold for cross-validation."""
|
| 34 |
+
|
| 35 |
+
fold_id: int
|
| 36 |
+
train: list[DatasetSample]
|
| 37 |
+
validation: list[DatasetSample]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DataSplitter:
|
| 41 |
+
"""
|
| 42 |
+
Basic dataset splitter with random sampling.
|
| 43 |
+
|
| 44 |
+
Provides reproducible train/validation/test splits
|
| 45 |
+
with configurable ratios.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, seed: int = 42):
|
| 49 |
+
"""
|
| 50 |
+
Initialize splitter.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
seed: Random seed for reproducibility
|
| 54 |
+
"""
|
| 55 |
+
self.seed = seed
|
| 56 |
+
import random
|
| 57 |
+
|
| 58 |
+
self.rng = random.Random(seed)
|
| 59 |
+
|
| 60 |
+
def split(
|
| 61 |
+
self,
|
| 62 |
+
samples: list[DatasetSample],
|
| 63 |
+
train_ratio: float = 0.7,
|
| 64 |
+
val_ratio: float = 0.15,
|
| 65 |
+
test_ratio: float = 0.15,
|
| 66 |
+
shuffle: bool = True,
|
| 67 |
+
) -> DataSplit:
|
| 68 |
+
"""
|
| 69 |
+
Split dataset into train/validation/test sets.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
samples: List of all samples
|
| 73 |
+
train_ratio: Proportion for training (default 0.7)
|
| 74 |
+
val_ratio: Proportion for validation (default 0.15)
|
| 75 |
+
test_ratio: Proportion for testing (default 0.15)
|
| 76 |
+
shuffle: Whether to shuffle before splitting
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
DataSplit with train, validation, and test sets
|
| 80 |
+
"""
|
| 81 |
+
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001:
|
| 82 |
+
raise ValueError("Ratios must sum to 1.0")
|
| 83 |
+
|
| 84 |
+
if not samples:
|
| 85 |
+
raise ValueError("Cannot split empty sample list")
|
| 86 |
+
|
| 87 |
+
# Copy and optionally shuffle
|
| 88 |
+
all_samples = list(samples)
|
| 89 |
+
if shuffle:
|
| 90 |
+
self.rng.shuffle(all_samples)
|
| 91 |
+
|
| 92 |
+
n = len(all_samples)
|
| 93 |
+
train_end = int(n * train_ratio)
|
| 94 |
+
val_end = train_end + int(n * val_ratio)
|
| 95 |
+
|
| 96 |
+
train_samples = all_samples[:train_end]
|
| 97 |
+
val_samples = all_samples[train_end:val_end]
|
| 98 |
+
test_samples = all_samples[val_end:]
|
| 99 |
+
|
| 100 |
+
split_info = {
|
| 101 |
+
"total_samples": n,
|
| 102 |
+
"train_samples": len(train_samples),
|
| 103 |
+
"val_samples": len(val_samples),
|
| 104 |
+
"test_samples": len(test_samples),
|
| 105 |
+
"train_ratio": len(train_samples) / n,
|
| 106 |
+
"val_ratio": len(val_samples) / n,
|
| 107 |
+
"test_ratio": len(test_samples) / n,
|
| 108 |
+
"seed": self.seed,
|
| 109 |
+
"shuffled": shuffle,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
logger.info(f"Split {n} samples: train={len(train_samples)}, val={len(val_samples)}, test={len(test_samples)}")
|
| 113 |
+
|
| 114 |
+
return DataSplit(
|
| 115 |
+
train=train_samples,
|
| 116 |
+
validation=val_samples,
|
| 117 |
+
test=test_samples,
|
| 118 |
+
split_info=split_info,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def create_k_folds(
|
| 122 |
+
self,
|
| 123 |
+
samples: list[DatasetSample],
|
| 124 |
+
k: int = 5,
|
| 125 |
+
shuffle: bool = True,
|
| 126 |
+
) -> list[CrossValidationFold]:
|
| 127 |
+
"""
|
| 128 |
+
Create k-fold cross-validation splits.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
samples: List of all samples
|
| 132 |
+
k: Number of folds
|
| 133 |
+
shuffle: Whether to shuffle before splitting
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
List of CrossValidationFold objects
|
| 137 |
+
"""
|
| 138 |
+
if k < 2:
|
| 139 |
+
raise ValueError("k must be at least 2")
|
| 140 |
+
|
| 141 |
+
if len(samples) < k:
|
| 142 |
+
raise ValueError(f"Need at least {k} samples for {k}-fold CV")
|
| 143 |
+
|
| 144 |
+
# Copy and optionally shuffle
|
| 145 |
+
all_samples = list(samples)
|
| 146 |
+
if shuffle:
|
| 147 |
+
self.rng.shuffle(all_samples)
|
| 148 |
+
|
| 149 |
+
# Calculate fold sizes
|
| 150 |
+
fold_size = len(all_samples) // k
|
| 151 |
+
folds = []
|
| 152 |
+
|
| 153 |
+
for fold_id in range(k):
|
| 154 |
+
# Validation is the current fold
|
| 155 |
+
val_start = fold_id * fold_size
|
| 156 |
+
val_end = len(all_samples) if fold_id == k - 1 else val_start + fold_size # noqa: SIM108
|
| 157 |
+
|
| 158 |
+
val_samples = all_samples[val_start:val_end]
|
| 159 |
+
train_samples = all_samples[:val_start] + all_samples[val_end:]
|
| 160 |
+
|
| 161 |
+
folds.append(
|
| 162 |
+
CrossValidationFold(
|
| 163 |
+
fold_id=fold_id,
|
| 164 |
+
train=train_samples,
|
| 165 |
+
validation=val_samples,
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
logger.info(f"Created {k}-fold cross-validation splits")
|
| 170 |
+
return folds
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class StratifiedSplitter(DataSplitter):
|
| 174 |
+
"""
|
| 175 |
+
Stratified dataset splitter.
|
| 176 |
+
|
| 177 |
+
Ensures proportional representation of categories
|
| 178 |
+
(domain, difficulty, etc.) across splits.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, seed: int = 42, stratify_by: str = "domain"):
|
| 182 |
+
"""
|
| 183 |
+
Initialize stratified splitter.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
seed: Random seed for reproducibility
|
| 187 |
+
stratify_by: Attribute to stratify on ('domain', 'difficulty', 'labels')
|
| 188 |
+
"""
|
| 189 |
+
super().__init__(seed)
|
| 190 |
+
self.stratify_by = stratify_by
|
| 191 |
+
|
| 192 |
+
def split(
|
| 193 |
+
self,
|
| 194 |
+
samples: list[DatasetSample],
|
| 195 |
+
train_ratio: float = 0.7,
|
| 196 |
+
val_ratio: float = 0.15,
|
| 197 |
+
test_ratio: float = 0.15,
|
| 198 |
+
shuffle: bool = True,
|
| 199 |
+
) -> DataSplit:
|
| 200 |
+
"""
|
| 201 |
+
Stratified split maintaining category proportions.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
samples: List of all samples
|
| 205 |
+
train_ratio: Proportion for training
|
| 206 |
+
val_ratio: Proportion for validation
|
| 207 |
+
test_ratio: Proportion for testing
|
| 208 |
+
shuffle: Whether to shuffle before splitting
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
DataSplit with stratified train, validation, and test sets
|
| 212 |
+
"""
|
| 213 |
+
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001:
|
| 214 |
+
raise ValueError("Ratios must sum to 1.0")
|
| 215 |
+
|
| 216 |
+
if not samples:
|
| 217 |
+
raise ValueError("Cannot split empty sample list")
|
| 218 |
+
|
| 219 |
+
# Group samples by stratification key
|
| 220 |
+
groups = defaultdict(list)
|
| 221 |
+
for sample in samples:
|
| 222 |
+
key = self._get_stratify_key(sample)
|
| 223 |
+
groups[key].append(sample)
|
| 224 |
+
|
| 225 |
+
# Split each group proportionally
|
| 226 |
+
train_samples = []
|
| 227 |
+
val_samples = []
|
| 228 |
+
test_samples = []
|
| 229 |
+
|
| 230 |
+
for _key, group_samples in groups.items():
|
| 231 |
+
if shuffle:
|
| 232 |
+
self.rng.shuffle(group_samples)
|
| 233 |
+
|
| 234 |
+
n = len(group_samples)
|
| 235 |
+
train_end = int(n * train_ratio)
|
| 236 |
+
val_end = train_end + int(n * val_ratio)
|
| 237 |
+
|
| 238 |
+
train_samples.extend(group_samples[:train_end])
|
| 239 |
+
val_samples.extend(group_samples[train_end:val_end])
|
| 240 |
+
test_samples.extend(group_samples[val_end:])
|
| 241 |
+
|
| 242 |
+
# Final shuffle of combined sets
|
| 243 |
+
if shuffle:
|
| 244 |
+
self.rng.shuffle(train_samples)
|
| 245 |
+
self.rng.shuffle(val_samples)
|
| 246 |
+
self.rng.shuffle(test_samples)
|
| 247 |
+
|
| 248 |
+
# Verify stratification
|
| 249 |
+
stratify_info = self._verify_stratification(train_samples, val_samples, test_samples)
|
| 250 |
+
|
| 251 |
+
split_info = {
|
| 252 |
+
"total_samples": len(samples),
|
| 253 |
+
"train_samples": len(train_samples),
|
| 254 |
+
"val_samples": len(val_samples),
|
| 255 |
+
"test_samples": len(test_samples),
|
| 256 |
+
"train_ratio": len(train_samples) / len(samples),
|
| 257 |
+
"val_ratio": len(val_samples) / len(samples),
|
| 258 |
+
"test_ratio": len(test_samples) / len(samples),
|
| 259 |
+
"stratify_by": self.stratify_by,
|
| 260 |
+
"stratification_info": stratify_info,
|
| 261 |
+
"seed": self.seed,
|
| 262 |
+
"shuffled": shuffle,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
logger.info(
|
| 266 |
+
f"Stratified split ({self.stratify_by}): "
|
| 267 |
+
f"train={len(train_samples)}, val={len(val_samples)}, "
|
| 268 |
+
f"test={len(test_samples)}"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return DataSplit(
|
| 272 |
+
train=train_samples,
|
| 273 |
+
validation=val_samples,
|
| 274 |
+
test=test_samples,
|
| 275 |
+
split_info=split_info,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def _get_stratify_key(self, sample: DatasetSample) -> str:
|
| 279 |
+
"""Get stratification key for a sample."""
|
| 280 |
+
if self.stratify_by == "domain":
|
| 281 |
+
return sample.domain or "unknown"
|
| 282 |
+
elif self.stratify_by == "difficulty":
|
| 283 |
+
return sample.difficulty or "unknown"
|
| 284 |
+
elif self.stratify_by == "labels":
|
| 285 |
+
return ",".join(sorted(sample.labels)) if sample.labels else "unknown"
|
| 286 |
+
else:
|
| 287 |
+
return str(getattr(sample, self.stratify_by, "unknown"))
|
| 288 |
+
|
| 289 |
+
def _verify_stratification(
|
| 290 |
+
self,
|
| 291 |
+
train: list[DatasetSample],
|
| 292 |
+
val: list[DatasetSample],
|
| 293 |
+
test: list[DatasetSample],
|
| 294 |
+
) -> dict[str, dict[str, float]]:
|
| 295 |
+
"""
|
| 296 |
+
Verify that stratification was successful.
|
| 297 |
+
|
| 298 |
+
Returns dictionary showing distribution of stratification key
|
| 299 |
+
across train/val/test splits.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def get_distribution(samples: list[DatasetSample]) -> dict[str, float]:
|
| 303 |
+
if not samples:
|
| 304 |
+
return {}
|
| 305 |
+
counts = defaultdict(int)
|
| 306 |
+
for sample in samples:
|
| 307 |
+
key = self._get_stratify_key(sample)
|
| 308 |
+
counts[key] += 1
|
| 309 |
+
total = len(samples)
|
| 310 |
+
return {k: v / total for k, v in counts.items()}
|
| 311 |
+
|
| 312 |
+
return {
|
| 313 |
+
"train": get_distribution(train),
|
| 314 |
+
"validation": get_distribution(val),
|
| 315 |
+
"test": get_distribution(test),
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
def create_stratified_k_folds(
|
| 319 |
+
self,
|
| 320 |
+
samples: list[DatasetSample],
|
| 321 |
+
k: int = 5,
|
| 322 |
+
shuffle: bool = True,
|
| 323 |
+
) -> list[CrossValidationFold]:
|
| 324 |
+
"""
|
| 325 |
+
Create stratified k-fold cross-validation splits.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
samples: List of all samples
|
| 329 |
+
k: Number of folds
|
| 330 |
+
shuffle: Whether to shuffle before splitting
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
List of CrossValidationFold objects with stratification
|
| 334 |
+
"""
|
| 335 |
+
if k < 2:
|
| 336 |
+
raise ValueError("k must be at least 2")
|
| 337 |
+
|
| 338 |
+
# Group samples by stratification key
|
| 339 |
+
groups = defaultdict(list)
|
| 340 |
+
for sample in samples:
|
| 341 |
+
key = self._get_stratify_key(sample)
|
| 342 |
+
groups[key].append(sample)
|
| 343 |
+
|
| 344 |
+
# Initialize folds
|
| 345 |
+
folds_data = [{"train": [], "val": []} for _ in range(k)]
|
| 346 |
+
|
| 347 |
+
# Distribute each group across folds
|
| 348 |
+
for _key, group_samples in groups.items():
|
| 349 |
+
if shuffle:
|
| 350 |
+
self.rng.shuffle(group_samples)
|
| 351 |
+
|
| 352 |
+
# Assign samples to folds
|
| 353 |
+
fold_size = len(group_samples) // k
|
| 354 |
+
for fold_id in range(k):
|
| 355 |
+
val_start = fold_id * fold_size
|
| 356 |
+
val_end = len(group_samples) if fold_id == k - 1 else val_start + fold_size
|
| 357 |
+
|
| 358 |
+
for i, sample in enumerate(group_samples):
|
| 359 |
+
if val_start <= i < val_end:
|
| 360 |
+
folds_data[fold_id]["val"].append(sample)
|
| 361 |
+
else:
|
| 362 |
+
folds_data[fold_id]["train"].append(sample)
|
| 363 |
+
|
| 364 |
+
# Create fold objects
|
| 365 |
+
folds = [
|
| 366 |
+
CrossValidationFold(
|
| 367 |
+
fold_id=i,
|
| 368 |
+
train=data["train"],
|
| 369 |
+
validation=data["val"],
|
| 370 |
+
)
|
| 371 |
+
for i, data in enumerate(folds_data)
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
logger.info(f"Created stratified {k}-fold cross-validation splits")
|
| 375 |
+
return folds
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class BalancedSampler:
|
| 379 |
+
"""
|
| 380 |
+
Balanced sampling for imbalanced datasets.
|
| 381 |
+
|
| 382 |
+
Provides utilities for:
|
| 383 |
+
- Oversampling minority classes
|
| 384 |
+
- Undersampling majority classes
|
| 385 |
+
- SMOTE-like synthetic sampling (for numerical features)
|
| 386 |
+
"""
|
| 387 |
+
|
| 388 |
+
def __init__(self, seed: int = 42):
|
| 389 |
+
"""Initialize balanced sampler."""
|
| 390 |
+
self.seed = seed
|
| 391 |
+
import random
|
| 392 |
+
|
| 393 |
+
self.rng = random.Random(seed)
|
| 394 |
+
|
| 395 |
+
def oversample_minority(
|
| 396 |
+
self,
|
| 397 |
+
samples: list[DatasetSample],
|
| 398 |
+
target_key: str = "domain",
|
| 399 |
+
target_ratio: float = 1.0,
|
| 400 |
+
) -> list[DatasetSample]:
|
| 401 |
+
"""
|
| 402 |
+
Oversample minority classes to balance dataset.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
samples: Original samples
|
| 406 |
+
target_key: Attribute to balance on
|
| 407 |
+
target_ratio: Target ratio relative to majority (1.0 = equal)
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
Balanced sample list (originals + oversampled)
|
| 411 |
+
"""
|
| 412 |
+
# Group by target key
|
| 413 |
+
groups = defaultdict(list)
|
| 414 |
+
for sample in samples:
|
| 415 |
+
key = getattr(sample, target_key, "unknown") or "unknown"
|
| 416 |
+
groups[key].append(sample)
|
| 417 |
+
|
| 418 |
+
# Find majority class size
|
| 419 |
+
max_count = max(len(g) for g in groups.values())
|
| 420 |
+
target_count = int(max_count * target_ratio)
|
| 421 |
+
|
| 422 |
+
# Oversample minority classes
|
| 423 |
+
balanced = []
|
| 424 |
+
for _key, group in groups.items():
|
| 425 |
+
balanced.extend(group)
|
| 426 |
+
|
| 427 |
+
# Oversample if needed
|
| 428 |
+
if len(group) < target_count:
|
| 429 |
+
num_to_add = target_count - len(group)
|
| 430 |
+
for _ in range(num_to_add):
|
| 431 |
+
# Randomly duplicate from group
|
| 432 |
+
original = self.rng.choice(group)
|
| 433 |
+
duplicate = DatasetSample(
|
| 434 |
+
id=f"{original.id}_oversample_{self.rng.randint(0, 999999)}",
|
| 435 |
+
text=original.text,
|
| 436 |
+
metadata={**original.metadata, "oversampled": True},
|
| 437 |
+
labels=original.labels,
|
| 438 |
+
difficulty=original.difficulty,
|
| 439 |
+
domain=original.domain,
|
| 440 |
+
reasoning_steps=original.reasoning_steps,
|
| 441 |
+
)
|
| 442 |
+
balanced.append(duplicate)
|
| 443 |
+
|
| 444 |
+
logger.info(f"Oversampled from {len(samples)} to {len(balanced)} samples")
|
| 445 |
+
return balanced
|
| 446 |
+
|
| 447 |
+
def undersample_majority(
|
| 448 |
+
self,
|
| 449 |
+
samples: list[DatasetSample],
|
| 450 |
+
target_key: str = "domain",
|
| 451 |
+
target_ratio: float = 1.0,
|
| 452 |
+
) -> list[DatasetSample]:
|
| 453 |
+
"""
|
| 454 |
+
Undersample majority classes to balance dataset.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
samples: Original samples
|
| 458 |
+
target_key: Attribute to balance on
|
| 459 |
+
target_ratio: Target ratio relative to minority (1.0 = equal)
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Balanced sample list (subset of originals)
|
| 463 |
+
"""
|
| 464 |
+
# Group by target key
|
| 465 |
+
groups = defaultdict(list)
|
| 466 |
+
for sample in samples:
|
| 467 |
+
key = getattr(sample, target_key, "unknown") or "unknown"
|
| 468 |
+
groups[key].append(sample)
|
| 469 |
+
|
| 470 |
+
# Find minority class size
|
| 471 |
+
min_count = min(len(g) for g in groups.values())
|
| 472 |
+
target_count = int(min_count * target_ratio)
|
| 473 |
+
|
| 474 |
+
# Undersample majority classes
|
| 475 |
+
balanced = []
|
| 476 |
+
for _key, group in groups.items():
|
| 477 |
+
if len(group) > target_count:
|
| 478 |
+
# Randomly select target_count samples
|
| 479 |
+
balanced.extend(self.rng.sample(group, target_count))
|
| 480 |
+
else:
|
| 481 |
+
balanced.extend(group)
|
| 482 |
+
|
| 483 |
+
logger.info(f"Undersampled from {len(samples)} to {len(balanced)} samples")
|
| 484 |
+
return balanced
|
| 485 |
+
|
| 486 |
+
def get_class_distribution(
|
| 487 |
+
self,
|
| 488 |
+
samples: list[DatasetSample],
|
| 489 |
+
target_key: str = "domain",
|
| 490 |
+
) -> dict[str, int]:
|
| 491 |
+
"""
|
| 492 |
+
Get distribution of classes.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
samples: Sample list
|
| 496 |
+
target_key: Attribute to analyze
|
| 497 |
+
|
| 498 |
+
Returns:
|
| 499 |
+
Dictionary of class counts
|
| 500 |
+
"""
|
| 501 |
+
distribution = defaultdict(int)
|
| 502 |
+
for sample in samples:
|
| 503 |
+
key = getattr(sample, target_key, "unknown") or "unknown"
|
| 504 |
+
distribution[key] += 1
|
| 505 |
+
return dict(distribution)
|
src/framework/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Framework module
|
src/framework/agents/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Agents module for async agent implementations
|
| 2 |
+
from .base import (
|
| 3 |
+
AgentContext,
|
| 4 |
+
AgentResult,
|
| 5 |
+
AsyncAgentBase,
|
| 6 |
+
CompositeAgent,
|
| 7 |
+
MetricsCollector,
|
| 8 |
+
NoOpMetricsCollector,
|
| 9 |
+
ParallelAgent,
|
| 10 |
+
SequentialAgent,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"AsyncAgentBase",
|
| 15 |
+
"AgentContext",
|
| 16 |
+
"AgentResult",
|
| 17 |
+
"MetricsCollector",
|
| 18 |
+
"NoOpMetricsCollector",
|
| 19 |
+
"CompositeAgent",
|
| 20 |
+
"ParallelAgent",
|
| 21 |
+
"SequentialAgent",
|
| 22 |
+
]
|