diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..72e5ed2c2f28e5743260e77a37767989a1d1ec71 --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so + +# Virtual environment +venv/ +env/ +.env + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Gradio +flagged/ +gradio_cached_examples/ + +# Logs +*.log diff --git a/DEPLOYMENT_GUIDE.md b/DEPLOYMENT_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..4153b281870b511f6353e2cac491a2821d6ec7f7 --- /dev/null +++ b/DEPLOYMENT_GUIDE.md @@ -0,0 +1,306 @@ +# Hugging Face Spaces Deployment Guide + +This guide walks you through deploying the LangGraph Multi-Agent MCTS demo to Hugging Face Spaces. + +## Prerequisites + +- [Hugging Face Account](https://huggingface.co/join) +- Git installed locally +- Python 3.10+ (for local testing) + +## Step 1: Create a New Space + +1. Go to [Hugging Face Spaces](https://huggingface.co/spaces) +2. Click **"Create new Space"** +3. Fill in the form: + - **Owner**: Your username or organization + - **Space name**: `langgraph-mcts-demo` (or your choice) + - **License**: MIT + - **SDK**: Gradio + - **Hardware**: CPU Basic (Free tier - sufficient for demo) + - **Visibility**: Public (or Private) +4. Click **"Create Space"** + +## Step 2: Clone and Deploy + +### Option A: Git-based Deployment (Recommended) + +```bash +# 1. Clone your new empty Space +git clone https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo +cd langgraph-mcts-demo + +# 2. Copy demo files from this directory +cp -r /path/to/huggingface_space/* . +cp -r /path/to/huggingface_space/.gitignore . + +# 3. Verify structure +ls -la +# Should show: +# - app.py +# - requirements.txt +# - README.md +# - .gitignore +# - demo_src/ +# - __init__.py +# - agents_demo.py +# - llm_mock.py +# - mcts_demo.py + +# 4. Commit and push +git add -A +git commit -m "Initial deployment of LangGraph Multi-Agent MCTS demo" +git push + +# 5. Space will automatically build and deploy (takes 2-5 minutes) +``` + +### Option B: Direct Upload via Web UI + +1. Navigate to your Space on Hugging Face +2. Click **"Files"** tab +3. Click **"Add file"** → **"Upload files"** +4. Upload all files maintaining the directory structure: + - `app.py` + - `requirements.txt` + - `README.md` + - `.gitignore` + - `demo_src/__init__.py` + - `demo_src/agents_demo.py` + - `demo_src/llm_mock.py` + - `demo_src/mcts_demo.py` +5. Commit changes + +## Step 3: Monitor Deployment + +1. Go to your Space URL: `https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo` +2. Click **"Logs"** tab to monitor build progress +3. Wait for "Running on" message +4. Your demo is now live! + +## Step 4: Test the Demo + +1. Enter a query or select an example +2. Enable/disable different agents +3. Adjust MCTS parameters +4. Click "Process Query" +5. Review results and consensus scores + +## Optional: Enable Real LLM Responses + +To use Hugging Face Inference API instead of mock responses: + +### 1. Update requirements.txt + +```txt +gradio>=4.0.0,<5.0.0 +numpy>=1.24.0,<2.0.0 +huggingface_hub>=0.20.0 +``` + +### 2. Add Secret Token + +1. Go to Space Settings → **Repository secrets** +2. Add new secret: + - Name: `HF_TOKEN` + - Value: Your Hugging Face token (from [Settings → Access Tokens](https://huggingface.co/settings/tokens)) + +### 3. Update app.py Initialization + +Change line ~290 in `app.py`: + +```python +# From: +framework = MultiAgentFrameworkDemo(use_hf_inference=False) + +# To: +import os +framework = MultiAgentFrameworkDemo( + use_hf_inference=True, + hf_model="mistralai/Mistral-7B-Instruct-v0.2" +) +``` + +### 4. Commit and Push + +```bash +git add -A +git commit -m "Enable Hugging Face Inference API" +git push +``` + +## Optional: Enable Weights & Biases Tracking + +Track experiments and visualize metrics with W&B integration. + +### 1. Get W&B API Key + +1. Sign up at [wandb.ai](https://wandb.ai) +2. Go to Settings → API Keys +3. Copy your API key + +### 2. Add W&B Secret to Space + +1. Go to Space Settings → **Repository secrets** +2. Add new secret: + - Name: `WANDB_API_KEY` + - Value: Your W&B API key + +### 3. Use W&B in the Demo + +1. Expand "Weights & Biases Tracking" accordion in the UI +2. Check "Enable W&B Tracking" +3. Optionally set: + - **Project Name**: Your W&B project (default: `langgraph-mcts-demo`) + - **Run Name**: Custom name for this run (auto-generated if empty) +4. Process your query +5. View the W&B run URL in the results + +### 4. What Gets Logged + +- **Agent Metrics**: Confidence scores, execution times, response lengths +- **MCTS Metrics**: Best value, visits, tree depth, exploration paths +- **Consensus Metrics**: Agreement scores, agent combinations +- **Performance**: Total processing time +- **Artifacts**: Full JSON results as artifacts + +### 5. View Your Dashboard + +After runs, visit your W&B project dashboard to: +- Compare different agent configurations +- Visualize consensus patterns +- Analyze MCTS exploration strategies +- Track performance over time + +## Customization Options + +### Change Gradio Theme + +In `app.py`, modify: + +```python +with gr.Blocks( + theme=gr.themes.Soft(), # Try: Default(), Monochrome(), Glass() + ... +) as demo: +``` + +### Add Custom Examples + +Update `EXAMPLE_QUERIES` list in `app.py`: + +```python +EXAMPLE_QUERIES = [ + "Your custom query 1", + "Your custom query 2", + ... +] +``` + +### Adjust MCTS Parameters + +Modify sliders in `app.py`: + +```python +mcts_iterations = gr.Slider( + minimum=10, + maximum=200, # Increase for more thorough search + value=50, # Change default + ... +) +``` + +### Add More Agent Types + +1. Create new agent in `demo_src/agents_demo.py` +2. Add to `MultiAgentFrameworkDemo` in `app.py` +3. Add UI controls in Gradio interface + +## Troubleshooting + +### Build Fails + +- Check **Logs** tab for error details +- Verify `requirements.txt` has compatible versions +- Ensure all imports in `app.py` are satisfied + +### Slow Performance + +- Reduce default MCTS iterations +- Use mock LLM (no API calls) +- Simplify tree visualization + +### Memory Issues (Free Tier) + +- Limit max MCTS iterations to 100 +- Reduce tree depth in `demo_src/mcts_demo.py` +- Simplify response generation + +### Missing Files + +Ensure directory structure: +``` +your-space/ +├── app.py +├── requirements.txt +├── README.md +├── .gitignore +└── demo_src/ + ├── __init__.py + ├── agents_demo.py + ├── llm_mock.py + ├── mcts_demo.py + └── wandb_tracker.py +``` + +## Upgrading Hardware + +For better performance: + +1. Go to Space Settings +2. Under **Hardware**, select: + - **CPU Upgrade** ($0.03/hr) - Faster processing + - **T4 Small** ($0.60/hr) - GPU for neural models +3. Save changes + +## Sharing Your Space + +### Embed in Website + +```html + +``` + +### Direct Link + +Share: `https://huggingface.co/spaces/YOUR_USERNAME/langgraph-mcts-demo` + +### API Access + +Gradio automatically provides API endpoint: +``` +https://YOUR_USERNAME-langgraph-mcts-demo.hf.space/api/predict +``` + +## Next Steps + +1. **Collect Feedback**: Enable flagging for user feedback +2. **Add Analytics**: Track usage patterns +3. **Extend Agents**: Add domain-specific reasoning modules +4. **Integrate RAG**: Connect to vector databases for real context +5. **Add Visualization**: Enhanced tree and consensus displays + +## Support + +- **Hugging Face Docs**: https://huggingface.co/docs/hub/spaces +- **Gradio Docs**: https://www.gradio.app/docs +- **Full Framework**: https://github.com/ianshank/langgraph_multi_agent_mcts + +--- + +**Happy Deploying!** 🚀 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4d1bbfbe60a76289200e3169f54d1573b7f27a71 --- /dev/null +++ b/README.md @@ -0,0 +1,225 @@ +--- +title: LangGraph Multi-Agent MCTS Demo +emoji: 🌳 +colorFrom: blue +colorTo: green +sdk: gradio +sdk_version: 4.44.0 +app_file: app.py +pinned: false +license: mit +tags: + - multi-agent + - mcts + - reasoning + - langgraph + - ai-agents + - wandb + - experiment-tracking +short_description: Multi-agent reasoning framework with Monte Carlo Tree Search +--- + +# LangGraph Multi-Agent MCTS Framework + +**Production Demo with Trained Neural Models** - Experience real trained meta-controllers for intelligent agent routing + +## What This Demo Shows + +This interactive demo showcases trained neural meta-controllers that dynamically route queries to specialized agents: + +### 🤖 Trained Meta-Controllers + +1. **RNN Meta-Controller** + - GRU-based recurrent neural network + - Learns sequential patterns in agent performance + - Fast inference (~2ms latency) + - Trained on 1000+ synthetic routing examples + +2. **BERT Meta-Controller with LoRA** + - Transformer-based text understanding + - Parameter-efficient fine-tuning with LoRA adapters + - Context-aware routing decisions + - Better generalization to unseen query patterns + +### 🧠 Three Specialized Agents + +1. **HRM (Hierarchical Reasoning Module)** + - Best for: Complex decomposition, multi-level problems + - Technique: Hierarchical planning with adaptive computation + +2. **TRM (Tree Reasoning Module)** + - Best for: Iterative refinement, comparison tasks + - Technique: Recursive refinement with convergence detection + +3. **MCTS (Monte Carlo Tree Search)** + - Best for: Optimization, strategic planning + - Technique: UCB1 exploration with value backpropagation + +### 📊 Key Features + +- **Real Trained Models**: Production-ready neural meta-controllers +- **Intelligent Routing**: Models learn optimal agent selection patterns +- **Routing Visualization**: See confidence scores and probability distributions +- **Feature Engineering**: Demonstrates query → features → routing pipeline +- **Performance Metrics**: Track execution time and routing accuracy + +## How to Use + +1. **Enter a Query**: Type your question or select an example +2. **Select Controller**: Choose RNN (fast) or BERT (context-aware) +3. **Process Query**: Click "🚀 Process Query" +4. **Review Results**: + - See which agent the controller selected + - View routing confidence and probabilities + - Examine features used for decision-making + - Check agent execution details + +## Weights & Biases Integration + +Track your experiments with **Weights & Biases** for: +- 📈 **Metrics Dashboard**: Visualize consensus scores, execution times, agent performance +- 🔄 **Run Comparison**: Compare different configurations side-by-side +- 📊 **Experiment History**: Track all your queries and results +- 🌳 **MCTS Visualization**: Log tree exploration patterns + +### Setting Up W&B + +1. **Get API Key**: Sign up at [wandb.ai](https://wandb.ai) and get your API key +2. **Configure Space Secret** (if deploying your own): + - Go to Space Settings → Repository secrets + - Add: `WANDB_API_KEY` = your API key +3. **Enable in UI**: + - Expand "Weights & Biases Tracking" accordion + - Check "Enable W&B Tracking" + - Set project name (optional) + - Set run name (optional, auto-generated if empty) +4. **View Results**: After processing, click the W&B run URL to see your dashboard + +### Logged Metrics + +- **Per Agent**: Confidence, execution time, response length, reasoning steps +- **MCTS**: Best value, visits, tree depth, top actions with UCB1 scores +- **Consensus**: Score, level (high/medium/low), number of agents +- **Performance**: Total processing time +- **Artifacts**: Full JSON results, tree visualizations + +## Example Queries + +- "What are the key factors to consider when choosing between microservices and monolithic architecture?" +- "How can we optimize a Python application that processes 10GB of log files daily?" +- "Should we use SQL or NoSQL database for a social media application with 1M users?" +- "How to design a fault-tolerant message queue system?" + +## Technical Details + +### Architecture + +``` +Query Input + │ + ├─→ HRM Agent (Hierarchical Decomposition) + │ ├─ Component Analysis + │ └─ Structured Synthesis + │ + ├─→ TRM Agent (Iterative Refinement) + │ ├─ Initial Response + │ ├─ Clarity Enhancement + │ └─ Validation Check + │ + └─→ MCTS Engine (Strategic Search) + ├─ Selection (UCB1) + ├─ Expansion + ├─ Simulation + └─ Backpropagation + │ + ▼ + Consensus Scoring + │ + ▼ + Final Synthesized Response +``` + +### MCTS Algorithm + +The Monte Carlo Tree Search implementation uses: + +- **UCB1 Selection**: `Q(s,a) + C * sqrt(ln(N(s)) / N(s,a))` +- **Progressive Widening**: Controls branching factor +- **Domain-Aware Actions**: Contextual decision options +- **Value Backpropagation**: Updates entire path statistics + +### Consensus Calculation + +``` +consensus = average_confidence * agreement_factor +agreement_factor = max(0, 1 - std_deviation * 2) +``` + +High consensus (>70%) indicates agents agree on approach. +Low consensus (<40%) suggests uncertainty or conflicting strategies. + +## Demo Scope + +This demonstration focuses on **meta-controller training and routing**: + +- ✅ **Real Trained Models**: Production RNN and BERT controllers +- ✅ **Actual Model Loading**: PyTorch and HuggingFace Transformers +- ✅ **Feature Engineering**: Query analysis → feature vectors +- ✅ **Routing Visualization**: See controller decision-making +- ⚠️ **Simplified Agents**: Agent responses are mocked for demo purposes +- ⚠️ **No Live LLM Calls**: Agents don't call actual LLMs (to reduce latency/cost) + +## Full Production Framework + +The complete repository includes all production features: + +- ✅ **Neural Meta-Controllers**: RNN and BERT with LoRA (deployed here!) +- ✅ **Agent Implementations**: Full HRM, TRM, and MCTS with PyTorch +- ✅ **Training Pipeline**: Data generation, training, evaluation +- ✅ **LLM Integration**: OpenAI, Anthropic, LM Studio support +- ✅ **RAG Systems**: ChromaDB, FAISS, Pinecone vector stores +- ✅ **Observability**: OpenTelemetry tracing, Prometheus metrics +- ✅ **Storage**: S3 artifact storage, experiment tracking +- ✅ **CI/CD**: Automated testing, security scanning, deployment + +**GitHub Repository**: [ianshank/langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts) + +## Technical Stack + +- **Python**: 3.11+ +- **UI**: Gradio 4.x +- **ML Frameworks**: PyTorch 2.1+, Transformers, PEFT (LoRA) +- **Models**: GRU-based RNN, BERT-mini with LoRA adapters +- **Architecture**: Neural meta-controller + multi-agent system +- **Experiment Tracking**: Weights & Biases (optional) +- **Numerical**: NumPy + +## Research Applications + +This framework demonstrates concepts applicable to: + +- Complex decision-making systems +- AI-assisted software architecture decisions +- Multi-perspective problem analysis +- Strategic planning with uncertainty + +## Citation + +If you use this framework in research, please cite: + +```bibtex +@software{langgraph_mcts_2024, + title={LangGraph Multi-Agent MCTS Framework}, + author={Your Name}, + year={2024}, + url={https://github.com/ianshank/langgraph_multi_agent_mcts} +} +``` + +## License + +MIT License - See repository for details. + +--- + +**Built with** LangGraph, Gradio, and Python | **Demo Version**: 1.0.0 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4543e9791790e1600dac3a2e611801079ecead0b --- /dev/null +++ b/app.py @@ -0,0 +1,553 @@ +""" +LangGraph Multi-Agent MCTS Framework - Integrated Demo with Trained Models + +Demonstrates the actual trained neural meta-controllers: +- RNN Meta-Controller for sequential pattern recognition +- BERT with LoRA adapters for text-based routing + +This is a production demonstration using real trained models. +""" + +import asyncio +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +# Fail fast if critical dependencies are missing or broken +try: + import peft + + print(f"[OK] PEFT library imported successfully (version: {peft.__version__})") +except ImportError as e: + print(f"CRITICAL ERROR: Could not import peft library: {e}") + # We don't exit here to allow the app to crash naturally later with full stack trace, + # but this print ensures it's visible in the logs immediately. + +import gradio as gr +import torch + +# Import the trained controllers +sys.path.insert(0, str(Path(__file__).parent)) + +from src.agents.meta_controller.base import MetaControllerFeatures +from src.agents.meta_controller.bert_controller import BERTMetaController +from src.agents.meta_controller.rnn_controller import RNNMetaController +from src.agents.meta_controller.feature_extractor import ( + FeatureExtractor, + FeatureExtractorConfig, +) +from src.utils.personality_response import PersonalityResponseGenerator + + +@dataclass +class AgentResult: + """Result from a single agent.""" + + agent_name: str + response: str + confidence: float + reasoning_steps: list[str] + execution_time_ms: float + + +@dataclass +class ControllerDecision: + """Decision made by the meta-controller.""" + + selected_agent: str + confidence: float + routing_probabilities: dict[str, float] + features_used: dict + + +def create_features_from_query( + query: str, + iteration: int = 0, + last_agent: str = "none", + feature_extractor: FeatureExtractor | None = None, +) -> MetaControllerFeatures: + """ + Convert a text query into features for the meta-controller. + + Uses semantic embeddings for robust feature extraction. Falls back to + heuristic-based extraction if embeddings are not available. + + Args: + query: The input query text + iteration: Current iteration number + last_agent: Name of the last agent used + feature_extractor: Optional FeatureExtractor instance (created if None) + + Returns: + MetaControllerFeatures instance + """ + # Use provided feature extractor or create a new one + if feature_extractor is None: + try: + config = FeatureExtractorConfig.from_env() + feature_extractor = FeatureExtractor(config) + except Exception as e: + print(f"Warning: Failed to initialize FeatureExtractor: {e}") + print("Falling back to heuristic-based feature extraction") + # Will use heuristic fallback below + + # Extract features using the feature extractor + try: + if feature_extractor is not None: + return feature_extractor.extract_features(query, iteration, last_agent) + except Exception as e: + print(f"Warning: Feature extraction failed: {e}") + print("Falling back to heuristic-based feature extraction") + + # Fallback to original heuristic-based extraction + # (This code is kept as a safety net but should rarely be used) + query_length = len(query) + + # Estimate complexity based on query characteristics + has_multiple_questions = "?" in query and query.count("?") > 1 + has_comparison = any(word in query.lower() for word in ["vs", "versus", "compare", "difference", "better"]) + has_optimization = any(word in query.lower() for word in ["optimize", "best", "improve", "maximize", "minimize"]) + has_technical = any(word in query.lower() for word in ["algorithm", "code", "implement", "technical", "system"]) + + # Create mock confidence scores based on query characteristics + hrm_confidence = 0.5 + (0.3 if has_multiple_questions else 0) + (0.1 if has_technical else 0) + trm_confidence = 0.5 + (0.3 if has_comparison else 0) + (0.1 if query_length > 100 else 0) + mcts_confidence = 0.5 + (0.3 if has_optimization else 0) + (0.1 if has_technical else 0) + + # Normalize + total = hrm_confidence + trm_confidence + mcts_confidence + if total == 0: + hrm_confidence = 1.0 / 3.0 + trm_confidence = 1.0 / 3.0 + mcts_confidence = 1.0 / 3.0 + else: + hrm_confidence /= total + trm_confidence /= total + mcts_confidence /= total + + # Calculate consensus score + max_confidence = max(hrm_confidence, trm_confidence, mcts_confidence) + if max_confidence == 0: + consensus_score = 0.0 + else: + consensus_score = min(hrm_confidence, trm_confidence, mcts_confidence) / max_confidence + + features = MetaControllerFeatures( + hrm_confidence=hrm_confidence, + trm_confidence=trm_confidence, + mcts_value=mcts_confidence, + consensus_score=consensus_score, + last_agent=last_agent, + iteration=iteration, + query_length=query_length, + has_rag_context=query_length > 50, + rag_relevance_score=0.7 if query_length > 50 else 0.0, + is_technical_query=has_technical, + ) + + return features + + +class IntegratedFramework: + """ + Integrated multi-agent framework using trained meta-controllers. + """ + + def __init__(self): + """Initialize the framework with trained models.""" + self.device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {self.device}") + + # Initialize feature extractor with semantic embeddings + print("Initializing Feature Extractor...") + try: + config = FeatureExtractorConfig.from_env() + # Set device to match the framework device + config.device = self.device + self.feature_extractor = FeatureExtractor(config) + print(f"[OK] Feature Extractor initialized: {self.feature_extractor}") + except Exception as e: + print(f"[WARN] Failed to initialize Feature Extractor: {e}") + print("[WARN] Will fall back to heuristic-based feature extraction") + self.feature_extractor = None + + # Load trained RNN Meta-Controller + print("Loading RNN Meta-Controller...") + self.rnn_controller = RNNMetaController(name="RNNController", seed=42, device=self.device) + + # Load the trained weights + rnn_model_path = Path(__file__).parent / "models" / "rnn_meta_controller.pt" + if rnn_model_path.exists(): + checkpoint = torch.load(rnn_model_path, map_location=self.device, weights_only=True) + self.rnn_controller.model.load_state_dict(checkpoint) + self.rnn_controller.model.eval() + print(f"[OK] Loaded RNN model from {rnn_model_path}") + else: + print(f"[WARN] RNN model not found at {rnn_model_path}, using untrained model") + + # Load trained BERT Meta-Controller with LoRA + print("Loading BERT Meta-Controller with LoRA...") + self.bert_controller = BERTMetaController(name="BERTController", seed=42, device=self.device, use_lora=True) + + bert_model_path = Path(__file__).parent / "models" / "bert_lora" / "final_model" + if bert_model_path.exists(): + try: + self.bert_controller.load_model(str(bert_model_path)) + print(f"[OK] Loaded BERT LoRA model from {bert_model_path}") + except Exception as e: + print(f"[WARN] Error loading BERT model: {e}") + print("Using untrained BERT model") + else: + print(f"[WARN] BERT model not found at {bert_model_path}, using untrained model") + + # Agent routing map + self.agent_handlers = { + "hrm": self._handle_hrm, + "trm": self._handle_trm, + "mcts": self._handle_mcts, + } + + print("Framework initialized successfully!") + + async def process_query( + self, + query: str, + controller_type: str = "rnn", + ) -> tuple[AgentResult, ControllerDecision]: + """ + Process a query using the trained meta-controller. + + Args: + query: The input query + controller_type: Which controller to use ("rnn" or "bert") + + Returns: + (agent_result, controller_decision) tuple + """ + start_time = time.perf_counter() + + # Step 1: Convert query to features using semantic embeddings + features = create_features_from_query(query, feature_extractor=self.feature_extractor) + + # Step 2: Get controller decision + if controller_type == "rnn": + prediction = self.rnn_controller.predict(features) + else: # bert + prediction = self.bert_controller.predict(features) + + selected_agent = prediction.agent + confidence = prediction.confidence + + # Get routing probabilities (prediction.probabilities is already a dict) + routing_probs = prediction.probabilities + + # Step 3: Route to selected agent + handler = self.agent_handlers.get(selected_agent, self._handle_hrm) + agent_result = await handler(query) + + # Create controller decision summary + controller_decision = ControllerDecision( + selected_agent=selected_agent, + confidence=confidence, + routing_probabilities=routing_probs, + features_used={ + "hrm_confidence": features.hrm_confidence, + "trm_confidence": features.trm_confidence, + "mcts_value": features.mcts_value, + "consensus_score": features.consensus_score, + "query_length": features.query_length, + "is_technical": features.is_technical_query, + }, + ) + + total_time = (time.perf_counter() - start_time) * 1000 + agent_result.execution_time_ms = round(total_time, 2) + + return agent_result, controller_decision + + async def _handle_hrm(self, query: str) -> AgentResult: + """Handle query with Hierarchical Reasoning Module.""" + # Simulate HRM processing + await asyncio.sleep(0.1) + + steps = [ + "Decompose query into hierarchical subproblems", + "Apply high-level reasoning (H-Module)", + "Execute low-level refinement (L-Module)", + "Synthesize hierarchical solution", + ] + + response = f"[HRM Analysis] Breaking down the problem hierarchically: {query[:100]}..." + + return AgentResult( + agent_name="HRM (Hierarchical Reasoning)", + response=response, + confidence=0.85, + reasoning_steps=steps, + execution_time_ms=0.0, + ) + + async def _handle_trm(self, query: str) -> AgentResult: + """Handle query with Tree Reasoning Module.""" + # Simulate TRM processing + await asyncio.sleep(0.1) + + steps = [ + "Initialize solution state", + "Recursive refinement iteration 1", + "Recursive refinement iteration 2", + "Convergence achieved - finalize", + ] + + response = f"[TRM Analysis] Applying iterative refinement: {query[:100]}..." + + return AgentResult( + agent_name="TRM (Iterative Refinement)", + response=response, + confidence=0.80, + reasoning_steps=steps, + execution_time_ms=0.0, + ) + + async def _handle_mcts(self, query: str) -> AgentResult: + """Handle query with MCTS.""" + # Simulate MCTS processing + await asyncio.sleep(0.15) + + steps = [ + "Build search tree", + "Selection: UCB1 exploration", + "Expansion: Add promising nodes", + "Simulation: Rollout evaluation", + "Backpropagation: Update values", + ] + + response = f"[MCTS Analysis] Strategic exploration via tree search: {query[:100]}..." + + return AgentResult( + agent_name="MCTS (Monte Carlo Tree Search)", + response=response, + confidence=0.88, + reasoning_steps=steps, + execution_time_ms=0.0, + ) + + +# Global framework instance +framework = None + + +def initialize_framework(): + """Initialize or reinitialize the framework.""" + global framework + try: + framework = IntegratedFramework() + return "[OK] Framework initialized with trained models!" + except Exception as e: + return f"[ERROR] Error initializing framework: {str(e)}" + + +def process_query_sync( + query: str, + controller_type: str, +): + """Synchronous wrapper for async processing.""" + global framework + + if framework is None: + framework = IntegratedFramework() + + if not query.strip(): + return ("Please enter a query.", {}, "", {}, "", "") + + # Run async function + agent_result, controller_decision = asyncio.run( + framework.process_query(query=query, controller_type=controller_type.lower()) + ) + + # Format outputs + final_response = agent_result.response + + # Generate personality-infused response + personality_gen = PersonalityResponseGenerator() + try: + personality_response = personality_gen.generate_response( + agent_response=final_response, + query=query + ) + except Exception as e: + # Fallback to a simple wrapper if personality generation fails + personality_response = f"Here's what I found:\n\n{final_response}" + print(f"Warning: Personality generation failed: {e}") + + # Controller decision visualization + routing_viz = "### 🧠 Meta-Controller Decision\n\n" + routing_viz += f"**Selected Agent:** `{controller_decision.selected_agent.upper()}`\n\n" + routing_viz += f"**Confidence:** {controller_decision.confidence:.1%}\n\n" + routing_viz += "**Routing Probabilities:**\n" + for agent, prob in controller_decision.routing_probabilities.items(): + bar = "█" * int(prob * 50) + routing_viz += f"- **{agent.upper()}**: {prob:.1%} {bar}\n" + + # Agent details + agent_details = { + "agent": agent_result.agent_name, + "confidence": f"{agent_result.confidence:.1%}", + "reasoning_steps": agent_result.reasoning_steps, + "execution_time_ms": agent_result.execution_time_ms, + } + + # Features used + features_viz = "### 📊 Features Used for Routing\n\n" + for feature, value in controller_decision.features_used.items(): + if isinstance(value, float): + features_viz += f"- **{feature}**: {value:.3f}\n" + elif isinstance(value, bool): + features_viz += f"- **{feature}**: {'Yes' if value else 'No'}\n" + else: + features_viz += f"- **{feature}**: {value}\n" + + # Metrics + metrics = f""" +**Controller:** {controller_type} +**Execution Time:** {agent_result.execution_time_ms:.2f} ms +**Agent Confidence:** {agent_result.confidence:.1%} +""" + + return final_response, agent_details, routing_viz, features_viz, metrics, personality_response + + +# Example queries +EXAMPLE_QUERIES = [ + "What are the key factors to consider when choosing between microservices and monolithic architecture?", + "How can we optimize a Python application that processes 10GB of log files daily?", + "Compare the performance characteristics of B-trees vs LSM-trees for write-heavy workloads", + "Design a distributed rate limiting system that handles 100k requests per second", + "Explain the difference between supervised and unsupervised learning with examples", +] + + +# Gradio Interface +with gr.Blocks( + title="LangGraph Multi-Agent MCTS - Trained Models Demo", + theme=gr.themes.Soft(), + css=""" + .agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; } + .highlight { background-color: #e3f2fd; padding: 10px; border-radius: 5px; margin: 10px 0; } + """, +) as demo: + gr.Markdown( + """ + # 🎯 LangGraph Multi-Agent MCTS Framework + ## Production Demo with Trained Neural Meta-Controllers + + This demo uses **REAL trained models**: + - 🧠 **RNN Meta-Controller**: GRU-based sequential pattern recognition + - 🤖 **BERT with LoRA**: Transformer-based text understanding for routing + + The meta-controllers learn to route queries to the optimal agent: + - **HRM**: Hierarchical reasoning for complex decomposition + - **TRM**: Iterative refinement for progressive improvement + - **MCTS**: Strategic exploration for optimization problems + + --- + """ + ) + + with gr.Row(): + with gr.Column(scale=2): + query_input = gr.Textbox( + label="Query", placeholder="Enter your question or reasoning task...", lines=4, max_lines=10 + ) + + gr.Markdown("**Example Queries:**") + example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True) + + def load_example(example): + return example + + example_dropdown.change(load_example, example_dropdown, query_input) + + with gr.Column(scale=1): + gr.Markdown("**Meta-Controller Selection**") + controller_type = gr.Radio( + choices=["RNN", "BERT"], + value="RNN", + label="Controller Type", + info="Choose which trained controller to use", + ) + + gr.Markdown( + """ + **Controller Comparison:** + - **RNN**: Fast, captures sequential patterns + - **BERT**: More context-aware, text understanding + """ + ) + + process_btn = gr.Button("🚀 Process Query", variant="primary", size="lg") + + gr.Markdown("---") + + with gr.Row(): + with gr.Column(): + gr.Markdown("### 🎯 Agent Response") + final_response_output = gr.Textbox(label="Response", lines=4, interactive=False) + + gr.Markdown("### 🤝 Personality-Infused Response") + gr.Markdown("*A conversational, balanced advisor interpretation*") + personality_output = gr.Textbox(label="Balanced Advisor Response", lines=8, interactive=False) + + gr.Markdown("### 📈 Performance Metrics") + metrics_output = gr.Markdown() + + with gr.Column(): + routing_viz = gr.Markdown(label="Controller Decision") + features_viz = gr.Markdown(label="Features") + + with gr.Accordion("🔍 Detailed Agent Information", open=False): + agent_details_output = gr.JSON(label="Agent Execution Details") + + # Wire up the processing + process_btn.click( + fn=process_query_sync, + inputs=[ + query_input, + controller_type, + ], + outputs=[final_response_output, agent_details_output, routing_viz, features_viz, metrics_output, personality_output], + ) + + gr.Markdown( + """ + --- + + ### 📚 About This Demo + + This is a **production demonstration** of trained neural meta-controllers for multi-agent routing. + + **Models:** + - RNN Meta-Controller: 10-dimensional feature vector → 3-class routing (HRM/TRM/MCTS) + - BERT with LoRA: Text features → routing decision with adapters + + **Training:** + - Synthetic dataset: 1000+ samples with balanced routing decisions + - Optimization: Adam optimizer, cross-entropy loss + - Validation: 80/20 train/val split with early stopping + + **Repository:** [GitHub - langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts) + + --- + *Built with PyTorch, Transformers, PEFT, and Gradio* + """ + ) + + +if __name__ == "__main__": + # Initialize framework + print("Initializing framework with trained models...") + framework = IntegratedFramework() + + # Launch the demo + demo.launch(server_name="0.0.0.0", share=False, show_error=True) diff --git a/app_mock.py b/app_mock.py new file mode 100644 index 0000000000000000000000000000000000000000..e8433a617be2773f6a62a15d3a511cbc01a2ba87 --- /dev/null +++ b/app_mock.py @@ -0,0 +1,590 @@ +""" +LangGraph Multi-Agent MCTS Framework - Hugging Face Spaces Demo + +A proof-of-concept demonstration of multi-agent reasoning with Monte Carlo Tree Search. +""" + +import asyncio +import time +from dataclasses import dataclass + +import gradio as gr +import numpy as np + +# Demo-specific simplified implementations +from demo_src.agents_demo import HRMAgent, TRMAgent +from demo_src.llm_mock import HuggingFaceClient, MockLLMClient +from demo_src.mcts_demo import MCTSDemo +from demo_src.wandb_tracker import WandBTracker, is_wandb_available + + +@dataclass +class AgentResult: + """Result from a single agent.""" + + agent_name: str + response: str + confidence: float + reasoning_steps: list[str] + execution_time_ms: float + + +@dataclass +class FrameworkResult: + """Combined result from all agents.""" + + query: str + hrm_result: AgentResult | None + trm_result: AgentResult | None + mcts_result: dict | None + consensus_score: float + final_response: str + total_time_ms: float + metadata: dict + + +class MultiAgentFrameworkDemo: + """Simplified multi-agent framework for Hugging Face Spaces demo.""" + + def __init__(self, use_hf_inference: bool = False, hf_model: str = ""): + """Initialize the demo framework. + + Args: + use_hf_inference: Use Hugging Face Inference API instead of mock + hf_model: Hugging Face model ID for inference + """ + self.use_hf_inference = use_hf_inference + self.hf_model = hf_model + + # Initialize components + if use_hf_inference and hf_model: + self.llm_client = HuggingFaceClient(model_id=hf_model) + else: + self.llm_client = MockLLMClient() + + self.hrm_agent = HRMAgent(self.llm_client) + self.trm_agent = TRMAgent(self.llm_client) + self.mcts = MCTSDemo() + + async def process_query( + self, + query: str, + use_hrm: bool = True, + use_trm: bool = True, + use_mcts: bool = False, + mcts_iterations: int = 25, + exploration_weight: float = 1.414, + seed: int | None = None, + ) -> FrameworkResult: + """Process a query through the multi-agent framework. + + Args: + query: The input query to process + use_hrm: Enable Hierarchical Reasoning Module + use_trm: Enable Tree Reasoning Module + use_mcts: Enable Monte Carlo Tree Search + mcts_iterations: Number of MCTS iterations + exploration_weight: UCB1 exploration parameter + seed: Random seed for reproducibility + + Returns: + FrameworkResult with all agent outputs and consensus + """ + start_time = time.perf_counter() + + hrm_result = None + trm_result = None + mcts_result = None + + # Run enabled agents + tasks = [] + agent_names = [] + + if use_hrm: + tasks.append(self._run_hrm(query)) + agent_names.append("hrm") + + if use_trm: + tasks.append(self._run_trm(query)) + agent_names.append("trm") + + if use_mcts: + tasks.append(self._run_mcts(query, mcts_iterations, exploration_weight, seed)) + agent_names.append("mcts") + + # Execute agents concurrently + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + + for name, result in zip(agent_names, results, strict=False): + if isinstance(result, Exception): + continue + if name == "hrm": + hrm_result = result + elif name == "trm": + trm_result = result + elif name == "mcts": + mcts_result = result + + # Calculate consensus score + consensus_score = self._calculate_consensus(hrm_result, trm_result, mcts_result) + + # Generate final synthesized response + final_response = self._synthesize_response(query, hrm_result, trm_result, mcts_result, consensus_score) + + total_time = (time.perf_counter() - start_time) * 1000 + + return FrameworkResult( + query=query, + hrm_result=hrm_result, + trm_result=trm_result, + mcts_result=mcts_result, + consensus_score=consensus_score, + final_response=final_response, + total_time_ms=round(total_time, 2), + metadata={ + "agents_used": agent_names, + "mcts_config": ( + {"iterations": mcts_iterations, "exploration_weight": exploration_weight, "seed": seed} + if use_mcts + else None + ), + }, + ) + + async def _run_hrm(self, query: str) -> AgentResult: + """Run Hierarchical Reasoning Module.""" + start = time.perf_counter() + result = await self.hrm_agent.process(query) + elapsed = (time.perf_counter() - start) * 1000 + + return AgentResult( + agent_name="HRM (Hierarchical Reasoning)", + response=result["response"], + confidence=result["confidence"], + reasoning_steps=result["steps"], + execution_time_ms=round(elapsed, 2), + ) + + async def _run_trm(self, query: str) -> AgentResult: + """Run Tree Reasoning Module.""" + start = time.perf_counter() + result = await self.trm_agent.process(query) + elapsed = (time.perf_counter() - start) * 1000 + + return AgentResult( + agent_name="TRM (Iterative Refinement)", + response=result["response"], + confidence=result["confidence"], + reasoning_steps=result["steps"], + execution_time_ms=round(elapsed, 2), + ) + + async def _run_mcts(self, query: str, iterations: int, exploration_weight: float, seed: int | None) -> dict: + """Run Monte Carlo Tree Search.""" + start = time.perf_counter() + + # MCTSDemo.search is now async and uses the production framework + result = await self.mcts.search(query=query, iterations=iterations, exploration_weight=exploration_weight, seed=seed) + + elapsed = (time.perf_counter() - start) * 1000 + result["execution_time_ms"] = round(elapsed, 2) + + return result + + def _calculate_consensus( + self, hrm_result: AgentResult | None, trm_result: AgentResult | None, mcts_result: dict | None + ) -> float: + """Calculate agreement score between agents.""" + confidences = [] + + if hrm_result: + confidences.append(hrm_result.confidence) + if trm_result: + confidences.append(trm_result.confidence) + if mcts_result: + confidences.append(mcts_result.get("best_value", 0.5)) + + if not confidences: + return 0.0 + + # Consensus is based on confidence alignment and average + if len(confidences) == 1: + return confidences[0] + + avg_confidence = np.mean(confidences) + std_confidence = np.std(confidences) + + # Higher consensus when agents agree (low std) and are confident (high avg) + agreement_factor = max(0, 1 - std_confidence * 2) + consensus = avg_confidence * agreement_factor + + return round(min(1.0, consensus), 3) + + def _synthesize_response( + self, + query: str, + hrm_result: AgentResult | None, + trm_result: AgentResult | None, + mcts_result: dict | None, + consensus_score: float, + ) -> str: + """Synthesize final response from all agent outputs.""" + parts = [] + + if hrm_result and hrm_result.confidence > 0.5: + parts.append(f"[HRM] {hrm_result.response}") + + if trm_result and trm_result.confidence > 0.5: + parts.append(f"[TRM] {trm_result.response}") + + if mcts_result and mcts_result.get("best_value", 0) > 0.5: + parts.append(f"[MCTS] Best path: {mcts_result.get('best_action', 'N/A')}") + + if not parts: + truncated_query = f"{query[:80]}..." if len(query) > 80 else query + return f"Insufficient confidence to answer query: '{truncated_query}'." + + synthesis = " | ".join(parts) + + if consensus_score > 0.7: + return f"HIGH CONSENSUS ({consensus_score:.1%}): {synthesis}" + elif consensus_score > 0.4: + return f"MODERATE CONSENSUS ({consensus_score:.1%}): {synthesis}" + else: + return f"LOW CONSENSUS ({consensus_score:.1%}): {synthesis}" + + +# Global framework instance +framework = None +wandb_tracker = None + + +def initialize_framework(use_hf: bool, model_id: str): + """Initialize or reinitialize the framework.""" + global framework + framework = MultiAgentFrameworkDemo(use_hf_inference=use_hf, hf_model=model_id) + return "Framework initialized successfully!" + + +def process_query_sync( + query: str, + use_hrm: bool, + use_trm: bool, + use_mcts: bool, + mcts_iterations: int, + exploration_weight: float, + seed: int, + enable_wandb: bool = False, + wandb_project: str = "langgraph-mcts-demo", + wandb_run_name: str = "", +): + """Synchronous wrapper for async processing.""" + global framework, wandb_tracker + + if framework is None: + framework = MultiAgentFrameworkDemo() + + if not query.strip(): + return "Please enter a query.", {}, "", {}, "" + + # Handle seed + seed_value = seed if seed > 0 else None + + # Initialize W&B tracking if enabled + wandb_url = "" + if enable_wandb and is_wandb_available(): + if wandb_tracker is None: + wandb_tracker = WandBTracker(project_name=wandb_project, enabled=True) + + # Start a new run + run_name = wandb_run_name if wandb_run_name.strip() else None + config = { + "query": query[:200], # Truncate for config + "use_hrm": use_hrm, + "use_trm": use_trm, + "use_mcts": use_mcts, + "mcts_iterations": mcts_iterations, + "exploration_weight": exploration_weight, + "seed": seed_value, + } + wandb_tracker.init_run(run_name=run_name, config=config) + + # Run async function + result = asyncio.run( + framework.process_query( + query=query, + use_hrm=use_hrm, + use_trm=use_trm, + use_mcts=use_mcts, + mcts_iterations=int(mcts_iterations), + exploration_weight=exploration_weight, + seed=seed_value, + ) + ) + + # Format outputs + final_response = result.final_response + + # Agent details + agent_details = {} + if result.hrm_result: + agent_details["HRM"] = { + "response": result.hrm_result.response, + "confidence": f"{result.hrm_result.confidence:.1%}", + "reasoning_steps": result.hrm_result.reasoning_steps, + "time_ms": result.hrm_result.execution_time_ms, + } + + # Log to W&B + if enable_wandb and wandb_tracker: + wandb_tracker.log_agent_result( + "HRM", + result.hrm_result.response, + result.hrm_result.confidence, + result.hrm_result.execution_time_ms, + result.hrm_result.reasoning_steps, + ) + + if result.trm_result: + agent_details["TRM"] = { + "response": result.trm_result.response, + "confidence": f"{result.trm_result.confidence:.1%}", + "reasoning_steps": result.trm_result.reasoning_steps, + "time_ms": result.trm_result.execution_time_ms, + } + + # Log to W&B + if enable_wandb and wandb_tracker: + wandb_tracker.log_agent_result( + "TRM", + result.trm_result.response, + result.trm_result.confidence, + result.trm_result.execution_time_ms, + result.trm_result.reasoning_steps, + ) + + if result.mcts_result: + agent_details["MCTS"] = result.mcts_result + + # Log to W&B + if enable_wandb and wandb_tracker: + wandb_tracker.log_mcts_result(result.mcts_result) + + # Log consensus and performance to W&B + if enable_wandb and wandb_tracker: + wandb_tracker.log_consensus(result.consensus_score, result.metadata["agents_used"], result.final_response) + wandb_tracker.log_performance(result.total_time_ms) + wandb_tracker.log_query_summary(query, use_hrm, use_trm, use_mcts, result.consensus_score, result.total_time_ms) + + # Get run URL + wandb_url = wandb_tracker.get_run_url() or "" + + # Finish the run + wandb_tracker.finish_run() + + # Metrics + metrics = f""" +**Consensus Score:** {result.consensus_score:.1%} +**Total Processing Time:** {result.total_time_ms:.2f} ms +**Agents Used:** {", ".join(result.metadata["agents_used"])} +""" + + if wandb_url: + metrics += f"\n**W&B Run:** [{wandb_url}]({wandb_url})" + + # Full JSON result + full_result = { + "query": result.query, + "final_response": result.final_response, + "consensus_score": result.consensus_score, + "total_time_ms": result.total_time_ms, + "metadata": result.metadata, + "agent_details": agent_details, + "wandb_url": wandb_url, + } + + return final_response, agent_details, metrics, full_result, wandb_url + + +def visualize_mcts_tree(mcts_result: dict) -> str: + """Create ASCII visualization of MCTS tree.""" + if not mcts_result or "tree_visualization" not in mcts_result: + return "No MCTS tree data available" + + return mcts_result["tree_visualization"] + + +# Example queries for demonstration +EXAMPLE_QUERIES = [ + "What are the key factors to consider when choosing between microservices and monolithic architecture?", + "How can we optimize a Python application that processes 10GB of log files daily?", + "What is the best approach to implement rate limiting in a distributed system?", + "Should we use SQL or NoSQL database for a social media application with 1M users?", + "How to design a fault-tolerant message queue system?", +] + + +# Gradio Interface +with gr.Blocks( + title="LangGraph Multi-Agent MCTS Demo", + theme=gr.themes.Soft(), + css=""" + .agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; } + .consensus-high { color: #28a745; font-weight: bold; } + .consensus-medium { color: #ffc107; font-weight: bold; } + .consensus-low { color: #dc3545; font-weight: bold; } + """, +) as demo: + gr.Markdown( + """ + # LangGraph Multi-Agent MCTS Framework + + **Proof-of-Concept Demo** - Multi-agent reasoning with Monte Carlo Tree Search + + This demo showcases: + - **HRM**: Hierarchical Reasoning Module - breaks down complex queries + - **TRM**: Tree Reasoning Module - iterative refinement of responses + - **MCTS**: Monte Carlo Tree Search - strategic exploration of solution space + - **Consensus**: Agreement scoring between agents + + --- + """ + ) + + with gr.Row(): + with gr.Column(scale=2): + query_input = gr.Textbox( + label="Query", placeholder="Enter your reasoning task or question...", lines=3, max_lines=10 + ) + + gr.Markdown("**Example Queries:**") + example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True) + + def load_example(example): + return example + + example_dropdown.change(load_example, example_dropdown, query_input) + + with gr.Column(scale=1): + gr.Markdown("**Agent Configuration**") + use_hrm = gr.Checkbox(label="Enable HRM (Hierarchical)", value=True) + use_trm = gr.Checkbox(label="Enable TRM (Iterative)", value=True) + use_mcts = gr.Checkbox(label="Enable MCTS", value=False) + + gr.Markdown("**MCTS Parameters**") + mcts_iterations = gr.Slider( + minimum=10, + maximum=100, + value=25, + step=5, + label="Iterations", + info="More iterations = better search, but slower", + ) + exploration_weight = gr.Slider( + minimum=0.1, + maximum=3.0, + value=1.414, + step=0.1, + label="Exploration Weight (C)", + info="Higher = more exploration, Lower = more exploitation", + ) + seed_input = gr.Number(label="Random Seed (0 for random)", value=0, precision=0) + + with gr.Accordion("Weights & Biases Tracking", open=False): + gr.Markdown( + """ + **Experiment Tracking with W&B** + + Track your experiments, visualize metrics, and compare runs. + Requires W&B API key set in Space secrets as `WANDB_API_KEY`. + """ + ) + with gr.Row(): + enable_wandb = gr.Checkbox( + label="Enable W&B Tracking", value=False, info="Log metrics and results to Weights & Biases" + ) + wandb_project = gr.Textbox( + label="Project Name", value="langgraph-mcts-demo", placeholder="Your W&B project name" + ) + wandb_run_name = gr.Textbox(label="Run Name (optional)", value="", placeholder="Auto-generated if empty") + + wandb_status = gr.Markdown(f"**W&B Status:** {'Available' if is_wandb_available() else 'Not installed'}") + + process_btn = gr.Button("Process Query", variant="primary", size="lg") + + gr.Markdown("---") + + with gr.Row(): + with gr.Column(): + gr.Markdown("### Final Response") + final_response_output = gr.Textbox(label="Synthesized Response", lines=4, interactive=False) + + gr.Markdown("### Performance Metrics") + metrics_output = gr.Markdown() + + with gr.Column(): + gr.Markdown("### Agent Details") + agent_details_output = gr.JSON(label="Individual Agent Results") + + with gr.Accordion("Full JSON Result", open=False): + full_result_output = gr.JSON(label="Complete Framework Output") + + with gr.Accordion("W&B Run Details", open=False, visible=True): + wandb_url_output = gr.Textbox( + label="W&B Run URL", interactive=False, placeholder="Enable W&B tracking to see run URL here" + ) + + # Wire up the processing + process_btn.click( + fn=process_query_sync, + inputs=[ + query_input, + use_hrm, + use_trm, + use_mcts, + mcts_iterations, + exploration_weight, + seed_input, + enable_wandb, + wandb_project, + wandb_run_name, + ], + outputs=[final_response_output, agent_details_output, metrics_output, full_result_output, wandb_url_output], + ) + + gr.Markdown( + """ + --- + + ### About This Demo + + This is a **proof-of-concept** demonstration of the LangGraph Multi-Agent MCTS Framework. + + **Features:** + - Multi-agent orchestration with consensus scoring + - Monte Carlo Tree Search for strategic reasoning + - Configurable exploration vs exploitation trade-offs + - Deterministic results with seeded randomness + - **Weights & Biases integration** for experiment tracking + + **Limitations (POC):** + - Uses mock/simplified LLM responses (not production LLM) + - Limited to demonstration scenarios + - No persistent storage or RAG + - Simplified MCTS implementation + + **Full Framework:** [GitHub Repository](https://github.com/ianshank/langgraph_multi_agent_mcts) + + --- + *Built with LangGraph, Gradio, Weights & Biases, and Python* + """ + ) + + +if __name__ == "__main__": + # Initialize with mock client for demo + framework = MultiAgentFrameworkDemo(use_hf_inference=False) + + # Launch the demo + demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True) diff --git a/demo_src/__init__.py b/demo_src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..622fa74289205c67a1f404910be05ac6892c9d74 --- /dev/null +++ b/demo_src/__init__.py @@ -0,0 +1 @@ +# Demo source modules for Hugging Face Spaces diff --git a/demo_src/agents_demo.py b/demo_src/agents_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..6335bc777201b0305c898d513e4ee6b6b57aa457 --- /dev/null +++ b/demo_src/agents_demo.py @@ -0,0 +1,234 @@ +""" +Simplified agent implementations for Hugging Face Spaces demo. +""" + +import asyncio +from typing import Any + + +class HRMAgent: + """Hierarchical Reasoning Module - breaks down complex queries.""" + + def __init__(self, llm_client): + """Initialize with an LLM client. + + Args: + llm_client: LLM client (MockLLMClient or HuggingFaceClient) + """ + self.llm_client = llm_client + self.name = "HRM (Hierarchical Reasoning)" + + async def process(self, query: str) -> dict[str, Any]: + """Process query using hierarchical decomposition. + + Args: + query: Input query to process + + Returns: + Dictionary with response, confidence, and reasoning steps + """ + # Step 1: Decompose the query + decomposition_steps = await self._decompose_query(query) + + # Step 2: Analyze each component + analysis_results = await self._analyze_components(decomposition_steps) + + # Step 3: Synthesize hierarchical response + llm_result = await self.llm_client.generate( + prompt=f"Hierarchical analysis of: {query}", context=f"Components: {', '.join(decomposition_steps)}" + ) + + # Compile reasoning steps + reasoning_steps = [ + f"1. Query decomposition: Identified {len(decomposition_steps)} key components", + f"2. Component analysis: {analysis_results}", + "3. Hierarchical synthesis: Combined insights from all levels", + f"4. Confidence assessment: {llm_result['confidence']:.1%} based on component clarity", + ] + + return { + "response": llm_result["response"], + "confidence": llm_result["confidence"], + "steps": reasoning_steps, + "components": decomposition_steps, + "tokens_used": llm_result.get("tokens_used", 0), + } + + async def _decompose_query(self, query: str) -> list[str]: + """Decompose query into hierarchical components.""" + # Simulate decomposition based on query structure + await asyncio.sleep(0.05) # Simulate processing + + # Simple heuristic decomposition + components = [] + + # Extract key phrases + query_lower = query.lower() + + if "?" in query: + components.append("Question type: Analytical") + else: + components.append("Question type: Directive") + + if "how" in query_lower: + components.append("Focus: Methodology/Process") + elif "what" in query_lower: + components.append("Focus: Definition/Identification") + elif "why" in query_lower: + components.append("Focus: Causation/Reasoning") + elif "should" in query_lower or "best" in query_lower: + components.append("Focus: Decision/Recommendation") + else: + components.append("Focus: General inquiry") + + # Domain detection + if any(term in query_lower for term in ["database", "sql", "nosql", "storage"]): + components.append("Domain: Data Management") + elif any(term in query_lower for term in ["architecture", "design", "pattern"]): + components.append("Domain: System Architecture") + elif any(term in query_lower for term in ["performance", "optimization", "speed"]): + components.append("Domain: Performance Engineering") + elif any(term in query_lower for term in ["scale", "distributed", "cluster"]): + components.append("Domain: Distributed Systems") + else: + components.append("Domain: Software Engineering") + + # Complexity assessment + word_count = len(query.split()) + if word_count > 20: + components.append("Complexity: High (detailed query)") + elif word_count > 10: + components.append("Complexity: Medium") + else: + components.append("Complexity: Low (concise query)") + + return components + + async def _analyze_components(self, components: list[str]) -> str: + """Analyze the decomposed components.""" + await asyncio.sleep(0.03) # Simulate processing + + # Generate analysis summary + analysis_parts = [] + + for component in components: + if "Focus:" in component: + focus = component.split(":")[1].strip() + analysis_parts.append(f"requires {focus.lower()} approach") + elif "Domain:" in component: + domain = component.split(":")[1].strip() + analysis_parts.append(f"applies to {domain}") + elif "Complexity:" in component: + complexity = component.split(":")[1].strip().split()[0] + analysis_parts.append(f"{complexity.lower()} complexity level") + + return "; ".join(analysis_parts) if analysis_parts else "Standard analysis" + + +class TRMAgent: + """Tree Reasoning Module - iterative refinement of responses.""" + + def __init__(self, llm_client): + """Initialize with an LLM client. + + Args: + llm_client: LLM client (MockLLMClient or HuggingFaceClient) + """ + self.llm_client = llm_client + self.name = "TRM (Iterative Refinement)" + self.max_iterations = 3 + + async def process(self, query: str) -> dict[str, Any]: + """Process query using iterative refinement. + + Args: + query: Input query to process + + Returns: + Dictionary with response, confidence, and reasoning steps + """ + reasoning_steps = [] + current_response = "" + current_confidence = 0.0 + + # Iterative refinement loop + for iteration in range(self.max_iterations): + step_num = iteration + 1 + + # Generate or refine response + if iteration == 0: + # Initial response + result = await self.llm_client.generate(prompt=query, context="") + current_response = result["response"] + current_confidence = result["confidence"] + reasoning_steps.append( + f"Iteration {step_num}: Initial response generated (confidence: {current_confidence:.1%})" + ) + else: + # Refinement iteration + refinement_result = await self._refine_response(query, current_response, iteration) + current_response = refinement_result["response"] + + # Confidence typically improves with refinement + confidence_improvement = min(0.1, (1 - current_confidence) * 0.3) + current_confidence = min(0.95, current_confidence + confidence_improvement) + + reasoning_steps.append( + f"Iteration {step_num}: {refinement_result['refinement_type']} " + f"(confidence: {current_confidence:.1%})" + ) + + # Check if confidence is high enough to stop + if current_confidence > 0.85: + reasoning_steps.append(f"Early termination: High confidence ({current_confidence:.1%}) achieved") + break + + # Final reasoning step + reasoning_steps.append(f"Final: Response refined through {len(reasoning_steps)} iterations") + + return { + "response": current_response, + "confidence": round(current_confidence, 3), + "steps": reasoning_steps, + "iterations_used": min(iteration + 1, self.max_iterations), + "refinement_history": reasoning_steps, + } + + async def _refine_response(self, query: str, current_response: str, iteration: int) -> dict[str, Any]: + """Refine the current response.""" + await asyncio.sleep(0.05) # Simulate refinement processing + + # Different refinement strategies based on iteration + refinement_strategies = [ + ("Clarity enhancement", "improve clarity and precision"), + ("Detail expansion", "add technical depth and specifics"), + ("Validation check", "verify accuracy and completeness"), + ] + + strategy_name, strategy_action = refinement_strategies[iteration % len(refinement_strategies)] + + # Generate refined response + refinement_prompt = f""" + Original query: {query} + Current response: {current_response} + Refinement task: {strategy_action} + """ + + result = await self.llm_client.generate( + prompt=refinement_prompt, context=f"Refinement iteration {iteration + 1}" + ) + + # Enhance the response based on strategy + enhanced_response = current_response + if strategy_name == "Clarity enhancement": + enhanced_response = f"{current_response}. {result['response']}" + elif strategy_name == "Detail expansion": + enhanced_response = f"{current_response}. Furthermore, {result['response']}" + else: # Validation + enhanced_response = f"{current_response}. Validated: {result['response']}" + + # Truncate if too long + if len(enhanced_response) > 300: + enhanced_response = enhanced_response[:297] + "..." + + return {"response": enhanced_response, "refinement_type": strategy_name, "strategy_action": strategy_action} diff --git a/demo_src/llm_mock.py b/demo_src/llm_mock.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7334442f2a9767f5017c5703338891ceafe97c --- /dev/null +++ b/demo_src/llm_mock.py @@ -0,0 +1,182 @@ +""" +Mock and lightweight LLM clients for demo purposes. +""" + +import asyncio +import random +from typing import Any + + +class MockLLMClient: + """Mock LLM client that generates plausible demo responses.""" + + def __init__(self): + self.response_templates = { + "architecture": [ + "Consider scalability requirements and team expertise", + "Evaluate coupling, deployment complexity, and operational overhead", + "Balance between development speed and long-term maintainability", + ], + "optimization": [ + "Profile first to identify actual bottlenecks", + "Consider memory-mapped files and streaming processing", + "Implement parallel processing with appropriate chunk sizes", + ], + "database": [ + "Consider data consistency requirements and query patterns", + "Evaluate write-heavy vs read-heavy workload characteristics", + "Plan for horizontal scaling and data distribution strategies", + ], + "distributed": [ + "Implement proper failure detection and recovery mechanisms", + "Use circuit breakers and bulkhead patterns for resilience", + "Consider eventual consistency vs strong consistency trade-offs", + ], + "default": [ + "Break down the problem into smaller components", + "Consider trade-offs between different approaches", + "Evaluate based on specific use case requirements", + ], + } + + async def generate(self, prompt: str, context: str = "") -> dict[str, Any]: + """Generate a mock response based on the prompt and optional context.""" + # Simulate processing time + await asyncio.sleep(random.uniform(0.1, 0.3)) + + # Determine response category + prompt_lower = prompt.lower() + if "architecture" in prompt_lower or "microservice" in prompt_lower or "monolith" in prompt_lower: + category = "architecture" + elif "optim" in prompt_lower or "performance" in prompt_lower or "process" in prompt_lower: + category = "optimization" + elif "database" in prompt_lower or "sql" in prompt_lower or "nosql" in prompt_lower: + category = "database" + elif "distribut" in prompt_lower or "fault" in prompt_lower or "rate limit" in prompt_lower: + category = "distributed" + else: + category = "default" + + templates = self.response_templates[category] + + # Generate response with some randomness + response = random.choice(templates) + confidence = random.uniform(0.6, 0.95) + + # Add more detail based on prompt length (simulating "understanding") + if len(prompt) > 100: + confidence = min(0.95, confidence + 0.1) + response += f". Additionally, {random.choice(self.response_templates['default'])}" + + # Lightly incorporate context to simulate conditioning + context_snippet = context.strip() + if context_snippet: + confidence = min(0.99, confidence + 0.05) + response += f" (context signal: {context_snippet[:60]}{'...' if len(context_snippet) > 60 else ''})" + + return { + "response": response, + "confidence": round(confidence, 3), + "tokens_used": len(prompt.split()) * 2 + len(response.split()), + } + + async def generate_reasoning_steps(self, query: str, num_steps: int = 3) -> list[str]: + """Generate mock reasoning steps.""" + await asyncio.sleep(random.uniform(0.05, 0.15)) + + base_steps = [ + f"Analyzing query: '{query[:50]}...'", + "Identifying key requirements and constraints", + "Evaluating potential approaches", + "Considering trade-offs and implications", + "Synthesizing recommendations based on analysis", + "Validating conclusions against requirements", + ] + + return random.sample(base_steps, min(num_steps, len(base_steps))) + + +class HuggingFaceClient: + """Lightweight Hugging Face Inference API client.""" + + def __init__(self, model_id: str = "mistralai/Mistral-7B-Instruct-v0.2"): + """Initialize with a Hugging Face model. + + Args: + model_id: The model ID on Hugging Face Hub + """ + self.model_id = model_id + self._client = None + + def _get_client(self): + """Lazy load the HF client.""" + if self._client is None: + try: + from huggingface_hub import InferenceClient + + self._client = InferenceClient(model=self.model_id) + except ImportError: + raise ImportError("huggingface_hub not installed. Install with: pip install huggingface_hub") + return self._client + + async def generate(self, prompt: str, context: str = "") -> dict[str, Any]: + """Generate response using Hugging Face Inference API.""" + try: + client = self._get_client() + + # Format prompt + if context: + full_prompt = f"Context: {context}\n\nQuestion: {prompt}\n\nAnswer:" + else: + full_prompt = f"Question: {prompt}\n\nProvide a concise, technical answer:\n\nAnswer:" + + # Call HF Inference API (sync call wrapped in async) + response_text = await asyncio.to_thread( + client.text_generation, full_prompt, max_new_tokens=150, temperature=0.7, do_sample=True + ) + + # Estimate confidence based on response characteristics + confidence = min(0.95, 0.6 + len(response_text) / 500) + + return { + "response": response_text.strip(), + "confidence": round(confidence, 3), + "tokens_used": len(full_prompt.split()) + len(response_text.split()), + } + + except Exception as e: + # Fallback to mock on error + print(f"HF Inference error: {e}. Falling back to mock.") + mock = MockLLMClient() + return await mock.generate(prompt, context) + + async def generate_reasoning_steps(self, query: str, num_steps: int = 3) -> list[str]: + """Generate reasoning steps using HF model.""" + try: + client = self._get_client() + + prompt = f"""Break down this question into {num_steps} reasoning steps: +Question: {query} + +Reasoning steps (one per line): +1.""" + + response = await asyncio.to_thread(client.text_generation, prompt, max_new_tokens=200, temperature=0.5) + + # Parse steps from response + lines = response.strip().split("\n") + steps = [] + for line in lines: + line = line.strip() + if line and not line.startswith("#"): + # Remove numbering + if line[0].isdigit() and "." in line[:3]: + line = line.split(".", 1)[1].strip() + steps.append(line) + + return steps[:num_steps] if steps else ["Analysis in progress"] + + except Exception as e: + print(f"HF reasoning error: {e}. Falling back to mock.") + mock = MockLLMClient() + return await mock.generate_reasoning_steps(query, num_steps) diff --git a/demo_src/mcts_demo.py b/demo_src/mcts_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3e10c0d6d4a76118e4d5626838890ee19102a7 --- /dev/null +++ b/demo_src/mcts_demo.py @@ -0,0 +1,436 @@ +""" +Educational MCTS demonstration using the production framework. + +This demo uses the real MCTSEngine from src.framework.mcts.core to provide +an authentic learning experience while remaining accessible for demonstrations. +""" + +from __future__ import annotations + +import math +from typing import Any + +from src.framework.mcts.core import MCTSEngine, MCTSNode, MCTSState +from src.framework.mcts.policies import RolloutPolicy, SelectionPolicy + + +class DemoRolloutPolicy(RolloutPolicy): + """ + Educational rollout policy for demo purposes. + + Evaluates states based on: + - Depth of exploration (deeper = more thorough) + - Action quality (domain-specific heuristics) + - Exploration randomness + """ + + def __init__(self, category: str, action_templates: dict[str, list[str]]): + """ + Initialize demo rollout policy. + + Args: + category: Query category for heuristic evaluation + action_templates: Available action templates for scoring + """ + self.category = category + self.action_templates = action_templates + + # Define key terms that indicate quality actions per category + self.quality_indicators = { + "architecture": ["scalability", "consistency", "requirements"], + "optimization": ["profile", "caching", "parallel"], + "database": ["patterns", "relationships", "scaling"], + "distributed": ["circuit", "retry", "bulkhead"], + "default": ["decompose", "constraints", "trade-offs"], + } + + async def evaluate( + self, + state: MCTSState, + rng, + max_depth: int = 10, + ) -> float: + """ + Evaluate a state through heuristic analysis. + + This combines: + - Depth bonus: rewards thorough exploration + - Action quality: rewards domain-appropriate actions + - Noise: adds exploration randomness + + Args: + state: State to evaluate + rng: Random number generator + max_depth: Maximum depth (unused in heuristic) + + Returns: + Estimated value in [0, 1] range + """ + # Base value + base_value = 0.5 + + # Depth bonus: deeper exploration = more value (up to 0.3) + depth = state.features.get("depth", 0) + depth_bonus = min(depth * 0.1, 0.3) + + # Action quality bonus + action_bonus = 0.0 + last_action = state.features.get("last_action", "") + + if last_action: + # Check if action contains quality indicators for this category + indicators = self.quality_indicators.get(self.category, self.quality_indicators["default"]) + for term in indicators: + if term in last_action.lower(): + action_bonus = 0.15 + break + + # Add exploration noise + noise = rng.uniform(-0.1, 0.1) + + # Combine components + value = base_value + depth_bonus + action_bonus + noise + + # Clamp to [0, 1] + return max(0.0, min(1.0, value)) + + +class MCTSDemo: + """ + Educational MCTS demonstration using the production framework. + + This class wraps the production MCTSEngine to provide: + - Simple, educational interface for demos + - Category-based action selection + - Tree visualization for learning + - Deterministic behavior with seeds + + Unlike the old mock implementation, this uses the real MCTS algorithm + with all its features: UCB1 selection, progressive widening, caching, etc. + """ + + def __init__(self, max_depth: int = 5): + """ + Initialize MCTS demo. + + Args: + max_depth: Maximum tree depth for exploration + """ + self.max_depth = max_depth + + # Action templates for different query types + # These provide domain-specific reasoning paths + self.action_templates = { + "architecture": [ + "Consider microservices for scalability", + "Evaluate monolith for simplicity", + "Analyze team capabilities", + "Assess deployment requirements", + "Review data consistency needs", + ], + "optimization": [ + "Profile application hotspots", + "Implement caching layer", + "Use parallel processing", + "Optimize database queries", + "Reduce memory allocations", + ], + "database": [ + "Analyze query patterns", + "Consider data relationships", + "Evaluate consistency requirements", + "Plan for horizontal scaling", + "Assess read/write ratios", + ], + "distributed": [ + "Implement circuit breakers", + "Add retry mechanisms", + "Use message queues", + "Apply bulkhead pattern", + "Design for eventual consistency", + ], + "default": [ + "Decompose the problem", + "Identify constraints", + "Evaluate trade-offs", + "Consider alternatives", + "Validate assumptions", + ], + } + + def _categorize_query(self, query: str) -> str: + """ + Categorize query to select appropriate action templates. + + Args: + query: User's input query + + Returns: + Category name for action selection + """ + query_lower = query.lower() + if "architecture" in query_lower or "microservice" in query_lower: + return "architecture" + elif "optim" in query_lower or "performance" in query_lower: + return "optimization" + elif "database" in query_lower or "sql" in query_lower: + return "database" + elif "distribut" in query_lower or "fault" in query_lower: + return "distributed" + return "default" + + def _create_action_generator(self, category: str): + """ + Create action generator function for this query category. + + Args: + category: Query category + + Returns: + Function that generates actions for a given state + """ + def action_generator(state: MCTSState) -> list[str]: + """Generate available actions from current state.""" + # Get category-specific actions + actions = self.action_templates.get(category, self.action_templates["default"]) + + # Filter out already-used actions (track via state features) + used_actions = state.features.get("used_actions", set()) + available = [a for a in actions if a not in used_actions] + + # If all actions used, allow re-exploring top 2 + if not available: + return actions[:2] + + return available + + return action_generator + + def _create_state_transition(self, category: str): + """ + Create state transition function for this query category. + + Args: + category: Query category + + Returns: + Function that computes next state from current state + action + """ + def state_transition(state: MCTSState, action: str) -> MCTSState: + """Compute next state by applying action.""" + # Track action history + action_history = list(state.features.get("action_history", [])) + action_history.append(action) + + # Track used actions + used_actions = set(state.features.get("used_actions", set())) + used_actions.add(action) + + # Increment depth + depth = state.features.get("depth", 0) + 1 + + # Create new state ID from action history + state_id = " -> ".join(action_history) + + # Build new state + new_state = MCTSState( + state_id=state_id, + features={ + "action_history": action_history, + "used_actions": used_actions, + "depth": depth, + "last_action": action, + "category": category, + }, + ) + + return new_state + + return state_transition + + def _generate_tree_visualization(self, root: MCTSNode, max_nodes: int = 20) -> str: + """ + Generate ASCII visualization of the MCTS tree. + + This provides educational insight into the search process. + + Args: + root: Root node of the tree + max_nodes: Maximum nodes to display + + Returns: + ASCII art representation of the tree + """ + max_nodes = max(1, max_nodes) + lines = [] + lines.append("MCTS Tree Visualization") + lines.append("=" * 50) + + nodes_rendered = 0 + + def format_node(node: MCTSNode, prefix: str = "", is_last: bool = True) -> list[str]: + nonlocal nodes_rendered + result = [] + + # Node representation + connector = "└── " if is_last else "├── " + + if nodes_rendered >= max_nodes: + result.append(f"{prefix}{connector}... (truncated)") + return result + + nodes_rendered += 1 + + # Display action or state + node_str = f"{node.state.state_id[:30]}..." + if node.action: + node_str = f"{node.action[:25]}..." + + stats = f"[V:{node.visits}, Q:{node.value:.3f}]" + + result.append(f"{prefix}{connector}{node_str} {stats}") + + # Recursively add children + new_prefix = prefix + (" " if is_last else "│ ") + + # Limit children shown + children_to_show = node.children[:3] + for i, child in enumerate(children_to_show): + is_child_last = i == len(children_to_show) - 1 + result.extend(format_node(child, new_prefix, is_child_last)) + + if len(node.children) > 3: + result.append(f"{new_prefix} ... and {len(node.children) - 3} more") + + return result + + # Start with root + lines.append(f"Root: {root.state.state_id[:40]}... [V:{root.visits}, Q:{root.value:.3f}]") + nodes_rendered += 1 + + for i, child in enumerate(root.children[:5]): + is_last = i == len(root.children[:5]) - 1 + lines.extend(format_node(child, "", is_last)) + + if len(root.children) > 5: + lines.append(f"... and {len(root.children) - 5} more branches") + + return "\n".join(lines) + + async def search( + self, + query: str, + iterations: int = 25, + exploration_weight: float = 1.414, + seed: int | None = None, + ) -> dict[str, Any]: + """ + Run MCTS search on the query using the production framework. + + This method demonstrates the full MCTS algorithm: + 1. Selection: UCB1-based tree traversal + 2. Expansion: Progressive widening of nodes + 3. Simulation: Heuristic evaluation (rollout) + 4. Backpropagation: Value updates up the tree + + Args: + query: The input query to analyze + iterations: Number of MCTS iterations (more = better but slower) + exploration_weight: UCB1 exploration constant (higher = more exploration) + seed: Random seed for deterministic results + + Returns: + Dictionary with: + - best_action: Recommended next step + - best_value: Confidence in recommendation + - statistics: Search metrics and performance data + - tree_visualization: ASCII art of search tree + """ + # Determine query category + category = self._categorize_query(query) + + # Initialize MCTS engine with production features + engine = MCTSEngine( + seed=seed if seed is not None else 42, + exploration_weight=exploration_weight, + progressive_widening_k=1.0, # Moderate expansion + progressive_widening_alpha=0.5, + max_parallel_rollouts=4, + cache_size_limit=10000, + ) + + # Create root state + root_state = MCTSState( + state_id=f"Query: {query[:50]}", + features={ + "query": query, + "category": category, + "action_history": [], + "used_actions": set(), + "depth": 0, + "last_action": "", + }, + ) + + # Create root node + root = MCTSNode(state=root_state, rng=engine.rng) + + # Create domain-specific functions + action_generator = self._create_action_generator(category) + state_transition = self._create_state_transition(category) + rollout_policy = DemoRolloutPolicy(category, self.action_templates) + + # Run MCTS search with production engine + best_action, stats = await engine.search( + root=root, + num_iterations=iterations, + action_generator=action_generator, + state_transition=state_transition, + rollout_policy=rollout_policy, + max_rollout_depth=self.max_depth, + selection_policy=SelectionPolicy.MAX_VISITS, # Most robust + ) + + # Extract best child info + best_child = None + if root.children: + best_child = max(root.children, key=lambda c: c.visits) + + # Compile results for demo interface + result = { + "best_action": best_action or "No action found", + "best_value": round(best_child.value, 4) if best_child else 0.0, + "root_visits": root.visits, + "total_nodes": engine.get_cached_node_count(), + "max_depth_reached": engine.get_cached_tree_depth(), + "iterations_completed": iterations, + "exploration_weight": exploration_weight, + "seed": seed, + "category": category, + + # Top actions sorted by visits + "top_actions": [ + { + "action": child.action, + "visits": child.visits, + "value": round(child.value, 4), + "ucb1": round( + child.visits / root.visits if root.visits > 0 else 0.0, 4 + ), # Simplified UCB display + } + for child in sorted(root.children, key=lambda c: -c.visits)[:5] + ], + + # Framework statistics + "framework_stats": { + "cache_hits": stats.get("cache_hits", 0), + "cache_misses": stats.get("cache_misses", 0), + "cache_hit_rate": round(stats.get("cache_hit_rate", 0.0), 4), + "total_simulations": stats.get("total_simulations", 0), + }, + + # Educational visualization + "tree_visualization": self._generate_tree_visualization(root), + } + + return result diff --git a/demo_src/wandb_tracker.py b/demo_src/wandb_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba8dbcfe5865a4c73905061dc6a4ba440f2ef50 --- /dev/null +++ b/demo_src/wandb_tracker.py @@ -0,0 +1,349 @@ +""" +Weights & Biases integration for experiment tracking. +""" + +import os +from datetime import datetime +from typing import Any + +try: + import wandb + + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + wandb = None + + +class WandBTracker: + """Weights & Biases experiment tracker for multi-agent MCTS demo.""" + + def __init__(self, project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True): + """Initialize W&B tracker. + + Args: + project_name: W&B project name + entity: W&B entity (username or team) + enabled: Whether tracking is enabled + """ + self.project_name = project_name + self.entity = entity + self.enabled = enabled and WANDB_AVAILABLE + self.run = None + self.run_id = None + + def is_available(self) -> bool: + """Check if W&B is available.""" + return WANDB_AVAILABLE + + def init_run( + self, run_name: str | None = None, config: dict[str, Any] | None = None, tags: list[str] | None = None + ) -> bool: + """Initialize a new W&B run. + + Args: + run_name: Optional name for the run + config: Configuration dictionary to log + tags: Tags for the run + + Returns: + True if run initialized successfully, False otherwise + """ + if not self.enabled: + return False + + try: + # Generate run name if not provided + if run_name is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + run_name = f"mcts_query_{timestamp}" + + # Default tags + if tags is None: + tags = ["demo", "multi-agent", "mcts"] + + # Initialize run + self.run = wandb.init( + project=self.project_name, + entity=self.entity, + name=run_name, + config=config or {}, + tags=tags, + reinit=True, + ) + + self.run_id = self.run.id + return True + + except Exception as e: + print(f"W&B init error: {e}") + self.enabled = False + return False + + def log_query_config(self, config: dict[str, Any]): + """Log query configuration. + + Args: + config: Configuration dictionary with agent settings, MCTS params, etc. + """ + if not self.enabled or not self.run: + return + + try: + wandb.config.update(config) + except Exception as e: + print(f"W&B config log error: {e}") + + def log_agent_result( + self, + agent_name: str, + response: str, + confidence: float, + execution_time_ms: float, + reasoning_steps: list[str] | None = None, + ): + """Log individual agent results. + + Args: + agent_name: Name of the agent (HRM, TRM, MCTS) + response: Agent's response text + confidence: Confidence score (0-1) + execution_time_ms: Execution time in milliseconds + reasoning_steps: Optional list of reasoning steps + """ + if not self.enabled or not self.run: + return + + try: + metrics = { + f"{agent_name}/confidence": confidence, + f"{agent_name}/execution_time_ms": execution_time_ms, + f"{agent_name}/response_length": len(response), + } + + if reasoning_steps: + metrics[f"{agent_name}/num_reasoning_steps"] = len(reasoning_steps) + + wandb.log(metrics) + + # Log response as text + wandb.log({f"{agent_name}/response": wandb.Html(f"
{response}
")}) + + except Exception as e: + print(f"W&B agent result log error: {e}") + + def log_mcts_result(self, mcts_result: dict[str, Any]): + """Log MCTS-specific metrics. + + Args: + mcts_result: Dictionary containing MCTS search results + """ + if not self.enabled or not self.run: + return + + try: + # Extract key metrics + metrics = { + "mcts/best_value": mcts_result.get("best_value", 0), + "mcts/root_visits": mcts_result.get("root_visits", 0), + "mcts/total_nodes": mcts_result.get("total_nodes", 0), + "mcts/max_depth": mcts_result.get("max_depth_reached", 0), + "mcts/iterations": mcts_result.get("iterations_completed", 0), + "mcts/exploration_weight": mcts_result.get("exploration_weight", 1.414), + } + + wandb.log(metrics) + + # Log top actions as table + if "top_actions" in mcts_result: + top_actions_data = [] + for action in mcts_result["top_actions"]: + top_actions_data.append( + [ + action.get("action", ""), + action.get("visits", 0), + action.get("value", 0), + action.get("ucb1", 0), + ] + ) + + if top_actions_data: + table = wandb.Table(data=top_actions_data, columns=["Action", "Visits", "Value", "UCB1"]) + wandb.log({"mcts/top_actions_table": table}) + + # Log tree visualization as text artifact + if "tree_visualization" in mcts_result: + wandb.log({"mcts/tree_visualization": wandb.Html(f"
{mcts_result['tree_visualization']}
")}) + + except Exception as e: + print(f"W&B MCTS result log error: {e}") + + def log_consensus(self, consensus_score: float, agents_used: list[str], final_response: str): + """Log consensus metrics. + + Args: + consensus_score: Agreement score between agents (0-1) + agents_used: List of agent names that were used + final_response: Final synthesized response + """ + if not self.enabled or not self.run: + return + + try: + wandb.log( + { + "consensus/score": consensus_score, + "consensus/num_agents": len(agents_used), + "consensus/agents": ", ".join(agents_used), + "consensus/final_response_length": len(final_response), + } + ) + + # Categorize consensus level + if consensus_score > 0.7: + consensus_level = "high" + elif consensus_score > 0.4: + consensus_level = "medium" + else: + consensus_level = "low" + + wandb.log({"consensus/level": consensus_level}) + + except Exception as e: + print(f"W&B consensus log error: {e}") + + def log_performance(self, total_time_ms: float): + """Log overall performance metrics. + + Args: + total_time_ms: Total execution time in milliseconds + """ + if not self.enabled or not self.run: + return + + try: + wandb.log({"performance/total_time_ms": total_time_ms, "performance/total_time_s": total_time_ms / 1000}) + except Exception as e: + print(f"W&B performance log error: {e}") + + def log_full_result(self, result: dict[str, Any]): + """Log the complete result as an artifact. + + Args: + result: Full framework result dictionary + """ + if not self.enabled or not self.run: + return + + try: + # Create artifact + artifact = wandb.Artifact(name=f"query_result_{self.run_id}", type="result") + + # Add result as JSON + import json + import tempfile + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(result, f, indent=2, default=str) + temp_path = f.name + + artifact.add_file(temp_path, name="result.json") + wandb.log_artifact(artifact) + + # Clean up temp file + os.unlink(temp_path) + + except Exception as e: + print(f"W&B full result log error: {e}") + + def log_query_summary( + self, query: str, use_hrm: bool, use_trm: bool, use_mcts: bool, consensus_score: float, total_time_ms: float + ): + """Log a summary row for the query. + + Args: + query: The input query + use_hrm: Whether HRM was enabled + use_trm: Whether TRM was enabled + use_mcts: Whether MCTS was enabled + consensus_score: Final consensus score + total_time_ms: Total execution time + """ + if not self.enabled or not self.run: + return + + try: + # Create summary table entry + summary_data = [ + [ + query[:100] + "..." if len(query) > 100 else query, + "✓" if use_hrm else "✗", + "✓" if use_trm else "✗", + "✓" if use_mcts else "✗", + f"{consensus_score:.1%}", + f"{total_time_ms:.2f}", + ] + ] + + table = wandb.Table(data=summary_data, columns=["Query", "HRM", "TRM", "MCTS", "Consensus", "Time (ms)"]) + + wandb.log({"query_summary": table}) + + except Exception as e: + print(f"W&B summary log error: {e}") + + def finish_run(self): + """Finish the current W&B run.""" + if not self.enabled or not self.run: + return + + try: + wandb.finish() + self.run = None + self.run_id = None + except Exception as e: + print(f"W&B finish error: {e}") + + def get_run_url(self) -> str | None: + """Get the URL for the current run. + + Returns: + URL string or None if no active run + """ + if not self.enabled or not self.run: + return None + + try: + return self.run.get_url() + except Exception: + return None + + +# Global tracker instance +_global_tracker: WandBTracker | None = None + + +def get_tracker( + project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True +) -> WandBTracker: + """Get or create the global W&B tracker. + + Args: + project_name: W&B project name + entity: W&B entity + enabled: Whether tracking is enabled + + Returns: + WandBTracker instance + """ + global _global_tracker + + if _global_tracker is None: + _global_tracker = WandBTracker(project_name=project_name, entity=entity, enabled=enabled) + + return _global_tracker + + +def is_wandb_available() -> bool: + """Check if W&B is available.""" + return WANDB_AVAILABLE diff --git a/models/bert_lora/final_model/README.md b/models/bert_lora/final_model/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4df78b2dcbdee07a91a2b3598da2770038a78e0f --- /dev/null +++ b/models/bert_lora/final_model/README.md @@ -0,0 +1,206 @@ +--- +base_model: prajjwal1/bert-mini +library_name: peft +tags: +- base_model:adapter:prajjwal1/bert-mini +- lora +- transformers +--- + +# Model Card for Model ID + + + + + +## Model Details + +### Model Description + + + + + +- **Developed by:** [More Information Needed] +- **Funded by [optional]:** [More Information Needed] +- **Shared by [optional]:** [More Information Needed] +- **Model type:** [More Information Needed] +- **Language(s) (NLP):** [More Information Needed] +- **License:** [More Information Needed] +- **Finetuned from model [optional]:** [More Information Needed] + +### Model Sources [optional] + + + +- **Repository:** [More Information Needed] +- **Paper [optional]:** [More Information Needed] +- **Demo [optional]:** [More Information Needed] + +## Uses + + + +### Direct Use + + + +[More Information Needed] + +### Downstream Use [optional] + + + +[More Information Needed] + +### Out-of-Scope Use + + + +[More Information Needed] + +## Bias, Risks, and Limitations + + + +[More Information Needed] + +### Recommendations + + + +Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations. + +## How to Get Started with the Model + +Use the code below to get started with the model. + +[More Information Needed] + +## Training Details + +### Training Data + + + +[More Information Needed] + +### Training Procedure + + + +#### Preprocessing [optional] + +[More Information Needed] + + +#### Training Hyperparameters + +- **Training regime:** [More Information Needed] + +#### Speeds, Sizes, Times [optional] + + + +[More Information Needed] + +## Evaluation + + + +### Testing Data, Factors & Metrics + +#### Testing Data + + + +[More Information Needed] + +#### Factors + + + +[More Information Needed] + +#### Metrics + + + +[More Information Needed] + +### Results + +[More Information Needed] + +#### Summary + + + +## Model Examination [optional] + + + +[More Information Needed] + +## Environmental Impact + + + +Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). + +- **Hardware Type:** [More Information Needed] +- **Hours used:** [More Information Needed] +- **Cloud Provider:** [More Information Needed] +- **Compute Region:** [More Information Needed] +- **Carbon Emitted:** [More Information Needed] + +## Technical Specifications [optional] + +### Model Architecture and Objective + +[More Information Needed] + +### Compute Infrastructure + +[More Information Needed] + +#### Hardware + +[More Information Needed] + +#### Software + +[More Information Needed] + +## Citation [optional] + + + +**BibTeX:** + +[More Information Needed] + +**APA:** + +[More Information Needed] + +## Glossary [optional] + + + +[More Information Needed] + +## More Information [optional] + +[More Information Needed] + +## Model Card Authors [optional] + +[More Information Needed] + +## Model Card Contact + +[More Information Needed] +### Framework versions + +- PEFT 0.17.1 \ No newline at end of file diff --git a/models/bert_lora/final_model/adapter_config.json b/models/bert_lora/final_model/adapter_config.json new file mode 100644 index 0000000000000000000000000000000000000000..dae7210ba02813d4b1e90f83f2a43a744acb3ce9 --- /dev/null +++ b/models/bert_lora/final_model/adapter_config.json @@ -0,0 +1,40 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": "prajjwal1/bert-mini", + "bias": "none", + "corda_config": null, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 16, + "lora_bias": false, + "lora_dropout": 0.1, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": [ + "classifier", + "score" + ], + "peft_type": "LORA", + "qalora_group_size": 16, + "r": 4, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "query", + "value" + ], + "target_parameters": null, + "task_type": "SEQ_CLS", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} \ No newline at end of file diff --git a/models/bert_lora/final_model/adapter_model.safetensors b/models/bert_lora/final_model/adapter_model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..90ac62deb6311c1d6fd3ed11325a37e0e7827ae1 Binary files /dev/null and b/models/bert_lora/final_model/adapter_model.safetensors differ diff --git a/models/bert_lora/generated_dataset.json b/models/bert_lora/generated_dataset.json new file mode 100644 index 0000000000000000000000000000000000000000..ad76e7af39a0620c2f4e1b07cdb4b28ffc1a502f --- /dev/null +++ b/models/bert_lora/generated_dataset.json @@ -0,0 +1,12993 @@ +{ + "seed": 42, + "num_samples": 999, + "samples": [ + { + "features": { + "hrm_confidence": 0.932186814566789, + "trm_confidence": 0.36522885075930284, + "mcts_value": 0.7145138679647244, + "consensus_score": 0.7101167835754781, + "last_agent": "none", + "iteration": 1, + "query_length": 2637, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9283419105971059, + "trm_confidence": 0.6511300084852987, + "mcts_value": 0.10612189126399749, + "consensus_score": 0.5519417910279142, + "last_agent": "trm", + "iteration": 4, + "query_length": 921, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8931595360241993, + "trm_confidence": 0.6525812194404133, + "mcts_value": 0.351698200208428, + "consensus_score": 0.5779273962479688, + "last_agent": "none", + "iteration": 6, + "query_length": 4441, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9482893515977746, + "trm_confidence": 0.5358341835586543, + "mcts_value": 0.6430777574912441, + "consensus_score": 0.679972291175198, + "last_agent": "none", + "iteration": 10, + "query_length": 2234, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9335150491221286, + "trm_confidence": 0.1622342921363004, + "mcts_value": 0.3890189803478681, + "consensus_score": 0.4036835270262114, + "last_agent": "trm", + "iteration": 1, + "query_length": 3720, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9234286467723452, + "trm_confidence": 0.7966752297173753, + "mcts_value": 0.2682939337358132, + "consensus_score": 0.6368912112821515, + "last_agent": "none", + "iteration": 5, + "query_length": 3981, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7389764516006414, + "trm_confidence": 0.30396424576879205, + "mcts_value": 0.1449897306915453, + "consensus_score": 0.42993960829016165, + "last_agent": "mcts", + "iteration": 4, + "query_length": 811, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9100795306006747, + "trm_confidence": 0.25304182222607313, + "mcts_value": 0.6741966292520353, + "consensus_score": 0.6733921988589547, + "last_agent": "mcts", + "iteration": 4, + "query_length": 4492, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9047486511924926, + "trm_confidence": 0.112465622685393, + "mcts_value": 0.16087585630415913, + "consensus_score": 0.29416916401088267, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3904, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9115496135879005, + "trm_confidence": 0.6336003434427341, + "mcts_value": 0.3724329203075315, + "consensus_score": 0.6529425316366341, + "last_agent": "none", + "iteration": 1, + "query_length": 1235, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9005208885371415, + "trm_confidence": 0.37712235352817686, + "mcts_value": 0.4524833101935956, + "consensus_score": 0.6297086222361764, + "last_agent": "trm", + "iteration": 6, + "query_length": 2831, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8677621482236241, + "trm_confidence": 0.23336138024133235, + "mcts_value": 0.02366076687148142, + "consensus_score": 0.3622715762919518, + "last_agent": "mcts", + "iteration": 2, + "query_length": 1390, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9560209219804499, + "trm_confidence": 0.2002570943780814, + "mcts_value": 0.0499083666946623, + "consensus_score": 0.35833890608879715, + "last_agent": "mcts", + "iteration": 3, + "query_length": 2177, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8671096457023835, + "trm_confidence": 0.6013358774543538, + "mcts_value": 0.5096013245558383, + "consensus_score": 0.6406263215255393, + "last_agent": "hrm", + "iteration": 8, + "query_length": 1613, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7068136219401581, + "trm_confidence": 0.054642268545230264, + "mcts_value": 0.4383374938777725, + "consensus_score": 0.3923065741713311, + "last_agent": "trm", + "iteration": 1, + "query_length": 4505, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7456936308139505, + "trm_confidence": 0.4496096311936752, + "mcts_value": 0.2880802654858253, + "consensus_score": 0.47066542105044684, + "last_agent": "none", + "iteration": 3, + "query_length": 3422, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8085437831660172, + "trm_confidence": 0.062103805426426364, + "mcts_value": 0.08361234832438866, + "consensus_score": 0.410466178548847, + "last_agent": "hrm", + "iteration": 9, + "query_length": 2480, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7797609884378558, + "trm_confidence": 0.6588082922365092, + "mcts_value": 0.5293644842266623, + "consensus_score": 0.6993559594654748, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3687, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7289172886460498, + "trm_confidence": 0.5676622519597796, + "mcts_value": 0.2866455884313108, + "consensus_score": 0.46821438263809284, + "last_agent": "trm", + "iteration": 3, + "query_length": 4046, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7530318348817695, + "trm_confidence": 0.5593963977267774, + "mcts_value": 0.49533740036194435, + "consensus_score": 0.6464811355136844, + "last_agent": "trm", + "iteration": 4, + "query_length": 1550, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8752293906738207, + "trm_confidence": 0.5037801849547963, + "mcts_value": 0.06546371962314065, + "consensus_score": 0.46465257885137445, + "last_agent": "mcts", + "iteration": 0, + "query_length": 905, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7989583636998355, + "trm_confidence": 0.10101639056485813, + "mcts_value": 0.07227436912106161, + "consensus_score": 0.3416119555641275, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1507, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8743183419101185, + "trm_confidence": 0.26858765190610195, + "mcts_value": 0.45755670357289335, + "consensus_score": 0.43804834000231074, + "last_agent": "none", + "iteration": 10, + "query_length": 2097, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9348205681750859, + "trm_confidence": 0.06906470554037497, + "mcts_value": 0.40627238425745194, + "consensus_score": 0.4681939515285418, + "last_agent": "hrm", + "iteration": 10, + "query_length": 813, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8420468203170861, + "trm_confidence": 0.19810844189938528, + "mcts_value": 0.24603972019376524, + "consensus_score": 0.4328661412977197, + "last_agent": "mcts", + "iteration": 4, + "query_length": 1430, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9478875772583073, + "trm_confidence": 0.7598435856691546, + "mcts_value": 0.11891546028371681, + "consensus_score": 0.6196894364448695, + "last_agent": "hrm", + "iteration": 1, + "query_length": 3333, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7843701351517025, + "trm_confidence": 0.4512891576261886, + "mcts_value": 0.4975334024340737, + "consensus_score": 0.6314603967875196, + "last_agent": "hrm", + "iteration": 1, + "query_length": 3077, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7690641972684642, + "trm_confidence": 0.025031401865777046, + "mcts_value": 0.37123192203583794, + "consensus_score": 0.3626269638291808, + "last_agent": "hrm", + "iteration": 9, + "query_length": 2306, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7951416678468146, + "trm_confidence": 0.6624000747790032, + "mcts_value": 0.20222911121111173, + "consensus_score": 0.5562683771253194, + "last_agent": "mcts", + "iteration": 2, + "query_length": 4957, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7493823452746055, + "trm_confidence": 0.029164163348725774, + "mcts_value": 0.2825443492468585, + "consensus_score": 0.452172065434564, + "last_agent": "none", + "iteration": 9, + "query_length": 463, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9672377472635574, + "trm_confidence": 0.7748306511117701, + "mcts_value": 0.4499735556104071, + "consensus_score": 0.6938664616947368, + "last_agent": "none", + "iteration": 8, + "query_length": 1684, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.812097318662113, + "trm_confidence": 0.0672694610295489, + "mcts_value": 0.5317868798463629, + "consensus_score": 0.4228766563639145, + "last_agent": "none", + "iteration": 10, + "query_length": 4740, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7368273797234458, + "trm_confidence": 0.5292753052442554, + "mcts_value": 0.09761564970867558, + "consensus_score": 0.39042643985700676, + "last_agent": "hrm", + "iteration": 6, + "query_length": 3603, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7589303997143719, + "trm_confidence": 0.20448170182489145, + "mcts_value": 0.5122556808021458, + "consensus_score": 0.5862545459926631, + "last_agent": "hrm", + "iteration": 5, + "query_length": 2174, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7041808863124605, + "trm_confidence": 0.13875378375170938, + "mcts_value": 0.07964446437793447, + "consensus_score": 0.3430581128699396, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1490, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9082787306928659, + "trm_confidence": 0.46970419528561314, + "mcts_value": 0.16147441014713151, + "consensus_score": 0.5739773506116561, + "last_agent": "none", + "iteration": 7, + "query_length": 431, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.739317325467194, + "trm_confidence": 0.07911795076613058, + "mcts_value": 0.5930068093130082, + "consensus_score": 0.44999633394709904, + "last_agent": "mcts", + "iteration": 3, + "query_length": 669, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8988592638290747, + "trm_confidence": 0.763408491622485, + "mcts_value": 0.22883022193361457, + "consensus_score": 0.7153276783241302, + "last_agent": "hrm", + "iteration": 0, + "query_length": 2989, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8901925335043255, + "trm_confidence": 0.08367933776133821, + "mcts_value": 0.11089530175490933, + "consensus_score": 0.3454119215367851, + "last_agent": "none", + "iteration": 10, + "query_length": 838, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9799069664800634, + "trm_confidence": 0.7077627732121353, + "mcts_value": 0.4112523272131102, + "consensus_score": 0.7565933788188737, + "last_agent": "mcts", + "iteration": 0, + "query_length": 3321, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9488285844648209, + "trm_confidence": 0.6763611211604272, + "mcts_value": 0.19747211169221357, + "consensus_score": 0.6137078572256313, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4667, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8809321472016177, + "trm_confidence": 0.32219040146723615, + "mcts_value": 0.29221234846655453, + "consensus_score": 0.4836213829865048, + "last_agent": "none", + "iteration": 7, + "query_length": 4814, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8361690646228899, + "trm_confidence": 0.18245181923444714, + "mcts_value": 0.17422351039937214, + "consensus_score": 0.446817654134259, + "last_agent": "mcts", + "iteration": 8, + "query_length": 4797, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7199676570865534, + "trm_confidence": 0.3685296458124635, + "mcts_value": 0.0906226837308878, + "consensus_score": 0.4579728336345698, + "last_agent": "hrm", + "iteration": 3, + "query_length": 4618, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.976291141746235, + "trm_confidence": 0.14505398231091812, + "mcts_value": 0.2494976860299916, + "consensus_score": 0.38767028240079326, + "last_agent": "hrm", + "iteration": 1, + "query_length": 2706, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.716618622749278, + "trm_confidence": 0.10768718328337355, + "mcts_value": 0.032916293776263636, + "consensus_score": 0.3039694631584993, + "last_agent": "trm", + "iteration": 7, + "query_length": 4018, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7953973290856148, + "trm_confidence": 0.35084619767887465, + "mcts_value": 0.6084760997666595, + "consensus_score": 0.6551328675414908, + "last_agent": "trm", + "iteration": 0, + "query_length": 2470, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7710234613313187, + "trm_confidence": 0.1673449143479575, + "mcts_value": 0.3833105111979031, + "consensus_score": 0.42381211409969827, + "last_agent": "hrm", + "iteration": 0, + "query_length": 45, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8571258846152912, + "trm_confidence": 0.07697842942649595, + "mcts_value": 0.6310330448273446, + "consensus_score": 0.4321048262493981, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3450, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9530724854793258, + "trm_confidence": 0.7700285610160599, + "mcts_value": 0.8356447951906794, + "consensus_score": 0.9133204566167403, + "last_agent": "trm", + "iteration": 8, + "query_length": 3143, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9336989063734387, + "trm_confidence": 0.11217602900417172, + "mcts_value": 0.44691933531130246, + "consensus_score": 0.5004426642677205, + "last_agent": "mcts", + "iteration": 9, + "query_length": 4170, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.815526848830106, + "trm_confidence": 0.45762469209576373, + "mcts_value": 0.1906616579001964, + "consensus_score": 0.41589141513281347, + "last_agent": "hrm", + "iteration": 5, + "query_length": 834, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.769770982158662, + "trm_confidence": 0.24614874601571698, + "mcts_value": 0.24539903097564403, + "consensus_score": 0.3859386992658501, + "last_agent": "hrm", + "iteration": 4, + "query_length": 922, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7890629423713855, + "trm_confidence": 0.653822834874885, + "mcts_value": 0.6314214625791688, + "consensus_score": 0.6876178322684955, + "last_agent": "mcts", + "iteration": 3, + "query_length": 2308, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9545681466320873, + "trm_confidence": 0.5576803541380092, + "mcts_value": 0.6874076335747309, + "consensus_score": 0.7397631666614238, + "last_agent": "mcts", + "iteration": 6, + "query_length": 371, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9204679486930872, + "trm_confidence": 0.16606648132791707, + "mcts_value": 0.5700595955075086, + "consensus_score": 0.6243418221746831, + "last_agent": "mcts", + "iteration": 1, + "query_length": 4763, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7285287244689803, + "trm_confidence": 0.45613311816464464, + "mcts_value": 0.05310641505342561, + "consensus_score": 0.49977738376914194, + "last_agent": "trm", + "iteration": 1, + "query_length": 4801, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.940265252825529, + "trm_confidence": 0.49885035958088786, + "mcts_value": 0.6576118411438995, + "consensus_score": 0.757932118980056, + "last_agent": "mcts", + "iteration": 10, + "query_length": 459, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8770227686074927, + "trm_confidence": 0.07385539064541159, + "mcts_value": 0.4787747781537166, + "consensus_score": 0.410809239954098, + "last_agent": "none", + "iteration": 6, + "query_length": 49, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8397955458943682, + "trm_confidence": 0.3866406596674736, + "mcts_value": 0.5651471213314001, + "consensus_score": 0.6570433856001606, + "last_agent": "trm", + "iteration": 5, + "query_length": 1186, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9793708707070305, + "trm_confidence": 0.10529022997310968, + "mcts_value": 0.10297746470847831, + "consensus_score": 0.31342132417835983, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3647, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9322964248425529, + "trm_confidence": 0.5586635055387557, + "mcts_value": 0.2776855134490763, + "consensus_score": 0.669221790749102, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3176, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8092576053316162, + "trm_confidence": 0.2230189473801963, + "mcts_value": 0.11178726047613538, + "consensus_score": 0.31091127890442094, + "last_agent": "none", + "iteration": 10, + "query_length": 2780, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.814995946833223, + "trm_confidence": 0.5217223241664531, + "mcts_value": 0.3953878002827034, + "consensus_score": 0.6645966877951524, + "last_agent": "hrm", + "iteration": 8, + "query_length": 3067, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8129078420442478, + "trm_confidence": 0.7033773655847808, + "mcts_value": 0.5116969009545275, + "consensus_score": 0.766232968195287, + "last_agent": "hrm", + "iteration": 1, + "query_length": 1362, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8911221651973622, + "trm_confidence": 0.0964549421489706, + "mcts_value": 0.4653839426303197, + "consensus_score": 0.5215396230133073, + "last_agent": "trm", + "iteration": 0, + "query_length": 487, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9476198533588963, + "trm_confidence": 0.250352173750384, + "mcts_value": 0.3886744578520911, + "consensus_score": 0.5173449870673691, + "last_agent": "none", + "iteration": 3, + "query_length": 2583, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9343882106294548, + "trm_confidence": 0.09227366636739937, + "mcts_value": 0.8319139640886433, + "consensus_score": 0.6953652852233815, + "last_agent": "trm", + "iteration": 3, + "query_length": 3892, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7319258595614149, + "trm_confidence": 0.631360115596809, + "mcts_value": 0.42066339906103695, + "consensus_score": 0.6246747945181796, + "last_agent": "trm", + "iteration": 0, + "query_length": 3407, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7086998509613408, + "trm_confidence": 0.14659200303695535, + "mcts_value": 0.08705739409814221, + "consensus_score": 0.3694700041705898, + "last_agent": "mcts", + "iteration": 2, + "query_length": 82, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.896880711798247, + "trm_confidence": 0.02881736647956746, + "mcts_value": 0.004326930085425239, + "consensus_score": 0.2203399195239597, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4901, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7715658461882166, + "trm_confidence": 0.5704339683573875, + "mcts_value": 0.03843501631801768, + "consensus_score": 0.5203377143495425, + "last_agent": "mcts", + "iteration": 10, + "query_length": 610, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9094362351935912, + "trm_confidence": 0.6782915533551156, + "mcts_value": 0.032499916743507586, + "consensus_score": 0.4804323239078826, + "last_agent": "none", + "iteration": 1, + "query_length": 3490, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9235564384986533, + "trm_confidence": 0.5188503109313061, + "mcts_value": 0.7009544972153812, + "consensus_score": 0.6454963473699984, + "last_agent": "mcts", + "iteration": 8, + "query_length": 274, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7812276253953401, + "trm_confidence": 0.4836066911859404, + "mcts_value": 0.6677425780176791, + "consensus_score": 0.6665010193950671, + "last_agent": "trm", + "iteration": 0, + "query_length": 360, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.71270516547405, + "trm_confidence": 0.5417206443861774, + "mcts_value": 0.4347622805893685, + "consensus_score": 0.4976882661010326, + "last_agent": "hrm", + "iteration": 1, + "query_length": 3854, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9940081538662577, + "trm_confidence": 0.40995695340003957, + "mcts_value": 0.7009747611084067, + "consensus_score": 0.7289282912131234, + "last_agent": "mcts", + "iteration": 6, + "query_length": 2238, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9838073360674285, + "trm_confidence": 0.26632882928637897, + "mcts_value": 0.5108558557039758, + "consensus_score": 0.6269525292704077, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2219, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7445316969760174, + "trm_confidence": 0.3276494531736131, + "mcts_value": 0.2604129715008172, + "consensus_score": 0.4390317864466729, + "last_agent": "none", + "iteration": 1, + "query_length": 4868, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7834226637343772, + "trm_confidence": 0.20824203195435526, + "mcts_value": 0.2924387541146086, + "consensus_score": 0.45023199267510605, + "last_agent": "hrm", + "iteration": 6, + "query_length": 4858, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8226349328286091, + "trm_confidence": 0.1572659757712483, + "mcts_value": 0.4251306463016947, + "consensus_score": 0.4317520338972982, + "last_agent": "mcts", + "iteration": 0, + "query_length": 1500, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8422398025270511, + "trm_confidence": 0.16744400594740463, + "mcts_value": 0.424901063235603, + "consensus_score": 0.4913493373230704, + "last_agent": "trm", + "iteration": 7, + "query_length": 2740, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8957299169563094, + "trm_confidence": 0.25162106077135565, + "mcts_value": 0.6265833766096053, + "consensus_score": 0.6011403281656927, + "last_agent": "none", + "iteration": 4, + "query_length": 2653, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8081972003350762, + "trm_confidence": 0.36312049753242803, + "mcts_value": 0.5217329060540332, + "consensus_score": 0.6416307786519208, + "last_agent": "none", + "iteration": 10, + "query_length": 4253, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8560825344221707, + "trm_confidence": 0.6047680473796243, + "mcts_value": 0.23775067596476046, + "consensus_score": 0.6336768917245321, + "last_agent": "hrm", + "iteration": 5, + "query_length": 217, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7216177441218831, + "trm_confidence": 0.5233979204122917, + "mcts_value": 0.034542003153205686, + "consensus_score": 0.3826415097891913, + "last_agent": "trm", + "iteration": 3, + "query_length": 2443, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7941680109405967, + "trm_confidence": 0.5155534217839192, + "mcts_value": 0.010192360309046276, + "consensus_score": 0.505405949248658, + "last_agent": "trm", + "iteration": 9, + "query_length": 2082, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7460838697181904, + "trm_confidence": 0.388193295848214, + "mcts_value": 0.07731850799225726, + "consensus_score": 0.3768490967352044, + "last_agent": "trm", + "iteration": 10, + "query_length": 392, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9316314674063757, + "trm_confidence": 0.25860537680230905, + "mcts_value": 0.5718838939220751, + "consensus_score": 0.6284548524737226, + "last_agent": "mcts", + "iteration": 4, + "query_length": 4629, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7032182934927028, + "trm_confidence": 0.12610740406362417, + "mcts_value": 0.3167428700488906, + "consensus_score": 0.314773116720065, + "last_agent": "trm", + "iteration": 1, + "query_length": 2095, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9967399008050817, + "trm_confidence": 0.49855996974088884, + "mcts_value": 0.7524273072394143, + "consensus_score": 0.8473067254603828, + "last_agent": "trm", + "iteration": 1, + "query_length": 2519, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8177718147525759, + "trm_confidence": 0.05745711959236592, + "mcts_value": 0.5421547088642604, + "consensus_score": 0.459217019867586, + "last_agent": "hrm", + "iteration": 5, + "query_length": 4596, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7542779956692605, + "trm_confidence": 0.5934979397686945, + "mcts_value": 0.029212916442549666, + "consensus_score": 0.4055667409489095, + "last_agent": "none", + "iteration": 3, + "query_length": 980, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8759335519044388, + "trm_confidence": 0.3827602430806626, + "mcts_value": 0.06526791031235166, + "consensus_score": 0.39005405924362707, + "last_agent": "mcts", + "iteration": 9, + "query_length": 4272, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8947447150402823, + "trm_confidence": 0.5326404951932822, + "mcts_value": 0.6063131424699416, + "consensus_score": 0.5895211472454582, + "last_agent": "trm", + "iteration": 4, + "query_length": 3721, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8015369449891978, + "trm_confidence": 0.5924331288738127, + "mcts_value": 0.33854244341010364, + "consensus_score": 0.6312296903162472, + "last_agent": "trm", + "iteration": 9, + "query_length": 1856, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9728656731614049, + "trm_confidence": 0.5124803335859882, + "mcts_value": 0.7421752482233381, + "consensus_score": 0.7106252441085485, + "last_agent": "trm", + "iteration": 5, + "query_length": 4242, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7314939147688211, + "trm_confidence": 0.2516834826978827, + "mcts_value": 0.5792931580045768, + "consensus_score": 0.5469899665633753, + "last_agent": "hrm", + "iteration": 1, + "query_length": 52, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7574809028988099, + "trm_confidence": 0.016320735107701632, + "mcts_value": 0.6097875396601716, + "consensus_score": 0.45083785820693534, + "last_agent": "hrm", + "iteration": 3, + "query_length": 720, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7021943368879047, + "trm_confidence": 0.16742333812465543, + "mcts_value": 0.4233627716865501, + "consensus_score": 0.45774743684688934, + "last_agent": "trm", + "iteration": 10, + "query_length": 4016, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8432517620883198, + "trm_confidence": 0.5659360942979396, + "mcts_value": 0.6714000325807745, + "consensus_score": 0.737668485666015, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3121, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.960040431478578, + "trm_confidence": 0.09813411461423771, + "mcts_value": 0.629905225135514, + "consensus_score": 0.5507109969815835, + "last_agent": "trm", + "iteration": 6, + "query_length": 4808, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9909445350116708, + "trm_confidence": 0.8772044609437818, + "mcts_value": 0.25679538015345144, + "consensus_score": 0.7550654915238845, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3189, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7371609324592663, + "trm_confidence": 0.026089803397021747, + "mcts_value": 0.49529267197729376, + "consensus_score": 0.41745441764124175, + "last_agent": "mcts", + "iteration": 10, + "query_length": 4301, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.993375093722152, + "trm_confidence": 0.36769174953390693, + "mcts_value": 0.7090558656421148, + "consensus_score": 0.6070047574284242, + "last_agent": "none", + "iteration": 6, + "query_length": 2707, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9774105001659665, + "trm_confidence": 0.7217430409667073, + "mcts_value": 0.03243850424508178, + "consensus_score": 0.5517378167205987, + "last_agent": "mcts", + "iteration": 0, + "query_length": 1851, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9025916886002452, + "trm_confidence": 0.5724551001857711, + "mcts_value": 0.6209817892745624, + "consensus_score": 0.7717675022979922, + "last_agent": "trm", + "iteration": 8, + "query_length": 1973, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7146891129469838, + "trm_confidence": 0.1441662035796492, + "mcts_value": 0.38227379065973416, + "consensus_score": 0.48533476332042036, + "last_agent": "hrm", + "iteration": 0, + "query_length": 4760, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9031862198093191, + "trm_confidence": 0.023780166634012397, + "mcts_value": 0.3223632530707673, + "consensus_score": 0.49557018944810566, + "last_agent": "trm", + "iteration": 7, + "query_length": 1280, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9558343388895738, + "trm_confidence": 0.29785724316647905, + "mcts_value": 0.7303216724735169, + "consensus_score": 0.6211264817336376, + "last_agent": "trm", + "iteration": 6, + "query_length": 2250, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7824475152265201, + "trm_confidence": 0.6050290054094732, + "mcts_value": 0.12802284211179998, + "consensus_score": 0.42212877249641706, + "last_agent": "trm", + "iteration": 3, + "query_length": 2255, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9422294817928845, + "trm_confidence": 0.8411711110154104, + "mcts_value": 0.2496048616995867, + "consensus_score": 0.6592568750622169, + "last_agent": "trm", + "iteration": 1, + "query_length": 605, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9992740111788752, + "trm_confidence": 0.6302840773303063, + "mcts_value": 0.5352594385037045, + "consensus_score": 0.7000796609219875, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1105, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7403100739939185, + "trm_confidence": 0.23395551080682137, + "mcts_value": 0.04300749469983864, + "consensus_score": 0.2794868340379491, + "last_agent": "trm", + "iteration": 0, + "query_length": 1516, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8903620794855391, + "trm_confidence": 0.2713253444375048, + "mcts_value": 0.33225381110472557, + "consensus_score": 0.5898222662448384, + "last_agent": "mcts", + "iteration": 8, + "query_length": 2521, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7853622625335476, + "trm_confidence": 0.6147677554924093, + "mcts_value": 0.16112669194044266, + "consensus_score": 0.48548744942546856, + "last_agent": "mcts", + "iteration": 9, + "query_length": 1537, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9226953848253929, + "trm_confidence": 0.48600301576815713, + "mcts_value": 0.537581421516067, + "consensus_score": 0.6086365990061708, + "last_agent": "mcts", + "iteration": 2, + "query_length": 1505, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7466324692250267, + "trm_confidence": 0.5653600569069068, + "mcts_value": 0.18315666384793758, + "consensus_score": 0.5106809422095007, + "last_agent": "hrm", + "iteration": 8, + "query_length": 2929, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8315158775752578, + "trm_confidence": 0.348389783042449, + "mcts_value": 0.7276401227978485, + "consensus_score": 0.6707680898676567, + "last_agent": "hrm", + "iteration": 8, + "query_length": 3361, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9362769508688018, + "trm_confidence": 0.15486121796165153, + "mcts_value": 0.47013042707773395, + "consensus_score": 0.4408016967656381, + "last_agent": "hrm", + "iteration": 7, + "query_length": 1601, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8538196192900334, + "trm_confidence": 0.32638315973663695, + "mcts_value": 0.02701896922834707, + "consensus_score": 0.49436215388418314, + "last_agent": "trm", + "iteration": 1, + "query_length": 1475, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7738199834514815, + "trm_confidence": 0.04415571589845285, + "mcts_value": 0.3066674970423677, + "consensus_score": 0.3780986257155231, + "last_agent": "mcts", + "iteration": 3, + "query_length": 247, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.733480110474768, + "trm_confidence": 0.24357587584709162, + "mcts_value": 0.03834356971646067, + "consensus_score": 0.3781044638014227, + "last_agent": "hrm", + "iteration": 2, + "query_length": 4886, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8182377331534292, + "trm_confidence": 0.2992247584654667, + "mcts_value": 0.0011927829157792366, + "consensus_score": 0.2952991657112979, + "last_agent": "hrm", + "iteration": 9, + "query_length": 821, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8524502067438625, + "trm_confidence": 0.3683258859978843, + "mcts_value": 0.2506367703176279, + "consensus_score": 0.4767363583824148, + "last_agent": "trm", + "iteration": 8, + "query_length": 2341, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.778104563024575, + "trm_confidence": 0.21868223955202493, + "mcts_value": 0.1644287290937298, + "consensus_score": 0.38304452400032296, + "last_agent": "mcts", + "iteration": 7, + "query_length": 1673, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7992207226164599, + "trm_confidence": 0.6505442046219245, + "mcts_value": 0.033960653943758315, + "consensus_score": 0.4867291147752941, + "last_agent": "mcts", + "iteration": 7, + "query_length": 965, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7142122060016445, + "trm_confidence": 0.08488742619537909, + "mcts_value": 0.5643524206239566, + "consensus_score": 0.356335973971218, + "last_agent": "mcts", + "iteration": 2, + "query_length": 4085, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7331888402686609, + "trm_confidence": 0.3926716075936083, + "mcts_value": 0.15300306038828654, + "consensus_score": 0.440130246169654, + "last_agent": "none", + "iteration": 6, + "query_length": 3441, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7014224178735622, + "trm_confidence": 0.5132355235771408, + "mcts_value": 0.37235209791058166, + "consensus_score": 0.4615520436747531, + "last_agent": "mcts", + "iteration": 8, + "query_length": 4976, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7762790451262471, + "trm_confidence": 0.6213609131888949, + "mcts_value": 0.3073400370129043, + "consensus_score": 0.589011555976013, + "last_agent": "mcts", + "iteration": 10, + "query_length": 406, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9439478037460338, + "trm_confidence": 0.26871223620600315, + "mcts_value": 0.6744944968198383, + "consensus_score": 0.6491982960298203, + "last_agent": "hrm", + "iteration": 2, + "query_length": 2582, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7952906903090685, + "trm_confidence": 0.054308038543778275, + "mcts_value": 0.02074383789783102, + "consensus_score": 0.2594099199983288, + "last_agent": "hrm", + "iteration": 0, + "query_length": 3514, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9175549962552606, + "trm_confidence": 0.5789035275376786, + "mcts_value": 0.6038300850843744, + "consensus_score": 0.6635268371851003, + "last_agent": "none", + "iteration": 9, + "query_length": 748, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7378079025239879, + "trm_confidence": 0.09166489801650933, + "mcts_value": 0.4420617404990008, + "consensus_score": 0.3584336343015816, + "last_agent": "mcts", + "iteration": 5, + "query_length": 615, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7012022814200018, + "trm_confidence": 0.009967814726302898, + "mcts_value": 0.5970447162996493, + "consensus_score": 0.4529993709754845, + "last_agent": "hrm", + "iteration": 1, + "query_length": 2868, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9641140900383076, + "trm_confidence": 0.4633383320539006, + "mcts_value": 0.5372992540761998, + "consensus_score": 0.6095546636896446, + "last_agent": "trm", + "iteration": 0, + "query_length": 2241, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7884265579608288, + "trm_confidence": 0.4568543529427487, + "mcts_value": 0.5756933887050966, + "consensus_score": 0.5106235419857642, + "last_agent": "none", + "iteration": 6, + "query_length": 789, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9621665493711404, + "trm_confidence": 0.21870105232034467, + "mcts_value": 0.5265855782305323, + "consensus_score": 0.5798592279919281, + "last_agent": "trm", + "iteration": 4, + "query_length": 3761, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9177308904146302, + "trm_confidence": 0.46360254817622953, + "mcts_value": 0.6201852913233449, + "consensus_score": 0.763856325169989, + "last_agent": "none", + "iteration": 4, + "query_length": 4622, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7037409633823639, + "trm_confidence": 0.4805950474667002, + "mcts_value": 0.3140842018588201, + "consensus_score": 0.48106569030807445, + "last_agent": "mcts", + "iteration": 1, + "query_length": 3864, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8185021720121669, + "trm_confidence": 0.49043636088886106, + "mcts_value": 0.10732565579300138, + "consensus_score": 0.5643979888062163, + "last_agent": "hrm", + "iteration": 1, + "query_length": 3413, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9576948581673141, + "trm_confidence": 0.782717527302752, + "mcts_value": 0.1819577406062614, + "consensus_score": 0.6347502196604081, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3226, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8137158085111915, + "trm_confidence": 0.36900498107278173, + "mcts_value": 0.5293747276201115, + "consensus_score": 0.616925701065254, + "last_agent": "trm", + "iteration": 8, + "query_length": 4750, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7313845290414325, + "trm_confidence": 0.5707547505270651, + "mcts_value": 0.5465091100904041, + "consensus_score": 0.6758356170499672, + "last_agent": "trm", + "iteration": 1, + "query_length": 26, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9230801095322956, + "trm_confidence": 0.019046151583096152, + "mcts_value": 0.8059803187685097, + "consensus_score": 0.5581311293657699, + "last_agent": "mcts", + "iteration": 7, + "query_length": 4657, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8183889664450247, + "trm_confidence": 0.2292612586441427, + "mcts_value": 0.437312910421789, + "consensus_score": 0.5111855865865955, + "last_agent": "hrm", + "iteration": 4, + "query_length": 60, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9806144108699484, + "trm_confidence": 0.41181074908830173, + "mcts_value": 0.1732539995949861, + "consensus_score": 0.4973342497610128, + "last_agent": "hrm", + "iteration": 4, + "query_length": 3847, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7488614501663664, + "trm_confidence": 0.44418593967111597, + "mcts_value": 0.22027402789520992, + "consensus_score": 0.5620868176153923, + "last_agent": "trm", + "iteration": 2, + "query_length": 999, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9260513689000773, + "trm_confidence": 0.7277810367370762, + "mcts_value": 0.22973591915308827, + "consensus_score": 0.5682737285925096, + "last_agent": "mcts", + "iteration": 2, + "query_length": 2400, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8405249923740312, + "trm_confidence": 0.19212168542979885, + "mcts_value": 0.033451822114106196, + "consensus_score": 0.3516644347045683, + "last_agent": "none", + "iteration": 10, + "query_length": 1231, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8486519669885855, + "trm_confidence": 0.08286983895231823, + "mcts_value": 0.18898692545925058, + "consensus_score": 0.3324577003816266, + "last_agent": "hrm", + "iteration": 8, + "query_length": 1515, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9704925560699316, + "trm_confidence": 0.8570596444224141, + "mcts_value": 0.8551428849736049, + "consensus_score": 0.9848315217089132, + "last_agent": "trm", + "iteration": 0, + "query_length": 2056, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7913399579428099, + "trm_confidence": 0.38224065571157784, + "mcts_value": 0.06704468949902856, + "consensus_score": 0.4827011592782339, + "last_agent": "trm", + "iteration": 6, + "query_length": 2019, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7495933970101873, + "trm_confidence": 0.16451308392705075, + "mcts_value": 0.10413763288980703, + "consensus_score": 0.40980133616797476, + "last_agent": "hrm", + "iteration": 6, + "query_length": 2625, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7888098588730077, + "trm_confidence": 0.25572247655900815, + "mcts_value": 0.2788680909106773, + "consensus_score": 0.4931364697963714, + "last_agent": "none", + "iteration": 8, + "query_length": 225, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9824616570214297, + "trm_confidence": 0.10647370900903526, + "mcts_value": 0.7907933777490511, + "consensus_score": 0.5466553770564673, + "last_agent": "hrm", + "iteration": 2, + "query_length": 897, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7538476217038262, + "trm_confidence": 0.27041040829999513, + "mcts_value": 0.29413943336981513, + "consensus_score": 0.3885176157291874, + "last_agent": "mcts", + "iteration": 7, + "query_length": 679, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9623725419566715, + "trm_confidence": 0.29262210136103023, + "mcts_value": 0.4577908039206839, + "consensus_score": 0.5206097631859256, + "last_agent": "hrm", + "iteration": 2, + "query_length": 1488, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9820057032422544, + "trm_confidence": 0.7832112112455303, + "mcts_value": 0.6856348186200852, + "consensus_score": 0.8204828819601692, + "last_agent": "hrm", + "iteration": 5, + "query_length": 2263, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8609704040731267, + "trm_confidence": 0.33068964615766583, + "mcts_value": 0.10026152842446387, + "consensus_score": 0.355773296792126, + "last_agent": "none", + "iteration": 10, + "query_length": 456, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9859616574715566, + "trm_confidence": 0.14492284170007116, + "mcts_value": 0.4911918073749651, + "consensus_score": 0.4822355528525055, + "last_agent": "none", + "iteration": 2, + "query_length": 4373, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7356859932220634, + "trm_confidence": 0.5828276906378326, + "mcts_value": 0.2043940190441073, + "consensus_score": 0.5292617433509195, + "last_agent": "none", + "iteration": 5, + "query_length": 4855, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8595658274038205, + "trm_confidence": 0.14222030388897, + "mcts_value": 0.7509097381536648, + "consensus_score": 0.6478935302229244, + "last_agent": "mcts", + "iteration": 8, + "query_length": 1861, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7458631366379007, + "trm_confidence": 0.5946241390199914, + "mcts_value": 0.220489322146777, + "consensus_score": 0.4303468705409431, + "last_agent": "mcts", + "iteration": 3, + "query_length": 294, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8868126234964624, + "trm_confidence": 0.590769276685474, + "mcts_value": 0.6244423384763792, + "consensus_score": 0.643065668435547, + "last_agent": "none", + "iteration": 10, + "query_length": 2660, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8917040947792494, + "trm_confidence": 0.0018278647108377233, + "mcts_value": 0.786453907816844, + "consensus_score": 0.5161745662951414, + "last_agent": "mcts", + "iteration": 0, + "query_length": 1994, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7387090171546125, + "trm_confidence": 0.0972924433311407, + "mcts_value": 0.4038447340742791, + "consensus_score": 0.3918675431673908, + "last_agent": "none", + "iteration": 10, + "query_length": 3242, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9178540385151703, + "trm_confidence": 0.3771120311092349, + "mcts_value": 0.539743836851922, + "consensus_score": 0.6315090401686404, + "last_agent": "none", + "iteration": 5, + "query_length": 3074, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8028140928264867, + "trm_confidence": 0.12456360802715516, + "mcts_value": 0.47657406224828497, + "consensus_score": 0.5371854332955573, + "last_agent": "trm", + "iteration": 0, + "query_length": 4940, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9677463069061979, + "trm_confidence": 0.6504787881584428, + "mcts_value": 0.8606374831518863, + "consensus_score": 0.8325702883916245, + "last_agent": "mcts", + "iteration": 7, + "query_length": 869, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9842606798997977, + "trm_confidence": 0.3239360681824274, + "mcts_value": 0.6530359849135624, + "consensus_score": 0.6343722851557455, + "last_agent": "mcts", + "iteration": 6, + "query_length": 756, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8526802455254062, + "trm_confidence": 0.6827459592509878, + "mcts_value": 0.3155995705425496, + "consensus_score": 0.6464229514016462, + "last_agent": "hrm", + "iteration": 3, + "query_length": 3486, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8320150324700657, + "trm_confidence": 0.09209262859482964, + "mcts_value": 0.06714311190465266, + "consensus_score": 0.3639149089525342, + "last_agent": "none", + "iteration": 7, + "query_length": 3668, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7059299336028074, + "trm_confidence": 0.197864506698318, + "mcts_value": 0.11985890761574668, + "consensus_score": 0.39684806714775983, + "last_agent": "mcts", + "iteration": 9, + "query_length": 3759, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9003547476543958, + "trm_confidence": 0.10353613391771245, + "mcts_value": 0.22021963019442897, + "consensus_score": 0.3357677365021964, + "last_agent": "hrm", + "iteration": 3, + "query_length": 1412, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9730980730555757, + "trm_confidence": 0.17682192410918202, + "mcts_value": 0.47328816085929903, + "consensus_score": 0.5813476062567055, + "last_agent": "mcts", + "iteration": 10, + "query_length": 2916, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7765320793082142, + "trm_confidence": 0.4231560380588548, + "mcts_value": 0.21682255200993333, + "consensus_score": 0.5673959012640091, + "last_agent": "trm", + "iteration": 4, + "query_length": 2207, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7069682747040275, + "trm_confidence": 0.2428246057360107, + "mcts_value": 0.531162558238724, + "consensus_score": 0.51569730177957, + "last_agent": "hrm", + "iteration": 4, + "query_length": 890, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9561953128441687, + "trm_confidence": 0.27587954965212524, + "mcts_value": 0.4662397894445769, + "consensus_score": 0.5640697351987828, + "last_agent": "hrm", + "iteration": 6, + "query_length": 4067, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7732615002297665, + "trm_confidence": 0.6144668090215966, + "mcts_value": 0.622009131556389, + "consensus_score": 0.7385770574216465, + "last_agent": "mcts", + "iteration": 10, + "query_length": 114, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9294001685685999, + "trm_confidence": 0.6324923508225583, + "mcts_value": 0.39861605011890444, + "consensus_score": 0.6447932954062084, + "last_agent": "trm", + "iteration": 2, + "query_length": 3035, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7867960408307003, + "trm_confidence": 0.17706179507643613, + "mcts_value": 0.4024359518088873, + "consensus_score": 0.4374277402170165, + "last_agent": "hrm", + "iteration": 2, + "query_length": 3291, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9581709757883128, + "trm_confidence": 0.14391486044111265, + "mcts_value": 0.006442317663250758, + "consensus_score": 0.3367353783367668, + "last_agent": "hrm", + "iteration": 8, + "query_length": 680, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7414987210348147, + "trm_confidence": 0.49673549026026803, + "mcts_value": 0.07422407534739722, + "consensus_score": 0.4154986135787992, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1439, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7461890747732879, + "trm_confidence": 0.27320895570225096, + "mcts_value": 0.4004230314625919, + "consensus_score": 0.5598538437953609, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3167, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7413992552043716, + "trm_confidence": 0.5071787475187043, + "mcts_value": 0.17358326090088597, + "consensus_score": 0.5510235664425928, + "last_agent": "none", + "iteration": 7, + "query_length": 2268, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9497587854772022, + "trm_confidence": 0.14986038328681847, + "mcts_value": 0.359868932238171, + "consensus_score": 0.49693301937926315, + "last_agent": "none", + "iteration": 5, + "query_length": 1482, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8958060204607765, + "trm_confidence": 0.789473000549209, + "mcts_value": 0.7953459723410123, + "consensus_score": 0.830166878310794, + "last_agent": "none", + "iteration": 1, + "query_length": 3420, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9929992234649767, + "trm_confidence": 0.28426521696112633, + "mcts_value": 0.41185917690810053, + "consensus_score": 0.5479965907411591, + "last_agent": "none", + "iteration": 0, + "query_length": 1850, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8060013085205403, + "trm_confidence": 0.5559505259024953, + "mcts_value": 0.5874778619952927, + "consensus_score": 0.7140560561256785, + "last_agent": "mcts", + "iteration": 4, + "query_length": 1979, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9246462339001165, + "trm_confidence": 0.660862778505748, + "mcts_value": 0.4051136030420285, + "consensus_score": 0.7421470757232248, + "last_agent": "hrm", + "iteration": 1, + "query_length": 2854, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7288285698915643, + "trm_confidence": 0.09654123447545868, + "mcts_value": 0.3356938520979582, + "consensus_score": 0.3005249148591434, + "last_agent": "trm", + "iteration": 0, + "query_length": 1141, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8308949979082789, + "trm_confidence": 0.5661687200967872, + "mcts_value": 0.02546317595496685, + "consensus_score": 0.504627152613182, + "last_agent": "trm", + "iteration": 9, + "query_length": 1717, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7440591204464212, + "trm_confidence": 0.5544067179494655, + "mcts_value": 0.5301409123022616, + "consensus_score": 0.6173239368188868, + "last_agent": "hrm", + "iteration": 9, + "query_length": 226, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8183444553816263, + "trm_confidence": 0.5299927836231273, + "mcts_value": 0.18493918750750507, + "consensus_score": 0.5615579250574635, + "last_agent": "none", + "iteration": 5, + "query_length": 2699, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8394038913571988, + "trm_confidence": 0.2720299008072148, + "mcts_value": 0.1620693596078362, + "consensus_score": 0.3661769188777529, + "last_agent": "trm", + "iteration": 8, + "query_length": 4810, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7257947050332835, + "trm_confidence": 0.11032220277393885, + "mcts_value": 0.10973091841392231, + "consensus_score": 0.37990413330176154, + "last_agent": "none", + "iteration": 4, + "query_length": 3499, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7411633773371836, + "trm_confidence": 0.4190258018660907, + "mcts_value": 0.2877089263480039, + "consensus_score": 0.4612163371898409, + "last_agent": "none", + "iteration": 9, + "query_length": 4766, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9618576876322535, + "trm_confidence": 0.16573581856333194, + "mcts_value": 0.19029638580714042, + "consensus_score": 0.470511406486588, + "last_agent": "mcts", + "iteration": 3, + "query_length": 1267, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8699261952428934, + "trm_confidence": 0.42415909204976393, + "mcts_value": 0.6379199814213259, + "consensus_score": 0.68610831058653, + "last_agent": "hrm", + "iteration": 0, + "query_length": 2657, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8804675357368047, + "trm_confidence": 0.3794571311556796, + "mcts_value": 0.20306567679207435, + "consensus_score": 0.4713946653578132, + "last_agent": "hrm", + "iteration": 8, + "query_length": 528, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8683669203670132, + "trm_confidence": 0.2961051612096069, + "mcts_value": 0.20799443698331926, + "consensus_score": 0.4618725336738305, + "last_agent": "none", + "iteration": 3, + "query_length": 2258, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9029957147528913, + "trm_confidence": 0.05306132320252652, + "mcts_value": 0.0008571698389146299, + "consensus_score": 0.26142683036625936, + "last_agent": "trm", + "iteration": 9, + "query_length": 3870, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7515881055811068, + "trm_confidence": 0.5800854083517826, + "mcts_value": 0.3062177008384156, + "consensus_score": 0.5411820895928338, + "last_agent": "mcts", + "iteration": 10, + "query_length": 2379, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7645534823886155, + "trm_confidence": 0.38927637401496484, + "mcts_value": 0.12991124070271629, + "consensus_score": 0.4638191008772238, + "last_agent": "none", + "iteration": 2, + "query_length": 4463, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7620955478922519, + "trm_confidence": 0.27989032203406305, + "mcts_value": 0.1166203836958097, + "consensus_score": 0.31314048952537765, + "last_agent": "trm", + "iteration": 9, + "query_length": 909, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.808227605955476, + "trm_confidence": 0.03914402980123324, + "mcts_value": 0.25309903289406965, + "consensus_score": 0.3197849916812422, + "last_agent": "mcts", + "iteration": 6, + "query_length": 2099, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9665096359188521, + "trm_confidence": 0.8055650852189132, + "mcts_value": 0.08650457300765457, + "consensus_score": 0.541007022639522, + "last_agent": "none", + "iteration": 1, + "query_length": 2305, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7365906892628868, + "trm_confidence": 0.5426186472867254, + "mcts_value": 0.48016222358567484, + "consensus_score": 0.6504189583593672, + "last_agent": "trm", + "iteration": 5, + "query_length": 724, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7132525308832484, + "trm_confidence": 0.24607675484191002, + "mcts_value": 0.19865498207189433, + "consensus_score": 0.47572385567487907, + "last_agent": "none", + "iteration": 6, + "query_length": 4580, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7511023580975557, + "trm_confidence": 0.6371909988083584, + "mcts_value": 0.32466538481186547, + "consensus_score": 0.569720662493389, + "last_agent": "hrm", + "iteration": 10, + "query_length": 4783, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8192575073664978, + "trm_confidence": 0.42058777144530124, + "mcts_value": 0.09210280147437415, + "consensus_score": 0.40819981871109534, + "last_agent": "none", + "iteration": 2, + "query_length": 4910, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.959843381899306, + "trm_confidence": 0.5067325029417107, + "mcts_value": 0.3084024260626239, + "consensus_score": 0.5733933124412227, + "last_agent": "hrm", + "iteration": 4, + "query_length": 4006, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9780308174466981, + "trm_confidence": 0.8190092519450405, + "mcts_value": 0.34374202281113436, + "consensus_score": 0.679446522653845, + "last_agent": "hrm", + "iteration": 4, + "query_length": 4847, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8385453734349244, + "trm_confidence": 0.48651928305460523, + "mcts_value": 0.40944883375396496, + "consensus_score": 0.5423142727533135, + "last_agent": "mcts", + "iteration": 3, + "query_length": 361, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7283919304748165, + "trm_confidence": 0.10435241816234642, + "mcts_value": 0.45439887808312374, + "consensus_score": 0.4068670845445116, + "last_agent": "mcts", + "iteration": 2, + "query_length": 3905, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9266618627420282, + "trm_confidence": 0.21280124956787808, + "mcts_value": 0.6834756638873442, + "consensus_score": 0.6931472306574268, + "last_agent": "none", + "iteration": 6, + "query_length": 532, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7157727970969165, + "trm_confidence": 0.5822254280053311, + "mcts_value": 0.24168142263616446, + "consensus_score": 0.5980066293552234, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4895, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7115625997970306, + "trm_confidence": 0.41611869374628807, + "mcts_value": 0.34410671336168863, + "consensus_score": 0.395902489742884, + "last_agent": "hrm", + "iteration": 8, + "query_length": 763, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8489252447317924, + "trm_confidence": 0.3488175050052104, + "mcts_value": 0.004098949373417123, + "consensus_score": 0.4579438422021932, + "last_agent": "mcts", + "iteration": 3, + "query_length": 4405, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8119181762674159, + "trm_confidence": 0.40228923558521373, + "mcts_value": 0.19246123282373728, + "consensus_score": 0.40100721074243934, + "last_agent": "hrm", + "iteration": 8, + "query_length": 757, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8609463771256752, + "trm_confidence": 0.734629772221256, + "mcts_value": 0.733232643276524, + "consensus_score": 0.8473980218816325, + "last_agent": "mcts", + "iteration": 2, + "query_length": 4657, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9634442273966612, + "trm_confidence": 0.3233287242950655, + "mcts_value": 0.0872995353312898, + "consensus_score": 0.5202167427889899, + "last_agent": "trm", + "iteration": 5, + "query_length": 1230, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9946432888953534, + "trm_confidence": 0.5432376312320689, + "mcts_value": 0.3802049553050477, + "consensus_score": 0.6133588320469295, + "last_agent": "none", + "iteration": 4, + "query_length": 652, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9384251675252715, + "trm_confidence": 0.7023628971642141, + "mcts_value": 0.23844571519841753, + "consensus_score": 0.5871705836446232, + "last_agent": "trm", + "iteration": 9, + "query_length": 339, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8561372331301683, + "trm_confidence": 0.3506024719689236, + "mcts_value": 0.5218305460581177, + "consensus_score": 0.5088684237489608, + "last_agent": "trm", + "iteration": 7, + "query_length": 576, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8856144334508926, + "trm_confidence": 0.031499218465408804, + "mcts_value": 0.36273955042497696, + "consensus_score": 0.43106791291023394, + "last_agent": "mcts", + "iteration": 2, + "query_length": 2242, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7565422909036904, + "trm_confidence": 0.22594317178577628, + "mcts_value": 0.06355995720741427, + "consensus_score": 0.4381457657782145, + "last_agent": "none", + "iteration": 4, + "query_length": 1131, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.992325136932654, + "trm_confidence": 0.07204211978374034, + "mcts_value": 0.5815010472099233, + "consensus_score": 0.6229045853236821, + "last_agent": "mcts", + "iteration": 6, + "query_length": 1805, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8293229072913815, + "trm_confidence": 0.026181879911408894, + "mcts_value": 0.38989657417045104, + "consensus_score": 0.5138423864817833, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3307, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9117391010888465, + "trm_confidence": 0.2449543416327622, + "mcts_value": 0.732991603678876, + "consensus_score": 0.5389179646086839, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3812, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7352513146757804, + "trm_confidence": 0.038003488930421665, + "mcts_value": 0.5302491411336361, + "consensus_score": 0.3656528491081953, + "last_agent": "trm", + "iteration": 5, + "query_length": 355, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.870234434064797, + "trm_confidence": 0.3059627671516115, + "mcts_value": 0.34276668916021336, + "consensus_score": 0.41894473093298157, + "last_agent": "none", + "iteration": 10, + "query_length": 1451, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7568090475018561, + "trm_confidence": 0.18787845586226595, + "mcts_value": 0.6334090395729941, + "consensus_score": 0.4448893916766401, + "last_agent": "trm", + "iteration": 9, + "query_length": 646, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9490057537477085, + "trm_confidence": 0.8285426268405806, + "mcts_value": 0.7149941480792436, + "consensus_score": 0.9195239798114148, + "last_agent": "trm", + "iteration": 1, + "query_length": 4616, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8637027082243143, + "trm_confidence": 0.4150690578704749, + "mcts_value": 0.6865088423172458, + "consensus_score": 0.7494542998666873, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3824, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8369211573705536, + "trm_confidence": 0.2847031937971334, + "mcts_value": 0.2664699753767556, + "consensus_score": 0.44099947718503385, + "last_agent": "none", + "iteration": 0, + "query_length": 3622, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9339609531616935, + "trm_confidence": 0.3720073749593458, + "mcts_value": 0.7761545231594749, + "consensus_score": 0.7164851521456285, + "last_agent": "none", + "iteration": 0, + "query_length": 1583, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7617459379039022, + "trm_confidence": 0.6456893506063608, + "mcts_value": 0.14472663325826077, + "consensus_score": 0.5205384425137849, + "last_agent": "trm", + "iteration": 10, + "query_length": 4357, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7709081306150634, + "trm_confidence": 0.032953426548556077, + "mcts_value": 0.10305271832695558, + "consensus_score": 0.21497045423736993, + "last_agent": "none", + "iteration": 2, + "query_length": 751, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9737346388886925, + "trm_confidence": 0.46938674734886066, + "mcts_value": 0.546948759270564, + "consensus_score": 0.6037301654912806, + "last_agent": "trm", + "iteration": 8, + "query_length": 4317, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9592410487578175, + "trm_confidence": 0.8152849463702759, + "mcts_value": 0.2989306274247526, + "consensus_score": 0.6092680192247032, + "last_agent": "hrm", + "iteration": 2, + "query_length": 2042, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9837080955473421, + "trm_confidence": 0.33431545265646406, + "mcts_value": 0.6811385287471287, + "consensus_score": 0.6744826974177959, + "last_agent": "hrm", + "iteration": 2, + "query_length": 3086, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8959575655068456, + "trm_confidence": 0.5355627499764045, + "mcts_value": 0.346050309650501, + "consensus_score": 0.5382526622436733, + "last_agent": "mcts", + "iteration": 7, + "query_length": 4515, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9810009075595381, + "trm_confidence": 0.16554322366482135, + "mcts_value": 0.09454228904780895, + "consensus_score": 0.41322501389713184, + "last_agent": "trm", + "iteration": 5, + "query_length": 167, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8310002468016519, + "trm_confidence": 0.7272099321804025, + "mcts_value": 0.3554712478770538, + "consensus_score": 0.6334831251010022, + "last_agent": "mcts", + "iteration": 4, + "query_length": 4867, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8778869893448848, + "trm_confidence": 0.17733316400235186, + "mcts_value": 0.49598773000963786, + "consensus_score": 0.427235986841987, + "last_agent": "trm", + "iteration": 10, + "query_length": 96, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9690266114772095, + "trm_confidence": 0.41327531492474356, + "mcts_value": 0.05242908683320221, + "consensus_score": 0.5426863802971464, + "last_agent": "trm", + "iteration": 7, + "query_length": 3741, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.828450959052662, + "trm_confidence": 0.46470563316140123, + "mcts_value": 0.623720838984344, + "consensus_score": 0.66517223230718, + "last_agent": "hrm", + "iteration": 3, + "query_length": 2180, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9015562565842201, + "trm_confidence": 0.769964489790417, + "mcts_value": 0.29730709197590294, + "consensus_score": 0.6412922996196729, + "last_agent": "trm", + "iteration": 8, + "query_length": 1853, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9209719279347199, + "trm_confidence": 0.3774085653223031, + "mcts_value": 0.1769154620345386, + "consensus_score": 0.5408060869867227, + "last_agent": "hrm", + "iteration": 1, + "query_length": 1940, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8880474948352626, + "trm_confidence": 0.5888840940074078, + "mcts_value": 0.7050565479056364, + "consensus_score": 0.6818466783992806, + "last_agent": "trm", + "iteration": 1, + "query_length": 507, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7463269264066427, + "trm_confidence": 0.12775749309271756, + "mcts_value": 0.18829407029016992, + "consensus_score": 0.3600044340684342, + "last_agent": "none", + "iteration": 9, + "query_length": 3641, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9124472661527905, + "trm_confidence": 0.1423249734377036, + "mcts_value": 0.29059827613361366, + "consensus_score": 0.44413782379203814, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1951, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7924999611858532, + "trm_confidence": 0.08813876240274264, + "mcts_value": 0.38104096528938425, + "consensus_score": 0.4615655355974986, + "last_agent": "none", + "iteration": 0, + "query_length": 2120, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9346633585803386, + "trm_confidence": 0.6243199532163899, + "mcts_value": 0.6803116479784208, + "consensus_score": 0.7360797798925045, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1391, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8082431684252546, + "trm_confidence": 0.1249026322116057, + "mcts_value": 0.33109273139471257, + "consensus_score": 0.3998248484936129, + "last_agent": "trm", + "iteration": 2, + "query_length": 4551, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.794154031633494, + "trm_confidence": 0.18936822561780658, + "mcts_value": 0.09085805086171304, + "consensus_score": 0.3458911097313942, + "last_agent": "trm", + "iteration": 3, + "query_length": 4304, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7707991787861658, + "trm_confidence": 0.5119240758573433, + "mcts_value": 0.6238454816739297, + "consensus_score": 0.6351092826635906, + "last_agent": "trm", + "iteration": 7, + "query_length": 128, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9996488912000976, + "trm_confidence": 0.5519494252201751, + "mcts_value": 0.8004891104852072, + "consensus_score": 0.865750791612955, + "last_agent": "trm", + "iteration": 8, + "query_length": 1800, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7243462001271698, + "trm_confidence": 0.009868138479946883, + "mcts_value": 0.297938402088473, + "consensus_score": 0.34291382027212014, + "last_agent": "hrm", + "iteration": 4, + "query_length": 2262, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8656488109913281, + "trm_confidence": 0.6405517801265206, + "mcts_value": 0.08319318191851317, + "consensus_score": 0.47483345756656303, + "last_agent": "none", + "iteration": 8, + "query_length": 352, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9984472640862603, + "trm_confidence": 0.7842654129109966, + "mcts_value": 0.7995273282172154, + "consensus_score": 0.8236162569308051, + "last_agent": "none", + "iteration": 4, + "query_length": 3748, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8818898194578146, + "trm_confidence": 0.734042348906971, + "mcts_value": 0.26947120053234996, + "consensus_score": 0.6398681460086686, + "last_agent": "trm", + "iteration": 3, + "query_length": 1085, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7585556589455966, + "trm_confidence": 0.03816023455513779, + "mcts_value": 0.6251626215781484, + "consensus_score": 0.4545354267278202, + "last_agent": "none", + "iteration": 9, + "query_length": 25, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7575763484233256, + "trm_confidence": 0.11537982339162972, + "mcts_value": 0.5339442365121275, + "consensus_score": 0.42190914489960163, + "last_agent": "mcts", + "iteration": 10, + "query_length": 3850, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9033125843363409, + "trm_confidence": 0.654277114228846, + "mcts_value": 0.18898441735980578, + "consensus_score": 0.5374328446663108, + "last_agent": "hrm", + "iteration": 10, + "query_length": 2813, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8289871149700143, + "trm_confidence": 0.3152982944972967, + "mcts_value": 0.18445907233839906, + "consensus_score": 0.49807304455598805, + "last_agent": "trm", + "iteration": 8, + "query_length": 2898, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9505377373242201, + "trm_confidence": 0.18122827258423413, + "mcts_value": 0.003662224693813079, + "consensus_score": 0.39270955061190405, + "last_agent": "hrm", + "iteration": 10, + "query_length": 4416, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7623994221125381, + "trm_confidence": 0.3437287593091789, + "mcts_value": 0.5365042757664858, + "consensus_score": 0.46535255341107823, + "last_agent": "none", + "iteration": 4, + "query_length": 1698, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8114629625219084, + "trm_confidence": 0.13435011914363088, + "mcts_value": 0.1390880405202477, + "consensus_score": 0.34364346477826113, + "last_agent": "hrm", + "iteration": 5, + "query_length": 3708, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8926757255703981, + "trm_confidence": 0.5472648702307961, + "mcts_value": 0.7791900694236185, + "consensus_score": 0.722078909801799, + "last_agent": "trm", + "iteration": 4, + "query_length": 1223, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7737626761015175, + "trm_confidence": 0.49172227340870756, + "mcts_value": 0.5084032169201707, + "consensus_score": 0.5108536549658871, + "last_agent": "hrm", + "iteration": 5, + "query_length": 663, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7556125818407456, + "trm_confidence": 0.3040915450308808, + "mcts_value": 0.18975375780077433, + "consensus_score": 0.4756980510826412, + "last_agent": "none", + "iteration": 9, + "query_length": 2699, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7798092617864996, + "trm_confidence": 0.18635909980374252, + "mcts_value": 0.17316669164944853, + "consensus_score": 0.30678597029114407, + "last_agent": "none", + "iteration": 10, + "query_length": 2572, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7674068423637621, + "trm_confidence": 0.5650740523405162, + "mcts_value": 0.2664256764332823, + "consensus_score": 0.4399922192677869, + "last_agent": "none", + "iteration": 1, + "query_length": 3936, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9348715822515172, + "trm_confidence": 0.176103126668828, + "mcts_value": 0.7659075071611872, + "consensus_score": 0.6958268293709436, + "last_agent": "none", + "iteration": 9, + "query_length": 371, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7151607965950173, + "trm_confidence": 0.5336049555333756, + "mcts_value": 0.18811631230975562, + "consensus_score": 0.5027439643807547, + "last_agent": "trm", + "iteration": 8, + "query_length": 3017, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9342515063557144, + "trm_confidence": 0.7829545983703684, + "mcts_value": 0.6076479123657337, + "consensus_score": 0.7631958887267172, + "last_agent": "none", + "iteration": 9, + "query_length": 944, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9234312783450263, + "trm_confidence": 0.6756302582411827, + "mcts_value": 0.6170405848403161, + "consensus_score": 0.6962716823518058, + "last_agent": "hrm", + "iteration": 1, + "query_length": 4935, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8508624810941059, + "trm_confidence": 0.3793934835125877, + "mcts_value": 0.4374857817217746, + "consensus_score": 0.534934282952401, + "last_agent": "mcts", + "iteration": 8, + "query_length": 4806, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7740175160438414, + "trm_confidence": 0.496429549704044, + "mcts_value": 0.4569649040000525, + "consensus_score": 0.5804220897718424, + "last_agent": "mcts", + "iteration": 2, + "query_length": 1691, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9881878856886498, + "trm_confidence": 0.6411262857398229, + "mcts_value": 0.815232129761202, + "consensus_score": 0.911685820052825, + "last_agent": "hrm", + "iteration": 3, + "query_length": 1784, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8421798212914233, + "trm_confidence": 0.7322559470279331, + "mcts_value": 0.06962227405779058, + "consensus_score": 0.5304943303184939, + "last_agent": "trm", + "iteration": 9, + "query_length": 495, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9363378551916618, + "trm_confidence": 0.3201360493631198, + "mcts_value": 0.8195931501855913, + "consensus_score": 0.7687283188648658, + "last_agent": "mcts", + "iteration": 3, + "query_length": 1255, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8010569631499805, + "trm_confidence": 0.4250948913284999, + "mcts_value": 0.46163252778246694, + "consensus_score": 0.6285138759041234, + "last_agent": "mcts", + "iteration": 0, + "query_length": 2283, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9135753431765736, + "trm_confidence": 0.458468610905465, + "mcts_value": 0.4032120715775476, + "consensus_score": 0.5027939998947915, + "last_agent": "hrm", + "iteration": 3, + "query_length": 4093, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7092678511108096, + "trm_confidence": 0.24153585322317822, + "mcts_value": 0.575848466210233, + "consensus_score": 0.41906020375707653, + "last_agent": "hrm", + "iteration": 8, + "query_length": 2699, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9599826496461017, + "trm_confidence": 0.6416055173615868, + "mcts_value": 0.7616910796400281, + "consensus_score": 0.7309447069688255, + "last_agent": "trm", + "iteration": 5, + "query_length": 2053, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7737099228984943, + "trm_confidence": 0.15545607553068916, + "mcts_value": 0.018312873062874588, + "consensus_score": 0.4075293492924431, + "last_agent": "hrm", + "iteration": 7, + "query_length": 548, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9231984168084972, + "trm_confidence": 0.43762291302503004, + "mcts_value": 0.3919526077106061, + "consensus_score": 0.574164427406082, + "last_agent": "none", + "iteration": 7, + "query_length": 1559, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9664176085044609, + "trm_confidence": 0.7033165481696211, + "mcts_value": 0.3028680682544538, + "consensus_score": 0.6732559343386835, + "last_agent": "none", + "iteration": 1, + "query_length": 1323, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.970867806289085, + "trm_confidence": 0.8028508591083362, + "mcts_value": 0.2892701046838016, + "consensus_score": 0.6214898648369789, + "last_agent": "hrm", + "iteration": 2, + "query_length": 674, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7360331890304153, + "trm_confidence": 0.04535205026991548, + "mcts_value": 0.6295603232822491, + "consensus_score": 0.38512672420134136, + "last_agent": "none", + "iteration": 1, + "query_length": 2644, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8740720650228829, + "trm_confidence": 0.5994547247302315, + "mcts_value": 0.48837136070431225, + "consensus_score": 0.7304010642182652, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2206, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9138249081283899, + "trm_confidence": 0.36409936512452107, + "mcts_value": 0.28163535888200514, + "consensus_score": 0.5133837012016812, + "last_agent": "none", + "iteration": 1, + "query_length": 2109, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7782111344007834, + "trm_confidence": 0.1041268042075554, + "mcts_value": 0.16707883336197832, + "consensus_score": 0.41821962101122034, + "last_agent": "trm", + "iteration": 3, + "query_length": 4765, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9569228891967217, + "trm_confidence": 0.06662118460104391, + "mcts_value": 0.6555669235911721, + "consensus_score": 0.5212367590721186, + "last_agent": "mcts", + "iteration": 2, + "query_length": 2268, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9729389868199702, + "trm_confidence": 0.25753376914582776, + "mcts_value": 0.49655720574113554, + "consensus_score": 0.5336434766722095, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1430, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8831006166370539, + "trm_confidence": 0.48829954656903063, + "mcts_value": 0.10578106858074822, + "consensus_score": 0.5311018378462882, + "last_agent": "trm", + "iteration": 7, + "query_length": 2802, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7222211658216119, + "trm_confidence": 0.27803258546131354, + "mcts_value": 0.2545103167234909, + "consensus_score": 0.3355969860641119, + "last_agent": "mcts", + "iteration": 2, + "query_length": 1660, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9228718929574247, + "trm_confidence": 0.19507172003678855, + "mcts_value": 0.682777879225739, + "consensus_score": 0.6094001545545434, + "last_agent": "none", + "iteration": 8, + "query_length": 4805, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.849758267626667, + "trm_confidence": 0.4646397357959349, + "mcts_value": 0.695753635586873, + "consensus_score": 0.6684130626190737, + "last_agent": "mcts", + "iteration": 5, + "query_length": 3214, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9212545411893177, + "trm_confidence": 0.05900741430404592, + "mcts_value": 0.4235106507204567, + "consensus_score": 0.4642425199124821, + "last_agent": "trm", + "iteration": 10, + "query_length": 1681, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8404961604082992, + "trm_confidence": 0.7125053970548564, + "mcts_value": 0.3336501412668026, + "consensus_score": 0.5650607856410709, + "last_agent": "hrm", + "iteration": 2, + "query_length": 728, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7001704387776833, + "trm_confidence": 0.25924101689077567, + "mcts_value": 0.46167665754423126, + "consensus_score": 0.5203893262975305, + "last_agent": "trm", + "iteration": 0, + "query_length": 1085, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7462849376240428, + "trm_confidence": 0.6082034279986184, + "mcts_value": 0.10816547835548636, + "consensus_score": 0.5128791248655455, + "last_agent": "none", + "iteration": 3, + "query_length": 1982, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8236775789392226, + "trm_confidence": 0.549736644951093, + "mcts_value": 0.16693731092908165, + "consensus_score": 0.5946088821893177, + "last_agent": "mcts", + "iteration": 5, + "query_length": 790, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9549190574203681, + "trm_confidence": 0.8539176146019773, + "mcts_value": 0.19145541734296373, + "consensus_score": 0.6241425125011274, + "last_agent": "hrm", + "iteration": 2, + "query_length": 3414, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8747334224618132, + "trm_confidence": 0.769715104649675, + "mcts_value": 0.7672015316023072, + "consensus_score": 0.80923799917706, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3612, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9798800375360516, + "trm_confidence": 0.17937289642684212, + "mcts_value": 0.46505331630125507, + "consensus_score": 0.6080028276870187, + "last_agent": "none", + "iteration": 3, + "query_length": 4694, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9956863304536023, + "trm_confidence": 0.8822928611113594, + "mcts_value": 0.8115874843332822, + "consensus_score": 0.9562118942650263, + "last_agent": "mcts", + "iteration": 2, + "query_length": 198, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8641848523614624, + "trm_confidence": 0.4285375751873198, + "mcts_value": 0.38629759737868324, + "consensus_score": 0.5870073976257497, + "last_agent": "hrm", + "iteration": 4, + "query_length": 3396, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8000649779435066, + "trm_confidence": 0.2314953845736733, + "mcts_value": 0.08906644020564398, + "consensus_score": 0.31128470958005844, + "last_agent": "trm", + "iteration": 9, + "query_length": 4928, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7683603359658501, + "trm_confidence": 0.507513472723641, + "mcts_value": 0.370259011369696, + "consensus_score": 0.4529421423043878, + "last_agent": "mcts", + "iteration": 3, + "query_length": 569, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8967618582948522, + "trm_confidence": 0.21913424302132728, + "mcts_value": 0.05174711649358016, + "consensus_score": 0.30020140377612037, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4028, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9546376339408267, + "trm_confidence": 0.1328082263232585, + "mcts_value": 0.5183944175395805, + "consensus_score": 0.524217698203371, + "last_agent": "trm", + "iteration": 4, + "query_length": 1919, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9408256750636647, + "trm_confidence": 0.7509014583599367, + "mcts_value": 0.009263096790295754, + "consensus_score": 0.5027175785440559, + "last_agent": "trm", + "iteration": 9, + "query_length": 1692, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.745011873474463, + "trm_confidence": 0.13080224976016785, + "mcts_value": 0.4256989638627555, + "consensus_score": 0.4788010601485034, + "last_agent": "none", + "iteration": 6, + "query_length": 4799, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8482190810985331, + "trm_confidence": 0.22466917418298077, + "mcts_value": 0.3570747693720694, + "consensus_score": 0.44858760028451655, + "last_agent": "trm", + "iteration": 2, + "query_length": 3672, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8056637480695372, + "trm_confidence": 0.5020157867165614, + "mcts_value": 0.013010889083076554, + "consensus_score": 0.42806584393091207, + "last_agent": "trm", + "iteration": 7, + "query_length": 1246, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9329945307666148, + "trm_confidence": 0.675158490223123, + "mcts_value": 0.5928470563406139, + "consensus_score": 0.6542262316724003, + "last_agent": "mcts", + "iteration": 5, + "query_length": 3108, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7736698116428354, + "trm_confidence": 0.028665874979164216, + "mcts_value": 0.15062604934972768, + "consensus_score": 0.3086327923371035, + "last_agent": "none", + "iteration": 4, + "query_length": 425, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9473246628050478, + "trm_confidence": 0.7843571444504731, + "mcts_value": 0.1706073965477051, + "consensus_score": 0.7200109843811864, + "last_agent": "trm", + "iteration": 0, + "query_length": 1525, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9855786370242724, + "trm_confidence": 0.010672872423246276, + "mcts_value": 0.6757135888627713, + "consensus_score": 0.651362013950732, + "last_agent": "none", + "iteration": 9, + "query_length": 1171, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8650970517092917, + "trm_confidence": 0.05979703219928956, + "mcts_value": 0.6501712098150373, + "consensus_score": 0.5477874157363418, + "last_agent": "mcts", + "iteration": 6, + "query_length": 2410, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8715525929563327, + "trm_confidence": 0.11849487677955242, + "mcts_value": 0.5707714821884884, + "consensus_score": 0.43527644544509964, + "last_agent": "trm", + "iteration": 8, + "query_length": 925, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.8417406032693464, + "trm_confidence": 0.3781618832841875, + "mcts_value": 0.2315820031465227, + "consensus_score": 0.5777198799784697, + "last_agent": "none", + "iteration": 8, + "query_length": 126, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7524422158699015, + "trm_confidence": 0.14326299569614181, + "mcts_value": 0.3026824843410715, + "consensus_score": 0.4879762453652392, + "last_agent": "trm", + "iteration": 8, + "query_length": 498, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.9578683400256693, + "trm_confidence": 0.3244462548711551, + "mcts_value": 0.730665495080943, + "consensus_score": 0.6193332829221486, + "last_agent": "trm", + "iteration": 6, + "query_length": 2393, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.7773682027443383, + "trm_confidence": 0.4693664392989025, + "mcts_value": 0.6687741749271744, + "consensus_score": 0.5686225964387391, + "last_agent": "hrm", + "iteration": 0, + "query_length": 3358, + "has_rag_context": false + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.801920784324709, + "trm_confidence": 0.05000797943196162, + "mcts_value": 0.33362689209879987, + "consensus_score": 0.4459393545472813, + "last_agent": "none", + "iteration": 3, + "query_length": 2562, + "has_rag_context": true + }, + "label": "hrm" + }, + { + "features": { + "hrm_confidence": 0.44399049870475243, + "trm_confidence": 0.9537098201940917, + "mcts_value": 0.7194485575615861, + "consensus_score": 0.6949941905822187, + "last_agent": "mcts", + "iteration": 10, + "query_length": 1264, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5618301748473334, + "trm_confidence": 0.7347696595957748, + "mcts_value": 0.3105706986357273, + "consensus_score": 0.4640513046208595, + "last_agent": "hrm", + "iteration": 1, + "query_length": 943, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1760292584653908, + "trm_confidence": 0.7142117520188469, + "mcts_value": 0.08585918085299075, + "consensus_score": 0.3135822163691997, + "last_agent": "trm", + "iteration": 4, + "query_length": 2713, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.021914319138401007, + "trm_confidence": 0.783890898730759, + "mcts_value": 0.4201730665660401, + "consensus_score": 0.3834030287215369, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3801, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.15000280015235107, + "trm_confidence": 0.7041332065882151, + "mcts_value": 0.27581539819966233, + "consensus_score": 0.3554648631471029, + "last_agent": "mcts", + "iteration": 8, + "query_length": 39, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.45055688875567235, + "trm_confidence": 0.8591323538809237, + "mcts_value": 0.016572915728581297, + "consensus_score": 0.44511406149096355, + "last_agent": "mcts", + "iteration": 3, + "query_length": 4273, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.49324445079443197, + "trm_confidence": 0.730086219548575, + "mcts_value": 0.5937959316784491, + "consensus_score": 0.5101626314664054, + "last_agent": "mcts", + "iteration": 6, + "query_length": 1243, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5398238930337239, + "trm_confidence": 0.7854422765949668, + "mcts_value": 0.43344288561357935, + "consensus_score": 0.6720688782016373, + "last_agent": "mcts", + "iteration": 10, + "query_length": 3688, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.40120532900118133, + "trm_confidence": 0.8334813283081018, + "mcts_value": 0.10763342161337451, + "consensus_score": 0.41469719239302494, + "last_agent": "none", + "iteration": 10, + "query_length": 958, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6199851429898666, + "trm_confidence": 0.8357523794573152, + "mcts_value": 0.23831112475494545, + "consensus_score": 0.5606425829517397, + "last_agent": "none", + "iteration": 10, + "query_length": 1462, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.033064328666376076, + "trm_confidence": 0.7392763936439497, + "mcts_value": 0.6043824969206342, + "consensus_score": 0.41206962239866646, + "last_agent": "mcts", + "iteration": 5, + "query_length": 238, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6762765013148828, + "trm_confidence": 0.8270965361651815, + "mcts_value": 0.3653027473934347, + "consensus_score": 0.717118958651632, + "last_agent": "mcts", + "iteration": 2, + "query_length": 4406, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.38453742310423844, + "trm_confidence": 0.9578592375427976, + "mcts_value": 0.23976082519376188, + "consensus_score": 0.5984809012971757, + "last_agent": "mcts", + "iteration": 4, + "query_length": 1653, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.16652177228253073, + "trm_confidence": 0.9929581199245436, + "mcts_value": 0.2691381459097174, + "consensus_score": 0.3824199327195652, + "last_agent": "hrm", + "iteration": 4, + "query_length": 2820, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.14578435320948174, + "trm_confidence": 0.9040104806223674, + "mcts_value": 0.542556525391946, + "consensus_score": 0.4834788720457953, + "last_agent": "trm", + "iteration": 3, + "query_length": 2008, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3802169571295888, + "trm_confidence": 0.9150255250490829, + "mcts_value": 0.2944678247370635, + "consensus_score": 0.6004279794464684, + "last_agent": "none", + "iteration": 6, + "query_length": 1336, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2981151404941536, + "trm_confidence": 0.9700273834431048, + "mcts_value": 0.0339131258470358, + "consensus_score": 0.43293859472298857, + "last_agent": "trm", + "iteration": 1, + "query_length": 4764, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.10775691844490401, + "trm_confidence": 0.9241332618062232, + "mcts_value": 0.6650295757125843, + "consensus_score": 0.6530213762807613, + "last_agent": "none", + "iteration": 6, + "query_length": 763, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.8082211137175843, + "trm_confidence": 0.9140759829722251, + "mcts_value": 0.2780380965244117, + "consensus_score": 0.5993497079866142, + "last_agent": "none", + "iteration": 5, + "query_length": 4129, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4005895980957989, + "trm_confidence": 0.7533737464888433, + "mcts_value": 0.009228730029908863, + "consensus_score": 0.473264372029028, + "last_agent": "trm", + "iteration": 2, + "query_length": 1301, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.15195141330764292, + "trm_confidence": 0.732474832442457, + "mcts_value": 0.17323875809406383, + "consensus_score": 0.2591170011809336, + "last_agent": "hrm", + "iteration": 3, + "query_length": 1372, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.27485041219425316, + "trm_confidence": 0.8135862782303556, + "mcts_value": 0.24191127190809028, + "consensus_score": 0.4090340974837071, + "last_agent": "hrm", + "iteration": 0, + "query_length": 3025, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4998685572957993, + "trm_confidence": 0.7606342900469399, + "mcts_value": 0.6545629060364406, + "consensus_score": 0.7240604861378592, + "last_agent": "none", + "iteration": 5, + "query_length": 3997, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5374037880661997, + "trm_confidence": 0.7943744707107527, + "mcts_value": 0.5310275204057412, + "consensus_score": 0.6897897312440535, + "last_agent": "hrm", + "iteration": 9, + "query_length": 857, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6693614822574055, + "trm_confidence": 0.9300626704321593, + "mcts_value": 0.2560408179709783, + "consensus_score": 0.7043367723547445, + "last_agent": "hrm", + "iteration": 1, + "query_length": 4636, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.07751066073395808, + "trm_confidence": 0.9310272154476609, + "mcts_value": 0.545344697158897, + "consensus_score": 0.4247967399658241, + "last_agent": "trm", + "iteration": 8, + "query_length": 3699, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6689706808174918, + "trm_confidence": 0.8807230759274343, + "mcts_value": 0.07641973561121224, + "consensus_score": 0.497222599193542, + "last_agent": "mcts", + "iteration": 4, + "query_length": 4261, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.19478995496854004, + "trm_confidence": 0.8344648262713499, + "mcts_value": 0.2884500524339905, + "consensus_score": 0.42425547095642147, + "last_agent": "none", + "iteration": 9, + "query_length": 933, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.369795165737947, + "trm_confidence": 0.9681229386148816, + "mcts_value": 0.28457740178950947, + "consensus_score": 0.5413342452478086, + "last_agent": "none", + "iteration": 4, + "query_length": 2344, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.11397331765338067, + "trm_confidence": 0.8508037974388201, + "mcts_value": 0.26423033765407056, + "consensus_score": 0.4228271947902419, + "last_agent": "trm", + "iteration": 4, + "query_length": 2422, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6565513876998769, + "trm_confidence": 0.7770165284605498, + "mcts_value": 0.5312697783633291, + "consensus_score": 0.7484645033740798, + "last_agent": "trm", + "iteration": 9, + "query_length": 838, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.12883318501989313, + "trm_confidence": 0.8579380271905042, + "mcts_value": 0.6201656254436124, + "consensus_score": 0.5524780423879958, + "last_agent": "hrm", + "iteration": 2, + "query_length": 838, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.35780216754323463, + "trm_confidence": 0.7345558159942782, + "mcts_value": 0.45754047753151095, + "consensus_score": 0.4284025278228275, + "last_agent": "hrm", + "iteration": 6, + "query_length": 911, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.31772290031448275, + "trm_confidence": 0.7202581943254159, + "mcts_value": 0.16037220844841238, + "consensus_score": 0.43203599133682113, + "last_agent": "hrm", + "iteration": 1, + "query_length": 2005, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1615036406046095, + "trm_confidence": 0.9684185745144196, + "mcts_value": 0.2807300984894874, + "consensus_score": 0.4877444328001141, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2460, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06737428867692129, + "trm_confidence": 0.7468873965882925, + "mcts_value": 0.646735595995737, + "consensus_score": 0.47806951804234776, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1293, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.07693929077895703, + "trm_confidence": 0.976707291563801, + "mcts_value": 0.09996145642120899, + "consensus_score": 0.47046245046616725, + "last_agent": "trm", + "iteration": 10, + "query_length": 1728, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3196784718584861, + "trm_confidence": 0.8065494660981616, + "mcts_value": 0.2940388128620583, + "consensus_score": 0.4114312934648213, + "last_agent": "trm", + "iteration": 9, + "query_length": 4297, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.05450232063485843, + "trm_confidence": 0.9115817228432954, + "mcts_value": 0.0314498724043145, + "consensus_score": 0.36325739980954025, + "last_agent": "none", + "iteration": 3, + "query_length": 2946, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5967026785359415, + "trm_confidence": 0.728206360650649, + "mcts_value": 0.47022154914127223, + "consensus_score": 0.6736303093811141, + "last_agent": "hrm", + "iteration": 1, + "query_length": 4255, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.38166291205082625, + "trm_confidence": 0.8105032492619032, + "mcts_value": 0.21092602355647808, + "consensus_score": 0.37982354339428354, + "last_agent": "hrm", + "iteration": 4, + "query_length": 2722, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.14097898988819313, + "trm_confidence": 0.8157763874389644, + "mcts_value": 0.18745924086930824, + "consensus_score": 0.406795340818609, + "last_agent": "hrm", + "iteration": 4, + "query_length": 1230, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5867465316361682, + "trm_confidence": 0.8601716898167119, + "mcts_value": 0.2880046945944309, + "consensus_score": 0.595811537193605, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1504, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6928280035829012, + "trm_confidence": 0.9729392729552633, + "mcts_value": 0.051257843268573244, + "consensus_score": 0.5581614808960973, + "last_agent": "mcts", + "iteration": 7, + "query_length": 2698, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06843046817694912, + "trm_confidence": 0.8375128873484537, + "mcts_value": 0.41448900915880876, + "consensus_score": 0.3654568799861959, + "last_agent": "trm", + "iteration": 8, + "query_length": 158, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7102941809372267, + "trm_confidence": 0.917029410513186, + "mcts_value": 0.513869027705941, + "consensus_score": 0.7362187389027544, + "last_agent": "hrm", + "iteration": 8, + "query_length": 2958, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.03135605725763654, + "trm_confidence": 0.9607161358831008, + "mcts_value": 0.4787219630754682, + "consensus_score": 0.43351858069824545, + "last_agent": "none", + "iteration": 9, + "query_length": 3755, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.13657622439010203, + "trm_confidence": 0.8260667437097314, + "mcts_value": 0.24615378011227454, + "consensus_score": 0.4436877883120809, + "last_agent": "hrm", + "iteration": 2, + "query_length": 3462, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6621184400767522, + "trm_confidence": 0.9664086274682653, + "mcts_value": 0.1731040516466299, + "consensus_score": 0.5590060150832142, + "last_agent": "none", + "iteration": 10, + "query_length": 833, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.24092127909679925, + "trm_confidence": 0.8831566012641474, + "mcts_value": 0.23426183125810168, + "consensus_score": 0.3606846276110622, + "last_agent": "none", + "iteration": 8, + "query_length": 3088, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4387600739311695, + "trm_confidence": 0.7014572013331837, + "mcts_value": 0.5707069543934159, + "consensus_score": 0.6216086860980519, + "last_agent": "none", + "iteration": 5, + "query_length": 1208, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.23726287854895656, + "trm_confidence": 0.8308409267947315, + "mcts_value": 0.3812900998175645, + "consensus_score": 0.4693614931486405, + "last_agent": "hrm", + "iteration": 6, + "query_length": 3169, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.07075106842511865, + "trm_confidence": 0.9985028251971431, + "mcts_value": 0.3456391869928632, + "consensus_score": 0.4077722660250702, + "last_agent": "mcts", + "iteration": 3, + "query_length": 748, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4984535248654349, + "trm_confidence": 0.7523611394880365, + "mcts_value": 0.3237594897871654, + "consensus_score": 0.6087071678236574, + "last_agent": "mcts", + "iteration": 0, + "query_length": 4766, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.05787850903063726, + "trm_confidence": 0.8209820748791032, + "mcts_value": 0.18942332955531607, + "consensus_score": 0.45330058643701027, + "last_agent": "hrm", + "iteration": 2, + "query_length": 620, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.38450059625280597, + "trm_confidence": 0.8720922006387127, + "mcts_value": 0.4481518822223316, + "consensus_score": 0.5435952820129928, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1542, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.502503780043091, + "trm_confidence": 0.8285177937524397, + "mcts_value": 0.63252569593037, + "consensus_score": 0.6284300778466518, + "last_agent": "trm", + "iteration": 6, + "query_length": 3271, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.04501556211515551, + "trm_confidence": 0.7510957264335486, + "mcts_value": 0.473920138132877, + "consensus_score": 0.3809558026382885, + "last_agent": "none", + "iteration": 7, + "query_length": 17, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7175235392332914, + "trm_confidence": 0.9587643401224977, + "mcts_value": 0.07636883334349913, + "consensus_score": 0.4889702055245333, + "last_agent": "mcts", + "iteration": 0, + "query_length": 2725, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.367568488911526, + "trm_confidence": 0.7704110467155687, + "mcts_value": 0.28575882241038536, + "consensus_score": 0.5336845579248805, + "last_agent": "mcts", + "iteration": 5, + "query_length": 149, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1346253143879693, + "trm_confidence": 0.9967275818499708, + "mcts_value": 0.7449151342248853, + "consensus_score": 0.5357726259417187, + "last_agent": "hrm", + "iteration": 5, + "query_length": 500, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5478730718912188, + "trm_confidence": 0.8995641896812363, + "mcts_value": 0.21053906941647224, + "consensus_score": 0.5363737352685999, + "last_agent": "none", + "iteration": 4, + "query_length": 1756, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2414694756928368, + "trm_confidence": 0.8386683790431004, + "mcts_value": 0.2256488687693059, + "consensus_score": 0.3854328194994388, + "last_agent": "mcts", + "iteration": 4, + "query_length": 213, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5188049086813686, + "trm_confidence": 0.9616312906010721, + "mcts_value": 0.11527457054933798, + "consensus_score": 0.5706754567857266, + "last_agent": "mcts", + "iteration": 9, + "query_length": 498, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.005559683494021258, + "trm_confidence": 0.7451801783011791, + "mcts_value": 0.030484982048218194, + "consensus_score": 0.31285855094184184, + "last_agent": "none", + "iteration": 10, + "query_length": 1355, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.05924469930123762, + "trm_confidence": 0.9244212877239678, + "mcts_value": 0.4881669060107118, + "consensus_score": 0.5418260208309533, + "last_agent": "mcts", + "iteration": 3, + "query_length": 2810, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7717402512901582, + "trm_confidence": 0.8785568056376277, + "mcts_value": 0.2005697978017373, + "consensus_score": 0.5548538736806177, + "last_agent": "trm", + "iteration": 1, + "query_length": 1961, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6171263197611457, + "trm_confidence": 0.878933529808082, + "mcts_value": 0.5865940232253513, + "consensus_score": 0.7727262565572104, + "last_agent": "hrm", + "iteration": 8, + "query_length": 1025, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.22804455185613456, + "trm_confidence": 0.8667632214627283, + "mcts_value": 0.2528673000162186, + "consensus_score": 0.5356138982910263, + "last_agent": "trm", + "iteration": 7, + "query_length": 4544, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3896818826228778, + "trm_confidence": 0.8250737610156234, + "mcts_value": 0.3531946095675046, + "consensus_score": 0.44410749590674725, + "last_agent": "hrm", + "iteration": 2, + "query_length": 4884, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5777041376588443, + "trm_confidence": 0.7202122915399018, + "mcts_value": 0.07680937377083678, + "consensus_score": 0.39783808360657574, + "last_agent": "mcts", + "iteration": 8, + "query_length": 4191, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.631420389676864, + "trm_confidence": 0.9700834298846228, + "mcts_value": 0.05970587071489856, + "consensus_score": 0.5921059756825924, + "last_agent": "mcts", + "iteration": 6, + "query_length": 1318, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6094912560302467, + "trm_confidence": 0.9484593557468007, + "mcts_value": 0.4973467618396553, + "consensus_score": 0.6495917882552565, + "last_agent": "mcts", + "iteration": 7, + "query_length": 960, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5157020450568663, + "trm_confidence": 0.8762565341586622, + "mcts_value": 0.33932349902157877, + "consensus_score": 0.6283181880276507, + "last_agent": "hrm", + "iteration": 10, + "query_length": 2497, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3564181702938383, + "trm_confidence": 0.7607215838511286, + "mcts_value": 0.3180551149056891, + "consensus_score": 0.4927122465848959, + "last_agent": "none", + "iteration": 0, + "query_length": 2415, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4534609618708469, + "trm_confidence": 0.8748934786196683, + "mcts_value": 0.509363800249543, + "consensus_score": 0.6566159527138855, + "last_agent": "none", + "iteration": 7, + "query_length": 1434, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.42324868604069527, + "trm_confidence": 0.8874281904053775, + "mcts_value": 0.477336157029331, + "consensus_score": 0.5162430178424677, + "last_agent": "none", + "iteration": 3, + "query_length": 4223, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5852719319442111, + "trm_confidence": 0.9618961646205122, + "mcts_value": 0.07792237608563882, + "consensus_score": 0.5601518143056683, + "last_agent": "none", + "iteration": 2, + "query_length": 84, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5033887527895373, + "trm_confidence": 0.9482322557526242, + "mcts_value": 0.5016128404374516, + "consensus_score": 0.6187962383020338, + "last_agent": "trm", + "iteration": 4, + "query_length": 4188, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4246367125632823, + "trm_confidence": 0.9169960992901469, + "mcts_value": 0.3008444231972218, + "consensus_score": 0.4575683593412612, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1885, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6722444983329169, + "trm_confidence": 0.8120150629776024, + "mcts_value": 0.35967682625370156, + "consensus_score": 0.5616134951732532, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1137, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3238170141971061, + "trm_confidence": 0.9447191906653171, + "mcts_value": 0.6525504792475115, + "consensus_score": 0.7009032232620207, + "last_agent": "trm", + "iteration": 1, + "query_length": 4619, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.356989558313623, + "trm_confidence": 0.9877288190113332, + "mcts_value": 0.04522453596626121, + "consensus_score": 0.5479748424819816, + "last_agent": "mcts", + "iteration": 9, + "query_length": 3248, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.16212506707963395, + "trm_confidence": 0.8656759361032029, + "mcts_value": 0.6274337243421171, + "consensus_score": 0.5514649168534099, + "last_agent": "none", + "iteration": 0, + "query_length": 3140, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7339216286610662, + "trm_confidence": 0.9851488809408118, + "mcts_value": 0.24536162360623354, + "consensus_score": 0.5654757280560386, + "last_agent": "trm", + "iteration": 10, + "query_length": 3298, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.48798433603059405, + "trm_confidence": 0.8597245959138609, + "mcts_value": 0.35144864419324634, + "consensus_score": 0.5503389321323188, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1278, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5603984069163154, + "trm_confidence": 0.7698196440028562, + "mcts_value": 0.09768133976289063, + "consensus_score": 0.42130619394362945, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4975, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.36557083789533973, + "trm_confidence": 0.7489620054153361, + "mcts_value": 0.23887854732585437, + "consensus_score": 0.47636898738027056, + "last_agent": "hrm", + "iteration": 10, + "query_length": 2177, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.788803394620058, + "trm_confidence": 0.9740300979439191, + "mcts_value": 0.7246475016671664, + "consensus_score": 0.8554175148715983, + "last_agent": "trm", + "iteration": 6, + "query_length": 3086, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.013303586213472012, + "trm_confidence": 0.907148665365344, + "mcts_value": 0.7733231465659587, + "consensus_score": 0.5817125516747075, + "last_agent": "none", + "iteration": 8, + "query_length": 725, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3077365828115015, + "trm_confidence": 0.7037191641747388, + "mcts_value": 0.3062003949110987, + "consensus_score": 0.42661657162807, + "last_agent": "trm", + "iteration": 3, + "query_length": 4154, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.28455970154793864, + "trm_confidence": 0.8582634623491548, + "mcts_value": 0.7354639078841835, + "consensus_score": 0.5606811521471453, + "last_agent": "trm", + "iteration": 3, + "query_length": 4690, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3809472177227391, + "trm_confidence": 0.862992752870787, + "mcts_value": 0.49116930984838786, + "consensus_score": 0.6683701142621463, + "last_agent": "hrm", + "iteration": 3, + "query_length": 2749, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.15940416182721573, + "trm_confidence": 0.707512060901555, + "mcts_value": 0.2817153276359186, + "consensus_score": 0.28907614115112334, + "last_agent": "mcts", + "iteration": 6, + "query_length": 3080, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4256674726587186, + "trm_confidence": 0.9240195144514141, + "mcts_value": 0.4976621895912099, + "consensus_score": 0.594996109156145, + "last_agent": "hrm", + "iteration": 0, + "query_length": 3633, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7929585788152211, + "trm_confidence": 0.8964987772604393, + "mcts_value": 0.012587260670319693, + "consensus_score": 0.5788649049472699, + "last_agent": "trm", + "iteration": 8, + "query_length": 2383, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4733439573332464, + "trm_confidence": 0.863561470366792, + "mcts_value": 0.5319923236550607, + "consensus_score": 0.6947591670570525, + "last_agent": "mcts", + "iteration": 1, + "query_length": 698, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5082465744857824, + "trm_confidence": 0.707805221065471, + "mcts_value": 0.04664323192977533, + "consensus_score": 0.3598239848146607, + "last_agent": "hrm", + "iteration": 0, + "query_length": 3014, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.11624612534982769, + "trm_confidence": 0.8003077680152176, + "mcts_value": 0.3159019436862092, + "consensus_score": 0.391292788984855, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3190, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.09900207687467694, + "trm_confidence": 0.8957701755301245, + "mcts_value": 0.2350631845562408, + "consensus_score": 0.4926088050201367, + "last_agent": "hrm", + "iteration": 0, + "query_length": 1857, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.839200706210048, + "trm_confidence": 0.9827513857700444, + "mcts_value": 0.3629901743729547, + "consensus_score": 0.7617934664041085, + "last_agent": "trm", + "iteration": 5, + "query_length": 3531, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.08786009173036866, + "trm_confidence": 0.7094743992078996, + "mcts_value": 0.3743384467839459, + "consensus_score": 0.3370680983960174, + "last_agent": "mcts", + "iteration": 7, + "query_length": 267, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7094734398843789, + "trm_confidence": 0.9659936674383465, + "mcts_value": 0.4267944139338113, + "consensus_score": 0.755332857837817, + "last_agent": "none", + "iteration": 2, + "query_length": 684, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5675345712066688, + "trm_confidence": 0.7592446155589834, + "mcts_value": 0.5426028944636934, + "consensus_score": 0.7035572940915822, + "last_agent": "hrm", + "iteration": 9, + "query_length": 4548, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6229253486081775, + "trm_confidence": 0.7685008694895777, + "mcts_value": 0.381856515082727, + "consensus_score": 0.5755932611771244, + "last_agent": "trm", + "iteration": 6, + "query_length": 1734, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.16567819614037949, + "trm_confidence": 0.7865659020882856, + "mcts_value": 0.05659224693895597, + "consensus_score": 0.3355917611519778, + "last_agent": "mcts", + "iteration": 9, + "query_length": 2722, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3241416909846743, + "trm_confidence": 0.856292689826732, + "mcts_value": 0.351223555127338, + "consensus_score": 0.452408114827305, + "last_agent": "none", + "iteration": 9, + "query_length": 1895, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.619377773100635, + "trm_confidence": 0.9149868270181939, + "mcts_value": 0.4685106965131719, + "consensus_score": 0.7402103738682992, + "last_agent": "hrm", + "iteration": 3, + "query_length": 4145, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6540348403488278, + "trm_confidence": 0.9188231378797913, + "mcts_value": 0.5441732431263956, + "consensus_score": 0.7749456771829988, + "last_agent": "none", + "iteration": 10, + "query_length": 895, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6285553989241485, + "trm_confidence": 0.944174705112871, + "mcts_value": 0.38230153859435984, + "consensus_score": 0.7370670141569622, + "last_agent": "none", + "iteration": 4, + "query_length": 2051, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.14264086330260134, + "trm_confidence": 0.7317827070972884, + "mcts_value": 0.531137302466143, + "consensus_score": 0.4478308355211137, + "last_agent": "mcts", + "iteration": 4, + "query_length": 2832, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6023345605360149, + "trm_confidence": 0.8474831002892362, + "mcts_value": 0.6691200500211939, + "consensus_score": 0.6203354566909851, + "last_agent": "trm", + "iteration": 6, + "query_length": 1953, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3507914131324489, + "trm_confidence": 0.8842845551414735, + "mcts_value": 0.3126509759824731, + "consensus_score": 0.5102851239046965, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1140, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6065421070084859, + "trm_confidence": 0.8416433353061938, + "mcts_value": 0.5530048631860084, + "consensus_score": 0.6353391340031346, + "last_agent": "trm", + "iteration": 8, + "query_length": 2918, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3764406418636302, + "trm_confidence": 0.7460476616122571, + "mcts_value": 0.18432514629470334, + "consensus_score": 0.483438947642208, + "last_agent": "hrm", + "iteration": 2, + "query_length": 3714, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.34580787111028655, + "trm_confidence": 0.7765373968172902, + "mcts_value": 0.6417948119039376, + "consensus_score": 0.5643082009684584, + "last_agent": "none", + "iteration": 9, + "query_length": 1183, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1430520289014284, + "trm_confidence": 0.962455254641549, + "mcts_value": 0.7069185384482377, + "consensus_score": 0.5105718196838507, + "last_agent": "mcts", + "iteration": 4, + "query_length": 4681, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4297221008494155, + "trm_confidence": 0.9066256638279759, + "mcts_value": 0.3128610767072595, + "consensus_score": 0.6121956697556076, + "last_agent": "trm", + "iteration": 7, + "query_length": 118, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.19953631197384059, + "trm_confidence": 0.9677453011628979, + "mcts_value": 0.14785643143045163, + "consensus_score": 0.46694010105577965, + "last_agent": "mcts", + "iteration": 10, + "query_length": 377, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.09414792784677961, + "trm_confidence": 0.7749011136801479, + "mcts_value": 0.5807953018467219, + "consensus_score": 0.42856095587348325, + "last_agent": "hrm", + "iteration": 5, + "query_length": 2635, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.33591852705329256, + "trm_confidence": 0.7036422702766131, + "mcts_value": 0.3092393207762463, + "consensus_score": 0.5110895146507483, + "last_agent": "hrm", + "iteration": 3, + "query_length": 3341, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.48181244318081007, + "trm_confidence": 0.7503813941796591, + "mcts_value": 0.5103600743829874, + "consensus_score": 0.5081839680702516, + "last_agent": "none", + "iteration": 4, + "query_length": 2594, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.17949808914376916, + "trm_confidence": 0.7008121821027361, + "mcts_value": 0.3334761664417101, + "consensus_score": 0.37141425116139576, + "last_agent": "trm", + "iteration": 2, + "query_length": 2971, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3466169105373836, + "trm_confidence": 0.7544569324329081, + "mcts_value": 0.46250204482550167, + "consensus_score": 0.5362509531242983, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1516, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.49903474535210157, + "trm_confidence": 0.7524933512462464, + "mcts_value": 0.2352663374726087, + "consensus_score": 0.5521504143015393, + "last_agent": "none", + "iteration": 2, + "query_length": 3875, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.024672623991961514, + "trm_confidence": 0.8491077538652446, + "mcts_value": 0.16877198340705082, + "consensus_score": 0.39231990419162194, + "last_agent": "trm", + "iteration": 9, + "query_length": 550, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6895519670085203, + "trm_confidence": 0.9646233128314313, + "mcts_value": 0.7519493168716266, + "consensus_score": 0.8417272431529607, + "last_agent": "trm", + "iteration": 7, + "query_length": 2840, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7314607128206364, + "trm_confidence": 0.8958361673473945, + "mcts_value": 0.07999904434156259, + "consensus_score": 0.532251537072416, + "last_agent": "none", + "iteration": 0, + "query_length": 2961, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.49541895353943594, + "trm_confidence": 0.8335796397998163, + "mcts_value": 0.6170285539474784, + "consensus_score": 0.7468757653775413, + "last_agent": "hrm", + "iteration": 1, + "query_length": 1504, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.042337164682284135, + "trm_confidence": 0.966226869441021, + "mcts_value": 0.6532135052245751, + "consensus_score": 0.5248709106886035, + "last_agent": "trm", + "iteration": 7, + "query_length": 2778, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.09166962753153295, + "trm_confidence": 0.777922675625991, + "mcts_value": 0.07928118116289253, + "consensus_score": 0.2369557329818372, + "last_agent": "none", + "iteration": 7, + "query_length": 1110, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4600966895664224, + "trm_confidence": 0.7649900314617617, + "mcts_value": 0.4567481143485402, + "consensus_score": 0.5684452702636409, + "last_agent": "trm", + "iteration": 10, + "query_length": 2275, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06392551939140412, + "trm_confidence": 0.9507588913807878, + "mcts_value": 0.6294178813732334, + "consensus_score": 0.6457235801481944, + "last_agent": "mcts", + "iteration": 9, + "query_length": 417, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.27840269980272137, + "trm_confidence": 0.8880298192981482, + "mcts_value": 0.5354546137596964, + "consensus_score": 0.6213747875897874, + "last_agent": "hrm", + "iteration": 10, + "query_length": 58, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3684733249395634, + "trm_confidence": 0.7183912197320327, + "mcts_value": 0.07360348928504164, + "consensus_score": 0.33577568669676794, + "last_agent": "trm", + "iteration": 3, + "query_length": 2862, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1085917895613392, + "trm_confidence": 0.7551424371635501, + "mcts_value": 0.029457983415894704, + "consensus_score": 0.3723787239843709, + "last_agent": "none", + "iteration": 7, + "query_length": 2098, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3399464330694429, + "trm_confidence": 0.7084876916914343, + "mcts_value": 0.17191870566237222, + "consensus_score": 0.42612954689413557, + "last_agent": "hrm", + "iteration": 4, + "query_length": 551, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4661264689151281, + "trm_confidence": 0.7863427044884104, + "mcts_value": 0.037099782425568675, + "consensus_score": 0.5210284254474875, + "last_agent": "mcts", + "iteration": 10, + "query_length": 4645, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1626991710875187, + "trm_confidence": 0.839563741469072, + "mcts_value": 0.6009076245985423, + "consensus_score": 0.5816609984811155, + "last_agent": "none", + "iteration": 2, + "query_length": 4304, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.610584907402012, + "trm_confidence": 0.8388396813450355, + "mcts_value": 0.5052515886628856, + "consensus_score": 0.680192234511322, + "last_agent": "mcts", + "iteration": 7, + "query_length": 2476, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.03248932695593583, + "trm_confidence": 0.8115133779488601, + "mcts_value": 0.5599191457885969, + "consensus_score": 0.45457359326509383, + "last_agent": "trm", + "iteration": 4, + "query_length": 3796, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3896249423489808, + "trm_confidence": 0.7628281844378286, + "mcts_value": 0.15414269285299517, + "consensus_score": 0.34837371019302776, + "last_agent": "none", + "iteration": 8, + "query_length": 3598, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.30799173649505523, + "trm_confidence": 0.8140706751697913, + "mcts_value": 0.5819996073446608, + "consensus_score": 0.5321880521911325, + "last_agent": "trm", + "iteration": 8, + "query_length": 1375, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3445939379059129, + "trm_confidence": 0.8982604587774398, + "mcts_value": 0.5969309674681113, + "consensus_score": 0.6048412620007652, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3424, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.07248691640425747, + "trm_confidence": 0.7953167004427881, + "mcts_value": 0.2884535276545265, + "consensus_score": 0.31019042439305106, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3260, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.0318471571811411, + "trm_confidence": 0.9792133543616941, + "mcts_value": 0.5463945301768476, + "consensus_score": 0.5106987179032431, + "last_agent": "none", + "iteration": 0, + "query_length": 2172, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.137782560158677, + "trm_confidence": 0.7355455757064404, + "mcts_value": 0.17450190345222677, + "consensus_score": 0.32953794684888726, + "last_agent": "hrm", + "iteration": 5, + "query_length": 2690, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06068349066973206, + "trm_confidence": 0.7485608081886223, + "mcts_value": 0.3564966240656323, + "consensus_score": 0.3543952817215828, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1661, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5716627543553109, + "trm_confidence": 0.7072688062262754, + "mcts_value": 0.06577214345905116, + "consensus_score": 0.3736063176058892, + "last_agent": "none", + "iteration": 1, + "query_length": 4420, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4792595224651311, + "trm_confidence": 0.8852980720509989, + "mcts_value": 0.7748521953486893, + "consensus_score": 0.6308243504765374, + "last_agent": "hrm", + "iteration": 10, + "query_length": 2024, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3305727687781483, + "trm_confidence": 0.8869214510686106, + "mcts_value": 0.3021399182018856, + "consensus_score": 0.592793378378774, + "last_agent": "hrm", + "iteration": 1, + "query_length": 2143, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4630549898545562, + "trm_confidence": 0.8137375183274645, + "mcts_value": 0.6908313485013811, + "consensus_score": 0.5673948289702623, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1973, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.12103597238851165, + "trm_confidence": 0.9113190636932357, + "mcts_value": 0.4327473909675736, + "consensus_score": 0.5230702186281557, + "last_agent": "none", + "iteration": 5, + "query_length": 1658, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5752077479652343, + "trm_confidence": 0.773235722525673, + "mcts_value": 0.023207103765942348, + "consensus_score": 0.49898927377377894, + "last_agent": "mcts", + "iteration": 10, + "query_length": 2778, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.43705698931283243, + "trm_confidence": 0.7483137048459819, + "mcts_value": 0.0388928802268296, + "consensus_score": 0.4260458548997197, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2065, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.49476042604650194, + "trm_confidence": 0.8095554109063201, + "mcts_value": 0.4216848270686829, + "consensus_score": 0.5168609890926189, + "last_agent": "mcts", + "iteration": 6, + "query_length": 3593, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.37745678874719507, + "trm_confidence": 0.900049985707135, + "mcts_value": 0.2572833569043462, + "consensus_score": 0.4287739523404067, + "last_agent": "mcts", + "iteration": 2, + "query_length": 3286, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.19934137811910552, + "trm_confidence": 0.9959877752574893, + "mcts_value": 0.7301188829975164, + "consensus_score": 0.5755449884856995, + "last_agent": "trm", + "iteration": 3, + "query_length": 3725, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5567073453453225, + "trm_confidence": 0.7973216992424815, + "mcts_value": 0.667820394630638, + "consensus_score": 0.7490087098029837, + "last_agent": "mcts", + "iteration": 1, + "query_length": 477, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5885682385532721, + "trm_confidence": 0.7443840739914134, + "mcts_value": 0.33371270642218903, + "consensus_score": 0.5228312239336328, + "last_agent": "trm", + "iteration": 1, + "query_length": 2627, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.616400595471275, + "trm_confidence": 0.7743840974982654, + "mcts_value": 0.5376609440049948, + "consensus_score": 0.736773953874989, + "last_agent": "mcts", + "iteration": 1, + "query_length": 3719, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3676503385665803, + "trm_confidence": 0.7771451422092329, + "mcts_value": 0.2169944690565093, + "consensus_score": 0.41062309332611235, + "last_agent": "none", + "iteration": 6, + "query_length": 4963, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.48208269744353577, + "trm_confidence": 0.8407400310916351, + "mcts_value": 0.09494959287482052, + "consensus_score": 0.5587096279787113, + "last_agent": "mcts", + "iteration": 5, + "query_length": 2008, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.565282606096974, + "trm_confidence": 0.896171629186319, + "mcts_value": 0.5701631380009995, + "consensus_score": 0.7522257705070587, + "last_agent": "none", + "iteration": 6, + "query_length": 4150, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.38184870645861857, + "trm_confidence": 0.7298528928596371, + "mcts_value": 0.041792505404248316, + "consensus_score": 0.33275672524540734, + "last_agent": "hrm", + "iteration": 0, + "query_length": 4818, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3956058611748936, + "trm_confidence": 0.7481014374828604, + "mcts_value": 0.5125608281675542, + "consensus_score": 0.6164655197628086, + "last_agent": "none", + "iteration": 3, + "query_length": 4570, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3636044808504141, + "trm_confidence": 0.8689032280259077, + "mcts_value": 0.5037191858411157, + "consensus_score": 0.5207597230038525, + "last_agent": "mcts", + "iteration": 10, + "query_length": 3767, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5042486778558699, + "trm_confidence": 0.9487663423609357, + "mcts_value": 0.20782446818532185, + "consensus_score": 0.6027481629327691, + "last_agent": "none", + "iteration": 0, + "query_length": 1989, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.0860525064016166, + "trm_confidence": 0.876995566672904, + "mcts_value": 0.06177547004860453, + "consensus_score": 0.3917751143950009, + "last_agent": "hrm", + "iteration": 3, + "query_length": 1594, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.578535731544378, + "trm_confidence": 0.9286219063090528, + "mcts_value": 0.23085907844237044, + "consensus_score": 0.6211278092362573, + "last_agent": "mcts", + "iteration": 10, + "query_length": 1251, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1637657571093312, + "trm_confidence": 0.7978487926010698, + "mcts_value": 0.27442913148171366, + "consensus_score": 0.46048109136816084, + "last_agent": "trm", + "iteration": 5, + "query_length": 572, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.09808957441878606, + "trm_confidence": 0.7942464664235027, + "mcts_value": 0.18531774254454983, + "consensus_score": 0.4311539880440103, + "last_agent": "none", + "iteration": 6, + "query_length": 3206, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.053949052911980266, + "trm_confidence": 0.9369717945206762, + "mcts_value": 0.21249967947120593, + "consensus_score": 0.3471608064460887, + "last_agent": "none", + "iteration": 10, + "query_length": 653, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.003914871612376854, + "trm_confidence": 0.8853357809165663, + "mcts_value": 0.1447452394219721, + "consensus_score": 0.37177355591158734, + "last_agent": "mcts", + "iteration": 4, + "query_length": 1187, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.371068074401949, + "trm_confidence": 0.9294172657425274, + "mcts_value": 0.6216606633028641, + "consensus_score": 0.7357961992476911, + "last_agent": "none", + "iteration": 5, + "query_length": 1580, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2177393810241916, + "trm_confidence": 0.9755993845183886, + "mcts_value": 0.7762695635461, + "consensus_score": 0.5760450510508233, + "last_agent": "mcts", + "iteration": 7, + "query_length": 137, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3886885142379816, + "trm_confidence": 0.8843818657757279, + "mcts_value": 0.4842358700709443, + "consensus_score": 0.5161014917841068, + "last_agent": "mcts", + "iteration": 2, + "query_length": 2527, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3443884439931525, + "trm_confidence": 0.8793338021890114, + "mcts_value": 0.19052828786785117, + "consensus_score": 0.5479787111122544, + "last_agent": "hrm", + "iteration": 4, + "query_length": 364, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1315284405602689, + "trm_confidence": 0.8229103229646983, + "mcts_value": 0.5289007220143683, + "consensus_score": 0.4841381476335501, + "last_agent": "hrm", + "iteration": 0, + "query_length": 2661, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6112116300243542, + "trm_confidence": 0.944729083478624, + "mcts_value": 0.73414717924184, + "consensus_score": 0.8447240816360495, + "last_agent": "trm", + "iteration": 3, + "query_length": 45, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.20920003901734632, + "trm_confidence": 0.8826423427616783, + "mcts_value": 0.1737092300537618, + "consensus_score": 0.34463198824139984, + "last_agent": "none", + "iteration": 7, + "query_length": 1519, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2619664334960157, + "trm_confidence": 0.9282858938551011, + "mcts_value": 0.17001383871513442, + "consensus_score": 0.4441448910463415, + "last_agent": "trm", + "iteration": 6, + "query_length": 4000, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7900517104476441, + "trm_confidence": 0.9624469462822163, + "mcts_value": 0.7333061975991593, + "consensus_score": 0.9215178589969355, + "last_agent": "none", + "iteration": 5, + "query_length": 4511, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.24394936531507527, + "trm_confidence": 0.7250686096445434, + "mcts_value": 0.16107211186535897, + "consensus_score": 0.31068481388053937, + "last_agent": "mcts", + "iteration": 10, + "query_length": 1115, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6754832080856815, + "trm_confidence": 0.9878137816945693, + "mcts_value": 0.5690404010302622, + "consensus_score": 0.77931625562138, + "last_agent": "trm", + "iteration": 3, + "query_length": 3555, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6151419235787853, + "trm_confidence": 0.7560076354843795, + "mcts_value": 0.1979787353524825, + "consensus_score": 0.47408006726392415, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2981, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.46509443130234623, + "trm_confidence": 0.8353896755029916, + "mcts_value": 0.11728955828719136, + "consensus_score": 0.5072568831779007, + "last_agent": "mcts", + "iteration": 2, + "query_length": 4989, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.21559775166278877, + "trm_confidence": 0.7486539804831354, + "mcts_value": 0.21486731881020213, + "consensus_score": 0.3838754859924959, + "last_agent": "none", + "iteration": 2, + "query_length": 3421, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.09827180114003921, + "trm_confidence": 0.8625215714989685, + "mcts_value": 0.525298829132584, + "consensus_score": 0.5841534924297165, + "last_agent": "none", + "iteration": 4, + "query_length": 1177, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.027041610671315285, + "trm_confidence": 0.7294949773377593, + "mcts_value": 0.45409808520273953, + "consensus_score": 0.455749205666904, + "last_agent": "trm", + "iteration": 9, + "query_length": 2817, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3136839190701999, + "trm_confidence": 0.9330721164870491, + "mcts_value": 0.2527021236687303, + "consensus_score": 0.4344241706773192, + "last_agent": "hrm", + "iteration": 10, + "query_length": 361, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5423567683286818, + "trm_confidence": 0.9481024043823998, + "mcts_value": 0.6234234921929721, + "consensus_score": 0.7469709789523773, + "last_agent": "none", + "iteration": 9, + "query_length": 1487, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3592651604263909, + "trm_confidence": 0.7002473850504056, + "mcts_value": 0.21240272293132584, + "consensus_score": 0.3487292223594683, + "last_agent": "hrm", + "iteration": 2, + "query_length": 2419, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.03534708308710173, + "trm_confidence": 0.7210700112660109, + "mcts_value": 0.5689164364087479, + "consensus_score": 0.40419839345802605, + "last_agent": "mcts", + "iteration": 10, + "query_length": 611, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4595328841269329, + "trm_confidence": 0.7427713846981341, + "mcts_value": 0.3861801905176152, + "consensus_score": 0.5676242344019248, + "last_agent": "none", + "iteration": 1, + "query_length": 342, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.37637053417760125, + "trm_confidence": 0.7229915984352644, + "mcts_value": 0.6146379513451925, + "consensus_score": 0.6057302688509411, + "last_agent": "none", + "iteration": 7, + "query_length": 2800, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.03228201472915102, + "trm_confidence": 0.9034034662515389, + "mcts_value": 0.6965553049876887, + "consensus_score": 0.5170159200652188, + "last_agent": "trm", + "iteration": 0, + "query_length": 195, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.22673522654323894, + "trm_confidence": 0.9811045002628791, + "mcts_value": 0.43307446871534455, + "consensus_score": 0.4817267760636323, + "last_agent": "none", + "iteration": 0, + "query_length": 4955, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.706902151582988, + "trm_confidence": 0.9181032717884066, + "mcts_value": 0.1172332492359439, + "consensus_score": 0.5384099377645399, + "last_agent": "hrm", + "iteration": 2, + "query_length": 3042, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5361609677126532, + "trm_confidence": 0.8660294109660951, + "mcts_value": 0.49629320506058805, + "consensus_score": 0.610438570706895, + "last_agent": "hrm", + "iteration": 6, + "query_length": 2899, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6022105841966343, + "trm_confidence": 0.7462911156168053, + "mcts_value": 0.492155691889362, + "consensus_score": 0.6981002830659674, + "last_agent": "mcts", + "iteration": 10, + "query_length": 3711, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3660627178418953, + "trm_confidence": 0.9784715944336495, + "mcts_value": 0.0031767683220894345, + "consensus_score": 0.36175125003510994, + "last_agent": "none", + "iteration": 1, + "query_length": 4906, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4947540701137711, + "trm_confidence": 0.8055107172550791, + "mcts_value": 0.5578112066415875, + "consensus_score": 0.6374013195390709, + "last_agent": "trm", + "iteration": 1, + "query_length": 4264, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.009179028716510931, + "trm_confidence": 0.8064405757412797, + "mcts_value": 0.19821918663560692, + "consensus_score": 0.3040663745259947, + "last_agent": "mcts", + "iteration": 3, + "query_length": 2701, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5759017228908051, + "trm_confidence": 0.9279779192782271, + "mcts_value": 0.11509313567569202, + "consensus_score": 0.6012243745561644, + "last_agent": "mcts", + "iteration": 7, + "query_length": 2335, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.38411527252869104, + "trm_confidence": 0.8944281790844549, + "mcts_value": 0.3490320696210835, + "consensus_score": 0.46494980643922096, + "last_agent": "none", + "iteration": 5, + "query_length": 1838, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.08605914289116712, + "trm_confidence": 0.7492760161762799, + "mcts_value": 0.6194796224353809, + "consensus_score": 0.4419111985897207, + "last_agent": "none", + "iteration": 3, + "query_length": 3590, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.48948486094393795, + "trm_confidence": 0.9694751335534431, + "mcts_value": 0.7991866758820104, + "consensus_score": 0.7127005379587096, + "last_agent": "none", + "iteration": 1, + "query_length": 4412, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5642706425765245, + "trm_confidence": 0.8036610975370052, + "mcts_value": 0.6188160411145123, + "consensus_score": 0.6567639625907616, + "last_agent": "none", + "iteration": 10, + "query_length": 2564, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.33312734089856844, + "trm_confidence": 0.7812761180158937, + "mcts_value": 0.025418611170304407, + "consensus_score": 0.440740922080869, + "last_agent": "none", + "iteration": 1, + "query_length": 4051, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.03307344027948186, + "trm_confidence": 0.869594964101483, + "mcts_value": 0.06104188617818761, + "consensus_score": 0.36046003504232005, + "last_agent": "hrm", + "iteration": 0, + "query_length": 81, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.16002249463825832, + "trm_confidence": 0.801608690026225, + "mcts_value": 0.13244891523472083, + "consensus_score": 0.41737562537958395, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1632, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5280922234212754, + "trm_confidence": 0.9879043657522382, + "mcts_value": 0.5697330752523868, + "consensus_score": 0.7372284971038164, + "last_agent": "trm", + "iteration": 7, + "query_length": 2780, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.39737379960595426, + "trm_confidence": 0.8450405954043092, + "mcts_value": 0.2412642627233361, + "consensus_score": 0.41561107133448993, + "last_agent": "trm", + "iteration": 7, + "query_length": 2378, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06432983471771407, + "trm_confidence": 0.7826905153250104, + "mcts_value": 0.13020762131912073, + "consensus_score": 0.35348883970009126, + "last_agent": "none", + "iteration": 6, + "query_length": 692, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.27174780465748594, + "trm_confidence": 0.7270515929468473, + "mcts_value": 0.2211892338025744, + "consensus_score": 0.4507547077068259, + "last_agent": "trm", + "iteration": 6, + "query_length": 2732, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7209489134411255, + "trm_confidence": 0.8460946246914985, + "mcts_value": 0.564267053498185, + "consensus_score": 0.709832775968962, + "last_agent": "mcts", + "iteration": 5, + "query_length": 670, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2817517105607826, + "trm_confidence": 0.982752778995942, + "mcts_value": 0.22593093578389653, + "consensus_score": 0.596693593524408, + "last_agent": "none", + "iteration": 7, + "query_length": 1684, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.01577439131129125, + "trm_confidence": 0.7185656431076729, + "mcts_value": 0.37090365019669297, + "consensus_score": 0.2827274615509427, + "last_agent": "none", + "iteration": 10, + "query_length": 4032, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.47099881769875523, + "trm_confidence": 0.8638162200244239, + "mcts_value": 0.22295367841930777, + "consensus_score": 0.43205074719581693, + "last_agent": "none", + "iteration": 8, + "query_length": 4342, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1947221200387112, + "trm_confidence": 0.7256672923055179, + "mcts_value": 0.007799355637569309, + "consensus_score": 0.296203379145196, + "last_agent": "mcts", + "iteration": 7, + "query_length": 717, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6616562017705362, + "trm_confidence": 0.8812921911396764, + "mcts_value": 0.31759699748431286, + "consensus_score": 0.6899362138102705, + "last_agent": "mcts", + "iteration": 5, + "query_length": 2900, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2428662903497912, + "trm_confidence": 0.9832436939981397, + "mcts_value": 0.023142295682260235, + "consensus_score": 0.36495223526773574, + "last_agent": "hrm", + "iteration": 4, + "query_length": 945, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6688805237376906, + "trm_confidence": 0.8574002528330715, + "mcts_value": 0.22141704712565802, + "consensus_score": 0.5867200628239739, + "last_agent": "none", + "iteration": 3, + "query_length": 3712, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.0670001734640036, + "trm_confidence": 0.7472804539865435, + "mcts_value": 0.3255789958687113, + "consensus_score": 0.4143307249261845, + "last_agent": "mcts", + "iteration": 8, + "query_length": 694, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3267956106620368, + "trm_confidence": 0.9321549283957002, + "mcts_value": 0.027350993797126976, + "consensus_score": 0.3519870391985993, + "last_agent": "mcts", + "iteration": 0, + "query_length": 126, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.26675615856256124, + "trm_confidence": 0.8486474023806047, + "mcts_value": 0.4196913407843851, + "consensus_score": 0.5078201302744079, + "last_agent": "trm", + "iteration": 7, + "query_length": 879, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.24076419485499137, + "trm_confidence": 0.7954192441326038, + "mcts_value": 0.4069230558316348, + "consensus_score": 0.500321991851245, + "last_agent": "none", + "iteration": 3, + "query_length": 3450, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5232461696556046, + "trm_confidence": 0.9082135039558441, + "mcts_value": 0.10914988406229703, + "consensus_score": 0.5017448951596606, + "last_agent": "none", + "iteration": 9, + "query_length": 2644, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.13977011727098437, + "trm_confidence": 0.9565491839452641, + "mcts_value": 0.434483402192434, + "consensus_score": 0.4951348535874965, + "last_agent": "trm", + "iteration": 0, + "query_length": 3296, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7027971798793406, + "trm_confidence": 0.8364775682002916, + "mcts_value": 0.2111804508894396, + "consensus_score": 0.6567402210062185, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3329, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4528137893677984, + "trm_confidence": 0.7608033363603153, + "mcts_value": 0.36276653795130237, + "consensus_score": 0.6143671543969972, + "last_agent": "trm", + "iteration": 4, + "query_length": 722, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4481637236190501, + "trm_confidence": 0.741486100448232, + "mcts_value": 0.6248777808889668, + "consensus_score": 0.5105112890996961, + "last_agent": "none", + "iteration": 8, + "query_length": 4318, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06857363256335977, + "trm_confidence": 0.7224344602682048, + "mcts_value": 0.4306979431022975, + "consensus_score": 0.436828020490445, + "last_agent": "mcts", + "iteration": 0, + "query_length": 2642, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6867730459739563, + "trm_confidence": 0.8570271543534677, + "mcts_value": 0.6646629381862011, + "consensus_score": 0.7986559333145749, + "last_agent": "hrm", + "iteration": 5, + "query_length": 1851, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.07578925298603742, + "trm_confidence": 0.8517926511733398, + "mcts_value": 0.4953156926761809, + "consensus_score": 0.4010946427909876, + "last_agent": "trm", + "iteration": 5, + "query_length": 2697, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.16010157541766123, + "trm_confidence": 0.7898537470823938, + "mcts_value": 0.42773108304892316, + "consensus_score": 0.516586144254161, + "last_agent": "trm", + "iteration": 0, + "query_length": 333, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.29374569321155797, + "trm_confidence": 0.9582303084863814, + "mcts_value": 0.5655772628314892, + "consensus_score": 0.586373694517166, + "last_agent": "hrm", + "iteration": 1, + "query_length": 1770, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5498426640572068, + "trm_confidence": 0.725907902350479, + "mcts_value": 0.578520264476097, + "consensus_score": 0.6824129043635416, + "last_agent": "trm", + "iteration": 6, + "query_length": 905, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.22196676364259524, + "trm_confidence": 0.995382212980279, + "mcts_value": 0.06800281593644064, + "consensus_score": 0.4408387822400072, + "last_agent": "mcts", + "iteration": 3, + "query_length": 2222, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.35901101594477813, + "trm_confidence": 0.7021256062509844, + "mcts_value": 0.15898145305339606, + "consensus_score": 0.3442777347215891, + "last_agent": "none", + "iteration": 0, + "query_length": 4410, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2907794325942135, + "trm_confidence": 0.9636906617159372, + "mcts_value": 0.5865100538924193, + "consensus_score": 0.6649694420812382, + "last_agent": "hrm", + "iteration": 3, + "query_length": 4692, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.45162383408007106, + "trm_confidence": 0.9075638979000847, + "mcts_value": 0.0008014781372515946, + "consensus_score": 0.3603478898903641, + "last_agent": "none", + "iteration": 10, + "query_length": 3176, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7951723642940054, + "trm_confidence": 0.981269577631884, + "mcts_value": 0.018015779119747756, + "consensus_score": 0.6565220316400938, + "last_agent": "hrm", + "iteration": 4, + "query_length": 2362, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.17350022601120896, + "trm_confidence": 0.7020147352490155, + "mcts_value": 0.45230798396929145, + "consensus_score": 0.5012118061301566, + "last_agent": "none", + "iteration": 6, + "query_length": 1952, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6700725888611195, + "trm_confidence": 0.7931087773863589, + "mcts_value": 0.30442181656221784, + "consensus_score": 0.5786261709853329, + "last_agent": "mcts", + "iteration": 6, + "query_length": 2346, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.33108310202008395, + "trm_confidence": 0.8658014693792242, + "mcts_value": 0.6330995001890818, + "consensus_score": 0.6918048944891201, + "last_agent": "none", + "iteration": 10, + "query_length": 1477, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5166902976733988, + "trm_confidence": 0.8326445672393485, + "mcts_value": 0.614193224751526, + "consensus_score": 0.6671058116347, + "last_agent": "trm", + "iteration": 7, + "query_length": 1003, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5990841615461594, + "trm_confidence": 0.902918133015186, + "mcts_value": 0.16752512087819482, + "consensus_score": 0.5561840478842063, + "last_agent": "hrm", + "iteration": 6, + "query_length": 4964, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5226621376728493, + "trm_confidence": 0.9567887526441046, + "mcts_value": 0.11064255655673892, + "consensus_score": 0.5933025634730631, + "last_agent": "none", + "iteration": 7, + "query_length": 3743, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6129866001129773, + "trm_confidence": 0.8176127385585569, + "mcts_value": 0.2971114670527823, + "consensus_score": 0.5900054324972238, + "last_agent": "hrm", + "iteration": 0, + "query_length": 3059, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4624508846282009, + "trm_confidence": 0.915452737638381, + "mcts_value": 0.40492024675669264, + "consensus_score": 0.6562619177500998, + "last_agent": "trm", + "iteration": 7, + "query_length": 2326, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6278566491726219, + "trm_confidence": 0.9877220014239958, + "mcts_value": 0.8770151232588164, + "consensus_score": 0.9251168545389185, + "last_agent": "hrm", + "iteration": 10, + "query_length": 4461, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.134780155862351, + "trm_confidence": 0.8828288105490902, + "mcts_value": 0.42535114381200245, + "consensus_score": 0.53406200788668, + "last_agent": "trm", + "iteration": 5, + "query_length": 3053, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.24109591575028186, + "trm_confidence": 0.7328967697575094, + "mcts_value": 0.320223246140162, + "consensus_score": 0.4178703697231394, + "last_agent": "trm", + "iteration": 1, + "query_length": 4663, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.41234914508910964, + "trm_confidence": 0.9512851878566462, + "mcts_value": 0.041476680211027434, + "consensus_score": 0.3716685895144577, + "last_agent": "none", + "iteration": 5, + "query_length": 28, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.645435040183728, + "trm_confidence": 0.9115377837997666, + "mcts_value": 0.1852714048300609, + "consensus_score": 0.5991834652475413, + "last_agent": "mcts", + "iteration": 1, + "query_length": 1162, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.47153643452004507, + "trm_confidence": 0.96663864825443, + "mcts_value": 0.7677453812636504, + "consensus_score": 0.6803563821819433, + "last_agent": "none", + "iteration": 3, + "query_length": 3270, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7422740536411904, + "trm_confidence": 0.8440240225583628, + "mcts_value": 0.08515006658387513, + "consensus_score": 0.6275423880481276, + "last_agent": "hrm", + "iteration": 9, + "query_length": 3237, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.36512558615465096, + "trm_confidence": 0.8032111504777462, + "mcts_value": 0.2056665511746096, + "consensus_score": 0.5178225856236568, + "last_agent": "mcts", + "iteration": 5, + "query_length": 2685, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4046702589584559, + "trm_confidence": 0.8868074033068132, + "mcts_value": 0.5594746319742098, + "consensus_score": 0.6177706909658732, + "last_agent": "none", + "iteration": 9, + "query_length": 4443, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7244777124892807, + "trm_confidence": 0.9920046051237716, + "mcts_value": 0.7677781225671295, + "consensus_score": 0.8246683798676082, + "last_agent": "trm", + "iteration": 9, + "query_length": 1130, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1163748177424833, + "trm_confidence": 0.9285844535692526, + "mcts_value": 0.2544184440051873, + "consensus_score": 0.42227243777758766, + "last_agent": "hrm", + "iteration": 3, + "query_length": 2995, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.21715587039051465, + "trm_confidence": 0.9460314234351094, + "mcts_value": 0.6305312909373195, + "consensus_score": 0.5126486087433115, + "last_agent": "none", + "iteration": 8, + "query_length": 3133, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5205522309687037, + "trm_confidence": 0.8205096326507446, + "mcts_value": 0.6593660351069702, + "consensus_score": 0.7269518942863782, + "last_agent": "mcts", + "iteration": 4, + "query_length": 579, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.24396749425820094, + "trm_confidence": 0.9338072634357416, + "mcts_value": 0.6702832678201668, + "consensus_score": 0.5816335703502248, + "last_agent": "none", + "iteration": 8, + "query_length": 2354, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6118934963390785, + "trm_confidence": 0.713319189073177, + "mcts_value": 0.5332547377561968, + "consensus_score": 0.6960586417493916, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1201, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.08920373448121506, + "trm_confidence": 0.7860506664040456, + "mcts_value": 0.3438548600718579, + "consensus_score": 0.35037941774769443, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3864, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.15847387327825596, + "trm_confidence": 0.7429590213229365, + "mcts_value": 0.12792101257159041, + "consensus_score": 0.28827365637353636, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2006, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6984138608003901, + "trm_confidence": 0.884507314571012, + "mcts_value": 0.4727101899991634, + "consensus_score": 0.724117013380798, + "last_agent": "hrm", + "iteration": 4, + "query_length": 1770, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6713096838839204, + "trm_confidence": 0.7961752187349718, + "mcts_value": 0.43288967658355837, + "consensus_score": 0.5889604169542866, + "last_agent": "trm", + "iteration": 4, + "query_length": 747, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4488211404725031, + "trm_confidence": 0.9005330712117321, + "mcts_value": 0.7086021561543003, + "consensus_score": 0.7204733703895366, + "last_agent": "mcts", + "iteration": 7, + "query_length": 724, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6098893298605026, + "trm_confidence": 0.7319166049037223, + "mcts_value": 0.3784285480519917, + "consensus_score": 0.5565481761013581, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1530, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.18342038240570238, + "trm_confidence": 0.946638182171232, + "mcts_value": 0.6811593682621512, + "consensus_score": 0.610614999599084, + "last_agent": "trm", + "iteration": 3, + "query_length": 3571, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.14906470567047758, + "trm_confidence": 0.9447500065979817, + "mcts_value": 0.043762371778592296, + "consensus_score": 0.43036790751419335, + "last_agent": "hrm", + "iteration": 3, + "query_length": 1191, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5791896333506865, + "trm_confidence": 0.7741202761842156, + "mcts_value": 0.20043586225665735, + "consensus_score": 0.4599537712140579, + "last_agent": "mcts", + "iteration": 3, + "query_length": 2706, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.05766527097728564, + "trm_confidence": 0.7952965662415622, + "mcts_value": 0.11022538489293426, + "consensus_score": 0.3044729840498236, + "last_agent": "trm", + "iteration": 9, + "query_length": 2683, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.6029620503797231, + "trm_confidence": 0.7638024120280091, + "mcts_value": 0.5907938874476786, + "consensus_score": 0.6884096166109919, + "last_agent": "hrm", + "iteration": 2, + "query_length": 1257, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.08530128330633818, + "trm_confidence": 0.7633662983313895, + "mcts_value": 0.2520781283405512, + "consensus_score": 0.3820795576734962, + "last_agent": "mcts", + "iteration": 3, + "query_length": 4228, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.23824827363124398, + "trm_confidence": 0.7888764206137916, + "mcts_value": 0.5446763936859894, + "consensus_score": 0.5995966421557799, + "last_agent": "hrm", + "iteration": 3, + "query_length": 4627, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.03645149074883067, + "trm_confidence": 0.7834486851676804, + "mcts_value": 0.5042875204423015, + "consensus_score": 0.375596595433501, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3641, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06196604299688403, + "trm_confidence": 0.7553414752537317, + "mcts_value": 0.228067362442038, + "consensus_score": 0.43891547507428164, + "last_agent": "trm", + "iteration": 6, + "query_length": 1758, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3671711066549766, + "trm_confidence": 0.7116732462322645, + "mcts_value": 0.05643153072067, + "consensus_score": 0.3869341087454756, + "last_agent": "none", + "iteration": 8, + "query_length": 4955, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.03268675557397803, + "trm_confidence": 0.7898234484449573, + "mcts_value": 0.017092359289111358, + "consensus_score": 0.2658728350317865, + "last_agent": "trm", + "iteration": 9, + "query_length": 447, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.52843234883557, + "trm_confidence": 0.8722166517345464, + "mcts_value": 0.046973836607357426, + "consensus_score": 0.4131610208865, + "last_agent": "trm", + "iteration": 3, + "query_length": 1212, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.537218265182701, + "trm_confidence": 0.8317105336684291, + "mcts_value": 0.6348304972281507, + "consensus_score": 0.7195942206794689, + "last_agent": "mcts", + "iteration": 10, + "query_length": 2895, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7122174030713968, + "trm_confidence": 0.9556537585022351, + "mcts_value": 0.6248546442004216, + "consensus_score": 0.824602298589196, + "last_agent": "mcts", + "iteration": 10, + "query_length": 2946, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4075676356567753, + "trm_confidence": 0.8914861071737631, + "mcts_value": 0.6735839520535416, + "consensus_score": 0.6438178188689948, + "last_agent": "trm", + "iteration": 9, + "query_length": 3894, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.22118242784204983, + "trm_confidence": 0.9885658304105771, + "mcts_value": 0.6305436490773224, + "consensus_score": 0.6648145353061471, + "last_agent": "trm", + "iteration": 10, + "query_length": 1067, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.44713971051353746, + "trm_confidence": 0.8720577585125344, + "mcts_value": 0.1291584908205778, + "consensus_score": 0.5777778604336374, + "last_agent": "none", + "iteration": 2, + "query_length": 2380, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3700644311505463, + "trm_confidence": 0.9387385014415938, + "mcts_value": 0.17713630425104762, + "consensus_score": 0.43310515579166653, + "last_agent": "none", + "iteration": 9, + "query_length": 4026, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.44960067505898793, + "trm_confidence": 0.7270410946114799, + "mcts_value": 0.15686753079097213, + "consensus_score": 0.5197820503039892, + "last_agent": "mcts", + "iteration": 3, + "query_length": 2904, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3199410028080757, + "trm_confidence": 0.7254575330664803, + "mcts_value": 0.28191721973716205, + "consensus_score": 0.4927010141245345, + "last_agent": "trm", + "iteration": 5, + "query_length": 1870, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1887215862478396, + "trm_confidence": 0.7784964746898831, + "mcts_value": 0.6295415304385765, + "consensus_score": 0.46526891754299554, + "last_agent": "none", + "iteration": 10, + "query_length": 2389, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.15580876886389092, + "trm_confidence": 0.7265170706252492, + "mcts_value": 0.03659760596095116, + "consensus_score": 0.3643137689828136, + "last_agent": "mcts", + "iteration": 5, + "query_length": 1120, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5036437005204956, + "trm_confidence": 0.8986148813894688, + "mcts_value": 0.7505958370597642, + "consensus_score": 0.8098702430062522, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3312, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.41157849497836363, + "trm_confidence": 0.9377275301005535, + "mcts_value": 0.5901441016930987, + "consensus_score": 0.5621351771273859, + "last_agent": "mcts", + "iteration": 8, + "query_length": 531, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5229840904691805, + "trm_confidence": 0.7254924259377215, + "mcts_value": 0.5835719402202927, + "consensus_score": 0.5559205968361439, + "last_agent": "hrm", + "iteration": 5, + "query_length": 3444, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.1135682279040384, + "trm_confidence": 0.9880678013568904, + "mcts_value": 0.3099394152532933, + "consensus_score": 0.5101540764015787, + "last_agent": "none", + "iteration": 2, + "query_length": 3436, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.0931857133634963, + "trm_confidence": 0.7069873180014079, + "mcts_value": 0.34132687164971404, + "consensus_score": 0.41111798093237595, + "last_agent": "trm", + "iteration": 7, + "query_length": 2463, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.19816711313021182, + "trm_confidence": 0.8245844241216375, + "mcts_value": 0.01128953256772494, + "consensus_score": 0.28600632181700203, + "last_agent": "none", + "iteration": 3, + "query_length": 2538, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.09542007485860418, + "trm_confidence": 0.7060433414489806, + "mcts_value": 0.4036851991817658, + "consensus_score": 0.42763328845763604, + "last_agent": "trm", + "iteration": 10, + "query_length": 1285, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.10596334580580775, + "trm_confidence": 0.8703268350523573, + "mcts_value": 0.4345196590148425, + "consensus_score": 0.38258895880717203, + "last_agent": "trm", + "iteration": 4, + "query_length": 1557, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2038095686689596, + "trm_confidence": 0.8283441428088825, + "mcts_value": 0.49952058948680605, + "consensus_score": 0.5787469937813052, + "last_agent": "trm", + "iteration": 8, + "query_length": 2092, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.023670709999474536, + "trm_confidence": 0.861503439799904, + "mcts_value": 0.5505167267521973, + "consensus_score": 0.48756021343289146, + "last_agent": "hrm", + "iteration": 2, + "query_length": 656, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5498646131944848, + "trm_confidence": 0.9806820359409224, + "mcts_value": 0.41912753188379326, + "consensus_score": 0.6740757416147484, + "last_agent": "mcts", + "iteration": 4, + "query_length": 2834, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.30222002578153945, + "trm_confidence": 0.7870927770123011, + "mcts_value": 0.4710654708838269, + "consensus_score": 0.428114476243183, + "last_agent": "mcts", + "iteration": 8, + "query_length": 2188, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.517244726214594, + "trm_confidence": 0.9047679380297499, + "mcts_value": 0.5252141294078266, + "consensus_score": 0.6707013811352175, + "last_agent": "none", + "iteration": 10, + "query_length": 4448, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5782538078279272, + "trm_confidence": 0.9196957346546054, + "mcts_value": 0.19366752533222478, + "consensus_score": 0.5205335237756176, + "last_agent": "trm", + "iteration": 4, + "query_length": 2537, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.27759100049047153, + "trm_confidence": 0.8100229296789879, + "mcts_value": 0.42875768468675873, + "consensus_score": 0.449435213499085, + "last_agent": "hrm", + "iteration": 8, + "query_length": 731, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.06612097058965101, + "trm_confidence": 0.9118732335049741, + "mcts_value": 0.3815523241631893, + "consensus_score": 0.35882608252140374, + "last_agent": "mcts", + "iteration": 7, + "query_length": 50, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2867772756099233, + "trm_confidence": 0.9847758442317486, + "mcts_value": 0.4988418394014734, + "consensus_score": 0.5416091682199714, + "last_agent": "none", + "iteration": 6, + "query_length": 3350, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.14486151824914825, + "trm_confidence": 0.8111884784516898, + "mcts_value": 0.07931269493394683, + "consensus_score": 0.2690967240029296, + "last_agent": "trm", + "iteration": 10, + "query_length": 3213, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5723089366326094, + "trm_confidence": 0.7005789123620018, + "mcts_value": 0.011999521163016088, + "consensus_score": 0.4120130151634491, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2620, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.37732918644693036, + "trm_confidence": 0.9138658874296303, + "mcts_value": 0.7930651348065357, + "consensus_score": 0.6026172795010194, + "last_agent": "trm", + "iteration": 10, + "query_length": 441, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.22386318014259662, + "trm_confidence": 0.9786790051379077, + "mcts_value": 0.1251749172286063, + "consensus_score": 0.47523542337757685, + "last_agent": "hrm", + "iteration": 5, + "query_length": 3564, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5440501535428177, + "trm_confidence": 0.9773857535449653, + "mcts_value": 0.2729229773638274, + "consensus_score": 0.578132939434112, + "last_agent": "mcts", + "iteration": 1, + "query_length": 2733, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.5210214009226612, + "trm_confidence": 0.7430050210628615, + "mcts_value": 0.4532371791200645, + "consensus_score": 0.5784902038906994, + "last_agent": "none", + "iteration": 4, + "query_length": 2534, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.7436476874790247, + "trm_confidence": 0.9438716043087283, + "mcts_value": 0.46161146207938814, + "consensus_score": 0.7803539423196538, + "last_agent": "none", + "iteration": 5, + "query_length": 159, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3321384555464614, + "trm_confidence": 0.9596295566866301, + "mcts_value": 0.34033824077037494, + "consensus_score": 0.5050176867248054, + "last_agent": "hrm", + "iteration": 5, + "query_length": 4605, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.34360040070387166, + "trm_confidence": 0.898138343593137, + "mcts_value": 0.1738642953167907, + "consensus_score": 0.5452360923540416, + "last_agent": "none", + "iteration": 3, + "query_length": 1367, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3708734210569285, + "trm_confidence": 0.7622059999079367, + "mcts_value": 0.35534973457636193, + "consensus_score": 0.5618726214324747, + "last_agent": "none", + "iteration": 1, + "query_length": 485, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4451081419245728, + "trm_confidence": 0.8594284516516211, + "mcts_value": 0.38743479820397003, + "consensus_score": 0.47119544529397467, + "last_agent": "trm", + "iteration": 5, + "query_length": 356, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.4561927932840184, + "trm_confidence": 0.8028064541133494, + "mcts_value": 0.34978475478808896, + "consensus_score": 0.45490497659531093, + "last_agent": "hrm", + "iteration": 0, + "query_length": 4638, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.399815127929161, + "trm_confidence": 0.7572655439716671, + "mcts_value": 0.11687132043692414, + "consensus_score": 0.5009945487675382, + "last_agent": "hrm", + "iteration": 2, + "query_length": 3720, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.2547861301271446, + "trm_confidence": 0.9627238049078941, + "mcts_value": 0.3670222299040831, + "consensus_score": 0.6199826651155291, + "last_agent": "trm", + "iteration": 10, + "query_length": 3757, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.3682753491268558, + "trm_confidence": 0.7206816855197664, + "mcts_value": 0.6139025598324387, + "consensus_score": 0.6570031105111498, + "last_agent": "trm", + "iteration": 1, + "query_length": 301, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.09011705955441619, + "trm_confidence": 0.7946610237708518, + "mcts_value": 0.4826493062063232, + "consensus_score": 0.4105788027159111, + "last_agent": "mcts", + "iteration": 3, + "query_length": 1686, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.11852945767324931, + "trm_confidence": 0.8493902676335996, + "mcts_value": 0.3238043581681488, + "consensus_score": 0.3592431389123262, + "last_agent": "trm", + "iteration": 10, + "query_length": 2869, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.14902826731190888, + "trm_confidence": 0.8565661461551899, + "mcts_value": 0.057772755366731966, + "consensus_score": 0.29491744671302983, + "last_agent": "mcts", + "iteration": 0, + "query_length": 3267, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.48758750266237133, + "trm_confidence": 0.848045247328965, + "mcts_value": 0.06533413356703832, + "consensus_score": 0.4551926859669786, + "last_agent": "none", + "iteration": 6, + "query_length": 1058, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.0511193831067477, + "trm_confidence": 0.8116235234548076, + "mcts_value": 0.15665045493738797, + "consensus_score": 0.3331483723501024, + "last_agent": "none", + "iteration": 10, + "query_length": 2360, + "has_rag_context": false + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.37819121088269464, + "trm_confidence": 0.9938865043372864, + "mcts_value": 0.07211910142969438, + "consensus_score": 0.4461136263330919, + "last_agent": "none", + "iteration": 2, + "query_length": 1219, + "has_rag_context": true + }, + "label": "trm" + }, + { + "features": { + "hrm_confidence": 0.31723435453434623, + "trm_confidence": 0.23776372466300788, + "mcts_value": 0.6380934441153502, + "consensus_score": 0.33376791876446466, + "last_agent": "trm", + "iteration": 9, + "query_length": 1345, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1686840121841821, + "trm_confidence": 0.3768754430606373, + "mcts_value": 0.9407471422986178, + "consensus_score": 0.4322633759443156, + "last_agent": "none", + "iteration": 7, + "query_length": 3990, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.42855636392941093, + "trm_confidence": 0.1846502295230367, + "mcts_value": 0.6646314497486088, + "consensus_score": 0.36041945971320166, + "last_agent": "none", + "iteration": 10, + "query_length": 2756, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3837846892832857, + "trm_confidence": 0.17542621208214185, + "mcts_value": 0.8260592450783548, + "consensus_score": 0.4523110645038243, + "last_agent": "mcts", + "iteration": 4, + "query_length": 1628, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13536975187858913, + "trm_confidence": 0.22041493479425212, + "mcts_value": 0.6886788593507293, + "consensus_score": 0.4286284993938186, + "last_agent": "trm", + "iteration": 9, + "query_length": 1277, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4640450577224613, + "trm_confidence": 0.3360216995983519, + "mcts_value": 0.8607849384132342, + "consensus_score": 0.6346544965847443, + "last_agent": "hrm", + "iteration": 7, + "query_length": 908, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3822976453806744, + "trm_confidence": 0.08681387071692682, + "mcts_value": 0.856150003967066, + "consensus_score": 0.42029892280404163, + "last_agent": "trm", + "iteration": 6, + "query_length": 4107, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5750886628472248, + "trm_confidence": 0.30967271573066185, + "mcts_value": 0.7515370227892435, + "consensus_score": 0.5961329131448321, + "last_agent": "trm", + "iteration": 4, + "query_length": 3365, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.18185877193841446, + "trm_confidence": 0.1690427244553648, + "mcts_value": 0.7922239563195673, + "consensus_score": 0.33405727528610984, + "last_agent": "hrm", + "iteration": 6, + "query_length": 424, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4683262010865425, + "trm_confidence": 0.3830163896608705, + "mcts_value": 0.9798412171603175, + "consensus_score": 0.6001141642393371, + "last_agent": "trm", + "iteration": 9, + "query_length": 1590, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6673941740574011, + "trm_confidence": 0.27658075424446743, + "mcts_value": 0.7172150697731274, + "consensus_score": 0.6137456131966376, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3704, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.41029586862175155, + "trm_confidence": 0.08314398616690488, + "mcts_value": 0.6646202905804774, + "consensus_score": 0.44384845643016324, + "last_agent": "trm", + "iteration": 5, + "query_length": 3689, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.12670689546912067, + "trm_confidence": 0.2750004870208167, + "mcts_value": 0.9640358696520795, + "consensus_score": 0.48467150832871886, + "last_agent": "trm", + "iteration": 6, + "query_length": 4998, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.284618700591374, + "trm_confidence": 0.45444631860734425, + "mcts_value": 0.6116286456675187, + "consensus_score": 0.4650280632525767, + "last_agent": "none", + "iteration": 5, + "query_length": 4661, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1919416816053799, + "trm_confidence": 0.5104898707022725, + "mcts_value": 0.7477884184725011, + "consensus_score": 0.4453370629321546, + "last_agent": "hrm", + "iteration": 7, + "query_length": 4404, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4319160371123925, + "trm_confidence": 0.05127219031728598, + "mcts_value": 0.815710392901433, + "consensus_score": 0.36859303552538447, + "last_agent": "hrm", + "iteration": 4, + "query_length": 3731, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13051870523620435, + "trm_confidence": 0.1534902256461637, + "mcts_value": 0.934368432366737, + "consensus_score": 0.37076724564825825, + "last_agent": "none", + "iteration": 8, + "query_length": 750, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6336349293494238, + "trm_confidence": 0.20652809834757668, + "mcts_value": 0.874701583416583, + "consensus_score": 0.6033733348792163, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2360, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.06683838837145224, + "trm_confidence": 0.49840071878243597, + "mcts_value": 0.9629660746374693, + "consensus_score": 0.5549764698683062, + "last_agent": "trm", + "iteration": 10, + "query_length": 2452, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5794184689020143, + "trm_confidence": 0.5395671972826748, + "mcts_value": 0.7286489352876984, + "consensus_score": 0.6159413233746899, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1597, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6056284183515128, + "trm_confidence": 0.22258915323155873, + "mcts_value": 0.8358025321613871, + "consensus_score": 0.5513833644009682, + "last_agent": "trm", + "iteration": 10, + "query_length": 2747, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.26967294262331243, + "trm_confidence": 0.6975536817441912, + "mcts_value": 0.9548156081322211, + "consensus_score": 0.5711938651033375, + "last_agent": "none", + "iteration": 7, + "query_length": 4031, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6360324407605084, + "trm_confidence": 0.5117975638356418, + "mcts_value": 0.9193728763200538, + "consensus_score": 0.7144308741981286, + "last_agent": "none", + "iteration": 6, + "query_length": 716, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.014257931579064364, + "trm_confidence": 0.35213206337937436, + "mcts_value": 0.8037201252899288, + "consensus_score": 0.4612308667842808, + "last_agent": "mcts", + "iteration": 8, + "query_length": 869, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3702812658390706, + "trm_confidence": 0.02343849829797033, + "mcts_value": 0.7561939959074084, + "consensus_score": 0.3599041751153518, + "last_agent": "trm", + "iteration": 7, + "query_length": 2137, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4832856761452664, + "trm_confidence": 0.30456142559419025, + "mcts_value": 0.6124963915962599, + "consensus_score": 0.4499168030676857, + "last_agent": "mcts", + "iteration": 5, + "query_length": 2461, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3317891311203637, + "trm_confidence": 0.4066099439054674, + "mcts_value": 0.7752107125082925, + "consensus_score": 0.40580152529265495, + "last_agent": "hrm", + "iteration": 10, + "query_length": 194, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.14516101266756698, + "trm_confidence": 0.3666476585589127, + "mcts_value": 0.7803698989263979, + "consensus_score": 0.38677221181038673, + "last_agent": "trm", + "iteration": 10, + "query_length": 1590, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.17421580970351525, + "trm_confidence": 0.013584305345812252, + "mcts_value": 0.8555022223073749, + "consensus_score": 0.3285983529082227, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3200, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4947874216922198, + "trm_confidence": 0.6969688604377766, + "mcts_value": 0.984805745527544, + "consensus_score": 0.6845326939486411, + "last_agent": "trm", + "iteration": 9, + "query_length": 651, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6900972403096294, + "trm_confidence": 0.45779103982578184, + "mcts_value": 0.8447041417793998, + "consensus_score": 0.7439608524389874, + "last_agent": "none", + "iteration": 9, + "query_length": 3863, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.33902091207983737, + "trm_confidence": 0.6431585950876396, + "mcts_value": 0.879769484128925, + "consensus_score": 0.6820165532035585, + "last_agent": "trm", + "iteration": 4, + "query_length": 4575, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.34527141232123215, + "trm_confidence": 0.05654023144514711, + "mcts_value": 0.6278444007194423, + "consensus_score": 0.3879543488102126, + "last_agent": "mcts", + "iteration": 7, + "query_length": 3040, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3633985489127985, + "trm_confidence": 0.08223578013738105, + "mcts_value": 0.7119183154685265, + "consensus_score": 0.4627728820362914, + "last_agent": "trm", + "iteration": 5, + "query_length": 4476, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.564362042106087, + "trm_confidence": 0.07407016937927796, + "mcts_value": 0.6820202394092734, + "consensus_score": 0.4630363159325126, + "last_agent": "none", + "iteration": 9, + "query_length": 4580, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.043136798864328404, + "trm_confidence": 0.3141644686405523, + "mcts_value": 0.6760243546432525, + "consensus_score": 0.284938943098287, + "last_agent": "mcts", + "iteration": 8, + "query_length": 1201, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.07943945317119827, + "trm_confidence": 0.3972852005837342, + "mcts_value": 0.7797415764132528, + "consensus_score": 0.35592691464910986, + "last_agent": "none", + "iteration": 8, + "query_length": 1697, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.47916888501006816, + "trm_confidence": 0.12721069746834832, + "mcts_value": 0.9903578369125293, + "consensus_score": 0.43506473555601977, + "last_agent": "trm", + "iteration": 6, + "query_length": 2501, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6109251809065226, + "trm_confidence": 0.42224326203559465, + "mcts_value": 0.6873645332478184, + "consensus_score": 0.5088025226690918, + "last_agent": "mcts", + "iteration": 7, + "query_length": 2055, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.02293006118458668, + "trm_confidence": 0.5126228955174705, + "mcts_value": 0.9904803457432356, + "consensus_score": 0.5632771727200497, + "last_agent": "trm", + "iteration": 4, + "query_length": 2794, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6077083785703474, + "trm_confidence": 0.3758999568929495, + "mcts_value": 0.753722663641583, + "consensus_score": 0.610171703036418, + "last_agent": "trm", + "iteration": 9, + "query_length": 4911, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3561551971414369, + "trm_confidence": 0.5026040999614456, + "mcts_value": 0.7905138172334725, + "consensus_score": 0.5024261765259104, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1564, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5601399601255522, + "trm_confidence": 0.23818074832487054, + "mcts_value": 0.6783086408838899, + "consensus_score": 0.4221763496498987, + "last_agent": "none", + "iteration": 5, + "query_length": 2296, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.049455501039777416, + "trm_confidence": 0.02133663638725458, + "mcts_value": 0.9384726117379951, + "consensus_score": 0.32945521022237767, + "last_agent": "trm", + "iteration": 10, + "query_length": 4636, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.621376823613001, + "trm_confidence": 0.19819455347125778, + "mcts_value": 0.9836682156268193, + "consensus_score": 0.6644694770882553, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3536, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5708519291478197, + "trm_confidence": 0.5551556028865148, + "mcts_value": 0.7691664572171499, + "consensus_score": 0.68311193027234, + "last_agent": "trm", + "iteration": 10, + "query_length": 2364, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6037348568884349, + "trm_confidence": 0.14812180701757874, + "mcts_value": 0.9227691911279663, + "consensus_score": 0.46662190675527215, + "last_agent": "trm", + "iteration": 6, + "query_length": 4698, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4900071078471968, + "trm_confidence": 0.005143697344983877, + "mcts_value": 0.715641312477804, + "consensus_score": 0.3746702590207091, + "last_agent": "none", + "iteration": 8, + "query_length": 3910, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.10939178398195547, + "trm_confidence": 0.6368554815027963, + "mcts_value": 0.9586427648589736, + "consensus_score": 0.5115745385178366, + "last_agent": "trm", + "iteration": 10, + "query_length": 4515, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.04569406413506189, + "trm_confidence": 0.05244518738162813, + "mcts_value": 0.6337379577753245, + "consensus_score": 0.2619376192221952, + "last_agent": "trm", + "iteration": 9, + "query_length": 2060, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4125013925652396, + "trm_confidence": 0.373128423350449, + "mcts_value": 0.7779343037127447, + "consensus_score": 0.5588003864963985, + "last_agent": "mcts", + "iteration": 7, + "query_length": 1259, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.24520848957413544, + "trm_confidence": 0.008550824040745563, + "mcts_value": 0.9989242970186536, + "consensus_score": 0.3426862126548425, + "last_agent": "hrm", + "iteration": 7, + "query_length": 906, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.43921226307516803, + "trm_confidence": 0.36424863017582754, + "mcts_value": 0.975993037481997, + "consensus_score": 0.6527592104987346, + "last_agent": "none", + "iteration": 5, + "query_length": 3920, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.43377415349500237, + "trm_confidence": 0.6543300057320453, + "mcts_value": 0.6012244419456372, + "consensus_score": 0.6206815241310022, + "last_agent": "none", + "iteration": 9, + "query_length": 3779, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.38558216552118657, + "trm_confidence": 0.4329781655522392, + "mcts_value": 0.9296955784400505, + "consensus_score": 0.4945350377049861, + "last_agent": "mcts", + "iteration": 5, + "query_length": 4454, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5190127646113992, + "trm_confidence": 0.5880399059653311, + "mcts_value": 0.623520423129399, + "consensus_score": 0.6609301701610041, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3238, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3318393511494707, + "trm_confidence": 0.5881757144394681, + "mcts_value": 0.6270480030517179, + "consensus_score": 0.4297499172597767, + "last_agent": "hrm", + "iteration": 4, + "query_length": 4775, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.568117721217664, + "trm_confidence": 0.08919866975081975, + "mcts_value": 0.8045080957744134, + "consensus_score": 0.41219687973020713, + "last_agent": "trm", + "iteration": 7, + "query_length": 428, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6647447717237248, + "trm_confidence": 0.5362233846202891, + "mcts_value": 0.8920854940579495, + "consensus_score": 0.773119623409771, + "last_agent": "trm", + "iteration": 7, + "query_length": 2224, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.047843523276939165, + "trm_confidence": 0.4591671770872801, + "mcts_value": 0.6731612735946729, + "consensus_score": 0.30378115873790856, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1205, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.21972065941926494, + "trm_confidence": 0.3403108164411346, + "mcts_value": 0.9808802901635648, + "consensus_score": 0.5619753283503452, + "last_agent": "trm", + "iteration": 5, + "query_length": 3782, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6851660581129289, + "trm_confidence": 0.5212923059696694, + "mcts_value": 0.8637869100021912, + "consensus_score": 0.6137079280840779, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3905, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3023762398867548, + "trm_confidence": 0.19880407143507609, + "mcts_value": 0.6586492571457411, + "consensus_score": 0.3103753646082625, + "last_agent": "hrm", + "iteration": 7, + "query_length": 427, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1187878525693058, + "trm_confidence": 0.5065164465219747, + "mcts_value": 0.8362149634342894, + "consensus_score": 0.4262214121144573, + "last_agent": "trm", + "iteration": 10, + "query_length": 1139, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2240733361205828, + "trm_confidence": 0.15755356804221354, + "mcts_value": 0.8087989548182943, + "consensus_score": 0.3911291754933931, + "last_agent": "none", + "iteration": 9, + "query_length": 2561, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.10797053546654987, + "trm_confidence": 0.5040983132964151, + "mcts_value": 0.7212994472591383, + "consensus_score": 0.46085584376523214, + "last_agent": "trm", + "iteration": 5, + "query_length": 619, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.45737663965518527, + "trm_confidence": 0.2515373237743721, + "mcts_value": 0.8075721740968564, + "consensus_score": 0.4933148914359465, + "last_agent": "trm", + "iteration": 5, + "query_length": 2215, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5448118163257429, + "trm_confidence": 0.5586054984192101, + "mcts_value": 0.6628045296607167, + "consensus_score": 0.6380608076129142, + "last_agent": "trm", + "iteration": 4, + "query_length": 2990, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.41688106510175704, + "trm_confidence": 0.6434554695583351, + "mcts_value": 0.7873350971167254, + "consensus_score": 0.5910117383513462, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3784, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.0365974075749261, + "trm_confidence": 0.5197425037389355, + "mcts_value": 0.7485121950034508, + "consensus_score": 0.3650812535923534, + "last_agent": "trm", + "iteration": 5, + "query_length": 4709, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.009780903592923794, + "trm_confidence": 0.0015646433231779942, + "mcts_value": 0.813886686969084, + "consensus_score": 0.27720953764304135, + "last_agent": "hrm", + "iteration": 6, + "query_length": 3222, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.15990553020219894, + "trm_confidence": 0.41454500477946354, + "mcts_value": 0.8454375750864443, + "consensus_score": 0.44914922680662356, + "last_agent": "trm", + "iteration": 5, + "query_length": 2849, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.020882475609089554, + "trm_confidence": 0.2156888734253879, + "mcts_value": 0.6947442412509153, + "consensus_score": 0.3908294681314898, + "last_agent": "mcts", + "iteration": 9, + "query_length": 509, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3971523913098409, + "trm_confidence": 0.5801162774355986, + "mcts_value": 0.8303334474327215, + "consensus_score": 0.5497920590294468, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1739, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2202815043024345, + "trm_confidence": 0.6145396436116048, + "mcts_value": 0.6163479959086697, + "consensus_score": 0.5661205774551992, + "last_agent": "trm", + "iteration": 4, + "query_length": 4524, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3425514376363715, + "trm_confidence": 0.38830366684107404, + "mcts_value": 0.7525395606350557, + "consensus_score": 0.39851664037438295, + "last_agent": "none", + "iteration": 9, + "query_length": 2853, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4623165669453584, + "trm_confidence": 0.36958590048149587, + "mcts_value": 0.8613619248014569, + "consensus_score": 0.53127614641898, + "last_agent": "trm", + "iteration": 5, + "query_length": 3363, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5771517973843117, + "trm_confidence": 0.2946381896294308, + "mcts_value": 0.6179765614036828, + "consensus_score": 0.47953969430487003, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1062, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.10591547503264781, + "trm_confidence": 0.1158720510815211, + "mcts_value": 0.6699667811082876, + "consensus_score": 0.3289569323109717, + "last_agent": "mcts", + "iteration": 10, + "query_length": 4479, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1950315914457207, + "trm_confidence": 0.4535299440877101, + "mcts_value": 0.9622469588269951, + "consensus_score": 0.44503595517348216, + "last_agent": "mcts", + "iteration": 7, + "query_length": 569, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.16648476941725376, + "trm_confidence": 0.5799478637875408, + "mcts_value": 0.8145279108890514, + "consensus_score": 0.6143474836122634, + "last_agent": "hrm", + "iteration": 7, + "query_length": 4423, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1690423424352558, + "trm_confidence": 0.6051152673501498, + "mcts_value": 0.916165307338229, + "consensus_score": 0.5194811534365443, + "last_agent": "hrm", + "iteration": 9, + "query_length": 899, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3735339537579074, + "trm_confidence": 0.36744546699302055, + "mcts_value": 0.7153546813656679, + "consensus_score": 0.5555309376028053, + "last_agent": "none", + "iteration": 10, + "query_length": 3732, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.25031874715615304, + "trm_confidence": 0.33524288546700354, + "mcts_value": 0.7085552531300148, + "consensus_score": 0.4301542438819369, + "last_agent": "none", + "iteration": 9, + "query_length": 462, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3185152324082315, + "trm_confidence": 0.43265168677086213, + "mcts_value": 0.6075406316151813, + "consensus_score": 0.4449816343500297, + "last_agent": "none", + "iteration": 4, + "query_length": 2300, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13140114571483114, + "trm_confidence": 0.5294149801296674, + "mcts_value": 0.7646622780518945, + "consensus_score": 0.43943960399395965, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3893, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.08541045230398672, + "trm_confidence": 0.0799841057774221, + "mcts_value": 0.9509708316865555, + "consensus_score": 0.33218573041392463, + "last_agent": "trm", + "iteration": 7, + "query_length": 4325, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.17168609017650976, + "trm_confidence": 0.059342319940571306, + "mcts_value": 0.9052748329580975, + "consensus_score": 0.3615277096685557, + "last_agent": "mcts", + "iteration": 9, + "query_length": 4637, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3905206749498208, + "trm_confidence": 0.5069341912244236, + "mcts_value": 0.8933379393453402, + "consensus_score": 0.6332088422339671, + "last_agent": "mcts", + "iteration": 5, + "query_length": 4970, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.22198446331259672, + "trm_confidence": 0.45778120270185196, + "mcts_value": 0.9082437330553728, + "consensus_score": 0.5256325518603668, + "last_agent": "trm", + "iteration": 7, + "query_length": 581, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.605980924748164, + "trm_confidence": 0.09900190161810396, + "mcts_value": 0.8664868661221519, + "consensus_score": 0.49704436327276325, + "last_agent": "mcts", + "iteration": 4, + "query_length": 945, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.34543005581446923, + "trm_confidence": 0.24886331304467244, + "mcts_value": 0.9958633986686976, + "consensus_score": 0.48502767116761686, + "last_agent": "trm", + "iteration": 10, + "query_length": 1037, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.19863066706415367, + "trm_confidence": 0.049342653749843375, + "mcts_value": 0.8902833582950973, + "consensus_score": 0.4160742159399531, + "last_agent": "none", + "iteration": 4, + "query_length": 2372, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5222729624500508, + "trm_confidence": 0.5388251171454546, + "mcts_value": 0.7001506395959293, + "consensus_score": 0.6319043541606525, + "last_agent": "hrm", + "iteration": 4, + "query_length": 1844, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.16983775796793601, + "trm_confidence": 0.19196700055773697, + "mcts_value": 0.9960582456223053, + "consensus_score": 0.5153263112987674, + "last_agent": "none", + "iteration": 5, + "query_length": 2822, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.026012463273550022, + "trm_confidence": 0.3876334659856583, + "mcts_value": 0.9020105510002314, + "consensus_score": 0.4537276117936565, + "last_agent": "trm", + "iteration": 6, + "query_length": 440, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.25493629665410233, + "trm_confidence": 0.5940680810031305, + "mcts_value": 0.6650998877199426, + "consensus_score": 0.44282997311359035, + "last_agent": "hrm", + "iteration": 9, + "query_length": 4586, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.05273146371199239, + "trm_confidence": 0.47453250954306, + "mcts_value": 0.795691626875996, + "consensus_score": 0.4991054978548028, + "last_agent": "mcts", + "iteration": 9, + "query_length": 3727, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5383433396880142, + "trm_confidence": 0.35403079482493055, + "mcts_value": 0.8688062889198347, + "consensus_score": 0.6696075213804851, + "last_agent": "none", + "iteration": 8, + "query_length": 929, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.347716446502645, + "trm_confidence": 0.5108732447446628, + "mcts_value": 0.6941808068055524, + "consensus_score": 0.4581345169301566, + "last_agent": "hrm", + "iteration": 5, + "query_length": 670, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.12727081704664014, + "trm_confidence": 0.23202471738705419, + "mcts_value": 0.7738441442948825, + "consensus_score": 0.35718845589604403, + "last_agent": "hrm", + "iteration": 8, + "query_length": 1700, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4561891312377787, + "trm_confidence": 0.4899456525584073, + "mcts_value": 0.9946360303688364, + "consensus_score": 0.611289438660992, + "last_agent": "mcts", + "iteration": 10, + "query_length": 4474, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6812791930683556, + "trm_confidence": 0.6048970737743725, + "mcts_value": 0.642552259724387, + "consensus_score": 0.6036385357633317, + "last_agent": "none", + "iteration": 8, + "query_length": 3610, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2007419475322064, + "trm_confidence": 0.09521921165463454, + "mcts_value": 0.7292512824581884, + "consensus_score": 0.35160109700351094, + "last_agent": "trm", + "iteration": 7, + "query_length": 3222, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6570274695412733, + "trm_confidence": 0.447013837431379, + "mcts_value": 0.6861428669917617, + "consensus_score": 0.6766481894648053, + "last_agent": "trm", + "iteration": 4, + "query_length": 3175, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.23164454587012664, + "trm_confidence": 0.13660021839519046, + "mcts_value": 0.9397847274576645, + "consensus_score": 0.3890140917354511, + "last_agent": "hrm", + "iteration": 5, + "query_length": 766, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1898918508955939, + "trm_confidence": 0.4539586088835965, + "mcts_value": 0.7303854411117326, + "consensus_score": 0.36868510796468457, + "last_agent": "mcts", + "iteration": 6, + "query_length": 339, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.15464950354903867, + "trm_confidence": 0.14046663778904223, + "mcts_value": 0.6703398997124024, + "consensus_score": 0.22200390169644207, + "last_agent": "none", + "iteration": 7, + "query_length": 2176, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6963623972143801, + "trm_confidence": 0.6428773479732098, + "mcts_value": 0.70917592470996, + "consensus_score": 0.6258203849304502, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4246, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5618521589751637, + "trm_confidence": 0.08836900559639752, + "mcts_value": 0.9377654186636974, + "consensus_score": 0.45421863074809, + "last_agent": "trm", + "iteration": 6, + "query_length": 3960, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.07515609478057111, + "trm_confidence": 0.40200810157588607, + "mcts_value": 0.9528339996382009, + "consensus_score": 0.4078011047565753, + "last_agent": "hrm", + "iteration": 8, + "query_length": 4111, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.05815794228443604, + "trm_confidence": 0.40547793712618185, + "mcts_value": 0.7819077169902374, + "consensus_score": 0.39688453282745606, + "last_agent": "trm", + "iteration": 9, + "query_length": 4693, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1915902853443494, + "trm_confidence": 0.11026234512208015, + "mcts_value": 0.9499588801890357, + "consensus_score": 0.37856657901553775, + "last_agent": "mcts", + "iteration": 4, + "query_length": 4286, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.09959808498961463, + "trm_confidence": 0.07795548095061157, + "mcts_value": 0.8640470535528122, + "consensus_score": 0.2736772647485496, + "last_agent": "trm", + "iteration": 4, + "query_length": 3947, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.03259280621015039, + "trm_confidence": 0.5320502470804781, + "mcts_value": 0.9776460461986113, + "consensus_score": 0.5615267158398527, + "last_agent": "mcts", + "iteration": 10, + "query_length": 1656, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.31715748456687903, + "trm_confidence": 0.5996026240044333, + "mcts_value": 0.7142035966761665, + "consensus_score": 0.5893732969234061, + "last_agent": "none", + "iteration": 7, + "query_length": 3651, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5981688076295373, + "trm_confidence": 0.2799658617472436, + "mcts_value": 0.8908062965073484, + "consensus_score": 0.62336407779224, + "last_agent": "hrm", + "iteration": 6, + "query_length": 849, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5430080698155945, + "trm_confidence": 0.6842541888902517, + "mcts_value": 0.8099641519819655, + "consensus_score": 0.6679487027569171, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4395, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2078571164724059, + "trm_confidence": 0.060513189288164355, + "mcts_value": 0.6535556277500726, + "consensus_score": 0.29288749713939705, + "last_agent": "hrm", + "iteration": 5, + "query_length": 378, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1721493554123952, + "trm_confidence": 0.42832623493944827, + "mcts_value": 0.9174310997566425, + "consensus_score": 0.5862051805525512, + "last_agent": "hrm", + "iteration": 7, + "query_length": 328, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.41325844502884623, + "trm_confidence": 0.3546277978318974, + "mcts_value": 0.7379064775863817, + "consensus_score": 0.43165487407108916, + "last_agent": "hrm", + "iteration": 7, + "query_length": 939, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4287118607378963, + "trm_confidence": 0.04911206963964689, + "mcts_value": 0.7069330857830289, + "consensus_score": 0.3622404120999607, + "last_agent": "trm", + "iteration": 9, + "query_length": 1259, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2870207995972183, + "trm_confidence": 0.22152209969067274, + "mcts_value": 0.8420707400012313, + "consensus_score": 0.5334308840117128, + "last_agent": "none", + "iteration": 6, + "query_length": 3194, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5221610491550703, + "trm_confidence": 0.4772597995887381, + "mcts_value": 0.9637205262561082, + "consensus_score": 0.5642874467466807, + "last_agent": "hrm", + "iteration": 4, + "query_length": 1468, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.15794658994125135, + "trm_confidence": 0.5453969607405275, + "mcts_value": 0.683773134610504, + "consensus_score": 0.41955721819814945, + "last_agent": "trm", + "iteration": 4, + "query_length": 1135, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.43199000437925217, + "trm_confidence": 0.608219663374262, + "mcts_value": 0.7772711948981151, + "consensus_score": 0.5279878332237377, + "last_agent": "mcts", + "iteration": 5, + "query_length": 1249, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3377762158718462, + "trm_confidence": 0.6468872542889508, + "mcts_value": 0.7456433448967901, + "consensus_score": 0.5857472691254665, + "last_agent": "hrm", + "iteration": 6, + "query_length": 3167, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6524794138598203, + "trm_confidence": 0.3511632894458381, + "mcts_value": 0.6260862679598145, + "consensus_score": 0.5439942579752518, + "last_agent": "hrm", + "iteration": 10, + "query_length": 4146, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.19074201448203928, + "trm_confidence": 0.4611717367721946, + "mcts_value": 0.7753628362120595, + "consensus_score": 0.4896878806847168, + "last_agent": "trm", + "iteration": 5, + "query_length": 2747, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.27813371894457856, + "trm_confidence": 0.0030806453324853985, + "mcts_value": 0.8576347429346562, + "consensus_score": 0.36893259910581094, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3946, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.057343918929729384, + "trm_confidence": 0.19296580365101634, + "mcts_value": 0.7722439609923336, + "consensus_score": 0.3719495015694868, + "last_agent": "trm", + "iteration": 6, + "query_length": 3265, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4097169332098666, + "trm_confidence": 0.6768629970444513, + "mcts_value": 0.9665208135659924, + "consensus_score": 0.7579703734932746, + "last_agent": "none", + "iteration": 6, + "query_length": 3559, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4791859970651187, + "trm_confidence": 0.3437572785462063, + "mcts_value": 0.6547327136015789, + "consensus_score": 0.3959238509215329, + "last_agent": "mcts", + "iteration": 5, + "query_length": 4657, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.12701811871100846, + "trm_confidence": 0.13127201260054583, + "mcts_value": 0.7263325940353923, + "consensus_score": 0.32405022804403866, + "last_agent": "trm", + "iteration": 9, + "query_length": 2794, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.14610881840587236, + "trm_confidence": 0.03224543007153378, + "mcts_value": 0.7747408030151097, + "consensus_score": 0.31022441439405146, + "last_agent": "hrm", + "iteration": 9, + "query_length": 4839, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4162918703423202, + "trm_confidence": 0.3053725646623523, + "mcts_value": 0.7634403070509002, + "consensus_score": 0.5108381359765598, + "last_agent": "none", + "iteration": 7, + "query_length": 1037, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6584276987172909, + "trm_confidence": 0.2996197271071803, + "mcts_value": 0.8030835374132496, + "consensus_score": 0.49022691502051574, + "last_agent": "none", + "iteration": 5, + "query_length": 3610, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3866693525184374, + "trm_confidence": 0.5001864687890575, + "mcts_value": 0.8162287407554648, + "consensus_score": 0.5226244825444922, + "last_agent": "hrm", + "iteration": 8, + "query_length": 1884, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2815183694024254, + "trm_confidence": 0.12524734995951484, + "mcts_value": 0.7645123010483306, + "consensus_score": 0.37322348205449835, + "last_agent": "mcts", + "iteration": 9, + "query_length": 3920, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.20955924034800916, + "trm_confidence": 0.47760280306181196, + "mcts_value": 0.8374452765271861, + "consensus_score": 0.5562316817182031, + "last_agent": "none", + "iteration": 6, + "query_length": 3385, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.007527602320017867, + "trm_confidence": 0.0860795408424368, + "mcts_value": 0.629478332855336, + "consensus_score": 0.17223016157143217, + "last_agent": "none", + "iteration": 10, + "query_length": 2200, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3003525811907735, + "trm_confidence": 0.2763886248904426, + "mcts_value": 0.7099726767881527, + "consensus_score": 0.38919696122072633, + "last_agent": "trm", + "iteration": 4, + "query_length": 601, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.542797116870872, + "trm_confidence": 0.1411816092963175, + "mcts_value": 0.6546417611118621, + "consensus_score": 0.4924810905196795, + "last_agent": "none", + "iteration": 9, + "query_length": 1604, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6647009394575618, + "trm_confidence": 0.32413276472680647, + "mcts_value": 0.6346983073152016, + "consensus_score": 0.567224436161121, + "last_agent": "trm", + "iteration": 4, + "query_length": 2763, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6203528193725827, + "trm_confidence": 0.1900075922950634, + "mcts_value": 0.8021590618237545, + "consensus_score": 0.5925447551355442, + "last_agent": "mcts", + "iteration": 8, + "query_length": 4554, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.16650482550703383, + "trm_confidence": 0.3097693113012998, + "mcts_value": 0.7173500691167987, + "consensus_score": 0.42494158391789427, + "last_agent": "hrm", + "iteration": 7, + "query_length": 2654, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4091297662980429, + "trm_confidence": 0.36351973519098435, + "mcts_value": 0.8783515072913102, + "consensus_score": 0.5612175890826414, + "last_agent": "mcts", + "iteration": 5, + "query_length": 2878, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2544274738110137, + "trm_confidence": 0.10481394578673936, + "mcts_value": 0.7550002142155073, + "consensus_score": 0.3907813448478427, + "last_agent": "mcts", + "iteration": 8, + "query_length": 1784, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5918205291007997, + "trm_confidence": 0.0763380145377019, + "mcts_value": 0.9439400227462136, + "consensus_score": 0.5763738663568887, + "last_agent": "mcts", + "iteration": 6, + "query_length": 3207, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.19636694379014258, + "trm_confidence": 0.30947440618900923, + "mcts_value": 0.690009451873861, + "consensus_score": 0.35235996368713696, + "last_agent": "trm", + "iteration": 8, + "query_length": 1010, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3544618486405938, + "trm_confidence": 0.6221514364555268, + "mcts_value": 0.6883515339398161, + "consensus_score": 0.6134674023819895, + "last_agent": "trm", + "iteration": 7, + "query_length": 3167, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6148680140308268, + "trm_confidence": 0.4971168583490874, + "mcts_value": 0.6719545858810111, + "consensus_score": 0.6090791467164668, + "last_agent": "hrm", + "iteration": 10, + "query_length": 2568, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.15169861868926238, + "trm_confidence": 0.08440420871119089, + "mcts_value": 0.8729957144008273, + "consensus_score": 0.3628212159351468, + "last_agent": "none", + "iteration": 6, + "query_length": 4551, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5704315813364017, + "trm_confidence": 0.5241789003278413, + "mcts_value": 0.7295918187007265, + "consensus_score": 0.641652207466697, + "last_agent": "trm", + "iteration": 10, + "query_length": 518, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.11266319650705656, + "trm_confidence": 0.19303345284654838, + "mcts_value": 0.7558684743931463, + "consensus_score": 0.4298103324610913, + "last_agent": "trm", + "iteration": 7, + "query_length": 958, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.33947217517053047, + "trm_confidence": 0.5016657593251701, + "mcts_value": 0.6864448220328685, + "consensus_score": 0.46653265726760484, + "last_agent": "trm", + "iteration": 9, + "query_length": 4322, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1756367402124386, + "trm_confidence": 0.4417197013131317, + "mcts_value": 0.9150250250158817, + "consensus_score": 0.4415408780782525, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4388, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13150250773626537, + "trm_confidence": 0.6008398704647704, + "mcts_value": 0.8526673542319743, + "consensus_score": 0.4390067739467064, + "last_agent": "none", + "iteration": 8, + "query_length": 3149, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5604702739715944, + "trm_confidence": 0.41662668123237145, + "mcts_value": 0.8126018927370995, + "consensus_score": 0.6531839456316204, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3290, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5192746395010424, + "trm_confidence": 0.16653692066503153, + "mcts_value": 0.6911244604345537, + "consensus_score": 0.5551067285620601, + "last_agent": "trm", + "iteration": 5, + "query_length": 4692, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.17009010337397337, + "trm_confidence": 0.25263195544389805, + "mcts_value": 0.9434692481935456, + "consensus_score": 0.49832513457294436, + "last_agent": "none", + "iteration": 5, + "query_length": 3510, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.17058772796290878, + "trm_confidence": 0.1649605688983235, + "mcts_value": 0.7945433693066476, + "consensus_score": 0.2785450063396625, + "last_agent": "none", + "iteration": 6, + "query_length": 109, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.43486322474956446, + "trm_confidence": 0.4386586274023478, + "mcts_value": 0.6124623243306682, + "consensus_score": 0.40948815318225856, + "last_agent": "none", + "iteration": 6, + "query_length": 3444, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2883823068083915, + "trm_confidence": 0.32557134992486714, + "mcts_value": 0.8328668272978313, + "consensus_score": 0.423929493564191, + "last_agent": "mcts", + "iteration": 9, + "query_length": 1517, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.15675407418460977, + "trm_confidence": 0.21751818744351983, + "mcts_value": 0.9642656169000174, + "consensus_score": 0.378303847804623, + "last_agent": "none", + "iteration": 6, + "query_length": 1006, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13275246708368724, + "trm_confidence": 0.6435778470574205, + "mcts_value": 0.8532422948258542, + "consensus_score": 0.463810763672716, + "last_agent": "hrm", + "iteration": 4, + "query_length": 1720, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2487519092491989, + "trm_confidence": 0.4168497941370826, + "mcts_value": 0.8815079231064198, + "consensus_score": 0.46041202899874717, + "last_agent": "mcts", + "iteration": 5, + "query_length": 4798, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.43066103537470446, + "trm_confidence": 0.15270442919272204, + "mcts_value": 0.7412751318883843, + "consensus_score": 0.474132405597394, + "last_agent": "hrm", + "iteration": 9, + "query_length": 3198, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.30771116271093707, + "trm_confidence": 0.541816860194706, + "mcts_value": 0.6175433133671664, + "consensus_score": 0.5191797458724317, + "last_agent": "mcts", + "iteration": 8, + "query_length": 4071, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3722451902992266, + "trm_confidence": 0.1430559856244378, + "mcts_value": 0.693978218101228, + "consensus_score": 0.4483388707386439, + "last_agent": "mcts", + "iteration": 5, + "query_length": 330, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.34069695390044646, + "trm_confidence": 0.4205945634849206, + "mcts_value": 0.8026179113292021, + "consensus_score": 0.5087920630048397, + "last_agent": "trm", + "iteration": 10, + "query_length": 1868, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4406615999824062, + "trm_confidence": 0.44450554110256446, + "mcts_value": 0.9516453804946554, + "consensus_score": 0.6341105772697317, + "last_agent": "mcts", + "iteration": 7, + "query_length": 2152, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.017872285996915215, + "trm_confidence": 0.005219501996662123, + "mcts_value": 0.8368011549795481, + "consensus_score": 0.35695934625039155, + "last_agent": "none", + "iteration": 4, + "query_length": 3910, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13871986073570491, + "trm_confidence": 0.6270115867856038, + "mcts_value": 0.845995105179858, + "consensus_score": 0.47526133858339714, + "last_agent": "trm", + "iteration": 10, + "query_length": 4759, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.10552091639326548, + "trm_confidence": 0.22136289713036456, + "mcts_value": 0.6962971814096776, + "consensus_score": 0.3206504448984908, + "last_agent": "hrm", + "iteration": 5, + "query_length": 763, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5814443788879998, + "trm_confidence": 0.11138242296333094, + "mcts_value": 0.6918609344958324, + "consensus_score": 0.5576115038375773, + "last_agent": "mcts", + "iteration": 5, + "query_length": 4467, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.08420000060926532, + "trm_confidence": 0.5717740472353708, + "mcts_value": 0.8234256267071626, + "consensus_score": 0.44807033138676555, + "last_agent": "none", + "iteration": 7, + "query_length": 3520, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.03884019056702028, + "trm_confidence": 0.11803313328209751, + "mcts_value": 0.841056260930314, + "consensus_score": 0.26488164351115007, + "last_agent": "mcts", + "iteration": 6, + "query_length": 2176, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3590681153579105, + "trm_confidence": 0.4365026033918791, + "mcts_value": 0.9687685283097158, + "consensus_score": 0.5596107174327927, + "last_agent": "hrm", + "iteration": 5, + "query_length": 1213, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.24055214801083497, + "trm_confidence": 0.31692561198959995, + "mcts_value": 0.7177291899339334, + "consensus_score": 0.4262157207483272, + "last_agent": "trm", + "iteration": 10, + "query_length": 3072, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.10465821764253724, + "trm_confidence": 0.2449994628955808, + "mcts_value": 0.7951904659530589, + "consensus_score": 0.42507819257167323, + "last_agent": "mcts", + "iteration": 5, + "query_length": 2077, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.036105681083131086, + "trm_confidence": 0.6531526690349112, + "mcts_value": 0.7666804094272028, + "consensus_score": 0.46803354790699225, + "last_agent": "trm", + "iteration": 9, + "query_length": 4934, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2033292341002282, + "trm_confidence": 0.2885614547813878, + "mcts_value": 0.9749441286315138, + "consensus_score": 0.5407177872229005, + "last_agent": "mcts", + "iteration": 6, + "query_length": 2260, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.38893016798324637, + "trm_confidence": 0.4825663175691036, + "mcts_value": 0.9979292693656894, + "consensus_score": 0.6518343651367351, + "last_agent": "trm", + "iteration": 9, + "query_length": 4872, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.20618925247647954, + "trm_confidence": 0.01596017330247862, + "mcts_value": 0.8253024529881239, + "consensus_score": 0.37203127646855805, + "last_agent": "trm", + "iteration": 9, + "query_length": 1101, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.20165745105494864, + "trm_confidence": 0.3569922506583052, + "mcts_value": 0.6658824561963457, + "consensus_score": 0.47551362976539985, + "last_agent": "hrm", + "iteration": 5, + "query_length": 2120, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6552332642343762, + "trm_confidence": 0.08169258417520159, + "mcts_value": 0.9013474455802182, + "consensus_score": 0.6313694248295101, + "last_agent": "none", + "iteration": 8, + "query_length": 4067, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.35430370894841656, + "trm_confidence": 0.43584077185655556, + "mcts_value": 0.6215098838274221, + "consensus_score": 0.4319182228520328, + "last_agent": "none", + "iteration": 7, + "query_length": 1305, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.02930306778478665, + "trm_confidence": 0.30763723660522463, + "mcts_value": 0.8210191759091919, + "consensus_score": 0.3287501243915139, + "last_agent": "mcts", + "iteration": 10, + "query_length": 2544, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.12262977807574112, + "trm_confidence": 0.15365290566384016, + "mcts_value": 0.7198587585408276, + "consensus_score": 0.3275975984973937, + "last_agent": "hrm", + "iteration": 10, + "query_length": 4865, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.027919921899200804, + "trm_confidence": 0.15917366193599397, + "mcts_value": 0.622156824686119, + "consensus_score": 0.338304391695657, + "last_agent": "trm", + "iteration": 8, + "query_length": 202, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6505977835323311, + "trm_confidence": 0.2318247714734927, + "mcts_value": 0.7275878370330569, + "consensus_score": 0.6036590132102162, + "last_agent": "none", + "iteration": 10, + "query_length": 3878, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.19376230117373294, + "trm_confidence": 0.460903398328293, + "mcts_value": 0.9577157077922971, + "consensus_score": 0.4550532188029849, + "last_agent": "mcts", + "iteration": 7, + "query_length": 3258, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.002607855654334512, + "trm_confidence": 0.22732648150856585, + "mcts_value": 0.6251053111722411, + "consensus_score": 0.30291933501906637, + "last_agent": "hrm", + "iteration": 7, + "query_length": 399, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6576675279857744, + "trm_confidence": 0.0879573147862478, + "mcts_value": 0.9964800123105004, + "consensus_score": 0.5637231036945398, + "last_agent": "trm", + "iteration": 8, + "query_length": 1547, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5223972689147434, + "trm_confidence": 0.08690500903747479, + "mcts_value": 0.7465488619637961, + "consensus_score": 0.4183129050235457, + "last_agent": "trm", + "iteration": 9, + "query_length": 4194, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.032799915867008554, + "trm_confidence": 0.12207420247613396, + "mcts_value": 0.934884085944662, + "consensus_score": 0.2729516512169948, + "last_agent": "none", + "iteration": 5, + "query_length": 3021, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3698200768977788, + "trm_confidence": 0.4605242469841963, + "mcts_value": 0.8016841206715465, + "consensus_score": 0.5209919709580493, + "last_agent": "trm", + "iteration": 4, + "query_length": 1230, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6063193015724208, + "trm_confidence": 0.015857791395556153, + "mcts_value": 0.621716435099165, + "consensus_score": 0.42072620067370714, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1575, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.47712759383618397, + "trm_confidence": 0.0856861934113269, + "mcts_value": 0.843366698617572, + "consensus_score": 0.5037122487838062, + "last_agent": "none", + "iteration": 10, + "query_length": 556, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.24115339617428277, + "trm_confidence": 0.5273167309174219, + "mcts_value": 0.7772207997433388, + "consensus_score": 0.5238051665640774, + "last_agent": "mcts", + "iteration": 5, + "query_length": 3470, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.32142392562430644, + "trm_confidence": 0.11908708393095507, + "mcts_value": 0.6680921202370974, + "consensus_score": 0.2734781653997501, + "last_agent": "mcts", + "iteration": 4, + "query_length": 1676, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.051226483406671686, + "trm_confidence": 0.3158319118734727, + "mcts_value": 0.6291845364640543, + "consensus_score": 0.2811010337150727, + "last_agent": "none", + "iteration": 10, + "query_length": 3814, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.31411206359138144, + "trm_confidence": 0.2625795571528991, + "mcts_value": 0.7618972994685068, + "consensus_score": 0.35376953006622147, + "last_agent": "hrm", + "iteration": 6, + "query_length": 360, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.39260310105547147, + "trm_confidence": 0.06913540039949356, + "mcts_value": 0.640948132006049, + "consensus_score": 0.2881484759127241, + "last_agent": "hrm", + "iteration": 7, + "query_length": 3999, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.0031143173662041, + "trm_confidence": 0.13508210923749067, + "mcts_value": 0.6550539140529478, + "consensus_score": 0.21159523105528902, + "last_agent": "mcts", + "iteration": 9, + "query_length": 102, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.28351870855068445, + "trm_confidence": 0.4987368986319067, + "mcts_value": 0.7623777589951879, + "consensus_score": 0.5402621918668641, + "last_agent": "trm", + "iteration": 5, + "query_length": 1984, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.39537985421393407, + "trm_confidence": 0.09709532221116647, + "mcts_value": 0.7750447562620203, + "consensus_score": 0.38116134291072923, + "last_agent": "none", + "iteration": 6, + "query_length": 4513, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4571445090238169, + "trm_confidence": 0.3565729185311946, + "mcts_value": 0.7843045065167165, + "consensus_score": 0.4729652904733782, + "last_agent": "trm", + "iteration": 4, + "query_length": 3495, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.04349742889371404, + "trm_confidence": 0.18068504975469057, + "mcts_value": 0.7917218730871143, + "consensus_score": 0.39944153958867057, + "last_agent": "hrm", + "iteration": 9, + "query_length": 2940, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2960341540506385, + "trm_confidence": 0.29864449828811546, + "mcts_value": 0.8408351304006768, + "consensus_score": 0.4173046031450657, + "last_agent": "mcts", + "iteration": 8, + "query_length": 1658, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6199640440317623, + "trm_confidence": 0.25792774231741794, + "mcts_value": 0.6619457645715343, + "consensus_score": 0.4232269367444519, + "last_agent": "trm", + "iteration": 6, + "query_length": 4463, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.34327889208174345, + "trm_confidence": 0.6694379580513925, + "mcts_value": 0.9667826343447138, + "consensus_score": 0.7348815233350119, + "last_agent": "trm", + "iteration": 9, + "query_length": 4262, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5295282585288805, + "trm_confidence": 0.40413730215330446, + "mcts_value": 0.7527501648675294, + "consensus_score": 0.6060753755312994, + "last_agent": "trm", + "iteration": 5, + "query_length": 2839, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2118643256670007, + "trm_confidence": 0.6903129267689389, + "mcts_value": 0.9830369736776687, + "consensus_score": 0.6040799449989437, + "last_agent": "mcts", + "iteration": 5, + "query_length": 3918, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3883777802635978, + "trm_confidence": 0.21499190556342018, + "mcts_value": 0.9139847653152711, + "consensus_score": 0.5274748910061812, + "last_agent": "trm", + "iteration": 6, + "query_length": 3172, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4677165338014153, + "trm_confidence": 0.5336812877933015, + "mcts_value": 0.9450656503012365, + "consensus_score": 0.6655198675925663, + "last_agent": "hrm", + "iteration": 10, + "query_length": 3227, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.064791257116833, + "trm_confidence": 0.3878136514494973, + "mcts_value": 0.7049164642875259, + "consensus_score": 0.40851230489962564, + "last_agent": "mcts", + "iteration": 9, + "query_length": 2647, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.44256067596182397, + "trm_confidence": 0.6990868447642679, + "mcts_value": 0.6049027207460108, + "consensus_score": 0.5410932845820328, + "last_agent": "mcts", + "iteration": 5, + "query_length": 4789, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.17563905026679527, + "trm_confidence": 0.2428720177572709, + "mcts_value": 0.7701661899108189, + "consensus_score": 0.4238671003608342, + "last_agent": "trm", + "iteration": 6, + "query_length": 679, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6425597511360087, + "trm_confidence": 0.28482017194039755, + "mcts_value": 0.6160125051197055, + "consensus_score": 0.4800928459389961, + "last_agent": "hrm", + "iteration": 4, + "query_length": 3975, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.028807240036680212, + "trm_confidence": 0.10169072715622485, + "mcts_value": 0.9455299801055019, + "consensus_score": 0.28795394821004383, + "last_agent": "mcts", + "iteration": 8, + "query_length": 1272, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.24066451373569545, + "trm_confidence": 0.5360884799176838, + "mcts_value": 0.7751073468262039, + "consensus_score": 0.5658005925109914, + "last_agent": "hrm", + "iteration": 5, + "query_length": 2241, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5429223745586418, + "trm_confidence": 0.48654437294048475, + "mcts_value": 0.7431240129661425, + "consensus_score": 0.6240343280929974, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1190, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.02963761542642319, + "trm_confidence": 0.6081220430624733, + "mcts_value": 0.735873490187804, + "consensus_score": 0.4743636632100025, + "last_agent": "mcts", + "iteration": 6, + "query_length": 1890, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5241162636843427, + "trm_confidence": 0.46632505648051575, + "mcts_value": 0.8893638012649838, + "consensus_score": 0.5673579013326266, + "last_agent": "hrm", + "iteration": 10, + "query_length": 2080, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4556607463952512, + "trm_confidence": 0.6330503900641488, + "mcts_value": 0.8631025294335881, + "consensus_score": 0.7362659819170875, + "last_agent": "trm", + "iteration": 8, + "query_length": 1645, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5733551887565057, + "trm_confidence": 0.45122373788403347, + "mcts_value": 0.7454935918237665, + "consensus_score": 0.5054773788721658, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3981, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6662827140882281, + "trm_confidence": 0.2994956771298266, + "mcts_value": 0.9537992592418003, + "consensus_score": 0.6966713983480932, + "last_agent": "none", + "iteration": 9, + "query_length": 126, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.697918453147421, + "trm_confidence": 0.02115344262594493, + "mcts_value": 0.6581350844908569, + "consensus_score": 0.47853617716369623, + "last_agent": "none", + "iteration": 8, + "query_length": 2610, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5051054710091956, + "trm_confidence": 0.15147969429688585, + "mcts_value": 0.7022510975876359, + "consensus_score": 0.518753072114909, + "last_agent": "none", + "iteration": 5, + "query_length": 4041, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13226202532353212, + "trm_confidence": 0.16659054134282192, + "mcts_value": 0.9133653943335903, + "consensus_score": 0.45989340141352153, + "last_agent": "trm", + "iteration": 4, + "query_length": 998, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6511273501213394, + "trm_confidence": 0.3933144844251073, + "mcts_value": 0.782212362475966, + "consensus_score": 0.5864203019824493, + "last_agent": "mcts", + "iteration": 10, + "query_length": 39, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.28815102182140645, + "trm_confidence": 0.21852605082475737, + "mcts_value": 0.7734939894749928, + "consensus_score": 0.3918626493212357, + "last_agent": "hrm", + "iteration": 8, + "query_length": 489, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.14084314868811096, + "trm_confidence": 0.268945602243145, + "mcts_value": 0.8727266902585276, + "consensus_score": 0.35681294157532756, + "last_agent": "none", + "iteration": 7, + "query_length": 1968, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.11439160427818405, + "trm_confidence": 0.15855390735182287, + "mcts_value": 0.7418622304225458, + "consensus_score": 0.4018918310559654, + "last_agent": "mcts", + "iteration": 5, + "query_length": 557, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13211278804007376, + "trm_confidence": 0.5091159917392016, + "mcts_value": 0.7324797627224844, + "consensus_score": 0.4516118533956279, + "last_agent": "hrm", + "iteration": 4, + "query_length": 645, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.24738488351716395, + "trm_confidence": 0.2693562444969951, + "mcts_value": 0.6591013780288429, + "consensus_score": 0.38983348186601663, + "last_agent": "mcts", + "iteration": 9, + "query_length": 150, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4521926221421078, + "trm_confidence": 0.6654542154042549, + "mcts_value": 0.9729888205344299, + "consensus_score": 0.7603564601768946, + "last_agent": "trm", + "iteration": 10, + "query_length": 742, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3100398441795226, + "trm_confidence": 0.460622617268879, + "mcts_value": 0.74292359315087, + "consensus_score": 0.48991505681969105, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1338, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1429320477494029, + "trm_confidence": 0.6607633266673058, + "mcts_value": 0.9319946595939554, + "consensus_score": 0.5096725046576749, + "last_agent": "none", + "iteration": 4, + "query_length": 1889, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.517131593916873, + "trm_confidence": 0.6723713247332932, + "mcts_value": 0.9274807310812058, + "consensus_score": 0.6597691423142072, + "last_agent": "none", + "iteration": 9, + "query_length": 2021, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.31192002079093006, + "trm_confidence": 0.004095015871525065, + "mcts_value": 0.8149594502277363, + "consensus_score": 0.39188393433789537, + "last_agent": "trm", + "iteration": 9, + "query_length": 4475, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.37127967759295266, + "trm_confidence": 0.12477220220576545, + "mcts_value": 0.9726623447572149, + "consensus_score": 0.5187609223686881, + "last_agent": "hrm", + "iteration": 5, + "query_length": 678, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.07748641831384613, + "trm_confidence": 0.07435221386350471, + "mcts_value": 0.8741260969676969, + "consensus_score": 0.3899769556847119, + "last_agent": "hrm", + "iteration": 4, + "query_length": 1605, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.09897539965443619, + "trm_confidence": 0.36286645309387816, + "mcts_value": 0.9184595034591311, + "consensus_score": 0.379425550262937, + "last_agent": "hrm", + "iteration": 8, + "query_length": 3230, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.14541151475870714, + "trm_confidence": 0.4780262591640515, + "mcts_value": 0.9777923425290989, + "consensus_score": 0.4594016570260412, + "last_agent": "mcts", + "iteration": 5, + "query_length": 492, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.699204268274178, + "trm_confidence": 0.018148359663660893, + "mcts_value": 0.8292384565366064, + "consensus_score": 0.4155932454202992, + "last_agent": "hrm", + "iteration": 6, + "query_length": 4669, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6388077833881113, + "trm_confidence": 0.6163674834515231, + "mcts_value": 0.9535348944394526, + "consensus_score": 0.6648724705235719, + "last_agent": "trm", + "iteration": 6, + "query_length": 3148, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.46426661995738583, + "trm_confidence": 0.1660000042337644, + "mcts_value": 0.7250444271136965, + "consensus_score": 0.4164542020271105, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1780, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.21438169265579282, + "trm_confidence": 0.016469676876294147, + "mcts_value": 0.9857296347539451, + "consensus_score": 0.3565071821891915, + "last_agent": "mcts", + "iteration": 5, + "query_length": 3180, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.43823421407178675, + "trm_confidence": 0.6124293493895462, + "mcts_value": 0.645086426531144, + "consensus_score": 0.5377537730249061, + "last_agent": "none", + "iteration": 5, + "query_length": 2388, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4140413718557169, + "trm_confidence": 0.07080024932932218, + "mcts_value": 0.9143867285802508, + "consensus_score": 0.47384688078716997, + "last_agent": "none", + "iteration": 7, + "query_length": 3661, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6867783549341924, + "trm_confidence": 0.12323032392527893, + "mcts_value": 0.6440642384213293, + "consensus_score": 0.4232395505850826, + "last_agent": "hrm", + "iteration": 5, + "query_length": 3098, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.17469002741858916, + "trm_confidence": 0.3005963681150697, + "mcts_value": 0.7774678297930802, + "consensus_score": 0.4522891119879232, + "last_agent": "none", + "iteration": 5, + "query_length": 4265, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13872532157599726, + "trm_confidence": 0.6681974928109701, + "mcts_value": 0.7225384948710823, + "consensus_score": 0.5719325584429859, + "last_agent": "trm", + "iteration": 6, + "query_length": 2384, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.24272674878274395, + "trm_confidence": 0.58772811267745, + "mcts_value": 0.80752569966152, + "consensus_score": 0.5964265337467927, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1102, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.0007363807922716958, + "trm_confidence": 0.6732710189801955, + "mcts_value": 0.7477560894079189, + "consensus_score": 0.483846978274696, + "last_agent": "none", + "iteration": 8, + "query_length": 2180, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2030334913415154, + "trm_confidence": 0.6801032874356393, + "mcts_value": 0.8202565843277612, + "consensus_score": 0.5921314702046956, + "last_agent": "mcts", + "iteration": 8, + "query_length": 680, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.04504428752245054, + "trm_confidence": 0.24516113759752034, + "mcts_value": 0.715144560196313, + "consensus_score": 0.27033614762577307, + "last_agent": "none", + "iteration": 8, + "query_length": 1430, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.23301783778983257, + "trm_confidence": 0.46925515268030815, + "mcts_value": 0.7211019432982609, + "consensus_score": 0.40431259810438114, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1304, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.12227880344444905, + "trm_confidence": 0.6529222532052206, + "mcts_value": 0.6404797783707098, + "consensus_score": 0.42131746934726383, + "last_agent": "trm", + "iteration": 8, + "query_length": 4349, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6154667532299842, + "trm_confidence": 0.10362325407179326, + "mcts_value": 0.6776406769395111, + "consensus_score": 0.36932701209435204, + "last_agent": "trm", + "iteration": 7, + "query_length": 1739, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.08971358945199735, + "trm_confidence": 0.6875615806959464, + "mcts_value": 0.7303964181662499, + "consensus_score": 0.5379049220026138, + "last_agent": "none", + "iteration": 9, + "query_length": 3616, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2691948469916442, + "trm_confidence": 0.22703799524254756, + "mcts_value": 0.879296909595676, + "consensus_score": 0.5019016364555173, + "last_agent": "trm", + "iteration": 9, + "query_length": 2247, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.20623477405137147, + "trm_confidence": 0.3426701595313342, + "mcts_value": 0.8274639271393289, + "consensus_score": 0.48153238358586203, + "last_agent": "hrm", + "iteration": 6, + "query_length": 3062, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.0066050093393553826, + "trm_confidence": 0.3988879217624059, + "mcts_value": 0.8338452848863351, + "consensus_score": 0.43983349585107095, + "last_agent": "trm", + "iteration": 10, + "query_length": 180, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.33249821550006176, + "trm_confidence": 0.37215949992821973, + "mcts_value": 0.8297870637085434, + "consensus_score": 0.4249424071794077, + "last_agent": "hrm", + "iteration": 10, + "query_length": 2264, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.38958886100466183, + "trm_confidence": 0.04028922516076817, + "mcts_value": 0.819223843937053, + "consensus_score": 0.4359574126581124, + "last_agent": "trm", + "iteration": 7, + "query_length": 1477, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4976213779106428, + "trm_confidence": 0.2008256766826197, + "mcts_value": 0.6188499666382289, + "consensus_score": 0.4694400731677854, + "last_agent": "hrm", + "iteration": 8, + "query_length": 4359, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5849432813373271, + "trm_confidence": 0.42391415051634174, + "mcts_value": 0.7761809380865067, + "consensus_score": 0.5087604503757408, + "last_agent": "trm", + "iteration": 7, + "query_length": 1394, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4301947589066826, + "trm_confidence": 0.007278671057098373, + "mcts_value": 0.8566582715646531, + "consensus_score": 0.452916335726504, + "last_agent": "hrm", + "iteration": 9, + "query_length": 3093, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.21256089197690825, + "trm_confidence": 0.2948914923860667, + "mcts_value": 0.6235903081559845, + "consensus_score": 0.3129685723986355, + "last_agent": "mcts", + "iteration": 6, + "query_length": 761, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.46808633399039357, + "trm_confidence": 0.18380072281210846, + "mcts_value": 0.6452869132028206, + "consensus_score": 0.47467249283102203, + "last_agent": "hrm", + "iteration": 4, + "query_length": 914, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5578799743408257, + "trm_confidence": 0.5487661061438933, + "mcts_value": 0.7565794654347361, + "consensus_score": 0.5961871104646532, + "last_agent": "hrm", + "iteration": 8, + "query_length": 4110, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.05688528576326571, + "trm_confidence": 0.304222998883027, + "mcts_value": 0.832162734472168, + "consensus_score": 0.33430254268044207, + "last_agent": "mcts", + "iteration": 10, + "query_length": 664, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.2522673022310073, + "trm_confidence": 0.20211872466786202, + "mcts_value": 0.836309499455014, + "consensus_score": 0.3708157987302678, + "last_agent": "trm", + "iteration": 10, + "query_length": 3976, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5227074134042587, + "trm_confidence": 0.2929715439256572, + "mcts_value": 0.6533044577724123, + "consensus_score": 0.3927404693595567, + "last_agent": "hrm", + "iteration": 7, + "query_length": 637, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.48722867299686506, + "trm_confidence": 0.17999942599883725, + "mcts_value": 0.6839737252374378, + "consensus_score": 0.5287782484460767, + "last_agent": "hrm", + "iteration": 8, + "query_length": 3467, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6081947137154111, + "trm_confidence": 0.2531792725155051, + "mcts_value": 0.8460758153320737, + "consensus_score": 0.5789780492314458, + "last_agent": "none", + "iteration": 7, + "query_length": 3122, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4046425342306847, + "trm_confidence": 0.5281018359588862, + "mcts_value": 0.6659661634434454, + "consensus_score": 0.5069134633177373, + "last_agent": "mcts", + "iteration": 10, + "query_length": 2624, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4088338487835222, + "trm_confidence": 0.2943976070261187, + "mcts_value": 0.6693235039583664, + "consensus_score": 0.36055634085907, + "last_agent": "mcts", + "iteration": 4, + "query_length": 3216, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6244276612236238, + "trm_confidence": 0.3819486118099557, + "mcts_value": 0.907731983807261, + "consensus_score": 0.5591491781164382, + "last_agent": "none", + "iteration": 7, + "query_length": 2719, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.20376871458816245, + "trm_confidence": 0.08672342650185047, + "mcts_value": 0.8997910516490358, + "consensus_score": 0.4768974041987981, + "last_agent": "trm", + "iteration": 9, + "query_length": 2961, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.06804884644945763, + "trm_confidence": 0.044298543381425914, + "mcts_value": 0.9294697959298257, + "consensus_score": 0.3989925746828216, + "last_agent": "trm", + "iteration": 7, + "query_length": 455, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5875125004474275, + "trm_confidence": 0.04033643008293962, + "mcts_value": 0.9615448783904088, + "consensus_score": 0.5530153846951598, + "last_agent": "trm", + "iteration": 9, + "query_length": 4449, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.48023325325328375, + "trm_confidence": 0.0729001088879177, + "mcts_value": 0.9932021996304753, + "consensus_score": 0.45510478965721624, + "last_agent": "trm", + "iteration": 9, + "query_length": 1130, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.07023052249153387, + "trm_confidence": 0.5343709379002983, + "mcts_value": 0.6393357067489559, + "consensus_score": 0.3275165787235706, + "last_agent": "hrm", + "iteration": 4, + "query_length": 3365, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6170942399050675, + "trm_confidence": 0.45535938149292754, + "mcts_value": 0.9201799249273803, + "consensus_score": 0.6287184664962302, + "last_agent": "hrm", + "iteration": 10, + "query_length": 1562, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6204802934349398, + "trm_confidence": 0.47657581474524546, + "mcts_value": 0.7102043357819398, + "consensus_score": 0.6815545649292236, + "last_agent": "mcts", + "iteration": 5, + "query_length": 2450, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4676245308653993, + "trm_confidence": 0.6398424062285161, + "mcts_value": 0.8450968626801308, + "consensus_score": 0.6270051948636078, + "last_agent": "none", + "iteration": 9, + "query_length": 159, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6278599542833151, + "trm_confidence": 0.4760425988010662, + "mcts_value": 0.9151010875463841, + "consensus_score": 0.5865291244442012, + "last_agent": "none", + "iteration": 5, + "query_length": 4674, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6771893195888536, + "trm_confidence": 0.025759580280082583, + "mcts_value": 0.9874601893508905, + "consensus_score": 0.5399593299964925, + "last_agent": "hrm", + "iteration": 9, + "query_length": 1672, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.42664222096266324, + "trm_confidence": 0.1000087329786002, + "mcts_value": 0.6343635619799833, + "consensus_score": 0.38564670521397426, + "last_agent": "hrm", + "iteration": 5, + "query_length": 98, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6376250648109357, + "trm_confidence": 0.5212521749410738, + "mcts_value": 0.6320077702106438, + "consensus_score": 0.6833264663421412, + "last_agent": "mcts", + "iteration": 7, + "query_length": 3425, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.431414967236463, + "trm_confidence": 0.2460915708799335, + "mcts_value": 0.7433629170008382, + "consensus_score": 0.5228777985021462, + "last_agent": "mcts", + "iteration": 9, + "query_length": 796, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.36891543450466274, + "trm_confidence": 0.15641123171088717, + "mcts_value": 0.697120372098111, + "consensus_score": 0.31499351078862187, + "last_agent": "trm", + "iteration": 8, + "query_length": 3342, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.06133603798457597, + "trm_confidence": 0.6968683253433593, + "mcts_value": 0.8895709118212334, + "consensus_score": 0.6481281384391293, + "last_agent": "none", + "iteration": 7, + "query_length": 660, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5162963128563078, + "trm_confidence": 0.6154580997694961, + "mcts_value": 0.6974274223222228, + "consensus_score": 0.6726605315321647, + "last_agent": "mcts", + "iteration": 9, + "query_length": 4994, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4834933691149103, + "trm_confidence": 0.09271597870094306, + "mcts_value": 0.9279959791115091, + "consensus_score": 0.4242386564553582, + "last_agent": "mcts", + "iteration": 9, + "query_length": 519, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6350832535941797, + "trm_confidence": 0.46171324980985295, + "mcts_value": 0.807750446876551, + "consensus_score": 0.6140783324445211, + "last_agent": "mcts", + "iteration": 6, + "query_length": 4002, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3950473148643869, + "trm_confidence": 0.3561326175534258, + "mcts_value": 0.9041101138745393, + "consensus_score": 0.5971434892949538, + "last_agent": "mcts", + "iteration": 8, + "query_length": 3863, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5284664924573852, + "trm_confidence": 0.359632315666183, + "mcts_value": 0.6250207145109178, + "consensus_score": 0.4193117982181257, + "last_agent": "trm", + "iteration": 8, + "query_length": 2544, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5556592843548339, + "trm_confidence": 0.21017367604051104, + "mcts_value": 0.8094508678178738, + "consensus_score": 0.46547605311069346, + "last_agent": "hrm", + "iteration": 6, + "query_length": 1578, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5918639565147593, + "trm_confidence": 0.5046427330550736, + "mcts_value": 0.8783734409608674, + "consensus_score": 0.593090888657487, + "last_agent": "hrm", + "iteration": 6, + "query_length": 2486, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.17372031475182587, + "trm_confidence": 0.36777085299629275, + "mcts_value": 0.8746896798232204, + "consensus_score": 0.5304262511627166, + "last_agent": "hrm", + "iteration": 8, + "query_length": 4420, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.45735500062212225, + "trm_confidence": 0.3846274245379609, + "mcts_value": 0.7331467002465364, + "consensus_score": 0.4443217392962154, + "last_agent": "mcts", + "iteration": 10, + "query_length": 4435, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13890753564258362, + "trm_confidence": 0.2638956732793761, + "mcts_value": 0.7893444052789862, + "consensus_score": 0.3639475105300018, + "last_agent": "trm", + "iteration": 5, + "query_length": 3220, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.622438379720267, + "trm_confidence": 0.6589407839330126, + "mcts_value": 0.7849101441121128, + "consensus_score": 0.6326577884340824, + "last_agent": "trm", + "iteration": 4, + "query_length": 4804, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6045200488979079, + "trm_confidence": 0.016588181483114306, + "mcts_value": 0.9979445614554436, + "consensus_score": 0.48993697408678183, + "last_agent": "hrm", + "iteration": 4, + "query_length": 3119, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.35960297881693765, + "trm_confidence": 0.38811491739756926, + "mcts_value": 0.8325627905169265, + "consensus_score": 0.477250480649796, + "last_agent": "hrm", + "iteration": 8, + "query_length": 2855, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5141036811251816, + "trm_confidence": 0.471118384598279, + "mcts_value": 0.8584046895883358, + "consensus_score": 0.5809657554330427, + "last_agent": "hrm", + "iteration": 9, + "query_length": 2349, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.13660095437682052, + "trm_confidence": 0.445975889224767, + "mcts_value": 0.6513189291697111, + "consensus_score": 0.3691818176349714, + "last_agent": "trm", + "iteration": 4, + "query_length": 2369, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.33445363477286755, + "trm_confidence": 0.42516981740064813, + "mcts_value": 0.6662468722161602, + "consensus_score": 0.4105705631007749, + "last_agent": "none", + "iteration": 8, + "query_length": 4914, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.15644903475187527, + "trm_confidence": 0.5496442951602336, + "mcts_value": 0.7246232553575843, + "consensus_score": 0.5688682950106064, + "last_agent": "hrm", + "iteration": 7, + "query_length": 1908, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.11063520378319223, + "trm_confidence": 0.5899335993964124, + "mcts_value": 0.7042277265701904, + "consensus_score": 0.3852363074032394, + "last_agent": "none", + "iteration": 10, + "query_length": 4243, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5931454356934307, + "trm_confidence": 0.01928333490401931, + "mcts_value": 0.8070959148296653, + "consensus_score": 0.5286054450237737, + "last_agent": "hrm", + "iteration": 4, + "query_length": 719, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.04279058165206425, + "trm_confidence": 0.583271527160927, + "mcts_value": 0.8116704876951447, + "consensus_score": 0.3847070222502957, + "last_agent": "trm", + "iteration": 4, + "query_length": 3913, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.6169720930483836, + "trm_confidence": 0.44480049095375845, + "mcts_value": 0.7859484495431563, + "consensus_score": 0.6406127620558477, + "last_agent": "trm", + "iteration": 6, + "query_length": 3837, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.0733583824003766, + "trm_confidence": 0.3703074200514973, + "mcts_value": 0.984468956358673, + "consensus_score": 0.43055863777047165, + "last_agent": "trm", + "iteration": 5, + "query_length": 2794, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.5440256379818951, + "trm_confidence": 0.6081611824354874, + "mcts_value": 0.9191736468715476, + "consensus_score": 0.6490340366787835, + "last_agent": "trm", + "iteration": 7, + "query_length": 3514, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.21164017849959813, + "trm_confidence": 0.39304547328913053, + "mcts_value": 0.8987430492545462, + "consensus_score": 0.43500295675164846, + "last_agent": "trm", + "iteration": 9, + "query_length": 3651, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.03215589797470598, + "trm_confidence": 0.6947210968425096, + "mcts_value": 0.9509811659011822, + "consensus_score": 0.5088419364607382, + "last_agent": "mcts", + "iteration": 6, + "query_length": 2991, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3868629492725526, + "trm_confidence": 0.1678673810126457, + "mcts_value": 0.707013079346613, + "consensus_score": 0.39933539882401803, + "last_agent": "trm", + "iteration": 8, + "query_length": 4828, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.4160259913511262, + "trm_confidence": 0.10340999377874809, + "mcts_value": 0.8770966001067517, + "consensus_score": 0.5287339820537985, + "last_agent": "mcts", + "iteration": 4, + "query_length": 2363, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.46720028290912374, + "trm_confidence": 0.2578597578154601, + "mcts_value": 0.6228235667727845, + "consensus_score": 0.4797042324734294, + "last_agent": "mcts", + "iteration": 6, + "query_length": 1202, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.03190645756728122, + "trm_confidence": 0.30115856844220246, + "mcts_value": 0.7831844040664724, + "consensus_score": 0.4002960159816925, + "last_agent": "mcts", + "iteration": 7, + "query_length": 433, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.21620392844068265, + "trm_confidence": 0.010104361945315464, + "mcts_value": 0.8672223942746995, + "consensus_score": 0.2956293396370816, + "last_agent": "mcts", + "iteration": 7, + "query_length": 800, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.1684574328960351, + "trm_confidence": 0.540957317932618, + "mcts_value": 0.8777532055014203, + "consensus_score": 0.48635142172215373, + "last_agent": "hrm", + "iteration": 8, + "query_length": 3036, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.03276345927082932, + "trm_confidence": 0.3748118647412166, + "mcts_value": 0.7542224514574996, + "consensus_score": 0.4863477374654152, + "last_agent": "none", + "iteration": 10, + "query_length": 3405, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.24146628227404615, + "trm_confidence": 0.11219073612816403, + "mcts_value": 0.9022786037697009, + "consensus_score": 0.4518281475111357, + "last_agent": "mcts", + "iteration": 7, + "query_length": 683, + "has_rag_context": false + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.23558057312119612, + "trm_confidence": 0.46696633168896406, + "mcts_value": 0.9309570648942034, + "consensus_score": 0.5177396166273129, + "last_agent": "trm", + "iteration": 4, + "query_length": 2024, + "has_rag_context": true + }, + "label": "mcts" + }, + { + "features": { + "hrm_confidence": 0.3382837374618919, + "trm_confidence": 0.07576130578580771, + "mcts_value": 0.9800137852816538, + "consensus_score": 0.5022516420516963, + "last_agent": "trm", + "iteration": 10, + "query_length": 2337, + "has_rag_context": false + }, + "label": "mcts" + } + ] +} \ No newline at end of file diff --git a/models/bert_lora/training_results.json b/models/bert_lora/training_results.json new file mode 100644 index 0000000000000000000000000000000000000000..70ae829f521a3edaa80917dbc6f8000c96785282 --- /dev/null +++ b/models/bert_lora/training_results.json @@ -0,0 +1,48 @@ +{ + "config": { + "model_name": "prajjwal1/bert-mini", + "lora_r": 4, + "lora_alpha": 16, + "lora_dropout": 0.1, + "lr": 0.001, + "batch_size": 16, + "epochs": 5, + "warmup_steps": 100, + "seed": 42, + "num_samples": 1000, + "data_path": null, + "balanced": true, + "output_dir": "models/bert_lora" + }, + "train_history": { + "train_loss": 1.1033503922549162, + "train_runtime": 11.0946, + "train_samples_per_second": 315.018, + "epochs": 5, + "final_metrics": { + "train_runtime": 11.0946, + "train_samples_per_second": 315.018, + "train_steps_per_second": 19.829, + "total_flos": 34821822412800.0, + "train_loss": 1.1033503922549162, + "epoch": 5.0 + }, + "eval_results": { + "eval_loss": 1.0453400611877441, + "eval_accuracy": 0.47651006711409394, + "eval_runtime": 0.1251, + "eval_samples_per_second": 1191.171, + "eval_steps_per_second": 79.944, + "epoch": 5.0 + } + }, + "test_results": { + "loss": 1.0559743153338401, + "accuracy": 0.4768211920529801 + }, + "model_params": { + "total_params": 11188486, + "trainable_params": 17155, + "trainable_percentage": 0.15 + } +} \ No newline at end of file diff --git a/models/rnn_meta_controller.history.json b/models/rnn_meta_controller.history.json new file mode 100644 index 0000000000000000000000000000000000000000..4e2219b882cb4f47e91da96dee8c241e00ff12c2 --- /dev/null +++ b/models/rnn_meta_controller.history.json @@ -0,0 +1,128 @@ +{ + "config": { + "hidden_dim": 64, + "num_layers": 1, + "dropout": 0.1, + "lr": 0.001, + "batch_size": 32, + "epochs": 20, + "patience": 5, + "seed": 42, + "num_samples": 1000 + }, + "training_history": { + "train_losses": [ + 1.060307163180727, + 0.9014069383794611, + 0.6105747597687172, + 0.35656250968123926, + 0.22574858390020602, + 0.16157509059165465, + 0.12456387586214325, + 0.10158110240643675, + 0.08592396827809738, + 0.07474524908783761, + 0.06479036057311477, + 0.057878461638183304, + 0.052609961931452606, + 0.04809149278497154, + 0.043710527828697, + 0.041286276738074695, + 0.03756282673302022, + 0.03491098284156936, + 0.031911260236731985, + 0.030496817025722878 + ], + "val_losses": [ + 1.0059996803601583, + 0.7808501919110616, + 0.47826388080914817, + 0.29279296696186063, + 0.2008462185660998, + 0.1529717780649662, + 0.12299496456980705, + 0.10291122049093246, + 0.08860023791591326, + 0.07790809428940217, + 0.06982718824098508, + 0.06387854401643077, + 0.05984275036801894, + 0.05463591649507483, + 0.04938021237030625, + 0.0452831008626769, + 0.04252756762628754, + 0.039516554485696055, + 0.038632405494960644, + 0.035608950459087886 + ], + "val_accuracies": [ + 0.8466666666666667, + 0.92, + 0.9822222222222222, + 0.9933333333333333, + 0.9911111111111112, + 0.9933333333333333, + 0.9955555555555555, + 0.9955555555555555, + 0.9955555555555555, + 0.9955555555555555, + 0.9955555555555555, + 0.9977777777777778, + 0.9933333333333333, + 0.9933333333333333, + 0.9977777777777778, + 0.9977777777777778, + 0.9977777777777778, + 0.9977777777777778, + 0.9955555555555555, + 0.9977777777777778 + ], + "best_epoch": 20, + "best_val_loss": 0.035608950459087886, + "best_val_accuracy": 0.9977777777777778, + "stopped_early": false, + "total_epochs": 20 + }, + "test_results": { + "loss": 0.022989434589787076, + "accuracy": 0.9977777777777778, + "per_class_metrics": { + "hrm": { + "precision": 1.0, + "recall": 1.0, + "f1_score": 1.0, + "support": 153 + }, + "trm": { + "precision": 0.9933774834437086, + "recall": 1.0, + "f1_score": 0.9966777408637874, + "support": 150 + }, + "mcts": { + "precision": 1.0, + "recall": 0.9931972789115646, + "f1_score": 0.9965870307167235, + "support": 147 + } + }, + "confusion_matrix": [ + [ + 153, + 0, + 0 + ], + [ + 0, + 150, + 0 + ], + [ + 0, + 1, + 146 + ] + ], + "total_samples": 450 + } +} \ No newline at end of file diff --git a/models/rnn_meta_controller.pt b/models/rnn_meta_controller.pt new file mode 100644 index 0000000000000000000000000000000000000000..ecac32abb9cde0a5c8899762752f8d206b36c47d Binary files /dev/null and b/models/rnn_meta_controller.pt differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a59c1caf3390642ffa3007ad4cf87067d771a6c2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +# LangGraph Multi-Agent MCTS Demo - Dependencies +# Optimized for Hugging Face Spaces deployment with trained models + +# Core UI Framework +gradio>=4.0.0,<5.0.0 + +# Numerical computation +numpy>=1.24.0,<2.0.0 + +# Machine Learning - Neural Models +torch>=2.1.0 +transformers>=4.40.0 +peft>=0.7.0 +sentence-transformers>=2.2.0 + +# Configuration +pyyaml>=6.0 + +# Experiment Tracking +wandb>=0.16.0 + +# Required for Gradio OAuth and model loading +huggingface_hub>=0.20.0,<0.30.0 + +# Note: This demo now uses REAL trained models: +# - RNN Meta-Controller (models/rnn_meta_controller.pt) +# - BERT with LoRA adapters (models/bert_lora/final_model/) +# - Actual HRM and TRM agent implementations diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed06e8d006ca2b3347dd740f88d0f7518f20d1 --- /dev/null +++ b/src/adapters/__init__.py @@ -0,0 +1,7 @@ +""" +Adapters package for external service integrations. +""" + +from .llm import BaseLLMClient, LLMResponse, create_client + +__all__ = ["create_client", "BaseLLMClient", "LLMResponse"] diff --git a/src/adapters/llm/__init__.py b/src/adapters/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d176cb1b0d35d0bc2d82a2132318408ccde2c0b9 --- /dev/null +++ b/src/adapters/llm/__init__.py @@ -0,0 +1,257 @@ +""" +LLM Client Factory and Provider Registry. + +This module provides a factory function to instantiate the correct LLM client +based on provider settings, with lazy loading of adapters. +""" + +import importlib +import logging +from typing import Any + +from .base import BaseLLMClient, LLMClient, LLMResponse, LLMToolResponse, ToolCall +from .exceptions import ( + CircuitBreakerOpenError, + LLMAuthenticationError, + LLMClientError, + LLMConnectionError, + LLMContentFilterError, + LLMContextLengthError, + LLMInvalidRequestError, + LLMModelNotFoundError, + LLMQuotaExceededError, + LLMRateLimitError, + LLMResponseParseError, + LLMServerError, + LLMStreamError, + LLMTimeoutError, +) + +logger = logging.getLogger(__name__) + +# Provider registry with lazy loading +# Maps provider name to (module_path, class_name) +_PROVIDER_REGISTRY: dict[str, tuple[str, str]] = { + "openai": ("src.adapters.llm.openai_client", "OpenAIClient"), + "anthropic": ("src.adapters.llm.anthropic_client", "AnthropicClient"), + "lmstudio": ("src.adapters.llm.lmstudio_client", "LMStudioClient"), + "local": ("src.adapters.llm.lmstudio_client", "LMStudioClient"), # Alias +} + +# Cache for loaded client classes +_CLIENT_CACHE: dict[str, type[BaseLLMClient]] = {} + + +def register_provider(name: str, module_path: str, class_name: str, override: bool = False) -> None: + """ + Register a new LLM provider. + + Args: + name: Provider identifier (e.g., "azure", "bedrock") + module_path: Full module path (e.g., "src.adapters.llm.azure_client") + class_name: Class name in the module (e.g., "AzureOpenAIClient") + override: If True, allow overriding existing provider + """ + if name in _PROVIDER_REGISTRY and not override: + raise ValueError(f"Provider '{name}' already registered. Use override=True to replace.") + + _PROVIDER_REGISTRY[name] = (module_path, class_name) + # Clear cache if overriding + if name in _CLIENT_CACHE: + del _CLIENT_CACHE[name] + + logger.info(f"Registered LLM provider: {name} -> {module_path}.{class_name}") + + +def list_providers() -> list[str]: + """ + List all registered provider names. + + Returns: + List of provider identifiers + """ + return list(_PROVIDER_REGISTRY.keys()) + + +def get_provider_class(provider: str) -> type[BaseLLMClient]: + """ + Get the client class for a provider (with lazy loading). + + Args: + provider: Provider identifier + + Returns: + Client class (not instantiated) + + Raises: + ValueError: If provider not registered + ImportError: If module cannot be loaded + """ + if provider not in _PROVIDER_REGISTRY: + available = ", ".join(list_providers()) + raise ValueError(f"Unknown provider '{provider}'. Available: {available}") + + # Check cache first + if provider in _CLIENT_CACHE: + return _CLIENT_CACHE[provider] + + # Lazy load the module + module_path, class_name = _PROVIDER_REGISTRY[provider] + + try: + module = importlib.import_module(module_path) + client_class = getattr(module, class_name) + except ImportError as e: + raise ImportError(f"Failed to load provider '{provider}': {e}") from e + except AttributeError as e: + raise ImportError(f"Class '{class_name}' not found in module '{module_path}'") from e + + # Cache for future use + _CLIENT_CACHE[provider] = client_class + return client_class + + +def create_client( + provider: str = "openai", + *, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + timeout: float | None = None, + max_retries: int | None = None, + **kwargs: Any, +) -> BaseLLMClient: + """ + Create an LLM client instance. + + This is the main factory function for creating provider clients. + + Args: + provider: Provider name ("openai", "anthropic", "lmstudio", etc.) + api_key: API key (may be optional for some providers) + model: Model identifier + base_url: Base URL for API + timeout: Request timeout in seconds + max_retries: Maximum retry attempts + **kwargs: Provider-specific parameters + + Returns: + Configured LLMClient instance + + Examples: + # OpenAI client + client = create_client("openai", model="gpt-4-turbo-preview") + + # Anthropic client + client = create_client("anthropic", model="sonnet") + + # Local LM Studio + client = create_client("lmstudio", base_url="http://localhost:1234/v1") + + # With custom settings + client = create_client( + "openai", + api_key="sk-...", + timeout=120.0, + max_retries=5, + organization="org-..." + ) + """ + client_class = get_provider_class(provider) + + # Build kwargs for client initialization + init_kwargs = {**kwargs} + + if api_key is not None: + init_kwargs["api_key"] = api_key + if model is not None: + init_kwargs["model"] = model + if base_url is not None: + init_kwargs["base_url"] = base_url + if timeout is not None: + init_kwargs["timeout"] = timeout + if max_retries is not None: + init_kwargs["max_retries"] = max_retries + + logger.info(f"Creating {provider} client with model={model or 'default'}") + + return client_class(**init_kwargs) + + +def create_client_from_config(config: dict) -> BaseLLMClient: + """ + Create an LLM client from a configuration dictionary. + + Useful for loading settings from YAML/JSON config files. + + Args: + config: Configuration dictionary with keys: + - provider: Required provider name + - Other keys passed to create_client + + Returns: + Configured LLMClient instance + + Example: + config = { + "provider": "openai", + "model": "gpt-4-turbo-preview", + "timeout": 60.0, + "max_retries": 3 + } + client = create_client_from_config(config) + """ + config = config.copy() + provider = config.pop("provider", "openai") + return create_client(provider, **config) + + +# Convenience aliases for common use cases +def create_openai_client(**kwargs) -> BaseLLMClient: + """Create an OpenAI client.""" + return create_client("openai", **kwargs) + + +def create_anthropic_client(**kwargs) -> BaseLLMClient: + """Create an Anthropic Claude client.""" + return create_client("anthropic", **kwargs) + + +def create_local_client(**kwargs) -> BaseLLMClient: + """Create a local LM Studio client.""" + return create_client("lmstudio", **kwargs) + + +__all__ = [ + # Base types + "LLMClient", + "LLMResponse", + "LLMToolResponse", + "ToolCall", + "BaseLLMClient", + # Exceptions + "LLMClientError", + "LLMAuthenticationError", + "LLMRateLimitError", + "LLMQuotaExceededError", + "LLMModelNotFoundError", + "LLMContextLengthError", + "LLMInvalidRequestError", + "LLMTimeoutError", + "LLMConnectionError", + "LLMServerError", + "LLMResponseParseError", + "LLMStreamError", + "LLMContentFilterError", + "CircuitBreakerOpenError", + # Factory functions + "create_client", + "create_client_from_config", + "create_openai_client", + "create_anthropic_client", + "create_local_client", + # Registry functions + "register_provider", + "list_providers", + "get_provider_class", +] diff --git a/src/adapters/llm/anthropic_client.py b/src/adapters/llm/anthropic_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f7482e47d385262879b3cadd767fba001264f6 --- /dev/null +++ b/src/adapters/llm/anthropic_client.py @@ -0,0 +1,521 @@ +""" +Anthropic Claude LLM client adapter. + +Implements the LLMClient protocol for Anthropic's Messages API. +Supports Claude 3 models with proper content block handling. +""" + +import json +import logging +from collections.abc import AsyncIterator +from typing import Any + +import httpx +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from .base import BaseLLMClient, LLMResponse, LLMToolResponse, ToolCall +from .exceptions import ( + CircuitBreakerOpenError, + LLMAuthenticationError, + LLMClientError, + LLMConnectionError, + LLMContentFilterError, + LLMContextLengthError, + LLMInvalidRequestError, + LLMModelNotFoundError, + LLMQuotaExceededError, + LLMRateLimitError, + LLMResponseParseError, + LLMServerError, + LLMStreamError, + LLMTimeoutError, +) +from .openai_client import CircuitBreaker + +logger = logging.getLogger(__name__) + + +# Model mappings for convenience +ANTHROPIC_MODELS = { + "claude-3-opus": "claude-3-opus-20240229", + "claude-3-sonnet": "claude-3-sonnet-20240229", + "claude-3-haiku": "claude-3-haiku-20240307", + "claude-3.5-sonnet": "claude-3-5-sonnet-20240620", + "claude-3.5-sonnet-v2": "claude-3-5-sonnet-20241022", + "claude-sonnet-4": "claude-sonnet-4-20250514", + # Add latest models + "opus": "claude-3-opus-20240229", + "sonnet": "claude-3-5-sonnet-20241022", + "haiku": "claude-3-haiku-20240307", +} + + +class AnthropicClient(BaseLLMClient): + """ + Anthropic Claude API client. + + Features: + - Messages API support (not legacy completion API) + - Content block handling (text, tool_use) + - Streaming with proper SSE parsing + - Model alias mapping + - System prompt support + - Tool/function calling (beta) + """ + + PROVIDER_NAME = "anthropic" + DEFAULT_BASE_URL = "https://api.anthropic.com" + DEFAULT_MODEL = "claude-3-5-sonnet-20241022" + API_VERSION = "2023-06-01" + + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + timeout: float = 120.0, # Claude can be slower + max_retries: int = 3, + # Circuit breaker settings + circuit_breaker_threshold: int = 5, + circuit_breaker_reset: float = 60.0, + # Rate limiting + rate_limit_per_minute: int | None = None, + ): + """ + Initialize Anthropic client. + + Args: + api_key: Anthropic API key (or set ANTHROPIC_API_KEY env var) + model: Model to use (supports aliases like 'sonnet', 'opus') + base_url: API base URL + timeout: Request timeout in seconds (default longer for Claude) + max_retries: Max retry attempts + circuit_breaker_threshold: Failures before circuit opens + circuit_breaker_reset: Seconds before circuit resets + rate_limit_per_minute: Rate limit for requests per minute (None to disable) + """ + import os + + api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + raise LLMAuthenticationError(self.PROVIDER_NAME, "API key not provided and ANTHROPIC_API_KEY not set") + + # Resolve model alias + model_name = model or self.DEFAULT_MODEL + resolved_model = ANTHROPIC_MODELS.get(model_name, model_name) + + super().__init__( + api_key=api_key, + model=resolved_model, + base_url=base_url or self.DEFAULT_BASE_URL, + timeout=timeout, + max_retries=max_retries, + rate_limit_per_minute=rate_limit_per_minute, + ) + + self.circuit_breaker = CircuitBreaker( + failure_threshold=circuit_breaker_threshold, + reset_timeout=circuit_breaker_reset, + ) + self._client: httpx.AsyncClient | None = None + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create the HTTP client.""" + if self._client is None or self._client.is_closed: + headers = { + "x-api-key": self.api_key, + "anthropic-version": self.API_VERSION, + "Content-Type": "application/json", + } + + self._client = httpx.AsyncClient( + base_url=self.base_url, + headers=headers, + timeout=httpx.Timeout(self.timeout), + ) + return self._client + + def _convert_messages_to_anthropic(self, messages: list[dict]) -> tuple[str | None, list[dict]]: + """ + Convert OpenAI-style messages to Anthropic format. + + Returns: + Tuple of (system_prompt, messages) + """ + system_prompt = None + anthropic_messages = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "system": + # Anthropic uses separate system parameter + system_prompt = content + elif role == "assistant": + anthropic_messages.append({"role": "assistant", "content": content}) + elif role == "user": + anthropic_messages.append({"role": "user", "content": content}) + elif role == "tool": + # Tool result message + anthropic_messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": content, + } + ], + } + ) + + return system_prompt, anthropic_messages + + def _convert_tools_to_anthropic(self, tools: list[dict]) -> list[dict]: + """Convert OpenAI-style tool definitions to Anthropic format.""" + anthropic_tools = [] + + for tool in tools: + if tool.get("type") == "function": + func = tool["function"] + anthropic_tools.append( + { + "name": func["name"], + "description": func.get("description", ""), + "input_schema": func.get("parameters", {"type": "object"}), + } + ) + else: + # Already in Anthropic format + anthropic_tools.append(tool) + + return anthropic_tools + + def _handle_error_response(self, response: httpx.Response) -> None: + """Convert HTTP error responses to appropriate exceptions.""" + status_code = response.status_code + + try: + error_data = response.json() + error_type = error_data.get("error", {}).get("type", "") + error_message = error_data.get("error", {}).get("message", response.text) + except Exception: + error_type = "" + error_message = response.text + + if status_code == 401: + raise LLMAuthenticationError(self.PROVIDER_NAME, error_message) + elif status_code == 429: + retry_after = response.headers.get("retry-after") + retry_after_float = float(retry_after) if retry_after else None + raise LLMRateLimitError(self.PROVIDER_NAME, retry_after=retry_after_float, message=error_message) + elif status_code == 402 or "billing" in error_type.lower(): + raise LLMQuotaExceededError(self.PROVIDER_NAME, error_message) + elif status_code == 404 or error_type == "not_found_error": + raise LLMModelNotFoundError(self.PROVIDER_NAME, self.model) + elif status_code == 400: + if "context" in error_message.lower() or "token" in error_message.lower(): + raise LLMContextLengthError(self.PROVIDER_NAME) + if "content_policy" in error_type or "safety" in error_message.lower(): + raise LLMContentFilterError(self.PROVIDER_NAME, error_message) + raise LLMInvalidRequestError(self.PROVIDER_NAME, error_message) + elif status_code >= 500: + raise LLMServerError(self.PROVIDER_NAME, status_code, error_message) + else: + raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code) + + def _make_retry_decorator(self): + """Create retry decorator with exponential backoff.""" + return retry( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=2, max=120), + retry=retry_if_exception_type((LLMRateLimitError, LLMServerError, LLMConnectionError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + + async def generate( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stream: bool = False, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse | AsyncIterator[str]: + """ + Generate a response from Anthropic Claude. + + Args: + messages: Chat messages (will be converted to Anthropic format) + prompt: Simple string prompt + temperature: Sampling temperature (0.0 to 1.0 for Claude) + max_tokens: Maximum tokens to generate (required for Anthropic) + tools: Tool definitions (will be converted to Anthropic format) + stream: If True, returns AsyncIterator + stop: Stop sequences + **kwargs: Additional parameters (top_p, top_k, etc.) + + Returns: + LLMResponse or AsyncIterator[str] for streaming + """ + # Apply rate limiting before proceeding + await self._apply_rate_limit() + + # Check circuit breaker + if not self.circuit_breaker.can_execute(): + raise CircuitBreakerOpenError( + self.PROVIDER_NAME, + self.circuit_breaker.failure_count, + self.circuit_breaker.get_reset_time(), + ) + + # Anthropic requires max_tokens + if max_tokens is None: + max_tokens = 4096 # Sensible default + + if stream: + return self._generate_stream( + messages=messages, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + stop=stop, + **kwargs, + ) + else: + return await self._generate_non_stream( + messages=messages, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + stop=stop, + **kwargs, + ) + + async def _generate_non_stream( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int = 4096, + tools: list[dict] | None = None, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """Non-streaming generation with retry logic.""" + + @self._make_retry_decorator() + async def _request(): + client = await self._get_client() + + # Convert messages + built_messages = self._build_messages(messages, prompt) + system_prompt, anthropic_messages = self._convert_messages_to_anthropic(built_messages) + + # Build request payload + payload = { + "model": self.model, + "messages": anthropic_messages, + "max_tokens": max_tokens, + "temperature": min(temperature, 1.0), # Anthropic max is 1.0 + } + + if system_prompt: + payload["system"] = system_prompt + if stop: + payload["stop_sequences"] = stop + if tools: + payload["tools"] = self._convert_tools_to_anthropic(tools) + + # Add any additional kwargs (top_p, top_k, etc.) + for key in ["top_p", "top_k", "metadata"]: + if key in kwargs: + payload[key] = kwargs[key] + + try: + response = await client.post("/v1/messages", json=payload) + except httpx.TimeoutException: + raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout) + except httpx.ConnectError: + raise LLMConnectionError(self.PROVIDER_NAME, self.base_url) + + if response.status_code != 200: + self._handle_error_response(response) + + return response + + try: + response = await _request() + self.circuit_breaker.record_success() + except Exception: + self.circuit_breaker.record_failure() + raise + + # Parse response + try: + data = response.json() + + # Extract text from content blocks + text_parts = [] + tool_calls = [] + + for block in data.get("content", []): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "tool_use": + tool_calls.append( + ToolCall( + id=block.get("id", ""), + name=block.get("name", ""), + arguments=block.get("input", {}), + type="tool_use", + ) + ) + + text = "\n".join(text_parts) + + # Build usage dict + usage = { + "prompt_tokens": data.get("usage", {}).get("input_tokens", 0), + "completion_tokens": data.get("usage", {}).get("output_tokens", 0), + } + usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"] + + finish_reason = data.get("stop_reason", "stop") + + if tool_calls: + llm_response = LLMToolResponse( + text=text, + usage=usage, + model=data.get("model", self.model), + raw_response=data, + finish_reason=finish_reason, + tool_calls=tool_calls, + ) + else: + llm_response = LLMResponse( + text=text, + usage=usage, + model=data.get("model", self.model), + raw_response=data, + finish_reason=finish_reason, + ) + + self._update_stats(llm_response) + return llm_response + + except (KeyError, json.JSONDecodeError) as e: + raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e + + async def _generate_stream( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int = 4096, + tools: list[dict] | None = None, + stop: list[str] | None = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + """Streaming generation with Server-Sent Events.""" + + client = await self._get_client() + + # Convert messages + built_messages = self._build_messages(messages, prompt) + system_prompt, anthropic_messages = self._convert_messages_to_anthropic(built_messages) + + # Build request payload + payload = { + "model": self.model, + "messages": anthropic_messages, + "max_tokens": max_tokens, + "temperature": min(temperature, 1.0), + "stream": True, + } + + if system_prompt: + payload["system"] = system_prompt + if stop: + payload["stop_sequences"] = stop + if tools: + payload["tools"] = self._convert_tools_to_anthropic(tools) + + for key in ["top_p", "top_k"]: + if key in kwargs: + payload[key] = kwargs[key] + + async def stream_generator(): + try: + async with client.stream("POST", "/v1/messages", json=payload) as response: + if response.status_code != 200: + await response.aread() + self._handle_error_response(response) + + async for line in response.aiter_lines(): + if not line.strip(): + continue + + if line.startswith("event:"): + event_type = line[6:].strip() + continue + + if line.startswith("data:"): + data_str = line[5:].strip() + if not data_str: + continue + + try: + data = json.loads(data_str) + event_type = data.get("type", "") + + if event_type == "content_block_delta": + delta = data.get("delta", {}) + if delta.get("type") == "text_delta": + text = delta.get("text", "") + if text: + yield text + + elif event_type == "message_stop": + break + + except json.JSONDecodeError: + continue + + self.circuit_breaker.record_success() + + except httpx.TimeoutException: + self.circuit_breaker.record_failure() + raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout) + except httpx.ConnectError: + self.circuit_breaker.record_failure() + raise LLMConnectionError(self.PROVIDER_NAME, self.base_url) + except Exception as e: + self.circuit_breaker.record_failure() + if isinstance(e, LLMClientError): + raise + raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e + + return stream_generator() + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None diff --git a/src/adapters/llm/base.py b/src/adapters/llm/base.py new file mode 100644 index 0000000000000000000000000000000000000000..31860b77004722a0d8f6f41c6fa9d59970140cd1 --- /dev/null +++ b/src/adapters/llm/base.py @@ -0,0 +1,305 @@ +""" +Base LLM client interface for provider-agnostic model access. + +This module defines the protocol and data structures for LLM clients, +enabling seamless switching between providers (OpenAI, Anthropic, LM Studio, etc.) +""" + +import asyncio +import time +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Protocol, runtime_checkable + + +@dataclass +class LLMResponse: + """Standardized response from any LLM provider.""" + + text: str + usage: dict = field(default_factory=dict) + model: str = "" + raw_response: Any = None + finish_reason: str = "stop" + created_at: datetime = field(default_factory=datetime.utcnow) + + @property + def total_tokens(self) -> int: + """Total tokens used in request/response.""" + return self.usage.get("total_tokens", 0) + + @property + def prompt_tokens(self) -> int: + """Tokens used in prompt.""" + return self.usage.get("prompt_tokens", 0) + + @property + def completion_tokens(self) -> int: + """Tokens used in completion.""" + return self.usage.get("completion_tokens", 0) + + +@dataclass +class ToolCall: + """Represents a tool/function call from the LLM.""" + + id: str + name: str + arguments: dict + type: str = "function" + + +@dataclass +class LLMToolResponse(LLMResponse): + """Response containing tool calls.""" + + tool_calls: list[ToolCall] = field(default_factory=list) + + +class TokenBucketRateLimiter: + """ + Token bucket rate limiter for controlling request rates. + + This implementation uses a token bucket algorithm where: + - Tokens are added at a fixed rate (rate_per_second) + - Each request consumes one token + - If no tokens available, caller waits until one becomes available + """ + + def __init__(self, rate_per_minute: int = 60): + """ + Initialize the rate limiter. + + Args: + rate_per_minute: Maximum requests allowed per minute + """ + self.rate_per_second = rate_per_minute / 60.0 + self.max_tokens = float(rate_per_minute) + self.tokens = self.max_tokens + self.last_refill = time.monotonic() + self._lock = asyncio.Lock() + self._wait_count = 0 + self._total_wait_time = 0.0 + + async def acquire(self) -> float: + """ + Acquire a token, waiting if necessary. + + Returns: + Time spent waiting (0.0 if no wait was needed) + """ + async with self._lock: + now = time.monotonic() + elapsed = now - self.last_refill + + # Refill tokens based on elapsed time + self.tokens = min(self.max_tokens, self.tokens + elapsed * self.rate_per_second) + self.last_refill = now + + wait_time = 0.0 + if self.tokens < 1: + # Calculate how long to wait for one token + wait_time = (1 - self.tokens) / self.rate_per_second + self._wait_count += 1 + self._total_wait_time += wait_time + + # Release lock during sleep to allow other operations + self._lock.release() + try: + await asyncio.sleep(wait_time) + finally: + await self._lock.acquire() + + # After sleeping, update time and set tokens to 0 + self.last_refill = time.monotonic() + self.tokens = 0 + else: + self.tokens -= 1 + + return wait_time + + @property + def stats(self) -> dict: + """Get rate limiter statistics.""" + return { + "rate_limit_waits": self._wait_count, + "total_rate_limit_wait_time": self._total_wait_time, + "current_tokens": self.tokens, + } + + +@runtime_checkable +class LLMClient(Protocol): + """ + Protocol for LLM clients. + + This protocol defines the interface that all LLM provider adapters must implement. + Using Protocol allows for structural subtyping (duck typing) while maintaining + type safety. + """ + + async def generate( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stream: bool = False, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse | AsyncIterator[str]: + """ + Generate a response from the LLM. + + Args: + messages: List of message dicts in OpenAI format [{"role": "...", "content": "..."}] + prompt: Simple string prompt (converted to single user message) + temperature: Sampling temperature (0.0 to 2.0) + max_tokens: Maximum tokens to generate + tools: List of tool definitions for function calling + stream: If True, returns AsyncIterator[str] for streaming + stop: Stop sequences + **kwargs: Provider-specific parameters + + Returns: + LLMResponse if stream=False, AsyncIterator[str] if stream=True + + Raises: + LLMClientError: Base exception for all client errors + """ + ... + + +class BaseLLMClient(ABC): + """ + Abstract base class for LLM clients. + + Provides common functionality and enforces the interface contract. + All concrete implementations should inherit from this class. + """ + + def __init__( + self, + api_key: str | None = None, + model: str = "default", + base_url: str | None = None, + timeout: float = 60.0, + max_retries: int = 3, + rate_limit_per_minute: int | None = None, + ): + """ + Initialize the LLM client. + + Args: + api_key: API key for authentication + model: Model identifier + base_url: Base URL for API requests + timeout: Request timeout in seconds + max_retries: Maximum number of retry attempts + rate_limit_per_minute: Rate limit (requests per minute), None to disable + """ + self.api_key = api_key + self.model = model + self.base_url = base_url + self.timeout = timeout + self.max_retries = max_retries + self._request_count = 0 + self._total_tokens_used = 0 + self._rate_limited_requests = 0 + + # Initialize rate limiter if configured + if rate_limit_per_minute is not None and rate_limit_per_minute > 0: + self._rate_limiter: TokenBucketRateLimiter | None = TokenBucketRateLimiter( + rate_per_minute=rate_limit_per_minute + ) + else: + self._rate_limiter = None + + @abstractmethod + async def generate( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stream: bool = False, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse | AsyncIterator[str]: + """Generate a response from the LLM.""" + pass + + def _build_messages( + self, + messages: list[dict] | None = None, + prompt: str | None = None, + ) -> list[dict]: + """ + Build message list from either messages or prompt. + + Args: + messages: Pre-formatted message list + prompt: Simple string prompt + + Returns: + List of message dicts + + Raises: + ValueError: If neither messages nor prompt provided + """ + if messages is not None: + return messages + elif prompt is not None: + return [{"role": "user", "content": prompt}] + else: + raise ValueError("Either 'messages' or 'prompt' must be provided") + + def _update_stats(self, response: LLMResponse) -> None: + """Update internal statistics.""" + self._request_count += 1 + self._total_tokens_used += response.total_tokens + + async def _apply_rate_limit(self) -> None: + """ + Apply rate limiting if configured. + + Waits if necessary to comply with rate limits. + Tracks rate-limited requests in metrics. + """ + if self._rate_limiter is not None: + wait_time = await self._rate_limiter.acquire() + if wait_time > 0: + self._rate_limited_requests += 1 + + @property + def stats(self) -> dict: + """Get client statistics.""" + base_stats = { + "request_count": self._request_count, + "total_tokens_used": self._total_tokens_used, + "rate_limited_requests": self._rate_limited_requests, + } + + # Include rate limiter stats if available + if self._rate_limiter is not None: + base_stats.update(self._rate_limiter.stats) + + return base_stats + + async def close(self) -> None: # noqa: B027 + """Clean up resources. Override in subclasses if needed.""" + pass + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() diff --git a/src/adapters/llm/exceptions.py b/src/adapters/llm/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..528231d3107b0b650f558d8c71091b34ebe683a1 --- /dev/null +++ b/src/adapters/llm/exceptions.py @@ -0,0 +1,204 @@ +""" +Custom exceptions for LLM client operations. + +Provides a hierarchy of structured exceptions for better error handling +and debugging across different LLM providers. +""" + + +class LLMClientError(Exception): + """Base exception for all LLM client errors.""" + + def __init__( + self, + message: str, + provider: str = "unknown", + status_code: int | None = None, + retry_after: float | None = None, + ): + self.message = message + self.provider = provider + self.status_code = status_code + self.retry_after = retry_after + super().__init__(self.message) + + def __str__(self) -> str: + parts = [f"[{self.provider}] {self.message}"] + if self.status_code: + parts.append(f"(status: {self.status_code})") + return " ".join(parts) + + +class LLMAuthenticationError(LLMClientError): + """Authentication failed - invalid or missing API key.""" + + def __init__(self, provider: str, message: str = "Authentication failed"): + super().__init__( + message=message, + provider=provider, + status_code=401, + ) + + +class LLMRateLimitError(LLMClientError): + """Rate limit exceeded - too many requests.""" + + def __init__( + self, + provider: str, + retry_after: float | None = None, + message: str = "Rate limit exceeded", + ): + super().__init__( + message=message, + provider=provider, + status_code=429, + retry_after=retry_after, + ) + + +class LLMQuotaExceededError(LLMClientError): + """Quota or credits exhausted.""" + + def __init__(self, provider: str, message: str = "Quota exceeded"): + super().__init__( + message=message, + provider=provider, + status_code=402, + ) + + +class LLMModelNotFoundError(LLMClientError): + """Requested model not available.""" + + def __init__(self, provider: str, model: str): + super().__init__( + message=f"Model '{model}' not found or not available", + provider=provider, + status_code=404, + ) + + +class LLMContextLengthError(LLMClientError): + """Input exceeds model's context window.""" + + def __init__( + self, + provider: str, + token_count: int | None = None, + max_tokens: int | None = None, + ): + message = "Context length exceeded" + if token_count and max_tokens: + message = f"Context length exceeded: {token_count} tokens provided, max is {max_tokens}" + super().__init__( + message=message, + provider=provider, + status_code=400, + ) + + +class LLMInvalidRequestError(LLMClientError): + """Invalid request parameters.""" + + def __init__(self, provider: str, message: str = "Invalid request parameters"): + super().__init__( + message=message, + provider=provider, + status_code=400, + ) + + +class LLMTimeoutError(LLMClientError): + """Request timed out.""" + + def __init__(self, provider: str, timeout: float): + super().__init__( + message=f"Request timed out after {timeout}s", + provider=provider, + status_code=408, + ) + + +class LLMConnectionError(LLMClientError): + """Failed to connect to the API endpoint.""" + + def __init__(self, provider: str, url: str | None = None): + message = "Failed to connect to API" + if url: + message = f"Failed to connect to {url}" + super().__init__( + message=message, + provider=provider, + ) + + +class LLMServerError(LLMClientError): + """Server-side error from the LLM provider.""" + + def __init__( + self, + provider: str, + status_code: int = 500, + message: str = "Server error", + ): + super().__init__( + message=message, + provider=provider, + status_code=status_code, + ) + + +class LLMResponseParseError(LLMClientError): + """Failed to parse response from LLM provider.""" + + def __init__(self, provider: str, raw_response: str | None = None): + message = "Failed to parse response" + if raw_response: + preview = raw_response[:200] + "..." if len(raw_response) > 200 else raw_response + message = f"Failed to parse response: {preview}" + super().__init__( + message=message, + provider=provider, + ) + + +class LLMStreamError(LLMClientError): + """Error during streaming response.""" + + def __init__(self, provider: str, message: str = "Stream interrupted"): + super().__init__( + message=message, + provider=provider, + ) + + +class LLMContentFilterError(LLMClientError): + """Content blocked by safety filters.""" + + def __init__(self, provider: str, reason: str | None = None): + message = "Content blocked by safety filters" + if reason: + message = f"Content blocked: {reason}" + super().__init__( + message=message, + provider=provider, + status_code=400, + ) + + +class CircuitBreakerOpenError(LLMClientError): + """Circuit breaker is open, requests are being blocked.""" + + def __init__( + self, + provider: str, + failure_count: int, + reset_time: float, + ): + super().__init__( + message=f"Circuit breaker open after {failure_count} failures. Resets in {reset_time:.1f}s", + provider=provider, + ) + self.failure_count = failure_count + self.reset_time = reset_time diff --git a/src/adapters/llm/lmstudio_client.py b/src/adapters/llm/lmstudio_client.py new file mode 100644 index 0000000000000000000000000000000000000000..50091d1cb9a8f680c8a963226442a68102f2eb2e --- /dev/null +++ b/src/adapters/llm/lmstudio_client.py @@ -0,0 +1,346 @@ +""" +LM Studio local LLM client adapter. + +Implements the LLMClient protocol for LM Studio's OpenAI-compatible API. +Designed for running local models with configurable endpoint. +""" + +import json +import logging +from collections.abc import AsyncIterator +from typing import Any + +import httpx + +from .base import BaseLLMClient, LLMResponse +from .exceptions import ( + LLMClientError, + LLMConnectionError, + LLMResponseParseError, + LLMServerError, + LLMStreamError, + LLMTimeoutError, +) + +logger = logging.getLogger(__name__) + + +class LMStudioClient(BaseLLMClient): + """ + LM Studio local server client. + + LM Studio provides an OpenAI-compatible API for running local models. + This client is optimized for local deployment with: + - No authentication required (local) + - Configurable base URL + - No circuit breaker (local server expected to be stable) + - Longer timeouts for large models + """ + + PROVIDER_NAME = "lmstudio" + DEFAULT_BASE_URL = "http://localhost:1234/v1" + DEFAULT_MODEL = "local-model" # LM Studio uses the loaded model + + def __init__( + self, + api_key: str | None = None, # Not required for local + model: str | None = None, + base_url: str | None = None, + timeout: float = 300.0, # Long timeout for local inference + max_retries: int = 2, # Fewer retries for local + # Rate limiting + rate_limit_per_minute: int | None = None, + ): + """ + Initialize LM Studio client. + + Args: + api_key: Not required for local server (ignored) + model: Model identifier (often ignored by LM Studio, uses loaded model) + base_url: Local server URL (default: http://localhost:1234/v1) + timeout: Request timeout in seconds (default longer for local models) + max_retries: Max retry attempts (fewer for local) + rate_limit_per_minute: Rate limit for requests per minute (None to disable) + """ + import os + + # Allow overriding via environment variable + base_url = base_url or os.environ.get("LMSTUDIO_BASE_URL", self.DEFAULT_BASE_URL) + + super().__init__( + api_key=api_key or "not-required", # Placeholder + model=model or self.DEFAULT_MODEL, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + rate_limit_per_minute=rate_limit_per_minute, + ) + + self._client: httpx.AsyncClient | None = None + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create the HTTP client.""" + if self._client is None or self._client.is_closed: + headers = {"Content-Type": "application/json"} + + # Add auth header if provided (some local servers may require it) + if self.api_key and self.api_key != "not-required": + headers["Authorization"] = f"Bearer {self.api_key}" + + self._client = httpx.AsyncClient( + base_url=self.base_url, + headers=headers, + timeout=httpx.Timeout(self.timeout), + ) + return self._client + + async def check_health(self) -> bool: + """ + Check if LM Studio server is running. + + Returns: + True if server is accessible, False otherwise + """ + try: + client = await self._get_client() + response = await client.get("/models") + return response.status_code == 200 + except Exception: + return False + + async def list_models(self) -> list[dict]: + """ + List available models on the LM Studio server. + + Returns: + List of model information dicts + """ + try: + client = await self._get_client() + response = await client.get("/models") + if response.status_code == 200: + data = response.json() + return data.get("data", []) + return [] + except Exception as e: + logger.warning(f"Failed to list models: {e}") + return [] + + def _handle_error_response(self, response: httpx.Response) -> None: + """Handle error responses from LM Studio server.""" + status_code = response.status_code + + try: + error_data = response.json() + error_message = error_data.get("error", {}).get("message", response.text) + except Exception: + error_message = response.text + + if status_code >= 500: + raise LLMServerError(self.PROVIDER_NAME, status_code, error_message) + else: + raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code) + + async def generate( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stream: bool = False, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse | AsyncIterator[str]: + """ + Generate a response from LM Studio local model. + + Args: + messages: Chat messages in OpenAI format + prompt: Simple string prompt + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + tools: Tool definitions (limited support in local models) + stream: If True, returns AsyncIterator + stop: Stop sequences + **kwargs: Additional parameters + + Returns: + LLMResponse or AsyncIterator[str] for streaming + """ + # Apply rate limiting before proceeding + await self._apply_rate_limit() + + if stream: + return self._generate_stream( + messages=messages, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + stop=stop, + **kwargs, + ) + else: + return await self._generate_non_stream( + messages=messages, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + stop=stop, + **kwargs, + ) + + async def _generate_non_stream( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """Non-streaming generation.""" + client = await self._get_client() + + # Build request payload (OpenAI-compatible) + payload = { + "model": self.model, + "messages": self._build_messages(messages, prompt), + "temperature": temperature, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop: + payload["stop"] = stop + + # Note: most local models don't support tools well + if tools: + logger.warning("Tool calling may not be fully supported by local models") + payload["tools"] = tools + + # Add additional kwargs (e.g., top_p, repeat_penalty) + for key in ["top_p", "top_k", "repeat_penalty", "presence_penalty", "frequency_penalty"]: + if key in kwargs: + payload[key] = kwargs[key] + + # Retry logic for local server + last_error = None + for attempt in range(self.max_retries): + try: + response = await client.post("/chat/completions", json=payload) + + if response.status_code != 200: + self._handle_error_response(response) + + # Parse response + try: + data = response.json() + choice = data["choices"][0] + message = choice["message"] + + usage = data.get("usage", {}) + finish_reason = choice.get("finish_reason", "stop") + + llm_response = LLMResponse( + text=message.get("content", ""), + usage=usage, + model=data.get("model", self.model), + raw_response=data, + finish_reason=finish_reason, + ) + + self._update_stats(llm_response) + return llm_response + + except (KeyError, json.JSONDecodeError) as e: + raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e + + except httpx.TimeoutException: + last_error = LLMTimeoutError(self.PROVIDER_NAME, self.timeout) + logger.warning(f"Attempt {attempt + 1} timed out, retrying...") + except httpx.ConnectError: + last_error = LLMConnectionError(self.PROVIDER_NAME, self.base_url) + logger.warning(f"Attempt {attempt + 1} connection failed, retrying...") + except LLMClientError: + raise # Don't retry client errors + + # All retries exhausted + if last_error: + raise last_error + raise LLMConnectionError(self.PROVIDER_NAME, self.base_url) + + async def _generate_stream( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, # noqa: ARG002 + stop: list[str] | None = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + """Streaming generation.""" + client = await self._get_client() + + # Build request payload + payload = { + "model": self.model, + "messages": self._build_messages(messages, prompt), + "temperature": temperature, + "stream": True, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop: + payload["stop"] = stop + + for key in ["top_p", "top_k", "repeat_penalty"]: + if key in kwargs: + payload[key] = kwargs[key] + + async def stream_generator(): + try: + async with client.stream("POST", "/chat/completions", json=payload) as response: + if response.status_code != 200: + await response.aread() + self._handle_error_response(response) + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + + try: + data = json.loads(data_str) + delta = data["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + yield content + except (json.JSONDecodeError, KeyError): + continue + + except httpx.TimeoutException: + raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout) + except httpx.ConnectError: + raise LLMConnectionError(self.PROVIDER_NAME, self.base_url) + except Exception as e: + if isinstance(e, LLMClientError): + raise + raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e + + return stream_generator() + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None diff --git a/src/adapters/llm/openai_client.py b/src/adapters/llm/openai_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4099510ef5f496bbc8a76ab40d70092b071c5995 --- /dev/null +++ b/src/adapters/llm/openai_client.py @@ -0,0 +1,458 @@ +""" +OpenAI-compatible LLM client adapter. + +Implements the LLMClient protocol for OpenAI API (and compatible APIs). +Includes retry logic, circuit breaker pattern, and streaming support. +""" + +import json +import logging +import time +from collections.abc import AsyncIterator +from typing import Any + +import httpx +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from .base import BaseLLMClient, LLMResponse, LLMToolResponse, ToolCall +from .exceptions import ( + CircuitBreakerOpenError, + LLMAuthenticationError, + LLMClientError, + LLMConnectionError, + LLMContextLengthError, + LLMInvalidRequestError, + LLMModelNotFoundError, + LLMQuotaExceededError, + LLMRateLimitError, + LLMResponseParseError, + LLMServerError, + LLMStreamError, + LLMTimeoutError, +) + +logger = logging.getLogger(__name__) + + +class CircuitBreaker: + """Simple circuit breaker implementation for resilience.""" + + def __init__( + self, + failure_threshold: int = 5, + reset_timeout: float = 60.0, + half_open_max_calls: int = 1, + ): + self.failure_threshold = failure_threshold + self.reset_timeout = reset_timeout + self.half_open_max_calls = half_open_max_calls + self.failure_count = 0 + self.last_failure_time = 0.0 + self.state = "closed" # closed, open, half-open + self.half_open_calls = 0 + + def can_execute(self) -> bool: + """Check if request can be executed.""" + if self.state == "closed": + return True + + if self.state == "open": + # Check if reset timeout has passed + if time.time() - self.last_failure_time >= self.reset_timeout: + self.state = "half-open" + self.half_open_calls = 0 + return True + return False + + if self.state == "half-open": + return self.half_open_calls < self.half_open_max_calls + + return False + + def record_success(self) -> None: + """Record successful request.""" + if self.state == "half-open": + self.state = "closed" + self.failure_count = 0 + elif self.state == "closed": + self.failure_count = 0 + + def record_failure(self) -> None: + """Record failed request.""" + self.failure_count += 1 + self.last_failure_time = time.time() + + if self.state == "half-open" or self.failure_count >= self.failure_threshold: + self.state = "open" + + def get_reset_time(self) -> float: + """Get time until circuit resets.""" + if self.state != "open": + return 0.0 + elapsed = time.time() - self.last_failure_time + return max(0, self.reset_timeout - elapsed) + + +class OpenAIClient(BaseLLMClient): + """ + OpenAI API client with retry logic and circuit breaker. + + Features: + - Exponential backoff retry for transient errors + - Circuit breaker to prevent cascading failures + - Streaming support + - Structured error handling + - Tool/function calling support + """ + + PROVIDER_NAME = "openai" + DEFAULT_BASE_URL = "https://api.openai.com/v1" + DEFAULT_MODEL = "gpt-4-turbo-preview" + + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + base_url: str | None = None, + timeout: float = 60.0, + max_retries: int = 3, + organization: str | None = None, + # Circuit breaker settings + circuit_breaker_threshold: int = 5, + circuit_breaker_reset: float = 60.0, + # Rate limiting + rate_limit_per_minute: int | None = None, + ): + """ + Initialize OpenAI client. + + Args: + api_key: OpenAI API key (or set OPENAI_API_KEY env var) + model: Model to use (default: gpt-4-turbo-preview) + base_url: API base URL (default: https://api.openai.com/v1) + timeout: Request timeout in seconds + max_retries: Max retry attempts for transient errors + organization: Optional organization ID + circuit_breaker_threshold: Failures before circuit opens + circuit_breaker_reset: Seconds before circuit resets + rate_limit_per_minute: Rate limit for requests per minute (None to disable) + """ + import os + + api_key = api_key or os.environ.get("OPENAI_API_KEY") + if not api_key: + raise LLMAuthenticationError(self.PROVIDER_NAME, "API key not provided and OPENAI_API_KEY not set") + + super().__init__( + api_key=api_key, + model=model or self.DEFAULT_MODEL, + base_url=base_url or self.DEFAULT_BASE_URL, + timeout=timeout, + max_retries=max_retries, + rate_limit_per_minute=rate_limit_per_minute, + ) + + self.organization = organization + self.circuit_breaker = CircuitBreaker( + failure_threshold=circuit_breaker_threshold, + reset_timeout=circuit_breaker_reset, + ) + + # Initialize async HTTP client + self._client: httpx.AsyncClient | None = None + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create the HTTP client.""" + if self._client is None or self._client.is_closed: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + if self.organization: + headers["OpenAI-Organization"] = self.organization + + self._client = httpx.AsyncClient( + base_url=self.base_url, + headers=headers, + timeout=httpx.Timeout(self.timeout), + ) + return self._client + + def _handle_error_response(self, response: httpx.Response) -> None: + """Convert HTTP error responses to appropriate exceptions.""" + status_code = response.status_code + + try: + error_data = response.json() + error_message = error_data.get("error", {}).get("message", response.text) + except Exception: + error_message = response.text + + if status_code == 401: + raise LLMAuthenticationError(self.PROVIDER_NAME, error_message) + elif status_code == 429: + retry_after = response.headers.get("Retry-After") + retry_after_float = float(retry_after) if retry_after else None + raise LLMRateLimitError(self.PROVIDER_NAME, retry_after=retry_after_float, message=error_message) + elif status_code == 402: + raise LLMQuotaExceededError(self.PROVIDER_NAME, error_message) + elif status_code == 404: + raise LLMModelNotFoundError(self.PROVIDER_NAME, self.model) + elif status_code == 400: + if "context_length" in error_message.lower(): + raise LLMContextLengthError(self.PROVIDER_NAME) + raise LLMInvalidRequestError(self.PROVIDER_NAME, error_message) + elif status_code >= 500: + raise LLMServerError(self.PROVIDER_NAME, status_code, error_message) + else: + raise LLMClientError(error_message, self.PROVIDER_NAME, status_code=status_code) + + def _make_retry_decorator(self): + """Create retry decorator with exponential backoff.""" + return retry( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=1, max=60), + retry=retry_if_exception_type((LLMRateLimitError, LLMServerError, LLMConnectionError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + + async def generate( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stream: bool = False, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse | AsyncIterator[str]: + """ + Generate a response from OpenAI. + + Args: + messages: Chat messages in OpenAI format + prompt: Simple string prompt + temperature: Sampling temperature (0.0 to 2.0) + max_tokens: Maximum tokens to generate + tools: Tool definitions for function calling + stream: If True, returns AsyncIterator + stop: Stop sequences + **kwargs: Additional OpenAI parameters (top_p, presence_penalty, etc.) + + Returns: + LLMResponse or AsyncIterator[str] for streaming + """ + # Apply rate limiting before proceeding + await self._apply_rate_limit() + + # Check circuit breaker + if not self.circuit_breaker.can_execute(): + raise CircuitBreakerOpenError( + self.PROVIDER_NAME, + self.circuit_breaker.failure_count, + self.circuit_breaker.get_reset_time(), + ) + + if stream: + return self._generate_stream( + messages=messages, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + stop=stop, + **kwargs, + ) + else: + return await self._generate_non_stream( + messages=messages, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + stop=stop, + **kwargs, + ) + + async def _generate_non_stream( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stop: list[str] | None = None, + **kwargs: Any, + ) -> LLMResponse: + """Non-streaming generation with retry logic.""" + + @self._make_retry_decorator() + async def _request(): + client = await self._get_client() + + # Build request payload + payload = { + "model": self.model, + "messages": self._build_messages(messages, prompt), + "temperature": temperature, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop: + payload["stop"] = stop + if tools: + payload["tools"] = tools + payload["tool_choice"] = kwargs.pop("tool_choice", "auto") + + # Add any additional kwargs + payload.update(kwargs) + + try: + response = await client.post("/chat/completions", json=payload) + except httpx.TimeoutException: + raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout) + except httpx.ConnectError: + raise LLMConnectionError(self.PROVIDER_NAME, self.base_url) + + if response.status_code != 200: + self._handle_error_response(response) + + return response + + try: + response = await _request() + self.circuit_breaker.record_success() + except Exception: + self.circuit_breaker.record_failure() + raise + + # Parse response + try: + data = response.json() + choice = data["choices"][0] + message = choice["message"] + + usage = data.get("usage", {}) + finish_reason = choice.get("finish_reason", "stop") + + # Check for tool calls + if "tool_calls" in message: + tool_calls = [ + ToolCall( + id=tc["id"], + name=tc["function"]["name"], + arguments=json.loads(tc["function"]["arguments"]), + ) + for tc in message["tool_calls"] + ] + llm_response = LLMToolResponse( + text=message.get("content", ""), + usage=usage, + model=data.get("model", self.model), + raw_response=data, + finish_reason=finish_reason, + tool_calls=tool_calls, + ) + else: + llm_response = LLMResponse( + text=message.get("content", ""), + usage=usage, + model=data.get("model", self.model), + raw_response=data, + finish_reason=finish_reason, + ) + + self._update_stats(llm_response) + return llm_response + + except (KeyError, json.JSONDecodeError) as e: + raise LLMResponseParseError(self.PROVIDER_NAME, response.text) from e + + async def _generate_stream( + self, + *, + messages: list[dict] | None = None, + prompt: str | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + stop: list[str] | None = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + """Streaming generation.""" + + client = await self._get_client() + + # Build request payload + payload = { + "model": self.model, + "messages": self._build_messages(messages, prompt), + "temperature": temperature, + "stream": True, + } + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if stop: + payload["stop"] = stop + # Note: tools with streaming have limited support + if tools: + payload["tools"] = tools + + payload.update(kwargs) + + async def stream_generator(): + try: + async with client.stream("POST", "/chat/completions", json=payload) as response: + if response.status_code != 200: + # Read the full response for error handling + await response.aread() + self._handle_error_response(response) + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + + try: + data = json.loads(data_str) + delta = data["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + yield content + except (json.JSONDecodeError, KeyError): + continue + + self.circuit_breaker.record_success() + + except httpx.TimeoutException: + self.circuit_breaker.record_failure() + raise LLMTimeoutError(self.PROVIDER_NAME, self.timeout) + except httpx.ConnectError: + self.circuit_breaker.record_failure() + raise LLMConnectionError(self.PROVIDER_NAME, self.base_url) + except Exception as e: + self.circuit_breaker.record_failure() + if isinstance(e, LLMClientError): + raise + raise LLMStreamError(self.PROVIDER_NAME, str(e)) from e + + return stream_generator() + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None diff --git a/src/agents/__init__.py b/src/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/agents/hrm_agent.py b/src/agents/hrm_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3bc482f901ca777de39904d1f4c9e4af612f7c --- /dev/null +++ b/src/agents/hrm_agent.py @@ -0,0 +1,454 @@ +""" +Hierarchical Reasoning Model (HRM) Agent. + +Implements the HRM architecture with: +- H-Module: High-level planning and decomposition +- L-Module: Low-level execution and refinement +- Adaptive Computation Time (ACT) for dynamic depth +- Halting mechanism based on confidence thresholds + +Based on: "Hierarchical Reasoning for Compositional Generalization" +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..training.system_config import HRMConfig + + +@dataclass +class SubProblem: + """Represents a decomposed subproblem in the hierarchy.""" + + level: int # Hierarchy level (0 = root, higher = more abstract) + description: str # Natural language description + state: torch.Tensor # Latent state representation + parent_id: int | None = None # Parent subproblem ID + confidence: float = 0.0 # Confidence in this decomposition + + +@dataclass +class HRMOutput: + """Output from HRM processing.""" + + final_state: torch.Tensor # Final processed state + subproblems: list[SubProblem] # Hierarchical decomposition + halt_step: int # Step at which halting occurred + total_ponder_cost: float # Total computation cost (for training) + convergence_path: list[float] # Confidence at each step + + +class AdaptiveComputationTime(nn.Module): + """ + Adaptive Computation Time (ACT) mechanism for dynamic depth. + + Allows the model to "ponder" longer on difficult problems by + dynamically adjusting the number of processing steps. + """ + + def __init__(self, hidden_dim: int, epsilon: float = 0.01): + super().__init__() + self.epsilon = epsilon + + # Halting unit: predicts probability of halting + self.halt_fc = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + nn.Sigmoid(), + ) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, float]: + """ + Compute halting probabilities. + + Args: + hidden_states: [batch, seq, hidden_dim] + + Returns: + halt_probs: [batch, seq] probability of halting + ponder_cost: Scalar cost for training + """ + # Compute halting probabilities + halt_logits = self.halt_fc(hidden_states) # [batch, seq, 1] + halt_probs = halt_logits.squeeze(-1) # [batch, seq] + + # Ponder cost is the expected number of steps + ponder_cost = halt_probs.sum(dim=-1).mean() + + return halt_probs, ponder_cost + + +class HModule(nn.Module): + """ + H-Module: High-level planning and abstract reasoning. + + Responsible for: + - Decomposing problems into subproblems + - Abstract planning and strategy + - Coordinating L-module executions + """ + + def __init__(self, config: HRMConfig): + super().__init__() + self.config = config + + # Multi-head self-attention for relational reasoning + self.attention = nn.MultiheadAttention( + embed_dim=config.h_dim, + num_heads=8, + dropout=config.dropout, + batch_first=True, + ) + + # Feed-forward network + self.ffn = nn.Sequential( + nn.Linear(config.h_dim, config.h_dim * 4), + nn.GELU(), + nn.Dropout(config.dropout), + nn.Linear(config.h_dim * 4, config.h_dim), + nn.Dropout(config.dropout), + ) + + # Layer normalization + self.norm1 = nn.LayerNorm(config.h_dim) + self.norm2 = nn.LayerNorm(config.h_dim) + + # Decomposition head: outputs subproblem structure + self.decompose_head = nn.Sequential( + nn.Linear(config.h_dim, config.h_dim), + nn.ReLU(), + nn.Linear(config.h_dim, config.h_dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Process input through high-level reasoning. + + Args: + x: [batch, seq, h_dim] input tensor + + Returns: + [batch, seq, h_dim] processed tensor + """ + # Self-attention for relational reasoning + attn_out, _ = self.attention(x, x, x) + x = self.norm1(x + attn_out) + + # Feed-forward processing + ffn_out = self.ffn(x) + x = self.norm2(x + ffn_out) + + return x + + def decompose(self, x: torch.Tensor) -> torch.Tensor: + """Generate subproblem representations.""" + return self.decompose_head(x) + + +class LModule(nn.Module): + """ + L-Module: Low-level execution and concrete operations. + + Responsible for: + - Executing concrete operations + - Processing individual subproblems + - Generating intermediate results + """ + + def __init__(self, config: HRMConfig): + super().__init__() + self.config = config + + # Projection from H-module to L-module dimension + self.h_to_l = nn.Linear(config.h_dim, config.l_dim) + + # GRU for sequential processing + self.gru = nn.GRU( + input_size=config.l_dim, + hidden_size=config.l_dim, + num_layers=config.num_l_layers, + dropout=config.dropout if config.num_l_layers > 1 else 0, + batch_first=True, + ) + + # Output projection + self.output_proj = nn.Sequential( + nn.Linear(config.l_dim, config.l_dim * 2), + nn.ReLU(), + nn.Dropout(config.dropout), + nn.Linear(config.l_dim * 2, config.l_dim), + ) + + # Back-projection to H-module dimension + self.l_to_h = nn.Linear(config.l_dim, config.h_dim) + + def forward(self, x: torch.Tensor, h_context: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]: + """ + Execute low-level processing. + + Args: + x: [batch, seq, h_dim] input from H-module + h_context: Optional hidden state + + Returns: + output: [batch, seq, l_dim] processed output + l_to_h: [batch, seq, h_dim] back-projection to H-module + """ + # Project to L-module dimension + x_l = self.h_to_l(x) + + # Sequential processing + gru_out, _ = self.gru(x_l, h_context) + + # Output processing + output = self.output_proj(gru_out) + + # Back-project to H-module dimension for feedback + feedback = self.l_to_h(output) + + return output, feedback + + +class HRMAgent(nn.Module): + """ + Complete Hierarchical Reasoning Model agent. + + Combines H-module and L-module with ACT for adaptive computation. + """ + + def __init__(self, config: HRMConfig, device: str = "cpu"): + super().__init__() + self.config = config + self.device = device + + # Input embedding + self.input_proj = nn.Linear(config.h_dim, config.h_dim) + + # Core modules + self.h_module = nn.ModuleList([HModule(config) for _ in range(config.num_h_layers)]) + + self.l_module = LModule(config) + + # Adaptive computation time + self.act = AdaptiveComputationTime(config.h_dim, config.ponder_epsilon) + + # State integration + self.integrate = nn.Sequential( + nn.Linear(config.h_dim * 2, config.h_dim), + nn.LayerNorm(config.h_dim), + nn.GELU(), + ) + + self.to(device) + + def forward( + self, + x: torch.Tensor, + max_steps: int | None = None, + return_decomposition: bool = False, + ) -> HRMOutput: + """ + Process input through hierarchical reasoning. + + Args: + x: [batch, seq, h_dim] input tensor + max_steps: Maximum outer loop steps (defaults to config) + return_decomposition: Whether to return subproblem decomposition + + Returns: + HRMOutput containing final state and optional decomposition + """ + batch_size, seq_len, _ = x.shape + max_steps = max_steps or self.config.max_outer_steps + + # Initial projection + h_state = self.input_proj(x) + + # Tracking + subproblems = [] + convergence_path = [] + total_ponder_cost = 0.0 + + # Outer loop: iterative refinement + for step in range(max_steps): + # H-module: high-level planning + for h_layer in self.h_module: + h_state = h_layer(h_state) + + # Check halting condition + halt_probs, ponder_cost = self.act(h_state) + total_ponder_cost += ponder_cost + + # Average halting probability across sequence + avg_halt_prob = halt_probs.mean().item() + convergence_path.append(avg_halt_prob) + + # Generate subproblem decomposition if requested + if return_decomposition: + subproblem_repr = self.h_module[0].decompose(h_state) + # Create subproblem entries (simplified) + for i in range(min(3, seq_len)): # Top 3 subproblems + subproblems.append( + SubProblem( + level=step, + description=f"Subproblem at step {step}, position {i}", + state=subproblem_repr[:, i, :].detach(), + confidence=halt_probs[:, i].mean().item(), + ) + ) + + # Halt if confident enough + if avg_halt_prob >= self.config.halt_threshold: + break + + # L-module: low-level execution + l_output, l_feedback = self.l_module(h_state) + + # Integrate L-module feedback + h_state = self.integrate(torch.cat([h_state, l_feedback], dim=-1)) + + return HRMOutput( + final_state=h_state, + subproblems=subproblems, + halt_step=step + 1, + total_ponder_cost=total_ponder_cost, + convergence_path=convergence_path, + ) + + async def decompose_problem(self, query: str, state: torch.Tensor) -> list[SubProblem]: + """ + Decompose a problem into hierarchical subproblems. + + Args: + query: Natural language problem description + state: Initial state representation + + Returns: + List of subproblems in hierarchical order + """ + # Ensure state is batched + if state.dim() == 2: + state = state.unsqueeze(0) # [1, seq, dim] + + # Forward pass with decomposition + output = self.forward(state, return_decomposition=True) + + # Add query context to subproblems + for i, sp in enumerate(output.subproblems): + sp.description = f"{query} -> Level {sp.level} Subproblem {i}" + + return output.subproblems + + def get_parameter_count(self) -> int: + """Return total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +# Training utilities +class HRMLoss(nn.Module): + """ + Combined loss for HRM training. + + Includes: + - Task loss (e.g., cross-entropy for classification) + - Ponder cost regularization (encourages efficiency) + - Consistency loss (encourages stable convergence) + """ + + def __init__( + self, + task_weight: float = 1.0, + ponder_weight: float = 0.01, + consistency_weight: float = 0.1, + ): + super().__init__() + self.task_weight = task_weight + self.ponder_weight = ponder_weight + self.consistency_weight = consistency_weight + + def forward( + self, + hrm_output: HRMOutput, + predictions: torch.Tensor, + targets: torch.Tensor, + task_loss_fn: nn.Module, + ) -> tuple[torch.Tensor, dict]: + """ + Compute combined loss. + + Args: + hrm_output: Output from HRM forward pass + predictions: Model predictions + targets: Ground truth targets + task_loss_fn: Loss function for the task + + Returns: + total_loss: Combined loss + loss_dict: Dictionary of individual loss components + """ + # Task loss + task_loss = task_loss_fn(predictions, targets) + + # Ponder cost (encourages efficiency) + ponder_loss = hrm_output.total_ponder_cost + + # Consistency loss (encourages monotonic convergence) + if len(hrm_output.convergence_path) > 1: + conv_tensor = torch.tensor(hrm_output.convergence_path) + # Penalize non-monotonic increases + diffs = conv_tensor[1:] - conv_tensor[:-1] + consistency_loss = F.relu(-diffs).mean() # Penalize decreases + else: + consistency_loss = torch.tensor(0.0) + + # Combine losses + total_loss = ( + self.task_weight * task_loss + self.ponder_weight * ponder_loss + self.consistency_weight * consistency_loss + ) + + loss_dict = { + "total": total_loss.item(), + "task": task_loss.item(), + "ponder": ponder_loss, + "consistency": consistency_loss.item(), + "halt_step": hrm_output.halt_step, + } + + return total_loss, loss_dict + + +def create_hrm_agent(config: HRMConfig, device: str = "cpu") -> HRMAgent: + """ + Factory function to create and initialize HRM agent. + + Args: + config: HRM configuration + device: Device to place model on + + Returns: + Initialized HRMAgent + """ + agent = HRMAgent(config, device) + + # Initialize weights + def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.GRU): + for name, param in m.named_parameters(): + if "weight" in name: + nn.init.orthogonal_(param) + elif "bias" in name: + nn.init.zeros_(param) + + agent.apply(init_weights) + + return agent diff --git a/src/agents/meta_controller/__init__.py b/src/agents/meta_controller/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88b175d96045c7dc2693833a3b09261536bd2a8c --- /dev/null +++ b/src/agents/meta_controller/__init__.py @@ -0,0 +1,45 @@ +""" +Neural Meta-Controller package for Multi-Agent MCTS Framework. + +This package provides the base infrastructure for neural network-based +meta-controllers that dynamically select which agent to route queries to. +""" + +from src.agents.meta_controller.base import ( + AbstractMetaController, + MetaControllerFeatures, + MetaControllerPrediction, +) +from src.agents.meta_controller.rnn_controller import ( + RNNMetaController, + RNNMetaControllerModel, +) +from src.agents.meta_controller.utils import ( + features_to_tensor, + features_to_text, + normalize_features, + one_hot_encode_agent, +) + +# Import BERT controller (may not be available if transformers/peft not installed) +try: + from src.agents.meta_controller.bert_controller import BERTMetaController # noqa: F401 + + _bert_available = True +except ImportError: + _bert_available = False + +__all__ = [ + "AbstractMetaController", + "MetaControllerFeatures", + "MetaControllerPrediction", + "normalize_features", + "one_hot_encode_agent", + "features_to_tensor", + "features_to_text", + "RNNMetaController", + "RNNMetaControllerModel", +] + +if _bert_available: + __all__.append("BERTMetaController") diff --git a/src/agents/meta_controller/base.py b/src/agents/meta_controller/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e35093c3a428a74724227f91e8eaac147fa029 --- /dev/null +++ b/src/agents/meta_controller/base.py @@ -0,0 +1,219 @@ +""" +Abstract base class for Neural Meta-Controllers. + +Provides the foundation for neural network-based meta-controllers that +dynamically select which agent (HRM, TRM, or MCTS) should handle a query. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class MetaControllerFeatures: + """ + Features extracted from the current agent state for meta-controller prediction. + + These features capture the current state of the multi-agent system, + including confidence scores from different agents and contextual information. + """ + + hrm_confidence: float + """Confidence score from the HRM (Human Response Model) agent.""" + + trm_confidence: float + """Confidence score from the TRM (Task Response Model) agent.""" + + mcts_value: float + """Value estimate from the MCTS (Monte Carlo Tree Search) process.""" + + consensus_score: float + """Agreement score between different agents.""" + + last_agent: str + """Name of the last agent used ('hrm', 'trm', 'mcts', or 'none').""" + + iteration: int + """Current iteration number in the reasoning process.""" + + query_length: int + """Length of the input query in characters.""" + + has_rag_context: bool + """Whether RAG (Retrieval-Augmented Generation) context is available.""" + + +@dataclass +class MetaControllerPrediction: + """ + Prediction output from the meta-controller. + + Contains the selected agent and associated confidence/probability information. + """ + + agent: str + """Name of the selected agent ('hrm', 'trm', or 'mcts').""" + + confidence: float + """Confidence score for the prediction (0.0 to 1.0).""" + + probabilities: dict[str, float] = field(default_factory=dict) + """Probability distribution over all possible agents.""" + + +class AbstractMetaController(ABC): + """ + Abstract base class for neural meta-controllers. + + This class defines the interface that all meta-controller implementations + must follow. Meta-controllers are responsible for deciding which agent + should handle a given query based on the current system state. + + Attributes: + AGENT_NAMES: List of valid agent names that can be selected. + name: Name of this meta-controller instance. + seed: Random seed for reproducibility. + """ + + AGENT_NAMES = ["hrm", "trm", "mcts"] + + def __init__(self, name: str, seed: int = 42) -> None: + """ + Initialize the meta-controller. + + Args: + name: Name identifier for this meta-controller instance. + seed: Random seed for reproducibility. Defaults to 42. + """ + self.name = name + self.seed = seed + + @abstractmethod + def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction: + """ + Predict which agent should handle the current query. + + Args: + features: Features extracted from the current agent state. + + Returns: + Prediction containing the selected agent and confidence scores. + """ + pass + + @abstractmethod + def load_model(self, path: str) -> None: + """ + Load a trained model from disk. + + Args: + path: Path to the saved model file or directory. + """ + pass + + @abstractmethod + def save_model(self, path: str) -> None: + """ + Save the current model to disk. + + Args: + path: Path where the model should be saved. + """ + pass + + def extract_features(self, state: dict[str, Any]) -> MetaControllerFeatures: + """ + Extract meta-controller features from an AgentState dictionary. + + This method converts raw state information into the structured + MetaControllerFeatures format required for prediction. + + Args: + state: Dictionary containing agent state information. + Expected keys include: + - 'hrm_confidence' or nested in 'agent_confidences' + - 'trm_confidence' or nested in 'agent_confidences' + - 'mcts_value' or nested in 'mcts_state' + - 'consensus_score' + - 'last_agent' + - 'iteration' + - 'query' or 'query_length' + - 'rag_context' or 'has_rag_context' + + Returns: + MetaControllerFeatures instance with extracted values. + + Example: + >>> state = { + ... 'agent_confidences': {'hrm': 0.8, 'trm': 0.6}, + ... 'mcts_state': {'value': 0.75}, + ... 'consensus_score': 0.7, + ... 'last_agent': 'hrm', + ... 'iteration': 2, + ... 'query': 'What is machine learning?', + ... 'rag_context': 'ML is a subset of AI...' + ... } + >>> features = controller.extract_features(state) + """ + # Extract HRM confidence + if "hrm_confidence" in state: + hrm_confidence = float(state["hrm_confidence"]) + elif "agent_confidences" in state and isinstance(state["agent_confidences"], dict): + hrm_confidence = float(state["agent_confidences"].get("hrm", 0.0)) + else: + hrm_confidence = 0.0 + + # Extract TRM confidence + if "trm_confidence" in state: + trm_confidence = float(state["trm_confidence"]) + elif "agent_confidences" in state and isinstance(state["agent_confidences"], dict): + trm_confidence = float(state["agent_confidences"].get("trm", 0.0)) + else: + trm_confidence = 0.0 + + # Extract MCTS value + if "mcts_value" in state: + mcts_value = float(state["mcts_value"]) + elif "mcts_state" in state and isinstance(state["mcts_state"], dict): + mcts_value = float(state["mcts_state"].get("value", 0.0)) + else: + mcts_value = 0.0 + + # Extract consensus score + consensus_score = float(state.get("consensus_score", 0.0)) + + # Extract last agent + last_agent = str(state.get("last_agent", "none")) + if last_agent not in self.AGENT_NAMES and last_agent != "none": + last_agent = "none" + + # Extract iteration + iteration = int(state.get("iteration", 0)) + + # Extract query length + if "query_length" in state: + query_length = int(state["query_length"]) + elif "query" in state and isinstance(state["query"], str): + query_length = len(state["query"]) + else: + query_length = 0 + + # Extract has_rag_context + if "has_rag_context" in state: + has_rag_context = bool(state["has_rag_context"]) + elif "rag_context" in state: + has_rag_context = state["rag_context"] is not None and len(str(state["rag_context"])) > 0 + else: + has_rag_context = False + + return MetaControllerFeatures( + hrm_confidence=hrm_confidence, + trm_confidence=trm_confidence, + mcts_value=mcts_value, + consensus_score=consensus_score, + last_agent=last_agent, + iteration=iteration, + query_length=query_length, + has_rag_context=has_rag_context, + ) diff --git a/src/agents/meta_controller/bert_controller.py b/src/agents/meta_controller/bert_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..2b466036ef59475f362ded252ea73ca76e063ef6 --- /dev/null +++ b/src/agents/meta_controller/bert_controller.py @@ -0,0 +1,428 @@ +""" +BERT-based Meta-Controller with LoRA adapters for efficient fine-tuning. + +This module provides a BERT-based meta-controller that uses Low-Rank Adaptation (LoRA) +for parameter-efficient fine-tuning. The controller converts agent state features into +text and uses a sequence classification model to predict the optimal agent. +""" + +import warnings +from typing import Any + +import torch + +from src.agents.meta_controller.base import ( + AbstractMetaController, + MetaControllerFeatures, + MetaControllerPrediction, +) +from src.agents.meta_controller.utils import features_to_text + +# Handle optional transformers and peft imports gracefully +_TRANSFORMERS_AVAILABLE = False +_PEFT_AVAILABLE = False + +try: + from transformers import AutoModelForSequenceClassification, AutoTokenizer + + _TRANSFORMERS_AVAILABLE = True +except ImportError: + warnings.warn( + "transformers library not installed. Install it with: pip install transformers", + ImportWarning, + stacklevel=2, + ) + AutoTokenizer = None # type: ignore + AutoModelForSequenceClassification = None # type: ignore + +try: + from peft import LoraConfig, TaskType, get_peft_model + + _PEFT_AVAILABLE = True +except ImportError: + warnings.warn( + "peft library not installed. Install it with: pip install peft", + ImportWarning, + stacklevel=2, + ) + LoraConfig = None # type: ignore + TaskType = None # type: ignore + get_peft_model = None # type: ignore + + +class BERTMetaController(AbstractMetaController): + """ + BERT-based meta-controller with optional LoRA adapters for efficient fine-tuning. + + This controller converts agent state features into structured text and uses + a pre-trained BERT model (with optional LoRA adapters) to classify which + agent should handle the current query. LoRA enables parameter-efficient + fine-tuning by only training low-rank decomposition matrices. + + Attributes: + DEFAULT_MODEL_NAME: Default BERT model to use. + NUM_LABELS: Number of output labels (agents to choose from). + device: PyTorch device for tensor operations. + model_name: Name of the pre-trained model. + lora_r: LoRA rank parameter. + lora_alpha: LoRA alpha scaling parameter. + lora_dropout: LoRA dropout rate. + use_lora: Whether to use LoRA adapters. + tokenizer: BERT tokenizer for text processing. + model: BERT sequence classification model (with or without LoRA). + + Example: + >>> controller = BERTMetaController(name="BERTController", seed=42) + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.8, + ... trm_confidence=0.6, + ... mcts_value=0.75, + ... consensus_score=0.7, + ... last_agent='hrm', + ... iteration=2, + ... query_length=150, + ... has_rag_context=True + ... ) + >>> prediction = controller.predict(features) + >>> prediction.agent in ['hrm', 'trm', 'mcts'] + True + >>> 0.0 <= prediction.confidence <= 1.0 + True + """ + + DEFAULT_MODEL_NAME = "prajjwal1/bert-mini" + NUM_LABELS = 3 + + def __init__( + self, + name: str = "BERTMetaController", + seed: int = 42, + model_name: str | None = None, + lora_r: int = 4, + lora_alpha: int = 16, + lora_dropout: float = 0.1, + device: str | None = None, + use_lora: bool = True, + ) -> None: + """ + Initialize the BERT meta-controller with optional LoRA adapters. + + Args: + name: Name identifier for this controller. Defaults to "BERTMetaController". + seed: Random seed for reproducibility. Defaults to 42. + model_name: Pre-trained model name from HuggingFace. If None, uses DEFAULT_MODEL_NAME. + lora_r: LoRA rank parameter (lower = more compression). Defaults to 4. + lora_alpha: LoRA alpha scaling parameter. Defaults to 16. + lora_dropout: Dropout rate for LoRA layers. Defaults to 0.1. + device: Device to run model on ('cpu', 'cuda', 'mps', etc.). + If None, auto-detects best available device. + use_lora: Whether to apply LoRA adapters to the model. Defaults to True. + + Raises: + ImportError: If transformers library is not installed. + ImportError: If use_lora is True and peft library is not installed. + + Example: + >>> controller = BERTMetaController( + ... name="CustomBERT", + ... seed=123, + ... lora_r=8, + ... lora_alpha=32, + ... use_lora=True + ... ) + """ + super().__init__(name=name, seed=seed) + + # Check for required dependencies + if not _TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers library is required for BERTMetaController. Install it with: pip install transformers" + ) + + if use_lora and not _PEFT_AVAILABLE: + raise ImportError("peft library is required for LoRA support. Install it with: pip install peft") + + # Set random seed for reproducibility + torch.manual_seed(seed) + + # Auto-detect device if not specified + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + + # Store configuration parameters + self.model_name = model_name if model_name is not None else self.DEFAULT_MODEL_NAME + self.lora_r = lora_r + self.lora_alpha = lora_alpha + self.lora_dropout = lora_dropout + self.use_lora = use_lora + + # Initialize tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + # Initialize base model for sequence classification + base_model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.NUM_LABELS) + + # Apply LoRA adapters if requested + if self.use_lora: + lora_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + target_modules=["query", "value"], + ) + self.model = get_peft_model(base_model, lora_config) + else: + self.model = base_model + + # Move model to device + self.model = self.model.to(self.device) + + # Set model to evaluation mode + self.model.eval() + + # Initialize tokenization cache for performance optimization + self._tokenization_cache: dict[str, Any] = {} + + def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction: + """ + Predict which agent should handle the current query. + + Converts features to structured text, tokenizes the text, runs through + the BERT model, and returns a prediction with confidence scores. + + Args: + features: Features extracted from the current agent state. + + Returns: + Prediction containing the selected agent, confidence score, + and probability distribution over all agents. + + Example: + >>> controller = BERTMetaController() + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.9, + ... trm_confidence=0.3, + ... mcts_value=0.5, + ... consensus_score=0.8, + ... last_agent='none', + ... iteration=0, + ... query_length=100, + ... has_rag_context=False + ... ) + >>> pred = controller.predict(features) + >>> isinstance(pred.agent, str) + True + >>> isinstance(pred.confidence, float) + True + >>> len(pred.probabilities) == 3 + True + """ + # Convert features to structured text + text = features_to_text(features) + + # Check cache for tokenized text + if text in self._tokenization_cache: + inputs = self._tokenization_cache[text] + else: + # Tokenize the text + inputs = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ) + # Cache the tokenized result + self._tokenization_cache[text] = inputs + + # Move inputs to device + inputs = {key: value.to(self.device) for key, value in inputs.items()} + + # Perform inference without gradient tracking + with torch.no_grad(): + # Get logits from model + outputs = self.model(**inputs) + logits = outputs.logits + + # Apply softmax to get probabilities + probabilities = torch.nn.functional.softmax(logits, dim=-1) + + # Get predicted agent index (argmax) + predicted_idx = torch.argmax(probabilities, dim=-1).item() + + # Extract confidence for selected agent + confidence = probabilities[0, predicted_idx].item() + + # Create probability dictionary + prob_dict: dict[str, float] = {} + for i, agent_name in enumerate(self.AGENT_NAMES): + prob_dict[agent_name] = probabilities[0, i].item() + + # Get agent name + selected_agent = self.AGENT_NAMES[predicted_idx] + + return MetaControllerPrediction( + agent=selected_agent, + confidence=float(confidence), + probabilities=prob_dict, + ) + + def load_model(self, path: str) -> None: + """ + Load a trained model from disk. + + For LoRA models, loads the PEFT adapter weights. For base models, + loads the full state dictionary. + + Args: + path: Path to the saved model file or directory. + For LoRA models, this should be a directory containing + adapter_config.json and adapter_model.bin. + For base models, this should be a .pt or .pth file. + + Raises: + FileNotFoundError: If the model file or directory does not exist. + RuntimeError: If the state dict is incompatible with the model. + + Example: + >>> controller = BERTMetaController(use_lora=True) + >>> controller.load_model("/path/to/lora_adapter") + >>> controller = BERTMetaController(use_lora=False) + >>> controller.load_model("/path/to/model.pt") + """ + if self.use_lora: + # Load PEFT adapter weights + # For PEFT models, the path should be a directory containing adapter files + from peft import PeftModel + + # Get the base model from the PEFT wrapper + base_model = self.model.get_base_model() + + # Load the PEFT model from the saved path + self.model = PeftModel.from_pretrained(base_model, path) + self.model = self.model.to(self.device) + else: + # Load base model state dict + state_dict = torch.load(path, map_location=self.device, weights_only=True) + self.model.load_state_dict(state_dict) + + # Ensure model is in evaluation mode + self.model.eval() + + def save_model(self, path: str) -> None: + """ + Save the current model to disk. + + For LoRA models, saves the PEFT adapter weights. For base models, + saves the full state dictionary. + + Args: + path: Path where the model should be saved. + For LoRA models, this should be a directory path where + adapter_config.json and adapter_model.bin will be saved. + For base models, this should be a .pt or .pth file path. + + Example: + >>> controller = BERTMetaController(use_lora=True) + >>> controller.save_model("/path/to/lora_adapter") + >>> controller = BERTMetaController(use_lora=False) + >>> controller.save_model("/path/to/model.pt") + """ + if self.use_lora: + # Save PEFT adapter weights + # This saves only the LoRA adapter weights, not the full model + self.model.save_pretrained(path) + else: + # Save base model state dict + torch.save(self.model.state_dict(), path) + + def clear_cache(self) -> None: + """ + Clear the tokenization cache. + + This method removes all cached tokenized inputs, freeing memory. + Useful when processing many different feature combinations or + when memory usage is a concern. + + Example: + >>> controller = BERTMetaController() + >>> # After many predictions... + >>> controller.clear_cache() + >>> info = controller.get_cache_info() + >>> info['cache_size'] == 0 + True + """ + self._tokenization_cache.clear() + + def get_cache_info(self) -> dict[str, Any]: + """ + Get information about the current tokenization cache. + + Returns: + Dictionary containing cache statistics: + - cache_size: Number of cached tokenizations + - cache_keys: List of cached text inputs (truncated for display) + + Example: + >>> controller = BERTMetaController() + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.8, + ... trm_confidence=0.6, + ... mcts_value=0.75, + ... consensus_score=0.7, + ... last_agent='hrm', + ... iteration=2, + ... query_length=150, + ... has_rag_context=True + ... ) + >>> _ = controller.predict(features) + >>> info = controller.get_cache_info() + >>> 'cache_size' in info + True + >>> info['cache_size'] >= 1 + True + """ + # Truncate keys for display (first 50 chars) + truncated_keys = [key[:50] + "..." if len(key) > 50 else key for key in self._tokenization_cache] + + return { + "cache_size": len(self._tokenization_cache), + "cache_keys": truncated_keys, + } + + def get_trainable_parameters(self) -> dict[str, int]: + """ + Get the number of trainable and total parameters in the model. + + This is particularly useful for LoRA models to see the efficiency + gains from using low-rank adaptation. + + Returns: + Dictionary containing: + - total_params: Total number of parameters in the model + - trainable_params: Number of trainable parameters + - trainable_percentage: Percentage of parameters that are trainable + + Example: + >>> controller = BERTMetaController(use_lora=True) + >>> params = controller.get_trainable_parameters() + >>> params['trainable_percentage'] < 10.0 # LoRA trains <10% of params + True + """ + total_params = sum(p.numel() for p in self.model.parameters()) + trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + trainable_percentage = (trainable_params / total_params) * 100 if total_params > 0 else 0.0 + + return { + "total_params": total_params, + "trainable_params": trainable_params, + "trainable_percentage": round(trainable_percentage, 2), + } diff --git a/src/agents/meta_controller/config_loader.py b/src/agents/meta_controller/config_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..5df65ebacc42c944fa5a599142267350529ec6f5 --- /dev/null +++ b/src/agents/meta_controller/config_loader.py @@ -0,0 +1,304 @@ +""" +Configuration loader for the Neural Meta-Controller framework. + +This module provides dataclass-based configuration management for the Meta-Controller, +supporting both RNN and BERT-based neural network controllers with comprehensive +validation and serialization capabilities. +""" + +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + +import yaml + + +@dataclass +class RNNConfig: + """ + Configuration for RNN-based Meta-Controller. + + Attributes: + hidden_dim: Hidden dimension size for RNN layers. Default is 64. + num_layers: Number of RNN layers. Default is 1. + dropout: Dropout rate for regularization. Default is 0.1. + model_path: Optional path to a pre-trained model file. None for untrained model. + """ + + hidden_dim: int = 64 + num_layers: int = 1 + dropout: float = 0.1 + model_path: str | None = None + + +@dataclass +class BERTConfig: + """ + Configuration for BERT-based Meta-Controller with LoRA fine-tuning. + + Attributes: + model_name: Name of the pre-trained BERT model from HuggingFace. + Default is "prajjwal1/bert-mini" for lightweight deployment. + use_lora: Whether to use LoRA (Low-Rank Adaptation) for efficient fine-tuning. + Default is True. + lora_r: LoRA rank parameter. Controls the rank of the low-rank matrices. + Default is 4. + lora_alpha: LoRA alpha parameter. Scaling factor for LoRA weights. + Default is 16. + lora_dropout: Dropout rate for LoRA layers. Default is 0.1. + model_path: Optional path to a trained LoRA adapter. None for base model only. + """ + + model_name: str = "prajjwal1/bert-mini" + use_lora: bool = True + lora_r: int = 4 + lora_alpha: int = 16 + lora_dropout: float = 0.1 + model_path: str | None = None + + +@dataclass +class InferenceConfig: + """ + Configuration for inference settings. + + Attributes: + device: Device to use for inference ("cpu", "cuda", "cuda:0", etc.). + None for auto-detection based on available hardware. + seed: Random seed for reproducibility. Default is 42. + """ + + device: str | None = None + seed: int = 42 + + +@dataclass +class MetaControllerConfig: + """ + Main configuration for the Neural Meta-Controller framework. + + This configuration controls the behavior of the Meta-Controller, including + which type of neural network to use (RNN or BERT), fallback behavior, + and specific model parameters. + + Attributes: + enabled: Whether the neural Meta-Controller is enabled. Default is False + for backward compatibility with rule-based systems. + type: Type of neural network controller ("rnn" or "bert"). Default is "rnn". + fallback_to_rule_based: Whether to fall back to rule-based selection on errors. + Default is True for robustness. + rnn: Configuration for RNN-based controller. + bert: Configuration for BERT-based controller. + inference: Configuration for inference settings. + """ + + enabled: bool = False + type: str = "rnn" # "rnn" or "bert" + fallback_to_rule_based: bool = True + rnn: RNNConfig = field(default_factory=RNNConfig) + bert: BERTConfig = field(default_factory=BERTConfig) + inference: InferenceConfig = field(default_factory=InferenceConfig) + + +class MetaControllerConfigLoader: + """ + Loader class for Meta-Controller configuration. + + Provides methods for loading configuration from YAML files or dictionaries, + converting configuration to dictionaries, and validating configuration values. + + Example: + >>> loader = MetaControllerConfigLoader() + >>> config = loader.load_from_yaml("config/meta_controller.yaml") + >>> print(config.type) + 'rnn' + >>> config.validate() + """ + + @staticmethod + def load_from_yaml(path: str) -> MetaControllerConfig: + """ + Load Meta-Controller configuration from a YAML file. + + Args: + path: Path to the YAML configuration file. + + Returns: + MetaControllerConfig: Loaded and parsed configuration object. + + Raises: + FileNotFoundError: If the specified file does not exist. + yaml.YAMLError: If the file contains invalid YAML. + KeyError: If the 'meta_controller' key is missing from the file. + + Example: + >>> config = MetaControllerConfigLoader.load_from_yaml("config/meta_controller.yaml") + >>> print(config.enabled) + False + """ + yaml_path = Path(path) + + if not yaml_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {path}") + + with open(yaml_path) as f: + raw_config = yaml.safe_load(f) + + if "meta_controller" not in raw_config: + raise KeyError("Configuration file must contain 'meta_controller' key") + + return MetaControllerConfigLoader.load_from_dict(raw_config["meta_controller"]) + + @staticmethod + def load_from_dict(config_dict: dict[str, Any]) -> MetaControllerConfig: + """ + Load Meta-Controller configuration from a dictionary. + + Args: + config_dict: Dictionary containing configuration values. + + Returns: + MetaControllerConfig: Parsed configuration object with defaults + applied for missing values. + + Example: + >>> config_dict = { + ... 'enabled': True, + ... 'type': 'bert', + ... 'bert': {'model_name': 'bert-base-uncased'} + ... } + >>> config = MetaControllerConfigLoader.load_from_dict(config_dict) + >>> print(config.type) + 'bert' + """ + # Parse nested configurations + rnn_config = RNNConfig(**config_dict.get("rnn", {})) + bert_config = BERTConfig(**config_dict.get("bert", {})) + inference_config = InferenceConfig(**config_dict.get("inference", {})) + + # Create main config with nested configs + return MetaControllerConfig( + enabled=config_dict.get("enabled", False), + type=config_dict.get("type", "rnn"), + fallback_to_rule_based=config_dict.get("fallback_to_rule_based", True), + rnn=rnn_config, + bert=bert_config, + inference=inference_config, + ) + + @staticmethod + def to_dict(config: MetaControllerConfig) -> dict[str, Any]: + """ + Convert a MetaControllerConfig object to a dictionary. + + Args: + config: MetaControllerConfig object to convert. + + Returns: + Dict[str, Any]: Dictionary representation of the configuration. + + Example: + >>> config = MetaControllerConfig(enabled=True, type='bert') + >>> config_dict = MetaControllerConfigLoader.to_dict(config) + >>> print(config_dict['enabled']) + True + """ + return asdict(config) + + @staticmethod + def validate(config: MetaControllerConfig) -> None: + """ + Validate the Meta-Controller configuration. + + Checks that: + - The controller type is valid ("rnn" or "bert") + - Model paths exist if specified + - Numeric parameters are within valid ranges + + Args: + config: MetaControllerConfig object to validate. + + Raises: + ValueError: If the configuration contains invalid values. + FileNotFoundError: If specified model paths do not exist. + + Example: + >>> config = MetaControllerConfig(type='invalid') + >>> MetaControllerConfigLoader.validate(config) + ValueError: Invalid controller type 'invalid'. Must be 'rnn' or 'bert'. + """ + # Validate controller type + valid_types = ["rnn", "bert"] + if config.type not in valid_types: + raise ValueError(f"Invalid controller type '{config.type}'. Must be one of: {valid_types}") + + # Validate RNN config + if config.rnn.hidden_dim <= 0: + raise ValueError(f"RNN hidden_dim must be positive, got {config.rnn.hidden_dim}") + if config.rnn.num_layers <= 0: + raise ValueError(f"RNN num_layers must be positive, got {config.rnn.num_layers}") + if not 0.0 <= config.rnn.dropout <= 1.0: + raise ValueError(f"RNN dropout must be between 0 and 1, got {config.rnn.dropout}") + if config.rnn.model_path is not None: + rnn_path = Path(config.rnn.model_path) + if not rnn_path.exists(): + raise FileNotFoundError(f"RNN model path does not exist: {config.rnn.model_path}") + + # Validate BERT config + if config.bert.lora_r <= 0: + raise ValueError(f"BERT lora_r must be positive, got {config.bert.lora_r}") + if config.bert.lora_alpha <= 0: + raise ValueError(f"BERT lora_alpha must be positive, got {config.bert.lora_alpha}") + if not 0.0 <= config.bert.lora_dropout <= 1.0: + raise ValueError(f"BERT lora_dropout must be between 0 and 1, got {config.bert.lora_dropout}") + if config.bert.model_path is not None: + bert_path = Path(config.bert.model_path) + if not bert_path.exists(): + raise FileNotFoundError(f"BERT model path does not exist: {config.bert.model_path}") + + # Validate inference config + if config.inference.device is not None: + valid_devices = ["cpu", "cuda", "mps"] + # Check if device starts with a valid prefix (e.g., "cuda:0", "cuda:1") + device_base = config.inference.device.split(":")[0] + if device_base not in valid_devices: + raise ValueError(f"Invalid device '{config.inference.device}'. Must start with one of: {valid_devices}") + + if not isinstance(config.inference.seed, int) or config.inference.seed < 0: + raise ValueError(f"Inference seed must be a non-negative integer, got {config.inference.seed}") + + @staticmethod + def save_to_yaml(config: MetaControllerConfig, path: str) -> None: + """ + Save a MetaControllerConfig object to a YAML file. + + Args: + config: MetaControllerConfig object to save. + path: Path where the YAML file will be saved. + + Example: + >>> config = MetaControllerConfig(enabled=True) + >>> MetaControllerConfigLoader.save_to_yaml(config, "my_config.yaml") + """ + yaml_path = Path(path) + yaml_path.parent.mkdir(parents=True, exist_ok=True) + + config_dict = {"meta_controller": MetaControllerConfigLoader.to_dict(config)} + + with open(yaml_path, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False) + + @staticmethod + def get_default_config() -> MetaControllerConfig: + """ + Get a default MetaControllerConfig with all default values. + + Returns: + MetaControllerConfig: Configuration object with default values. + + Example: + >>> config = MetaControllerConfigLoader.get_default_config() + >>> print(config.enabled) + False + """ + return MetaControllerConfig() diff --git a/src/agents/meta_controller/rnn_controller.py b/src/agents/meta_controller/rnn_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..43e6ef1b4b11e982319e4b0df55df8c62333d602 --- /dev/null +++ b/src/agents/meta_controller/rnn_controller.py @@ -0,0 +1,345 @@ +""" +RNN-based Meta-Controller for dynamic agent selection. + +This module provides a GRU-based recurrent neural network meta-controller +that learns to select the optimal agent (HRM, TRM, or MCTS) based on +sequential patterns in the agent state features. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.agents.meta_controller.base import ( + AbstractMetaController, + MetaControllerFeatures, + MetaControllerPrediction, +) +from src.agents.meta_controller.utils import features_to_tensor + + +class RNNMetaControllerModel(nn.Module): + """ + GRU-based neural network model for meta-controller predictions. + + This model uses a Gated Recurrent Unit (GRU) to capture sequential + patterns in agent state features and predict which agent should be + selected next. + + Architecture: + - GRU layer for sequence processing + - Dropout for regularization + - Linear layer for classification + + Attributes: + gru: GRU recurrent layer for processing sequences. + dropout: Dropout layer for regularization. + fc: Fully connected output layer. + hidden_dim: Dimension of the hidden state. + num_layers: Number of GRU layers. + """ + + def __init__( + self, + input_dim: int = 10, + hidden_dim: int = 64, + num_layers: int = 1, + num_agents: int = 3, + dropout: float = 0.1, + ) -> None: + """ + Initialize the RNN meta-controller model. + + Args: + input_dim: Dimension of input features. Defaults to 10. + hidden_dim: Dimension of GRU hidden state. Defaults to 64. + num_layers: Number of stacked GRU layers. Defaults to 1. + num_agents: Number of agents to choose from. Defaults to 3. + dropout: Dropout probability for regularization. Defaults to 0.1. + """ + super().__init__() + + self.hidden_dim = hidden_dim + self.num_layers = num_layers + + # GRU layer for sequence processing + self.gru = nn.GRU( + input_size=input_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + ) + + # Dropout for regularization + self.dropout = nn.Dropout(p=dropout) + + # Linear output layer for classification + self.fc = nn.Linear(hidden_dim, num_agents) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the model. + + Processes input features through GRU and produces agent selection logits. + + Args: + x: Input tensor of shape (batch_size, features) or + (batch_size, seq_len, features). + + Returns: + Logits tensor of shape (batch_size, num_agents). + Note: Returns raw logits, NOT softmax probabilities. + + Example: + >>> model = RNNMetaControllerModel() + >>> x = torch.randn(4, 10) # batch of 4, 10 features + >>> logits = model(x) + >>> logits.shape + torch.Size([4, 3]) + """ + # Handle 2D input by adding sequence dimension + if x.dim() == 2: + # Shape: (batch_size, features) -> (batch_size, 1, features) + x = x.unsqueeze(1) + + # Pass through GRU + # output shape: (batch_size, seq_len, hidden_dim) + # hidden shape: (num_layers, batch_size, hidden_dim) + output, hidden = self.gru(x) + + # Take the final hidden state from the last layer + # Shape: (batch_size, hidden_dim) + final_hidden = hidden[-1] if self.num_layers > 1 else hidden.squeeze(0) + + # Apply dropout + dropped = self.dropout(final_hidden) + + # Apply linear layer to get logits + logits = self.fc(dropped) + + return logits + + +class RNNMetaController(AbstractMetaController): + """ + RNN-based meta-controller using GRU for agent selection. + + This controller uses a recurrent neural network to learn patterns in + agent state sequences and predict the optimal agent for the current + situation. It supports both CPU and GPU execution. + + Attributes: + device: PyTorch device (CPU or CUDA) for tensor operations. + hidden_dim: Dimension of GRU hidden state. + num_layers: Number of GRU layers. + dropout: Dropout probability. + model: The underlying RNNMetaControllerModel. + hidden_state: Optional hidden state for sequence tracking. + + Example: + >>> controller = RNNMetaController(name="RNNController", seed=42) + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.8, + ... trm_confidence=0.6, + ... mcts_value=0.75, + ... consensus_score=0.7, + ... last_agent='hrm', + ... iteration=2, + ... query_length=150, + ... has_rag_context=True + ... ) + >>> prediction = controller.predict(features) + >>> prediction.agent in ['hrm', 'trm', 'mcts'] + True + >>> 0.0 <= prediction.confidence <= 1.0 + True + """ + + def __init__( + self, + name: str = "RNNMetaController", + seed: int = 42, + hidden_dim: int = 64, + num_layers: int = 1, + dropout: float = 0.1, + device: str | None = None, + ) -> None: + """ + Initialize the RNN meta-controller. + + Args: + name: Name identifier for this controller. Defaults to "RNNMetaController". + seed: Random seed for reproducibility. Defaults to 42. + hidden_dim: Dimension of GRU hidden state. Defaults to 64. + num_layers: Number of GRU layers. Defaults to 1. + dropout: Dropout probability. Defaults to 0.1. + device: Device to run model on ('cpu', 'cuda', 'mps', etc.). + If None, auto-detects best available device. + """ + super().__init__(name=name, seed=seed) + + # Set random seed for reproducibility + torch.manual_seed(seed) + + # Auto-detect device if not specified + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + + # Store configuration + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.dropout = dropout + + # Initialize model + self.model = RNNMetaControllerModel( + input_dim=10, # Fixed based on features_to_tensor output + hidden_dim=hidden_dim, + num_layers=num_layers, + num_agents=len(self.AGENT_NAMES), + dropout=dropout, + ) + + # Move model to device + self.model = self.model.to(self.device) + + # Set model to evaluation mode + self.model.eval() + + # Initialize hidden state for sequence tracking + self.hidden_state: torch.Tensor | None = None + + def predict(self, features: MetaControllerFeatures) -> MetaControllerPrediction: + """ + Predict which agent should handle the current query. + + Converts features to tensor format, runs through the GRU model, + and returns a prediction with confidence scores. + + Args: + features: Features extracted from the current agent state. + + Returns: + Prediction containing the selected agent, confidence score, + and probability distribution over all agents. + + Example: + >>> controller = RNNMetaController() + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.9, + ... trm_confidence=0.3, + ... mcts_value=0.5, + ... consensus_score=0.8, + ... last_agent='none', + ... iteration=0, + ... query_length=100, + ... has_rag_context=False + ... ) + >>> pred = controller.predict(features) + >>> isinstance(pred.agent, str) + True + >>> isinstance(pred.confidence, float) + True + >>> len(pred.probabilities) == 3 + True + """ + # Convert features to tensor + feature_tensor = features_to_tensor(features) + + # Add batch dimension: (10,) -> (1, 10) + feature_tensor = feature_tensor.unsqueeze(0) + + # Move to device + feature_tensor = feature_tensor.to(self.device) + + # Perform inference without gradient tracking + with torch.no_grad(): + # Get logits from model + logits = self.model(feature_tensor) + + # Apply softmax to get probabilities + probabilities = F.softmax(logits, dim=-1) + + # Get predicted agent index (argmax) + predicted_idx = torch.argmax(probabilities, dim=-1).item() + + # Extract confidence for selected agent + confidence = probabilities[0, predicted_idx].item() + + # Create probability dictionary + prob_dict: dict[str, float] = {} + for i, agent_name in enumerate(self.AGENT_NAMES): + prob_dict[agent_name] = probabilities[0, i].item() + + # Get agent name + selected_agent = self.AGENT_NAMES[predicted_idx] + + return MetaControllerPrediction( + agent=selected_agent, + confidence=float(confidence), + probabilities=prob_dict, + ) + + def load_model(self, path: str) -> None: + """ + Load a trained model from disk. + + Loads the model state dictionary from the specified path and + sets the model to evaluation mode. + + Args: + path: Path to the saved model file (.pt or .pth). + + Raises: + FileNotFoundError: If the model file does not exist. + RuntimeError: If the state dict is incompatible with the model. + + Example: + >>> controller = RNNMetaController() + >>> controller.load_model("/path/to/model.pt") + """ + # Load state dict with appropriate device mapping + state_dict = torch.load(path, map_location=self.device, weights_only=True) + + # Load into model + self.model.load_state_dict(state_dict) + + # Ensure model is in evaluation mode + self.model.eval() + + def save_model(self, path: str) -> None: + """ + Save the current model to disk. + + Saves the model state dictionary to the specified path. + + Args: + path: Path where the model should be saved (.pt or .pth). + + Example: + >>> controller = RNNMetaController() + >>> controller.save_model("/path/to/model.pt") + """ + torch.save(self.model.state_dict(), path) + + def reset_hidden_state(self) -> None: + """ + Reset the hidden state for sequence tracking. + + This method clears any accumulated hidden state, useful when + starting a new conversation or resetting the controller state. + + Example: + >>> controller = RNNMetaController() + >>> controller.reset_hidden_state() + >>> controller.hidden_state is None + True + """ + self.hidden_state = None diff --git a/src/agents/meta_controller/utils.py b/src/agents/meta_controller/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..faa5cab1ab7fd7bb5e73d94e086e35b31625e1cd --- /dev/null +++ b/src/agents/meta_controller/utils.py @@ -0,0 +1,201 @@ +""" +Utility functions for Neural Meta-Controller feature processing. + +Provides functions for normalizing, encoding, and converting features +into formats suitable for different neural network architectures. +""" + +import torch + +from src.agents.meta_controller.base import MetaControllerFeatures + + +def normalize_features(features: MetaControllerFeatures) -> list[float]: + """ + Normalize meta-controller features to a 10-dimensional vector in range [0, 1]. + + The normalization strategy: + - Confidence scores (hrm, trm, mcts_value, consensus): Already 0-1, clipped + - last_agent: Encoded as 3 one-hot values (hrm=0, trm=1, mcts=2) + - iteration: Normalized to 0-1 assuming max 20 iterations + - query_length: Normalized to 0-1 assuming max 10000 characters + - has_rag_context: Binary 0 or 1 + + Output vector structure (10 dimensions): + [hrm_conf, trm_conf, mcts_value, consensus, last_hrm, last_trm, last_mcts, + iteration_norm, query_length_norm, has_rag] + + Args: + features: MetaControllerFeatures instance to normalize. + + Returns: + List of 10 floats, each normalized to range [0, 1]. + + Example: + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.8, + ... trm_confidence=0.6, + ... mcts_value=0.75, + ... consensus_score=0.7, + ... last_agent='hrm', + ... iteration=2, + ... query_length=150, + ... has_rag_context=True + ... ) + >>> normalized = normalize_features(features) + >>> len(normalized) + 10 + >>> all(0.0 <= v <= 1.0 for v in normalized) + True + """ + # Clip confidence scores to [0, 1] + hrm_conf = max(0.0, min(1.0, features.hrm_confidence)) + trm_conf = max(0.0, min(1.0, features.trm_confidence)) + mcts_val = max(0.0, min(1.0, features.mcts_value)) + consensus = max(0.0, min(1.0, features.consensus_score)) + + # One-hot encode last_agent (3 dimensions) + last_agent_onehot = one_hot_encode_agent(features.last_agent) + + # Normalize iteration (assuming max 20 iterations) + max_iterations = 20 + iteration_norm = max(0.0, min(1.0, features.iteration / max_iterations)) + + # Normalize query length (assuming max 10000 characters) + max_query_length = 10000 + query_length_norm = max(0.0, min(1.0, features.query_length / max_query_length)) + + # Binary for has_rag_context + has_rag = 1.0 if features.has_rag_context else 0.0 + + # Combine into 10-dimensional vector + return [ + hrm_conf, + trm_conf, + mcts_val, + consensus, + last_agent_onehot[0], # hrm + last_agent_onehot[1], # trm + last_agent_onehot[2], # mcts + iteration_norm, + query_length_norm, + has_rag, + ] + + +def one_hot_encode_agent(agent: str) -> list[float]: + """ + One-hot encode an agent name into a 3-dimensional vector. + + Encoding: + - 'hrm' -> [1.0, 0.0, 0.0] + - 'trm' -> [0.0, 1.0, 0.0] + - 'mcts' -> [0.0, 0.0, 1.0] + - 'none' or other -> [0.0, 0.0, 0.0] + + Args: + agent: Agent name string ('hrm', 'trm', 'mcts', or 'none'). + + Returns: + List of 3 floats representing the one-hot encoding. + + Example: + >>> one_hot_encode_agent('hrm') + [1.0, 0.0, 0.0] + >>> one_hot_encode_agent('trm') + [0.0, 1.0, 0.0] + >>> one_hot_encode_agent('mcts') + [0.0, 0.0, 1.0] + >>> one_hot_encode_agent('none') + [0.0, 0.0, 0.0] + """ + agent_lower = agent.lower() + + if agent_lower == "hrm": # noqa: SIM116 + return [1.0, 0.0, 0.0] + elif agent_lower == "trm": + return [0.0, 1.0, 0.0] + elif agent_lower == "mcts": + return [0.0, 0.0, 1.0] + else: + # 'none' or unknown agent + return [0.0, 0.0, 0.0] + + +def features_to_tensor(features: MetaControllerFeatures) -> torch.Tensor: + """ + Convert meta-controller features to a PyTorch tensor. + + Uses normalize_features internally to create a normalized 10-dimensional + tensor suitable for neural network input. + + Args: + features: MetaControllerFeatures instance to convert. + + Returns: + PyTorch tensor of shape (10,) with float32 dtype. + + Example: + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.8, + ... trm_confidence=0.6, + ... mcts_value=0.75, + ... consensus_score=0.7, + ... last_agent='hrm', + ... iteration=2, + ... query_length=150, + ... has_rag_context=True + ... ) + >>> tensor = features_to_tensor(features) + >>> tensor.shape + torch.Size([10]) + >>> tensor.dtype + torch.float32 + """ + normalized = normalize_features(features) + return torch.tensor(normalized, dtype=torch.float32) + + +def features_to_text(features: MetaControllerFeatures) -> str: + """ + Convert meta-controller features to structured text format. + + Creates a human-readable text representation suitable for text-based + models like BERT or other language models. + + Args: + features: MetaControllerFeatures instance to convert. + + Returns: + Structured text string describing the features. + + Example: + >>> features = MetaControllerFeatures( + ... hrm_confidence=0.8, + ... trm_confidence=0.6, + ... mcts_value=0.75, + ... consensus_score=0.7, + ... last_agent='hrm', + ... iteration=2, + ... query_length=150, + ... has_rag_context=True + ... ) + >>> text = features_to_text(features) + >>> 'HRM confidence: 0.800' in text + True + """ + rag_status = "available" if features.has_rag_context else "not available" + + text = ( + f"Agent State Features:\n" + f"HRM confidence: {features.hrm_confidence:.3f}\n" + f"TRM confidence: {features.trm_confidence:.3f}\n" + f"MCTS value: {features.mcts_value:.3f}\n" + f"Consensus score: {features.consensus_score:.3f}\n" + f"Last agent used: {features.last_agent}\n" + f"Current iteration: {features.iteration}\n" + f"Query length: {features.query_length} characters\n" + f"RAG context: {rag_status}" + ) + + return text diff --git a/src/agents/trm_agent.py b/src/agents/trm_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..085e5936622c94c12477e67bf2391ee997401f08 --- /dev/null +++ b/src/agents/trm_agent.py @@ -0,0 +1,395 @@ +""" +Tiny Recursive Model (TRM) Agent. + +Implements recursive refinement with: +- Deep supervision at all recursion levels +- Convergence detection +- Memory-efficient recursion +- Iterative improvement mechanism + +Based on principles from: +- "Recursive Refinement Networks" +- "Deep Supervision for Neural Networks" +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ..training.system_config import TRMConfig + + +@dataclass +class TRMOutput: + """Output from TRM recursive processing.""" + + final_prediction: torch.Tensor # Final refined output + intermediate_predictions: list[torch.Tensor] # Predictions at each recursion + recursion_depth: int # Actual depth used + converged: bool # Whether convergence was achieved + convergence_step: int # Step at which convergence occurred + residual_norms: list[float] # L2 norms of residuals at each step + + +class RecursiveBlock(nn.Module): + """ + Core recursive processing block. + + Applies the same transformation repeatedly, with residual connections. + """ + + def __init__(self, config: TRMConfig): + super().__init__() + self.config = config + + # Main processing pathway + self.transform = nn.Sequential( + nn.Linear(config.latent_dim, config.hidden_dim), + nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(), + nn.GELU(), + nn.Dropout(config.dropout), + nn.Linear(config.hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(), + ) + + # Residual scaling (learned) + self.residual_scale = nn.Parameter(torch.ones(1)) + + def forward(self, x: torch.Tensor, iteration: int = 0) -> torch.Tensor: # noqa: ARG002 + """ + Apply recursive transformation. + + Args: + x: Input tensor [batch, ..., latent_dim] + iteration: Current recursion iteration (reserved for future iteration-dependent behavior) + + Returns: + Refined tensor [batch, ..., latent_dim] + """ + # Residual connection with learned scaling + residual = self.transform(x) + return x + self.residual_scale * residual + + +class DeepSupervisionHead(nn.Module): + """ + Supervision head for intermediate predictions. + + Enables training signal at each recursion level. + """ + + def __init__(self, latent_dim: int, output_dim: int): + super().__init__() + self.head = nn.Sequential( + nn.Linear(latent_dim, latent_dim // 2), + nn.ReLU(), + nn.Linear(latent_dim // 2, output_dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Generate prediction from latent state.""" + return self.head(x) + + +class TRMAgent(nn.Module): + """ + Tiny Recursive Model for iterative refinement. + + Features: + - Shared weights across recursions (parameter efficiency) + - Deep supervision at all levels + - Automatic convergence detection + - Residual connections for stable gradients + """ + + def __init__(self, config: TRMConfig, output_dim: int | None = None, device: str = "cpu"): + super().__init__() + self.config = config + self.device = device + self.output_dim = output_dim or config.latent_dim + + # Initial encoding + self.encoder = nn.Sequential( + nn.Linear(config.latent_dim, config.hidden_dim), + nn.LayerNorm(config.hidden_dim) if config.use_layer_norm else nn.Identity(), + nn.GELU(), + nn.Linear(config.hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim) if config.use_layer_norm else nn.Identity(), + ) + + # Shared recursive block + self.recursive_block = RecursiveBlock(config) + + # Deep supervision heads (one per recursion level) + if config.deep_supervision: + self.supervision_heads = nn.ModuleList( + [DeepSupervisionHead(config.latent_dim, self.output_dim) for _ in range(config.num_recursions)] + ) + else: + # Single output head + self.output_head = DeepSupervisionHead(config.latent_dim, self.output_dim) + + self.to(device) + + def forward( + self, + x: torch.Tensor, + num_recursions: int | None = None, + check_convergence: bool = True, + ) -> TRMOutput: + """ + Process input through recursive refinement. + + Args: + x: Input tensor [batch, ..., latent_dim] + num_recursions: Number of recursions (defaults to config) + check_convergence: Whether to check for early convergence + + Returns: + TRMOutput with final and intermediate predictions + """ + num_recursions = num_recursions or self.config.num_recursions + + # Initial encoding + latent = self.encoder(x) + previous_latent = latent.clone() + + # Tracking + intermediate_predictions = [] + residual_norms = [] + converged = False + convergence_step = num_recursions + + # Recursive refinement + for i in range(num_recursions): + # Apply recursive transformation + latent = self.recursive_block(latent, iteration=i) + + # Generate intermediate prediction + if self.config.deep_supervision and i < len(self.supervision_heads): + pred = self.supervision_heads[i](latent) + else: + pred = self.output_head(latent) + + intermediate_predictions.append(pred) + + # Check convergence + if check_convergence and i >= self.config.min_recursions: + residual = latent - previous_latent + residual_norm = torch.norm(residual, p=2, dim=-1).mean().item() + residual_norms.append(residual_norm) + + if residual_norm < self.config.convergence_threshold: + converged = True + convergence_step = i + 1 + break + + previous_latent = latent.clone() + + # Final prediction + final_pred = intermediate_predictions[-1] + + return TRMOutput( + final_prediction=final_pred, + intermediate_predictions=intermediate_predictions, + recursion_depth=len(intermediate_predictions), + converged=converged, + convergence_step=convergence_step, + residual_norms=residual_norms, + ) + + async def refine_solution( + self, + initial_prediction: torch.Tensor, + num_recursions: int | None = None, + convergence_threshold: float | None = None, + ) -> tuple[torch.Tensor, dict]: + """ + Refine an initial prediction through recursive processing. + + Args: + initial_prediction: Initial solution [batch, ..., latent_dim] + num_recursions: Maximum recursions (optional) + convergence_threshold: Convergence threshold (optional) + + Returns: + refined_solution: Final refined prediction + info: Dictionary with refinement metadata + """ + # Temporarily override convergence threshold if provided + original_threshold = self.config.convergence_threshold + if convergence_threshold is not None: + self.config.convergence_threshold = convergence_threshold + + # Process + output = self.forward( + initial_prediction, + num_recursions=num_recursions, + check_convergence=True, + ) + + # Restore original threshold + self.config.convergence_threshold = original_threshold + + info = { + "converged": output.converged, + "convergence_step": output.convergence_step, + "total_recursions": output.recursion_depth, + "final_residual": output.residual_norms[-1] if output.residual_norms else None, + "refinement_path": output.residual_norms, + } + + return output.final_prediction, info + + def get_parameter_count(self) -> int: + """Return total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class TRMLoss(nn.Module): + """ + Deep supervision loss for TRM. + + Applies weighted supervision at all recursion levels, + with exponential decay for deeper levels. + """ + + def __init__( + self, + task_loss_fn: nn.Module, + supervision_weight_decay: float = 0.5, + final_weight: float = 1.0, + ): + """ + Initialize TRM loss. + + Args: + task_loss_fn: Base loss function (e.g., MSE, CrossEntropy) + supervision_weight_decay: Decay factor for intermediate losses + final_weight: Weight for final prediction loss + """ + super().__init__() + self.task_loss_fn = task_loss_fn + self.supervision_weight_decay = supervision_weight_decay + self.final_weight = final_weight + + def forward(self, trm_output: TRMOutput, targets: torch.Tensor) -> tuple[torch.Tensor, dict]: + """ + Compute deep supervision loss. + + Args: + trm_output: Output from TRM forward pass + targets: Ground truth targets + + Returns: + total_loss: Combined loss + loss_dict: Dictionary of loss components + """ + # Final prediction loss (highest weight) + final_loss = self.task_loss_fn(trm_output.final_prediction, targets) + total_loss = self.final_weight * final_loss + + # Intermediate supervision losses + intermediate_losses = [] + num_intermediate = len(trm_output.intermediate_predictions) - 1 + + for i, pred in enumerate(trm_output.intermediate_predictions[:-1]): + # Exponential decay: earlier predictions get lower weight + weight = self.supervision_weight_decay ** (num_intermediate - i) + loss = self.task_loss_fn(pred, targets) + intermediate_losses.append(loss.item()) + total_loss = total_loss + weight * loss + + loss_dict = { + "total": total_loss.item(), + "final": final_loss.item(), + "intermediate_mean": (sum(intermediate_losses) / len(intermediate_losses) if intermediate_losses else 0.0), + "recursion_depth": trm_output.recursion_depth, + "converged": trm_output.converged, + "convergence_step": trm_output.convergence_step, + } + + return total_loss, loss_dict + + +def create_trm_agent(config: TRMConfig, output_dim: int | None = None, device: str = "cpu") -> TRMAgent: + """ + Factory function to create and initialize TRM agent. + + Args: + config: TRM configuration + output_dim: Output dimension (defaults to latent_dim) + device: Device to place model on + + Returns: + Initialized TRMAgent + """ + agent = TRMAgent(config, output_dim, device) + + # Initialize weights with Xavier/He initialization + def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + + agent.apply(init_weights) + + return agent + + +# Utility functions for integration +class TRMRefinementWrapper: + """ + Wrapper for using TRM as a refinement step in pipelines. + + Provides a clean interface for integrating TRM into larger systems. + """ + + def __init__(self, trm_agent: TRMAgent, device: str = "cpu"): + self.trm_agent = trm_agent + self.device = device + self.trm_agent.eval() + + @torch.no_grad() + async def refine( + self, + predictions: torch.Tensor, + num_iterations: int = 10, + return_path: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + """ + Refine predictions using TRM. + + Args: + predictions: Initial predictions to refine + num_iterations: Number of refinement iterations + return_path: Whether to return intermediate predictions + + Returns: + refined_predictions or (refined_predictions, refinement_path) + """ + # Ensure predictions are on correct device + predictions = predictions.to(self.device) + + # Run TRM + output = self.trm_agent(predictions, num_recursions=num_iterations, check_convergence=True) + + if return_path: + return output.final_prediction, output.intermediate_predictions + return output.final_prediction + + def get_refinement_stats(self, predictions: torch.Tensor) -> dict: + """Get statistics about the refinement process.""" + with torch.no_grad(): + output = self.trm_agent(predictions, check_convergence=True) + + return { + "converged": output.converged, + "steps_to_convergence": output.convergence_step, + "final_residual": (output.residual_norms[-1] if output.residual_norms else None), + "total_refinement_iterations": output.recursion_depth, + } diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccafadde98d9c0869a91d085a484e0b453cd1cde --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,35 @@ +""" +API module for LangGraph Multi-Agent MCTS Framework. + +Provides: +- Authentication and authorization +- Rate limiting +- Error handling +- REST API endpoints +""" + +from src.api.exceptions import ( + AuthenticationError, + AuthorizationError, + ConfigurationError, + FrameworkError, + LLMError, + MCTSError, + RAGError, + RateLimitError, + TimeoutError, + ValidationError, +) + +__all__ = [ + "FrameworkError", + "ValidationError", + "AuthenticationError", + "AuthorizationError", + "RateLimitError", + "LLMError", + "MCTSError", + "RAGError", + "TimeoutError", + "ConfigurationError", +] diff --git a/src/api/auth.py b/src/api/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..a5738b95c45181ea190e646cd4a019601a92a73e --- /dev/null +++ b/src/api/auth.py @@ -0,0 +1,439 @@ +""" +Authentication and authorization layer for LangGraph Multi-Agent MCTS Framework. + +Provides: +- API key authentication with secure hashing +- JWT token support (optional) +- Rate limiting per client +- Role-based access control +""" + +import hashlib +import secrets +import time +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime, timedelta + +from src.api.exceptions import ( + AuthenticationError, + AuthorizationError, + RateLimitError, +) + + +@dataclass +class ClientInfo: + """Information about an authenticated client.""" + + client_id: str + roles: set[str] = field(default_factory=lambda: {"user"}) + created_at: datetime = field(default_factory=datetime.utcnow) + last_access: datetime = field(default_factory=datetime.utcnow) + request_count: int = 0 + + +@dataclass +class RateLimitConfig: + """Rate limiting configuration.""" + + requests_per_minute: int = 60 + requests_per_hour: int = 1000 + requests_per_day: int = 10000 + burst_limit: int = 100 # Max requests in 1 second + + +class APIKeyAuthenticator: + """ + API key-based authentication with secure hashing. + + Keys are stored as SHA-256 hashes to prevent exposure. + """ + + def __init__( + self, + valid_keys: list[str] | None = None, + rate_limit_config: RateLimitConfig | None = None, + ): + """ + Initialize authenticator. + + Args: + valid_keys: List of valid API keys (will be hashed) + rate_limit_config: Rate limiting configuration + """ + self._key_to_client: dict[str, ClientInfo] = {} + self._rate_limits: dict[str, list[float]] = defaultdict(list) + self.rate_limit_config = rate_limit_config or RateLimitConfig() + + # Hash and store initial keys + if valid_keys: + for i, key in enumerate(valid_keys): + client_id = f"client_{i}" + self._add_key(key, client_id) + + def _hash_key(self, api_key: str) -> str: + """ + Securely hash an API key. + + Uses SHA-256 with consistent encoding. + """ + return hashlib.sha256(api_key.encode("utf-8")).hexdigest() + + def _add_key(self, api_key: str, client_id: str, roles: set[str] | None = None) -> None: + """ + Add a new API key. + + Args: + api_key: Raw API key + client_id: Client identifier + roles: Set of roles (defaults to {"user"}) + """ + key_hash = self._hash_key(api_key) + self._key_to_client[key_hash] = ClientInfo( + client_id=client_id, + roles=roles or {"user"}, + ) + + def authenticate(self, api_key: str | None) -> ClientInfo: + """ + Authenticate an API key. + + Args: + api_key: API key to validate + + Returns: + ClientInfo for the authenticated client + + Raises: + AuthenticationError: If authentication fails + """ + if not api_key: + raise AuthenticationError( + user_message="API key is required", + internal_details="No API key provided in request", + ) + + # Constant-time comparison to prevent timing attacks + key_hash = self._hash_key(api_key) + + if key_hash not in self._key_to_client: + raise AuthenticationError( + user_message="Invalid API key", + internal_details=f"API key hash not found: {key_hash[:16]}...", + ) + + client_info = self._key_to_client[key_hash] + client_info.last_access = datetime.utcnow() + client_info.request_count += 1 + + # Check rate limits + self._check_rate_limit(client_info.client_id) + + return client_info + + def _check_rate_limit(self, client_id: str) -> None: + """ + Check if client has exceeded rate limits. + + Args: + client_id: Client identifier + + Raises: + RateLimitError: If rate limit exceeded + """ + now = time.time() + request_times = self._rate_limits[client_id] + + # Clean old entries + one_day_ago = now - 86400 + request_times = [t for t in request_times if t > one_day_ago] + self._rate_limits[client_id] = request_times + + # Check burst limit (1 second window) + one_second_ago = now - 1 + burst_count = sum(1 for t in request_times if t > one_second_ago) + if burst_count >= self.rate_limit_config.burst_limit: + raise RateLimitError( + user_message="Too many requests. Please slow down.", + internal_details=f"Client {client_id} exceeded burst limit: {burst_count}/{self.rate_limit_config.burst_limit}", + retry_after_seconds=1, + ) + + # Check per-minute limit + one_minute_ago = now - 60 + minute_count = sum(1 for t in request_times if t > one_minute_ago) + if minute_count >= self.rate_limit_config.requests_per_minute: + raise RateLimitError( + user_message="Rate limit exceeded. Please wait a minute.", + internal_details=f"Client {client_id} exceeded minute limit: {minute_count}/{self.rate_limit_config.requests_per_minute}", + retry_after_seconds=60, + ) + + # Check per-hour limit + one_hour_ago = now - 3600 + hour_count = sum(1 for t in request_times if t > one_hour_ago) + if hour_count >= self.rate_limit_config.requests_per_hour: + raise RateLimitError( + user_message="Hourly rate limit exceeded. Please try again later.", + internal_details=f"Client {client_id} exceeded hour limit: {hour_count}/{self.rate_limit_config.requests_per_hour}", + retry_after_seconds=3600, + ) + + # Check per-day limit + day_count = len(request_times) + if day_count >= self.rate_limit_config.requests_per_day: + raise RateLimitError( + user_message="Daily rate limit exceeded. Please try again tomorrow.", + internal_details=f"Client {client_id} exceeded day limit: {day_count}/{self.rate_limit_config.requests_per_day}", + retry_after_seconds=86400, + ) + + # Record this request + request_times.append(now) + + def require_auth(self, api_key: str | None) -> ClientInfo: + """ + Require authentication for a request. + + Convenience method that raises on failure. + + Args: + api_key: API key to validate + + Returns: + ClientInfo for authenticated client + + Raises: + AuthenticationError: If authentication fails + """ + return self.authenticate(api_key) + + def require_role(self, client_info: ClientInfo, required_role: str) -> None: + """ + Require a specific role for an operation. + + Args: + client_info: Authenticated client info + required_role: Role that is required + + Raises: + AuthorizationError: If client doesn't have required role + """ + if required_role not in client_info.roles: + raise AuthorizationError( + user_message="You do not have permission for this operation", + internal_details=f"Client {client_info.client_id} missing role: {required_role}", + required_permission=required_role, + ) + + def generate_api_key(self) -> str: + """ + Generate a secure random API key. + + Returns: + New API key (32 bytes hex = 64 characters) + """ + return secrets.token_hex(32) + + def revoke_key(self, api_key: str) -> bool: + """ + Revoke an API key. + + Args: + api_key: Key to revoke + + Returns: + True if key was revoked, False if not found + """ + key_hash = self._hash_key(api_key) + if key_hash in self._key_to_client: + del self._key_to_client[key_hash] + return True + return False + + def add_client( + self, + client_id: str, + roles: set[str] | None = None, + ) -> str: + """ + Add a new client and generate their API key. + + Args: + client_id: Unique client identifier + roles: Set of roles for the client + + Returns: + Generated API key (save this securely!) + """ + api_key = self.generate_api_key() + self._add_key(api_key, client_id, roles) + return api_key + + def get_client_stats(self, client_id: str) -> dict: + """ + Get statistics for a client. + + Args: + client_id: Client identifier + + Returns: + Dictionary with client statistics + """ + now = time.time() + request_times = self._rate_limits.get(client_id, []) + + return { + "total_requests_today": len([t for t in request_times if t > now - 86400]), + "requests_last_hour": len([t for t in request_times if t > now - 3600]), + "requests_last_minute": len([t for t in request_times if t > now - 60]), + } + + +class JWTAuthenticator: + """ + JWT token-based authentication. + + Note: Requires PyJWT library for full functionality. + This is a placeholder for JWT support. + """ + + def __init__(self, secret_key: str, algorithm: str = "HS256"): + """ + Initialize JWT authenticator. + + Args: + secret_key: Secret key for signing tokens + algorithm: JWT signing algorithm + """ + self.secret_key = secret_key + self.algorithm = algorithm + self._token_blacklist: set[str] = set() + + def create_token( + self, + client_id: str, + roles: set[str], + expires_in_hours: int = 24, + ) -> str: + """ + Create a JWT token. + + Args: + client_id: Client identifier + roles: Client roles + expires_in_hours: Token validity period + + Returns: + JWT token string + """ + try: + import jwt + except ImportError: + raise ImportError("PyJWT library required for JWT authentication. Install with: pip install PyJWT") + + now = datetime.utcnow() + payload = { + "sub": client_id, + "roles": list(roles), + "iat": now, + "exp": now + timedelta(hours=expires_in_hours), + "jti": secrets.token_hex(16), # Unique token ID + } + + return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) + + def verify_token(self, token: str) -> ClientInfo: + """ + Verify a JWT token. + + Args: + token: JWT token string + + Returns: + ClientInfo from token claims + + Raises: + AuthenticationError: If token is invalid + """ + try: + import jwt + except ImportError: + raise ImportError("PyJWT library required for JWT authentication") + + if token in self._token_blacklist: + raise AuthenticationError( + user_message="Token has been revoked", + internal_details="Token found in blacklist", + ) + + try: + payload = jwt.decode( + token, + self.secret_key, + algorithms=[self.algorithm], + ) + + return ClientInfo( + client_id=payload["sub"], + roles=set(payload.get("roles", ["user"])), + ) + except jwt.ExpiredSignatureError: + raise AuthenticationError( + user_message="Token has expired", + internal_details="JWT signature expired", + ) + except jwt.InvalidTokenError as e: + raise AuthenticationError( + user_message="Invalid token", + internal_details=f"JWT validation failed: {str(e)}", + ) + + def revoke_token(self, token: str) -> None: + """ + Revoke a JWT token by adding to blacklist. + + Args: + token: Token to revoke + """ + self._token_blacklist.add(token) + + +# Default authenticator instance +_default_authenticator: APIKeyAuthenticator | None = None + + +def get_authenticator() -> APIKeyAuthenticator: + """ + Get or create the default authenticator instance. + + Returns: + APIKeyAuthenticator instance + """ + global _default_authenticator + if _default_authenticator is None: + _default_authenticator = APIKeyAuthenticator() + return _default_authenticator + + +def set_authenticator(authenticator: APIKeyAuthenticator) -> None: + """ + Set the default authenticator instance. + + Args: + authenticator: Authenticator to use + """ + global _default_authenticator + _default_authenticator = authenticator + + +# Exports +__all__ = [ + "APIKeyAuthenticator", + "JWTAuthenticator", + "ClientInfo", + "RateLimitConfig", + "get_authenticator", + "set_authenticator", +] diff --git a/src/api/exceptions.py b/src/api/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..790ec42a7ab11e8ea60162379281ba8547705a48 --- /dev/null +++ b/src/api/exceptions.py @@ -0,0 +1,299 @@ +""" +Custom exception hierarchy for LangGraph Multi-Agent MCTS Framework. + +Provides: +- Sanitized error messages for production +- Structured error information for logging +- Clear separation between user-facing and internal errors +""" + +import re +from datetime import datetime +from typing import Any + + +class FrameworkError(Exception): + """ + Base exception for all framework errors. + + Provides sanitized user-facing messages while preserving + internal details for logging. + """ + + def __init__( + self, + user_message: str, + internal_details: str | None = None, + error_code: str | None = None, + context: dict[str, Any] | None = None, + ): + """ + Initialize framework error. + + Args: + user_message: Safe message to show to users + internal_details: Detailed information for logs (may contain sensitive data) + error_code: Machine-readable error code + context: Additional context for debugging + """ + self.user_message = user_message + self.internal_details = internal_details or user_message + self.error_code = error_code or self.__class__.__name__.upper() + self.context = context or {} + self.timestamp = datetime.utcnow() + + super().__init__(user_message) + + def sanitize_details(self) -> str: + """ + Remove sensitive information from internal details. + + Sanitizes: + - File paths + - API keys + - Passwords + - Connection strings + - IP addresses + """ + sanitized = self.internal_details + + # Remove file paths (Unix and Windows) + sanitized = re.sub(r"/[\w/.-]+", "/***", sanitized) + sanitized = re.sub(r"[A-Za-z]:\\[\w\\.-]+", "C:\\***", sanitized) + + # Remove API keys and secrets + sanitized = re.sub( + r"(api[_-]?key|secret|password|token|credential)[\s=:]+[\S]+", r"\1=***", sanitized, flags=re.IGNORECASE + ) + + # Remove connection strings + sanitized = re.sub(r"(mongodb|postgresql|mysql|redis)://[^\s]+", r"\1://***", sanitized, flags=re.IGNORECASE) + + # Remove IP addresses + sanitized = re.sub(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", "***.***.***", sanitized) + + # Remove email addresses + sanitized = re.sub(r"\b[\w.-]+@[\w.-]+\.\w+\b", "***@***", sanitized) + + return sanitized + + def to_log_dict(self) -> dict[str, Any]: + """ + Convert exception to dictionary for structured logging. + + Returns sanitized version safe for logs. + """ + return { + "error_type": self.__class__.__name__, + "error_code": self.error_code, + "user_message": self.user_message, + "sanitized_details": self.sanitize_details(), + "timestamp": self.timestamp.isoformat(), + "context": {k: str(v) for k, v in self.context.items()}, + } + + def to_user_response(self) -> dict[str, Any]: + """ + Convert exception to safe user-facing response. + """ + return { + "error": True, + "error_code": self.error_code, + "message": self.user_message, + "timestamp": self.timestamp.isoformat(), + } + + +class ValidationError(FrameworkError): + """Raised when input validation fails.""" + + def __init__( + self, + user_message: str = "Invalid input provided", + internal_details: str | None = None, + field_name: str | None = None, + **kwargs, + ): + context = kwargs.pop("context", {}) + if field_name: + context["field_name"] = field_name + super().__init__( + user_message=user_message, + internal_details=internal_details, + error_code="VALIDATION_ERROR", + context=context, + **kwargs, + ) + self.field_name = field_name + + +class AuthenticationError(FrameworkError): + """Raised when authentication fails.""" + + def __init__(self, user_message: str = "Authentication failed", internal_details: str | None = None, **kwargs): + super().__init__( + user_message=user_message, internal_details=internal_details, error_code="AUTH_ERROR", **kwargs + ) + + +class AuthorizationError(FrameworkError): + """Raised when authorization fails.""" + + def __init__( + self, + user_message: str = "Access denied", + internal_details: str | None = None, + required_permission: str | None = None, + **kwargs, + ): + context = kwargs.pop("context", {}) + if required_permission: + context["required_permission"] = required_permission + super().__init__( + user_message=user_message, + internal_details=internal_details, + error_code="AUTHZ_ERROR", + context=context, + **kwargs, + ) + + +class RateLimitError(FrameworkError): + """Raised when rate limit is exceeded.""" + + def __init__( + self, + user_message: str = "Rate limit exceeded. Please try again later.", + internal_details: str | None = None, + retry_after_seconds: int | None = None, + **kwargs, + ): + context = kwargs.pop("context", {}) + if retry_after_seconds: + context["retry_after_seconds"] = retry_after_seconds + super().__init__( + user_message=user_message, + internal_details=internal_details, + error_code="RATE_LIMIT", + context=context, + **kwargs, + ) + self.retry_after_seconds = retry_after_seconds + + +class LLMError(FrameworkError): + """Raised when LLM operations fail.""" + + def __init__( + self, + user_message: str = "Language model service temporarily unavailable", + internal_details: str | None = None, + provider: str | None = None, + **kwargs, + ): + context = kwargs.pop("context", {}) + if provider: + context["provider"] = provider + super().__init__( + user_message=user_message, + internal_details=internal_details, + error_code="LLM_ERROR", + context=context, + **kwargs, + ) + + +class MCTSError(FrameworkError): + """Raised when MCTS simulation fails.""" + + def __init__( + self, + user_message: str = "Tactical simulation failed", + internal_details: str | None = None, + iteration: int | None = None, + **kwargs, + ): + context = kwargs.pop("context", {}) + if iteration is not None: + context["iteration"] = iteration + super().__init__( + user_message=user_message, + internal_details=internal_details, + error_code="MCTS_ERROR", + context=context, + **kwargs, + ) + + +class RAGError(FrameworkError): + """Raised when RAG retrieval fails.""" + + def __init__(self, user_message: str = "Context retrieval failed", internal_details: str | None = None, **kwargs): + super().__init__(user_message=user_message, internal_details=internal_details, error_code="RAG_ERROR", **kwargs) + + +class TimeoutError(FrameworkError): + """Raised when operation times out.""" + + def __init__( + self, + user_message: str = "Operation timed out", + internal_details: str | None = None, + operation: str | None = None, + timeout_seconds: float | None = None, + **kwargs, + ): + context = kwargs.pop("context", {}) + if operation: + context["operation"] = operation + if timeout_seconds: + context["timeout_seconds"] = timeout_seconds + super().__init__( + user_message=user_message, + internal_details=internal_details, + error_code="TIMEOUT", + context=context, + **kwargs, + ) + + +class ConfigurationError(FrameworkError): + """Raised when configuration is invalid.""" + + def __init__( + self, + user_message: str = "System configuration error", + internal_details: str | None = None, + config_key: str | None = None, + **kwargs, + ): + context = kwargs.pop("context", {}) + if config_key: + context["config_key"] = config_key + super().__init__( + user_message=user_message, + internal_details=internal_details, + error_code="CONFIG_ERROR", + context=context, + **kwargs, + ) + + +# Convenience function for wrapping exceptions +def wrap_exception( + exc: Exception, user_message: str = "An unexpected error occurred", error_class: type = FrameworkError, **kwargs +) -> FrameworkError: + """ + Wrap a standard exception in a FrameworkError with sanitized details. + + Args: + exc: Original exception + user_message: Safe user-facing message + error_class: FrameworkError subclass to use + **kwargs: Additional context + + Returns: + FrameworkError instance with sanitized details + """ + internal_details = f"{type(exc).__name__}: {str(exc)}" + return error_class(user_message=user_message, internal_details=internal_details, **kwargs) diff --git a/src/api/inference_server.py b/src/api/inference_server.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b0b21c2192f63437406eeaaff54131c20cbd48 --- /dev/null +++ b/src/api/inference_server.py @@ -0,0 +1,380 @@ +""" +FastAPI Inference Server for LangGraph Multi-Agent MCTS. + +Provides REST API for: +- Problem solving with HRM+MCTS+TRM +- Policy-value network inference +- Health checks and monitoring +""" + +import time +from typing import Any + +import torch +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field + +from ..framework.mcts.neural_mcts import NeuralMCTS +from ..training.performance_monitor import PerformanceMonitor +from ..training.system_config import SystemConfig + + +# Request/Response Models +class InferenceRequest(BaseModel): + """Request for problem inference.""" + + state: list[list[float]] # State representation + query: str | None = "Solve this problem" + max_thinking_time: float = Field(default=10.0, ge=0.1, le=60.0) + use_mcts: bool = True + num_simulations: int | None = None + use_hrm_decomposition: bool = False + use_trm_refinement: bool = False + temperature: float = Field(default=0.1, ge=0.0, le=2.0) + + +class PolicyValueRequest(BaseModel): + """Request for policy-value evaluation.""" + + state: list[list[float]] # State representation + + +class InferenceResponse(BaseModel): + """Response with inference results.""" + + success: bool + action_probabilities: dict[str, float] | None = None + best_action: str | None = None + value_estimate: float | None = None + subproblems: list[dict[str, Any]] | None = None + refinement_info: dict[str, Any] | None = None + performance_stats: dict[str, float] + error: str | None = None + + +class PolicyValueResponse(BaseModel): + """Response with policy-value predictions.""" + + policy_probs: list[float] + value: float + inference_time_ms: float + + +class HealthResponse(BaseModel): + """Health check response.""" + + status: str + device: str + model_loaded: bool + gpu_available: bool + gpu_memory_gb: float | None = None + uptime_seconds: float + + +# Inference Server +class InferenceServer: + """ + Production inference server with comprehensive features. + + Features: + - FastAPI REST endpoints + - Performance monitoring + - Health checks + - CORS support + - Error handling + """ + + def __init__( + self, + checkpoint_path: str, + config: SystemConfig | None = None, + host: str = "0.0.0.0", + port: int = 8000, + ): + """ + Initialize inference server. + + Args: + checkpoint_path: Path to model checkpoint + config: System configuration (loaded from checkpoint if None) + host: Server host + port: Server port + """ + self.checkpoint_path = checkpoint_path + self.host = host + self.port = port + self.start_time = time.time() + + # Load models + self.config, self.models = self._load_models(checkpoint_path, config) + self.device = self.config.device + + # Performance monitoring + self.monitor = PerformanceMonitor(window_size=100, enable_gpu_monitoring=(self.device != "cpu")) + + # Setup FastAPI app + self.app = FastAPI( + title="LangGraph Multi-Agent MCTS API", + description="Neural-guided MCTS with HRM and TRM agents", + version="1.0.0", + ) + + # CORS middleware + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Setup routes + self._setup_routes() + + def _load_models( + self, checkpoint_path: str, config: SystemConfig | None + ) -> tuple[SystemConfig, dict[str, torch.nn.Module]]: + """Load models from checkpoint.""" + print(f"Loading models from {checkpoint_path}...") + + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + # Load config + if config is None: + config_dict = checkpoint.get("config", {}) + config = SystemConfig.from_dict(config_dict) + + device = config.device + + # Load models + models = {} + + # Policy-Value Network + from ..models.policy_value_net import create_policy_value_network + + models["policy_value_net"] = create_policy_value_network(config.neural_net, board_size=19, device=device) + models["policy_value_net"].load_state_dict(checkpoint["policy_value_net"]) + models["policy_value_net"].eval() + + # HRM Agent + from ..agents.hrm_agent import create_hrm_agent + + models["hrm_agent"] = create_hrm_agent(config.hrm, device) + models["hrm_agent"].load_state_dict(checkpoint["hrm_agent"]) + models["hrm_agent"].eval() + + # TRM Agent + from ..agents.trm_agent import create_trm_agent + + models["trm_agent"] = create_trm_agent(config.trm, output_dim=config.neural_net.action_size, device=device) + models["trm_agent"].load_state_dict(checkpoint["trm_agent"]) + models["trm_agent"].eval() + + # MCTS + models["mcts"] = NeuralMCTS( + policy_value_network=models["policy_value_net"], + config=config.mcts, + device=device, + ) + + print(f"✓ Models loaded successfully on {device}") + + return config, models + + def _setup_routes(self): + """Setup API routes.""" + + @self.app.get("/", response_model=dict[str, str]) + async def root(): + """Root endpoint.""" + return { + "message": "LangGraph Multi-Agent MCTS API", + "version": "1.0.0", + "docs": "/docs", + } + + @self.app.get("/health", response_model=HealthResponse) + async def health(): + """Health check endpoint.""" + gpu_memory = None + if torch.cuda.is_available(): + gpu_memory = torch.cuda.memory_allocated() / (1024**3) + + return HealthResponse( + status="healthy", + device=self.device, + model_loaded=True, + gpu_available=torch.cuda.is_available(), + gpu_memory_gb=gpu_memory, + uptime_seconds=time.time() - self.start_time, + ) + + @self.app.post("/inference", response_model=InferenceResponse) + async def inference(request: InferenceRequest): + """ + Main inference endpoint. + + Processes a problem using the full pipeline: + 1. Optional HRM decomposition + 2. MCTS search + 3. Optional TRM refinement + """ + try: + start_time = time.perf_counter() + + # Convert state to tensor + state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0) + state_tensor = state_tensor.to(self.device) + + results = {} + + # HRM Decomposition (if requested) + if request.use_hrm_decomposition: + with torch.no_grad(): + hrm_output = self.models["hrm_agent"](state_tensor) + results["subproblems"] = [ + { + "level": sp.level, + "description": sp.description, + "confidence": sp.confidence, + } + for sp in hrm_output.subproblems + ] + + # MCTS Search (if requested) + if request.use_mcts: + # Note: This is a simplified version + # In production, you'd need to convert request.state to GameState + results["action_probabilities"] = {"action_0": 0.5, "action_1": 0.3, "action_2": 0.2} + results["best_action"] = "action_0" + results["value_estimate"] = 0.75 + + # TRM Refinement (if requested) + if request.use_trm_refinement and results.get("best_action"): + with torch.no_grad(): + # Simplified: just run TRM on the state + trm_output = self.models["trm_agent"](state_tensor) + results["refinement_info"] = { + "converged": trm_output.converged, + "convergence_step": trm_output.convergence_step, + "recursion_depth": trm_output.recursion_depth, + } + + # Performance stats + elapsed_ms = (time.perf_counter() - start_time) * 1000 + self.monitor.log_inference(elapsed_ms) + + perf_stats = { + "inference_time_ms": elapsed_ms, + "device": self.device, + } + + return InferenceResponse( + success=True, + action_probabilities=results.get("action_probabilities"), + best_action=results.get("best_action"), + value_estimate=results.get("value_estimate"), + subproblems=results.get("subproblems"), + refinement_info=results.get("refinement_info"), + performance_stats=perf_stats, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") + + @self.app.post("/policy-value", response_model=PolicyValueResponse) + async def policy_value(request: PolicyValueRequest): + """ + Get policy and value predictions for a state. + + This is a direct neural network evaluation without MCTS. + """ + try: + start_time = time.perf_counter() + + # Convert state to tensor + state_tensor = torch.tensor(request.state, dtype=torch.float32).unsqueeze(0) + state_tensor = state_tensor.to(self.device) + + # Get predictions + with torch.no_grad(): + policy_log_probs, value = self.models["policy_value_net"](state_tensor) + policy_probs = torch.exp(policy_log_probs).squeeze(0) + + elapsed_ms = (time.perf_counter() - start_time) * 1000 + + return PolicyValueResponse( + policy_probs=policy_probs.cpu().tolist(), + value=value.item(), + inference_time_ms=elapsed_ms, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Policy-value inference failed: {str(e)}") + + @self.app.get("/stats") + async def stats(): + """Get performance statistics.""" + return self.monitor.get_stats() + + @self.app.post("/reset-stats") + async def reset_stats(): + """Reset performance statistics.""" + self.monitor.reset() + return {"message": "Statistics reset successfully"} + + def run(self): + """Start the inference server.""" + print(f"\n{'=' * 80}") + print("Starting LangGraph Multi-Agent MCTS Inference Server") + print(f"{'=' * 80}") + print(f"Host: {self.host}:{self.port}") + print(f"Device: {self.device}") + print(f"Checkpoint: {self.checkpoint_path}") + print(f"{'=' * 80}\n") + + uvicorn.run(self.app, host=self.host, port=self.port) + + +def main(): + """Main entry point for inference server.""" + import argparse + + parser = argparse.ArgumentParser(description="LangGraph MCTS Inference Server") + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to model checkpoint", + ) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host") + parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument( + "--device", + type=str, + default=None, + help="Device (cpu, cuda, mps)", + ) + + args = parser.parse_args() + + # Load config and override device if specified + config = None + if args.device: + config = SystemConfig() + config.device = args.device + + server = InferenceServer( + checkpoint_path=args.checkpoint, + config=config, + host=args.host, + port=args.port, + ) + + server.run() + + +if __name__ == "__main__": + main() diff --git a/src/api/rest_server.py b/src/api/rest_server.py new file mode 100644 index 0000000000000000000000000000000000000000..020a758116ac1b626bdf669285a3d1c9a92f9543 --- /dev/null +++ b/src/api/rest_server.py @@ -0,0 +1,441 @@ +""" +Production REST API server for LangGraph Multi-Agent MCTS Framework. + +Provides: +- OpenAPI/Swagger documentation +- Authentication via API keys +- Rate limiting +- Health and readiness endpoints +- Request validation with Pydantic +- Prometheus metrics exposure +""" + +import asyncio +import time +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Any + +from fastapi import Depends, FastAPI, Header, HTTPException, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +# Import framework components +try: + from src.adapters.llm import create_client # noqa: F401 + from src.api.auth import ( + APIKeyAuthenticator, + ClientInfo, + RateLimitConfig, + get_authenticator, + set_authenticator, + ) + from src.api.exceptions import ( + AuthenticationError, + AuthorizationError, # noqa: F401 + FrameworkError, + RateLimitError, + ValidationError, # noqa: F401 + ) + from src.models.validation import MCTSConfig, QueryInput # noqa: F401 + + IMPORTS_AVAILABLE = True +except ImportError as e: + IMPORTS_AVAILABLE = False + import_error = str(e) + +# Prometheus metrics (optional) +try: + from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest + + PROMETHEUS_AVAILABLE = True + + # Define metrics + REQUEST_COUNT = Counter("mcts_requests_total", "Total number of requests", ["method", "endpoint", "status"]) + REQUEST_LATENCY = Histogram("mcts_request_duration_seconds", "Request latency in seconds", ["method", "endpoint"]) + ACTIVE_REQUESTS = Gauge("mcts_active_requests", "Number of active requests") + ERROR_COUNT = Counter("mcts_errors_total", "Total number of errors", ["error_type"]) +except ImportError: + PROMETHEUS_AVAILABLE = False + + +# Request/Response Models +class QueryRequest(BaseModel): + """Request model for query processing.""" + + query: str = Field( + ..., + min_length=1, + max_length=10000, + description="User query to process", + json_schema_extra={"example": "Recommend defensive positions for night attack scenario"}, + ) + use_mcts: bool = Field(default=True, description="Enable MCTS tactical simulation") + use_rag: bool = Field(default=True, description="Enable RAG context retrieval") + mcts_iterations: int | None = Field(default=None, ge=1, le=10000, description="Override default MCTS iterations") + thread_id: str | None = Field( + default=None, + max_length=100, + pattern=r"^[a-zA-Z0-9_-]+$", + description="Conversation thread ID for state persistence", + ) + + class Config: + json_schema_extra = { + "example": { + "query": "Recommend defensive positions for night attack", + "use_mcts": True, + "use_rag": True, + "mcts_iterations": 200, + "thread_id": "session_123", + } + } + + +class QueryResponse(BaseModel): + """Response model for query results.""" + + response: str = Field(..., description="Final synthesized response") + confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence score") + agents_used: list[str] = Field(..., description="List of agents that contributed") + mcts_stats: dict[str, Any] | None = Field(default=None, description="MCTS simulation statistics") + processing_time_ms: float = Field(..., description="Total processing time in milliseconds") + metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") + + +class HealthResponse(BaseModel): + """Health check response.""" + + status: str = Field(..., description="Service status") + timestamp: str = Field(..., description="Current timestamp") + version: str = Field(default="1.0.0", description="API version") + uptime_seconds: float = Field(..., description="Service uptime") + + +class ReadinessResponse(BaseModel): + """Readiness check response.""" + + ready: bool = Field(..., description="Whether service is ready") + checks: dict[str, bool] = Field(..., description="Individual check results") + + +class ErrorResponse(BaseModel): + """Error response model.""" + + error: bool = Field(default=True) + error_code: str = Field(..., description="Machine-readable error code") + message: str = Field(..., description="Human-readable error message") + timestamp: str = Field(..., description="Error timestamp") + + +# Application startup +start_time = time.time() +framework_instance = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + global framework_instance + + # Startup + print("Starting MCTS Framework API server...") + + # Initialize authenticator with demo key (replace in production) + authenticator = APIKeyAuthenticator( + valid_keys=["demo-api-key-replace-in-production"], + rate_limit_config=RateLimitConfig( + requests_per_minute=60, + requests_per_hour=1000, + requests_per_day=10000, + ), + ) + set_authenticator(authenticator) + + # Initialize framework (lazy loading) + # framework_instance = create_framework() + + print("API server started successfully") + + yield + + # Shutdown + print("Shutting down API server...") + + +# Create FastAPI app +app = FastAPI( + title="LangGraph Multi-Agent MCTS API", + description=""" +## Multi-Agent Reasoning API with MCTS Tactical Simulation + +This API provides access to a sophisticated multi-agent reasoning framework that combines: +- **HRM Agent**: Hierarchical decomposition of complex queries +- **TRM Agent**: Iterative refinement for response quality +- **MCTS Engine**: Monte Carlo Tree Search for tactical simulation +- **RAG Integration**: Context retrieval from vector stores + +### Features +- Secure API key authentication +- Rate limiting per client +- Real-time metrics (Prometheus) +- Distributed tracing (OpenTelemetry) +- Production-grade error handling + +### Quick Start +1. Obtain an API key +2. Include `X-API-Key` header in requests +3. Send queries to `/query` endpoint +4. Monitor health via `/health` endpoint + """, + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", + openapi_tags=[ + {"name": "query", "description": "Query processing operations"}, + {"name": "health", "description": "Health and readiness checks"}, + {"name": "metrics", "description": "Observability endpoints"}, + ], + lifespan=lifespan, +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Middleware for metrics +@app.middleware("http") +async def metrics_middleware(request: Request, call_next): + """Track request metrics.""" + if PROMETHEUS_AVAILABLE: + ACTIVE_REQUESTS.inc() + + start = time.perf_counter() + + try: + response = await call_next(request) + status = response.status_code + except Exception: + status = 500 + raise + finally: + if PROMETHEUS_AVAILABLE: + ACTIVE_REQUESTS.dec() + elapsed = time.perf_counter() - start + REQUEST_COUNT.labels(method=request.method, endpoint=request.url.path, status=str(status)).inc() + REQUEST_LATENCY.labels(method=request.method, endpoint=request.url.path).observe(elapsed) + + return response + + +# Authentication dependency +async def verify_api_key(x_api_key: str = Header(..., description="API key for authentication")): + """Verify API key and return client info.""" + if not IMPORTS_AVAILABLE: + raise HTTPException(status_code=500, detail="Authentication module not available") + + try: + authenticator = get_authenticator() + client_info = authenticator.require_auth(x_api_key) + return client_info + except AuthenticationError as e: + if PROMETHEUS_AVAILABLE: + ERROR_COUNT.labels(error_type="authentication").inc() + raise HTTPException(status_code=401, detail=e.user_message) + except RateLimitError as e: + if PROMETHEUS_AVAILABLE: + ERROR_COUNT.labels(error_type="rate_limit").inc() + raise HTTPException( + status_code=429, detail=e.user_message, headers={"Retry-After": str(e.retry_after_seconds or 60)} + ) + + +# Exception handlers +@app.exception_handler(FrameworkError) +async def framework_error_handler(request: Request, exc: FrameworkError): + """Handle framework-specific errors.""" + if PROMETHEUS_AVAILABLE: + ERROR_COUNT.labels(error_type=exc.error_code).inc() + + return JSONResponse(status_code=500, content=exc.to_user_response()) + + +@app.exception_handler(ValidationError) +async def validation_error_handler(request: Request, exc: ValidationError): + """Handle validation errors.""" + if PROMETHEUS_AVAILABLE: + ERROR_COUNT.labels(error_type="validation").inc() + + return JSONResponse(status_code=400, content=exc.to_user_response()) + + +# Endpoints +@app.get("/health", response_model=HealthResponse, tags=["health"]) +async def health_check(): + """ + Health check endpoint. + + Returns basic service health status. Use this for load balancer health checks. + """ + return HealthResponse( + status="healthy", + timestamp=datetime.utcnow().isoformat(), + version="1.0.0", + uptime_seconds=time.time() - start_time, + ) + + +@app.get("/ready", response_model=ReadinessResponse, tags=["health"]) +async def readiness_check(): + """ + Readiness check endpoint. + + Verifies all dependencies are available. Use this for Kubernetes readiness probes. + """ + checks = { + "imports_available": IMPORTS_AVAILABLE, + "authenticator_configured": True, + "llm_client_available": True, # Would check actual client + "prometheus_available": PROMETHEUS_AVAILABLE, + } + + # Check if all critical services are available + all_ready = all( + [ + checks["imports_available"], + checks["authenticator_configured"], + ] + ) + + if not all_ready: + raise HTTPException(status_code=503, detail="Service not ready") + + return ReadinessResponse(ready=all_ready, checks=checks) + + +@app.get("/metrics", tags=["metrics"]) +async def prometheus_metrics(): + """ + Prometheus metrics endpoint. + + Returns metrics in Prometheus text format for scraping. + """ + if not PROMETHEUS_AVAILABLE: + raise HTTPException(status_code=501, detail="Prometheus metrics not available") + + return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST) + + +@app.post( + "/query", + response_model=QueryResponse, + tags=["query"], + responses={ + 401: {"model": ErrorResponse, "description": "Authentication failed"}, + 429: {"model": ErrorResponse, "description": "Rate limit exceeded"}, + 400: {"model": ErrorResponse, "description": "Invalid input"}, + 500: {"model": ErrorResponse, "description": "Internal server error"}, + }, +) +async def process_query(request: QueryRequest, client_info: ClientInfo = Depends(verify_api_key)): + """ + Process a query using the multi-agent MCTS framework. + + This endpoint: + 1. Validates the input query + 2. Optionally retrieves context via RAG + 3. Processes through HRM and TRM agents + 4. Optionally runs MCTS simulation + 5. Synthesizes a final response + + **Authentication**: Requires valid API key in X-API-Key header. + + **Rate Limiting**: Subject to rate limits per client. + """ + start_time = time.perf_counter() + + # Validate input using validation models + if IMPORTS_AVAILABLE: + try: + QueryInput( + query=request.query, + use_rag=request.use_rag, + use_mcts=request.use_mcts, + thread_id=request.thread_id, + ) + except Exception as e: + if PROMETHEUS_AVAILABLE: + ERROR_COUNT.labels(error_type="validation").inc() + raise HTTPException(status_code=400, detail=f"Validation failed: {str(e)}") + + # Process query (mock implementation for demo) + # In production, this would call the actual framework + await asyncio.sleep(0.1) # Simulate processing + + processing_time = (time.perf_counter() - start_time) * 1000 + + # Mock response + return QueryResponse( + response=f"Processed query: {request.query[:100]}...", + confidence=0.85, + agents_used=["hrm", "trm"] + (["mcts"] if request.use_mcts else []), + mcts_stats=( + { + "iterations": request.mcts_iterations or 100, + "best_action": "recommended_action", + "root_visits": request.mcts_iterations or 100, + } + if request.use_mcts + else None + ), + processing_time_ms=processing_time, + metadata={ + "client_id": client_info.client_id, + "thread_id": request.thread_id, + "rag_enabled": request.use_rag, + }, + ) + + +@app.get("/stats", tags=["metrics"]) +async def get_stats(client_info: ClientInfo = Depends(verify_api_key)): + """ + Get usage statistics for the authenticated client. + + Returns request counts and rate limit information. + """ + authenticator = get_authenticator() + stats = authenticator.get_client_stats(client_info.client_id) + + return { + "client_id": client_info.client_id, + "roles": list(client_info.roles), + **stats, + "rate_limits": { + "per_minute": authenticator.rate_limit_config.requests_per_minute, + "per_hour": authenticator.rate_limit_config.requests_per_hour, + "per_day": authenticator.rate_limit_config.requests_per_day, + }, + } + + +# Entry point +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "src.api.rest_server:app", + host="0.0.0.0", + port=8000, + reload=False, + workers=4, + log_level="info", + access_log=True, + ) diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/config/meta_controller.yaml b/src/config/meta_controller.yaml new file mode 100644 index 0000000000000000000000000000000000000000..765ded2760bcceb6140100a271c201d061925e3d --- /dev/null +++ b/src/config/meta_controller.yaml @@ -0,0 +1,22 @@ +meta_controller: + enabled: false # Disabled by default for backward compatibility + type: "rnn" # "rnn" or "bert" + fallback_to_rule_based: true # Fallback on errors + + rnn: + hidden_dim: 64 + num_layers: 1 + dropout: 0.1 + model_path: null # Path to trained model (null for untrained) + + bert: + model_name: "prajjwal1/bert-mini" + use_lora: true + lora_r: 4 + lora_alpha: 16 + lora_dropout: 0.1 + model_path: null # Path to trained LoRA adapter + + inference: + device: null # Auto-detect if null + seed: 42 diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..0b398ae926ac1ec325329b8cfa102758352fa9dd --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,431 @@ +""" +Pydantic Settings v2 configuration management for LangGraph Multi-Agent MCTS. + +Provides: +- Secure configuration loading from environment variables and .env files +- Type-safe settings with validation +- Secrets protection using SecretStr +- MCTS parameter bounds validation +- Support for multiple LLM providers +""" + +from enum import Enum + +from pydantic import ( + Field, + SecretStr, + field_validator, + model_validator, +) +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class LLMProvider(str, Enum): + """Supported LLM providers.""" + + OPENAI = "openai" + ANTHROPIC = "anthropic" + LMSTUDIO = "lmstudio" + + +class LogLevel(str, Enum): + """Supported log levels.""" + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + + +class MCTSImplementation(str, Enum): + """MCTS implementation variants.""" + + BASELINE = "baseline" # Original MCTS core + NEURAL = "neural" # Neural-guided AlphaZero-style MCTS + + +class Settings(BaseSettings): + """ + Application settings with security-first configuration. + + All sensitive values use SecretStr to prevent accidental exposure in logs. + Configuration is loaded from environment variables with .env file support. + """ + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=True, + extra="ignore", + validate_default=True, + ) + + # LLM Provider Configuration + LLM_PROVIDER: LLMProvider = Field( + default=LLMProvider.OPENAI, description="LLM provider to use (openai, anthropic, lmstudio)" + ) + + # API Keys (Secrets) + OPENAI_API_KEY: SecretStr | None = Field( + default=None, description="OpenAI API key (required if using OpenAI provider)" + ) + + ANTHROPIC_API_KEY: SecretStr | None = Field( + default=None, description="Anthropic API key (required if using Anthropic provider)" + ) + + BRAINTRUST_API_KEY: SecretStr | None = Field( + default=None, description="Braintrust API key for experiment tracking (optional)" + ) + + PINECONE_API_KEY: SecretStr | None = Field( + default=None, description="Pinecone API key for vector storage (optional)" + ) + + PINECONE_HOST: str | None = Field( + default=None, description="Pinecone host URL (e.g., https://index.svc.environment.pinecone.io)" + ) + + # Local LLM Configuration + LMSTUDIO_BASE_URL: str | None = Field( + default="http://localhost:1234/v1", description="LM Studio API base URL for local inference" + ) + + LMSTUDIO_MODEL: str | None = Field(default=None, description="LM Studio model identifier (e.g., liquid/lfm2-1.2b)") + + # MCTS Configuration with bounds validation + MCTS_ENABLED: bool = Field(default=True, description="Enable MCTS for agent decision-making") + + MCTS_IMPL: MCTSImplementation = Field( + default=MCTSImplementation.BASELINE, description="MCTS implementation variant to use" + ) + + MCTS_ITERATIONS: int = Field(default=100, ge=1, le=10000, description="Number of MCTS iterations (1-10000)") + + MCTS_C: float = Field( + default=1.414, ge=0.0, le=10.0, description="MCTS exploration weight (UCB1 constant, 0.0-10.0)" + ) + + # Random seed for reproducibility + SEED: int | None = Field(default=None, ge=0, description="Random seed for reproducibility (optional)") + + # LangSmith Configuration for tracing and evaluation + LANGSMITH_API_KEY: SecretStr | None = Field( + default=None, description="LangSmith API key for tracing and evaluation (optional)" + ) + + LANGSMITH_PROJECT: str = Field(default="langgraph-mcts", description="LangSmith project name") + + LANGCHAIN_TRACING_V2: bool = Field(default=False, description="Enable LangChain tracing v2") + + LANGCHAIN_ENDPOINT: str = Field(default="https://api.smith.langchain.com", description="LangChain API endpoint") + + # Weights & Biases Configuration for experiment tracking + WANDB_API_KEY: SecretStr | None = Field( + default=None, description="Weights & Biases API key for experiment tracking (optional)" + ) + + WANDB_PROJECT: str = Field(default="langgraph-mcts", description="W&B project name") + + WANDB_ENTITY: str | None = Field(default=None, description="W&B entity (username or team name)") + + WANDB_MODE: str = Field(default="online", description="W&B mode: online, offline, or disabled") + + # Logging Configuration + LOG_LEVEL: LogLevel = Field(default=LogLevel.INFO, description="Application log level") + + # OpenTelemetry Configuration + OTEL_EXPORTER_OTLP_ENDPOINT: str | None = Field( + default=None, description="OpenTelemetry OTLP exporter endpoint URL" + ) + + # S3 Storage Configuration + S3_BUCKET: str | None = Field(default=None, description="S3 bucket name for artifact storage") + + S3_PREFIX: str = Field(default="mcts-artifacts", description="S3 key prefix for stored artifacts") + + S3_REGION: str = Field(default="us-east-1", description="AWS region for S3 bucket") + + # Network Configuration (security) + HTTP_TIMEOUT_SECONDS: int = Field(default=30, ge=1, le=300, description="HTTP request timeout in seconds") + + HTTP_MAX_RETRIES: int = Field(default=3, ge=0, le=10, description="Maximum HTTP request retries") + + # Security Settings + MAX_QUERY_LENGTH: int = Field( + default=10000, ge=1, le=100000, description="Maximum allowed query length in characters" + ) + + RATE_LIMIT_REQUESTS_PER_MINUTE: int = Field( + default=60, ge=1, le=1000, description="Rate limit for API requests per minute" + ) + + @field_validator("OPENAI_API_KEY") + @classmethod + def validate_openai_key_format(cls, v: SecretStr | None) -> SecretStr | None: + """Validate OpenAI API key format without exposing the value.""" + if v is not None: + secret_value = v.get_secret_value() + # Check for obviously invalid patterns + if secret_value in ("", "your-api-key-here", "sk-xxx", "REPLACE_ME"): + raise ValueError("OpenAI API key appears to be a placeholder value") + if not secret_value.startswith("sk-"): + raise ValueError("OpenAI API key should start with 'sk-'") + if len(secret_value) < 20: + raise ValueError("OpenAI API key appears to be too short") + return v + + @field_validator("ANTHROPIC_API_KEY") + @classmethod + def validate_anthropic_key_format(cls, v: SecretStr | None) -> SecretStr | None: + """Validate Anthropic API key format without exposing the value.""" + if v is not None: + secret_value = v.get_secret_value() + # Check for obviously invalid patterns + if secret_value in ("", "your-api-key-here", "REPLACE_ME"): + raise ValueError("Anthropic API key appears to be a placeholder value") + if len(secret_value) < 20: + raise ValueError("Anthropic API key appears to be too short") + return v + + @field_validator("BRAINTRUST_API_KEY") + @classmethod + def validate_braintrust_key_format(cls, v: SecretStr | None) -> SecretStr | None: + """Validate Braintrust API key format without exposing the value.""" + if v is not None: + secret_value = v.get_secret_value() + # Check for obviously invalid patterns + if secret_value in ("", "your-api-key-here", "REPLACE_ME"): + raise ValueError("Braintrust API key appears to be a placeholder value") + if len(secret_value) < 20: + raise ValueError("Braintrust API key appears to be too short") + return v + + @field_validator("PINECONE_API_KEY") + @classmethod + def validate_pinecone_key_format(cls, v: SecretStr | None) -> SecretStr | None: + """Validate Pinecone API key format without exposing the value.""" + if v is not None: + secret_value = v.get_secret_value() + # Check for obviously invalid patterns + if secret_value in ("", "your-api-key-here", "REPLACE_ME"): + raise ValueError("Pinecone API key appears to be a placeholder value") + if len(secret_value) < 20: + raise ValueError("Pinecone API key appears to be too short") + return v + + @field_validator("LANGSMITH_API_KEY") + @classmethod + def validate_langsmith_key_format(cls, v: SecretStr | None) -> SecretStr | None: + """Validate LangSmith API key format without exposing the value.""" + if v is not None: + secret_value = v.get_secret_value() + if secret_value in ("", "your-api-key-here", "REPLACE_ME"): + raise ValueError("LangSmith API key appears to be a placeholder value") + if len(secret_value) < 20: + raise ValueError("LangSmith API key appears to be too short") + return v + + @field_validator("WANDB_API_KEY") + @classmethod + def validate_wandb_key_format(cls, v: SecretStr | None) -> SecretStr | None: + """Validate Weights & Biases API key format without exposing the value.""" + if v is not None: + secret_value = v.get_secret_value() + if secret_value in ("", "your-api-key-here", "REPLACE_ME"): + raise ValueError("W&B API key appears to be a placeholder value") + if len(secret_value) < 20: + raise ValueError("W&B API key appears to be too short") + return v + + @field_validator("PINECONE_HOST") + @classmethod + def validate_pinecone_host(cls, v: str | None) -> str | None: + """Validate Pinecone host URL format.""" + if v is not None and v != "": + if not v.startswith("https://"): + raise ValueError("Pinecone host must start with https://") + if "pinecone.io" not in v: + raise ValueError("Pinecone host should be a valid pinecone.io URL") + return v + + @field_validator("LMSTUDIO_BASE_URL") + @classmethod + def validate_lmstudio_url(cls, v: str | None) -> str | None: + """Validate LM Studio base URL format.""" + if v is not None: + if not v.startswith(("http://", "https://")): + raise ValueError("LM Studio base URL must start with http:// or https://") + # Warn if not localhost (potential security concern) + if not any(host in v for host in ("localhost", "127.0.0.1", "::1")): + import warnings + + warnings.warn( + "LM Studio URL points to non-localhost address. Ensure this is intentional and secure.", + UserWarning, + stacklevel=2, + ) + return v + + @field_validator("OTEL_EXPORTER_OTLP_ENDPOINT") + @classmethod + def validate_otel_endpoint(cls, v: str | None) -> str | None: + """Validate OpenTelemetry endpoint URL.""" + if v is not None and v != "" and not v.startswith(("http://", "https://", "grpc://")): + raise ValueError("OpenTelemetry endpoint must start with http://, https://, or grpc://") + return v + + @field_validator("S3_BUCKET") + @classmethod + def validate_s3_bucket_name(cls, v: str | None) -> str | None: + """Validate S3 bucket name format.""" + if v is not None: + # S3 bucket naming rules + if len(v) < 3 or len(v) > 63: + raise ValueError("S3 bucket name must be 3-63 characters long") + if not v.replace("-", "").replace(".", "").isalnum(): + raise ValueError("S3 bucket name can only contain lowercase letters, numbers, hyphens, and periods") + if v.startswith("-") or v.endswith("-"): + raise ValueError("S3 bucket name cannot start or end with a hyphen") + return v + + @model_validator(mode="after") + def validate_provider_credentials(self) -> "Settings": + """Ensure required API keys are provided for the selected provider.""" + if self.LLM_PROVIDER == LLMProvider.OPENAI: + if self.OPENAI_API_KEY is None: + raise ValueError( + "OPENAI_API_KEY is required when using OpenAI provider. " + "Set the OPENAI_API_KEY environment variable." + ) + elif self.LLM_PROVIDER == LLMProvider.ANTHROPIC: + if self.ANTHROPIC_API_KEY is None: + raise ValueError( + "ANTHROPIC_API_KEY is required when using Anthropic provider. " + "Set the ANTHROPIC_API_KEY environment variable." + ) + elif self.LLM_PROVIDER == LLMProvider.LMSTUDIO and self.LMSTUDIO_BASE_URL is None: + raise ValueError("LMSTUDIO_BASE_URL is required when using LM Studio provider.") + return self + + def get_api_key(self) -> str | None: + """ + Get the API key for the current provider. + + Returns the secret value - use with caution to avoid logging. + """ + if self.LLM_PROVIDER == LLMProvider.OPENAI and self.OPENAI_API_KEY: + return self.OPENAI_API_KEY.get_secret_value() + elif self.LLM_PROVIDER == LLMProvider.ANTHROPIC and self.ANTHROPIC_API_KEY: + return self.ANTHROPIC_API_KEY.get_secret_value() + return None + + def safe_dict(self) -> dict: + """ + Return settings as dictionary with secrets masked. + + Safe for logging and display purposes. + """ + data = self.model_dump() + # Mask all sensitive fields + secret_fields = [ + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "BRAINTRUST_API_KEY", + "PINECONE_API_KEY", + "LANGSMITH_API_KEY", + "WANDB_API_KEY", + ] + for field in secret_fields: + if field in data and data[field]: + data[field] = "***MASKED***" + return data + + def get_braintrust_api_key(self) -> str | None: + """ + Get the Braintrust API key if configured. + + Returns the secret value - use with caution to avoid logging. + """ + if self.BRAINTRUST_API_KEY: + return self.BRAINTRUST_API_KEY.get_secret_value() + return None + + def get_pinecone_api_key(self) -> str | None: + """ + Get the Pinecone API key if configured. + + Returns the secret value - use with caution to avoid logging. + """ + if self.PINECONE_API_KEY: + return self.PINECONE_API_KEY.get_secret_value() + return None + + def get_langsmith_api_key(self) -> str | None: + """ + Get the LangSmith API key if configured. + + Returns the secret value - use with caution to avoid logging. + """ + if self.LANGSMITH_API_KEY: + return self.LANGSMITH_API_KEY.get_secret_value() + return None + + def get_wandb_api_key(self) -> str | None: + """ + Get the Weights & Biases API key if configured. + + Returns the secret value - use with caution to avoid logging. + """ + if self.WANDB_API_KEY: + return self.WANDB_API_KEY.get_secret_value() + return None + + def __repr__(self) -> str: + """Safe string representation that doesn't expose secrets.""" + return f"Settings(LLM_PROVIDER={self.LLM_PROVIDER}, MCTS_ENABLED={self.MCTS_ENABLED}, MCTS_IMPL={self.MCTS_IMPL}, LOG_LEVEL={self.LOG_LEVEL})" + + +# Global settings instance (lazily loaded) +_settings: Settings | None = None + + +def get_settings() -> Settings: + """ + Get the global settings instance. + + Settings are loaded once and cached. To reload, call reset_settings() first. + + Returns: + Settings: Application configuration instance + + Raises: + ValidationError: If configuration is invalid + """ + global _settings + if _settings is None: + _settings = Settings() + return _settings + + +def reset_settings() -> None: + """ + Reset the global settings instance. + + Forces settings to be reloaded from environment on next get_settings() call. + Useful for testing. + """ + global _settings + _settings = None + + +# Type exports for external use +__all__ = [ + "Settings", + "LLMProvider", + "LogLevel", + "MCTSImplementation", + "get_settings", + "reset_settings", +] diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba25fc46f08ba038e5732be81ca91a67bffcbb39 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,29 @@ +""" +Dataset Integration Module for Multi-Agent MCTS Training. + +This module provides utilities for loading, preprocessing, and managing +open-source datasets for training HRM/TRM agents and neural meta-controllers. + +Supported Datasets: +- DABStep: Multi-step reasoning tasks (CC-BY-4.0) +- PRIMUS-Seed: Cybersecurity domain knowledge (ODC-BY) +- PRIMUS-Instruct: Instruction fine-tuning data (ODC-BY) +""" + +from .dataset_loader import DABStepLoader, DatasetLoader, PRIMUSLoader +from .preprocessing import TextPreprocessor, TokenizerWrapper +from .tactical_augmentation import TacticalAugmenter +from .train_test_split import DataSplitter, StratifiedSplitter + +__all__ = [ + "DatasetLoader", + "DABStepLoader", + "PRIMUSLoader", + "TextPreprocessor", + "TokenizerWrapper", + "TacticalAugmenter", + "DataSplitter", + "StratifiedSplitter", +] + +__version__ = "1.0.0" diff --git a/src/data/dataset_loader.py b/src/data/dataset_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9db982a77641a347217659100402f25062088558 --- /dev/null +++ b/src/data/dataset_loader.py @@ -0,0 +1,551 @@ +""" +Dataset Loading Module for Open-Source Training Data. + +Provides unified loading interfaces for: +- DABStep: Multi-step data analysis reasoning +- PRIMUS: Cybersecurity domain knowledge +- Custom tactical datasets + +License Attribution: +- DABStep: CC-BY-4.0 (Creative Commons Attribution) +- PRIMUS: ODC-BY (Open Data Commons Attribution) +""" + +import logging +from abc import ABC, abstractmethod +from collections.abc import Iterator +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class DatasetSample: + """Standardized representation of a dataset sample.""" + + id: str + text: str + metadata: dict[str, Any] = field(default_factory=dict) + labels: list[str] | None = None + difficulty: str | None = None + domain: str | None = None + reasoning_steps: list[str] | None = None + + +@dataclass +class DatasetStatistics: + """Statistics about a loaded dataset.""" + + total_samples: int + domains: dict[str, int] + avg_text_length: float + difficulty_distribution: dict[str, int] + total_tokens: int = 0 + + +class DatasetLoader(ABC): + """Abstract base class for dataset loaders.""" + + def __init__(self, cache_dir: str | None = None): + """ + Initialize dataset loader. + + Args: + cache_dir: Directory to cache downloaded datasets + """ + self.cache_dir = cache_dir or str(Path.home() / ".cache" / "mcts_datasets") + self._dataset = None + self._statistics = None + + @abstractmethod + def load(self, split: str = "train") -> list[DatasetSample]: + """Load dataset split.""" + pass + + @abstractmethod + def get_statistics(self) -> DatasetStatistics: + """Get dataset statistics.""" + pass + + @abstractmethod + def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]: + """Iterate over samples in batches.""" + pass + + +class DABStepLoader(DatasetLoader): + """ + Loader for DABStep Multi-Step Reasoning Dataset. + + DABStep contains 450+ data analysis tasks requiring sequential, + iterative problem-solving. Perfect for training HRM/TRM agents. + + License: CC-BY-4.0 (Attribution required) + Source: huggingface.co/datasets/adyen/DABstep + """ + + DATASET_NAME = "adyen/DABstep" + DIFFICULTIES = ["easy", "medium", "hard"] + + def __init__(self, cache_dir: str | None = None): + """Initialize DABStep loader.""" + super().__init__(cache_dir) + self._loaded_samples: list[DatasetSample] = [] + + def load(self, split: str = "train", difficulty: str | None = None) -> list[DatasetSample]: + """ + Load DABStep dataset. + + Args: + split: Dataset split ('train', 'validation', 'test') + difficulty: Filter by difficulty ('easy', 'medium', 'hard') + + Returns: + List of DatasetSample objects + """ + try: + from datasets import load_dataset + + logger.info(f"Loading DABStep dataset (split={split})") + + dataset = load_dataset( + self.DATASET_NAME, + cache_dir=self.cache_dir, + ) + + if split not in dataset: + available_splits = list(dataset.keys()) + logger.warning(f"Split '{split}' not found. Available: {available_splits}") + split = available_splits[0] if available_splits else "train" + + samples = [] + for idx, item in enumerate(dataset[split]): + sample = DatasetSample( + id=f"dabstep_{split}_{idx}", + text=str(item.get("question", item.get("text", ""))), + metadata={ + "source": "DABStep", + "license": "CC-BY-4.0", + "split": split, + "original_data": item, + }, + difficulty=item.get("difficulty", "medium"), + domain="data_analysis", + reasoning_steps=item.get("steps", []), + ) + + if difficulty and sample.difficulty != difficulty: + continue + + samples.append(sample) + + self._loaded_samples = samples + logger.info(f"Loaded {len(samples)} DABStep samples") + return samples + + except ImportError: + logger.error("datasets library not installed. Run: pip install datasets") + raise + except Exception as e: + logger.error(f"Failed to load DABStep: {e}") + raise + + def get_statistics(self) -> DatasetStatistics: + """Get statistics about loaded DABStep data.""" + if not self._loaded_samples: + raise ValueError("No samples loaded. Call load() first.") + + difficulty_dist = {} + total_length = 0 + + for sample in self._loaded_samples: + diff = sample.difficulty or "unknown" + difficulty_dist[diff] = difficulty_dist.get(diff, 0) + 1 + total_length += len(sample.text) + + return DatasetStatistics( + total_samples=len(self._loaded_samples), + domains={"data_analysis": len(self._loaded_samples)}, + avg_text_length=total_length / len(self._loaded_samples), + difficulty_distribution=difficulty_dist, + ) + + def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]: + """Iterate over samples in batches.""" + if not self._loaded_samples: + raise ValueError("No samples loaded. Call load() first.") + + for i in range(0, len(self._loaded_samples), batch_size): + yield self._loaded_samples[i : i + batch_size] + + def get_reasoning_tasks(self) -> list[DatasetSample]: + """Get only samples with explicit reasoning steps.""" + return [s for s in self._loaded_samples if s.reasoning_steps] + + +class PRIMUSLoader(DatasetLoader): + """ + Loader for PRIMUS Cybersecurity Dataset Suite. + + PRIMUS contains: + - Seed: 674,848 cybersecurity documents (190M tokens) + - Instruct: 835 instruction-tuning samples + - Reasoning: Self-reflection data for reasoning + + License: ODC-BY (Open Data Commons Attribution) + Source: huggingface.co/datasets/trendmicro-ailab/Primus-Seed + """ + + SEED_DATASET = "trendmicro-ailab/Primus-Seed" + INSTRUCT_DATASET = "trendmicro-ailab/Primus-Instruct" + + DOMAINS = [ + "mitre_attack", + "wikipedia", + "company_sites", + "threat_intelligence", + "vulnerability_db", + ] + + def __init__(self, cache_dir: str | None = None): + """Initialize PRIMUS loader.""" + super().__init__(cache_dir) + self._seed_samples: list[DatasetSample] = [] + self._instruct_samples: list[DatasetSample] = [] + + def load( + self, + split: str = "train", + dataset_type: str = "seed", + domains: list[str] | None = None, + max_samples: int | None = None, + streaming: bool = True, + ) -> list[DatasetSample]: + """ + Load PRIMUS dataset. + + Args: + split: Dataset split ('train', 'validation', 'test') + dataset_type: 'seed' for knowledge base, 'instruct' for fine-tuning + domains: Filter by specific domains + max_samples: Limit number of samples (useful for large datasets) + streaming: Use streaming mode for large datasets (default True) + + Returns: + List of DatasetSample objects + """ + try: + from datasets import load_dataset + + dataset_name = self.SEED_DATASET if dataset_type == "seed" else self.INSTRUCT_DATASET + + logger.info(f"Loading PRIMUS {dataset_type} dataset") + + # Use streaming for large seed dataset to avoid download issues + use_streaming = streaming and dataset_type == "seed" and max_samples is not None + + if use_streaming: + logger.info(f"Using streaming mode (max_samples={max_samples})") + dataset = load_dataset( + dataset_name, + "default", + streaming=True, + cache_dir=self.cache_dir, + ) + # For streaming, iterate the first available split + data_iter = iter(dataset["train"]) if "train" in dataset else iter(dataset[list(dataset.keys())[0]]) + else: + dataset = load_dataset( + dataset_name, + cache_dir=self.cache_dir, + ) + + if split not in dataset: + available_splits = list(dataset.keys()) + logger.warning(f"Split '{split}' not found. Using: {available_splits[0]}") + split = available_splits[0] + + data_iter = iter(dataset[split]) + + samples = [] + count = 0 + + for idx, item in enumerate(data_iter): + if max_samples and count >= max_samples: + break + + domain = item.get("domain", item.get("source", "unknown")) + + if domains and domain not in domains: + continue + + if dataset_type == "instruct": + text = f"Instruction: {item.get('instruction', '')}\nResponse: {item.get('response', '')}" + else: + text = str(item.get("text", item.get("content", ""))) + + sample = DatasetSample( + id=f"primus_{dataset_type}_{split}_{idx}", + text=text, + metadata={ + "source": f"PRIMUS-{dataset_type.capitalize()}", + "license": "ODC-BY", + "split": split, + "original_domain": domain, + }, + domain=domain, + labels=item.get("labels", item.get("tags", [])), + ) + + samples.append(sample) + count += 1 + + if dataset_type == "seed": + self._seed_samples = samples + else: + self._instruct_samples = samples + + logger.info(f"Loaded {len(samples)} PRIMUS {dataset_type} samples") + return samples + + except ImportError: + logger.error("datasets library not installed. Run: pip install datasets") + raise + except Exception as e: + if "gated dataset" in str(e): + logger.error( + f"PRIMUS is a gated dataset. Please authenticate with HuggingFace:\n" + f"1. Create account at https://huggingface.co/\n" + f"2. Accept dataset terms at https://huggingface.co/datasets/{dataset_name}\n" + f"3. Create token at https://huggingface.co/settings/tokens\n" + f"4. Run: huggingface-cli login" + ) + else: + logger.error(f"Failed to load PRIMUS: {e}") + raise + + def load_seed(self, max_samples: int | None = None) -> list[DatasetSample]: + """Load PRIMUS-Seed knowledge base.""" + return self.load(dataset_type="seed", max_samples=max_samples) + + def load_instruct(self) -> list[DatasetSample]: + """Load PRIMUS-Instruct fine-tuning data.""" + return self.load(dataset_type="instruct", streaming=False) + + def get_statistics(self) -> DatasetStatistics: + """Get statistics about loaded PRIMUS data.""" + all_samples = self._seed_samples + self._instruct_samples + + if not all_samples: + raise ValueError("No samples loaded. Call load() first.") + + domain_dist = {} + total_length = 0 + + for sample in all_samples: + domain = sample.domain or "unknown" + domain_dist[domain] = domain_dist.get(domain, 0) + 1 + total_length += len(sample.text) + + return DatasetStatistics( + total_samples=len(all_samples), + domains=domain_dist, + avg_text_length=total_length / len(all_samples), + difficulty_distribution={"cybersecurity": len(all_samples)}, + ) + + def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]: + """Iterate over all loaded samples in batches.""" + all_samples = self._seed_samples + self._instruct_samples + + if not all_samples: + raise ValueError("No samples loaded. Call load() first.") + + for i in range(0, len(all_samples), batch_size): + yield all_samples[i : i + batch_size] + + def get_mitre_attack_samples(self) -> list[DatasetSample]: + """Get samples specifically from MITRE ATT&CK.""" + return [s for s in self._seed_samples if "mitre" in (s.domain or "").lower()] + + def get_threat_intelligence_samples(self) -> list[DatasetSample]: + """Get threat intelligence related samples.""" + return [ + s + for s in self._seed_samples + if any(kw in (s.domain or "").lower() for kw in ["threat", "cti", "intelligence"]) + ] + + +class CombinedDatasetLoader: + """ + Unified loader for combining multiple datasets. + + Provides a single interface for loading and managing: + - DABStep (multi-step reasoning) + - PRIMUS (cybersecurity knowledge) + - Custom tactical datasets + """ + + def __init__(self, cache_dir: str | None = None): + """Initialize combined loader.""" + self.cache_dir = cache_dir + self.dabstep_loader = DABStepLoader(cache_dir) + self.primus_loader = PRIMUSLoader(cache_dir) + self._all_samples: list[DatasetSample] = [] + + def load_all( + self, + dabstep_split: str = "train", + primus_max_samples: int | None = 10000, + include_instruct: bool = True, + ) -> list[DatasetSample]: + """ + Load all datasets. + + Args: + dabstep_split: Split for DABStep + primus_max_samples: Max samples from PRIMUS-Seed (None for all) + include_instruct: Whether to include PRIMUS-Instruct + + Returns: + Combined list of all samples + """ + logger.info("Loading combined datasets") + + # Load DABStep + dabstep_samples = self.dabstep_loader.load(split=dabstep_split) + logger.info(f"DABStep: {len(dabstep_samples)} samples") + + # Load PRIMUS-Seed + primus_seed = self.primus_loader.load_seed(max_samples=primus_max_samples) + logger.info(f"PRIMUS-Seed: {len(primus_seed)} samples") + + # Load PRIMUS-Instruct + primus_instruct = [] + if include_instruct: + primus_instruct = self.primus_loader.load_instruct() + logger.info(f"PRIMUS-Instruct: {len(primus_instruct)} samples") + + self._all_samples = dabstep_samples + primus_seed + primus_instruct + logger.info(f"Total combined samples: {len(self._all_samples)}") + + return self._all_samples + + def get_domain_distribution(self) -> dict[str, int]: + """Get distribution of samples across domains.""" + dist = {} + for sample in self._all_samples: + domain = sample.domain or "unknown" + dist[domain] = dist.get(domain, 0) + 1 + return dist + + def filter_by_domain(self, domain: str) -> list[DatasetSample]: + """Filter samples by domain.""" + return [s for s in self._all_samples if s.domain == domain] + + def get_multi_step_reasoning_samples(self) -> list[DatasetSample]: + """Get samples suitable for multi-step reasoning training.""" + return [ + s + for s in self._all_samples + if s.reasoning_steps or s.domain == "data_analysis" or "instruct" in s.metadata.get("source", "").lower() + ] + + def export_for_training(self, output_path: str, format: str = "jsonl") -> str: + """ + Export dataset for training. + + Args: + output_path: Path to save exported data + format: Export format ('jsonl', 'csv', 'parquet') + + Returns: + Path to exported file + """ + import json + + output_file = Path(output_path) + output_file.parent.mkdir(parents=True, exist_ok=True) + + if format == "jsonl": + with open(output_file, "w", encoding="utf-8") as f: + for sample in self._all_samples: + record = { + "id": sample.id, + "text": sample.text, + "domain": sample.domain, + "difficulty": sample.difficulty, + "labels": sample.labels, + "metadata": sample.metadata, + } + f.write(json.dumps(record) + "\n") + else: + raise NotImplementedError(f"Format {format} not yet supported") + + logger.info(f"Exported {len(self._all_samples)} samples to {output_file}") + return str(output_file) + + +def load_dataset( + dataset_name: str, + split: str = "train", + cache_dir: str | None = None, + **kwargs, +) -> Any: + """ + Unified interface for loading datasets from HuggingFace. + + This function provides compatibility with the standard HuggingFace datasets API. + It wraps the underlying load_dataset function from the datasets library. + + Args: + dataset_name: HuggingFace dataset identifier (e.g., "adyen/DABstep") + split: Dataset split to load ("train", "validation", "test") + cache_dir: Optional directory for caching downloaded datasets + **kwargs: Additional arguments passed to datasets.load_dataset + + Returns: + HuggingFace Dataset object or dict of Dataset objects + + Raises: + ImportError: If datasets library is not installed + Exception: If dataset loading fails + + Examples: + >>> # Load DABStep dataset + >>> dataset = load_dataset("adyen/DABstep") + >>> samples = dataset["train"] + + >>> # Load PRIMUS-Seed with custom cache + >>> dataset = load_dataset("trendmicro-ailab/Primus-Seed", cache_dir="/tmp/cache") + + License Attribution: + - DABStep: CC-BY-4.0 (Creative Commons Attribution 4.0) + - PRIMUS: ODC-BY (Open Data Commons Attribution) + """ + try: + from datasets import load_dataset as hf_load_dataset + + logger.info(f"Loading dataset: {dataset_name} (split={split})") + + load_kwargs = { + **kwargs, + } + + if cache_dir: + load_kwargs["cache_dir"] = cache_dir + + dataset = hf_load_dataset(dataset_name, **load_kwargs) + + logger.info(f"Successfully loaded dataset: {dataset_name}") + return dataset + + except ImportError: + logger.error("datasets library not installed. Run: pip install datasets") + raise ImportError("The datasets library is required but not installed. Install it with: pip install datasets") + except Exception as e: + logger.error(f"Failed to load dataset {dataset_name}: {e}") + raise diff --git a/src/data/preprocessing.py b/src/data/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..4809af5e6b2cec0774c74f1e8cbfd53a3bf81276 --- /dev/null +++ b/src/data/preprocessing.py @@ -0,0 +1,406 @@ +""" +Text Preprocessing Module for Training Data. + +Provides utilities for: +- Text cleaning and normalization +- Tokenization with various backends +- Feature extraction for meta-controller training +""" + +import logging +import re +from dataclasses import dataclass +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class PreprocessedText: + """Preprocessed text with metadata.""" + + original: str + cleaned: str + tokens: list[str] + token_ids: list[int] | None = None + features: dict[str, Any] | None = None + + +class TextPreprocessor: + """ + Text preprocessing pipeline for multi-agent training data. + + Handles: + - HTML/XML tag removal + - Special character normalization + - Whitespace cleanup + - Domain-specific preprocessing (cyber, military, etc.) + """ + + # Patterns for cleaning + HTML_TAG_PATTERN = re.compile(r"<[^>]+>") + URL_PATTERN = re.compile(r"https?://\S+|www\.\S+") + MULTIPLE_SPACES = re.compile(r"\s+") + SPECIAL_CHARS = re.compile(r"[^\w\s\-.,!?;:()[\]{}\"'/]") + + # Domain-specific patterns + IP_ADDRESS_PATTERN = re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b") + CVE_PATTERN = re.compile(r"CVE-\d{4}-\d{4,}") + MITRE_TECHNIQUE_PATTERN = re.compile(r"T\d{4}(?:\.\d{3})?") + + def __init__( + self, + remove_html: bool = True, + normalize_urls: bool = True, + lowercase: bool = False, + preserve_domain_patterns: bool = True, + ): + """ + Initialize preprocessor. + + Args: + remove_html: Remove HTML/XML tags + normalize_urls: Replace URLs with placeholder + lowercase: Convert to lowercase + preserve_domain_patterns: Keep domain-specific patterns (IPs, CVEs, etc.) + """ + self.remove_html = remove_html + self.normalize_urls = normalize_urls + self.lowercase = lowercase + self.preserve_domain_patterns = preserve_domain_patterns + + def clean(self, text: str) -> str: + """ + Clean and normalize text. + + Args: + text: Raw input text + + Returns: + Cleaned text + """ + if not text: + return "" + + result = text + + # Remove HTML tags + if self.remove_html: + result = self.HTML_TAG_PATTERN.sub(" ", result) + + # Preserve or normalize URLs + if self.normalize_urls: + if self.preserve_domain_patterns: + result = self.URL_PATTERN.sub("[URL]", result) + else: + result = self.URL_PATTERN.sub("", result) + + # Normalize whitespace + result = self.MULTIPLE_SPACES.sub(" ", result) + + # Lowercase if requested + if self.lowercase: + result = result.lower() + + # Strip leading/trailing whitespace + result = result.strip() + + return result + + def extract_domain_features(self, text: str) -> dict[str, Any]: + """ + Extract domain-specific features from text. + + Args: + text: Input text + + Returns: + Dictionary of extracted features + """ + features = { + "has_ip_addresses": bool(self.IP_ADDRESS_PATTERN.search(text)), + "ip_count": len(self.IP_ADDRESS_PATTERN.findall(text)), + "has_cve": bool(self.CVE_PATTERN.search(text)), + "cve_ids": self.CVE_PATTERN.findall(text), + "has_mitre_techniques": bool(self.MITRE_TECHNIQUE_PATTERN.search(text)), + "mitre_techniques": self.MITRE_TECHNIQUE_PATTERN.findall(text), + "text_length": len(text), + "word_count": len(text.split()), + "sentence_count": len(re.findall(r"[.!?]+", text)), + } + + # Detect domain indicators + domain_keywords = { + "cybersecurity": ["attack", "vulnerability", "exploit", "malware", "threat"], + "military": ["tactical", "reconnaissance", "deployment", "terrain", "objective"], + "data_analysis": ["dataset", "analysis", "correlation", "statistics", "visualization"], + } + + for domain, keywords in domain_keywords.items(): + features[f"is_{domain}"] = any(kw in text.lower() for kw in keywords) + + return features + + def preprocess(self, text: str) -> PreprocessedText: + """ + Full preprocessing pipeline. + + Args: + text: Raw input text + + Returns: + PreprocessedText object with all preprocessing results + """ + cleaned = self.clean(text) + tokens = cleaned.split() # Simple whitespace tokenization + features = self.extract_domain_features(text) + + return PreprocessedText( + original=text, + cleaned=cleaned, + tokens=tokens, + features=features, + ) + + def batch_preprocess(self, texts: list[str]) -> list[PreprocessedText]: + """ + Preprocess multiple texts. + + Args: + texts: List of raw texts + + Returns: + List of PreprocessedText objects + """ + return [self.preprocess(text) for text in texts] + + +class TokenizerWrapper: + """ + Wrapper for various tokenization backends. + + Supports: + - Simple whitespace tokenization + - HuggingFace tokenizers + - Custom vocabularies + """ + + def __init__( + self, + backend: str = "simple", + model_name: str | None = None, + max_length: int = 512, + ): + """ + Initialize tokenizer. + + Args: + backend: Tokenizer backend ('simple', 'huggingface', 'custom') + model_name: Model name for HuggingFace tokenizer + max_length: Maximum sequence length + """ + self.backend = backend + self.model_name = model_name + self.max_length = max_length + self._tokenizer = None + + if backend == "huggingface" and model_name: + self._load_huggingface_tokenizer() + + def _load_huggingface_tokenizer(self): + """Load HuggingFace tokenizer.""" + try: + from transformers import AutoTokenizer + + self._tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + model_max_length=self.max_length, + ) + logger.info(f"Loaded HuggingFace tokenizer: {self.model_name}") + except ImportError: + logger.error("transformers library not installed. Run: pip install transformers") + raise + + def tokenize(self, text: str) -> tuple[list[str], list[int] | None]: + """ + Tokenize text. + + Args: + text: Input text + + Returns: + Tuple of (tokens, token_ids) + """ + if self.backend == "simple": + tokens = text.split()[: self.max_length] + return tokens, None + + elif self.backend == "huggingface" and self._tokenizer: + encoded = self._tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_tensors=None, + ) + tokens = self._tokenizer.convert_ids_to_tokens(encoded["input_ids"]) + token_ids = encoded["input_ids"] + return tokens, token_ids + + else: + raise ValueError(f"Unsupported backend: {self.backend}") + + def batch_tokenize(self, texts: list[str]) -> list[tuple[list[str], list[int] | None]]: + """ + Tokenize multiple texts. + + Args: + texts: List of input texts + + Returns: + List of (tokens, token_ids) tuples + """ + return [self.tokenize(text) for text in texts] + + def encode_for_training(self, texts: list[str]) -> dict[str, Any]: + """ + Encode texts for model training. + + Args: + texts: List of input texts + + Returns: + Dictionary with encoded data ready for training + """ + if self.backend != "huggingface" or not self._tokenizer: + raise ValueError("encode_for_training requires HuggingFace backend") + + encoded = self._tokenizer( + texts, + truncation=True, + padding=True, + max_length=self.max_length, + return_tensors="pt", + ) + + return encoded + + +class MetaControllerFeatureExtractor: + """ + Extract features for meta-controller training. + + Converts text and agent state information into numerical features + suitable for RNN/BERT routing decisions. + """ + + def __init__(self): + """Initialize feature extractor.""" + self.preprocessor = TextPreprocessor() + + def extract_query_features(self, query: str) -> dict[str, float]: + """ + Extract numerical features from query text. + + Args: + query: User query text + + Returns: + Dictionary of numerical features + """ + domain_features = self.preprocessor.extract_domain_features(query) + + features = { + "query_length": domain_features["text_length"] / 10000, # Normalize + "word_count": domain_features["word_count"] / 500, + "sentence_count": domain_features["sentence_count"] / 50, + "has_technical_terms": float( + domain_features["has_ip_addresses"] + or domain_features["has_cve"] + or domain_features["has_mitre_techniques"] + ), + "is_cybersecurity": float(domain_features["is_cybersecurity"]), + "is_military": float(domain_features["is_military"]), + "is_data_analysis": float(domain_features["is_data_analysis"]), + "complexity_score": self._estimate_complexity(query), + } + + return features + + def _estimate_complexity(self, text: str) -> float: + """ + Estimate query complexity (0-1 scale). + + Args: + text: Input text + + Returns: + Complexity score + """ + # Simple heuristic based on length, technical terms, etc. + score = 0.0 + + # Length factor + word_count = len(text.split()) + if word_count > 50: + score += 0.3 + elif word_count > 20: + score += 0.1 + + # Technical term factor + technical_indicators = [ + "analyze", + "compare", + "evaluate", + "synthesize", + "strategic", + "tactical", + "multi-step", + "consider", + ] + for term in technical_indicators: + if term in text.lower(): + score += 0.1 + + # Question complexity + if "?" in text: + if any(kw in text.lower() for kw in ["why", "how", "what if"]): + score += 0.2 + else: + score += 0.1 + + return min(score, 1.0) + + def extract_agent_state_features( + self, + hrm_confidence: float = 0.0, + trm_confidence: float = 0.0, + mcts_iterations: int = 0, + consensus_score: float = 0.0, + rag_retrieved: int = 0, + ) -> list[float]: + """ + Extract features from current agent state. + + Args: + hrm_confidence: HRM agent confidence + trm_confidence: TRM agent confidence + mcts_iterations: MCTS iterations completed + consensus_score: Inter-agent consensus + rag_retrieved: Number of RAG documents retrieved + + Returns: + List of normalized features (10-dimensional) + """ + return [ + hrm_confidence, + trm_confidence, + min(mcts_iterations / 1000, 1.0), + consensus_score, + min(rag_retrieved / 20, 1.0), + # Derived features + abs(hrm_confidence - trm_confidence), # Disagreement + (hrm_confidence + trm_confidence) / 2, # Average confidence + float(mcts_iterations > 0), # MCTS active + float(consensus_score > 0.7), # High consensus + float(rag_retrieved > 0), # RAG used + ] diff --git a/src/data/tactical_augmentation.py b/src/data/tactical_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..da3456a68bb491c1bd36c9f5df449610cd67cdcd --- /dev/null +++ b/src/data/tactical_augmentation.py @@ -0,0 +1,484 @@ +""" +Tactical Data Augmentation Module. + +Provides domain-specific data augmentation techniques for: +- Cybersecurity threat scenarios +- Military tactical situations +- Multi-step reasoning problems + +These augmentations help increase training data diversity and improve +model robustness for tactical analysis tasks. +""" + +import logging +import random +from dataclasses import dataclass + +from .dataset_loader import DatasetSample + +logger = logging.getLogger(__name__) + + +@dataclass +class AugmentationResult: + """Result of data augmentation.""" + + original: DatasetSample + augmented: list[DatasetSample] + augmentation_types: list[str] + + +class TacticalAugmenter: + """ + Domain-specific data augmentation for tactical analysis. + + Augmentation techniques: + - Paraphrasing tactical scenarios + - Varying urgency levels + - Adding/removing constraints + - Scenario parameter variation + - Threat actor substitution + - Temporal shifting + """ + + # Tactical scenario templates + URGENCY_MODIFIERS = { + "high": ["IMMEDIATE", "CRITICAL", "URGENT", "TIME-SENSITIVE"], + "medium": ["PRIORITY", "IMPORTANT", "ATTENTION REQUIRED"], + "low": ["ROUTINE", "STANDARD", "WHEN POSSIBLE"], + } + + THREAT_ACTORS = [ + "APT28", + "APT29", + "Lazarus Group", + "Cozy Bear", + "Fancy Bear", + "Unknown Actor", + "Nation-State Actor", + "Criminal Organization", + ] + + ATTACK_VECTORS = [ + "phishing", + "spear-phishing", + "watering hole", + "supply chain compromise", + "zero-day exploit", + "credential stuffing", + "brute force", + "social engineering", + ] + + MILITARY_OBJECTIVES = [ + "secure perimeter", + "establish forward position", + "conduct reconnaissance", + "neutralize threat", + "protect assets", + "maintain operational security", + "coordinate with allied forces", + "execute tactical withdrawal", + ] + + ENVIRONMENTAL_CONDITIONS = [ + "night operations", + "adverse weather", + "limited visibility", + "urban terrain", + "mountainous region", + "coastal area", + "contested airspace", + "electronic warfare environment", + ] + + def __init__(self, seed: int = 42): + """ + Initialize augmenter. + + Args: + seed: Random seed for reproducibility + """ + self.rng = random.Random(seed) + self._augmentation_count = 0 + + def augment_sample( + self, + sample: DatasetSample, + num_augmentations: int = 3, + techniques: list[str] | None = None, + ) -> AugmentationResult: + """ + Augment a single sample. + + Args: + sample: Original dataset sample + num_augmentations: Number of augmented versions to create + techniques: Specific techniques to use (None for random selection) + + Returns: + AugmentationResult with augmented samples + """ + available_techniques = [ + "urgency_variation", + "parameter_substitution", + "constraint_addition", + "temporal_shift", + "perspective_change", + ] + + if techniques: + available_techniques = [t for t in techniques if t in available_techniques] + + augmented_samples = [] + used_techniques = [] + + for _i in range(num_augmentations): + technique = self.rng.choice(available_techniques) + used_techniques.append(technique) + + augmented_text = self._apply_technique(sample.text, sample.domain, technique) + + aug_sample = DatasetSample( + id=f"{sample.id}_aug_{self._augmentation_count}", + text=augmented_text, + metadata={ + **sample.metadata, + "augmentation": technique, + "original_id": sample.id, + }, + labels=sample.labels, + difficulty=sample.difficulty, + domain=sample.domain, + reasoning_steps=sample.reasoning_steps, + ) + + augmented_samples.append(aug_sample) + self._augmentation_count += 1 + + return AugmentationResult( + original=sample, + augmented=augmented_samples, + augmentation_types=used_techniques, + ) + + def _apply_technique(self, text: str, domain: str | None, technique: str) -> str: + """Apply specific augmentation technique.""" + if technique == "urgency_variation": + return self._augment_urgency(text) + elif technique == "parameter_substitution": + return self._augment_parameters(text, domain) + elif technique == "constraint_addition": + return self._augment_constraints(text, domain) + elif technique == "temporal_shift": + return self._augment_temporal(text) + elif technique == "perspective_change": + return self._augment_perspective(text, domain) + else: + return text + + def _augment_urgency(self, text: str) -> str: + """Vary urgency level in the text.""" + urgency_level = self.rng.choice(list(self.URGENCY_MODIFIERS.keys())) + modifier = self.rng.choice(self.URGENCY_MODIFIERS[urgency_level]) + + # Add urgency prefix + if urgency_level == "high": + return f"[{modifier}] {text}" + elif urgency_level == "medium": + return f"{modifier}: {text}" + else: + return f"({modifier}) {text}" + + def _augment_parameters(self, text: str, domain: str | None) -> str: + """Substitute domain-specific parameters.""" + if domain == "cybersecurity" or "cyber" in text.lower(): + # Substitute threat actors + for actor in self.THREAT_ACTORS: + if actor in text: + new_actor = self.rng.choice([a for a in self.THREAT_ACTORS if a != actor]) + text = text.replace(actor, new_actor) + break + + # Substitute attack vectors + for vector in self.ATTACK_VECTORS: + if vector in text.lower(): + new_vector = self.rng.choice([v for v in self.ATTACK_VECTORS if v != vector]) + text = text.replace(vector, new_vector) + break + + elif domain == "military" or any(kw in text.lower() for kw in ["tactical", "military", "reconnaissance"]): + # Substitute objectives + for obj in self.MILITARY_OBJECTIVES: + if obj in text.lower(): + new_obj = self.rng.choice([o for o in self.MILITARY_OBJECTIVES if o != obj]) + text = text.replace(obj, new_obj) + break + + return text + + def _augment_constraints(self, text: str, domain: str | None) -> str: + """Add additional constraints to the scenario.""" + constraints = [] + + if domain == "cybersecurity": + constraints = [ + "with limited network visibility", + "under active attack", + "with compromised credentials", + "during maintenance window", + "with restricted access to logs", + ] + elif domain == "military": + constraints = [ + "with limited ammunition", + "under communication blackout", + "with reduced personnel", + "in contested environment", + "with time constraint of 2 hours", + ] + else: + constraints = [ + "with incomplete information", + "under time pressure", + "with resource constraints", + "considering multiple stakeholders", + "with conflicting objectives", + ] + + if constraints: + constraint = self.rng.choice(constraints) + return f"{text} [{constraint}]" + + return text + + def _augment_temporal(self, text: str) -> str: + """Shift temporal context.""" + temporal_contexts = [ + "In the past 24 hours, ", + "Over the next week, ", + "Immediately, ", + "During the upcoming operation, ", + "Following initial assessment, ", + ] + + context = self.rng.choice(temporal_contexts) + return f"{context}{text.lower()}" if text else text + + def _augment_perspective(self, text: str, domain: str | None) -> str: + """Change analytical perspective.""" + perspectives = { + "cybersecurity": [ + "From a threat hunter's perspective: ", + "Considering the attacker's viewpoint: ", + "For incident response purposes: ", + "From a risk management standpoint: ", + ], + "military": [ + "From the commander's perspective: ", + "Considering enemy capabilities: ", + "For tactical planning purposes: ", + "From a logistics standpoint: ", + ], + "default": [ + "From an analytical perspective: ", + "Considering all factors: ", + "For decision-making purposes: ", + "From a strategic viewpoint: ", + ], + } + + domain_perspectives = perspectives.get(domain or "default", perspectives["default"]) + perspective = self.rng.choice(domain_perspectives) + + return f"{perspective}{text}" + + def augment_batch( + self, + samples: list[DatasetSample], + augmentations_per_sample: int = 2, + ) -> list[DatasetSample]: + """ + Augment a batch of samples. + + Args: + samples: List of original samples + augmentations_per_sample: Number of augmentations per sample + + Returns: + List of all samples (original + augmented) + """ + all_samples = list(samples) # Keep originals + + for sample in samples: + result = self.augment_sample(sample, num_augmentations=augmentations_per_sample) + all_samples.extend(result.augmented) + + logger.info( + f"Augmented {len(samples)} samples to {len(all_samples)} (+{len(all_samples) - len(samples)} augmented)" + ) + + return all_samples + + def create_tactical_scenarios(self, base_samples: list[DatasetSample]) -> list[DatasetSample]: + """ + Create tactical scenario variations from base samples. + + Combines multiple augmentation techniques to create + diverse tactical scenarios for training. + + Args: + base_samples: Base dataset samples + + Returns: + Extended list with tactical scenario variations + """ + scenarios = list(base_samples) + + for sample in base_samples: + # Create high-stakes variant + high_stakes = self._augment_urgency(sample.text) + high_stakes = self._augment_constraints(high_stakes, sample.domain) + scenarios.append( + DatasetSample( + id=f"{sample.id}_highstakes_{self._augmentation_count}", + text=high_stakes, + metadata={ + **sample.metadata, + "scenario_type": "high_stakes", + "original_id": sample.id, + }, + labels=sample.labels, + difficulty="hard", # High stakes scenarios are harder + domain=sample.domain, + reasoning_steps=sample.reasoning_steps, + ) + ) + self._augmentation_count += 1 + + # Create multi-perspective variant + if self.rng.random() > 0.5: + multi_perspective = self._augment_perspective(sample.text, sample.domain) + scenarios.append( + DatasetSample( + id=f"{sample.id}_multiperspective_{self._augmentation_count}", + text=multi_perspective, + metadata={ + **sample.metadata, + "scenario_type": "multi_perspective", + "original_id": sample.id, + }, + labels=sample.labels, + difficulty=sample.difficulty, + domain=sample.domain, + reasoning_steps=sample.reasoning_steps, + ) + ) + self._augmentation_count += 1 + + logger.info(f"Created {len(scenarios) - len(base_samples)} tactical scenarios") + return scenarios + + +class CyberSecurityAugmenter(TacticalAugmenter): + """ + Specialized augmenter for cybersecurity scenarios. + + Focuses on: + - MITRE ATT&CK technique variations + - Threat intelligence context + - Incident response scenarios + """ + + MITRE_TACTICS = [ + "Initial Access", + "Execution", + "Persistence", + "Privilege Escalation", + "Defense Evasion", + "Credential Access", + "Discovery", + "Lateral Movement", + "Collection", + "Exfiltration", + "Impact", + ] + + SEVERITY_LEVELS = ["LOW", "MEDIUM", "HIGH", "CRITICAL"] + + def augment_with_mitre_context(self, sample: DatasetSample) -> DatasetSample: + """ + Add MITRE ATT&CK context to sample. + + Args: + sample: Original sample + + Returns: + Augmented sample with MITRE context + """ + tactic = self.rng.choice(self.MITRE_TACTICS) + severity = self.rng.choice(self.SEVERITY_LEVELS) + + augmented_text = f"[MITRE ATT&CK: {tactic}] [Severity: {severity}] {sample.text}" + + return DatasetSample( + id=f"{sample.id}_mitre_{self._augmentation_count}", + text=augmented_text, + metadata={ + **sample.metadata, + "mitre_tactic": tactic, + "severity": severity, + }, + labels=sample.labels, + difficulty=sample.difficulty, + domain="cybersecurity", + reasoning_steps=sample.reasoning_steps, + ) + + +class MilitaryTacticalAugmenter(TacticalAugmenter): + """ + Specialized augmenter for military tactical scenarios. + + Focuses on: + - Environmental condition variations + - Force composition changes + - Mission objective variations + """ + + FORCE_COMPOSITIONS = [ + "infantry platoon", + "mechanized company", + "special operations team", + "combined arms battalion", + "air assault element", + ] + + def augment_with_force_composition(self, sample: DatasetSample) -> DatasetSample: + """ + Add force composition context to sample. + + Args: + sample: Original sample + + Returns: + Augmented sample with force composition + """ + force = self.rng.choice(self.FORCE_COMPOSITIONS) + condition = self.rng.choice(self.ENVIRONMENTAL_CONDITIONS) + + augmented_text = f"[Force: {force}] [Conditions: {condition}] {sample.text}" + + return DatasetSample( + id=f"{sample.id}_tactical_{self._augmentation_count}", + text=augmented_text, + metadata={ + **sample.metadata, + "force_composition": force, + "environmental_conditions": condition, + }, + labels=sample.labels, + difficulty=sample.difficulty, + domain="military", + reasoning_steps=sample.reasoning_steps, + ) diff --git a/src/data/train_test_split.py b/src/data/train_test_split.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a8c59e2e7c9ed879db8eea5d79a7a50151eabf --- /dev/null +++ b/src/data/train_test_split.py @@ -0,0 +1,505 @@ +""" +Data Splitting Module for Training Pipeline. + +Provides utilities for: +- Train/validation/test splitting +- Stratified sampling by domain or difficulty +- Cross-validation fold creation +- Reproducible splits with seeding +""" + +import logging +from collections import defaultdict +from dataclasses import dataclass +from typing import Any + +from .dataset_loader import DatasetSample + +logger = logging.getLogger(__name__) + + +@dataclass +class DataSplit: + """Result of dataset splitting.""" + + train: list[DatasetSample] + validation: list[DatasetSample] + test: list[DatasetSample] + split_info: dict[str, Any] + + +@dataclass +class CrossValidationFold: + """Single fold for cross-validation.""" + + fold_id: int + train: list[DatasetSample] + validation: list[DatasetSample] + + +class DataSplitter: + """ + Basic dataset splitter with random sampling. + + Provides reproducible train/validation/test splits + with configurable ratios. + """ + + def __init__(self, seed: int = 42): + """ + Initialize splitter. + + Args: + seed: Random seed for reproducibility + """ + self.seed = seed + import random + + self.rng = random.Random(seed) + + def split( + self, + samples: list[DatasetSample], + train_ratio: float = 0.7, + val_ratio: float = 0.15, + test_ratio: float = 0.15, + shuffle: bool = True, + ) -> DataSplit: + """ + Split dataset into train/validation/test sets. + + Args: + samples: List of all samples + train_ratio: Proportion for training (default 0.7) + val_ratio: Proportion for validation (default 0.15) + test_ratio: Proportion for testing (default 0.15) + shuffle: Whether to shuffle before splitting + + Returns: + DataSplit with train, validation, and test sets + """ + if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001: + raise ValueError("Ratios must sum to 1.0") + + if not samples: + raise ValueError("Cannot split empty sample list") + + # Copy and optionally shuffle + all_samples = list(samples) + if shuffle: + self.rng.shuffle(all_samples) + + n = len(all_samples) + train_end = int(n * train_ratio) + val_end = train_end + int(n * val_ratio) + + train_samples = all_samples[:train_end] + val_samples = all_samples[train_end:val_end] + test_samples = all_samples[val_end:] + + split_info = { + "total_samples": n, + "train_samples": len(train_samples), + "val_samples": len(val_samples), + "test_samples": len(test_samples), + "train_ratio": len(train_samples) / n, + "val_ratio": len(val_samples) / n, + "test_ratio": len(test_samples) / n, + "seed": self.seed, + "shuffled": shuffle, + } + + logger.info(f"Split {n} samples: train={len(train_samples)}, val={len(val_samples)}, test={len(test_samples)}") + + return DataSplit( + train=train_samples, + validation=val_samples, + test=test_samples, + split_info=split_info, + ) + + def create_k_folds( + self, + samples: list[DatasetSample], + k: int = 5, + shuffle: bool = True, + ) -> list[CrossValidationFold]: + """ + Create k-fold cross-validation splits. + + Args: + samples: List of all samples + k: Number of folds + shuffle: Whether to shuffle before splitting + + Returns: + List of CrossValidationFold objects + """ + if k < 2: + raise ValueError("k must be at least 2") + + if len(samples) < k: + raise ValueError(f"Need at least {k} samples for {k}-fold CV") + + # Copy and optionally shuffle + all_samples = list(samples) + if shuffle: + self.rng.shuffle(all_samples) + + # Calculate fold sizes + fold_size = len(all_samples) // k + folds = [] + + for fold_id in range(k): + # Validation is the current fold + val_start = fold_id * fold_size + val_end = len(all_samples) if fold_id == k - 1 else val_start + fold_size # noqa: SIM108 + + val_samples = all_samples[val_start:val_end] + train_samples = all_samples[:val_start] + all_samples[val_end:] + + folds.append( + CrossValidationFold( + fold_id=fold_id, + train=train_samples, + validation=val_samples, + ) + ) + + logger.info(f"Created {k}-fold cross-validation splits") + return folds + + +class StratifiedSplitter(DataSplitter): + """ + Stratified dataset splitter. + + Ensures proportional representation of categories + (domain, difficulty, etc.) across splits. + """ + + def __init__(self, seed: int = 42, stratify_by: str = "domain"): + """ + Initialize stratified splitter. + + Args: + seed: Random seed for reproducibility + stratify_by: Attribute to stratify on ('domain', 'difficulty', 'labels') + """ + super().__init__(seed) + self.stratify_by = stratify_by + + def split( + self, + samples: list[DatasetSample], + train_ratio: float = 0.7, + val_ratio: float = 0.15, + test_ratio: float = 0.15, + shuffle: bool = True, + ) -> DataSplit: + """ + Stratified split maintaining category proportions. + + Args: + samples: List of all samples + train_ratio: Proportion for training + val_ratio: Proportion for validation + test_ratio: Proportion for testing + shuffle: Whether to shuffle before splitting + + Returns: + DataSplit with stratified train, validation, and test sets + """ + if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001: + raise ValueError("Ratios must sum to 1.0") + + if not samples: + raise ValueError("Cannot split empty sample list") + + # Group samples by stratification key + groups = defaultdict(list) + for sample in samples: + key = self._get_stratify_key(sample) + groups[key].append(sample) + + # Split each group proportionally + train_samples = [] + val_samples = [] + test_samples = [] + + for _key, group_samples in groups.items(): + if shuffle: + self.rng.shuffle(group_samples) + + n = len(group_samples) + train_end = int(n * train_ratio) + val_end = train_end + int(n * val_ratio) + + train_samples.extend(group_samples[:train_end]) + val_samples.extend(group_samples[train_end:val_end]) + test_samples.extend(group_samples[val_end:]) + + # Final shuffle of combined sets + if shuffle: + self.rng.shuffle(train_samples) + self.rng.shuffle(val_samples) + self.rng.shuffle(test_samples) + + # Verify stratification + stratify_info = self._verify_stratification(train_samples, val_samples, test_samples) + + split_info = { + "total_samples": len(samples), + "train_samples": len(train_samples), + "val_samples": len(val_samples), + "test_samples": len(test_samples), + "train_ratio": len(train_samples) / len(samples), + "val_ratio": len(val_samples) / len(samples), + "test_ratio": len(test_samples) / len(samples), + "stratify_by": self.stratify_by, + "stratification_info": stratify_info, + "seed": self.seed, + "shuffled": shuffle, + } + + logger.info( + f"Stratified split ({self.stratify_by}): " + f"train={len(train_samples)}, val={len(val_samples)}, " + f"test={len(test_samples)}" + ) + + return DataSplit( + train=train_samples, + validation=val_samples, + test=test_samples, + split_info=split_info, + ) + + def _get_stratify_key(self, sample: DatasetSample) -> str: + """Get stratification key for a sample.""" + if self.stratify_by == "domain": + return sample.domain or "unknown" + elif self.stratify_by == "difficulty": + return sample.difficulty or "unknown" + elif self.stratify_by == "labels": + return ",".join(sorted(sample.labels)) if sample.labels else "unknown" + else: + return str(getattr(sample, self.stratify_by, "unknown")) + + def _verify_stratification( + self, + train: list[DatasetSample], + val: list[DatasetSample], + test: list[DatasetSample], + ) -> dict[str, dict[str, float]]: + """ + Verify that stratification was successful. + + Returns dictionary showing distribution of stratification key + across train/val/test splits. + """ + + def get_distribution(samples: list[DatasetSample]) -> dict[str, float]: + if not samples: + return {} + counts = defaultdict(int) + for sample in samples: + key = self._get_stratify_key(sample) + counts[key] += 1 + total = len(samples) + return {k: v / total for k, v in counts.items()} + + return { + "train": get_distribution(train), + "validation": get_distribution(val), + "test": get_distribution(test), + } + + def create_stratified_k_folds( + self, + samples: list[DatasetSample], + k: int = 5, + shuffle: bool = True, + ) -> list[CrossValidationFold]: + """ + Create stratified k-fold cross-validation splits. + + Args: + samples: List of all samples + k: Number of folds + shuffle: Whether to shuffle before splitting + + Returns: + List of CrossValidationFold objects with stratification + """ + if k < 2: + raise ValueError("k must be at least 2") + + # Group samples by stratification key + groups = defaultdict(list) + for sample in samples: + key = self._get_stratify_key(sample) + groups[key].append(sample) + + # Initialize folds + folds_data = [{"train": [], "val": []} for _ in range(k)] + + # Distribute each group across folds + for _key, group_samples in groups.items(): + if shuffle: + self.rng.shuffle(group_samples) + + # Assign samples to folds + fold_size = len(group_samples) // k + for fold_id in range(k): + val_start = fold_id * fold_size + val_end = len(group_samples) if fold_id == k - 1 else val_start + fold_size + + for i, sample in enumerate(group_samples): + if val_start <= i < val_end: + folds_data[fold_id]["val"].append(sample) + else: + folds_data[fold_id]["train"].append(sample) + + # Create fold objects + folds = [ + CrossValidationFold( + fold_id=i, + train=data["train"], + validation=data["val"], + ) + for i, data in enumerate(folds_data) + ] + + logger.info(f"Created stratified {k}-fold cross-validation splits") + return folds + + +class BalancedSampler: + """ + Balanced sampling for imbalanced datasets. + + Provides utilities for: + - Oversampling minority classes + - Undersampling majority classes + - SMOTE-like synthetic sampling (for numerical features) + """ + + def __init__(self, seed: int = 42): + """Initialize balanced sampler.""" + self.seed = seed + import random + + self.rng = random.Random(seed) + + def oversample_minority( + self, + samples: list[DatasetSample], + target_key: str = "domain", + target_ratio: float = 1.0, + ) -> list[DatasetSample]: + """ + Oversample minority classes to balance dataset. + + Args: + samples: Original samples + target_key: Attribute to balance on + target_ratio: Target ratio relative to majority (1.0 = equal) + + Returns: + Balanced sample list (originals + oversampled) + """ + # Group by target key + groups = defaultdict(list) + for sample in samples: + key = getattr(sample, target_key, "unknown") or "unknown" + groups[key].append(sample) + + # Find majority class size + max_count = max(len(g) for g in groups.values()) + target_count = int(max_count * target_ratio) + + # Oversample minority classes + balanced = [] + for _key, group in groups.items(): + balanced.extend(group) + + # Oversample if needed + if len(group) < target_count: + num_to_add = target_count - len(group) + for _ in range(num_to_add): + # Randomly duplicate from group + original = self.rng.choice(group) + duplicate = DatasetSample( + id=f"{original.id}_oversample_{self.rng.randint(0, 999999)}", + text=original.text, + metadata={**original.metadata, "oversampled": True}, + labels=original.labels, + difficulty=original.difficulty, + domain=original.domain, + reasoning_steps=original.reasoning_steps, + ) + balanced.append(duplicate) + + logger.info(f"Oversampled from {len(samples)} to {len(balanced)} samples") + return balanced + + def undersample_majority( + self, + samples: list[DatasetSample], + target_key: str = "domain", + target_ratio: float = 1.0, + ) -> list[DatasetSample]: + """ + Undersample majority classes to balance dataset. + + Args: + samples: Original samples + target_key: Attribute to balance on + target_ratio: Target ratio relative to minority (1.0 = equal) + + Returns: + Balanced sample list (subset of originals) + """ + # Group by target key + groups = defaultdict(list) + for sample in samples: + key = getattr(sample, target_key, "unknown") or "unknown" + groups[key].append(sample) + + # Find minority class size + min_count = min(len(g) for g in groups.values()) + target_count = int(min_count * target_ratio) + + # Undersample majority classes + balanced = [] + for _key, group in groups.items(): + if len(group) > target_count: + # Randomly select target_count samples + balanced.extend(self.rng.sample(group, target_count)) + else: + balanced.extend(group) + + logger.info(f"Undersampled from {len(samples)} to {len(balanced)} samples") + return balanced + + def get_class_distribution( + self, + samples: list[DatasetSample], + target_key: str = "domain", + ) -> dict[str, int]: + """ + Get distribution of classes. + + Args: + samples: Sample list + target_key: Attribute to analyze + + Returns: + Dictionary of class counts + """ + distribution = defaultdict(int) + for sample in samples: + key = getattr(sample, target_key, "unknown") or "unknown" + distribution[key] += 1 + return dict(distribution) diff --git a/src/framework/__init__.py b/src/framework/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a26dae4aa5e0cd67a0fdc606e779153ca0ca7b0 --- /dev/null +++ b/src/framework/__init__.py @@ -0,0 +1 @@ +# Framework module diff --git a/src/framework/agents/__init__.py b/src/framework/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb70202614c10dae0c0b114b12d74eb82eba81e --- /dev/null +++ b/src/framework/agents/__init__.py @@ -0,0 +1,22 @@ +# Agents module for async agent implementations +from .base import ( + AgentContext, + AgentResult, + AsyncAgentBase, + CompositeAgent, + MetricsCollector, + NoOpMetricsCollector, + ParallelAgent, + SequentialAgent, +) + +__all__ = [ + "AsyncAgentBase", + "AgentContext", + "AgentResult", + "MetricsCollector", + "NoOpMetricsCollector", + "CompositeAgent", + "ParallelAgent", + "SequentialAgent", +] diff --git a/src/framework/agents/base.py b/src/framework/agents/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7f24fb6960f5c4aac9e0b19919353ce1478959e6 --- /dev/null +++ b/src/framework/agents/base.py @@ -0,0 +1,533 @@ +""" +Base async agent class for Multi-Agent MCTS Framework. + +Provides common patterns for all agents with hook points for metrics, +logging, and extensibility. +""" + +import asyncio +import logging +import time +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Protocol + +from src.adapters.llm.base import LLMClient, LLMResponse + + +@dataclass +class AgentContext: + """ + Context passed to agent during processing. + + Contains all information needed for the agent to process a request. + """ + + query: str + session_id: str = field(default_factory=lambda: str(uuid.uuid4())) + rag_context: str | None = None + metadata: dict = field(default_factory=dict) + conversation_history: list[dict] = field(default_factory=list) + max_iterations: int = 5 + temperature: float = 0.7 + additional_context: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + """Convert context to dictionary.""" + return { + "query": self.query, + "session_id": self.session_id, + "rag_context": self.rag_context, + "metadata": self.metadata, + "conversation_history": self.conversation_history, + "max_iterations": self.max_iterations, + "temperature": self.temperature, + "additional_context": self.additional_context, + } + + +@dataclass +class AgentResult: + """ + Result from agent processing. + + Standardized result format for all agents. + """ + + response: str + confidence: float = 0.0 + metadata: dict = field(default_factory=dict) + agent_name: str = "" + processing_time_ms: float = 0.0 + token_usage: dict = field(default_factory=dict) + intermediate_steps: list[dict] = field(default_factory=list) + created_at: datetime = field(default_factory=datetime.utcnow) + error: str | None = None + success: bool = True + + def to_dict(self) -> dict: + """Convert result to dictionary.""" + return { + "response": self.response, + "confidence": self.confidence, + "metadata": self.metadata, + "agent_name": self.agent_name, + "processing_time_ms": self.processing_time_ms, + "token_usage": self.token_usage, + "intermediate_steps": self.intermediate_steps, + "created_at": self.created_at.isoformat(), + "error": self.error, + "success": self.success, + } + + +class MetricsCollector(Protocol): + """Protocol for metrics collection.""" + + def record_latency(self, agent_name: str, latency_ms: float) -> None: ... + def record_tokens(self, agent_name: str, tokens: int) -> None: ... + def record_error(self, agent_name: str, error_type: str) -> None: ... + def record_success(self, agent_name: str) -> None: ... + + +class NoOpMetricsCollector: + """Default no-op metrics collector.""" + + def record_latency(self, agent_name: str, latency_ms: float) -> None: + pass + + def record_tokens(self, agent_name: str, tokens: int) -> None: + pass + + def record_error(self, agent_name: str, error_type: str) -> None: + pass + + def record_success(self, agent_name: str) -> None: + pass + + +class AsyncAgentBase(ABC): + """ + Base class for async agents in the Multi-Agent MCTS Framework. + + Features: + - Async processing by default + - Hook points for metrics/logging + - Lifecycle management + - Error handling patterns + - Backward compatibility with existing framework + """ + + def __init__( + self, + model_adapter: LLMClient, + logger: Any = None, + name: str | None = None, + metrics_collector: MetricsCollector | None = None, + **config: Any, + ): + """ + Initialize async agent. + + Args: + model_adapter: LLM client for generating responses + logger: Logger instance (uses standard logging if None) + name: Agent name (uses class name if None) + metrics_collector: Optional metrics collector + **config: Additional configuration parameters + """ + self.model_adapter = model_adapter + self.logger = logger or logging.getLogger(self.__class__.__name__) + self.name = name or self.__class__.__name__ + self.metrics = metrics_collector or NoOpMetricsCollector() + self.config = config + + # Runtime state + self._request_count = 0 + self._total_processing_time = 0.0 + self._error_count = 0 + self._initialized = False + + async def initialize(self) -> None: + """ + Initialize agent resources. + + Override this method to perform async initialization tasks + like loading prompts, setting up connections, etc. + """ + self._initialized = True + self.logger.info(f"Agent {self.name} initialized") + + async def shutdown(self) -> None: + """ + Clean up agent resources. + + Override this method to perform cleanup tasks. + """ + self._initialized = False + self.logger.info(f"Agent {self.name} shutdown") + + async def __aenter__(self): + """Async context manager entry.""" + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.shutdown() + + # Hook points for subclasses + async def pre_process(self, context: AgentContext) -> AgentContext: + """ + Hook called before processing. + + Override to modify context or perform pre-processing. + + Args: + context: Agent context + + Returns: + Potentially modified context + """ + return context + + async def post_process(self, _context: AgentContext, result: AgentResult) -> AgentResult: + """ + Hook called after processing. + + Override to modify result or perform post-processing. + + Args: + context: Agent context + result: Agent result + + Returns: + Potentially modified result + """ + return result + + async def on_error(self, _context: AgentContext, error: Exception) -> AgentResult: + """ + Hook called when processing fails. + + Override to customize error handling. + + Args: + context: Agent context + error: The exception that occurred + + Returns: + Error result + """ + self.logger.error(f"Agent {self.name} error: {error}") + self._error_count += 1 + self.metrics.record_error(self.name, type(error).__name__) + + return AgentResult( + response="", + confidence=0.0, + agent_name=self.name, + error=str(error), + success=False, + ) + + @abstractmethod + async def _process_impl(self, context: AgentContext) -> AgentResult: + """ + Core processing logic to be implemented by subclasses. + + Args: + context: Agent context with all necessary information + + Returns: + AgentResult with response and metadata + """ + pass + + async def process( + self, + query: str | None = None, + context: AgentContext | None = None, + *, + rag_context: str | None = None, + **kwargs: Any, + ) -> dict: + """ + Process a query and return structured response. + + This method provides backward compatibility with the existing + LangGraphMultiAgentFramework while using the new async patterns. + + Args: + query: Query string (if not using context object) + context: Full context object (if not using query string) + rag_context: RAG context (used if query provided) + **kwargs: Additional parameters merged into context + + Returns: + Dictionary with 'response' and 'metadata' keys for backward compatibility + """ + # Build context if not provided + if context is None: + if query is None: + raise ValueError("Either 'query' or 'context' must be provided") + context = AgentContext( + query=query, + rag_context=rag_context, + additional_context=kwargs, + ) + + # Ensure initialized + if not self._initialized: + await self.initialize() + + # Track timing + start_time = time.perf_counter() + + try: + # Pre-processing hook + context = await self.pre_process(context) + + # Core processing + result = await self._process_impl(context) + + # Calculate timing + elapsed_ms = (time.perf_counter() - start_time) * 1000 + result.processing_time_ms = elapsed_ms + result.agent_name = self.name + + # Update stats + self._request_count += 1 + self._total_processing_time += elapsed_ms + self.metrics.record_latency(self.name, elapsed_ms) + if result.token_usage: + self.metrics.record_tokens(self.name, result.token_usage.get("total_tokens", 0)) + self.metrics.record_success(self.name) + + # Post-processing hook + result = await self.post_process(context, result) + + self.logger.info(f"Agent {self.name} processed query in {elapsed_ms:.2f}ms") + + except Exception as e: + result = await self.on_error(context, e) + result.processing_time_ms = (time.perf_counter() - start_time) * 1000 + + # Return backward-compatible format + return { + "response": result.response, + "metadata": { + **result.metadata, + "agent_name": result.agent_name, + "confidence": result.confidence, + "processing_time_ms": result.processing_time_ms, + "token_usage": result.token_usage, + "success": result.success, + "error": result.error, + }, + } + + @property + def stats(self) -> dict: + """Get agent statistics.""" + return { + "name": self.name, + "request_count": self._request_count, + "total_processing_time_ms": self._total_processing_time, + "error_count": self._error_count, + "average_processing_time_ms": ( + self._total_processing_time / self._request_count if self._request_count > 0 else 0.0 + ), + "initialized": self._initialized, + } + + async def generate_llm_response( + self, + prompt: str | None = None, + messages: list[dict] | None = None, + temperature: float = 0.7, + max_tokens: int | None = None, + **kwargs: Any, + ) -> LLMResponse: + """ + Convenience method to generate LLM response with error handling. + + Args: + prompt: Simple string prompt + messages: Chat messages + temperature: Sampling temperature + max_tokens: Max tokens to generate + **kwargs: Additional parameters + + Returns: + LLMResponse from the model adapter + """ + response = await self.model_adapter.generate( + prompt=prompt, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + **kwargs, + ) + return response + + +class CompositeAgent(AsyncAgentBase): + """ + Agent that combines multiple sub-agents. + + Useful for creating complex agents from simpler building blocks. + """ + + def __init__( + self, + model_adapter: LLMClient, + logger: Any = None, + name: str = "CompositeAgent", + sub_agents: list[AsyncAgentBase] | None = None, + **config: Any, + ): + super().__init__(model_adapter, logger, name, **config) + self.sub_agents = sub_agents or [] + + def add_agent(self, agent: AsyncAgentBase) -> None: + """Add a sub-agent.""" + self.sub_agents.append(agent) + + async def initialize(self) -> None: + """Initialize all sub-agents.""" + await super().initialize() + for agent in self.sub_agents: + await agent.initialize() + + async def shutdown(self) -> None: + """Shutdown all sub-agents.""" + for agent in self.sub_agents: + await agent.shutdown() + await super().shutdown() + + +class ParallelAgent(CompositeAgent): + """ + Execute multiple agents in parallel and aggregate results. + """ + + async def _process_impl(self, context: AgentContext) -> AgentResult: + """Execute all sub-agents in parallel.""" + if not self.sub_agents: + return AgentResult( + response="No sub-agents configured", + confidence=0.0, + agent_name=self.name, + ) + + # Run all agents concurrently + tasks = [agent.process(context=context) for agent in self.sub_agents] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Aggregate results + successful_results = [] + errors = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + errors.append(f"{self.sub_agents[i].name}: {str(result)}") + elif isinstance(result, dict) and result.get("metadata", {}).get("success", True): + successful_results.append(result) + else: + if isinstance(result, dict): + errors.append( + f"{self.sub_agents[i].name}: {result.get('metadata', {}).get('error', 'Unknown error')}" + ) + + if not successful_results: + return AgentResult( + response=f"All sub-agents failed: {'; '.join(errors)}", + confidence=0.0, + agent_name=self.name, + success=False, + error="All sub-agents failed", + ) + + # Aggregate: highest confidence wins (simple strategy) + best_result = max(successful_results, key=lambda r: r.get("metadata", {}).get("confidence", 0.0)) + + return AgentResult( + response=best_result["response"], + confidence=best_result.get("metadata", {}).get("confidence", 0.0), + metadata={ + "aggregation_method": "highest_confidence", + "sub_agent_results": successful_results, + "errors": errors, + }, + agent_name=self.name, + ) + + +class SequentialAgent(CompositeAgent): + """ + Execute multiple agents sequentially, passing context through each. + """ + + async def _process_impl(self, context: AgentContext) -> AgentResult: + """Execute sub-agents in sequence.""" + if not self.sub_agents: + return AgentResult( + response="No sub-agents configured", + confidence=0.0, + agent_name=self.name, + ) + + current_context = context + intermediate_results = [] + + for agent in self.sub_agents: + result = await agent.process(context=current_context) + + intermediate_results.append( + { + "agent": agent.name, + "result": result, + } + ) + + # Check for failure + if not result.get("metadata", {}).get("success", True): + return AgentResult( + response=result["response"], + confidence=result.get("metadata", {}).get("confidence", 0.0), + metadata={ + "failed_at": agent.name, + "intermediate_results": intermediate_results, + }, + agent_name=self.name, + success=False, + error=result.get("metadata", {}).get("error"), + ) + + # Update context for next agent + current_context = AgentContext( + query=current_context.query, + session_id=current_context.session_id, + rag_context=result["response"], # Previous output becomes context + metadata={ + **current_context.metadata, + f"{agent.name}_result": result["response"], + }, + additional_context=current_context.additional_context, + ) + + # Final result from last agent + final_result = intermediate_results[-1]["result"] + + return AgentResult( + response=final_result["response"], + confidence=final_result.get("metadata", {}).get("confidence", 0.0), + metadata={ + "pipeline": [r["agent"] for r in intermediate_results], + "intermediate_results": intermediate_results, + }, + agent_name=self.name, + ) diff --git a/src/framework/graph.py b/src/framework/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..40e379cff6da001cad89905b4055c21c183aa9fe --- /dev/null +++ b/src/framework/graph.py @@ -0,0 +1,915 @@ +""" +LangGraph Integration Module - Extract graph building with new MCTS core integration. + +Provides: +- Graph building extracted from LangGraphMultiAgentFramework +- Integration with new deterministic MCTS core +- Backward compatibility with original process() signature +- Support for parallel HRM/TRM execution +""" + +from __future__ import annotations + +import asyncio +import operator +import time +from typing import Annotated, Any, NotRequired, TypedDict + +# LangGraph imports (these would be installed dependencies) +try: + from langgraph.checkpoint.memory import MemorySaver + from langgraph.graph import END, StateGraph +except ImportError: + # Stubs for development without LangGraph installed + StateGraph = None + END = "END" + MemorySaver = None + +# Import new MCTS modules +from .mcts.config import ConfigPreset, MCTSConfig, create_preset_config +from .mcts.core import MCTSEngine, MCTSNode, MCTSState +from .mcts.experiments import ExperimentTracker +from .mcts.policies import ( + HybridRolloutPolicy, +) + +# Neural Meta-Controller imports (optional) +try: + from src.agents.meta_controller.base import ( + AbstractMetaController, + MetaControllerFeatures, + ) + from src.agents.meta_controller.bert_controller import BERTMetaController + from src.agents.meta_controller.config_loader import ( + MetaControllerConfig, + MetaControllerConfigLoader, + ) + from src.agents.meta_controller.rnn_controller import RNNMetaController + + _META_CONTROLLER_AVAILABLE = True +except ImportError: + _META_CONTROLLER_AVAILABLE = False + AbstractMetaController = None # type: ignore + MetaControllerFeatures = None # type: ignore + RNNMetaController = None # type: ignore + BERTMetaController = None # type: ignore + MetaControllerConfig = None # type: ignore + MetaControllerConfigLoader = None # type: ignore + + +class AgentState(TypedDict): + """Shared state for LangGraph agent framework.""" + + # Input + query: str + use_mcts: bool + use_rag: bool + + # RAG context + rag_context: NotRequired[str] + retrieved_docs: NotRequired[list[dict]] + + # Agent results + hrm_results: NotRequired[dict] + trm_results: NotRequired[dict] + agent_outputs: Annotated[list[dict], operator.add] + + # MCTS simulation (updated for new core) + mcts_root: NotRequired[Any] # MCTSNode + mcts_iterations: NotRequired[int] + mcts_best_action: NotRequired[str] + mcts_stats: NotRequired[dict] + mcts_config: NotRequired[dict] + + # Evaluation + confidence_scores: NotRequired[dict[str, float]] + consensus_reached: NotRequired[bool] + consensus_score: NotRequired[float] + + # Control flow + iteration: int + max_iterations: int + + # Neural Meta-Controller (optional) + routing_history: NotRequired[list[dict]] + meta_controller_predictions: NotRequired[list[dict]] + last_routed_agent: NotRequired[str] + + # Output + final_response: NotRequired[str] + metadata: NotRequired[dict] + + +class GraphBuilder: + """ + Builds and configures the LangGraph state machine for multi-agent orchestration. + + Extracts graph building logic from LangGraphMultiAgentFramework for modularity. + """ + + def __init__( + self, + hrm_agent, + trm_agent, + model_adapter, + logger, + vector_store=None, + mcts_config: MCTSConfig | None = None, + top_k_retrieval: int = 5, + max_iterations: int = 3, + consensus_threshold: float = 0.75, + enable_parallel_agents: bool = True, + meta_controller_config: Any | None = None, + ): + """ + Initialize graph builder. + + Args: + hrm_agent: HRM agent instance + trm_agent: TRM agent instance + model_adapter: Model adapter for LLM calls + logger: Logger instance + vector_store: Optional vector store for RAG + mcts_config: MCTS configuration (uses balanced preset if None) + top_k_retrieval: Number of documents for RAG + max_iterations: Maximum agent iterations + consensus_threshold: Threshold for consensus + enable_parallel_agents: Run HRM/TRM in parallel + meta_controller_config: Optional neural meta-controller configuration + """ + self.hrm_agent = hrm_agent + self.trm_agent = trm_agent + self.model_adapter = model_adapter + self.logger = logger + self.vector_store = vector_store + self.top_k_retrieval = top_k_retrieval + self.max_iterations = max_iterations + self.consensus_threshold = consensus_threshold + self.enable_parallel_agents = enable_parallel_agents + + # MCTS configuration + self.mcts_config = mcts_config or create_preset_config(ConfigPreset.BALANCED) + + # MCTS engine with deterministic behavior + self.mcts_engine = MCTSEngine( + seed=self.mcts_config.seed, + exploration_weight=self.mcts_config.exploration_weight, + progressive_widening_k=self.mcts_config.progressive_widening_k, + progressive_widening_alpha=self.mcts_config.progressive_widening_alpha, + max_parallel_rollouts=self.mcts_config.max_parallel_rollouts, + cache_size_limit=self.mcts_config.cache_size_limit, + ) + + # Experiment tracking + self.experiment_tracker = ExperimentTracker(name="langgraph_mcts") + + # Neural Meta-Controller (optional) + self.meta_controller: Any | None = None + self.meta_controller_config = meta_controller_config + self.use_neural_routing = False + + if meta_controller_config is not None: + self._init_meta_controller(meta_controller_config) + + def build_graph(self) -> StateGraph: + """ + Build LangGraph state machine. + + Returns: + Configured StateGraph + """ + if StateGraph is None: + raise ImportError("LangGraph not installed. Install with: pip install langgraph") + + workflow = StateGraph(AgentState) + + # Add nodes + workflow.add_node("entry", self._entry_node) + workflow.add_node("retrieve_context", self._retrieve_context_node) + workflow.add_node("route_decision", self._route_decision_node) + workflow.add_node("parallel_agents", self._parallel_agents_node) + workflow.add_node("hrm_agent", self._hrm_agent_node) + workflow.add_node("trm_agent", self._trm_agent_node) + workflow.add_node("mcts_simulator", self._mcts_simulator_node) + workflow.add_node("aggregate_results", self._aggregate_results_node) + workflow.add_node("evaluate_consensus", self._evaluate_consensus_node) + workflow.add_node("synthesize", self._synthesize_node) + + # Define edges + workflow.set_entry_point("entry") + workflow.add_edge("entry", "retrieve_context") + workflow.add_edge("retrieve_context", "route_decision") + + # Conditional routing + workflow.add_conditional_edges( + "route_decision", + self._route_to_agents, + { + "parallel": "parallel_agents", + "hrm": "hrm_agent", + "trm": "trm_agent", + "mcts": "mcts_simulator", + "aggregate": "aggregate_results", + }, + ) + + # Parallel agents to aggregation + workflow.add_edge("parallel_agents", "aggregate_results") + + # Sequential agent nodes + workflow.add_edge("hrm_agent", "aggregate_results") + workflow.add_edge("trm_agent", "aggregate_results") + workflow.add_edge("mcts_simulator", "aggregate_results") + + # Aggregation to evaluation + workflow.add_edge("aggregate_results", "evaluate_consensus") + + # Conditional consensus check + workflow.add_conditional_edges( + "evaluate_consensus", + self._check_consensus, + { + "synthesize": "synthesize", + "iterate": "route_decision", + }, + ) + + # Synthesis to end + workflow.add_edge("synthesize", END) + + return workflow + + def _entry_node(self, state: AgentState) -> dict: + """Initialize state and parse query.""" + self.logger.info(f"Entry node: {state['query'][:100]}") + return { + "iteration": 0, + "agent_outputs": [], + "mcts_config": self.mcts_config.to_dict(), + } + + def _retrieve_context_node(self, state: AgentState) -> dict: + """Retrieve context from vector store using RAG.""" + if not state.get("use_rag", True) or not self.vector_store: + return {"rag_context": ""} + + query = state["query"] + + # Retrieve documents + docs = self.vector_store.similarity_search(query, k=self.top_k_retrieval) + + # Format context + context = "\n\n".join([doc.page_content for doc in docs]) + + self.logger.info(f"Retrieved {len(docs)} documents") + + return { + "rag_context": context, + "retrieved_docs": [{"content": doc.page_content, "metadata": doc.metadata} for doc in docs], + } + + def _route_decision_node(self, _state: AgentState) -> dict: + """Prepare routing decision.""" + return {} + + def _init_meta_controller(self, config: Any) -> None: + """ + Initialize the neural meta-controller based on configuration. + + Args: + config: MetaControllerConfig or dict with configuration + """ + if not _META_CONTROLLER_AVAILABLE: + self.logger.warning("Meta-controller modules not available. Falling back to rule-based routing.") + return + + try: + # Handle both config object and dict + mc_config = MetaControllerConfigLoader.load_from_dict(config) if isinstance(config, dict) else config + + if not mc_config.enabled: + self.logger.info("Neural meta-controller disabled in config") + return + + # Initialize based on type + if mc_config.type == "rnn": + self.meta_controller = RNNMetaController( + name="GraphBuilder_RNN", + seed=mc_config.inference.seed, + hidden_dim=mc_config.rnn.hidden_dim, + num_layers=mc_config.rnn.num_layers, + dropout=mc_config.rnn.dropout, + device=mc_config.inference.device, + ) + # Load trained model if path specified + if mc_config.rnn.model_path: + self.meta_controller.load_model(mc_config.rnn.model_path) + self.logger.info(f"Loaded RNN model from {mc_config.rnn.model_path}") + + elif mc_config.type == "bert": + self.meta_controller = BERTMetaController( + name="GraphBuilder_BERT", + seed=mc_config.inference.seed, + model_name=mc_config.bert.model_name, + lora_r=mc_config.bert.lora_r, + lora_alpha=mc_config.bert.lora_alpha, + lora_dropout=mc_config.bert.lora_dropout, + device=mc_config.inference.device, + use_lora=mc_config.bert.use_lora, + ) + # Load trained model if path specified + if mc_config.bert.model_path: + self.meta_controller.load_model(mc_config.bert.model_path) + self.logger.info(f"Loaded BERT model from {mc_config.bert.model_path}") + else: + raise ValueError(f"Unknown meta-controller type: {mc_config.type}") + + self.use_neural_routing = True + self.logger.info(f"Initialized {mc_config.type.upper()} neural meta-controller") + + except Exception as e: + self.logger.error(f"Failed to initialize meta-controller: {e}") + if hasattr(config, "fallback_to_rule_based") and config.fallback_to_rule_based: + self.logger.warning("Falling back to rule-based routing") + else: + raise + + def _extract_meta_controller_features(self, state: AgentState) -> Any: + """ + Extract features from AgentState for meta-controller prediction. + + Args: + state: Current agent state + + Returns: + MetaControllerFeatures instance + """ + if not _META_CONTROLLER_AVAILABLE or MetaControllerFeatures is None: + return None + + # Extract HRM confidence + hrm_conf = 0.0 + if "hrm_results" in state: + hrm_conf = state["hrm_results"].get("metadata", {}).get("decomposition_quality_score", 0.5) + + # Extract TRM confidence + trm_conf = 0.0 + if "trm_results" in state: + trm_conf = state["trm_results"].get("metadata", {}).get("final_quality_score", 0.5) + + # Extract MCTS value + mcts_val = 0.0 + if "mcts_stats" in state: + mcts_val = state["mcts_stats"].get("best_action_value", 0.5) + + # Consensus score + consensus = state.get("consensus_score", 0.0) + + # Last agent used + last_agent = state.get("last_routed_agent", "none") + + # Iteration + iteration = state.get("iteration", 0) + + # Query length + query_length = len(state.get("query", "")) + + # Has RAG context + has_rag = bool(state.get("rag_context", "")) + + return MetaControllerFeatures( + hrm_confidence=hrm_conf, + trm_confidence=trm_conf, + mcts_value=mcts_val, + consensus_score=consensus, + last_agent=last_agent, + iteration=iteration, + query_length=query_length, + has_rag_context=has_rag, + ) + + def _neural_route_decision(self, state: AgentState) -> str: + """ + Make routing decision using neural meta-controller. + + Args: + state: Current agent state + + Returns: + Route decision string ("parallel", "hrm", "trm", "mcts", "aggregate") + """ + try: + features = self._extract_meta_controller_features(state) + if features is None: + return self._rule_based_route_decision(state) + + prediction = self.meta_controller.predict(features) + + # Log prediction for debugging + self.logger.debug( + f"Neural routing: agent={prediction.agent}, " + f"confidence={prediction.confidence:.3f}, " + f"probs={prediction.probabilities}" + ) + + # Map agent prediction to route + agent = prediction.agent + + # Handle routing based on predicted agent + state.get("iteration", 0) + + if agent == "hrm": + if "hrm_results" not in state: + return "hrm" + elif agent == "trm": + if "trm_results" not in state: + return "trm" + elif agent == "mcts" and state.get("use_mcts", False) and "mcts_stats" not in state: + return "mcts" + + # If predicted agent already ran or not applicable, use rule-based + return self._rule_based_route_decision(state) + + except Exception as e: + self.logger.error(f"Neural routing failed: {e}") + # Fallback to rule-based routing + return self._rule_based_route_decision(state) + + def _rule_based_route_decision(self, state: AgentState) -> str: + """ + Make routing decision using rule-based logic. + + Args: + state: Current agent state + + Returns: + Route decision string + """ + iteration = state.get("iteration", 0) + + # First iteration: run HRM and TRM + if iteration == 0: + if self.enable_parallel_agents: + if "hrm_results" not in state and "trm_results" not in state: + return "parallel" + else: + if "hrm_results" not in state: + return "hrm" + elif "trm_results" not in state: + return "trm" + + # Run MCTS if enabled and not yet done + if state.get("use_mcts", False) and "mcts_stats" not in state: + return "mcts" + + return "aggregate" + + def _route_to_agents(self, state: AgentState) -> str: + """Route to appropriate agent based on state.""" + # Use neural routing if enabled + if self.use_neural_routing and self.meta_controller is not None: + return self._neural_route_decision(state) + + # Fall back to rule-based routing + return self._rule_based_route_decision(state) + + async def _parallel_agents_node(self, state: AgentState) -> dict: + """Execute HRM and TRM agents in parallel.""" + self.logger.info("Executing HRM and TRM agents in parallel") + + # Run both agents concurrently + hrm_task = asyncio.create_task( + self.hrm_agent.process( + query=state["query"], + rag_context=state.get("rag_context"), + ) + ) + + trm_task = asyncio.create_task( + self.trm_agent.process( + query=state["query"], + rag_context=state.get("rag_context"), + ) + ) + + # Await both results + hrm_result, trm_result = await asyncio.gather(hrm_task, trm_task) + + # Combine outputs + return { + "hrm_results": { + "response": hrm_result["response"], + "metadata": hrm_result["metadata"], + }, + "trm_results": { + "response": trm_result["response"], + "metadata": trm_result["metadata"], + }, + "agent_outputs": [ + { + "agent": "hrm", + "response": hrm_result["response"], + "confidence": hrm_result["metadata"].get("decomposition_quality_score", 0.7), + }, + { + "agent": "trm", + "response": trm_result["response"], + "confidence": trm_result["metadata"].get("final_quality_score", 0.7), + }, + ], + } + + async def _hrm_agent_node(self, state: AgentState) -> dict: + """Execute HRM agent.""" + self.logger.info("Executing HRM agent") + + result = await self.hrm_agent.process( + query=state["query"], + rag_context=state.get("rag_context"), + ) + + return { + "hrm_results": { + "response": result["response"], + "metadata": result["metadata"], + }, + "agent_outputs": [ + { + "agent": "hrm", + "response": result["response"], + "confidence": result["metadata"].get("decomposition_quality_score", 0.7), + } + ], + } + + async def _trm_agent_node(self, state: AgentState) -> dict: + """Execute TRM agent.""" + self.logger.info("Executing TRM agent") + + result = await self.trm_agent.process( + query=state["query"], + rag_context=state.get("rag_context"), + ) + + return { + "trm_results": { + "response": result["response"], + "metadata": result["metadata"], + }, + "agent_outputs": [ + { + "agent": "trm", + "response": result["response"], + "confidence": result["metadata"].get("final_quality_score", 0.7), + } + ], + } + + async def _mcts_simulator_node(self, state: AgentState) -> dict: + """Execute MCTS simulation using new deterministic engine.""" + self.logger.info("Executing MCTS simulation with deterministic engine") + + start_time = time.perf_counter() + + # Reset engine for this simulation + self.mcts_engine.clear_cache() + + # Create root state + root_state = MCTSState( + state_id="root", + features={ + "query": state["query"][:100], # Truncate for hashing + "has_hrm": "hrm_results" in state, + "has_trm": "trm_results" in state, + }, + ) + + root = MCTSNode( + state=root_state, + rng=self.mcts_engine.rng, + ) + + # Define action generator based on domain + def action_generator(mcts_state: MCTSState) -> list[str]: + """Generate available actions for state.""" + depth = len(mcts_state.state_id.split("_")) - 1 + + if depth == 0: + # Root level actions + return ["action_A", "action_B", "action_C", "action_D"] + elif depth < self.mcts_config.max_tree_depth: + # Subsequent actions + return ["continue", "refine", "fallback", "escalate"] + else: + return [] # Terminal + + # Define state transition + def state_transition(mcts_state: MCTSState, action: str) -> MCTSState: + """Compute next state from action.""" + new_id = f"{mcts_state.state_id}_{action}" + new_features = mcts_state.features.copy() + new_features["last_action"] = action + new_features["depth"] = len(new_id.split("_")) - 1 + return MCTSState(state_id=new_id, features=new_features) + + # Create rollout policy using agent results + def heuristic_fn(mcts_state: MCTSState) -> float: + """Evaluate state using agent confidence.""" + base = 0.5 + + # Bias based on agent confidence + if state.get("hrm_results"): + hrm_conf = state["hrm_results"]["metadata"].get("decomposition_quality_score", 0.5) + base += hrm_conf * 0.2 + + if state.get("trm_results"): + trm_conf = state["trm_results"]["metadata"].get("final_quality_score", 0.5) + base += trm_conf * 0.2 + + return min(base, 1.0) + + rollout_policy = HybridRolloutPolicy( + heuristic_fn=heuristic_fn, + heuristic_weight=0.7, + random_weight=0.3, + ) + + # Run MCTS search + best_action, stats = await self.mcts_engine.search( + root=root, + num_iterations=self.mcts_config.num_iterations, + action_generator=action_generator, + state_transition=state_transition, + rollout_policy=rollout_policy, + max_rollout_depth=self.mcts_config.max_rollout_depth, + selection_policy=self.mcts_config.selection_policy, + ) + + end_time = time.perf_counter() + execution_time_ms = (end_time - start_time) * 1000 + + # Compute tree statistics + tree_depth = self.mcts_engine.get_tree_depth(root) + tree_node_count = self.mcts_engine.count_nodes(root) + + # Track experiment + self.experiment_tracker.create_result( + experiment_id=f"mcts_{int(time.time())}", + config=self.mcts_config, + mcts_stats=stats, + execution_time_ms=execution_time_ms, + tree_depth=tree_depth, + tree_node_count=tree_node_count, + metadata={ + "query": state["query"][:100], + "has_rag": state.get("use_rag", False), + }, + ) + + self.logger.info( + f"MCTS complete: best_action={best_action}, " + f"iterations={stats['iterations']}, " + f"cache_hit_rate={stats['cache_hit_rate']:.2%}" + ) + + return { + "mcts_root": root, + "mcts_best_action": best_action, + "mcts_stats": stats, + "agent_outputs": [ + { + "agent": "mcts", + "response": ( + f"Simulated {stats['iterations']} scenarios with " + f"seed {self.mcts_config.seed}. " + f"Recommended action: {best_action} " + f"(visits={stats['best_action_visits']}, " + f"value={stats['best_action_value']:.3f})" + ), + "confidence": min( + stats["best_action_visits"] / stats["iterations"] if stats["iterations"] > 0 else 0.5, + 1.0, + ), + } + ], + } + + def _aggregate_results_node(self, state: AgentState) -> dict: + """Aggregate results from all agents.""" + self.logger.info("Aggregating agent results") + + agent_outputs = state.get("agent_outputs", []) + + confidence_scores = {output["agent"]: output["confidence"] for output in agent_outputs} + + return {"confidence_scores": confidence_scores} + + def _evaluate_consensus_node(self, state: AgentState) -> dict: + """Evaluate consensus among agents.""" + agent_outputs = state.get("agent_outputs", []) + + if len(agent_outputs) < 2: + return { + "consensus_reached": True, + "consensus_score": 1.0, + } + + avg_confidence = sum(o["confidence"] for o in agent_outputs) / len(agent_outputs) + + consensus_reached = avg_confidence >= self.consensus_threshold + + self.logger.info(f"Consensus: {consensus_reached} (score={avg_confidence:.2f})") + + return { + "consensus_reached": consensus_reached, + "consensus_score": avg_confidence, + } + + def _check_consensus(self, state: AgentState) -> str: + """Check if consensus reached or need more iterations.""" + if state.get("consensus_reached", False): + return "synthesize" + + if state.get("iteration", 0) >= state.get("max_iterations", self.max_iterations): + return "synthesize" + + return "iterate" + + async def _synthesize_node(self, state: AgentState) -> dict: + """Synthesize final response from agent outputs.""" + self.logger.info("Synthesizing final response") + + agent_outputs = state.get("agent_outputs", []) + + synthesis_prompt = f"""Query: {state["query"]} + +Agent Outputs: +""" + + for output in agent_outputs: + synthesis_prompt += f""" +{output["agent"].upper()} (confidence={output["confidence"]:.2f}): +{output["response"]} + +""" + + synthesis_prompt += """ +Synthesize these outputs into a comprehensive final response. +Prioritize higher-confidence outputs. Integrate insights from all agents. + +Final Response:""" + + try: + response = await self.model_adapter.generate( + prompt=synthesis_prompt, + temperature=0.5, + ) + final_response = response.text + except Exception as e: + self.logger.error(f"Synthesis failed: {e}") + best_output = max(agent_outputs, key=lambda o: o["confidence"]) + final_response = best_output["response"] + + metadata = { + "agents_used": [o["agent"] for o in agent_outputs], + "confidence_scores": state.get("confidence_scores", {}), + "consensus_score": state.get("consensus_score", 0.0), + "iterations": state.get("iteration", 0), + "mcts_config": state.get("mcts_config", {}), + } + + if state.get("mcts_stats"): + metadata["mcts_stats"] = state["mcts_stats"] + + return { + "final_response": final_response, + "metadata": metadata, + } + + +class IntegratedFramework: + """ + Integrated multi-agent framework with new MCTS core. + + Maintains backward compatibility with original process() signature. + """ + + def __init__( + self, + model_adapter, + logger, + vector_store=None, + _embedding_model=None, + hrm_config: dict | None = None, + trm_config: dict | None = None, + mcts_config: MCTSConfig | None = None, + top_k_retrieval: int = 5, + max_iterations: int = 3, + consensus_threshold: float = 0.75, + enable_parallel_agents: bool = True, + ): + """ + Initialize integrated framework. + + Backward compatible with LangGraphMultiAgentFramework. + """ + self.model_adapter = model_adapter + self.logger = logger + self.vector_store = vector_store + + # Import agents (would be real imports in production) + try: + from improved_hrm_agent import HRMAgent + from improved_trm_agent import TRMAgent + + self.hrm_agent = HRMAgent( + model_adapter=model_adapter, + logger=logger, + **(hrm_config or {}), + ) + self.trm_agent = TRMAgent( + model_adapter=model_adapter, + logger=logger, + **(trm_config or {}), + ) + except ImportError: + self.hrm_agent = None + self.trm_agent = None + self.logger.warning("Could not import HRM/TRM agents") + + # Build graph + self.graph_builder = GraphBuilder( + hrm_agent=self.hrm_agent, + trm_agent=self.trm_agent, + model_adapter=model_adapter, + logger=logger, + vector_store=vector_store, + mcts_config=mcts_config, + top_k_retrieval=top_k_retrieval, + max_iterations=max_iterations, + consensus_threshold=consensus_threshold, + enable_parallel_agents=enable_parallel_agents, + ) + + # Compile graph + if StateGraph is not None: + self.graph = self.graph_builder.build_graph() + self.memory = MemorySaver() if MemorySaver else None + self.app = self.graph.compile(checkpointer=self.memory) if self.memory else self.graph.compile() + else: + self.graph = None + self.app = None + + self.logger.info("Integrated framework initialized with new MCTS core") + + async def process( + self, + query: str, + use_rag: bool = True, + use_mcts: bool = False, + config: dict | None = None, + ) -> dict: + """ + Process query through LangGraph. + + Backward compatible with original signature. + + Args: + query: User query to process + use_rag: Enable RAG context retrieval + use_mcts: Enable MCTS simulation + config: Optional LangGraph config + + Returns: + Dictionary with response, metadata, and state + """ + if self.app is None: + raise RuntimeError("LangGraph not available. Install with: pip install langgraph") + + initial_state = { + "query": query, + "use_rag": use_rag, + "use_mcts": use_mcts, + "iteration": 0, + "max_iterations": self.graph_builder.max_iterations, + "agent_outputs": [], + } + + config = config or {"configurable": {"thread_id": "default"}} + + result = await self.app.ainvoke(initial_state, config=config) + + return { + "response": result.get("final_response", ""), + "metadata": result.get("metadata", {}), + "state": result, + } + + def get_experiment_tracker(self) -> ExperimentTracker: + """Get the experiment tracker for analysis.""" + return self.graph_builder.experiment_tracker + + def set_mcts_seed(self, seed: int) -> None: + """Set MCTS seed for deterministic behavior.""" + self.graph_builder.mcts_engine.reset_seed(seed) + self.graph_builder.mcts_config.seed = seed diff --git a/src/framework/mcts/__init__.py b/src/framework/mcts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60d1cdeb1251f532a6288042b735f654dadd0ac5 --- /dev/null +++ b/src/framework/mcts/__init__.py @@ -0,0 +1,26 @@ +""" +MCTS (Monte Carlo Tree Search) module for multi-agent framework. + +Provides deterministic, testable MCTS with: +- Progressive widening for controlled branching +- Simulation result caching +- Configurable selection and rollout policies +- Experiment tracking and analysis +""" + +from .config import MCTSConfig, create_preset_config +from .core import MCTSEngine, MCTSNode +from .experiments import ExperimentResult, ExperimentTracker +from .policies import RolloutPolicy, SelectionPolicy, ucb1 + +__all__ = [ + "MCTSNode", + "MCTSEngine", + "ucb1", + "RolloutPolicy", + "SelectionPolicy", + "MCTSConfig", + "create_preset_config", + "ExperimentTracker", + "ExperimentResult", +] diff --git a/src/framework/mcts/config.py b/src/framework/mcts/config.py new file mode 100644 index 0000000000000000000000000000000000000000..767d19ca0c79ba5bf4aa82f75db4e30c9186546b --- /dev/null +++ b/src/framework/mcts/config.py @@ -0,0 +1,355 @@ +""" +MCTS Configuration Module - Parameter management and presets. + +Provides: +- MCTSConfig dataclass with all parameters +- Validation of parameter bounds +- Preset configurations (fast, balanced, thorough) +- Serialization support for experiment tracking +""" + +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Any + +from .policies import SelectionPolicy + + +class ConfigPreset(Enum): + """Preset configuration names.""" + + FAST = "fast" + BALANCED = "balanced" + THOROUGH = "thorough" + EXPLORATION_HEAVY = "exploration_heavy" + EXPLOITATION_HEAVY = "exploitation_heavy" + + +@dataclass +class MCTSConfig: + """ + Complete configuration for MCTS engine. + + All MCTS parameters are centralized here with validation. + Supports serialization for experiment tracking and reproducibility. + """ + + # Core MCTS parameters + num_iterations: int = 100 + """Number of MCTS iterations to run.""" + + seed: int = 42 + """Random seed for deterministic behavior.""" + + exploration_weight: float = 1.414 + """UCB1 exploration constant (c). Higher = more exploration.""" + + # Progressive widening + progressive_widening_k: float = 1.0 + """Progressive widening coefficient. Higher = more conservative.""" + + progressive_widening_alpha: float = 0.5 + """Progressive widening exponent. Lower = more aggressive expansion.""" + + # Rollout configuration + max_rollout_depth: int = 10 + """Maximum depth for rollout simulations.""" + + rollout_policy: str = "hybrid" + """Rollout policy: 'random', 'greedy', 'hybrid'.""" + + # Action selection + selection_policy: SelectionPolicy = SelectionPolicy.MAX_VISITS + """Policy for final action selection.""" + + # Parallelization + max_parallel_rollouts: int = 4 + """Maximum concurrent rollout simulations.""" + + # Caching + enable_cache: bool = True + """Enable simulation result caching.""" + + cache_size_limit: int = 10000 + """Maximum number of cached simulation results.""" + + # Tree structure + max_tree_depth: int = 20 + """Maximum depth of MCTS tree.""" + + max_children_per_node: int = 50 + """Maximum children per node (action branching limit).""" + + # Early termination + early_termination_threshold: float = 0.95 + """Stop if best action has this fraction of total visits.""" + + min_iterations_before_termination: int = 50 + """Minimum iterations before early termination check.""" + + # Value bounds + min_value: float = 0.0 + """Minimum value for normalization.""" + + max_value: float = 1.0 + """Maximum value for normalization.""" + + # Metadata + name: str = "default" + """Configuration name for tracking.""" + + description: str = "" + """Description of this configuration.""" + + def __post_init__(self): + """Validate configuration parameters after initialization.""" + self.validate() + + def validate(self) -> None: + """ + Validate all configuration parameters. + + Raises: + ValueError: If any parameter is out of valid bounds. + """ + errors = [] + + # Core parameters + if self.num_iterations < 1: + errors.append("num_iterations must be >= 1") + if self.num_iterations > 100000: + errors.append("num_iterations should be <= 100000 for practical use") + + if self.exploration_weight < 0: + errors.append("exploration_weight must be >= 0") + if self.exploration_weight > 10: + errors.append("exploration_weight should be <= 10") + + # Progressive widening + if self.progressive_widening_k <= 0: + errors.append("progressive_widening_k must be > 0") + if not 0 < self.progressive_widening_alpha < 1: + errors.append("progressive_widening_alpha must be in (0, 1)") + + # Rollout + if self.max_rollout_depth < 1: + errors.append("max_rollout_depth must be >= 1") + if self.rollout_policy not in ["random", "greedy", "hybrid", "llm"]: + errors.append("rollout_policy must be one of: random, greedy, hybrid, llm") + + # Parallelization + if self.max_parallel_rollouts < 1: + errors.append("max_parallel_rollouts must be >= 1") + if self.max_parallel_rollouts > 100: + errors.append("max_parallel_rollouts should be <= 100") + + # Caching + if self.cache_size_limit < 0: + errors.append("cache_size_limit must be >= 0") + + # Tree structure + if self.max_tree_depth < 1: + errors.append("max_tree_depth must be >= 1") + if self.max_children_per_node < 1: + errors.append("max_children_per_node must be >= 1") + + # Early termination + if not 0 < self.early_termination_threshold <= 1: + errors.append("early_termination_threshold must be in (0, 1]") + if self.min_iterations_before_termination < 1: + errors.append("min_iterations_before_termination must be >= 1") + if self.min_iterations_before_termination > self.num_iterations: + errors.append("min_iterations_before_termination must be <= num_iterations") + + # Value bounds + if self.min_value >= self.max_value: + errors.append("min_value must be < max_value") + + if errors: + raise ValueError("Invalid MCTS configuration:\n" + "\n".join(f" - {e}" for e in errors)) + + def to_dict(self) -> dict[str, Any]: + """ + Convert configuration to dictionary for serialization. + + Returns: + Dictionary representation of config. + """ + d = asdict(self) + # Convert enum to string + d["selection_policy"] = self.selection_policy.value + return d + + def to_json(self, indent: int = 2) -> str: + """ + Serialize configuration to JSON string. + + Args: + indent: JSON indentation level + + Returns: + JSON string representation + """ + return json.dumps(self.to_dict(), indent=indent) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> MCTSConfig: + """ + Create configuration from dictionary. + + Args: + data: Dictionary with configuration parameters + + Returns: + MCTSConfig instance + """ + # Convert selection_policy string back to enum + if "selection_policy" in data and isinstance(data["selection_policy"], str): + data["selection_policy"] = SelectionPolicy(data["selection_policy"]) + return cls(**data) + + @classmethod + def from_json(cls, json_str: str) -> MCTSConfig: + """ + Deserialize configuration from JSON string. + + Args: + json_str: JSON string + + Returns: + MCTSConfig instance + """ + data = json.loads(json_str) + return cls.from_dict(data) + + def copy(self, **overrides) -> MCTSConfig: + """ + Create a copy with optional parameter overrides. + + Args: + **overrides: Parameters to override + + Returns: + New MCTSConfig instance + """ + data = self.to_dict() + data.update(overrides) + return self.from_dict(data) + + def __repr__(self) -> str: + return ( + f"MCTSConfig(name={self.name!r}, " + f"iterations={self.num_iterations}, " + f"c={self.exploration_weight}, " + f"widening_k={self.progressive_widening_k}, " + f"widening_alpha={self.progressive_widening_alpha})" + ) + + +def create_preset_config(preset: ConfigPreset) -> MCTSConfig: + """ + Create a preset configuration. + + Args: + preset: Preset type to create + + Returns: + MCTSConfig with preset parameters + """ + if preset == ConfigPreset.FAST: + return MCTSConfig( + name="fast", + description="Fast search with minimal iterations", + num_iterations=25, + exploration_weight=1.414, + progressive_widening_k=0.5, # Aggressive widening + progressive_widening_alpha=0.5, + max_rollout_depth=5, + rollout_policy="random", + selection_policy=SelectionPolicy.MAX_VISITS, + max_parallel_rollouts=8, + cache_size_limit=1000, + early_termination_threshold=0.8, + min_iterations_before_termination=10, + ) + + elif preset == ConfigPreset.BALANCED: + return MCTSConfig( + name="balanced", + description="Balanced search for typical use cases", + num_iterations=100, + exploration_weight=1.414, + progressive_widening_k=1.0, + progressive_widening_alpha=0.5, + max_rollout_depth=10, + rollout_policy="hybrid", + selection_policy=SelectionPolicy.MAX_VISITS, + max_parallel_rollouts=4, + cache_size_limit=10000, + early_termination_threshold=0.9, + min_iterations_before_termination=50, + ) + + elif preset == ConfigPreset.THOROUGH: + return MCTSConfig( + name="thorough", + description="Thorough search for high-stakes decisions", + num_iterations=500, + exploration_weight=1.414, + progressive_widening_k=2.0, # Conservative widening + progressive_widening_alpha=0.6, + max_rollout_depth=20, + rollout_policy="hybrid", + selection_policy=SelectionPolicy.ROBUST_CHILD, + max_parallel_rollouts=4, + cache_size_limit=50000, + early_termination_threshold=0.95, + min_iterations_before_termination=200, + ) + + elif preset == ConfigPreset.EXPLORATION_HEAVY: + return MCTSConfig( + name="exploration_heavy", + description="High exploration for diverse action discovery", + num_iterations=200, + exploration_weight=2.5, # High exploration + progressive_widening_k=0.8, # More widening + progressive_widening_alpha=0.4, # Aggressive + max_rollout_depth=15, + rollout_policy="random", + selection_policy=SelectionPolicy.MAX_VISITS, + max_parallel_rollouts=6, + cache_size_limit=20000, + early_termination_threshold=0.95, + min_iterations_before_termination=100, + ) + + elif preset == ConfigPreset.EXPLOITATION_HEAVY: + return MCTSConfig( + name="exploitation_heavy", + description="High exploitation for known-good action refinement", + num_iterations=150, + exploration_weight=0.5, # Low exploration + progressive_widening_k=3.0, # Conservative + progressive_widening_alpha=0.7, # Very conservative + max_rollout_depth=10, + rollout_policy="greedy", + selection_policy=SelectionPolicy.MAX_VALUE, + max_parallel_rollouts=4, + cache_size_limit=10000, + early_termination_threshold=0.85, + min_iterations_before_termination=75, + ) + + else: + raise ValueError(f"Unknown preset: {preset}") + + +# Default configurations for easy access +DEFAULT_CONFIG = MCTSConfig() +FAST_CONFIG = create_preset_config(ConfigPreset.FAST) +BALANCED_CONFIG = create_preset_config(ConfigPreset.BALANCED) +THOROUGH_CONFIG = create_preset_config(ConfigPreset.THOROUGH) diff --git a/src/framework/mcts/core.py b/src/framework/mcts/core.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e1387f6837474100af579bf8d87b16075ab3c3 --- /dev/null +++ b/src/framework/mcts/core.py @@ -0,0 +1,619 @@ +""" +MCTS Core Module - Deterministic, testable Monte Carlo Tree Search implementation. + +Features: +- Seeded RNG for deterministic behavior +- Progressive widening to control branching factor +- Simulation result caching with hashable state keys +- Clear separation of MCTS phases: select, expand, simulate, backpropagate +- Support for parallel rollouts with asyncio.Semaphore +""" + +from __future__ import annotations + +import asyncio +import hashlib +from collections import OrderedDict +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + +from .policies import RolloutPolicy, SelectionPolicy, ucb1 + + +@dataclass +class MCTSState: + """Hashable state representation for caching.""" + + state_id: str + features: dict[str, Any] = field(default_factory=dict) + + def to_hash_key(self) -> str: + """Generate a hashable key for this state.""" + # Sort features for deterministic hashing + feature_str = str(sorted(self.features.items())) + combined = f"{self.state_id}:{feature_str}" + return hashlib.sha256(combined.encode()).hexdigest() + + +class MCTSNode: + """ + Monte Carlo Tree Search node with proper state management. + + Attributes: + state: The state this node represents + parent: Parent node (None for root) + action: Action taken to reach this node from parent + children: List of child nodes + visits: Number of times this node has been visited + value_sum: Total accumulated value from simulations + rng: Seeded random number generator for deterministic behavior + """ + + def __init__( + self, + state: MCTSState, + parent: MCTSNode | None = None, + action: str | None = None, + rng: np.random.Generator | None = None, + ): + self.state = state + self.parent = parent + self.action = action + self.children: list[MCTSNode] = [] + self.visits: int = 0 + self.value_sum: float = 0.0 + self.terminal: bool = False + self.expanded_actions: set = set() + self.available_actions: list[str] = [] + + # Track depth for O(1) tree statistics + self.depth: int = 0 if parent is None else parent.depth + 1 + + # Use provided RNG or create default + self._rng = rng or np.random.default_rng() + + @property + def value(self) -> float: + """Average value of this node.""" + if self.visits == 0: + return 0.0 + return self.value_sum / self.visits + + @property + def is_fully_expanded(self) -> bool: + """Check if all available actions have been expanded.""" + return len(self.expanded_actions) >= len(self.available_actions) + + def select_child(self, exploration_weight: float = 1.414) -> MCTSNode: + """ + Select best child using UCB1 policy. + + Args: + exploration_weight: Exploration constant (c in UCB1) + + Returns: + Best child node according to UCB1 + """ + if not self.children: + raise ValueError("No children to select from") + + best_child = None + best_score = float("-inf") + + for child in self.children: + score = ucb1( + value_sum=child.value_sum, + visits=child.visits, + parent_visits=self.visits, + c=exploration_weight, + ) + if score > best_score: + best_score = score + best_child = child + + return best_child + + def add_child(self, action: str, child_state: MCTSState) -> MCTSNode: + """ + Add a child node for the given action. + + Args: + action: Action taken to reach child state + child_state: State of the child node + + Returns: + Newly created child node + """ + child = MCTSNode( + state=child_state, + parent=self, + action=action, + rng=self._rng, + ) + self.children.append(child) + self.expanded_actions.add(action) + return child + + def get_unexpanded_action(self) -> str | None: + """Get a random unexpanded action.""" + unexpanded = [a for a in self.available_actions if a not in self.expanded_actions] + if not unexpanded: + return None + return self._rng.choice(unexpanded) + + def __repr__(self) -> str: + return ( + f"MCTSNode(state={self.state.state_id}, " + f"visits={self.visits}, value={self.value:.3f}, " + f"children={len(self.children)})" + ) + + +class MCTSEngine: + """ + Main MCTS engine with deterministic behavior and advanced features. + + Features: + - Seeded RNG for reproducibility + - Progressive widening to control branching + - Simulation result caching + - Parallel rollout support with semaphore + """ + + def __init__( + self, + seed: int = 42, + exploration_weight: float = 1.414, + progressive_widening_k: float = 1.0, + progressive_widening_alpha: float = 0.5, + max_parallel_rollouts: int = 4, + cache_size_limit: int = 10000, + ): + """ + Initialize MCTS engine. + + Args: + seed: Random seed for deterministic behavior + exploration_weight: UCB1 exploration constant + progressive_widening_k: Progressive widening coefficient + progressive_widening_alpha: Progressive widening exponent + max_parallel_rollouts: Maximum concurrent rollouts + cache_size_limit: Maximum number of cached simulation results + """ + self.seed = seed + self.rng = np.random.default_rng(seed) + self.exploration_weight = exploration_weight + self.progressive_widening_k = progressive_widening_k + self.progressive_widening_alpha = progressive_widening_alpha + + # Parallel rollout control + self.max_parallel_rollouts = max_parallel_rollouts + self._semaphore: asyncio.Semaphore | None = None + + # Simulation cache: state_hash -> (value, visit_count) + # Using OrderedDict for LRU eviction + self._simulation_cache: OrderedDict[str, tuple[float, int]] = OrderedDict() + self.cache_size_limit = cache_size_limit + + # Statistics + self.total_simulations = 0 + self.cache_hits = 0 + self.cache_misses = 0 + self.cache_evictions = 0 + + # Cached tree statistics for O(1) retrieval + self._cached_tree_depth: int = 0 + self._cached_node_count: int = 0 + + def reset_seed(self, seed: int) -> None: + """Reset the random seed for new experiment.""" + self.seed = seed + self.rng = np.random.default_rng(seed) + + def clear_cache(self) -> None: + """Clear simulation result cache.""" + self._simulation_cache.clear() + self.cache_hits = 0 + self.cache_misses = 0 + self.cache_evictions = 0 + + def should_expand(self, node: MCTSNode) -> bool: + """ + Check if node should expand based on progressive widening. + + Progressive widening formula: expand when visits > k * n^alpha + where n is the number of children. + + This prevents excessive branching and focuses search on promising areas. + """ + if node.terminal or node.is_fully_expanded: + return False + + num_children = len(node.children) + threshold = self.progressive_widening_k * (num_children**self.progressive_widening_alpha) + + return node.visits > threshold + + def select(self, node: MCTSNode) -> MCTSNode: + """ + MCTS Selection Phase: traverse tree to find leaf node. + + Uses UCB1 to balance exploration and exploitation. + """ + while node.children and not node.terminal: + # Check if we should expand instead of selecting + if self.should_expand(node): + break + node = node.select_child(self.exploration_weight) + return node + + def expand( + self, + node: MCTSNode, + action_generator: Callable[[MCTSState], list[str]], + state_transition: Callable[[MCTSState, str], MCTSState], + ) -> MCTSNode: + """ + MCTS Expansion Phase: add a new child node. + + Args: + node: Node to expand + action_generator: Function to generate available actions + state_transition: Function to compute next state from action + + Returns: + Newly expanded child node, or original node if cannot expand + """ + if node.terminal: + return node + + # Generate available actions if not yet done + if not node.available_actions: + node.available_actions = action_generator(node.state) + + if not node.available_actions: + node.terminal = True + return node + + # Check progressive widening + if not self.should_expand(node): + return node + + # Get unexpanded action + action = node.get_unexpanded_action() + if action is None: + return node + + # Create child state + child_state = state_transition(node.state, action) + child = node.add_child(action, child_state) + + # Update cached node count for O(1) retrieval + self._cached_node_count += 1 + + return child + + async def simulate( + self, + node: MCTSNode, + rollout_policy: RolloutPolicy, + max_depth: int = 10, + ) -> float: + """ + MCTS Simulation Phase: evaluate node value through rollout. + + Uses caching to avoid redundant simulations. + + Args: + node: Node to simulate from + rollout_policy: Policy for rollout evaluation + max_depth: Maximum rollout depth + + Returns: + Estimated value from simulation + """ + # Check cache first + state_hash = node.state.to_hash_key() + if state_hash in self._simulation_cache: + cached_value, cached_count = self._simulation_cache[state_hash] + # Move to end for LRU (most recently used) + self._simulation_cache.move_to_end(state_hash) + self.cache_hits += 1 + # Return cached average with small noise for exploration + noise = self.rng.normal(0, 0.01) + return cached_value + noise + + self.cache_misses += 1 + + # Acquire semaphore for parallel control + if self._semaphore is None: + self._semaphore = asyncio.Semaphore(self.max_parallel_rollouts) + + async with self._semaphore: + # Perform rollout + value = await rollout_policy.evaluate( + state=node.state, + rng=self.rng, + max_depth=max_depth, + ) + + self.total_simulations += 1 + + # Update cache with LRU eviction + if state_hash in self._simulation_cache: + # Update existing cache entry with running average + old_value, old_count = self._simulation_cache[state_hash] + new_count = old_count + 1 + new_value = (old_value * old_count + value) / new_count + self._simulation_cache[state_hash] = (new_value, new_count) + # Move to end for LRU (most recently used) + self._simulation_cache.move_to_end(state_hash) + else: + # Evict oldest entry if cache is full + if len(self._simulation_cache) >= self.cache_size_limit: + # Remove the first item (least recently used) + self._simulation_cache.popitem(last=False) + self.cache_evictions += 1 + # Add new entry at the end (most recently used) + self._simulation_cache[state_hash] = (value, 1) + + return value + + def backpropagate(self, node: MCTSNode, value: float) -> None: + """ + MCTS Backpropagation Phase: update ancestor statistics. + + Args: + node: Leaf node to start backpropagation + value: Value to propagate up the tree + """ + # Update cached tree depth if this node is deeper than current max + if node.depth > self._cached_tree_depth: + self._cached_tree_depth = node.depth + + current = node + while current is not None: + current.visits += 1 + current.value_sum += value + current = current.parent + + async def run_iteration( + self, + root: MCTSNode, + action_generator: Callable[[MCTSState], list[str]], + state_transition: Callable[[MCTSState, str], MCTSState], + rollout_policy: RolloutPolicy, + max_rollout_depth: int = 10, + ) -> None: + """ + Run a single MCTS iteration (select, expand, simulate, backpropagate). + + Args: + root: Root node of the tree + action_generator: Function to generate actions + state_transition: Function to compute state transitions + rollout_policy: Policy for rollout evaluation + max_rollout_depth: Maximum depth for rollouts + """ + # Selection + leaf = self.select(root) + + # Expansion + if not leaf.terminal and leaf.visits > 0: + leaf = self.expand(leaf, action_generator, state_transition) + + # Simulation + value = await self.simulate(leaf, rollout_policy, max_rollout_depth) + + # Backpropagation + self.backpropagate(leaf, value) + + async def search( + self, + root: MCTSNode, + num_iterations: int, + action_generator: Callable[[MCTSState], list[str]], + state_transition: Callable[[MCTSState, str], MCTSState], + rollout_policy: RolloutPolicy, + max_rollout_depth: int = 10, + selection_policy: SelectionPolicy = SelectionPolicy.MAX_VISITS, + ) -> tuple[str | None, dict[str, Any]]: + """ + Run MCTS search for specified number of iterations. + + Args: + root: Root node to search from + num_iterations: Number of MCTS iterations + action_generator: Function to generate available actions + state_transition: Function to compute state transitions + rollout_policy: Policy for rollout simulation + max_rollout_depth: Maximum rollout depth + selection_policy: Policy for final action selection + + Returns: + Tuple of (best_action, statistics_dict) + """ + # Reset cached tree statistics for new search + self._cached_tree_depth = 0 + self._cached_node_count = 1 # Start with root node + + # Initialize root's available actions + if not root.available_actions: + root.available_actions = action_generator(root.state) + + # Run iterations + for _i in range(num_iterations): + await self.run_iteration( + root=root, + action_generator=action_generator, + state_transition=state_transition, + rollout_policy=rollout_policy, + max_rollout_depth=max_rollout_depth, + ) + + # Select best action based on policy + best_action = self._select_best_action(root, selection_policy) + + # Compute statistics + stats = self._compute_statistics(root, num_iterations) + + return best_action, stats + + def _select_best_action( + self, + root: MCTSNode, + policy: SelectionPolicy, + ) -> str | None: + """ + Select the best action from root based on selection policy. + + Args: + root: Root node with children + policy: Selection policy to use + + Returns: + Best action string or None if no children + """ + if not root.children: + return None + + if policy == SelectionPolicy.MAX_VISITS: + # Most robust: select action with most visits + best_child = max(root.children, key=lambda c: c.visits) + elif policy == SelectionPolicy.MAX_VALUE: + # Greedy: select action with highest average value + best_child = max(root.children, key=lambda c: c.value) + elif policy == SelectionPolicy.ROBUST_CHILD: + # Robust: require both high visits and high value + # Normalize both metrics and combine + max_visits = max(c.visits for c in root.children) + max_value = max(c.value for c in root.children) or 1.0 + + def robust_score(child): + visit_score = child.visits / max_visits if max_visits > 0 else 0 + value_score = child.value / max_value if max_value > 0 else 0 + return 0.5 * visit_score + 0.5 * value_score + + best_child = max(root.children, key=robust_score) + else: + # Default to max visits + best_child = max(root.children, key=lambda c: c.visits) + + return best_child.action + + def _compute_statistics( + self, + root: MCTSNode, + num_iterations: int, + ) -> dict[str, Any]: + """ + Compute comprehensive MCTS statistics. + + Args: + root: Root node + num_iterations: Number of iterations run + + Returns: + Dictionary of statistics + """ + # Best child info + best_child = None + if root.children: + best_child = max(root.children, key=lambda c: c.visits) + + # Action statistics + action_stats = {} + for child in root.children: + action_stats[child.action] = { + "visits": child.visits, + "value": child.value, + "value_sum": child.value_sum, + "num_children": len(child.children), + } + + return { + "iterations": num_iterations, + "root_visits": root.visits, + "root_value": root.value, + "num_children": len(root.children), + "best_action": best_child.action if best_child else None, + "best_action_visits": best_child.visits if best_child else 0, + "best_action_value": best_child.value if best_child else 0.0, + "action_stats": action_stats, + "total_simulations": self.total_simulations, + "cache_hits": self.cache_hits, + "cache_misses": self.cache_misses, + "cache_evictions": self.cache_evictions, + "cache_hit_rate": ( + self.cache_hits / (self.cache_hits + self.cache_misses) + if (self.cache_hits + self.cache_misses) > 0 + else 0.0 + ), + "cache_size": len(self._simulation_cache), + "seed": self.seed, + } + + def get_tree_depth(self, node: MCTSNode) -> int: + """Get maximum depth of the tree from given node. + + Uses iterative BFS to avoid stack overflow for large trees (5000+ nodes). + Each level of the tree is processed iteratively, tracking depth as we go. + """ + if not node.children: + return 0 + + from collections import deque + + max_depth = 0 + # Queue contains tuples of (node, depth) + queue = deque([(node, 0)]) + + while queue: + current_node, depth = queue.popleft() + max_depth = max(max_depth, depth) + + for child in current_node.children: + queue.append((child, depth + 1)) + + return max_depth + + def count_nodes(self, node: MCTSNode) -> int: + """Count total number of nodes in tree. + + Uses iterative BFS to avoid stack overflow for large trees (5000+ nodes). + Traverses all nodes in the tree using a queue-based approach. + """ + from collections import deque + + count = 0 + queue = deque([node]) + + while queue: + current_node = queue.popleft() + count += 1 + + for child in current_node.children: + queue.append(child) + + return count + + def get_cached_tree_depth(self) -> int: + """ + Get cached maximum tree depth in O(1) time. + + Returns: + Maximum depth of tree from last search + """ + return self._cached_tree_depth + + def get_cached_node_count(self) -> int: + """ + Get cached total node count in O(1) time. + + Returns: + Total number of nodes in tree from last search + """ + return self._cached_node_count diff --git a/src/framework/mcts/experiments.py b/src/framework/mcts/experiments.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a872d80ddf766edd2f9cf6ca92b64ede13854e --- /dev/null +++ b/src/framework/mcts/experiments.py @@ -0,0 +1,440 @@ +""" +Experiment Tracking Module - Track, analyze, and compare MCTS experiments. + +Provides: +- Experiment run tracking (seed, params, results) +- Statistical analysis of MCTS performance +- Comparison utilities for different configurations +- Export to JSON/CSV for analysis +""" + +from __future__ import annotations + +import csv +import json +import statistics +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any + +from .config import MCTSConfig + + +@dataclass +class ExperimentResult: + """Result of a single MCTS experiment run.""" + + # Identification + experiment_id: str + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + + # Configuration + config: dict[str, Any] | None = None + seed: int = 42 + + # Core results + best_action: str | None = None + best_action_value: float = 0.0 + best_action_visits: int = 0 + root_visits: int = 0 + + # Performance metrics + total_iterations: int = 0 + total_simulations: int = 0 + execution_time_ms: float = 0.0 + + # Cache statistics + cache_hits: int = 0 + cache_misses: int = 0 + cache_hit_rate: float = 0.0 + + # Tree statistics + tree_depth: int = 0 + tree_node_count: int = 0 + branching_factor: float = 0.0 + + # Action distribution + action_stats: dict[str, dict[str, Any]] = field(default_factory=dict) + + # Optional metadata + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return asdict(self) + + def to_json(self, indent: int = 2) -> str: + """Serialize to JSON.""" + return json.dumps(self.to_dict(), indent=indent) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ExperimentResult: + """Create from dictionary.""" + return cls(**data) + + @classmethod + def from_json(cls, json_str: str) -> ExperimentResult: + """Deserialize from JSON.""" + return cls.from_dict(json.loads(json_str)) + + +class ExperimentTracker: + """ + Track and analyze MCTS experiments. + + Features: + - Store multiple experiment results + - Statistical analysis across runs + - Configuration comparison + - Export to JSON/CSV + """ + + def __init__(self, name: str = "mcts_experiments"): + """ + Initialize experiment tracker. + + Args: + name: Name of this experiment series + """ + self.name = name + self.results: list[ExperimentResult] = [] + self.created_at = datetime.now().isoformat() + + def add_result(self, result: ExperimentResult) -> None: + """ + Add an experiment result. + + Args: + result: ExperimentResult to add + """ + self.results.append(result) + + def create_result( + self, + experiment_id: str, + config: MCTSConfig, + mcts_stats: dict[str, Any], + execution_time_ms: float = 0.0, + tree_depth: int = 0, + tree_node_count: int = 0, + metadata: dict[str, Any] | None = None, + ) -> ExperimentResult: + """ + Create and add an experiment result from MCTS statistics. + + Args: + experiment_id: Unique ID for this experiment + config: MCTS configuration used + mcts_stats: Statistics dict from MCTSEngine.search() + execution_time_ms: Execution time in milliseconds + tree_depth: Depth of MCTS tree + tree_node_count: Total nodes in tree + metadata: Optional additional metadata + + Returns: + Created ExperimentResult + """ + # Calculate branching factor + branching_factor = 0.0 + if tree_node_count > 1 and tree_depth > 0: + branching_factor = (tree_node_count - 1) / tree_depth + + result = ExperimentResult( + experiment_id=experiment_id, + config=config.to_dict(), + seed=config.seed, + best_action=mcts_stats.get("best_action"), + best_action_value=mcts_stats.get("best_action_value", 0.0), + best_action_visits=mcts_stats.get("best_action_visits", 0), + root_visits=mcts_stats.get("root_visits", 0), + total_iterations=mcts_stats.get("iterations", 0), + total_simulations=mcts_stats.get("total_simulations", 0), + execution_time_ms=execution_time_ms, + cache_hits=mcts_stats.get("cache_hits", 0), + cache_misses=mcts_stats.get("cache_misses", 0), + cache_hit_rate=mcts_stats.get("cache_hit_rate", 0.0), + tree_depth=tree_depth, + tree_node_count=tree_node_count, + branching_factor=branching_factor, + action_stats=mcts_stats.get("action_stats", {}), + metadata=metadata or {}, + ) + + self.add_result(result) + return result + + def get_summary_statistics(self) -> dict[str, Any]: + """ + Compute summary statistics across all experiments. + + Returns: + Dictionary of summary statistics + """ + if not self.results: + return {"error": "No results to analyze"} + + # Extract metrics + best_values = [r.best_action_value for r in self.results] + best_visits = [r.best_action_visits for r in self.results] + exec_times = [r.execution_time_ms for r in self.results] + cache_rates = [r.cache_hit_rate for r in self.results] + tree_depths = [r.tree_depth for r in self.results] + node_counts = [r.tree_node_count for r in self.results] + + def compute_stats(values: list[float]) -> dict[str, float]: + """Compute basic statistics.""" + if not values: + return {} + return { + "mean": statistics.mean(values), + "std": statistics.stdev(values) if len(values) > 1 else 0.0, + "min": min(values), + "max": max(values), + "median": statistics.median(values), + } + + # Best action consistency + best_actions = [r.best_action for r in self.results] + action_counts = {} + for action in best_actions: + action_counts[action] = action_counts.get(action, 0) + 1 + most_common_action = max(action_counts.items(), key=lambda x: x[1]) + consistency_rate = most_common_action[1] / len(best_actions) + + return { + "num_experiments": len(self.results), + "best_action_value_stats": compute_stats(best_values), + "best_action_visits_stats": compute_stats(best_visits), + "execution_time_ms_stats": compute_stats(exec_times), + "cache_hit_rate_stats": compute_stats(cache_rates), + "tree_depth_stats": compute_stats(tree_depths), + "tree_node_count_stats": compute_stats(node_counts), + "action_consistency": { + "most_common_action": most_common_action[0], + "consistency_rate": consistency_rate, + "action_distribution": action_counts, + }, + } + + def compare_configs( + self, + config_names: list[str] | None = None, + ) -> dict[str, dict[str, Any]]: + """ + Compare performance across different configurations. + + Args: + config_names: Specific config names to compare (all if None) + + Returns: + Dictionary mapping config names to their statistics + """ + # Group results by configuration name + grouped: dict[str, list[ExperimentResult]] = {} + + for result in self.results: + if result.config is None: + continue + + config_name = result.config.get("name", "unnamed") + + if config_names and config_name not in config_names: + continue + + if config_name not in grouped: + grouped[config_name] = [] + grouped[config_name].append(result) + + # Compute statistics for each group + comparison = {} + for name, results in grouped.items(): + values = [r.best_action_value for r in results] + times = [r.execution_time_ms for r in results] + visits = [r.best_action_visits for r in results] + + comparison[name] = { + "num_runs": len(results), + "avg_value": statistics.mean(values) if values else 0.0, + "std_value": statistics.stdev(values) if len(values) > 1 else 0.0, + "avg_time_ms": statistics.mean(times) if times else 0.0, + "avg_visits": statistics.mean(visits) if visits else 0.0, + "value_per_ms": ( + statistics.mean(values) / statistics.mean(times) if times and statistics.mean(times) > 0 else 0.0 + ), + } + + return comparison + + def analyze_seed_consistency(self, seed: int) -> dict[str, Any]: + """ + Analyze consistency of results for a specific seed. + + Args: + seed: Seed value to analyze + + Returns: + Analysis of determinism for this seed + """ + seed_results = [r for r in self.results if r.seed == seed] + + if not seed_results: + return {"error": f"No results found for seed {seed}"} + + # Check if all results are identical + best_actions = [r.best_action for r in seed_results] + best_values = [r.best_action_value for r in seed_results] + best_visits = [r.best_action_visits for r in seed_results] + + is_deterministic = len(set(best_actions)) == 1 and len(set(best_values)) == 1 and len(set(best_visits)) == 1 + + return { + "seed": seed, + "num_runs": len(seed_results), + "is_deterministic": is_deterministic, + "unique_actions": list(set(best_actions)), + "value_variance": statistics.variance(best_values) if len(best_values) > 1 else 0.0, + "visits_variance": statistics.variance(best_visits) if len(best_visits) > 1 else 0.0, + } + + def export_to_json(self, file_path: str) -> None: + """ + Export all results to JSON file. + + Args: + file_path: Path to output file + """ + data = { + "name": self.name, + "created_at": self.created_at, + "num_experiments": len(self.results), + "results": [r.to_dict() for r in self.results], + "summary": self.get_summary_statistics(), + } + + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "w") as f: + json.dump(data, f, indent=2) + + def export_to_csv(self, file_path: str) -> None: + """ + Export results to CSV file for spreadsheet analysis. + + Args: + file_path: Path to output file + """ + if not self.results: + return + + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + + # Define CSV columns + fieldnames = [ + "experiment_id", + "timestamp", + "seed", + "config_name", + "num_iterations", + "exploration_weight", + "best_action", + "best_action_value", + "best_action_visits", + "root_visits", + "total_simulations", + "execution_time_ms", + "cache_hit_rate", + "tree_depth", + "tree_node_count", + "branching_factor", + ] + + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + for result in self.results: + row = { + "experiment_id": result.experiment_id, + "timestamp": result.timestamp, + "seed": result.seed, + "config_name": (result.config.get("name", "unnamed") if result.config else "unknown"), + "num_iterations": (result.config.get("num_iterations", 0) if result.config else 0), + "exploration_weight": (result.config.get("exploration_weight", 0) if result.config else 0), + "best_action": result.best_action, + "best_action_value": result.best_action_value, + "best_action_visits": result.best_action_visits, + "root_visits": result.root_visits, + "total_simulations": result.total_simulations, + "execution_time_ms": result.execution_time_ms, + "cache_hit_rate": result.cache_hit_rate, + "tree_depth": result.tree_depth, + "tree_node_count": result.tree_node_count, + "branching_factor": result.branching_factor, + } + writer.writerow(row) + + @classmethod + def load_from_json(cls, file_path: str) -> ExperimentTracker: + """ + Load experiment tracker from JSON file. + + Args: + file_path: Path to JSON file + + Returns: + Loaded ExperimentTracker + """ + with open(file_path) as f: + data = json.load(f) + + tracker = cls(name=data.get("name", "loaded_experiments")) + tracker.created_at = data.get("created_at", tracker.created_at) + + for result_data in data.get("results", []): + tracker.results.append(ExperimentResult.from_dict(result_data)) + + return tracker + + def clear(self) -> None: + """Clear all results.""" + self.results.clear() + + def __len__(self) -> int: + return len(self.results) + + def __repr__(self) -> str: + return f"ExperimentTracker(name={self.name!r}, num_results={len(self.results)})" + + +def run_determinism_test( + engine_factory, + config: MCTSConfig, + num_runs: int = 3, +) -> tuple[bool, dict[str, Any]]: + """ + Test that MCTS produces deterministic results with same seed. + + Args: + engine_factory: Factory function to create MCTSEngine + config: Configuration to test + num_runs: Number of runs to compare + + Returns: + Tuple of (is_deterministic, analysis_dict) + """ + ExperimentTracker(name="determinism_test") + + # This is a stub - actual implementation would run the engine + # Results would be compared to verify determinism + + analysis = { + "config": config.to_dict(), + "num_runs": num_runs, + "is_deterministic": True, # Would be computed from actual runs + "message": "Determinism test requires actual engine execution", + } + + return True, analysis diff --git a/src/framework/mcts/neural_mcts.py b/src/framework/mcts/neural_mcts.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf340667bc63fe0231e24bcda8562ed785b160f --- /dev/null +++ b/src/framework/mcts/neural_mcts.py @@ -0,0 +1,624 @@ +""" +Neural-Guided Monte Carlo Tree Search (MCTS). + +Implements AlphaZero-style MCTS with: +- Policy and value network guidance +- PUCT (Predictor + UCT) selection +- Dirichlet noise for exploration +- Virtual loss for parallel search +- Temperature-based action selection + +Based on: +- "Mastering the Game of Go with Deep Neural Networks and Tree Search" (AlphaGo) +- "Mastering Chess and Shogi by Self-Play with a General RL Algorithm" (AlphaZero) +""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from ...training.system_config import MCTSConfig + + +@dataclass +class GameState: + """ + Abstract game/problem state interface. + + Users should subclass this for their specific domain. + """ + + def get_legal_actions(self) -> list[Any]: + """Return list of legal actions from this state.""" + raise NotImplementedError + + def apply_action(self, action: Any) -> GameState: + """Apply action and return new state.""" + raise NotImplementedError + + def is_terminal(self) -> bool: + """Check if this is a terminal state.""" + raise NotImplementedError + + def get_reward(self, player: int = 1) -> float: + """Get reward for the player (1 or -1).""" + raise NotImplementedError + + def to_tensor(self) -> torch.Tensor: + """Convert state to tensor for neural network input.""" + raise NotImplementedError + + def get_canonical_form(self, player: int) -> GameState: # noqa: ARG002 + """Get state from perspective of given player.""" + return self + + def get_hash(self) -> str: + """Get unique hash for this state (for caching).""" + raise NotImplementedError + + def action_to_index(self, action: Any) -> int: + """ + Map action to its index in the neural network's action space. + + This method should return the index corresponding to the action + in the network's policy output vector. + + Default implementation uses string-based mapping for Tic-Tac-Toe style + actions (e.g., "0,0" -> 0, "0,1" -> 1, etc.). Override this method + for custom action mappings. + + Args: + action: The action to map + + Returns: + Index in the action space (0 to action_size-1) + """ + # Default implementation for grid-based actions like "row,col" + if isinstance(action, str) and "," in action: + row, col = map(int, action.split(",")) + # Assume 3x3 grid by default - override for different sizes + return row * 3 + col + # For other action types, assume they are already indices + return int(action) + + +class NeuralMCTSNode: + """ + MCTS node with neural network guidance. + + Stores statistics for PUCT selection and backpropagation. + """ + + def __init__( + self, + state: GameState, + parent: NeuralMCTSNode | None = None, + action: Any | None = None, + prior: float = 0.0, + ): + self.state = state + self.parent = parent + self.action = action # Action that led to this node + self.prior = prior # Prior probability from policy network + + # Statistics + self.visit_count: int = 0 + self.value_sum: float = 0.0 + self.virtual_loss: float = 0.0 + + # Children: action -> NeuralMCTSNode + self.children: dict[Any, NeuralMCTSNode] = {} + + # Caching + self.is_expanded: bool = False + self.is_terminal: bool = state.is_terminal() + + @property + def value(self) -> float: + """Average value (Q-value) of this node.""" + if self.visit_count == 0: + return 0.0 + return self.value_sum / self.visit_count + + def expand( + self, + policy_probs: np.ndarray, + valid_actions: list[Any], + ): + """ + Expand node by creating children for all legal actions. + + Args: + policy_probs: Prior probabilities from policy network + valid_actions: List of legal actions + """ + self.is_expanded = True + + for action, prior in zip(valid_actions, policy_probs, strict=True): + if action not in self.children: + next_state = self.state.apply_action(action) + self.children[action] = NeuralMCTSNode( + state=next_state, + parent=self, + action=action, + prior=prior, + ) + + def select_child(self, c_puct: float) -> tuple[Any, NeuralMCTSNode]: + """ + Select best child using PUCT algorithm. + + PUCT = Q(s,a) + c_puct * P(s,a) * sqrt(N(s)) / (1 + N(s,a)) + + Args: + c_puct: Exploration constant + + Returns: + (action, child_node) tuple + """ + best_score = -float("inf") + best_action = None + best_child = None + + # Precompute sqrt term for efficiency + sqrt_parent_visits = math.sqrt(self.visit_count) + + for action, child in self.children.items(): + # Q-value (average value) + q_value = child.value + + # U-value (exploration bonus) + u_value = c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count + child.virtual_loss) + + # PUCT score + puct_score = q_value + u_value + + if puct_score > best_score: + best_score = puct_score + best_action = action + best_child = child + + return best_action, best_child + + def add_virtual_loss(self, virtual_loss: float): + """Add virtual loss for parallel search.""" + self.virtual_loss += virtual_loss + + def revert_virtual_loss(self, virtual_loss: float): + """Remove virtual loss after search completes.""" + self.virtual_loss -= virtual_loss + + def update(self, value: float): + """Update node statistics with search result.""" + self.visit_count += 1 + self.value_sum += value + + def get_action_probs(self, temperature: float = 1.0) -> dict[Any, float]: + """ + Get action selection probabilities based on visit counts. + + Args: + temperature: Temperature parameter + - temperature -> 0: argmax (deterministic) + - temperature = 1: proportional to visits + - temperature -> inf: uniform + + Returns: + Dictionary mapping actions to probabilities + """ + if not self.children: + return {} + + if temperature == 0: + # Deterministic: select most visited + visits = {action: child.visit_count for action, child in self.children.items()} + max_visits = max(visits.values()) + best_actions = [a for a, v in visits.items() if v == max_visits] + + # Uniform over best actions + prob = 1.0 / len(best_actions) + return {a: (prob if a in best_actions else 0.0) for a in self.children} + + # Temperature-scaled visits + visits = np.array([child.visit_count for child in self.children.values()]) + actions = list(self.children.keys()) + + if temperature != 1.0: + visits = visits ** (1.0 / temperature) + + # Normalize to probabilities + probs = visits / visits.sum() + + return dict(zip(actions, probs, strict=True)) + + +class NeuralMCTS: + """ + Neural-guided MCTS for decision making. + + Combines tree search with neural network evaluation + using the AlphaZero algorithm. + """ + + def __init__( + self, + policy_value_network: nn.Module, + config: MCTSConfig, + device: str = "cpu", + ): + """ + Initialize neural MCTS. + + Args: + policy_value_network: Network that outputs (policy, value) + config: MCTS configuration + device: Device for neural network + """ + self.network = policy_value_network + self.config = config + self.device = device + + # Caching for network evaluations + self.cache: dict[str, tuple[np.ndarray, float]] = {} + self.cache_hits = 0 + self.cache_misses = 0 + + def add_dirichlet_noise( + self, + policy_probs: np.ndarray, + epsilon: float | None = None, + alpha: float | None = None, + ) -> np.ndarray: + """ + Add Dirichlet noise to policy for exploration (at root only). + + Policy' = (1 - epsilon) * Policy + epsilon * Noise + + Args: + policy_probs: Original policy probabilities + epsilon: Mixing parameter (defaults to config) + alpha: Dirichlet concentration parameter (defaults to config) + + Returns: + Noised policy probabilities + """ + epsilon = epsilon or self.config.dirichlet_epsilon + alpha = alpha or self.config.dirichlet_alpha + + noise = np.random.dirichlet([alpha] * len(policy_probs)) + return (1 - epsilon) * policy_probs + epsilon * noise + + @torch.no_grad() + async def evaluate_state(self, state: GameState, add_noise: bool = False) -> tuple[np.ndarray, float]: + """ + Evaluate state using neural network. + + Args: + state: Game state to evaluate + add_noise: Whether to add Dirichlet noise (for root exploration) + + Returns: + (policy_probs, value) tuple + """ + # Check cache + state_hash = state.get_hash() + if not add_noise and state_hash in self.cache: + self.cache_hits += 1 + return self.cache[state_hash] + + self.cache_misses += 1 + + # Get legal actions + legal_actions = state.get_legal_actions() + if not legal_actions: + return np.array([]), 0.0 + + # Convert state to tensor + state_tensor = state.to_tensor().unsqueeze(0).to(self.device) + + # Network forward pass + policy_logits, value = self.network(state_tensor) + + # Convert to numpy (detach to remove gradients) + policy_logits = policy_logits.squeeze(0).detach().cpu().numpy() + value = value.item() + + # Proper action masking: Map legal actions to their indices in the action space + # Create a mask for legal actions + action_mask = np.full_like(policy_logits, -np.inf) # Mask all actions initially + action_indices = [] + + # Map legal actions to their network output indices + for action in legal_actions: + try: + action_idx = state.action_to_index(action) + if 0 <= action_idx < len(policy_logits): + action_mask[action_idx] = 0 # Unmask legal actions + action_indices.append(action_idx) + except (ValueError, IndexError, AttributeError) as e: + # Fallback: if action_to_index fails, use sequential mapping + print(f"Warning: action_to_index failed for action {action}: {e}") + action_indices = list(range(len(legal_actions))) + action_mask = np.full_like(policy_logits, -np.inf) + action_mask[action_indices] = 0 + break + + # Apply mask before softmax for numerical stability + masked_logits = policy_logits + action_mask + + # Compute softmax over legal actions only + exp_logits = np.exp(masked_logits - np.max(masked_logits)) # Subtract max for stability + policy_probs_full = exp_logits / exp_logits.sum() + + # Extract probabilities for legal actions in order + policy_probs = policy_probs_full[action_indices] + + # Normalize to ensure probabilities sum to 1 (handle numerical errors) + if policy_probs.sum() > 0: + policy_probs = policy_probs / policy_probs.sum() + else: + # Fallback: uniform distribution over legal actions + policy_probs = np.ones(len(legal_actions)) / len(legal_actions) + + # Add Dirichlet noise if requested (root exploration) + if add_noise: + policy_probs = self.add_dirichlet_noise(policy_probs) + + # Cache result (without noise) + if not add_noise: + self.cache[state_hash] = (policy_probs, value) + + return policy_probs, value + + async def search( + self, + root_state: GameState, + num_simulations: int | None = None, + temperature: float = 1.0, + add_root_noise: bool = True, + ) -> tuple[dict[Any, float], NeuralMCTSNode]: + """ + Run MCTS search from root state. + + Args: + root_state: Initial state + num_simulations: Number of MCTS simulations + temperature: Temperature for action selection + add_root_noise: Whether to add Dirichlet noise to root + + Returns: + (action_probs, root_node) tuple + """ + num_simulations = num_simulations or self.config.num_simulations + + # Create root node + root = NeuralMCTSNode(state=root_state) + + # Expand root + policy_probs, _ = await self.evaluate_state(root_state, add_noise=add_root_noise) + legal_actions = root_state.get_legal_actions() + root.expand(policy_probs, legal_actions) + + # Run simulations + for _ in range(num_simulations): + await self._simulate(root) + + # Get action probabilities + action_probs = root.get_action_probs(temperature) + + return action_probs, root + + async def _simulate(self, node: NeuralMCTSNode) -> float: + """ + Run single MCTS simulation (select, expand, evaluate, backpropagate). + + Args: + node: Root node for this simulation + + Returns: + Value from this simulation + """ + path: list[NeuralMCTSNode] = [] + + # Selection: traverse tree using PUCT + current = node + while current.is_expanded and not current.is_terminal: + # Add virtual loss for parallel search + current.add_virtual_loss(self.config.virtual_loss) + path.append(current) + + # Select best child + _, current = current.select_child(self.config.c_puct) + + # Add leaf to path + path.append(current) + current.add_virtual_loss(self.config.virtual_loss) + + # Evaluate leaf node + if current.is_terminal: + # Terminal node: use game result + value = current.state.get_reward() + else: + # Non-terminal: expand and evaluate with network + policy_probs, value = await self.evaluate_state(current.state, add_noise=False) + + if not current.is_expanded: + legal_actions = current.state.get_legal_actions() + current.expand(policy_probs, legal_actions) + + # Backpropagate + for node_in_path in reversed(path): + node_in_path.update(value) + node_in_path.revert_virtual_loss(self.config.virtual_loss) + + # Flip value for opponent + value = -value + + return value + + def select_action( + self, + action_probs: dict[Any, float], + temperature: float = 1.0, + deterministic: bool = False, + ) -> Any: + """ + Select action from probability distribution. + + Args: + action_probs: Action probability dictionary + temperature: Temperature (unused if deterministic=True) + deterministic: If True, select action with highest probability + + Returns: + Selected action + """ + if not action_probs: + return None + + actions = list(action_probs.keys()) + probs = list(action_probs.values()) + + if deterministic or temperature == 0: + return actions[np.argmax(probs)] + + # Sample from distribution + return np.random.choice(actions, p=probs) + + def clear_cache(self): + """Clear the evaluation cache.""" + self.cache.clear() + self.cache_hits = 0 + self.cache_misses = 0 + + def get_cache_stats(self) -> dict: + """Get cache performance statistics.""" + total = self.cache_hits + self.cache_misses + hit_rate = self.cache_hits / total if total > 0 else 0.0 + + return { + "cache_size": len(self.cache), + "cache_hits": self.cache_hits, + "cache_misses": self.cache_misses, + "hit_rate": hit_rate, + } + + +# Training data collection +@dataclass +class MCTSExample: + """Training example from MCTS self-play.""" + + state: torch.Tensor # State representation + policy_target: np.ndarray # Target policy (visit counts) + value_target: float # Target value (game outcome) + player: int # Player to move (1 or -1) + + +class SelfPlayCollector: + """ + Collect training data from self-play games. + + Uses MCTS to generate high-quality training examples. + """ + + def __init__( + self, + mcts: NeuralMCTS, + config: MCTSConfig, + ): + self.mcts = mcts + self.config = config + + async def play_game( + self, + initial_state: GameState, + temperature_threshold: int | None = None, + ) -> list[MCTSExample]: + """ + Play a single self-play game. + + Args: + initial_state: Starting game state + temperature_threshold: Move number to switch to greedy play + + Returns: + List of training examples from the game + """ + temperature_threshold = temperature_threshold or self.config.temperature_threshold + + examples: list[MCTSExample] = [] + state = initial_state + player = 1 # Current player (1 or -1) + move_count = 0 + + while not state.is_terminal(): + # Determine temperature + temperature = ( + self.config.temperature_init if move_count < temperature_threshold else self.config.temperature_final + ) + + # Run MCTS + action_probs, root = await self.mcts.search(state, temperature=temperature, add_root_noise=True) + + # Store training example + # Convert action probs to array for all actions + probs = np.array(list(action_probs.values())) + + examples.append( + MCTSExample( + state=state.to_tensor(), + policy_target=probs, + value_target=0.0, # Will be filled with game outcome + player=player, + ) + ) + + # Select and apply action + action = self.mcts.select_action(action_probs, temperature=temperature) + state = state.apply_action(action) + + # Switch player + player = -player + move_count += 1 + + # Get game outcome + outcome = state.get_reward() + + # Assign values to examples + for example in examples: + # Value is from perspective of the player who made the move + example.value_target = outcome if example.player == 1 else -outcome + + return examples + + async def generate_batch(self, num_games: int, initial_state_fn: Callable[[], GameState]) -> list[MCTSExample]: + """ + Generate a batch of training examples from multiple games. + + Args: + num_games: Number of games to play + initial_state_fn: Function that returns initial game state + + Returns: + Combined list of training examples + """ + all_examples = [] + + for _ in range(num_games): + initial_state = initial_state_fn() + examples = await self.play_game(initial_state) + all_examples.extend(examples) + + # Clear cache periodically + if len(self.mcts.cache) > 10000: + self.mcts.clear_cache() + + return all_examples diff --git a/src/framework/mcts/policies.py b/src/framework/mcts/policies.py new file mode 100644 index 0000000000000000000000000000000000000000..59c23239dc6799d72aa2910355840b87f3705176 --- /dev/null +++ b/src/framework/mcts/policies.py @@ -0,0 +1,397 @@ +""" +MCTS Policies Module - Selection, rollout, and evaluation policies. + +Provides: +- UCB1 with configurable exploration weight +- Rollout heuristics (random, greedy, hybrid) +- Action selection policies (max visits, max value, robust child) +- Progressive widening parameters +""" + +from __future__ import annotations + +import math +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from enum import Enum +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from .core import MCTSState + + +def ucb1( + value_sum: float, + visits: int, + parent_visits: int, + c: float = 1.414, +) -> float: + """ + Upper Confidence Bound 1 (UCB1) formula for tree selection. + + Formula: Q(s,a) + c * sqrt(N(s)) / sqrt(N(s,a)) + + Args: + value_sum: Total accumulated value for the node + visits: Number of visits to the node + parent_visits: Number of visits to the parent node + c: Exploration weight constant (default sqrt(2)) + + Returns: + UCB1 score for node selection + """ + if visits == 0: + return float("inf") + + exploitation = value_sum / visits + exploration = c * ((parent_visits) ** 0.5 / (visits) ** 0.5) + + return exploitation + exploration + + +def ucb1_tuned( + value_sum: float, + value_squared_sum: float, + visits: int, + parent_visits: int, + c: float = 1.0, +) -> float: + """ + UCB1-Tuned variant with variance estimate. + + Provides tighter bounds by considering value variance. + + Args: + value_sum: Total accumulated value + value_squared_sum: Sum of squared values (for variance) + visits: Number of visits + parent_visits: Parent visit count + c: Exploration constant + + Returns: + UCB1-Tuned score + """ + if visits == 0: + return float("inf") + + mean_value = value_sum / visits + variance = value_squared_sum / visits - mean_value**2 + variance = max(0, variance) # Ensure non-negative + + # Variance bound term + ln_parent = math.log(parent_visits) + variance_bound = variance + math.sqrt(2 * ln_parent / visits) + min_bound = min(0.25, variance_bound) + + exploitation = mean_value + exploration = c * math.sqrt(ln_parent / visits * min_bound) + + return exploitation + exploration + + +class SelectionPolicy(Enum): + """Policy for selecting the final action after MCTS search.""" + + MAX_VISITS = "max_visits" + """Select action with most visits (most robust).""" + + MAX_VALUE = "max_value" + """Select action with highest average value (greedy).""" + + ROBUST_CHILD = "robust_child" + """Select action balancing visits and value.""" + + SECURE_CHILD = "secure_child" + """Select action with lowest lower confidence bound.""" + + +class RolloutPolicy(ABC): + """Abstract base class for rollout/simulation policies.""" + + @abstractmethod + async def evaluate( + self, + state: MCTSState, + rng: np.random.Generator, + max_depth: int = 10, + ) -> float: + """ + Evaluate a state through rollout simulation. + + Args: + state: State to evaluate + rng: Seeded random number generator + max_depth: Maximum rollout depth + + Returns: + Estimated value in [0, 1] range + """ + pass + + +class RandomRolloutPolicy(RolloutPolicy): + """Random rollout policy - uniform random evaluation.""" + + def __init__(self, base_value: float = 0.5, noise_scale: float = 0.3): + """ + Initialize random rollout policy. + + Args: + base_value: Base value for evaluations + noise_scale: Scale of random noise + """ + self.base_value = base_value + self.noise_scale = noise_scale + + async def evaluate( + self, + _state: MCTSState, + rng: np.random.Generator, + _max_depth: int = 10, + ) -> float: + """Generate random evaluation with noise.""" + noise = rng.uniform(-self.noise_scale, self.noise_scale) + value = self.base_value + noise + return max(0.0, min(1.0, value)) + + +class GreedyRolloutPolicy(RolloutPolicy): + """Greedy rollout policy using domain heuristics.""" + + def __init__( + self, + heuristic_fn: Callable[[MCTSState], float], + noise_scale: float = 0.05, + ): + """ + Initialize greedy rollout policy. + + Args: + heuristic_fn: Function to evaluate state heuristically + noise_scale: Small noise for tie-breaking + """ + self.heuristic_fn = heuristic_fn + self.noise_scale = noise_scale + + async def evaluate( + self, + state: MCTSState, + rng: np.random.Generator, + _max_depth: int = 10, + ) -> float: + """Evaluate using heuristic with small noise.""" + base_value = self.heuristic_fn(state) + noise = rng.uniform(-self.noise_scale, self.noise_scale) + value = base_value + noise + return max(0.0, min(1.0, value)) + + +class HybridRolloutPolicy(RolloutPolicy): + """Hybrid policy combining random and heuristic evaluation.""" + + def __init__( + self, + heuristic_fn: Callable[[MCTSState], float] | None = None, + heuristic_weight: float = 0.7, + random_weight: float = 0.3, + base_random_value: float = 0.5, + noise_scale: float = 0.2, + ): + """ + Initialize hybrid rollout policy. + + Args: + heuristic_fn: Optional heuristic evaluation function + heuristic_weight: Weight for heuristic component + random_weight: Weight for random component + base_random_value: Base value for random component + noise_scale: Noise scale for random component + """ + self.heuristic_fn = heuristic_fn + self.heuristic_weight = heuristic_weight + self.random_weight = random_weight + self.base_random_value = base_random_value + self.noise_scale = noise_scale + + # Normalize weights + total_weight = heuristic_weight + random_weight + if total_weight > 0: + self.heuristic_weight /= total_weight + self.random_weight /= total_weight + + async def evaluate( + self, + state: MCTSState, + rng: np.random.Generator, + _max_depth: int = 10, + ) -> float: + """Combine heuristic and random evaluation.""" + # Random component + random_noise = rng.uniform(-self.noise_scale, self.noise_scale) + random_value = self.base_random_value + random_noise + + # Heuristic component + heuristic_value = self.heuristic_fn(state) if self.heuristic_fn is not None else self.base_random_value + + # Combine + value = self.heuristic_weight * heuristic_value + self.random_weight * random_value + + return max(0.0, min(1.0, value)) + + +class LLMRolloutPolicy(RolloutPolicy): + """Rollout policy that uses an LLM for state evaluation.""" + + def __init__( + self, + evaluate_fn: Callable[[MCTSState], Awaitable[float]], + cache_results: bool = True, + ): + """ + Initialize LLM rollout policy. + + Args: + evaluate_fn: Async function to evaluate state with LLM + cache_results: Whether to cache evaluation results + """ + self.evaluate_fn = evaluate_fn + self.cache_results = cache_results + self._cache: dict = {} + + async def evaluate( + self, + state: MCTSState, + _rng: np.random.Generator, + _max_depth: int = 10, + ) -> float: + """Evaluate state using LLM.""" + state_key = state.to_hash_key() + + if self.cache_results and state_key in self._cache: + return self._cache[state_key] + + value = await self.evaluate_fn(state) + value = max(0.0, min(1.0, value)) + + if self.cache_results: + self._cache[state_key] = value + + return value + + +class ProgressiveWideningConfig: + """Configuration for progressive widening in MCTS.""" + + def __init__( + self, + k: float = 1.0, + alpha: float = 0.5, + ): + """ + Configure progressive widening parameters. + + Progressive widening expands when: visits > k * num_children^alpha + + Args: + k: Coefficient controlling expansion threshold + alpha: Exponent controlling growth rate + + Common configurations: + - k=1.0, alpha=0.5: Moderate widening (default) + - k=2.0, alpha=0.5: Conservative (fewer expansions) + - k=0.5, alpha=0.5: Aggressive (more expansions) + - k=1.0, alpha=0.3: Very aggressive + - k=1.0, alpha=0.7: Very conservative + """ + if k <= 0: + raise ValueError("k must be positive") + if not 0 < alpha < 1: + raise ValueError("alpha must be in (0, 1)") + + self.k = k + self.alpha = alpha + + def should_expand(self, visits: int, num_children: int) -> bool: + """ + Check if expansion should occur. + + Args: + visits: Number of visits to node + num_children: Current number of children + + Returns: + True if should expand, False otherwise + """ + threshold = self.k * (num_children**self.alpha) + return visits > threshold + + def min_visits_for_expansion(self, num_children: int) -> int: + """ + Calculate minimum visits needed to expand to next child. + + Args: + num_children: Current number of children + + Returns: + Minimum visit count for expansion + """ + threshold = self.k * (num_children**self.alpha) + return int(math.ceil(threshold)) + + def __repr__(self) -> str: + return f"ProgressiveWideningConfig(k={self.k}, alpha={self.alpha})" + + +def compute_action_probabilities( + children_stats: list[dict], + temperature: float = 1.0, +) -> list[float]: + """ + Compute action probabilities from visit counts using softmax. + + Args: + children_stats: List of dicts with 'visits' key + temperature: Temperature parameter (lower = more deterministic) + + Returns: + List of probabilities for each action + """ + if not children_stats: + return [] + + visits = np.array([c["visits"] for c in children_stats], dtype=float) + + if temperature == 0: + # Deterministic: assign 1.0 to max, 0 to others + probs = np.zeros_like(visits) + probs[np.argmax(visits)] = 1.0 + return probs.tolist() + + # Apply temperature + scaled_visits = visits ** (1.0 / temperature) + probs = scaled_visits / scaled_visits.sum() + return probs.tolist() + + +def select_action_stochastic( + children_stats: list[dict], + rng: np.random.Generator, + temperature: float = 1.0, +) -> int: + """ + Stochastically select action based on visit counts. + + Args: + children_stats: List of child statistics + rng: Random number generator + temperature: Temperature for softmax + + Returns: + Index of selected action + """ + probs = compute_action_probabilities(children_stats, temperature) + if not probs: + raise ValueError("No actions to select from") + return rng.choice(len(probs), p=probs) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/policy_value_net.py b/src/models/policy_value_net.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2567beb7dde1dfd45d0411b80d50ebe1d1b2d0 --- /dev/null +++ b/src/models/policy_value_net.py @@ -0,0 +1,422 @@ +""" +Policy-Value Network using ResNet Architecture. + +Implements the dual-head neural network used in AlphaZero: +- Policy Head: Outputs action probabilities +- Value Head: Outputs state value estimation + +Based on: +- "Mastering Chess and Shogi by Self-Play with a General RL Algorithm" (AlphaZero) +- Deep Residual Learning for Image Recognition (ResNet) +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..training.system_config import NeuralNetworkConfig + + +class ResidualBlock(nn.Module): + """ + Residual block with batch normalization and skip connections. + + Architecture: + Conv -> BN -> ReLU -> Conv -> BN -> Add -> ReLU + """ + + def __init__(self, channels: int, use_batch_norm: bool = True): + super().__init__() + + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(channels) if use_batch_norm else nn.Identity() + + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(channels) if use_batch_norm else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply residual block transformation.""" + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + # Skip connection + out = out + residual + out = F.relu(out) + + return out + + +class PolicyHead(nn.Module): + """ + Policy head for outputting action probabilities. + + Architecture: + Conv -> BN -> ReLU -> FC -> LogSoftmax + """ + + def __init__( + self, + input_channels: int, + policy_conv_channels: int, + action_size: int, + board_size: int = 19, + ): + super().__init__() + + self.conv = nn.Conv2d(input_channels, policy_conv_channels, kernel_size=1, bias=False) + self.bn = nn.BatchNorm2d(policy_conv_channels) + + # Assuming square board + fc_input_size = policy_conv_channels * board_size * board_size + + self.fc = nn.Linear(fc_input_size, action_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute policy (action probabilities). + + Args: + x: [batch, channels, height, width] + + Returns: + Log probabilities: [batch, action_size] + """ + batch_size = x.size(0) + + out = self.conv(x) + out = self.bn(out) + out = F.relu(out) + + # Flatten spatial dimensions + out = out.view(batch_size, -1) + + # Fully connected layer + out = self.fc(out) + + # Log probabilities for numerical stability + return F.log_softmax(out, dim=1) + + +class ValueHead(nn.Module): + """ + Value head for estimating state value. + + Architecture: + Conv -> BN -> ReLU -> FC -> ReLU -> FC -> Tanh + """ + + def __init__( + self, + input_channels: int, + value_conv_channels: int, + value_fc_hidden: int, + board_size: int = 19, + ): + super().__init__() + + self.conv = nn.Conv2d(input_channels, value_conv_channels, kernel_size=1, bias=False) + self.bn = nn.BatchNorm2d(value_conv_channels) + + # Assuming square board + fc_input_size = value_conv_channels * board_size * board_size + + self.fc1 = nn.Linear(fc_input_size, value_fc_hidden) + self.fc2 = nn.Linear(value_fc_hidden, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute value estimation. + + Args: + x: [batch, channels, height, width] + + Returns: + Value: [batch, 1] in range [-1, 1] + """ + batch_size = x.size(0) + + out = self.conv(x) + out = self.bn(out) + out = F.relu(out) + + # Flatten spatial dimensions + out = out.view(batch_size, -1) + + # Fully connected layers + out = self.fc1(out) + out = F.relu(out) + + out = self.fc2(out) + + # Tanh to bound value in [-1, 1] + return torch.tanh(out) + + +class PolicyValueNetwork(nn.Module): + """ + Combined policy-value network with ResNet backbone. + + This is the core neural network used in AlphaZero-style learning. + """ + + def __init__(self, config: NeuralNetworkConfig, board_size: int = 19): + super().__init__() + self.config = config + self.board_size = board_size + + # Initial convolution + self.conv_input = nn.Conv2d( + config.input_channels, + config.num_channels, + kernel_size=3, + padding=1, + bias=False, + ) + self.bn_input = nn.BatchNorm2d(config.num_channels) if config.use_batch_norm else nn.Identity() + + # Residual blocks (shared feature extractor) + self.res_blocks = nn.ModuleList( + [ResidualBlock(config.num_channels, config.use_batch_norm) for _ in range(config.num_res_blocks)] + ) + + # Policy head + self.policy_head = PolicyHead( + input_channels=config.num_channels, + policy_conv_channels=config.policy_conv_channels, + action_size=config.action_size, + board_size=board_size, + ) + + # Value head + self.value_head = ValueHead( + input_channels=config.num_channels, + value_conv_channels=config.value_conv_channels, + value_fc_hidden=config.value_fc_hidden, + board_size=board_size, + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the network. + + Args: + x: Input state [batch, channels, height, width] + + Returns: + (policy_logits, value) tuple + - policy_logits: [batch, action_size] log probabilities + - value: [batch, 1] state value in [-1, 1] + """ + # Initial convolution + out = self.conv_input(x) + out = self.bn_input(out) + out = F.relu(out) + + # Residual blocks + for res_block in self.res_blocks: + out = res_block(out) + + # Split into policy and value heads + policy = self.policy_head(out) + value = self.value_head(out) + + return policy, value + + def predict(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Inference mode prediction. + + Args: + state: Input state tensor + + Returns: + (policy_probs, value) tuple with probabilities (not log) + """ + with torch.no_grad(): + policy_log_probs, value = self.forward(state) + policy_probs = torch.exp(policy_log_probs) + return policy_probs, value + + def get_parameter_count(self) -> int: + """Return total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + +class AlphaZeroLoss(nn.Module): + """ + Combined loss function for AlphaZero training. + + Loss = (z - v)^2 - π^T log(p) + c||θ||^2 + + Where: + - z: actual game outcome + - v: value prediction + - π: MCTS visit count distribution + - p: policy prediction + - c: L2 regularization coefficient + """ + + def __init__(self, value_loss_weight: float = 1.0): + super().__init__() + self.value_loss_weight = value_loss_weight + + def forward( + self, + policy_logits: torch.Tensor, + value: torch.Tensor, + target_policy: torch.Tensor, + target_value: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """ + Compute AlphaZero loss. + + Args: + policy_logits: Predicted policy log probabilities [batch, action_size] + value: Predicted values [batch, 1] + target_policy: Target policy from MCTS [batch, action_size] + target_value: Target value from game outcome [batch, 1] + + Returns: + (total_loss, loss_dict) tuple + """ + # Value loss: MSE between predicted and actual outcome + value_loss = F.mse_loss(value.squeeze(-1), target_value) + + # Policy loss: Cross-entropy between MCTS policy and network policy + # Target policy is already normalized, policy_logits are log probabilities + policy_loss = -torch.sum(target_policy * policy_logits, dim=1).mean() + + # Combined loss + total_loss = self.value_loss_weight * value_loss + policy_loss + + loss_dict = { + "total": total_loss.item(), + "value": value_loss.item(), + "policy": policy_loss.item(), + } + + return total_loss, loss_dict + + +def create_policy_value_network( + config: NeuralNetworkConfig, + board_size: int = 19, + device: str = "cpu", +) -> PolicyValueNetwork: + """ + Factory function to create and initialize policy-value network. + + Args: + config: Network configuration + board_size: Board/grid size (for games) + device: Device to place model on + + Returns: + Initialized PolicyValueNetwork + """ + network = PolicyValueNetwork(config, board_size) + + # He initialization for convolutional layers + def init_weights(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + network.apply(init_weights) + network = network.to(device) + + return network + + +# Example: Simpler MLP-based policy-value network for non-spatial tasks +class MLPPolicyValueNetwork(nn.Module): + """ + MLP-based policy-value network for non-spatial state representations. + + Useful for tasks where state is not naturally represented as an image. + """ + + def __init__( + self, + state_dim: int, + action_size: int, + hidden_dims: list[int] | None = None, + use_batch_norm: bool = True, + dropout: float = 0.1, + ): + super().__init__() + self.state_dim = state_dim + self.action_size = action_size + + if hidden_dims is None: + hidden_dims = [512, 256] + + # Shared feature extractor + layers = [] + prev_dim = state_dim + + for hidden_dim in hidden_dims: + layers.append(nn.Linear(prev_dim, hidden_dim)) + if use_batch_norm: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.ReLU()) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + prev_dim = hidden_dim + + self.shared_network = nn.Sequential(*layers) + + # Policy head + self.policy_head = nn.Sequential( + nn.Linear(prev_dim, prev_dim // 2), + nn.ReLU(), + nn.Linear(prev_dim // 2, action_size), + ) + + # Value head + self.value_head = nn.Sequential( + nn.Linear(prev_dim, prev_dim // 2), + nn.ReLU(), + nn.Linear(prev_dim // 2, 1), + nn.Tanh(), + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass. + + Args: + x: Input state [batch, state_dim] + + Returns: + (policy_log_probs, value) tuple + """ + # Shared features + features = self.shared_network(x) + + # Policy + policy_logits = self.policy_head(features) + policy_log_probs = F.log_softmax(policy_logits, dim=1) + + # Value + value = self.value_head(features) + + return policy_log_probs, value + + def get_parameter_count(self) -> int: + """Return total number of trainable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) diff --git a/src/models/validation.py b/src/models/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a511e884446bc9e267a65ac1f24c58579d0b13 --- /dev/null +++ b/src/models/validation.py @@ -0,0 +1,489 @@ +""" +Input validation models for LangGraph Multi-Agent MCTS framework. + +Provides: +- Pydantic models for all external inputs +- Query sanitization and length limits +- Configuration validation +- MCP tool input validation with strict type checking +- Security-focused input processing +""" + +import re +from datetime import datetime +from typing import Any + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, +) + +# Constants for validation +MAX_QUERY_LENGTH = 10000 +MIN_QUERY_LENGTH = 1 +MAX_CONTEXT_LENGTH = 50000 +MAX_ITERATIONS = 10000 +MIN_ITERATIONS = 1 +MAX_EXPLORATION_WEIGHT = 10.0 +MIN_EXPLORATION_WEIGHT = 0.0 +MAX_BATCH_SIZE = 100 + + +class QueryInput(BaseModel): + """ + Validated query input for the multi-agent framework. + + Performs sanitization and security checks on user queries. + """ + + model_config = ConfigDict( + strict=True, + validate_assignment=True, + extra="forbid", + ) + + query: str = Field( + ..., min_length=MIN_QUERY_LENGTH, max_length=MAX_QUERY_LENGTH, description="User query to process" + ) + + use_rag: bool = Field(default=True, description="Enable RAG context retrieval") + + use_mcts: bool = Field(default=False, description="Enable MCTS simulation for tactical planning") + + thread_id: str | None = Field( + default=None, + max_length=100, + pattern=r"^[a-zA-Z0-9_-]+$", + description="Conversation thread ID for state persistence", + ) + + @field_validator("query") + @classmethod + def sanitize_query(cls, v: str) -> str: + """ + Sanitize query input for security. + + Removes potentially dangerous patterns while preserving legitimate content. + """ + # Strip leading/trailing whitespace + v = v.strip() + + # Check for empty query after stripping + if not v: + raise ValueError("Query cannot be empty or contain only whitespace") + + # Remove null bytes + v = v.replace("\x00", "") + + # Limit consecutive whitespace + v = re.sub(r"\s+", " ", v) + + # Check for suspicious patterns (basic injection prevention) + suspicious_patterns = [ + r"]*>", # Script tags + r"javascript:", # JavaScript URLs + r"on\w+\s*=", # Event handlers + r"\{\{.*\}\}", # Template injection + r"\$\{.*\}", # Template literals + ] + + for pattern in suspicious_patterns: + if re.search(pattern, v, re.IGNORECASE): + raise ValueError(f"Query contains potentially unsafe content matching pattern: {pattern}") + + return v + + @field_validator("thread_id") + @classmethod + def validate_thread_id(cls, v: str | None) -> str | None: + """Validate thread ID format for safe storage keys.""" + if v is not None: # noqa: SIM102 + # Additional safety check beyond pattern + if ".." in v or "/" in v or "\\" in v: + raise ValueError("Thread ID contains invalid path characters") + return v + + +class MCTSConfig(BaseModel): + """ + Validated MCTS configuration parameters. + + Enforces bounds on exploration weight and iteration counts. + """ + + model_config = ConfigDict( + strict=True, + extra="forbid", + ) + + iterations: int = Field( + default=100, ge=MIN_ITERATIONS, le=MAX_ITERATIONS, description="Number of MCTS simulation iterations" + ) + + exploration_weight: float = Field( + default=1.414, + ge=MIN_EXPLORATION_WEIGHT, + le=MAX_EXPLORATION_WEIGHT, + description="UCB1 exploration constant (c parameter)", + ) + + max_depth: int = Field(default=10, ge=1, le=50, description="Maximum tree depth for MCTS expansion") + + simulation_timeout_seconds: float = Field( + default=30.0, ge=1.0, le=300.0, description="Timeout for MCTS simulation phase" + ) + + @field_validator("exploration_weight") + @classmethod + def validate_exploration_weight(cls, v: float) -> float: + """Validate exploration weight is within reasonable bounds.""" + if not (MIN_EXPLORATION_WEIGHT <= v <= MAX_EXPLORATION_WEIGHT): + raise ValueError( + f"Exploration weight must be between {MIN_EXPLORATION_WEIGHT} and {MAX_EXPLORATION_WEIGHT}" + ) + # Warn for unusual values + if v < 0.5 or v > 3.0: + import warnings + + warnings.warn( + f"Exploration weight {v} is outside typical range (0.5-3.0). " + "This may lead to suboptimal search behavior.", + UserWarning, + stacklevel=2, + ) + return v + + +class AgentConfig(BaseModel): + """ + Validated configuration for HRM/TRM agents. + """ + + model_config = ConfigDict( + extra="forbid", + ) + + max_iterations: int = Field(default=3, ge=1, le=20, description="Maximum iterations for agent refinement") + + consensus_threshold: float = Field( + default=0.75, ge=0.0, le=1.0, description="Consensus threshold for agent agreement" + ) + + temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="LLM temperature for response generation") + + max_tokens: int = Field(default=2048, ge=1, le=128000, description="Maximum tokens in LLM response") + + @field_validator("temperature") + @classmethod + def validate_temperature(cls, v: float) -> float: + """Validate temperature is within LLM bounds.""" + if v < 0.0 or v > 2.0: + raise ValueError("Temperature must be between 0.0 and 2.0") + return v + + +class RAGConfig(BaseModel): + """ + Validated RAG (Retrieval Augmented Generation) configuration. + """ + + model_config = ConfigDict( + extra="forbid", + ) + + top_k: int = Field(default=5, ge=1, le=50, description="Number of documents to retrieve") + + similarity_threshold: float = Field( + default=0.5, ge=0.0, le=1.0, description="Minimum similarity score for retrieved documents" + ) + + chunk_size: int = Field(default=1000, ge=100, le=10000, description="Document chunk size for embedding") + + chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Overlap between document chunks") + + @model_validator(mode="after") + def validate_chunk_overlap(self) -> "RAGConfig": + """Ensure chunk overlap is less than chunk size.""" + if self.chunk_overlap >= self.chunk_size: + raise ValueError("Chunk overlap must be less than chunk size") + return self + + +class MCPToolInput(BaseModel): + """ + Base validation model for MCP (Model Context Protocol) tool inputs. + + Provides strict validation for external tool invocations. + """ + + model_config = ConfigDict( + strict=True, + extra="forbid", + ) + + tool_name: str = Field( + ..., + min_length=1, + max_length=100, + pattern=r"^[a-zA-Z][a-zA-Z0-9_-]*$", + description="Name of the MCP tool to invoke", + ) + + parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters as key-value pairs") + + timeout_seconds: float = Field(default=30.0, ge=1.0, le=300.0, description="Timeout for tool execution") + + @field_validator("tool_name") + @classmethod + def validate_tool_name(cls, v: str) -> str: + """Validate tool name is safe and follows naming conventions.""" + # Prevent path traversal in tool names + if ".." in v or "/" in v or "\\" in v: + raise ValueError("Tool name contains invalid characters") + + # Prevent overly long names + if len(v) > 100: + raise ValueError("Tool name exceeds maximum length of 100 characters") + + return v + + @field_validator("parameters") + @classmethod + def validate_parameters(cls, v: dict[str, Any]) -> dict[str, Any]: + """Validate tool parameters for security.""" + # Check for reasonable size + if len(str(v)) > 100000: + raise ValueError("Tool parameters exceed maximum size") + + # Check parameter count + if len(v) > 50: + raise ValueError("Too many parameters (maximum 50)") + + # Validate parameter keys + for key in v: + if not isinstance(key, str): + raise ValueError("Parameter keys must be strings") + if len(key) > 100: + raise ValueError(f"Parameter key '{key[:20]}...' exceeds maximum length") + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", key): + raise ValueError(f"Invalid parameter key format: {key}") + + return v + + +class FileReadInput(MCPToolInput): + """ + Validated input for file reading operations. + + Implements path traversal protection. + """ + + tool_name: str = Field(default="read_file", frozen=True) + + file_path: str = Field(..., min_length=1, max_length=1000, description="Path to file to read") + + @field_validator("file_path") + @classmethod + def validate_file_path(cls, v: str) -> str: + """Validate file path for security concerns.""" + # Normalize path + v = v.strip() + + # Check for path traversal attempts + if ".." in v: + raise ValueError("Path traversal detected: '..' not allowed in file path") + + # Check for absolute paths (may be allowed in some contexts) + if v.startswith("/"): + import warnings + + warnings.warn( + "Absolute file path provided. Ensure this is within allowed directories.", UserWarning, stacklevel=2 + ) + + # Check for suspicious patterns + suspicious = [ + "/etc/", + "/root/", + "~/.ssh/", + "/var/", + "\\windows\\", + "\\system32\\", + ] + for pattern in suspicious: + if pattern.lower() in v.lower(): + raise ValueError(f"File path contains restricted directory: {pattern}") + + return v + + +class WebFetchInput(MCPToolInput): + """ + Validated input for web fetch operations. + + Implements URL validation and security checks. + """ + + tool_name: str = Field(default="web_fetch", frozen=True) + + url: str = Field(..., min_length=1, max_length=2000, description="URL to fetch") + + @field_validator("url") + @classmethod + def validate_url(cls, v: str) -> str: + """Validate URL for security.""" + v = v.strip() + + # Must start with https:// for security (http:// only for local) + if not v.startswith(("https://", "http://localhost", "http://127.0.0.1")): + raise ValueError("URL must use HTTPS protocol (except for localhost)") + + # Check for suspicious patterns + if any(char in v for char in ["<", ">", "'", '"', ";"]): + raise ValueError("URL contains invalid characters") + + # Validate basic URL structure + url_pattern = r"^https?://[^\s/$.?#].[^\s]*$" + if not re.match(url_pattern, v, re.IGNORECASE): + raise ValueError("Invalid URL format") + + return v + + +class BatchQueryInput(BaseModel): + """ + Validated batch query input for processing multiple queries. + """ + + model_config = ConfigDict( + strict=True, + extra="forbid", + ) + + queries: list[QueryInput] = Field( + ..., min_length=1, max_length=MAX_BATCH_SIZE, description="List of queries to process in batch" + ) + + parallel: bool = Field(default=False, description="Process queries in parallel (if system supports)") + + @field_validator("queries") + @classmethod + def validate_batch_size(cls, v: list[QueryInput]) -> list[QueryInput]: + """Validate batch doesn't exceed limits.""" + if len(v) > MAX_BATCH_SIZE: + raise ValueError(f"Batch size exceeds maximum of {MAX_BATCH_SIZE}") + if len(v) == 0: + raise ValueError("Batch must contain at least one query") + return v + + +class APIRequestMetadata(BaseModel): + """ + Metadata for API request tracking and audit logging. + + Used for security monitoring and rate limiting. + """ + + model_config = ConfigDict( + extra="forbid", + ) + + request_id: str = Field( + ..., min_length=1, max_length=100, pattern=r"^[a-zA-Z0-9_-]+$", description="Unique request identifier" + ) + + timestamp: datetime = Field(default_factory=datetime.utcnow, description="Request timestamp (UTC)") + + client_id: str | None = Field( + default=None, max_length=100, pattern=r"^[a-zA-Z0-9_-]+$", description="Client identifier for rate limiting" + ) + + source_ip: str | None = Field(default=None, description="Source IP address (for audit logging)") + + @field_validator("source_ip") + @classmethod + def validate_ip_address(cls, v: str | None) -> str | None: + """Validate IP address format.""" + if v is not None: + # Basic IPv4/IPv6 validation + import ipaddress + + try: + ipaddress.ip_address(v) + except ValueError: + raise ValueError(f"Invalid IP address format: {v}") + return v + + +# Convenience functions for common validation patterns + + +def validate_query(query: str, **kwargs) -> QueryInput: + """ + Validate a query string and return a validated QueryInput model. + + Args: + query: Raw query string + **kwargs: Additional query parameters + + Returns: + QueryInput: Validated query model + + Raises: + ValidationError: If validation fails + """ + return QueryInput(query=query, **kwargs) + + +def validate_mcts_config(**kwargs) -> MCTSConfig: + """ + Validate MCTS configuration parameters. + + Args: + **kwargs: MCTS configuration parameters + + Returns: + MCTSConfig: Validated configuration + + Raises: + ValidationError: If validation fails + """ + return MCTSConfig(**kwargs) + + +def validate_tool_input(tool_name: str, parameters: dict[str, Any], **kwargs) -> MCPToolInput: + """ + Validate MCP tool input parameters. + + Args: + tool_name: Name of the tool + parameters: Tool parameters + **kwargs: Additional options + + Returns: + MCPToolInput: Validated tool input + + Raises: + ValidationError: If validation fails + """ + return MCPToolInput(tool_name=tool_name, parameters=parameters, **kwargs) + + +# Type exports +__all__ = [ + "QueryInput", + "MCTSConfig", + "AgentConfig", + "RAGConfig", + "MCPToolInput", + "FileReadInput", + "WebFetchInput", + "BatchQueryInput", + "APIRequestMetadata", + "validate_query", + "validate_mcts_config", + "validate_tool_input", +] diff --git a/src/observability/__init__.py b/src/observability/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f06667b6f51c6423e82a4d4d9d416e46252b6564 --- /dev/null +++ b/src/observability/__init__.py @@ -0,0 +1,59 @@ +# Observability Module +""" +Comprehensive observability infrastructure for multi-agent MCTS framework. + +Includes: +- JSON structured logging with correlation IDs +- OpenTelemetry tracing with automatic span creation +- Metrics collection for MCTS and agent performance +- Debug utilities for MCTS tree visualization +- Performance profiling tools +""" + +from .debug import MCTSDebugger, export_tree_to_dot, visualize_mcts_tree +from .logging import CorrelationIdFilter, get_logger, setup_logging +from .metrics import MetricsCollector, agent_metrics, mcts_metrics +from .profiling import AsyncProfiler, MemoryProfiler, generate_performance_report, profile_block +from .tracing import TracingManager, get_tracer, trace_operation + +# Braintrust integration (optional) +try: + from .braintrust_tracker import ( # noqa: F401 + BRAINTRUST_AVAILABLE, + BraintrustContextManager, + BraintrustTracker, + create_training_tracker, + ) + + _braintrust_exports = [ + "BraintrustTracker", + "BraintrustContextManager", + "create_training_tracker", + "BRAINTRUST_AVAILABLE", + ] +except ImportError: + _braintrust_exports = [] + +__all__ = [ + # Logging + "setup_logging", + "get_logger", + "CorrelationIdFilter", + # Tracing + "TracingManager", + "trace_operation", + "get_tracer", + # Metrics + "MetricsCollector", + "mcts_metrics", + "agent_metrics", + # Debug + "MCTSDebugger", + "export_tree_to_dot", + "visualize_mcts_tree", + # Profiling + "profile_block", + "AsyncProfiler", + "MemoryProfiler", + "generate_performance_report", +] + _braintrust_exports diff --git a/src/observability/braintrust_tracker.py b/src/observability/braintrust_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..af12012dc889abe82d2304a737bf2f4dd2176dbf --- /dev/null +++ b/src/observability/braintrust_tracker.py @@ -0,0 +1,469 @@ +""" +Braintrust integration for experiment tracking in Neural Meta-Controller training. + +Provides experiment logging, metric tracking, and model versioning capabilities. +""" + +import os +from datetime import datetime +from typing import Any + +# Check if braintrust is available +try: + import braintrust + + BRAINTRUST_AVAILABLE = True +except ImportError: + BRAINTRUST_AVAILABLE = False + braintrust = None + + +class BraintrustTracker: + """ + Experiment tracker using Braintrust API for neural meta-controller training. + + Provides: + - Experiment creation and management + - Metric logging (loss, accuracy, etc.) + - Hyperparameter tracking + - Model evaluation logging + - Training run comparison + """ + + def __init__( + self, + project_name: str = "neural-meta-controller", + api_key: str | None = None, + auto_init: bool = True, + ): + """ + Initialize Braintrust tracker. + + Args: + project_name: Name of the Braintrust project + api_key: Braintrust API key (if None, reads from BRAINTRUST_API_KEY env var) + auto_init: Whether to initialize Braintrust client immediately + """ + self.project_name = project_name + self._api_key = api_key or os.environ.get("BRAINTRUST_API_KEY") + self._experiment: Any = None + self._current_span: Any = None + self._is_initialized = False + self._metrics_buffer: list[dict[str, Any]] = [] + + if not BRAINTRUST_AVAILABLE: + print("Warning: braintrust package not installed. Install with: pip install braintrust") + return + + if auto_init and self._api_key: + self._initialize() + + def _initialize(self) -> None: + """Initialize Braintrust client with API key.""" + if not BRAINTRUST_AVAILABLE: + return + + if self._api_key: + braintrust.login(api_key=self._api_key) + self._is_initialized = True + + @property + def is_available(self) -> bool: + """Check if Braintrust is available and configured.""" + return BRAINTRUST_AVAILABLE and self._is_initialized and self._api_key is not None + + def start_experiment( + self, + experiment_name: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Any | None: + """ + Start a new experiment run. + + Args: + experiment_name: Optional name for the experiment (auto-generated if None) + metadata: Optional metadata to attach to the experiment + + Returns: + Braintrust Experiment object or None if not available + """ + if not self.is_available: + return None + + if experiment_name is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + experiment_name = f"meta_controller_training_{timestamp}" + + try: + self._experiment = braintrust.init( + project=self.project_name, + experiment=experiment_name, + metadata=metadata or {}, + ) + return self._experiment + except Exception as e: + print(f"Warning: Failed to start Braintrust experiment: {e}") + return None + + def log_hyperparameters(self, params: dict[str, Any]) -> None: + """ + Log hyperparameters for the current experiment. + + Args: + params: Dictionary of hyperparameters + """ + if not self.is_available or self._experiment is None: + self._metrics_buffer.append({"type": "hyperparameters", "data": params}) + return + + try: + self._experiment.log(metadata=params) + except Exception as e: + print(f"Warning: Failed to log hyperparameters: {e}") + + def log_training_step( + self, + epoch: int, + step: int, + loss: float, + metrics: dict[str, float] | None = None, + ) -> None: + """ + Log a single training step. + + Args: + epoch: Current epoch number + step: Current step/batch number + loss: Training loss value + metrics: Optional additional metrics (accuracy, etc.) + """ + if not self.is_available or self._experiment is None: + self._metrics_buffer.append( + { + "type": "training_step", + "epoch": epoch, + "step": step, + "loss": loss, + "metrics": metrics or {}, + } + ) + return + + try: + log_data = { + "input": {"epoch": epoch, "step": step}, + "output": {"loss": loss}, + "scores": metrics or {}, + } + self._experiment.log(**log_data) + except Exception as e: + print(f"Warning: Failed to log training step: {e}") + + def log_epoch_summary( + self, + epoch: int, + train_loss: float, + val_loss: float | None = None, + train_accuracy: float | None = None, + val_accuracy: float | None = None, + additional_metrics: dict[str, float] | None = None, + ) -> None: + """ + Log summary metrics for a completed epoch. + + Args: + epoch: Epoch number + train_loss: Training loss for the epoch + val_loss: Optional validation loss + train_accuracy: Optional training accuracy + val_accuracy: Optional validation accuracy + additional_metrics: Optional additional metrics + """ + if not self.is_available or self._experiment is None: + self._metrics_buffer.append( + { + "type": "epoch_summary", + "epoch": epoch, + "train_loss": train_loss, + "val_loss": val_loss, + "train_accuracy": train_accuracy, + "val_accuracy": val_accuracy, + "additional_metrics": additional_metrics or {}, + } + ) + return + + try: + scores = { + "train_loss": train_loss, + } + if val_loss is not None: + scores["val_loss"] = val_loss + if train_accuracy is not None: + scores["train_accuracy"] = train_accuracy + if val_accuracy is not None: + scores["val_accuracy"] = val_accuracy + if additional_metrics: + scores.update(additional_metrics) + + self._experiment.log( + input={"epoch": epoch}, + output={"completed": True}, + scores=scores, + ) + except Exception as e: + print(f"Warning: Failed to log epoch summary: {e}") + + def log_evaluation( + self, + eval_type: str, + predictions: list[str], + ground_truth: list[str], + metrics: dict[str, float], + ) -> None: + """ + Log model evaluation results. + + Args: + eval_type: Type of evaluation (e.g., "validation", "test") + predictions: Model predictions + ground_truth: Ground truth labels + metrics: Computed metrics (accuracy, precision, recall, f1, etc.) + """ + if not self.is_available or self._experiment is None: + self._metrics_buffer.append( + { + "type": "evaluation", + "eval_type": eval_type, + "num_samples": len(predictions), + "metrics": metrics, + } + ) + return + + try: + self._experiment.log( + input={ + "eval_type": eval_type, + "num_samples": len(predictions), + }, + output={ + "predictions_sample": predictions[:10], + "ground_truth_sample": ground_truth[:10], + }, + scores=metrics, + ) + except Exception as e: + print(f"Warning: Failed to log evaluation: {e}") + + def log_model_prediction( + self, + input_features: dict[str, Any], + prediction: str, + confidence: float, + ground_truth: str | None = None, + ) -> None: + """ + Log a single model prediction for analysis. + + Args: + input_features: Input features used for prediction + prediction: Model's predicted agent + confidence: Prediction confidence score + ground_truth: Optional ground truth label + """ + if not self.is_available or self._experiment is None: + self._metrics_buffer.append( + { + "type": "prediction", + "input": input_features, + "prediction": prediction, + "confidence": confidence, + "ground_truth": ground_truth, + } + ) + return + + try: + scores = {"confidence": confidence} + if ground_truth: + scores["correct"] = float(prediction == ground_truth) + + self._experiment.log( + input=input_features, + output={"prediction": prediction}, + expected=ground_truth, + scores=scores, + ) + except Exception as e: + print(f"Warning: Failed to log prediction: {e}") + + def log_model_artifact( + self, + model_path: str, + model_type: str, + metrics: dict[str, float], + metadata: dict[str, Any] | None = None, + ) -> None: + """ + Log a trained model artifact. + + Args: + model_path: Path to the saved model + model_type: Type of model (e.g., "rnn", "bert") + metrics: Final model metrics + metadata: Optional additional metadata + """ + if not self.is_available or self._experiment is None: + self._metrics_buffer.append( + { + "type": "model_artifact", + "model_path": model_path, + "model_type": model_type, + "metrics": metrics, + "metadata": metadata or {}, + } + ) + return + + try: + self._experiment.log( + input={ + "model_path": model_path, + "model_type": model_type, + }, + output={"saved": True}, + scores=metrics, + metadata=metadata or {}, + ) + except Exception as e: + print(f"Warning: Failed to log model artifact: {e}") + + def end_experiment(self) -> str | None: + """ + End the current experiment and return summary URL. + + Returns: + URL to view the experiment in Braintrust dashboard, or None + """ + if not self.is_available or self._experiment is None: + return None + + try: + summary = self._experiment.summarize() + self._experiment = None + return summary.experiment_url if hasattr(summary, "experiment_url") else None + except Exception as e: + print(f"Warning: Failed to end experiment: {e}") + return None + + def get_buffered_metrics(self) -> list[dict[str, Any]]: + """ + Get all buffered metrics (useful when Braintrust is not available). + + Returns: + List of buffered metric dictionaries + """ + return self._metrics_buffer.copy() + + def clear_buffer(self) -> None: + """Clear the metrics buffer.""" + self._metrics_buffer.clear() + + +class BraintrustContextManager: + """ + Context manager for Braintrust experiment tracking. + + Usage: + with BraintrustContextManager( + project_name="neural-meta-controller", + experiment_name="training_run_1" + ) as tracker: + tracker.log_hyperparameters({"learning_rate": 0.001}) + tracker.log_epoch_summary(1, train_loss=0.5, val_loss=0.4) + """ + + def __init__( + self, + project_name: str = "neural-meta-controller", + experiment_name: str | None = None, + api_key: str | None = None, + metadata: dict[str, Any] | None = None, + ): + """ + Initialize context manager. + + Args: + project_name: Name of the Braintrust project + experiment_name: Optional experiment name + api_key: Optional API key + metadata: Optional experiment metadata + """ + self.project_name = project_name + self.experiment_name = experiment_name + self.api_key = api_key + self.metadata = metadata + self.tracker: BraintrustTracker | None = None + self.experiment_url: str | None = None + + def __enter__(self) -> BraintrustTracker: + """Start experiment tracking.""" + self.tracker = BraintrustTracker( + project_name=self.project_name, + api_key=self.api_key, + ) + self.tracker.start_experiment( + experiment_name=self.experiment_name, + metadata=self.metadata, + ) + return self.tracker + + def __exit__(self, exc_type, exc_val, exc_tb): + """End experiment tracking.""" + if self.tracker: + self.experiment_url = self.tracker.end_experiment() + return False + + +def create_training_tracker( + model_type: str = "rnn", + config: dict[str, Any] | None = None, +) -> BraintrustTracker: + """ + Create a pre-configured tracker for meta-controller training. + + Args: + model_type: Type of model being trained ("rnn" or "bert") + config: Optional training configuration + + Returns: + Configured BraintrustTracker instance + """ + tracker = BraintrustTracker(project_name="neural-meta-controller") + + if tracker.is_available: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + experiment_name = f"{model_type}_training_{timestamp}" + + metadata = { + "model_type": model_type, + "timestamp": timestamp, + } + if config: + metadata.update(config) + + tracker.start_experiment( + experiment_name=experiment_name, + metadata=metadata, + ) + + return tracker + + +__all__ = [ + "BraintrustTracker", + "BraintrustContextManager", + "create_training_tracker", + "BRAINTRUST_AVAILABLE", +] diff --git a/src/observability/debug.py b/src/observability/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..fad98fa3d477c0c68d71bf0041bb2ee39a6c68b4 --- /dev/null +++ b/src/observability/debug.py @@ -0,0 +1,530 @@ +""" +Debug utilities for multi-agent MCTS framework. + +Provides: +- MCTS tree visualization (text-based) +- Step-by-step MCTS execution logging when LOG_LEVEL=DEBUG +- UCB score logging at each selection +- State diff tracking between iterations +- Export tree to DOT format for graphviz +""" + +import logging +import os +from typing import Any + +from .logging import get_logger + + +class MCTSDebugger: + """ + Comprehensive debugger for MCTS operations. + + Provides detailed step-by-step logging, tree visualization, + and state tracking for MCTS execution. + """ + + def __init__(self, session_id: str = "default", enabled: bool | None = None): + """ + Initialize MCTS debugger. + + Args: + session_id: Unique identifier for debug session + enabled: Enable debugging (defaults to LOG_LEVEL=DEBUG) + """ + self.session_id = session_id + self.logger = get_logger("observability.debug") + + # Auto-enable if LOG_LEVEL is DEBUG + if enabled is None: + log_level = os.environ.get("LOG_LEVEL", "INFO").upper() + self.enabled = log_level == "DEBUG" + else: + self.enabled = enabled + + # State tracking + self._iteration_count = 0 + self._state_history: list[dict[str, Any]] = [] + self._selection_history: list[dict[str, Any]] = [] + self._ucb_history: list[dict[str, float]] = [] + + def log_iteration_start(self, iteration: int) -> None: + """Log the start of an MCTS iteration.""" + if not self.enabled: + return + + self._iteration_count = iteration + self.logger.debug( + f"=== MCTS Iteration {iteration} START ===", + extra={ + "debug_event": "iteration_start", + "mcts_iteration": iteration, + "session_id": self.session_id, + }, + ) + + def log_selection( + self, + node_id: str, + ucb_score: float, + visits: int, + value: float, + depth: int, + children_count: int, + is_selected: bool = False, + ) -> None: + """Log UCB score and selection decision for a node.""" + if not self.enabled: + return + + selection_data = { + "node_id": node_id, + "ucb_score": round(ucb_score, 6), + "visits": visits, + "value": round(value, 6), + "avg_value": round(value / max(visits, 1), 6), + "depth": depth, + "children_count": children_count, + "is_selected": is_selected, + } + + self._selection_history.append(selection_data) + + log_msg = f"Selection: node={node_id} UCB={ucb_score:.4f} visits={visits} value={value:.4f} depth={depth}" + + if is_selected: + log_msg += " [SELECTED]" + + self.logger.debug( + log_msg, + extra={ + "debug_event": "mcts_selection", + "selection": selection_data, + "session_id": self.session_id, + "mcts_iteration": self._iteration_count, + }, + ) + + def log_ucb_comparison( + self, + parent_id: str, + children_ucb: dict[str, float], + selected_child: str, + ) -> None: + """Log UCB score comparison for all children of a node.""" + if not self.enabled: + return + + self._ucb_history.append(children_ucb) + + ucb_summary = ", ".join( + [ + f"{cid}={score:.4f}{'*' if cid == selected_child else ''}" + for cid, score in sorted(children_ucb.items(), key=lambda x: x[1], reverse=True) + ] + ) + + self.logger.debug( + f"UCB Comparison at {parent_id}: {ucb_summary}", + extra={ + "debug_event": "ucb_comparison", + "parent_id": parent_id, + "children_ucb": {k: round(v, 6) for k, v in children_ucb.items()}, + "selected_child": selected_child, + "session_id": self.session_id, + "mcts_iteration": self._iteration_count, + }, + ) + + def log_expansion( + self, + parent_id: str, + action: str, + new_node_id: str, + available_actions: list[str], + ) -> None: + """Log node expansion details.""" + if not self.enabled: + return + + self.logger.debug( + f"Expansion: parent={parent_id} action={action} new_node={new_node_id} " + f"available={len(available_actions)} actions", + extra={ + "debug_event": "mcts_expansion", + "parent_id": parent_id, + "action": action, + "new_node_id": new_node_id, + "available_actions": available_actions, + "session_id": self.session_id, + "mcts_iteration": self._iteration_count, + }, + ) + + def log_simulation( + self, + node_id: str, + simulation_result: float, + simulation_details: dict[str, Any] | None = None, + ) -> None: + """Log simulation/rollout results.""" + if not self.enabled: + return + + self.logger.debug( + f"Simulation: node={node_id} result={simulation_result:.4f}", + extra={ + "debug_event": "mcts_simulation", + "node_id": node_id, + "simulation_result": round(simulation_result, 6), + "simulation_details": simulation_details or {}, + "session_id": self.session_id, + "mcts_iteration": self._iteration_count, + }, + ) + + def log_backpropagation( + self, + path: list[str], + value: float, + updates: list[dict[str, Any]], + ) -> None: + """Log backpropagation path and value updates.""" + if not self.enabled: + return + + self.logger.debug( + f"Backprop: path={' -> '.join(path)} value={value:.4f}", + extra={ + "debug_event": "mcts_backprop", + "path": path, + "value": round(value, 6), + "updates": updates, + "session_id": self.session_id, + "mcts_iteration": self._iteration_count, + }, + ) + + def log_iteration_end( + self, + iteration: int, + best_action: str, + best_ucb: float, + tree_size: int, + ) -> None: + """Log the end of an MCTS iteration.""" + if not self.enabled: + return + + self.logger.debug( + f"=== MCTS Iteration {iteration} END === " + f"best_action={best_action} UCB={best_ucb:.4f} tree_size={tree_size}", + extra={ + "debug_event": "iteration_end", + "mcts_iteration": iteration, + "best_action": best_action, + "best_ucb": round(best_ucb, 6), + "tree_size": tree_size, + "session_id": self.session_id, + }, + ) + + def log_state_diff( + self, + old_state: dict[str, Any], + new_state: dict[str, Any], + description: str = "State change", + ) -> None: + """Log differences between two states.""" + if not self.enabled: + return + + diff = self._compute_state_diff(old_state, new_state) + + if diff: + self._state_history.append( + { + "iteration": self._iteration_count, + "description": description, + "diff": diff, + } + ) + + self.logger.debug( + f"State diff: {description}", + extra={ + "debug_event": "state_diff", + "description": description, + "diff": diff, + "session_id": self.session_id, + "mcts_iteration": self._iteration_count, + }, + ) + + def _compute_state_diff( + self, + old: dict[str, Any], + new: dict[str, Any], + prefix: str = "", + ) -> dict[str, Any]: + """Compute differences between two dictionaries.""" + diff = {} + + all_keys = set(old.keys()) | set(new.keys()) + + for key in all_keys: + full_key = f"{prefix}.{key}" if prefix else key + + if key not in old: + diff[full_key] = {"added": new[key]} + elif key not in new: + diff[full_key] = {"removed": old[key]} + elif old[key] != new[key]: + if isinstance(old[key], dict) and isinstance(new[key], dict): + nested_diff = self._compute_state_diff(old[key], new[key], full_key) + diff.update(nested_diff) + else: + diff[full_key] = {"old": old[key], "new": new[key]} + + return diff + + def get_debug_summary(self) -> dict[str, Any]: + """Get summary of debug information collected.""" + return { + "session_id": self.session_id, + "total_iterations": self._iteration_count, + "selection_history_count": len(self._selection_history), + "state_changes_count": len(self._state_history), + "ucb_comparisons_count": len(self._ucb_history), + } + + +def visualize_mcts_tree( + root_node: Any, + max_depth: int = 10, + max_children: int = 5, + show_ucb: bool = True, + indent: str = " ", +) -> str: + """ + Generate text-based visualization of MCTS tree. + + Args: + root_node: Root MCTSNode + max_depth: Maximum depth to visualize + max_children: Maximum children to show per node + show_ucb: Show UCB scores + indent: Indentation string + + Returns: + Text representation of the tree + """ + lines = ["MCTS Tree Visualization", "=" * 40] + + def render_node(node: Any, depth: int = 0, prefix: str = "") -> None: + if depth > max_depth: + lines.append(f"{prefix}{indent}... (max depth reached)") + return + + # Node info + visits = getattr(node, "visits", 0) + value = getattr(node, "value", 0.0) + action = getattr(node, "action", "root") + state_id = getattr(node, "state_id", "unknown") + + avg_value = value / max(visits, 1) + + node_info = f"[{state_id}] action={action} visits={visits} value={value:.3f} avg={avg_value:.3f}" + + if show_ucb and hasattr(node, "ucb1") and visits > 0: + try: + ucb = node.ucb1() + if ucb != float("inf"): + node_info += f" UCB={ucb:.3f}" + except Exception: + pass + + lines.append(f"{prefix}{node_info}") + + # Children + children = getattr(node, "children", []) + if children: + # Sort by visits (most visited first) + sorted_children = sorted(children, key=lambda c: getattr(c, "visits", 0), reverse=True) + display_children = sorted_children[:max_children] + + for i, child in enumerate(display_children): + is_last = i == len(display_children) - 1 + child_prefix = prefix + indent + ("└── " if is_last else "├── ") + next_prefix = prefix + indent + (" " if is_last else "│ ") + + lines.append(f"{child_prefix[:-4]}") + render_node(child, depth + 1, next_prefix) + + if len(children) > max_children: + lines.append(f"{prefix}{indent}... and {len(children) - max_children} more children") + + render_node(root_node) + lines.append("=" * 40) + + return "\n".join(lines) + + +def export_tree_to_dot( + root_node: Any, + filename: str = "mcts_tree.dot", + max_depth: int = 10, + include_ucb: bool = True, +) -> str: + """ + Export MCTS tree to DOT format for graphviz visualization. + + Args: + root_node: Root MCTSNode + filename: Output filename (optional) + max_depth: Maximum depth to export + include_ucb: Include UCB scores in labels + + Returns: + DOT format string + """ + lines = [ + "digraph MCTSTree {", + ' graph [rankdir=TB, label="MCTS Tree", fontsize=16];', + " node [shape=box, style=filled, fillcolor=lightblue];", + " edge [fontsize=10];", + "", + ] + + node_counter = [0] # Use list for mutable counter in closure + + def add_node(node: Any, depth: int = 0, parent_dot_id: str | None = None) -> None: + if depth > max_depth: + return + + # Generate unique DOT ID + dot_id = f"node_{node_counter[0]}" + node_counter[0] += 1 + + # Node attributes + visits = getattr(node, "visits", 0) + value = getattr(node, "value", 0.0) + action = getattr(node, "action", "root") + state_id = getattr(node, "state_id", "unknown") + + avg_value = value / max(visits, 1) + + # Build label + label_parts = [ + f"ID: {state_id}", + f"Action: {action}", + f"Visits: {visits}", + f"Value: {value:.3f}", + f"Avg: {avg_value:.3f}", + ] + + if include_ucb and hasattr(node, "ucb1") and visits > 0: + try: + ucb = node.ucb1() + if ucb != float("inf"): + label_parts.append(f"UCB: {ucb:.3f}") + except Exception: + pass + + label = "\\n".join(label_parts) + + # Color based on value + if avg_value >= 0.7: + color = "lightgreen" + elif avg_value >= 0.4: + color = "lightyellow" + else: + color = "lightcoral" + + lines.append(f' {dot_id} [label="{label}", fillcolor={color}];') + + # Edge from parent + if parent_dot_id: + lines.append(f' {parent_dot_id} -> {dot_id} [label="{action}"];') + + # Process children + children = getattr(node, "children", []) + for child in children: + add_node(child, depth + 1, dot_id) + + add_node(root_node) + lines.append("}") + + dot_content = "\n".join(lines) + + # Write to file if filename provided + if filename: + with open(filename, "w") as f: + f.write(dot_content) + + return dot_content + + +def print_debug_banner(message: str, char: str = "=", width: int = 60) -> None: + """Print a debug banner message.""" + logger = get_logger("observability.debug") + border = char * width + logger.debug(border) + logger.debug(f"{message:^{width}}") + logger.debug(border) + + +def log_agent_state_snapshot( + agent_name: str, + state: dict[str, Any], + include_keys: list[str] | None = None, +) -> None: + """ + Log a snapshot of agent state for debugging. + + Args: + agent_name: Name of the agent + state: Current state dictionary + include_keys: Specific keys to include (None = all) + """ + logger = get_logger("observability.debug") + + filtered_state = {k: state.get(k) for k in include_keys if k in state} if include_keys else state + + logger.debug( + f"Agent {agent_name} state snapshot", + extra={ + "debug_event": "agent_state_snapshot", + "agent_name": agent_name, + "state": filtered_state, + }, + ) + + +def enable_verbose_debugging() -> None: + """Enable verbose debugging by setting LOG_LEVEL to DEBUG.""" + os.environ["LOG_LEVEL"] = "DEBUG" + + # Reconfigure root logger + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + + for handler in root_logger.handlers: + handler.setLevel(logging.DEBUG) + + logger = get_logger("observability.debug") + logger.info("Verbose debugging ENABLED") + + +def disable_verbose_debugging() -> None: + """Disable verbose debugging by setting LOG_LEVEL to INFO.""" + os.environ["LOG_LEVEL"] = "INFO" + + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + + for handler in root_logger.handlers: + handler.setLevel(logging.INFO) + + logger = get_logger("observability.debug") + logger.info("Verbose debugging DISABLED") diff --git a/src/observability/logging.py b/src/observability/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4f1919ae9b7fa1d9a03a6469c250bcb01f029f --- /dev/null +++ b/src/observability/logging.py @@ -0,0 +1,495 @@ +""" +JSON-structured logging infrastructure for multi-agent MCTS framework. + +Provides: +- JSON-structured logging via logging.config.dictConfig +- Per-module loggers with proper hierarchy +- Correlation IDs for request tracking +- Log levels configurable via environment/settings +- Performance metrics in logs (timing, memory) +- Safe sanitization (no secrets in logs) +""" + +import json +import logging +import logging.config +import os +import re +import time +import traceback +import uuid +from contextvars import ContextVar +from datetime import datetime +from functools import wraps + +import psutil + +# Context variable for correlation ID tracking across async calls +_correlation_id: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_request_metadata: ContextVar[dict | None] = ContextVar("request_metadata", default=None) + + +def get_correlation_id() -> str: + """Get current correlation ID or generate new one.""" + cid = _correlation_id.get() + if cid is None: + cid = str(uuid.uuid4()) + _correlation_id.set(cid) + return cid + + +def set_correlation_id(cid: str) -> None: + """Set correlation ID for current context.""" + _correlation_id.set(cid) + + +def set_request_metadata(metadata: dict) -> None: + """Set request metadata for current context.""" + _request_metadata.set(metadata) + + +def get_request_metadata() -> dict: + """Get request metadata for current context.""" + metadata = _request_metadata.get() + return metadata if metadata is not None else {} + + +# Patterns for sensitive data sanitization +SENSITIVE_PATTERNS = [ + (re.compile(r'("?api[_-]?key"?\s*[:=]\s*)"[^"]*"', re.IGNORECASE), r'\1"***REDACTED***"'), + (re.compile(r'("?password"?\s*[:=]\s*)"[^"]*"', re.IGNORECASE), r'\1"***REDACTED***"'), + (re.compile(r'("?secret"?\s*[:=]\s*)"[^"]*"', re.IGNORECASE), r'\1"***REDACTED***"'), + (re.compile(r'("?token"?\s*[:=]\s*)"[^"]*"', re.IGNORECASE), r'\1"***REDACTED***"'), + (re.compile(r'("?authorization"?\s*[:=]\s*)"[^"]*"', re.IGNORECASE), r'\1"***REDACTED***"'), + (re.compile(r'("?aws[_-]?secret[_-]?access[_-]?key"?\s*[:=]\s*)"[^"]*"', re.IGNORECASE), r'\1"***REDACTED***"'), + (re.compile(r"(Bearer\s+)\S+", re.IGNORECASE), r"\1***REDACTED***"), + (re.compile(r"(Basic\s+)\S+", re.IGNORECASE), r"\1***REDACTED***"), +] + + +def sanitize_message(message: str) -> str: + """Sanitize sensitive data from log messages.""" + for pattern, replacement in SENSITIVE_PATTERNS: + message = pattern.sub(replacement, message) + return message + + +def sanitize_dict(data: dict) -> dict: + """Recursively sanitize sensitive data from dictionaries.""" + sensitive_keys = { + "api_key", + "apikey", + "password", + "secret", + "token", + "authorization", + "auth", + "credentials", + "aws_secret_access_key", + "private_key", + } + + result = {} + for key, value in data.items(): + key_lower = key.lower().replace("-", "_") + if key_lower in sensitive_keys: + result[key] = "***REDACTED***" + elif isinstance(value, dict): + result[key] = sanitize_dict(value) + elif isinstance(value, list): + result[key] = [sanitize_dict(item) if isinstance(item, dict) else item for item in value] + elif isinstance(value, str): + result[key] = sanitize_message(value) + else: + result[key] = value + return result + + +class CorrelationIdFilter(logging.Filter): + """Add correlation ID and request metadata to log records.""" + + def filter(self, record: logging.LogRecord) -> bool: + record.correlation_id = get_correlation_id() + record.request_metadata = get_request_metadata() + return True + + +class PerformanceMetricsFilter(logging.Filter): + """Add performance metrics to log records.""" + + def filter(self, record: logging.LogRecord) -> bool: + process = psutil.Process() + record.memory_mb = process.memory_info().rss / (1024 * 1024) + record.cpu_percent = process.cpu_percent() + record.thread_count = process.num_threads() + return True + + +class JSONFormatter(logging.Formatter): + """Format log records as JSON with comprehensive metadata.""" + + def __init__(self, include_hostname: bool = True, include_process: bool = True): + super().__init__() + self.include_hostname = include_hostname + self.include_process = include_process + if include_hostname: + import socket + + self.hostname = socket.gethostname() + else: + self.hostname = None + + def format(self, record: logging.LogRecord) -> str: + log_data = { + "timestamp": datetime.utcfromtimestamp(record.created).isoformat() + "Z", + "level": record.levelname, + "logger": record.name, + "message": sanitize_message(record.getMessage()), + "correlation_id": getattr(record, "correlation_id", None), + "module": record.module, + "function": record.funcName, + "line": record.lineno, + } + + # Add hostname if configured + if self.hostname: + log_data["hostname"] = self.hostname + + # Add process info if configured + if self.include_process: + log_data["process"] = { + "id": record.process, + "name": record.processName, + "thread_id": record.thread, + "thread_name": record.threadName, + } + + # Add performance metrics if available + if hasattr(record, "memory_mb"): + log_data["performance"] = { + "memory_mb": round(getattr(record, "memory_mb", 0), 2), + "cpu_percent": round(getattr(record, "cpu_percent", 0), 2), + "thread_count": getattr(record, "thread_count", 0), + } + + # Add request metadata if available + request_metadata = getattr(record, "request_metadata", {}) + if request_metadata: + log_data["request"] = sanitize_dict(request_metadata) + + # Add exception info if present + if record.exc_info: + log_data["exception"] = { + "type": record.exc_info[0].__name__ if record.exc_info[0] else None, + "message": str(record.exc_info[1]) if record.exc_info[1] else None, + "traceback": traceback.format_exception(*record.exc_info), + } + + # Add any extra fields (sanitized) + extra_fields = {} + for key, value in record.__dict__.items(): + if key not in { + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "message", + "pathname", + "process", + "processName", + "relativeCreated", + "thread", + "threadName", + "exc_info", + "exc_text", + "stack_info", + "correlation_id", + "request_metadata", + "memory_mb", + "cpu_percent", + "thread_count", + "taskName", + }: + if isinstance(value, dict): + extra_fields[key] = sanitize_dict(value) + elif isinstance(value, str): + extra_fields[key] = sanitize_message(value) + else: + extra_fields[key] = value + + if extra_fields: + log_data["extra"] = extra_fields + + return json.dumps(log_data, default=str) + + +def setup_logging( + log_level: str | None = None, + log_file: str | None = None, + include_performance_metrics: bool = True, + json_output: bool = True, + include_hostname: bool = True, + include_process: bool = True, +) -> None: + """ + Configure JSON-structured logging for the application. + + Args: + log_level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL). + Defaults to LOG_LEVEL env var or INFO. + log_file: Optional file path for log output. + include_performance_metrics: Include memory/CPU metrics in logs. + json_output: Use JSON formatter (default True). + include_hostname: Include hostname in logs. + include_process: Include process/thread info in logs. + """ + if log_level is None: + log_level = os.environ.get("LOG_LEVEL", "INFO").upper() + + # Base configuration + config = { + "version": 1, + "disable_existing_loggers": False, + "filters": { + "correlation_id": { + "()": CorrelationIdFilter, + }, + }, + "formatters": { + "json": { + "()": JSONFormatter, + "include_hostname": include_hostname, + "include_process": include_process, + }, + "standard": { + "format": "%(asctime)s [%(levelname)s] %(name)s (%(correlation_id)s): %(message)s", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": log_level, + "formatter": "json" if json_output else "standard", + "filters": ["correlation_id"], + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + # Root logger + "": { + "handlers": ["console"], + "level": log_level, + "propagate": False, + }, + # Framework loggers with hierarchy + "mcts": { + "handlers": ["console"], + "level": log_level, + "propagate": False, + }, + "mcts.framework": { + "level": log_level, + "propagate": True, + }, + "mcts.agents": { + "level": log_level, + "propagate": True, + }, + "mcts.observability": { + "level": log_level, + "propagate": True, + }, + "mcts.storage": { + "level": log_level, + "propagate": True, + }, + # Third-party loggers (quieter) + "httpx": { + "level": "WARNING", + "propagate": True, + }, + "opentelemetry": { + "level": "WARNING", + "propagate": True, + }, + "aioboto3": { + "level": "WARNING", + "propagate": True, + }, + "botocore": { + "level": "WARNING", + "propagate": True, + }, + }, + } + + # Add performance metrics filter if requested + if include_performance_metrics: + config["filters"]["performance_metrics"] = { + "()": PerformanceMetricsFilter, + } + config["handlers"]["console"]["filters"].append("performance_metrics") + + # Add file handler if requested + if log_file: + config["handlers"]["file"] = { + "class": "logging.handlers.RotatingFileHandler", + "level": log_level, + "formatter": "json" if json_output else "standard", + "filters": ["correlation_id"], + "filename": log_file, + "maxBytes": 10 * 1024 * 1024, # 10MB + "backupCount": 5, + "encoding": "utf-8", + } + if include_performance_metrics: + config["handlers"]["file"]["filters"].append("performance_metrics") + + # Add file handler to all loggers + config["loggers"][""]["handlers"].append("file") + config["loggers"]["mcts"]["handlers"].append("file") + + logging.config.dictConfig(config) + + +def get_logger(name: str) -> logging.Logger: + """ + Get a logger with the specified name. + + Uses hierarchical naming under 'mcts' root logger. + Example: get_logger("framework.graph") returns logger "mcts.framework.graph" + + Args: + name: Logger name (will be prefixed with 'mcts.') + + Returns: + Configured logger instance + """ + if not name.startswith("mcts."): + name = f"mcts.{name}" + return logging.getLogger(name) + + +class LogContext: + """Context manager for adding temporary log context.""" + + def __init__(self, **kwargs): + self.kwargs = kwargs + self._old_metadata = None + + def __enter__(self): + self._old_metadata = get_request_metadata().copy() + new_metadata = {**self._old_metadata, **self.kwargs} + set_request_metadata(new_metadata) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + set_request_metadata(self._old_metadata) + return False + + +def log_execution_time(logger: logging.Logger | None = None, level: int = logging.INFO): + """ + Decorator to log function execution time. + + Args: + logger: Logger instance (defaults to function's module logger) + level: Log level for timing message + + Example: + @log_execution_time() + def my_function(): + ... + """ + + def decorator(func): + @wraps(func) + def sync_wrapper(*args, **kwargs): + nonlocal logger + if logger is None: + logger = get_logger(func.__module__) + + start_time = time.perf_counter() + start_memory = psutil.Process().memory_info().rss + + try: + result = func(*args, **kwargs) + success = True + error = None + except Exception as e: + success = False + error = str(e) + raise + finally: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + memory_delta_mb = (psutil.Process().memory_info().rss - start_memory) / (1024 * 1024) + + logger.log( + level, + f"Function {func.__name__} completed", + extra={ + "timing": { + "function": func.__name__, + "elapsed_ms": round(elapsed_ms, 2), + "success": success, + "error": error, + "memory_delta_mb": round(memory_delta_mb, 2), + } + }, + ) + + return result + + @wraps(func) + async def async_wrapper(*args, **kwargs): + nonlocal logger + if logger is None: + logger = get_logger(func.__module__) + + start_time = time.perf_counter() + start_memory = psutil.Process().memory_info().rss + + try: + result = await func(*args, **kwargs) + success = True + error = None + except Exception as e: + success = False + error = str(e) + raise + finally: + elapsed_ms = (time.perf_counter() - start_time) * 1000 + memory_delta_mb = (psutil.Process().memory_info().rss - start_memory) / (1024 * 1024) + + logger.log( + level, + f"Async function {func.__name__} completed", + extra={ + "timing": { + "function": func.__name__, + "elapsed_ms": round(elapsed_ms, 2), + "success": success, + "error": error, + "memory_delta_mb": round(memory_delta_mb, 2), + } + }, + ) + + return result + + if asyncio.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper + + return decorator + + +# Import asyncio for decorator +import asyncio # noqa: E402 diff --git a/src/observability/metrics.py b/src/observability/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..ce9bf00a78bd459c1d7785f02f70e46c545f12fe --- /dev/null +++ b/src/observability/metrics.py @@ -0,0 +1,469 @@ +""" +Metrics collection infrastructure for multi-agent MCTS framework. + +Provides: +- MCTS iteration counters +- UCB score distributions +- Agent confidence tracking +- Timing metrics for each graph node +- Memory usage monitoring +- Export to Prometheus format (optional) +""" + +import time +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Optional + +import psutil + +try: + from prometheus_client import ( + REGISTRY, + Counter, + Gauge, + Histogram, + Summary, + generate_latest, + start_http_server, + ) + + PROMETHEUS_AVAILABLE = True +except ImportError: + PROMETHEUS_AVAILABLE = False + + +@dataclass +class MCTSMetrics: + """Container for MCTS-specific metrics.""" + + iterations: int = 0 + total_simulations: int = 0 + tree_depth: int = 0 + total_nodes: int = 0 + ucb_scores: list[float] = field(default_factory=list) + selection_times_ms: list[float] = field(default_factory=list) + expansion_times_ms: list[float] = field(default_factory=list) + simulation_times_ms: list[float] = field(default_factory=list) + backprop_times_ms: list[float] = field(default_factory=list) + best_action_visits: int = 0 + best_action_value: float = 0.0 + + +@dataclass +class AgentMetrics: + """Container for agent-specific metrics.""" + + name: str + executions: int = 0 + total_time_ms: float = 0.0 + avg_confidence: float = 0.0 + confidence_scores: list[float] = field(default_factory=list) + success_count: int = 0 + error_count: int = 0 + memory_usage_mb: list[float] = field(default_factory=list) + + +class MetricsCollector: + """ + Central metrics collection and reporting for the MCTS framework. + + Collects: + - MCTS iteration counters and UCB scores + - Agent confidence and execution times + - Graph node timing metrics + - Memory usage monitoring + - Request/response latencies + + Supports optional Prometheus export. + """ + + _instance: Optional["MetricsCollector"] = None + + def __init__(self): + self._mcts_metrics: dict[str, MCTSMetrics] = defaultdict(MCTSMetrics) + self._agent_metrics: dict[str, AgentMetrics] = {} + self._node_timings: dict[str, list[float]] = defaultdict(list) + self._request_latencies: list[float] = [] + self._memory_samples: list[dict[str, float]] = [] + self._start_time = datetime.utcnow() + self._process = psutil.Process() + + # Prometheus metrics (if available) + self._prometheus_initialized = False + if PROMETHEUS_AVAILABLE: + self._init_prometheus_metrics() + + @classmethod + def get_instance(cls) -> "MetricsCollector": + """Get singleton instance of MetricsCollector.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _init_prometheus_metrics(self) -> None: + """Initialize Prometheus metrics.""" + if not PROMETHEUS_AVAILABLE or self._prometheus_initialized: + return + + # MCTS counters + self._prom_mcts_iterations = Counter( + "mcts_iterations_total", + "Total number of MCTS iterations", + ["session_id"], + ) + self._prom_mcts_simulations = Counter( + "mcts_simulations_total", + "Total number of MCTS simulations", + ["session_id"], + ) + + # MCTS gauges + self._prom_mcts_tree_depth = Gauge( + "mcts_tree_depth", + "Current MCTS tree depth", + ["session_id"], + ) + self._prom_mcts_total_nodes = Gauge( + "mcts_total_nodes", + "Total nodes in MCTS tree", + ["session_id"], + ) + + # UCB score histogram + self._prom_ucb_scores = Histogram( + "mcts_ucb_score", + "UCB score distribution", + ["session_id"], + buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, float("inf")], + ) + + # Agent metrics + self._prom_agent_executions = Counter( + "agent_executions_total", + "Total agent executions", + ["agent_name"], + ) + self._prom_agent_confidence = Summary( + "agent_confidence", + "Agent confidence scores", + ["agent_name"], + ) + self._prom_agent_execution_time = Histogram( + "agent_execution_time_ms", + "Agent execution time in milliseconds", + ["agent_name"], + buckets=[10, 50, 100, 250, 500, 1000, 2500, 5000, 10000], + ) + + # System metrics + self._prom_memory_usage = Gauge( + "framework_memory_usage_mb", + "Memory usage in MB", + ) + self._prom_cpu_percent = Gauge( + "framework_cpu_percent", + "CPU usage percentage", + ) + + # Request latency + self._prom_request_latency = Histogram( + "request_latency_ms", + "Request latency in milliseconds", + buckets=[10, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000], + ) + + self._prometheus_initialized = True + + def start_prometheus_server(self, port: int = 8000) -> None: + """Start Prometheus metrics HTTP server.""" + if PROMETHEUS_AVAILABLE: + start_http_server(port) + + def record_mcts_iteration( + self, + session_id: str, + ucb_score: float, + selection_time_ms: float = 0.0, + expansion_time_ms: float = 0.0, + simulation_time_ms: float = 0.0, + backprop_time_ms: float = 0.0, + ) -> None: + """Record metrics for a single MCTS iteration.""" + metrics = self._mcts_metrics[session_id] + metrics.iterations += 1 + metrics.ucb_scores.append(ucb_score) + + if selection_time_ms > 0: + metrics.selection_times_ms.append(selection_time_ms) + if expansion_time_ms > 0: + metrics.expansion_times_ms.append(expansion_time_ms) + if simulation_time_ms > 0: + metrics.simulation_times_ms.append(simulation_time_ms) + if backprop_time_ms > 0: + metrics.backprop_times_ms.append(backprop_time_ms) + + # Prometheus + if self._prometheus_initialized: + self._prom_mcts_iterations.labels(session_id=session_id).inc() + self._prom_ucb_scores.labels(session_id=session_id).observe(ucb_score) + + def record_mcts_simulation(self, session_id: str) -> None: + """Record an MCTS simulation.""" + self._mcts_metrics[session_id].total_simulations += 1 + if self._prometheus_initialized: + self._prom_mcts_simulations.labels(session_id=session_id).inc() + + def update_mcts_tree_stats( + self, + session_id: str, + tree_depth: int, + total_nodes: int, + best_action_visits: int = 0, + best_action_value: float = 0.0, + ) -> None: + """Update MCTS tree statistics.""" + metrics = self._mcts_metrics[session_id] + metrics.tree_depth = tree_depth + metrics.total_nodes = total_nodes + metrics.best_action_visits = best_action_visits + metrics.best_action_value = best_action_value + + if self._prometheus_initialized: + self._prom_mcts_tree_depth.labels(session_id=session_id).set(tree_depth) + self._prom_mcts_total_nodes.labels(session_id=session_id).set(total_nodes) + + def record_agent_execution( + self, + agent_name: str, + execution_time_ms: float, + confidence: float, + success: bool = True, + ) -> None: + """Record agent execution metrics.""" + if agent_name not in self._agent_metrics: + self._agent_metrics[agent_name] = AgentMetrics(name=agent_name) + + metrics = self._agent_metrics[agent_name] + metrics.executions += 1 + metrics.total_time_ms += execution_time_ms + metrics.confidence_scores.append(confidence) + metrics.avg_confidence = sum(metrics.confidence_scores) / len(metrics.confidence_scores) + + if success: + metrics.success_count += 1 + else: + metrics.error_count += 1 + + # Memory sample + memory_mb = self._process.memory_info().rss / (1024 * 1024) + metrics.memory_usage_mb.append(memory_mb) + + # Prometheus + if self._prometheus_initialized: + self._prom_agent_executions.labels(agent_name=agent_name).inc() + self._prom_agent_confidence.labels(agent_name=agent_name).observe(confidence) + self._prom_agent_execution_time.labels(agent_name=agent_name).observe(execution_time_ms) + + def record_node_timing(self, node_name: str, execution_time_ms: float) -> None: + """Record execution time for a graph node.""" + self._node_timings[node_name].append(execution_time_ms) + + def record_request_latency(self, latency_ms: float) -> None: + """Record end-to-end request latency.""" + self._request_latencies.append(latency_ms) + if self._prometheus_initialized: + self._prom_request_latency.observe(latency_ms) + + def sample_system_metrics(self) -> dict[str, float]: + """Sample current system metrics.""" + memory_info = self._process.memory_info() + cpu_percent = self._process.cpu_percent() + + sample = { + "timestamp": datetime.utcnow().isoformat(), + "memory_rss_mb": memory_info.rss / (1024 * 1024), + "memory_vms_mb": memory_info.vms / (1024 * 1024), + "cpu_percent": cpu_percent, + "thread_count": self._process.num_threads(), + "open_files": len(self._process.open_files()), + } + + self._memory_samples.append(sample) + + if self._prometheus_initialized: + self._prom_memory_usage.set(sample["memory_rss_mb"]) + self._prom_cpu_percent.set(cpu_percent) + + return sample + + def get_mcts_summary(self, session_id: str) -> dict[str, Any]: + """Get summary statistics for MCTS session.""" + metrics = self._mcts_metrics.get(session_id) + if not metrics: + return {} + + def safe_avg(lst: list[float]) -> float: + return sum(lst) / len(lst) if lst else 0.0 + + def safe_percentile(lst: list[float], p: float) -> float: + if not lst: + return 0.0 + sorted_lst = sorted(lst) + idx = int(len(sorted_lst) * p) + return sorted_lst[min(idx, len(sorted_lst) - 1)] + + return { + "session_id": session_id, + "total_iterations": metrics.iterations, + "total_simulations": metrics.total_simulations, + "tree_depth": metrics.tree_depth, + "total_nodes": metrics.total_nodes, + "best_action_visits": metrics.best_action_visits, + "best_action_value": round(metrics.best_action_value, 4), + "ucb_scores": { + "count": len(metrics.ucb_scores), + "mean": round(safe_avg(metrics.ucb_scores), 4), + "min": round(min(metrics.ucb_scores), 4) if metrics.ucb_scores else 0.0, + "max": round(max(metrics.ucb_scores), 4) if metrics.ucb_scores else 0.0, + "p50": round(safe_percentile(metrics.ucb_scores, 0.5), 4), + "p95": round(safe_percentile(metrics.ucb_scores, 0.95), 4), + }, + "timing_ms": { + "selection_avg": round(safe_avg(metrics.selection_times_ms), 2), + "expansion_avg": round(safe_avg(metrics.expansion_times_ms), 2), + "simulation_avg": round(safe_avg(metrics.simulation_times_ms), 2), + "backprop_avg": round(safe_avg(metrics.backprop_times_ms), 2), + }, + } + + def get_agent_summary(self, agent_name: str) -> dict[str, Any]: + """Get summary statistics for an agent.""" + metrics = self._agent_metrics.get(agent_name) + if not metrics: + return {} + + def safe_avg(lst: list[float]) -> float: + return sum(lst) / len(lst) if lst else 0.0 + + return { + "agent_name": agent_name, + "total_executions": metrics.executions, + "success_count": metrics.success_count, + "error_count": metrics.error_count, + "success_rate": round(metrics.success_count / max(metrics.executions, 1), 4), + "avg_execution_time_ms": round(metrics.total_time_ms / max(metrics.executions, 1), 2), + "total_time_ms": round(metrics.total_time_ms, 2), + "confidence": { + "mean": round(safe_avg(metrics.confidence_scores), 4), + "min": round(min(metrics.confidence_scores), 4) if metrics.confidence_scores else 0.0, + "max": round(max(metrics.confidence_scores), 4) if metrics.confidence_scores else 0.0, + }, + "avg_memory_mb": round(safe_avg(metrics.memory_usage_mb), 2), + } + + def get_node_timing_summary(self) -> dict[str, dict[str, float]]: + """Get timing summary for all graph nodes.""" + summary = {} + for node_name, timings in self._node_timings.items(): + if timings: + summary[node_name] = { + "count": len(timings), + "mean_ms": round(sum(timings) / len(timings), 2), + "min_ms": round(min(timings), 2), + "max_ms": round(max(timings), 2), + "total_ms": round(sum(timings), 2), + } + return summary + + def get_full_report(self) -> dict[str, Any]: + """Generate a comprehensive metrics report.""" + # Sample current system state + current_system = self.sample_system_metrics() + + report = { + "report_time": datetime.utcnow().isoformat(), + "uptime_seconds": (datetime.utcnow() - self._start_time).total_seconds(), + "system_metrics": current_system, + "mcts_sessions": {session_id: self.get_mcts_summary(session_id) for session_id in self._mcts_metrics}, + "agents": {agent_name: self.get_agent_summary(agent_name) for agent_name in self._agent_metrics}, + "node_timings": self.get_node_timing_summary(), + "request_latencies": { + "count": len(self._request_latencies), + "mean_ms": round(sum(self._request_latencies) / max(len(self._request_latencies), 1), 2), + "min_ms": round(min(self._request_latencies), 2) if self._request_latencies else 0.0, + "max_ms": round(max(self._request_latencies), 2) if self._request_latencies else 0.0, + }, + } + + return report + + def export_prometheus_format(self) -> str: + """Export metrics in Prometheus text format.""" + if PROMETHEUS_AVAILABLE: + return generate_latest(REGISTRY).decode("utf-8") + else: + return "# Prometheus client not available\n" + + def reset(self) -> None: + """Reset all collected metrics.""" + self._mcts_metrics.clear() + self._agent_metrics.clear() + self._node_timings.clear() + self._request_latencies.clear() + self._memory_samples.clear() + self._start_time = datetime.utcnow() + + +# Convenience singleton accessors +def mcts_metrics() -> MetricsCollector: + """Get the singleton MetricsCollector instance.""" + return MetricsCollector.get_instance() + + +def agent_metrics() -> MetricsCollector: + """Alias for mcts_metrics() - same singleton.""" + return MetricsCollector.get_instance() + + +class MetricsTimer: + """Context manager for timing operations and recording metrics.""" + + def __init__( + self, + collector: MetricsCollector | None = None, + node_name: str | None = None, + mcts_session_id: str | None = None, + agent_name: str | None = None, + ): + self.collector = collector or MetricsCollector.get_instance() + self.node_name = node_name + self.mcts_session_id = mcts_session_id + self.agent_name = agent_name + self.start_time: float = 0.0 + self.elapsed_ms: float = 0.0 + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.elapsed_ms = (time.perf_counter() - self.start_time) * 1000 + + if self.node_name: + self.collector.record_node_timing(self.node_name, self.elapsed_ms) + + return False + + async def __aenter__(self): + self.start_time = time.perf_counter() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.elapsed_ms = (time.perf_counter() - self.start_time) * 1000 + + if self.node_name: + self.collector.record_node_timing(self.node_name, self.elapsed_ms) + + return False diff --git a/src/observability/profiling.py b/src/observability/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7e30dd7808d4ca62419e80851293cf134c1849 --- /dev/null +++ b/src/observability/profiling.py @@ -0,0 +1,580 @@ +""" +Performance profiling infrastructure for multi-agent MCTS framework. + +Provides: +- Context manager for timing code blocks +- Memory profiling hooks +- Async-aware profiling +- Report generation +""" + +import asyncio +import functools +import time +from collections import defaultdict +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Optional + +import psutil + +from .logging import get_logger + + +@dataclass +class TimingResult: + """Result of a timed operation.""" + + name: str + elapsed_ms: float + start_time: float + end_time: float + memory_start_mb: float + memory_end_mb: float + memory_delta_mb: float + success: bool = True + error: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ProfilingSession: + """Container for profiling results within a session.""" + + session_id: str + start_time: datetime = field(default_factory=datetime.utcnow) + timings: list[TimingResult] = field(default_factory=list) + memory_samples: list[dict[str, float]] = field(default_factory=list) + cpu_samples: list[float] = field(default_factory=list) + markers: list[dict[str, Any]] = field(default_factory=list) + + +class AsyncProfiler: + """ + Async-aware profiler for multi-agent MCTS framework. + + Tracks: + - Execution times for async operations + - Memory usage patterns + - CPU utilization + - Custom markers and events + """ + + _instance: Optional["AsyncProfiler"] = None + + def __init__(self): + self.logger = get_logger("observability.profiling") + self._sessions: dict[str, ProfilingSession] = {} + self._current_session: str | None = None + self._process = psutil.Process() + self._aggregate_timings: dict[str, list[float]] = defaultdict(list) + + @classmethod + def get_instance(cls) -> "AsyncProfiler": + """Get singleton instance of AsyncProfiler.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def start_session(self, session_id: str | None = None) -> str: + """Start a new profiling session.""" + if session_id is None: + session_id = f"session_{datetime.utcnow().strftime('%Y%m%d_%H%M%S_%f')}" + + self._sessions[session_id] = ProfilingSession(session_id=session_id) + self._current_session = session_id + self.logger.info(f"Started profiling session: {session_id}") + return session_id + + def end_session(self, session_id: str | None = None) -> ProfilingSession: + """End a profiling session and return results.""" + if session_id is None: + session_id = self._current_session + + if session_id not in self._sessions: + raise ValueError(f"Unknown session: {session_id}") + + session = self._sessions[session_id] + self.logger.info(f"Ended profiling session: {session_id}") + + if self._current_session == session_id: + self._current_session = None + + return session + + def get_current_session(self) -> ProfilingSession | None: + """Get current profiling session.""" + if self._current_session and self._current_session in self._sessions: + return self._sessions[self._current_session] + return None + + @contextmanager + def time_block( + self, + name: str, + session_id: str | None = None, + metadata: dict[str, Any] | None = None, + ): + """ + Context manager for timing synchronous code blocks. + + Args: + name: Name of the operation being timed + session_id: Optional session ID (uses current if not specified) + metadata: Additional metadata to record + + Example: + with profiler.time_block("mcts.selection"): + # perform selection + """ + if session_id is None: + session_id = self._current_session + + start_time = time.perf_counter() + memory_start = self._process.memory_info().rss / (1024 * 1024) + success = True + error = None + + try: + yield + except Exception as e: + success = False + error = str(e) + raise + finally: + end_time = time.perf_counter() + memory_end = self._process.memory_info().rss / (1024 * 1024) + elapsed_ms = (end_time - start_time) * 1000 + + result = TimingResult( + name=name, + elapsed_ms=elapsed_ms, + start_time=start_time, + end_time=end_time, + memory_start_mb=memory_start, + memory_end_mb=memory_end, + memory_delta_mb=memory_end - memory_start, + success=success, + error=error, + metadata=metadata or {}, + ) + + # Record in session if available + if session_id and session_id in self._sessions: + self._sessions[session_id].timings.append(result) + + # Record in aggregates + self._aggregate_timings[name].append(elapsed_ms) + + self.logger.debug( + f"Timed block '{name}': {elapsed_ms:.2f}ms", + extra={ + "profiling": { + "name": name, + "elapsed_ms": round(elapsed_ms, 2), + "memory_delta_mb": round(result.memory_delta_mb, 2), + "success": success, + } + }, + ) + + @asynccontextmanager + async def async_time_block( + self, + name: str, + session_id: str | None = None, + metadata: dict[str, Any] | None = None, + ): + """ + Async context manager for timing asynchronous code blocks. + + Args: + name: Name of the operation being timed + session_id: Optional session ID + metadata: Additional metadata + + Example: + async with profiler.async_time_block("llm.call"): + await model.generate(...) + """ + if session_id is None: + session_id = self._current_session + + start_time = time.perf_counter() + memory_start = self._process.memory_info().rss / (1024 * 1024) + success = True + error = None + + try: + yield + except Exception as e: + success = False + error = str(e) + raise + finally: + end_time = time.perf_counter() + memory_end = self._process.memory_info().rss / (1024 * 1024) + elapsed_ms = (end_time - start_time) * 1000 + + result = TimingResult( + name=name, + elapsed_ms=elapsed_ms, + start_time=start_time, + end_time=end_time, + memory_start_mb=memory_start, + memory_end_mb=memory_end, + memory_delta_mb=memory_end - memory_start, + success=success, + error=error, + metadata=metadata or {}, + ) + + if session_id and session_id in self._sessions: + self._sessions[session_id].timings.append(result) + + self._aggregate_timings[name].append(elapsed_ms) + + self.logger.debug( + f"Async timed block '{name}': {elapsed_ms:.2f}ms", + extra={ + "profiling": { + "name": name, + "elapsed_ms": round(elapsed_ms, 2), + "memory_delta_mb": round(result.memory_delta_mb, 2), + "success": success, + } + }, + ) + + def sample_memory(self, session_id: str | None = None) -> dict[str, float]: + """Sample current memory usage.""" + memory_info = self._process.memory_info() + + sample = { + "timestamp": time.time(), + "rss_mb": memory_info.rss / (1024 * 1024), + "vms_mb": memory_info.vms / (1024 * 1024), + "percent": self._process.memory_percent(), + } + + if session_id is None: + session_id = self._current_session + + if session_id and session_id in self._sessions: + self._sessions[session_id].memory_samples.append(sample) + + return sample + + def sample_cpu(self, session_id: str | None = None) -> float: + """Sample current CPU usage.""" + cpu_percent = self._process.cpu_percent() + + if session_id is None: + session_id = self._current_session + + if session_id and session_id in self._sessions: + self._sessions[session_id].cpu_samples.append(cpu_percent) + + return cpu_percent + + def add_marker( + self, + name: str, + data: dict[str, Any] | None = None, + session_id: str | None = None, + ) -> None: + """Add a custom marker/event to the profiling session.""" + marker = { + "timestamp": time.time(), + "name": name, + "data": data or {}, + } + + if session_id is None: + session_id = self._current_session + + if session_id and session_id in self._sessions: + self._sessions[session_id].markers.append(marker) + + self.logger.debug(f"Added profiling marker: {name}") + + def get_timing_summary(self, name: str | None = None) -> dict[str, Any]: + """ + Get summary statistics for timed operations. + + Args: + name: Optional specific operation name (all if None) + + Returns: + Summary statistics + """ + if name: + timings = self._aggregate_timings.get(name, []) + if not timings: + return {} + return self._compute_stats(name, timings) + else: + return {op_name: self._compute_stats(op_name, times) for op_name, times in self._aggregate_timings.items()} + + def _compute_stats(self, name: str, timings: list[float]) -> dict[str, Any]: + """Compute statistics for a list of timings.""" + if not timings: + return {} + + sorted_timings = sorted(timings) + n = len(sorted_timings) + + return { + "name": name, + "count": n, + "total_ms": round(sum(timings), 2), + "mean_ms": round(sum(timings) / n, 2), + "min_ms": round(min(timings), 2), + "max_ms": round(max(timings), 2), + "p50_ms": round(sorted_timings[n // 2], 2), + "p90_ms": round(sorted_timings[int(n * 0.9)], 2), + "p95_ms": round(sorted_timings[int(n * 0.95)], 2), + "p99_ms": round(sorted_timings[min(int(n * 0.99), n - 1)], 2), + } + + def reset(self) -> None: + """Reset all profiling data.""" + self._sessions.clear() + self._current_session = None + self._aggregate_timings.clear() + self.logger.info("Profiler reset") + + +class MemoryProfiler: + """ + Memory-focused profiler for tracking memory usage patterns. + """ + + def __init__(self): + self.logger = get_logger("observability.profiling.memory") + self._process = psutil.Process() + self._baseline: float | None = None + self._peak: float = 0.0 + self._samples: list[dict[str, Any]] = [] + + def set_baseline(self) -> float: + """Set current memory as baseline.""" + self._baseline = self._process.memory_info().rss / (1024 * 1024) + self.logger.info(f"Memory baseline set: {self._baseline:.2f} MB") + return self._baseline + + def get_current(self) -> float: + """Get current memory usage in MB.""" + return self._process.memory_info().rss / (1024 * 1024) + + def get_delta(self) -> float: + """Get memory change from baseline.""" + if self._baseline is None: + self.set_baseline() + return 0.0 + + current = self.get_current() + return current - self._baseline + + def sample(self, label: str = "") -> dict[str, Any]: + """Take a memory sample with optional label.""" + memory_info = self._process.memory_info() + current_mb = memory_info.rss / (1024 * 1024) + + if current_mb > self._peak: + self._peak = current_mb + + sample = { + "timestamp": datetime.utcnow().isoformat(), + "label": label, + "rss_mb": round(current_mb, 2), + "vms_mb": round(memory_info.vms / (1024 * 1024), 2), + "percent": round(self._process.memory_percent(), 2), + "delta_from_baseline_mb": round(self.get_delta(), 2) if self._baseline else 0.0, + "peak_mb": round(self._peak, 2), + } + + self._samples.append(sample) + self.logger.debug(f"Memory sample [{label}]: {current_mb:.2f} MB") + return sample + + def check_leak(self, threshold_mb: float = 10.0) -> dict[str, Any]: + """ + Check for potential memory leak. + + Args: + threshold_mb: Memory increase threshold to consider as leak + + Returns: + Leak detection result + """ + if self._baseline is None: + return {"status": "no_baseline", "message": "Set baseline first"} + + current = self.get_current() + delta = current - self._baseline + + if delta > threshold_mb: + self.logger.warning(f"Potential memory leak detected: {delta:.2f} MB increase") + return { + "status": "potential_leak", + "baseline_mb": round(self._baseline, 2), + "current_mb": round(current, 2), + "delta_mb": round(delta, 2), + "threshold_mb": threshold_mb, + } + + return { + "status": "ok", + "baseline_mb": round(self._baseline, 2), + "current_mb": round(current, 2), + "delta_mb": round(delta, 2), + "threshold_mb": threshold_mb, + } + + def get_summary(self) -> dict[str, Any]: + """Get memory profiling summary.""" + if not self._samples: + return {"message": "No samples collected"} + + rss_values = [s["rss_mb"] for s in self._samples] + + return { + "sample_count": len(self._samples), + "baseline_mb": round(self._baseline, 2) if self._baseline else None, + "current_mb": round(self.get_current(), 2), + "peak_mb": round(self._peak, 2), + "mean_mb": round(sum(rss_values) / len(rss_values), 2), + "min_mb": round(min(rss_values), 2), + "max_mb": round(max(rss_values), 2), + } + + +@contextmanager +def profile_block( + name: str, + metadata: dict[str, Any] | None = None, +): + """ + Convenience context manager for profiling a code block. + + Uses the global AsyncProfiler singleton. + + Example: + with profile_block("data_processing", {"batch_size": 100}): + process_data(batch) + """ + profiler = AsyncProfiler.get_instance() + with profiler.time_block(name, metadata=metadata): + yield + + +def generate_performance_report(session_id: str | None = None) -> dict[str, Any]: + """ + Generate a comprehensive performance report. + + Args: + session_id: Optional specific session (uses current if not specified) + + Returns: + Performance report with timing summaries, memory stats, etc. + """ + profiler = AsyncProfiler.get_instance() + + report = { + "report_time": datetime.utcnow().isoformat(), + "timing_summary": profiler.get_timing_summary(), + } + + # Add session-specific data if available + session = profiler._sessions.get(session_id) if session_id else profiler.get_current_session() + + if session: + report["session"] = { + "session_id": session.session_id, + "start_time": session.start_time.isoformat(), + "timing_count": len(session.timings), + "memory_samples": len(session.memory_samples), + "cpu_samples": len(session.cpu_samples), + "markers_count": len(session.markers), + } + + # Compute session-specific stats + if session.timings: + session_times = {} + for timing in session.timings: + if timing.name not in session_times: + session_times[timing.name] = [] + session_times[timing.name].append(timing.elapsed_ms) + + report["session"]["timing_breakdown"] = { + name: profiler._compute_stats(name, times) for name, times in session_times.items() + } + + if session.memory_samples: + rss_values = [s["rss_mb"] for s in session.memory_samples] + report["session"]["memory_summary"] = { + "sample_count": len(rss_values), + "mean_mb": round(sum(rss_values) / len(rss_values), 2), + "min_mb": round(min(rss_values), 2), + "max_mb": round(max(rss_values), 2), + } + + if session.cpu_samples: + report["session"]["cpu_summary"] = { + "sample_count": len(session.cpu_samples), + "mean_percent": round(sum(session.cpu_samples) / len(session.cpu_samples), 2), + "min_percent": round(min(session.cpu_samples), 2), + "max_percent": round(max(session.cpu_samples), 2), + } + + # Current system state + process = psutil.Process() + report["current_system"] = { + "memory_mb": round(process.memory_info().rss / (1024 * 1024), 2), + "cpu_percent": process.cpu_percent(), + "thread_count": process.num_threads(), + } + + return report + + +def profile_function(name: str | None = None, metadata: dict[str, Any] | None = None): + """ + Decorator for profiling function execution. + + Args: + name: Optional custom name (defaults to function name) + metadata: Additional metadata + + Example: + @profile_function() + def process_batch(data): + ... + + @profile_function(name="custom_name") + async def async_operation(): + ... + """ + + def decorator(func): + op_name = name or f"{func.__module__}.{func.__name__}" + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + profiler = AsyncProfiler.get_instance() + with profiler.time_block(op_name, metadata=metadata): + return func(*args, **kwargs) + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + profiler = AsyncProfiler.get_instance() + async with profiler.async_time_block(op_name, metadata=metadata): + return await func(*args, **kwargs) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper + + return decorator diff --git a/src/observability/tracing.py b/src/observability/tracing.py new file mode 100644 index 0000000000000000000000000000000000000000..bdfbb9ac16a2090dae052a59d637e6b4586ebae7 --- /dev/null +++ b/src/observability/tracing.py @@ -0,0 +1,384 @@ +""" +OpenTelemetry tracing infrastructure for multi-agent MCTS framework. + +Provides: +- OpenTelemetry SDK integration +- Automatic span creation for key operations +- Trace context propagation +- OTLP exporter configuration from environment +- Custom attributes for MCTS metrics +- httpx instrumentation for LLM calls +""" + +import functools +import os +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Optional + +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +from opentelemetry.propagate import extract, inject +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, + ConsoleSpanExporter, + SimpleSpanProcessor, +) +from opentelemetry.trace import Span, SpanKind, Status, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +from .logging import get_correlation_id + + +class TracingManager: + """ + Manages OpenTelemetry tracing configuration and lifecycle. + + Environment Variables: + OTEL_EXPORTER_OTLP_ENDPOINT: OTLP collector endpoint (default: localhost:4317) + OTEL_SERVICE_NAME: Service name for traces (default: mcts-framework) + OTEL_EXPORTER_TYPE: Exporter type (otlp, console, none) (default: otlp) + OTEL_TRACE_SAMPLE_RATE: Sampling rate 0.0-1.0 (default: 1.0) + """ + + _instance: Optional["TracingManager"] = None + _provider: TracerProvider | None = None + + def __init__(self): + self._initialized = False + self._httpx_instrumented = False + + @classmethod + def get_instance(cls) -> "TracingManager": + """Get singleton instance of TracingManager.""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def initialize( + self, + service_name: str | None = None, + otlp_endpoint: str | None = None, + exporter_type: str | None = None, + additional_resources: dict[str, str] | None = None, + ) -> None: + """ + Initialize OpenTelemetry tracing. + + Args: + service_name: Service name for traces + otlp_endpoint: OTLP collector endpoint + exporter_type: Type of exporter (otlp, console, none) + additional_resources: Additional resource attributes + """ + if self._initialized: + return + + # Get configuration from environment or parameters + service_name = service_name or os.environ.get("OTEL_SERVICE_NAME", "mcts-framework") + otlp_endpoint = otlp_endpoint or os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "localhost:4317") + exporter_type = exporter_type or os.environ.get("OTEL_EXPORTER_TYPE", "otlp") + + # Build resource attributes + resource_attrs = { + SERVICE_NAME: service_name, + "service.version": os.environ.get("SERVICE_VERSION", "0.1.0"), + "deployment.environment": os.environ.get("ENVIRONMENT", "development"), + } + + if additional_resources: + resource_attrs.update(additional_resources) + + resource = Resource.create(resource_attrs) + + # Create tracer provider + self._provider = TracerProvider(resource=resource) + + # Configure exporter based on type + if exporter_type.lower() == "otlp": + exporter = OTLPSpanExporter( + endpoint=otlp_endpoint, + insecure=os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "true").lower() == "true", + ) + processor = BatchSpanProcessor(exporter) + elif exporter_type.lower() == "console": + exporter = ConsoleSpanExporter() + processor = SimpleSpanProcessor(exporter) + elif exporter_type.lower() == "none": + processor = None + else: + raise ValueError(f"Unknown exporter type: {exporter_type}") + + if processor: + self._provider.add_span_processor(processor) + + # Set as global provider + trace.set_tracer_provider(self._provider) + + # Instrument httpx for LLM calls + self._instrument_httpx() + + self._initialized = True + + def _instrument_httpx(self) -> None: + """Instrument httpx client for automatic tracing of HTTP requests.""" + if self._httpx_instrumented: + return + + try: + HTTPXClientInstrumentor().instrument() + self._httpx_instrumented = True + except Exception: + # httpx instrumentation is optional + pass + + def shutdown(self) -> None: + """Shutdown tracing provider.""" + if self._provider: + self._provider.shutdown() + self._initialized = False + + def get_tracer(self, name: str = "mcts-framework") -> trace.Tracer: + """Get a tracer instance.""" + if not self._initialized: + self.initialize() + return trace.get_tracer(name) + + +def get_tracer(name: str = "mcts-framework") -> trace.Tracer: + """Get a tracer instance from the global TracingManager.""" + return TracingManager.get_instance().get_tracer(name) + + +def add_mcts_attributes(span: Span, **attributes: Any) -> None: + """ + Add MCTS-specific attributes to a span. + + Common attributes: + - mcts.iteration: Current MCTS iteration number + - mcts.node_visits: Number of visits to current node + - mcts.node_value: Value of current node + - mcts.ucb_score: UCB score for selection + - mcts.exploration_weight: Exploration weight parameter + - mcts.tree_depth: Current depth in tree + - agent.name: Name of the agent + - agent.confidence: Agent confidence score + """ + for key, value in attributes.items(): + if value is not None: + # Prefix non-standard attributes + if not key.startswith(("mcts.", "agent.", "framework.")): + key = f"custom.{key}" + span.set_attribute(key, value) + + +@contextmanager +def trace_span( + name: str, + kind: SpanKind = SpanKind.INTERNAL, + attributes: dict[str, Any] | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, +): + """ + Context manager for creating a traced span. + + Args: + name: Name of the span + kind: Span kind (INTERNAL, CLIENT, SERVER, PRODUCER, CONSUMER) + attributes: Initial attributes for the span + record_exception: Record exceptions as span events + set_status_on_exception: Set span status to ERROR on exception + + Example: + with trace_span("mcts.selection", attributes={"mcts.iteration": 5}) as span: + # Perform selection + span.set_attribute("mcts.selected_node", node_id) + """ + tracer = get_tracer() + with tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes or {}, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: + # Add correlation ID as attribute + span.set_attribute("correlation_id", get_correlation_id()) + yield span + + +@asynccontextmanager +async def async_trace_span( + name: str, + kind: SpanKind = SpanKind.INTERNAL, + attributes: dict[str, Any] | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, +): + """ + Async context manager for creating a traced span. + + Same as trace_span but for async contexts. + """ + tracer = get_tracer() + with tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes or {}, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: + # Add correlation ID as attribute + span.set_attribute("correlation_id", get_correlation_id()) + yield span + + +def trace_operation( + name: str | None = None, + kind: SpanKind = SpanKind.INTERNAL, + attributes: dict[str, Any] | None = None, +): + """ + Decorator for tracing function execution. + + Args: + name: Span name (defaults to function name) + kind: Span kind + attributes: Additional attributes + + Example: + @trace_operation(attributes={"component": "mcts"}) + async def select_best_child(node): + ... + """ + + def decorator(func): + span_name = name or f"{func.__module__}.{func.__name__}" + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + with trace_span(span_name, kind=kind, attributes=attributes) as span: + # Add function arguments as attributes (limited) + span.set_attribute("function.args_count", len(args)) + span.set_attribute("function.kwargs_count", len(kwargs)) + + result = func(*args, **kwargs) + + # Mark as successful + span.set_status(Status(StatusCode.OK)) + return result + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + async with async_trace_span(span_name, kind=kind, attributes=attributes) as span: + # Add function arguments as attributes (limited) + span.set_attribute("function.args_count", len(args)) + span.set_attribute("function.kwargs_count", len(kwargs)) + + result = await func(*args, **kwargs) + + # Mark as successful + span.set_status(Status(StatusCode.OK)) + return result + + if asyncio.iscoroutinefunction(func): + return async_wrapper + return sync_wrapper + + return decorator + + +class SpanContextPropagator: + """ + Utility for propagating trace context across service boundaries. + + Example: + # Inject context into headers + headers = {} + propagator = SpanContextPropagator() + propagator.inject(headers) + + # Extract context from headers + context = propagator.extract(headers) + with trace_span("operation", context=context): + ... + """ + + def __init__(self): + self._propagator = TraceContextTextMapPropagator() + + def inject(self, carrier: dict[str, str], context: Context | None = None) -> None: + """Inject trace context into a carrier (e.g., HTTP headers).""" + inject(carrier, context=context) + + def extract(self, carrier: dict[str, str]) -> Context: + """Extract trace context from a carrier.""" + return extract(carrier) + + def get_trace_parent(self) -> str | None: + """Get the traceparent header value for the current span.""" + carrier = {} + self.inject(carrier) + return carrier.get("traceparent") + + +def record_mcts_iteration( + iteration: int, + selected_node_id: str, + ucb_score: float, + node_visits: int, + node_value: float, + tree_depth: int, +) -> None: + """ + Record MCTS iteration as a span event. + + Call this within an active span to add iteration details. + """ + current_span = trace.get_current_span() + if current_span: + current_span.add_event( + "mcts.iteration", + attributes={ + "mcts.iteration": iteration, + "mcts.selected_node_id": selected_node_id, + "mcts.ucb_score": ucb_score, + "mcts.node_visits": node_visits, + "mcts.node_value": node_value, + "mcts.tree_depth": tree_depth, + }, + ) + + +def record_agent_execution( + agent_name: str, + confidence: float, + execution_time_ms: float, + success: bool, + error: str | None = None, +) -> None: + """ + Record agent execution as a span event. + + Call this within an active span to add agent execution details. + """ + current_span = trace.get_current_span() + if current_span: + attrs = { + "agent.name": agent_name, + "agent.confidence": confidence, + "agent.execution_time_ms": execution_time_ms, + "agent.success": success, + } + if error: + attrs["agent.error"] = error + + current_span.add_event("agent.execution", attributes=attrs) + + +# Import asyncio for decorator +import asyncio # noqa: E402 diff --git a/src/storage/__init__.py b/src/storage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..386cf70cd4bf2be6ebdcdbfa47098d69a1343c03 --- /dev/null +++ b/src/storage/__init__.py @@ -0,0 +1,31 @@ +# Storage Module +""" +Storage infrastructure for multi-agent MCTS framework. + +Includes: +- Async S3 client with retry strategies +- Content-hash based idempotent keys +- Compression support +- Pinecone vector storage for agent selection history +""" + +from .s3_client import S3Config, S3StorageClient + +# Pinecone integration (optional) +try: + from .pinecone_store import ( # noqa: F401 + PINECONE_AVAILABLE, + PineconeVectorStore, + ) + + _pinecone_exports = [ + "PineconeVectorStore", + "PINECONE_AVAILABLE", + ] +except ImportError: + _pinecone_exports = [] + +__all__ = [ + "S3StorageClient", + "S3Config", +] + _pinecone_exports diff --git a/src/storage/pinecone_store.py b/src/storage/pinecone_store.py new file mode 100644 index 0000000000000000000000000000000000000000..e52b2fa58b828561710f1c800b3c098c2f508725 --- /dev/null +++ b/src/storage/pinecone_store.py @@ -0,0 +1,426 @@ +""" +Pinecone vector storage integration for Meta-Controller features and predictions. + +Provides semantic search and retrieval of agent selection history using vector embeddings. +""" + +import os +import uuid +from datetime import datetime +from typing import Any + +# Check if pinecone is available +try: + from pinecone import Pinecone + + PINECONE_AVAILABLE = True +except ImportError: + PINECONE_AVAILABLE = False + Pinecone = None # type: ignore + +from src.agents.meta_controller.base import MetaControllerFeatures, MetaControllerPrediction +from src.agents.meta_controller.utils import normalize_features + + +class PineconeVectorStore: + """ + Vector storage for Meta-Controller features and predictions using Pinecone. + + Stores agent selection decisions as vectors for: + - Finding similar past routing decisions + - Analyzing patterns in agent selection + - Building retrieval-augmented routing strategies + """ + + # Dimension of normalized feature vectors + VECTOR_DIMENSION = 10 + + def __init__( + self, + api_key: str | None = None, + host: str | None = None, + namespace: str = "meta_controller", + auto_init: bool = True, + ): + """ + Initialize Pinecone vector store. + + Args: + api_key: Pinecone API key (if None, reads from PINECONE_API_KEY env var) + host: Pinecone host URL (if None, reads from PINECONE_HOST env var) + namespace: Namespace for storing vectors (default: "meta_controller") + auto_init: Whether to initialize Pinecone client immediately + """ + self._api_key = api_key or os.environ.get("PINECONE_API_KEY") + self._host = host or os.environ.get("PINECONE_HOST") + self.namespace = namespace + self._client: Any = None + self._index: Any = None + self._is_initialized = False + self._operation_buffer: list[dict[str, Any]] = [] + + if not PINECONE_AVAILABLE: + print("Warning: pinecone package not installed. Install with: pip install pinecone") + return + + if auto_init and self._api_key and self._host: + self._initialize() + + def _initialize(self) -> None: + """Initialize Pinecone client and index connection.""" + if not PINECONE_AVAILABLE: + return + + if self._api_key and self._host: + try: + self._client = Pinecone(api_key=self._api_key) + self._index = self._client.Index(host=self._host) + self._is_initialized = True + except Exception as e: + print(f"Warning: Failed to initialize Pinecone: {e}") + self._is_initialized = False + + @property + def is_available(self) -> bool: + """Check if Pinecone is available and configured.""" + return PINECONE_AVAILABLE and self._is_initialized and self._api_key is not None and self._host is not None + + def store_prediction( + self, + features: MetaControllerFeatures, + prediction: MetaControllerPrediction, + metadata: dict[str, Any] | None = None, + ) -> str | None: + """ + Store a prediction along with its input features. + + Args: + features: Input features used for the prediction + prediction: The prediction result + metadata: Optional additional metadata + + Returns: + Vector ID if successful, None otherwise + """ + if not self.is_available: + # Buffer the operation for when Pinecone becomes available + self._operation_buffer.append( + { + "type": "store_prediction", + "features": features, + "prediction": prediction, + "metadata": metadata, + "timestamp": datetime.now().isoformat(), + } + ) + return None + + try: + # Normalize features to create the vector + vector = normalize_features(features) + + # Generate unique ID + vector_id = str(uuid.uuid4()) + + # Build metadata + vector_metadata = { + "selected_agent": prediction.agent, + "confidence": prediction.confidence, + "hrm_prob": prediction.probabilities.get("hrm", 0.0), + "trm_prob": prediction.probabilities.get("trm", 0.0), + "mcts_prob": prediction.probabilities.get("mcts", 0.0), + "timestamp": datetime.now().isoformat(), + "iteration": features.iteration, + "query_length": features.query_length, + "last_agent": features.last_agent, + "has_rag_context": features.has_rag_context, + } + + if metadata: + vector_metadata.update(metadata) + + # Upsert to Pinecone + self._index.upsert( + vectors=[ + { + "id": vector_id, + "values": vector, + "metadata": vector_metadata, + } + ], + namespace=self.namespace, + ) + + return vector_id + + except Exception as e: + print(f"Warning: Failed to store prediction in Pinecone: {e}") + return None + + def find_similar_decisions( + self, + features: MetaControllerFeatures, + top_k: int = 5, + include_metadata: bool = True, + ) -> list[dict[str, Any]]: + """ + Find similar past routing decisions based on current features. + + Args: + features: Current features to find similar decisions for + top_k: Number of similar decisions to return + include_metadata: Whether to include metadata in results + + Returns: + List of similar decisions with scores and metadata + """ + if not self.is_available: + return [] + + try: + # Normalize features to create query vector + query_vector = normalize_features(features) + + # Query Pinecone + results = self._index.query( + vector=query_vector, + top_k=top_k, + include_metadata=include_metadata, + namespace=self.namespace, + ) + + # Format results + similar_decisions = [] + for match in results.get("matches", []): + decision = { + "id": match.get("id"), + "score": match.get("score"), + } + if include_metadata and "metadata" in match: + decision["metadata"] = match["metadata"] + similar_decisions.append(decision) + + return similar_decisions + + except Exception as e: + print(f"Warning: Failed to query Pinecone: {e}") + return [] + + def get_agent_distribution( + self, + features: MetaControllerFeatures, + top_k: int = 10, + ) -> dict[str, float]: + """ + Get the distribution of agent selections for similar past decisions. + + Useful for rule-based fallback that considers historical patterns. + + Args: + features: Current features + top_k: Number of similar decisions to consider + + Returns: + Dictionary mapping agent names to selection frequency + """ + similar = self.find_similar_decisions(features, top_k=top_k, include_metadata=True) + + if not similar: + return {"hrm": 0.0, "trm": 0.0, "mcts": 0.0} + + # Count agent selections + counts = {"hrm": 0, "trm": 0, "mcts": 0} + total = 0 + + for decision in similar: + if "metadata" in decision: + agent = decision["metadata"].get("selected_agent") + if agent in counts: + counts[agent] += 1 + total += 1 + + # Convert to distribution + if total > 0: + return {agent: count / total for agent, count in counts.items()} + else: + return {"hrm": 0.0, "trm": 0.0, "mcts": 0.0} + + def store_batch( + self, + features_list: list[MetaControllerFeatures], + predictions_list: list[MetaControllerPrediction], + batch_metadata: dict[str, Any] | None = None, + ) -> int: + """ + Store multiple predictions in a batch. + + Args: + features_list: List of input features + predictions_list: List of corresponding predictions + batch_metadata: Optional metadata to apply to all vectors + + Returns: + Number of vectors successfully stored + """ + if not self.is_available: + # Buffer for later + self._operation_buffer.append( + { + "type": "store_batch", + "features_list": features_list, + "predictions_list": predictions_list, + "batch_metadata": batch_metadata, + "timestamp": datetime.now().isoformat(), + } + ) + return 0 + + if len(features_list) != len(predictions_list): + raise ValueError("Features and predictions lists must have same length") + + try: + vectors = [] + for features, prediction in zip(features_list, predictions_list, strict=False): + vector_id = str(uuid.uuid4()) + vector_values = normalize_features(features) + + metadata = { + "selected_agent": prediction.agent, + "confidence": prediction.confidence, + "hrm_prob": prediction.probabilities.get("hrm", 0.0), + "trm_prob": prediction.probabilities.get("trm", 0.0), + "mcts_prob": prediction.probabilities.get("mcts", 0.0), + "timestamp": datetime.now().isoformat(), + "iteration": features.iteration, + "query_length": features.query_length, + "last_agent": features.last_agent, + "has_rag_context": features.has_rag_context, + } + + if batch_metadata: + metadata.update(batch_metadata) + + vectors.append( + { + "id": vector_id, + "values": vector_values, + "metadata": metadata, + } + ) + + # Upsert batch to Pinecone + self._index.upsert(vectors=vectors, namespace=self.namespace) + + return len(vectors) + + except Exception as e: + print(f"Warning: Failed to store batch in Pinecone: {e}") + return 0 + + def delete_namespace(self) -> bool: + """ + Delete all vectors in the current namespace. + + Use with caution! This permanently deletes all stored data. + + Returns: + True if successful, False otherwise + """ + if not self.is_available: + return False + + try: + self._index.delete(delete_all=True, namespace=self.namespace) + return True + except Exception as e: + print(f"Warning: Failed to delete namespace: {e}") + return False + + def get_stats(self) -> dict[str, Any]: + """ + Get statistics about the vector store. + + Returns: + Dictionary containing index statistics + """ + if not self.is_available: + return { + "available": False, + "buffered_operations": len(self._operation_buffer), + } + + try: + stats = self._index.describe_index_stats() + return { + "available": True, + "total_vectors": stats.get("total_vector_count", 0), + "namespace_stats": stats.get("namespaces", {}), + "dimension": stats.get("dimension", self.VECTOR_DIMENSION), + "buffered_operations": len(self._operation_buffer), + } + except Exception as e: + return { + "available": True, + "error": str(e), + "buffered_operations": len(self._operation_buffer), + } + + def get_buffered_operations(self) -> list[dict[str, Any]]: + """ + Get all buffered operations (useful when Pinecone is not available). + + Returns: + List of buffered operation dictionaries + """ + return self._operation_buffer.copy() + + def clear_buffer(self) -> None: + """Clear the operations buffer.""" + self._operation_buffer.clear() + + def flush_buffer(self) -> int: + """ + Attempt to flush buffered operations to Pinecone. + + Returns: + Number of operations successfully flushed + """ + if not self.is_available or not self._operation_buffer: + return 0 + + flushed = 0 + remaining_buffer = [] + + for operation in self._operation_buffer: + try: + if operation["type"] == "store_prediction": + result = self.store_prediction( + features=operation["features"], + prediction=operation["prediction"], + metadata=operation.get("metadata"), + ) + if result: + flushed += 1 + else: + remaining_buffer.append(operation) + elif operation["type"] == "store_batch": + count = self.store_batch( + features_list=operation["features_list"], + predictions_list=operation["predictions_list"], + batch_metadata=operation.get("batch_metadata"), + ) + if count > 0: + flushed += 1 + else: + remaining_buffer.append(operation) + except Exception: + remaining_buffer.append(operation) + + self._operation_buffer = remaining_buffer + return flushed + + +__all__ = [ + "PineconeVectorStore", + "PINECONE_AVAILABLE", +] diff --git a/src/storage/s3_client.py b/src/storage/s3_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8638d4ff01683a7374c037046a96ac96d77ef2 --- /dev/null +++ b/src/storage/s3_client.py @@ -0,0 +1,596 @@ +""" +Async S3 storage client for multi-agent MCTS framework. + +Provides: +- aioboto3 async client +- Retry strategy with tenacity +- Exponential backoff for failures +- Content-hash based idempotent keys +- Store: configs, MCTS stats, traces, logs +- Compression support +""" + +import asyncio +import gzip +import hashlib +import json +import os +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import aioboto3 +from botocore.config import Config as BotoConfig +from botocore.exceptions import ClientError +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from src.observability.logging import get_logger + + +@dataclass +class S3Config: + """Configuration for S3 storage client.""" + + bucket_name: str = field(default_factory=lambda: os.environ.get("S3_BUCKET_NAME", "mcts-framework-storage")) + region_name: str = field(default_factory=lambda: os.environ.get("AWS_REGION", "us-east-1")) + endpoint_url: str | None = field(default_factory=lambda: os.environ.get("S3_ENDPOINT_URL")) + + # Retry configuration + max_retries: int = 5 + initial_wait_seconds: float = 1.0 + max_wait_seconds: float = 60.0 + exponential_base: float = 2.0 + + # Storage options + enable_compression: bool = True + compression_threshold_bytes: int = 1024 # Compress if larger than 1KB + use_content_hash_keys: bool = True + + # Prefixes for different data types + prefix_configs: str = "configs/" + prefix_mcts_stats: str = "mcts-stats/" + prefix_traces: str = "traces/" + prefix_logs: str = "logs/" + prefix_checkpoints: str = "checkpoints/" + + +class S3StorageClient: + """ + Async S3 storage client with retry logic and compression. + + Features: + - Automatic retries with exponential backoff + - Content-hash based idempotent keys for deduplication + - Gzip compression for large payloads + - Organized storage by data type + """ + + def __init__(self, config: S3Config | None = None): + """ + Initialize S3 storage client. + + Args: + config: S3 configuration (uses environment variables if not provided) + """ + self.config = config or S3Config() + self.logger = get_logger("storage.s3") + self._session: aioboto3.Session | None = None + self._initialized = False + + # boto3 config with retries and timeouts + self._boto_config = BotoConfig( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=10, + read_timeout=30, + max_pool_connections=25, + ) + + async def initialize(self) -> None: + """Initialize the aioboto3 session.""" + if self._initialized: + return + + self._session = aioboto3.Session() + self._initialized = True + self.logger.info(f"S3 client initialized for bucket: {self.config.bucket_name}") + + async def close(self) -> None: + """Close the client (cleanup if needed).""" + self._initialized = False + self.logger.info("S3 client closed") + + def _get_client_params(self) -> dict[str, Any]: + """Get parameters for S3 client context manager.""" + params = { + "region_name": self.config.region_name, + "config": self._boto_config, + } + if self.config.endpoint_url: + params["endpoint_url"] = self.config.endpoint_url + return params + + def _compute_content_hash(self, data: bytes) -> str: + """Compute SHA256 hash of content for idempotent keys.""" + return hashlib.sha256(data).hexdigest() + + def _compress_data(self, data: bytes) -> bytes: + """Compress data using gzip.""" + return gzip.compress(data) + + def _decompress_data(self, data: bytes) -> bytes: + """Decompress gzip data.""" + return gzip.decompress(data) + + def _should_compress(self, data: bytes) -> bool: + """Determine if data should be compressed.""" + return self.config.enable_compression and len(data) >= self.config.compression_threshold_bytes + + def _generate_key( + self, + prefix: str, + name: str, + data: bytes | None = None, + timestamp: datetime | None = None, + ) -> str: + """ + Generate S3 key for object. + + Args: + prefix: Storage prefix (e.g., configs/, logs/) + name: Object name + data: Optional data for content-hash key generation + timestamp: Optional timestamp for key + + Returns: + Full S3 key + """ + if timestamp is None: + timestamp = datetime.utcnow() + + date_prefix = timestamp.strftime("%Y/%m/%d") + + if self.config.use_content_hash_keys and data: + content_hash = self._compute_content_hash(data)[:12] + return f"{prefix}{date_prefix}/{name}_{content_hash}" + else: + timestamp_str = timestamp.strftime("%H%M%S_%f") + return f"{prefix}{date_prefix}/{name}_{timestamp_str}" + + @retry( + retry=retry_if_exception_type((ClientError, asyncio.TimeoutError)), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1, max=60), + reraise=True, + ) + async def _put_object_with_retry( + self, + key: str, + body: bytes, + content_type: str = "application/octet-stream", + metadata: dict[str, str] | None = None, + ) -> dict[str, Any]: + """ + Put object to S3 with retry logic. + + Uses tenacity for exponential backoff retry strategy. + """ + if not self._session: + await self.initialize() + + async with self._session.client("s3", **self._get_client_params()) as s3: + extra_args = { + "ContentType": content_type, + } + if metadata: + extra_args["Metadata"] = metadata + + response = await s3.put_object( + Bucket=self.config.bucket_name, + Key=key, + Body=body, + **extra_args, + ) + + self.logger.debug( + "Uploaded object to S3", + extra={ + "s3_key": key, + "size_bytes": len(body), + "etag": response.get("ETag"), + }, + ) + + return { + "key": key, + "etag": response.get("ETag"), + "size_bytes": len(body), + "version_id": response.get("VersionId"), + } + + @retry( + retry=retry_if_exception_type((ClientError, asyncio.TimeoutError)), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1, max=60), + reraise=True, + ) + async def _get_object_with_retry(self, key: str) -> bytes: + """ + Get object from S3 with retry logic. + """ + if not self._session: + await self.initialize() + + async with self._session.client("s3", **self._get_client_params()) as s3: + response = await s3.get_object( + Bucket=self.config.bucket_name, + Key=key, + ) + + async with response["Body"] as stream: + data = await stream.read() + + self.logger.debug( + "Downloaded object from S3", + extra={ + "s3_key": key, + "size_bytes": len(data), + }, + ) + + return data + + async def store_config( + self, + config_name: str, + config_data: dict[str, Any], + session_id: str | None = None, + ) -> dict[str, Any]: + """ + Store configuration data to S3. + + Args: + config_name: Name of the configuration + config_data: Configuration dictionary + session_id: Optional session identifier + + Returns: + Upload result with key and metadata + """ + json_data = json.dumps(config_data, indent=2, default=str).encode("utf-8") + + key = self._generate_key( + prefix=self.config.prefix_configs, + name=config_name, + data=json_data if self.config.use_content_hash_keys else None, + ) + + if self._should_compress(json_data): + body = self._compress_data(json_data) + key += ".gz" + content_type = "application/gzip" + else: + body = json_data + content_type = "application/json" + + metadata = { + "config_name": config_name, + "original_size": str(len(json_data)), + "compressed": str(len(body) != len(json_data)), + } + if session_id: + metadata["session_id"] = session_id + + result = await self._put_object_with_retry(key, body, content_type, metadata) + self.logger.info(f"Stored config '{config_name}' to S3: {key}") + return result + + async def store_mcts_stats( + self, + session_id: str, + stats: dict[str, Any], + iteration: int | None = None, + ) -> dict[str, Any]: + """ + Store MCTS statistics to S3. + + Args: + session_id: MCTS session identifier + stats: Statistics dictionary + iteration: Optional iteration number + + Returns: + Upload result + """ + name = f"{session_id}_stats" + if iteration is not None: + name += f"_iter{iteration}" + + json_data = json.dumps(stats, indent=2, default=str).encode("utf-8") + + key = self._generate_key( + prefix=self.config.prefix_mcts_stats, + name=name, + data=json_data if self.config.use_content_hash_keys else None, + ) + + if self._should_compress(json_data): + body = self._compress_data(json_data) + key += ".gz" + content_type = "application/gzip" + else: + body = json_data + content_type = "application/json" + + metadata = { + "session_id": session_id, + "data_type": "mcts_stats", + "original_size": str(len(json_data)), + } + if iteration is not None: + metadata["iteration"] = str(iteration) + + result = await self._put_object_with_retry(key, body, content_type, metadata) + self.logger.info(f"Stored MCTS stats for session '{session_id}' to S3: {key}") + return result + + async def store_traces( + self, + session_id: str, + trace_data: dict[str, Any] | list[dict[str, Any]], + ) -> dict[str, Any]: + """ + Store trace data to S3. + + Args: + session_id: Session identifier + trace_data: Trace spans/events + + Returns: + Upload result + """ + json_data = json.dumps(trace_data, indent=2, default=str).encode("utf-8") + + key = self._generate_key( + prefix=self.config.prefix_traces, + name=f"{session_id}_traces", + data=json_data if self.config.use_content_hash_keys else None, + ) + + if self._should_compress(json_data): + body = self._compress_data(json_data) + key += ".gz" + content_type = "application/gzip" + else: + body = json_data + content_type = "application/json" + + metadata = { + "session_id": session_id, + "data_type": "traces", + "original_size": str(len(json_data)), + } + + result = await self._put_object_with_retry(key, body, content_type, metadata) + self.logger.info(f"Stored traces for session '{session_id}' to S3: {key}") + return result + + async def store_logs( + self, + session_id: str, + log_entries: list[dict[str, Any]], + ) -> dict[str, Any]: + """ + Store log entries to S3. + + Args: + session_id: Session identifier + log_entries: List of JSON log entries + + Returns: + Upload result + """ + # Store as newline-delimited JSON (NDJSON) + ndjson_data = "\n".join(json.dumps(entry, default=str) for entry in log_entries).encode("utf-8") + + key = self._generate_key( + prefix=self.config.prefix_logs, + name=f"{session_id}_logs", + data=ndjson_data if self.config.use_content_hash_keys else None, + ) + + if self._should_compress(ndjson_data): + body = self._compress_data(ndjson_data) + key += ".gz" + content_type = "application/gzip" + else: + body = ndjson_data + key += ".ndjson" + content_type = "application/x-ndjson" + + metadata = { + "session_id": session_id, + "data_type": "logs", + "entry_count": str(len(log_entries)), + "original_size": str(len(ndjson_data)), + } + + result = await self._put_object_with_retry(key, body, content_type, metadata) + self.logger.info(f"Stored {len(log_entries)} log entries for session '{session_id}' to S3: {key}") + return result + + async def store_checkpoint( + self, + session_id: str, + checkpoint_data: dict[str, Any], + checkpoint_name: str = "checkpoint", + ) -> dict[str, Any]: + """ + Store framework checkpoint/state to S3. + + Args: + session_id: Session identifier + checkpoint_data: Checkpoint state + checkpoint_name: Name for the checkpoint + + Returns: + Upload result + """ + json_data = json.dumps(checkpoint_data, indent=2, default=str).encode("utf-8") + + key = self._generate_key( + prefix=self.config.prefix_checkpoints, + name=f"{session_id}_{checkpoint_name}", + data=json_data if self.config.use_content_hash_keys else None, + ) + + if self._should_compress(json_data): + body = self._compress_data(json_data) + key += ".gz" + content_type = "application/gzip" + else: + body = json_data + content_type = "application/json" + + metadata = { + "session_id": session_id, + "data_type": "checkpoint", + "checkpoint_name": checkpoint_name, + "original_size": str(len(json_data)), + } + + result = await self._put_object_with_retry(key, body, content_type, metadata) + self.logger.info(f"Stored checkpoint '{checkpoint_name}' for session '{session_id}' to S3: {key}") + return result + + async def retrieve_object(self, key: str) -> bytes: + """ + Retrieve and decompress object from S3. + + Args: + key: S3 object key + + Returns: + Decompressed data bytes + """ + data = await self._get_object_with_retry(key) + + # Auto-decompress if gzip + if key.endswith(".gz"): + data = self._decompress_data(data) + + return data + + async def retrieve_json(self, key: str) -> Any: + """ + Retrieve JSON object from S3. + + Args: + key: S3 object key + + Returns: + Parsed JSON data + """ + data = await self.retrieve_object(key) + return json.loads(data.decode("utf-8")) + + async def list_objects( + self, + prefix: str, + max_keys: int = 1000, + ) -> list[dict[str, Any]]: + """ + List objects with given prefix. + + Args: + prefix: S3 key prefix + max_keys: Maximum objects to return + + Returns: + List of object metadata + """ + if not self._session: + await self.initialize() + + async with self._session.client("s3", **self._get_client_params()) as s3: + response = await s3.list_objects_v2( + Bucket=self.config.bucket_name, + Prefix=prefix, + MaxKeys=max_keys, + ) + + objects = [] + for obj in response.get("Contents", []): + objects.append( + { + "key": obj["Key"], + "size": obj["Size"], + "last_modified": obj["LastModified"], + "etag": obj["ETag"], + } + ) + + return objects + + async def delete_object(self, key: str) -> dict[str, Any]: + """ + Delete object from S3. + + Args: + key: S3 object key + + Returns: + Deletion result + """ + if not self._session: + await self.initialize() + + async with self._session.client("s3", **self._get_client_params()) as s3: + response = await s3.delete_object( + Bucket=self.config.bucket_name, + Key=key, + ) + + self.logger.info(f"Deleted object from S3: {key}") + + return { + "key": key, + "deleted": True, + "version_id": response.get("VersionId"), + } + + async def health_check(self) -> dict[str, Any]: + """ + Check S3 connectivity and bucket access. + + Returns: + Health check result + """ + if not self._session: + await self.initialize() + + try: + async with self._session.client("s3", **self._get_client_params()) as s3: + # Try to head the bucket + await s3.head_bucket(Bucket=self.config.bucket_name) + + return { + "status": "healthy", + "bucket": self.config.bucket_name, + "region": self.config.region_name, + "timestamp": datetime.utcnow().isoformat(), + } + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + return { + "status": "unhealthy", + "bucket": self.config.bucket_name, + "error_code": error_code, + "error_message": str(e), + "timestamp": datetime.utcnow().isoformat(), + } diff --git a/src/training/__init__.py b/src/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a821b9576ba2235cddf64bf0091ad74b4f24c2fd --- /dev/null +++ b/src/training/__init__.py @@ -0,0 +1,27 @@ +""" +Training Module for Multi-Agent MCTS. + +Provides: +- Experiment tracking (Braintrust, W&B) +- Training pipelines +- Model evaluation +- Artifact management +""" + +from .experiment_tracker import ( + BraintrustTracker, + ExperimentConfig, + TrainingMetrics, + UnifiedExperimentTracker, + WandBTracker, +) + +__all__ = [ + "BraintrustTracker", + "WandBTracker", + "UnifiedExperimentTracker", + "TrainingMetrics", + "ExperimentConfig", +] + +__version__ = "1.0.0" diff --git a/src/training/data_generator.py b/src/training/data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6738dcb82584e6733e0375fdaca8939a686a4d --- /dev/null +++ b/src/training/data_generator.py @@ -0,0 +1,622 @@ +""" +Synthetic data generator for training Neural Meta-Controllers. + +Provides functionality to generate synthetic training data for meta-controllers +that learn to select the optimal agent (HRM, TRM, or MCTS) based on system state. +""" + +import json +from dataclasses import asdict +from typing import Any + +import numpy as np +import torch + +from src.agents.meta_controller.base import MetaControllerFeatures +from src.agents.meta_controller.utils import features_to_text, normalize_features + + +class MetaControllerDataGenerator: + """ + Synthetic data generator for training neural meta-controllers. + + Generates labeled training data by creating random feature vectors + and determining the optimal agent based on weighted scoring rules. + The generator supports balanced and unbalanced datasets, multiple + output formats (tensors, text), and dataset persistence. + + Attributes: + seed: Random seed for reproducibility. + rng: NumPy random number generator instance. + AGENT_NAMES: List of valid agent names. + LABEL_TO_INDEX: Mapping from agent names to numeric indices. + INDEX_TO_LABEL: Mapping from numeric indices to agent names. + """ + + AGENT_NAMES = ["hrm", "trm", "mcts"] + LABEL_TO_INDEX = {"hrm": 0, "trm": 1, "mcts": 2} + INDEX_TO_LABEL = {0: "hrm", 1: "trm", 2: "mcts"} + + def __init__(self, seed: int = 42) -> None: + """ + Initialize the data generator with a random seed. + + Args: + seed: Random seed for reproducibility. Defaults to 42. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> generator.seed + 42 + """ + self.seed = seed + self.rng = np.random.default_rng(seed) + + def generate_single_sample(self) -> tuple[MetaControllerFeatures, str]: + """ + Generate a single training sample with features and optimal agent label. + + Creates random features and determines the optimal agent based on + weighted scoring rules: + - If hrm_confidence > 0.7 and highest: select "hrm" + - If trm_confidence > 0.7 and highest: select "trm" + - If mcts_value > 0.6 and iteration > 3: select "mcts" + - Otherwise: select agent with highest score + + Returns: + Tuple of (MetaControllerFeatures, optimal_agent_label). + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, label = generator.generate_single_sample() + >>> isinstance(features, MetaControllerFeatures) + True + >>> label in ['hrm', 'trm', 'mcts'] + True + """ + # Generate random features + hrm_confidence = float(self.rng.uniform(0, 1)) + trm_confidence = float(self.rng.uniform(0, 1)) + mcts_value = float(self.rng.uniform(0, 1)) + + # Consensus score is average of confidences plus noise + avg_confidence = (hrm_confidence + trm_confidence + mcts_value) / 3.0 + noise = float(self.rng.uniform(-0.1, 0.1)) + consensus_score = float(np.clip(avg_confidence + noise, 0.0, 1.0)) + + # Random categorical and discrete features + last_agent = str(self.rng.choice(["none", "hrm", "trm", "mcts"])) + iteration = int(self.rng.integers(0, 11)) # [0, 10] inclusive + query_length = int(self.rng.integers(10, 5001)) # [10, 5000] inclusive + has_rag_context = bool(self.rng.choice([True, False])) + + features = MetaControllerFeatures( + hrm_confidence=hrm_confidence, + trm_confidence=trm_confidence, + mcts_value=mcts_value, + consensus_score=consensus_score, + last_agent=last_agent, + iteration=iteration, + query_length=query_length, + has_rag_context=has_rag_context, + ) + + # Determine optimal agent based on weighted scoring + optimal_agent = self._determine_optimal_agent(features) + + return features, optimal_agent + + def _determine_optimal_agent(self, features: MetaControllerFeatures) -> str: + """ + Determine the optimal agent based on weighted scoring rules. + + Args: + features: MetaControllerFeatures to evaluate. + + Returns: + Name of the optimal agent ('hrm', 'trm', or 'mcts'). + """ + hrm_conf = features.hrm_confidence + trm_conf = features.trm_confidence + mcts_val = features.mcts_value + + # Check if HRM should be selected (high confidence and highest) + if hrm_conf > 0.7 and hrm_conf > trm_conf and hrm_conf > mcts_val: + return "hrm" + + # Check if TRM should be selected (high confidence and highest) + if trm_conf > 0.7 and trm_conf > hrm_conf and trm_conf > mcts_val: + return "trm" + + # Check if MCTS should be selected (good value and enough iterations) + if mcts_val > 0.6 and features.iteration > 3: + return "mcts" + + # Default: select agent with highest score + scores = {"hrm": hrm_conf, "trm": trm_conf, "mcts": mcts_val} + return max(scores, key=lambda k: scores[k]) + + def generate_dataset(self, num_samples: int = 1000) -> tuple[list[MetaControllerFeatures], list[str]]: + """ + Generate a dataset with the specified number of samples. + + Creates an unbalanced dataset where the distribution of labels + depends on the random feature generation and scoring rules. + + Args: + num_samples: Number of samples to generate. Defaults to 1000. + + Returns: + Tuple of (features_list, labels_list). + + Raises: + ValueError: If num_samples is not positive. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.generate_dataset(100) + >>> len(features) + 100 + >>> len(labels) + 100 + >>> all(isinstance(f, MetaControllerFeatures) for f in features) + True + """ + if num_samples <= 0: + raise ValueError(f"num_samples must be positive, got {num_samples}") + + features_list: list[MetaControllerFeatures] = [] + labels_list: list[str] = [] + + for _ in range(num_samples): + features, label = self.generate_single_sample() + features_list.append(features) + labels_list.append(label) + + return features_list, labels_list + + def generate_balanced_dataset( + self, num_samples_per_class: int = 500 + ) -> tuple[list[MetaControllerFeatures], list[str]]: + """ + Generate a balanced dataset with equal samples per agent class. + + Creates samples biased toward each agent class to ensure balanced + representation. This is useful for training when class imbalance + would otherwise affect model performance. + + Args: + num_samples_per_class: Number of samples per agent class. + Defaults to 500. + + Returns: + Tuple of (features_list, labels_list) with balanced classes. + + Raises: + ValueError: If num_samples_per_class is not positive. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.generate_balanced_dataset(10) + >>> labels.count('hrm') + 10 + >>> labels.count('trm') + 10 + >>> labels.count('mcts') + 10 + """ + if num_samples_per_class <= 0: + raise ValueError(f"num_samples_per_class must be positive, got {num_samples_per_class}") + + features_list: list[MetaControllerFeatures] = [] + labels_list: list[str] = [] + + # Generate samples for each class + for target_agent in self.AGENT_NAMES: + count = 0 + max_attempts = num_samples_per_class * 100 # Prevent infinite loop + + attempts = 0 + while count < num_samples_per_class and attempts < max_attempts: + attempts += 1 + features = self._generate_biased_features(target_agent) + label = self._determine_optimal_agent(features) + + if label == target_agent: + features_list.append(features) + labels_list.append(label) + count += 1 + + # If we couldn't generate enough samples, force generate the rest + while count < num_samples_per_class: + features = self._generate_forced_features(target_agent) + features_list.append(features) + labels_list.append(target_agent) + count += 1 + + return features_list, labels_list + + def _generate_biased_features(self, target_agent: str) -> MetaControllerFeatures: + """ + Generate features biased toward selecting a specific agent. + + Args: + target_agent: The agent to bias toward ('hrm', 'trm', or 'mcts'). + + Returns: + MetaControllerFeatures biased toward the target agent. + """ + if target_agent == "hrm": + # Bias toward high HRM confidence + hrm_confidence = float(self.rng.uniform(0.7, 1.0)) + trm_confidence = float(self.rng.uniform(0, hrm_confidence - 0.1)) + mcts_value = float(self.rng.uniform(0, hrm_confidence - 0.1)) + elif target_agent == "trm": + # Bias toward high TRM confidence + trm_confidence = float(self.rng.uniform(0.7, 1.0)) + hrm_confidence = float(self.rng.uniform(0, trm_confidence - 0.1)) + mcts_value = float(self.rng.uniform(0, trm_confidence - 0.1)) + else: # mcts + # Bias toward high MCTS value with enough iterations + mcts_value = float(self.rng.uniform(0.6, 1.0)) + hrm_confidence = float(self.rng.uniform(0, 0.7)) + trm_confidence = float(self.rng.uniform(0, 0.7)) + + # Ensure valid ranges + hrm_confidence = float(np.clip(hrm_confidence, 0.0, 1.0)) + trm_confidence = float(np.clip(trm_confidence, 0.0, 1.0)) + mcts_value = float(np.clip(mcts_value, 0.0, 1.0)) + + avg_confidence = (hrm_confidence + trm_confidence + mcts_value) / 3.0 + noise = float(self.rng.uniform(-0.1, 0.1)) + consensus_score = float(np.clip(avg_confidence + noise, 0.0, 1.0)) + + last_agent = str(self.rng.choice(["none", "hrm", "trm", "mcts"])) + + # For MCTS, bias iteration to be > 3 + iteration = int(self.rng.integers(4, 11)) if target_agent == "mcts" else int(self.rng.integers(0, 11)) + + query_length = int(self.rng.integers(10, 5001)) + has_rag_context = bool(self.rng.choice([True, False])) + + return MetaControllerFeatures( + hrm_confidence=hrm_confidence, + trm_confidence=trm_confidence, + mcts_value=mcts_value, + consensus_score=consensus_score, + last_agent=last_agent, + iteration=iteration, + query_length=query_length, + has_rag_context=has_rag_context, + ) + + def _generate_forced_features(self, target_agent: str) -> MetaControllerFeatures: + """ + Generate features that will definitely select a specific agent. + + Args: + target_agent: The agent to force selection of. + + Returns: + MetaControllerFeatures that will result in target_agent selection. + """ + if target_agent == "hrm": + hrm_confidence = 0.85 + trm_confidence = 0.3 + mcts_value = 0.3 + iteration = 2 + elif target_agent == "trm": + hrm_confidence = 0.3 + trm_confidence = 0.85 + mcts_value = 0.3 + iteration = 2 + else: # mcts + hrm_confidence = 0.5 + trm_confidence = 0.5 + mcts_value = 0.75 + iteration = 5 + + avg_confidence = (hrm_confidence + trm_confidence + mcts_value) / 3.0 + noise = float(self.rng.uniform(-0.05, 0.05)) + consensus_score = float(np.clip(avg_confidence + noise, 0.0, 1.0)) + + return MetaControllerFeatures( + hrm_confidence=hrm_confidence, + trm_confidence=trm_confidence, + mcts_value=mcts_value, + consensus_score=consensus_score, + last_agent=str(self.rng.choice(["none", "hrm", "trm", "mcts"])), + iteration=iteration, + query_length=int(self.rng.integers(10, 5001)), + has_rag_context=bool(self.rng.choice([True, False])), + ) + + def to_tensor_dataset( + self, features_list: list[MetaControllerFeatures], labels_list: list[str] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert features and labels to PyTorch tensors. + + Uses normalize_features to convert each feature set to a 10-dimensional + vector, and converts string labels to numeric indices. + + Args: + features_list: List of MetaControllerFeatures instances. + labels_list: List of agent name strings ('hrm', 'trm', 'mcts'). + + Returns: + Tuple of (X tensor shape (N, 10), y tensor shape (N,)). + X contains normalized features as float32. + y contains label indices as int64. + + Raises: + ValueError: If lists have different lengths or are empty. + KeyError: If labels_list contains invalid agent names. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.generate_dataset(10) + >>> X, y = generator.to_tensor_dataset(features, labels) + >>> X.shape + torch.Size([10, 10]) + >>> y.shape + torch.Size([10]) + >>> X.dtype + torch.float32 + >>> y.dtype + torch.int64 + """ + if len(features_list) != len(labels_list): + raise ValueError( + f"features_list and labels_list must have same length, got {len(features_list)} and {len(labels_list)}" + ) + + if len(features_list) == 0: + raise ValueError("Cannot convert empty dataset to tensors") + + # Convert features to normalized vectors + X_list = [normalize_features(f) for f in features_list] + X = torch.tensor(X_list, dtype=torch.float32) + + # Convert labels to indices + try: + y_list = [self.LABEL_TO_INDEX[label] for label in labels_list] + except KeyError as e: + raise KeyError(f"Invalid agent label: {e}. Valid labels: {self.AGENT_NAMES}") + y = torch.tensor(y_list, dtype=torch.int64) + + return X, y + + def to_text_dataset( + self, features_list: list[MetaControllerFeatures], labels_list: list[str] + ) -> tuple[list[str], list[int]]: + """ + Convert features to text format and labels to indices. + + Uses features_to_text to create human-readable text representations + suitable for text-based models like BERT. + + Args: + features_list: List of MetaControllerFeatures instances. + labels_list: List of agent name strings ('hrm', 'trm', 'mcts'). + + Returns: + Tuple of (text_list, label_indices). + text_list contains structured text representations. + label_indices contains integer indices for each label. + + Raises: + ValueError: If lists have different lengths. + KeyError: If labels_list contains invalid agent names. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.generate_dataset(10) + >>> texts, indices = generator.to_text_dataset(features, labels) + >>> len(texts) + 10 + >>> all(isinstance(t, str) for t in texts) + True + >>> all(i in [0, 1, 2] for i in indices) + True + """ + if len(features_list) != len(labels_list): + raise ValueError( + f"features_list and labels_list must have same length, got {len(features_list)} and {len(labels_list)}" + ) + + # Convert features to text + text_list = [features_to_text(f) for f in features_list] + + # Convert labels to indices + try: + label_indices = [self.LABEL_TO_INDEX[label] for label in labels_list] + except KeyError as e: + raise KeyError(f"Invalid agent label: {e}. Valid labels: {self.AGENT_NAMES}") + + return text_list, label_indices + + def split_dataset( + self, + X: Any, + y: Any, + train_ratio: float = 0.7, + val_ratio: float = 0.15, + ) -> dict[str, Any]: + """ + Split dataset into train, validation, and test sets. + + Shuffles the data and splits it according to the specified ratios. + The test ratio is automatically calculated as (1 - train_ratio - val_ratio). + + Args: + X: Feature data (tensor, array, or list). + y: Label data (tensor, array, or list). + train_ratio: Proportion for training set. Defaults to 0.7. + val_ratio: Proportion for validation set. Defaults to 0.15. + + Returns: + Dictionary with keys: + - 'X_train': Training features + - 'y_train': Training labels + - 'X_val': Validation features + - 'y_val': Validation labels + - 'X_test': Test features + - 'y_test': Test labels + + Raises: + ValueError: If ratios are invalid or data sizes don't match. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.generate_dataset(100) + >>> X, y = generator.to_tensor_dataset(features, labels) + >>> splits = generator.split_dataset(X, y, 0.7, 0.15) + >>> 'X_train' in splits + True + >>> splits['X_train'].shape[0] == 70 + True + """ + # Validate ratios + if not (0 < train_ratio < 1): + raise ValueError(f"train_ratio must be in (0, 1), got {train_ratio}") + if not (0 < val_ratio < 1): + raise ValueError(f"val_ratio must be in (0, 1), got {val_ratio}") + if train_ratio + val_ratio >= 1: + raise ValueError(f"train_ratio + val_ratio must be < 1, got {train_ratio + val_ratio}") + + # Get dataset size + n_samples = X.shape[0] if isinstance(X, (torch.Tensor, np.ndarray)) else len(X) + + n_labels = y.shape[0] if isinstance(y, (torch.Tensor, np.ndarray)) else len(y) + + if n_samples != n_labels: + raise ValueError(f"X and y must have same number of samples, got {n_samples} and {n_labels}") + + if n_samples == 0: + raise ValueError("Cannot split empty dataset") + + # Generate shuffled indices + indices = self.rng.permutation(n_samples) + + # Calculate split points + train_end = int(n_samples * train_ratio) + val_end = train_end + int(n_samples * val_ratio) + + train_indices = indices[:train_end] + val_indices = indices[train_end:val_end] + test_indices = indices[val_end:] + + # Split data based on type + if isinstance(X, (torch.Tensor, np.ndarray)): + X_train = X[train_indices] + X_val = X[val_indices] + X_test = X[test_indices] + y_train = y[train_indices] + y_val = y[val_indices] + y_test = y[test_indices] + else: + # Assume list-like + X_train = [X[i] for i in train_indices] + X_val = [X[i] for i in val_indices] + X_test = [X[i] for i in test_indices] + y_train = [y[i] for i in train_indices] + y_val = [y[i] for i in val_indices] + y_test = [y[i] for i in test_indices] + + return { + "X_train": X_train, + "y_train": y_train, + "X_val": X_val, + "y_val": y_val, + "X_test": X_test, + "y_test": y_test, + } + + def save_dataset( + self, + features_list: list[MetaControllerFeatures], + labels_list: list[str], + path: str, + ) -> None: + """ + Save dataset to a JSON file. + + Converts MetaControllerFeatures to dictionaries for JSON serialization. + + Args: + features_list: List of MetaControllerFeatures instances. + labels_list: List of agent name strings. + path: Path to save the JSON file. + + Raises: + ValueError: If lists have different lengths. + IOError: If file cannot be written. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.generate_dataset(10) + >>> generator.save_dataset(features, labels, 'dataset.json') + """ + if len(features_list) != len(labels_list): + raise ValueError( + f"features_list and labels_list must have same length, got {len(features_list)} and {len(labels_list)}" + ) + + # Convert to serializable format + data = { + "seed": self.seed, + "num_samples": len(features_list), + "samples": [ + {"features": asdict(f), "label": label} for f, label in zip(features_list, labels_list, strict=False) + ], + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + def load_dataset(self, path: str) -> tuple[list[MetaControllerFeatures], list[str]]: + """ + Load dataset from a JSON file. + + Reconstructs MetaControllerFeatures from saved dictionaries. + + Args: + path: Path to the JSON file to load. + + Returns: + Tuple of (features_list, labels_list). + + Raises: + IOError: If file cannot be read. + KeyError: If JSON structure is invalid. + TypeError: If data types are incorrect. + + Example: + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.load_dataset('dataset.json') + >>> isinstance(features[0], MetaControllerFeatures) + True + """ + with open(path, encoding="utf-8") as f: + data = json.load(f) + + features_list: list[MetaControllerFeatures] = [] + labels_list: list[str] = [] + + for sample in data["samples"]: + features_dict = sample["features"] + features = MetaControllerFeatures( + hrm_confidence=float(features_dict["hrm_confidence"]), + trm_confidence=float(features_dict["trm_confidence"]), + mcts_value=float(features_dict["mcts_value"]), + consensus_score=float(features_dict["consensus_score"]), + last_agent=str(features_dict["last_agent"]), + iteration=int(features_dict["iteration"]), + query_length=int(features_dict["query_length"]), + has_rag_context=bool(features_dict["has_rag_context"]), + ) + features_list.append(features) + labels_list.append(str(sample["label"])) + + return features_list, labels_list diff --git a/src/training/experiment_tracker.py b/src/training/experiment_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7f775058214429a653ee370d623371483ec7f48a --- /dev/null +++ b/src/training/experiment_tracker.py @@ -0,0 +1,645 @@ +""" +Experiment Tracking Integration Module. + +Provides unified interface for: +- Braintrust experiment tracking +- Weights & Biases (W&B) logging +- Metric collection and visualization +- Model artifact versioning +""" + +import logging +import os +import time +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ExperimentConfig: + """Configuration for experiment tracking.""" + + project_name: str + experiment_name: str + tags: list[str] = field(default_factory=list) + description: str = "" + save_artifacts: bool = True + log_frequency: int = 1 # Log every N steps + + +@dataclass +class TrainingMetrics: + """Standard training metrics.""" + + epoch: int + step: int + train_loss: float + val_loss: float | None = None + accuracy: float | None = None + learning_rate: float | None = None + timestamp: float = field(default_factory=time.time) + custom_metrics: dict[str, float] = field(default_factory=dict) + + +class BraintrustTracker: + """ + Braintrust experiment tracking integration. + + Provides: + - Experiment initialization and management + - Metric logging with automatic visualization + - Hyperparameter tracking + - Model evaluation scoring + - Artifact versioning + """ + + def __init__(self, api_key: str | None = None, project_name: str = "mcts-neural-meta-controller"): + """ + Initialize Braintrust tracker. + + Args: + api_key: Braintrust API key (or from BRAINTRUST_API_KEY env var) + project_name: Project name in Braintrust + """ + self.api_key = api_key or os.getenv("BRAINTRUST_API_KEY") + self.project_name = project_name + self._experiment = None + self._experiment_id = None + self._metrics_buffer: list[dict[str, Any]] = [] + self._initialized = False + + if not self.api_key: + logger.warning("BRAINTRUST_API_KEY not set. Using offline mode.") + self._offline_mode = True + else: + self._offline_mode = False + self._initialize_client() + + def _initialize_client(self): + """Initialize Braintrust client.""" + try: + import braintrust + + braintrust.login(api_key=self.api_key) + self._bt = braintrust + self._initialized = True + logger.info(f"Braintrust client initialized for project: {self.project_name}") + except ImportError: + logger.error("braintrust library not installed. Run: pip install braintrust") + self._offline_mode = True + except Exception as e: + logger.error(f"Failed to initialize Braintrust: {e}") + self._offline_mode = True + + def init_experiment( + self, + name: str, + description: str = "", + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> str: + """ + Initialize a new experiment. + + Args: + name: Experiment name (e.g., "rnn_meta_controller_v2") + description: Experiment description + tags: List of tags for filtering + metadata: Additional metadata + + Returns: + Experiment ID + """ + if self._offline_mode: + exp_id = f"offline_{int(time.time())}" + logger.info(f"Created offline experiment: {exp_id}") + self._experiment_id = exp_id + self._experiment_config = { + "name": name, + "description": description, + "tags": tags or [], + "metadata": metadata or {}, + "start_time": datetime.now().isoformat(), + } + return exp_id + + try: + self._experiment = self._bt.init( + project=self.project_name, + experiment=name, + ) + self._experiment_id = self._experiment.id + logger.info(f"Created Braintrust experiment: {name} (ID: {self._experiment_id})") + return self._experiment_id + except Exception as e: + logger.error(f"Failed to create experiment: {e}") + return self.init_experiment(name, description, tags, metadata) # Fallback to offline + + def log_hyperparameters(self, params: dict[str, Any]): + """ + Log hyperparameters for the experiment. + + Args: + params: Dictionary of hyperparameters + """ + logger.info(f"Logging hyperparameters: {params}") + + if self._offline_mode: + self._experiment_config["hyperparameters"] = params + return + + try: + if self._experiment: + # Braintrust uses metadata for hyperparameters + self._experiment.log( + input="hyperparameters", + output=params, + metadata={"type": "hyperparameters"}, + ) + except Exception as e: + logger.error(f"Failed to log hyperparameters: {e}") + + def log_metric( + self, + name: str, + value: float, + step: int | None = None, + timestamp: float | None = None, + ): + """ + Log a single metric. + + Args: + name: Metric name + value: Metric value + step: Optional step number + timestamp: Optional timestamp + """ + metric_data = { + "name": name, + "value": value, + "step": step or len(self._metrics_buffer), + "timestamp": timestamp or time.time(), + } + + self._metrics_buffer.append(metric_data) + + if self._offline_mode: + logger.debug(f"Metric logged (offline): {name}={value}") + return + + try: + if self._experiment: + self._experiment.log( + input=f"metric_{name}", + output={"value": value}, + scores={name: value}, + metadata={"step": step}, + ) + except Exception as e: + logger.error(f"Failed to log metric {name}: {e}") + + def log_training_step(self, metrics: TrainingMetrics): + """ + Log a complete training step. + + Args: + metrics: TrainingMetrics object + """ + self.log_metric("train_loss", metrics.train_loss, step=metrics.step) + + if metrics.val_loss is not None: + self.log_metric("val_loss", metrics.val_loss, step=metrics.step) + + if metrics.accuracy is not None: + self.log_metric("accuracy", metrics.accuracy, step=metrics.step) + + if metrics.learning_rate is not None: + self.log_metric("learning_rate", metrics.learning_rate, step=metrics.step) + + for key, value in metrics.custom_metrics.items(): + self.log_metric(key, value, step=metrics.step) + + def log_evaluation( + self, + input_data: Any, + output: Any, + expected: Any, + scores: dict[str, float], + metadata: dict[str, Any] | None = None, + ): + """ + Log an evaluation result. + + Args: + input_data: Input to the model + output: Model output + expected: Expected output + scores: Dictionary of scores (e.g., accuracy, f1) + metadata: Additional metadata + """ + if self._offline_mode: + logger.info(f"Evaluation logged (offline): scores={scores}") + return + + try: + if self._experiment: + self._experiment.log( + input=input_data, + output=output, + expected=expected, + scores=scores, + metadata=metadata or {}, + ) + except Exception as e: + logger.error(f"Failed to log evaluation: {e}") + + def log_artifact(self, path: str | Path, name: str | None = None): + """ + Log a model artifact. + + Args: + path: Path to artifact file + name: Optional artifact name + """ + path = Path(path) + if not path.exists(): + logger.warning(f"Artifact not found: {path}") + return + + logger.info(f"Logging artifact: {path}") + + if self._offline_mode: + if "artifacts" not in self._experiment_config: + self._experiment_config["artifacts"] = [] + self._experiment_config["artifacts"].append(str(path)) + return + + # Braintrust artifact logging would go here + # For now, just log the path + try: + if self._experiment: + self._experiment.log( + input=f"artifact_{name or path.name}", + output={"path": str(path), "name": name or path.name}, + metadata={"artifact_path": str(path), "artifact_name": name or path.name}, + ) + except Exception as e: + logger.error(f"Failed to log artifact: {e}") + + def get_summary(self) -> dict[str, Any]: + """ + Get experiment summary. + + Returns: + Dictionary with experiment summary + """ + if self._offline_mode: + return { + "id": self._experiment_id, + "config": self._experiment_config, + "metrics_count": len(self._metrics_buffer), + "offline": True, + } + + return { + "id": self._experiment_id, + "project": self.project_name, + "metrics_count": len(self._metrics_buffer), + "offline": False, + } + + def end_experiment(self): + """End the current experiment.""" + summary = self.get_summary() + logger.info(f"Experiment ended: {summary}") + + if not self._offline_mode and self._experiment: + try: + # Braintrust experiments auto-close, but we'll try explicit close if available + if hasattr(self._experiment, "close"): + self._experiment.close() + elif hasattr(self._experiment, "flush"): + self._experiment.flush() + except Exception as e: + logger.error(f"Failed to end experiment: {e}") + + self._experiment = None + self._experiment_id = None + self._metrics_buffer = [] + + return summary + + +class WandBTracker: + """ + Weights & Biases experiment tracking integration. + + Provides: + - Real-time metric visualization + - Hyperparameter sweep management + - Model artifact logging + - Collaborative experiment comparison + """ + + def __init__( + self, + api_key: str | None = None, + project_name: str = "mcts-neural-meta-controller", + entity: str | None = None, + ): + """ + Initialize W&B tracker. + + Args: + api_key: W&B API key (or from WANDB_API_KEY env var) + project_name: Project name in W&B + entity: W&B entity (team or username) + """ + self.api_key = api_key or os.getenv("WANDB_API_KEY") + self.project_name = project_name + self.entity = entity + self._run = None + self._initialized = False + self._offline_mode = os.getenv("WANDB_MODE") == "offline" + + if not self.api_key and not self._offline_mode: + logger.warning("WANDB_API_KEY not set. Using offline mode.") + self._offline_mode = True + os.environ["WANDB_MODE"] = "offline" + else: + self._initialize_client() + + def _initialize_client(self): + """Initialize W&B client.""" + try: + import wandb + + if self.api_key: + wandb.login(key=self.api_key) + + self._wandb = wandb + self._initialized = True + logger.info(f"W&B client initialized for project: {self.project_name}") + except ImportError: + logger.error("wandb library not installed. Run: pip install wandb") + self._offline_mode = True + except Exception as e: + logger.error(f"Failed to initialize W&B: {e}") + self._offline_mode = True + + def init_run( + self, + name: str, + config: dict[str, Any] | None = None, + tags: list[str] | None = None, + notes: str = "", + ): + """ + Initialize a new W&B run. + + Args: + name: Run name + config: Configuration dictionary + tags: List of tags + notes: Run notes/description + + Returns: + Run object + """ + if self._offline_mode: + logger.info(f"W&B run initialized (offline mode): {name}") + self._run_config = config or {} + return None + + try: + self._run = self._wandb.init( + project=self.project_name, + entity=self.entity, + name=name, + config=config, + tags=tags, + notes=notes, + ) + logger.info(f"W&B run initialized: {name}") + return self._run + except Exception as e: + logger.error(f"Failed to initialize W&B run: {e}") + self._offline_mode = True + return None + + def log(self, metrics: dict[str, Any], step: int | None = None): + """ + Log metrics to W&B. + + Args: + metrics: Dictionary of metrics + step: Optional step number + """ + if self._offline_mode: + logger.debug(f"W&B metrics (offline): {metrics}") + return + + try: + if self._run: + self._wandb.log(metrics, step=step) + except Exception as e: + logger.error(f"Failed to log to W&B: {e}") + + def log_training_step(self, metrics: TrainingMetrics): + """ + Log a complete training step to W&B. + + Args: + metrics: TrainingMetrics object + """ + log_data = { + "epoch": metrics.epoch, + "train_loss": metrics.train_loss, + } + + if metrics.val_loss is not None: + log_data["val_loss"] = metrics.val_loss + + if metrics.accuracy is not None: + log_data["accuracy"] = metrics.accuracy + + if metrics.learning_rate is not None: + log_data["learning_rate"] = metrics.learning_rate + + log_data.update(metrics.custom_metrics) + + self.log(log_data, step=metrics.step) + + def update_config(self, config: dict[str, Any]): + """ + Update run configuration. + + Args: + config: Configuration updates + """ + if self._offline_mode: + self._run_config.update(config) + return + + try: + if self._run: + self._wandb.config.update(config) + except Exception as e: + logger.error(f"Failed to update W&B config: {e}") + + def watch_model(self, model, log_freq: int = 100): + """ + Watch model gradients and parameters. + + Args: + model: PyTorch model + log_freq: Logging frequency + """ + if self._offline_mode: + return + + try: + if self._run: + self._wandb.watch(model, log="all", log_freq=log_freq) + except Exception as e: + logger.error(f"Failed to watch model: {e}") + + def log_artifact(self, path: str | Path, name: str, artifact_type: str = "model"): + """ + Log artifact to W&B. + + Args: + path: Path to artifact + name: Artifact name + artifact_type: Type of artifact (model, dataset, etc.) + """ + if self._offline_mode: + logger.info(f"Artifact logged (offline): {path}") + return + + try: + artifact = self._wandb.Artifact(name, type=artifact_type) + artifact.add_file(str(path)) + if self._run: + self._run.log_artifact(artifact) + logger.info(f"Artifact logged: {name}") + except Exception as e: + logger.error(f"Failed to log artifact: {e}") + + def finish(self): + """Finish the W&B run.""" + if self._offline_mode: + logger.info("W&B run finished (offline)") + return + + try: + if self._run: + self._run.finish() + logger.info("W&B run finished") + except Exception as e: + logger.error(f"Failed to finish W&B run: {e}") + + +class UnifiedExperimentTracker: + """ + Unified experiment tracker that coordinates both Braintrust and W&B. + + Provides single interface for: + - Dual logging to both platforms + - Fallback handling + - Consistent metric tracking + """ + + def __init__( + self, + braintrust_api_key: str | None = None, + wandb_api_key: str | None = None, + project_name: str = "mcts-neural-meta-controller", + ): + """ + Initialize unified tracker. + + Args: + braintrust_api_key: Braintrust API key + wandb_api_key: W&B API key + project_name: Project name for both platforms + """ + self.bt = BraintrustTracker(api_key=braintrust_api_key, project_name=project_name) + self.wandb = WandBTracker(api_key=wandb_api_key, project_name=project_name) + self.project_name = project_name + + def init_experiment( + self, + name: str, + config: dict[str, Any] | None = None, + description: str = "", + tags: list[str] | None = None, + ): + """ + Initialize experiment on both platforms. + + Args: + name: Experiment/run name + config: Configuration dictionary + description: Description + tags: List of tags + """ + self.bt.init_experiment(name, description, tags) + self.wandb.init_run(name, config, tags, description) + + if config: + self.bt.log_hyperparameters(config) + + logger.info(f"Unified experiment initialized: {name}") + + def log_metrics(self, metrics: TrainingMetrics): + """ + Log training metrics to both platforms. + + Args: + metrics: TrainingMetrics object + """ + self.bt.log_training_step(metrics) + self.wandb.log_training_step(metrics) + + def log_evaluation( + self, + input_data: Any, + output: Any, + expected: Any, + scores: dict[str, float], + ): + """ + Log evaluation to Braintrust. + + Args: + input_data: Input data + output: Model output + expected: Expected output + scores: Evaluation scores + """ + self.bt.log_evaluation(input_data, output, expected, scores) + self.wandb.log(scores) + + def log_artifact(self, path: str | Path, name: str): + """ + Log artifact to both platforms. + + Args: + path: Path to artifact + name: Artifact name + """ + self.bt.log_artifact(path, name) + self.wandb.log_artifact(path, name) + + def finish(self): + """End tracking on both platforms.""" + bt_summary = self.bt.end_experiment() + self.wandb.finish() + logger.info("Unified experiment ended") + return bt_summary diff --git a/src/training/performance_monitor.py b/src/training/performance_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..144de22cdfb31eb8f211eb5e74f6febbde5c19ce --- /dev/null +++ b/src/training/performance_monitor.py @@ -0,0 +1,370 @@ +""" +Performance Monitoring System for LangGraph Multi-Agent MCTS. + +Tracks and analyzes system performance including: +- Inference latency +- Memory usage +- Training metrics +- Cache efficiency +- Throughput statistics +""" + +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import psutil +import torch + + +@dataclass +class PerformanceMetrics: + """Container for performance metrics.""" + + # Timing metrics (milliseconds) + hrm_decomposition_time: float = 0.0 + mcts_exploration_time: float = 0.0 + trm_refinement_time: float = 0.0 + total_inference_time: float = 0.0 + network_forward_time: float = 0.0 + + # Memory metrics (GB) + cpu_memory_used: float = 0.0 + gpu_memory_used: float = 0.0 + gpu_memory_allocated: float = 0.0 + + # Training metrics + policy_loss: float = 0.0 + value_loss: float = 0.0 + total_loss: float = 0.0 + learning_rate: float = 0.0 + + # MCTS metrics + mcts_simulations: int = 0 + cache_hit_rate: float = 0.0 + avg_tree_depth: float = 0.0 + + # Convergence metrics + hrm_halt_step: int = 0 + trm_convergence_step: int = 0 + + timestamp: float = field(default_factory=time.time) + + +class PerformanceMonitor: + """ + Track and analyze system performance metrics. + + Features: + - Rolling window statistics + - Automatic anomaly detection + - Performance alerts + - Export to various formats (dict, JSON, wandb) + """ + + def __init__( + self, + window_size: int = 100, + enable_gpu_monitoring: bool = True, + alert_threshold_ms: float = 1000.0, + ): + """ + Initialize performance monitor. + + Args: + window_size: Number of recent measurements to keep + enable_gpu_monitoring: Whether to monitor GPU usage + alert_threshold_ms: Threshold for slow inference alerts + """ + self.window_size = window_size + self.enable_gpu_monitoring = enable_gpu_monitoring and torch.cuda.is_available() + self.alert_threshold_ms = alert_threshold_ms + + # Time series data + self.metrics_history: deque = deque(maxlen=window_size) + + # Individual metric queues for faster access + self._metric_queues: dict[str, deque] = { + "hrm_decomposition_time": deque(maxlen=window_size), + "mcts_exploration_time": deque(maxlen=window_size), + "trm_refinement_time": deque(maxlen=window_size), + "total_inference_time": deque(maxlen=window_size), + "network_forward_time": deque(maxlen=window_size), + "cpu_memory_used": deque(maxlen=window_size), + "gpu_memory_used": deque(maxlen=window_size), + "policy_loss": deque(maxlen=window_size), + "value_loss": deque(maxlen=window_size), + "total_loss": deque(maxlen=window_size), + "cache_hit_rate": deque(maxlen=window_size), + } + + # Counters + self.total_inferences = 0 + self.slow_inference_count = 0 + + # Process info + self.process = psutil.Process() + + def log_timing(self, stage: str, elapsed_ms: float): + """ + Log execution time for a processing stage. + + Args: + stage: Stage name (e.g., "hrm_decomposition", "mcts_exploration") + elapsed_ms: Elapsed time in milliseconds + """ + metric_name = f"{stage}_time" + if metric_name in self._metric_queues: + self._metric_queues[metric_name].append(elapsed_ms) + + def log_memory(self): + """Log current memory usage.""" + # CPU memory + memory_info = self.process.memory_info() + cpu_memory_gb = memory_info.rss / (1024**3) # Bytes to GB + self._metric_queues["cpu_memory_used"].append(cpu_memory_gb) + + # GPU memory + if self.enable_gpu_monitoring: + gpu_memory_gb = torch.cuda.memory_allocated() / (1024**3) + self._metric_queues["gpu_memory_used"].append(gpu_memory_gb) + + def log_loss(self, policy_loss: float, value_loss: float, total_loss: float): + """ + Log training losses. + + Args: + policy_loss: Policy head loss + value_loss: Value head loss + total_loss: Combined loss + """ + self._metric_queues["policy_loss"].append(policy_loss) + self._metric_queues["value_loss"].append(value_loss) + self._metric_queues["total_loss"].append(total_loss) + + def log_mcts_stats(self, cache_hit_rate: float, simulations: int = 0): # noqa: ARG002 + """ + Log MCTS statistics. + + Args: + cache_hit_rate: Cache hit rate (0-1) + simulations: Number of simulations performed + """ + self._metric_queues["cache_hit_rate"].append(cache_hit_rate) + + def log_inference(self, total_time_ms: float): + """ + Log complete inference. + + Args: + total_time_ms: Total inference time in milliseconds + """ + self.total_inferences += 1 + self._metric_queues["total_inference_time"].append(total_time_ms) + + # Check for slow inference + if total_time_ms > self.alert_threshold_ms: + self.slow_inference_count += 1 + + def get_stats(self, metric_name: str | None = None) -> dict[str, Any]: + """ + Get performance statistics. + + Args: + metric_name: Specific metric to get stats for (None = all metrics) + + Returns: + Dictionary of statistics + """ + if metric_name: + return self._compute_metric_stats(metric_name) + + # Compute stats for all metrics + stats = {} + for name, queue in self._metric_queues.items(): + if len(queue) > 0: + stats[name] = self._compute_metric_stats(name) + + # Add system stats + stats["system"] = { + "total_inferences": self.total_inferences, + "slow_inference_count": self.slow_inference_count, + "slow_inference_rate": ( + self.slow_inference_count / self.total_inferences if self.total_inferences > 0 else 0.0 + ), + } + + return stats + + def _compute_metric_stats(self, metric_name: str) -> dict[str, float]: + """Compute statistics for a single metric.""" + if metric_name not in self._metric_queues: + return {} + + values = list(self._metric_queues[metric_name]) + if not values: + return {} + + return { + "mean": float(np.mean(values)), + "std": float(np.std(values)), + "min": float(np.min(values)), + "max": float(np.max(values)), + "median": float(np.median(values)), + "p95": float(np.percentile(values, 95)), + "p99": float(np.percentile(values, 99)), + "count": len(values), + } + + def get_current_memory(self) -> dict[str, float]: + """Get current memory usage snapshot.""" + memory = {} + + # CPU memory + memory_info = self.process.memory_info() + memory["cpu_rss_gb"] = memory_info.rss / (1024**3) + memory["cpu_vms_gb"] = memory_info.vms / (1024**3) + + # System-wide CPU memory + system_memory = psutil.virtual_memory() + memory["system_used_percent"] = system_memory.percent + + # GPU memory + if self.enable_gpu_monitoring: + memory["gpu_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3) + memory["gpu_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3) + memory["gpu_max_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) + + return memory + + def alert_if_slow(self): + """Print alert if recent inferences are slow.""" + recent_times = list(self._metric_queues["total_inference_time"])[-10:] + if recent_times and np.mean(recent_times) > self.alert_threshold_ms: + print( + f"⚠️ Performance Alert: Avg inference time {np.mean(recent_times):.1f}ms " + f"(threshold: {self.alert_threshold_ms}ms)" + ) + + def print_summary(self): + """Print formatted summary of performance statistics.""" + print("\n" + "=" * 80) + print("Performance Summary") + print("=" * 80) + + stats = self.get_stats() + + # Timing statistics + print("\n[Timing Statistics (ms)]") + timing_metrics = [ + "total_inference_time", + "hrm_decomposition_time", + "mcts_exploration_time", + "trm_refinement_time", + "network_forward_time", + ] + for metric in timing_metrics: + if metric in stats: + s = stats[metric] + print( + f" {metric:30s}: mean={s['mean']:6.1f} " + f"std={s['std']:6.1f} p95={s['p95']:6.1f} max={s['max']:6.1f}" + ) + + # Memory statistics + print("\n[Memory Statistics (GB)]") + memory = self.get_current_memory() + print(f" CPU RSS: {memory['cpu_rss_gb']:.2f} GB") + print(f" System Memory Used: {memory['system_used_percent']:.1f}%") + if self.enable_gpu_monitoring: + print(f" GPU Allocated: {memory['gpu_allocated_gb']:.2f} GB") + print(f" GPU Reserved: {memory['gpu_reserved_gb']:.2f} GB") + + # Loss statistics + if "total_loss" in stats: + print("\n[Training Loss]") + for metric in ["policy_loss", "value_loss", "total_loss"]: + if metric in stats: + s = stats[metric] + print(f" {metric:20s}: mean={s['mean']:.4f} std={s['std']:.4f}") + + # System statistics + print("\n[System Statistics]") + sys_stats = stats.get("system", {}) + print(f" Total Inferences: {sys_stats.get('total_inferences', 0)}") + print(f" Slow Inferences: {sys_stats.get('slow_inference_count', 0)}") + print(f" Slow Inference Rate: {sys_stats.get('slow_inference_rate', 0):.2%}") + + # Cache statistics + if "cache_hit_rate" in stats: + s = stats["cache_hit_rate"] + print("\n[Cache Performance]") + print(f" Hit Rate: {s['mean']:.2%}") + + print("=" * 80 + "\n") + + def export_to_dict(self) -> dict[str, Any]: + """Export all statistics to dictionary.""" + return { + "stats": self.get_stats(), + "memory": self.get_current_memory(), + "window_size": self.window_size, + } + + def export_to_wandb(self, step: int) -> dict[str, float]: # noqa: ARG002 + """ + Export metrics for Weights & Biases logging. + + Args: + step: Training step/iteration + + Returns: + Flattened metrics dictionary + """ + stats = self.get_stats() + wandb_metrics = {} + + # Flatten nested statistics + for metric_name, metric_stats in stats.items(): + if metric_name == "system": + for key, value in metric_stats.items(): + wandb_metrics[f"system/{key}"] = value + elif isinstance(metric_stats, dict): + # Log mean and p95 for each metric + wandb_metrics[f"{metric_name}/mean"] = metric_stats.get("mean", 0) + wandb_metrics[f"{metric_name}/p95"] = metric_stats.get("p95", 0) + + # Add memory + memory = self.get_current_memory() + for key, value in memory.items(): + wandb_metrics[f"memory/{key}"] = value + + return wandb_metrics + + def reset(self): + """Reset all metrics.""" + for queue in self._metric_queues.values(): + queue.clear() + self.metrics_history.clear() + self.total_inferences = 0 + self.slow_inference_count = 0 + + +class TimingContext: + """Context manager for timing code blocks.""" + + def __init__(self, monitor: PerformanceMonitor, stage: str): + self.monitor = monitor + self.stage = stage + self.start_time = None + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.start_time is not None: + elapsed = (time.perf_counter() - self.start_time) * 1000 # ms + self.monitor.log_timing(self.stage, elapsed) diff --git a/src/training/replay_buffer.py b/src/training/replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..f99bc9be123725129857d770fd740767143b4da5 --- /dev/null +++ b/src/training/replay_buffer.py @@ -0,0 +1,389 @@ +""" +Experience Replay Buffer for LangGraph Multi-Agent MCTS Training. + +Implements: +- Prioritized experience replay +- Uniform sampling +- Efficient circular buffer storage +- Data augmentation support +""" + +import random +from collections import deque +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class Experience: + """Single training example from self-play.""" + + state: torch.Tensor # State representation + policy: np.ndarray # MCTS visit count distribution + value: float # Game outcome from this state's perspective + metadata: dict = None # Optional metadata (e.g., game_id, move_number) + + +class ReplayBuffer: + """ + Simple uniform sampling replay buffer. + + Stores recent experiences in a circular buffer and samples uniformly. + """ + + def __init__(self, capacity: int): + """ + Initialize replay buffer. + + Args: + capacity: Maximum number of experiences to store + """ + self.capacity = capacity + self.buffer: deque = deque(maxlen=capacity) + + def add(self, experience: Experience): + """Add an experience to the buffer.""" + self.buffer.append(experience) + + def add_batch(self, experiences: list[Experience]): + """Add multiple experiences.""" + for exp in experiences: + self.add(exp) + + def sample(self, batch_size: int) -> list[Experience]: + """ + Sample a batch of experiences uniformly. + + Args: + batch_size: Number of experiences to sample + + Returns: + List of sampled experiences + """ + if len(self.buffer) < batch_size: + return random.sample(self.buffer, len(self.buffer)) + + return random.sample(self.buffer, batch_size) + + def __len__(self) -> int: + """Return current buffer size.""" + return len(self.buffer) + + def is_ready(self, min_size: int) -> bool: + """Check if buffer has enough samples for training.""" + return len(self.buffer) >= min_size + + def clear(self): + """Clear all experiences from buffer.""" + self.buffer.clear() + + +class PrioritizedReplayBuffer: + """ + Prioritized Experience Replay (PER) buffer. + + Samples experiences with probability proportional to their TD error. + Helps focus training on more informative examples. + + Based on: "Prioritized Experience Replay" (Schaul et al., 2015) + """ + + def __init__( + self, + capacity: int, + alpha: float = 0.6, + beta_start: float = 0.4, + beta_frames: int = 100_000, + ): + """ + Initialize prioritized replay buffer. + + Args: + capacity: Maximum buffer size + alpha: Priority exponent (0 = uniform, 1 = full prioritization) + beta_start: Initial importance sampling weight + beta_frames: Number of frames to anneal beta to 1.0 + """ + self.capacity = capacity + self.alpha = alpha + self.beta_start = beta_start + self.beta_frames = beta_frames + self.frame = 1 + + # Storage + self.buffer: list[Experience | None] = [None] * capacity + self.priorities: np.ndarray = np.zeros(capacity, dtype=np.float32) + self.position = 0 + self.size = 0 + + def _get_beta(self) -> float: + """Get current beta value (anneals from beta_start to 1.0).""" + return min(1.0, self.beta_start + (1.0 - self.beta_start) * self.frame / self.beta_frames) + + def add(self, experience: Experience, priority: float | None = None): + """ + Add experience with priority. + + Args: + experience: Experience to add + priority: Priority value (uses max priority if None) + """ + if priority is None: + # New experiences get max priority + priority = self.priorities.max() if self.size > 0 else 1.0 + + self.buffer[self.position] = experience + self.priorities[self.position] = priority**self.alpha + + self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def add_batch(self, experiences: list[Experience], priorities: list[float] | None = None): + """ + Add multiple experiences. + + Args: + experiences: List of experiences + priorities: Optional list of priorities (same length as experiences) + """ + if priorities is None: + priorities = [None] * len(experiences) + + for exp, priority in zip(experiences, priorities, strict=True): + self.add(exp, priority) + + def sample(self, batch_size: int) -> tuple[list[Experience], np.ndarray, np.ndarray]: + """ + Sample batch with prioritized sampling. + + Args: + batch_size: Number of experiences to sample + + Returns: + (experiences, indices, weights) tuple + - experiences: Sampled experiences + - indices: Buffer indices (for updating priorities) + - weights: Importance sampling weights + """ + if self.size < batch_size: + batch_size = self.size + + # Compute sampling probabilities + priorities = self.priorities[: self.size] + probs = priorities / priorities.sum() + + # Sample indices + indices = np.random.choice(self.size, batch_size, p=probs, replace=False) + + # Compute importance sampling weights + beta = self._get_beta() + weights = (self.size * probs[indices]) ** (-beta) + weights = weights / weights.max() # Normalize + + # Get experiences + experiences = [self.buffer[idx] for idx in indices] + + self.frame += 1 + + return experiences, indices, weights + + def update_priorities(self, indices: np.ndarray, priorities: np.ndarray): + """ + Update priorities for sampled experiences. + + Args: + indices: Buffer indices to update + priorities: New priority values + """ + for idx, priority in zip(indices, priorities, strict=True): + self.priorities[idx] = (priority + 1e-6) ** self.alpha # Small epsilon for stability + + def __len__(self) -> int: + """Return current buffer size.""" + return self.size + + def is_ready(self, min_size: int) -> bool: + """Check if buffer has enough samples.""" + return self.size >= min_size + + +class AugmentedReplayBuffer(ReplayBuffer): + """ + Replay buffer with data augmentation support. + + Applies symmetries/transformations to states during sampling + (e.g., rotations and flips for Go/chess boards). + """ + + def __init__(self, capacity: int, augmentation_fn=None): + """ + Initialize augmented replay buffer. + + Args: + capacity: Maximum buffer size + augmentation_fn: Function to apply augmentations + Should take (state, policy) and return augmented versions + """ + super().__init__(capacity) + self.augmentation_fn = augmentation_fn + + def sample(self, batch_size: int, apply_augmentation: bool = True) -> list[Experience]: + """ + Sample batch with optional augmentation. + + Args: + batch_size: Number of experiences to sample + apply_augmentation: Whether to apply data augmentation + + Returns: + List of (possibly augmented) experiences + """ + experiences = super().sample(batch_size) + + if apply_augmentation and self.augmentation_fn is not None: + augmented = [] + for exp in experiences: + aug_state, aug_policy = self.augmentation_fn(exp.state, exp.policy) + augmented.append( + Experience( + state=aug_state, + policy=aug_policy, + value=exp.value, + metadata=exp.metadata, + ) + ) + return augmented + + return experiences + + +# Data augmentation utilities for board games +class BoardGameAugmentation: + """ + Data augmentation for board games (Go, Chess, etc.). + + Applies symmetry transformations: rotations and reflections. + """ + + @staticmethod + def rotate_90(state: torch.Tensor, policy: np.ndarray, board_size: int = 19) -> tuple[torch.Tensor, np.ndarray]: + """ + Rotate state and policy 90 degrees clockwise. + + Args: + state: State tensor [channels, height, width] + policy: Policy array [action_size] + board_size: Board dimension + + Returns: + (rotated_state, rotated_policy) tuple + """ + # Rotate state + rotated_state = torch.rot90(state, k=1, dims=[1, 2]) + + # Rotate policy (assuming policy corresponds to board positions) + # This is game-specific; here's a simple version for square boards + if len(policy) == board_size * board_size + 1: # +1 for pass action + policy_board = policy[:-1].reshape(board_size, board_size) + rotated_policy_board = np.rot90(policy_board, k=1) + rotated_policy = np.append(rotated_policy_board.flatten(), policy[-1]) + else: + rotated_policy = policy # Can't rotate, return original + + return rotated_state, rotated_policy + + @staticmethod + def flip_horizontal( + state: torch.Tensor, policy: np.ndarray, board_size: int = 19 + ) -> tuple[torch.Tensor, np.ndarray]: + """Flip state and policy horizontally.""" + flipped_state = torch.flip(state, dims=[2]) # Flip width dimension + + if len(policy) == board_size * board_size + 1: + policy_board = policy[:-1].reshape(board_size, board_size) + flipped_policy_board = np.fliplr(policy_board) + flipped_policy = np.append(flipped_policy_board.flatten(), policy[-1]) + else: + flipped_policy = policy + + return flipped_state, flipped_policy + + @staticmethod + def random_symmetry( + state: torch.Tensor, policy: np.ndarray, board_size: int = 19 + ) -> tuple[torch.Tensor, np.ndarray]: + """ + Apply random symmetry transformation. + + Randomly selects from: + - Identity + - 90° rotation + - 180° rotation + - 270° rotation + - Horizontal flip + - Vertical flip + - Diagonal flip + - Anti-diagonal flip + """ + transform = random.randint(0, 7) + + if transform == 0: + # Identity + return state, policy + elif transform == 1: + # 90° rotation + return BoardGameAugmentation.rotate_90(state, policy, board_size) + elif transform == 2: + # 180° rotation + s, p = BoardGameAugmentation.rotate_90(state, policy, board_size) + return BoardGameAugmentation.rotate_90(s, p, board_size) + elif transform == 3: + # 270° rotation + s, p = BoardGameAugmentation.rotate_90(state, policy, board_size) + s, p = BoardGameAugmentation.rotate_90(s, p, board_size) + return BoardGameAugmentation.rotate_90(s, p, board_size) + elif transform == 4: + # Horizontal flip + return BoardGameAugmentation.flip_horizontal(state, policy, board_size) + elif transform == 5: + # Vertical flip + return torch.flip(state, dims=[1]), policy # Simplified + elif transform == 6: + # Diagonal flip (transpose) + return state.transpose(1, 2), policy # Simplified + else: + # Anti-diagonal flip + return torch.flip(state.transpose(1, 2), dims=[1, 2]), policy # Simplified + + +def collate_experiences(experiences: list[Experience]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Collate list of experiences into batched tensors. + + Args: + experiences: List of Experience objects + + Returns: + (states, policies, values) tuple of batched tensors + """ + states = torch.stack([exp.state for exp in experiences]) + + # Handle variable-sized policies by padding to max size + max_policy_size = max(len(exp.policy) for exp in experiences) + padded_policies = [] + for exp in experiences: + policy = exp.policy + if len(policy) < max_policy_size: + # Pad with zeros + padded = np.zeros(max_policy_size, dtype=policy.dtype) + padded[: len(policy)] = policy + padded_policies.append(padded) + else: + padded_policies.append(policy) + + policies = torch.from_numpy(np.stack(padded_policies)) + values = torch.tensor([exp.value for exp in experiences], dtype=torch.float32) + + return states, policies, values diff --git a/src/training/system_config.py b/src/training/system_config.py new file mode 100644 index 0000000000000000000000000000000000000000..67bc2a1d1045fd027ce37597589691f20e5ecc8c --- /dev/null +++ b/src/training/system_config.py @@ -0,0 +1,351 @@ +""" +System Configuration for LangGraph Multi-Agent MCTS with DeepMind-Style Learning. + +This module provides centralized configuration management for all framework components +including HRM, TRM, Neural MCTS, and training infrastructure. +""" + +from dataclasses import dataclass, field + +import torch + + +@dataclass +class HRMConfig: + """Configuration for Hierarchical Reasoning Model (HRM) Agent.""" + + # Model dimensions + h_dim: int = 512 # High-level planning dimension + l_dim: int = 256 # Low-level execution dimension + num_h_layers: int = 2 # Number of high-level layers + num_l_layers: int = 4 # Number of low-level layers + + # Halting and iteration control + max_outer_steps: int = 10 # Maximum planning steps + halt_threshold: float = 0.95 # Confidence threshold for halting + + # Training features + use_augmentation: bool = True # Use tactical augmentation + dropout: float = 0.1 + + # ACT (Adaptive Computation Time) parameters + ponder_epsilon: float = 0.01 # Small constant for numerical stability + max_ponder_steps: int = 16 # Maximum pondering steps + + +@dataclass +class TRMConfig: + """Configuration for Tiny Recursive Model (TRM) Agent.""" + + # Model architecture + latent_dim: int = 256 # Latent state dimension + num_recursions: int = 16 # Maximum recursion depth + hidden_dim: int = 512 # Hidden layer dimension + + # Deep supervision + deep_supervision: bool = True # Enable supervision at all recursion levels + supervision_weight_decay: float = 0.5 # Decay factor for deeper levels + + # Convergence criteria + convergence_threshold: float = 0.01 # L2 distance threshold + min_recursions: int = 3 # Minimum recursions before checking convergence + + # Training + dropout: float = 0.1 + use_layer_norm: bool = True + + +@dataclass +class MCTSConfig: + """Configuration for Neural-Guided MCTS.""" + + # Search parameters + num_simulations: int = 1600 # AlphaGo Zero used 1600 + c_puct: float = 1.25 # Exploration constant for PUCT + + # Dirichlet noise for exploration (applied at root) + dirichlet_epsilon: float = 0.25 # Mix of prior and noise + dirichlet_alpha: float = 0.3 # Game/task specific + + # Temperature for action selection + temperature_threshold: int = 30 # Move number to switch to greedy + temperature_init: float = 1.0 # Initial temperature + temperature_final: float = 0.1 # Final temperature + + # Virtual loss for parallel MCTS + virtual_loss: float = 3.0 # Discourage simultaneous exploration + num_parallel: int = 8 # Parallel search threads + + # Progressive widening + use_progressive_widening: bool = True + pw_k: float = 1.0 + pw_alpha: float = 0.5 + + +@dataclass +class NeuralNetworkConfig: + """Configuration for Policy-Value Networks.""" + + # ResNet architecture + num_res_blocks: int = 19 # AlphaGo Zero used 19 or 39 + num_channels: int = 256 # Feature channels + + # Policy head + policy_conv_channels: int = 2 + policy_fc_dim: int = 256 + + # Value head + value_conv_channels: int = 1 + value_fc_hidden: int = 256 + + # Regularization + use_batch_norm: bool = True + dropout: float = 0.0 # Usually 0 for ResNets with BN + weight_decay: float = 1e-4 + + # Input/Output + input_channels: int = 17 # Game/task specific + action_size: int = 362 # Game/task specific (e.g., Go: 19x19 + pass) + + +@dataclass +class TrainingConfig: + """Configuration for training pipeline.""" + + # Self-play generation + games_per_iteration: int = 25_000 + num_actors: int = 128 # Parallel self-play workers + + # Experience replay + buffer_size: int = 500_000 # Keep last N positions + batch_size: int = 2048 + + # Optimization + learning_rate: float = 0.2 # With LR schedule + momentum: float = 0.9 # SGD momentum + weight_decay: float = 1e-4 + + # Learning rate schedule + lr_schedule: str = "cosine" # "cosine", "step", "constant" + lr_decay_steps: int = 100 # For step schedule + lr_decay_gamma: float = 0.1 + + # Training loop + epochs_per_iteration: int = 1 + checkpoint_interval: int = 10 # Save every N iterations + + # Evaluation + evaluation_games: int = 400 # Games to evaluate new model + win_rate_threshold: float = 0.55 # Required win rate to replace best model + + # Early stopping + patience: int = 20 # Iterations without improvement + min_delta: float = 0.01 # Minimum improvement + + +@dataclass +class SystemConfig: + """ + Master configuration for the entire LangGraph Multi-Agent MCTS system. + + This provides centralized configuration management with sensible defaults + based on DeepMind's AlphaGo Zero and research on HRM/TRM architectures. + """ + + # Component configurations + hrm: HRMConfig = field(default_factory=HRMConfig) + trm: TRMConfig = field(default_factory=TRMConfig) + mcts: MCTSConfig = field(default_factory=MCTSConfig) + neural_net: NeuralNetworkConfig = field(default_factory=NeuralNetworkConfig) + training: TrainingConfig = field(default_factory=TrainingConfig) + + # System settings + device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu") + seed: int = 42 # For reproducibility + + # Performance optimizations + use_mixed_precision: bool = True # FP16 training + gradient_checkpointing: bool = False # Trade compute for memory + compile_model: bool = False # PyTorch 2.0 compilation + + # Distributed training + distributed: bool = False + world_size: int = 1 + rank: int = 0 + backend: str = "nccl" # "nccl" for GPU, "gloo" for CPU + + # Logging and monitoring + log_interval: int = 10 # Log every N iterations + use_wandb: bool = False # Weights & Biases integration + wandb_project: str = "langgraph-mcts-deepmind" + wandb_entity: str | None = None + + # Paths + checkpoint_dir: str = "./checkpoints" + data_dir: str = "./data" + log_dir: str = "./logs" + + def __post_init__(self): + """Validate configuration after initialization.""" + # Ensure device is valid + if self.device.startswith("cuda") and not torch.cuda.is_available(): + print("⚠️ CUDA requested but not available, falling back to CPU") + self.device = "cpu" + + # Adjust settings based on device + if self.device == "cpu": + self.use_mixed_precision = False + self.distributed = False + self.backend = "gloo" + + def to_dict(self) -> dict: + """Convert configuration to dictionary for logging.""" + return { + "hrm": self.hrm.__dict__, + "trm": self.trm.__dict__, + "mcts": self.mcts.__dict__, + "neural_net": self.neural_net.__dict__, + "training": self.training.__dict__, + "device": self.device, + "seed": self.seed, + "use_mixed_precision": self.use_mixed_precision, + "gradient_checkpointing": self.gradient_checkpointing, + "compile_model": self.compile_model, + "distributed": self.distributed, + } + + @classmethod + def from_dict(cls, config_dict: dict) -> "SystemConfig": + """Create configuration from dictionary.""" + config = cls() + + # Update nested configs + if "hrm" in config_dict: + for key, value in config_dict["hrm"].items(): + setattr(config.hrm, key, value) + + if "trm" in config_dict: + for key, value in config_dict["trm"].items(): + setattr(config.trm, key, value) + + if "mcts" in config_dict: + for key, value in config_dict["mcts"].items(): + setattr(config.mcts, key, value) + + if "neural_net" in config_dict: + for key, value in config_dict["neural_net"].items(): + setattr(config.neural_net, key, value) + + if "training" in config_dict: + for key, value in config_dict["training"].items(): + setattr(config.training, key, value) + + # Update system settings + for key in [ + "device", + "seed", + "use_mixed_precision", + "gradient_checkpointing", + "compile_model", + "distributed", + "log_interval", + "use_wandb", + ]: + if key in config_dict: + setattr(config, key, config_dict[key]) + + return config + + def save(self, path: str): + """Save configuration to file.""" + import json + + with open(path, "w") as f: + json.dump(self.to_dict(), f, indent=2) + + @classmethod + def load(cls, path: str) -> "SystemConfig": + """Load configuration from file.""" + import json + + with open(path) as f: + config_dict = json.load(f) + return cls.from_dict(config_dict) + + +# Preset configurations for different use cases +def get_small_config() -> SystemConfig: + """Configuration for fast experimentation (reduced model sizes).""" + config = SystemConfig() + + # Smaller models + config.hrm.h_dim = 256 + config.hrm.l_dim = 128 + config.trm.latent_dim = 128 + config.neural_net.num_res_blocks = 9 + config.neural_net.num_channels = 128 + + # Fewer simulations + config.mcts.num_simulations = 400 + config.training.games_per_iteration = 1000 + config.training.num_actors = 16 + + return config + + +def get_medium_config() -> SystemConfig: + """Configuration for balanced training (moderate resources).""" + config = SystemConfig() + + # Default settings are already medium + config.neural_net.num_res_blocks = 19 + config.mcts.num_simulations = 800 + config.training.games_per_iteration = 10_000 + config.training.num_actors = 64 + + return config + + +def get_large_config() -> SystemConfig: + """Configuration for maximum performance (high resources).""" + config = SystemConfig() + + # Larger models + config.hrm.h_dim = 768 + config.hrm.l_dim = 384 + config.trm.latent_dim = 384 + config.neural_net.num_res_blocks = 39 + config.neural_net.num_channels = 384 + + # More simulations + config.mcts.num_simulations = 3200 + config.training.games_per_iteration = 50_000 + config.training.num_actors = 256 + + # Optimization + config.use_mixed_precision = True + config.gradient_checkpointing = True + + return config + + +def get_arc_agi_config() -> SystemConfig: + """Configuration optimized for ARC-AGI benchmark tasks.""" + config = SystemConfig() + + # ARC-AGI specific settings + config.hrm.h_dim = 512 + config.hrm.l_dim = 256 + config.hrm.max_outer_steps = 20 # Complex reasoning + + config.trm.num_recursions = 20 # Deep refinement + config.trm.convergence_threshold = 0.005 # Precise solutions + + config.mcts.num_simulations = 1600 + config.mcts.c_puct = 1.5 # More exploration for puzzle solving + + # Input/output for grid tasks + config.neural_net.input_channels = 11 # 10 colors + 1 empty + config.neural_net.action_size = 100 # Depends on grid size + + return config diff --git a/src/training/train_bert_lora.py b/src/training/train_bert_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..5edcf49d1482d8aff7afb7632f8c4931af183461 --- /dev/null +++ b/src/training/train_bert_lora.py @@ -0,0 +1,731 @@ +""" +Training script for BERT Meta-Controller with LoRA adapters. + +This module provides a training pipeline for fine-tuning BERT-based meta-controllers +using Low-Rank Adaptation (LoRA) for parameter-efficient training. It supports +synthetic data generation, dataset preparation, training with HuggingFace Trainer, +and model evaluation. +""" + +import argparse +import json +import logging +import warnings +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F + +# Handle optional dependencies gracefully +_TRANSFORMERS_AVAILABLE = False +_DATASETS_AVAILABLE = False + +try: + from transformers import ( + AutoTokenizer, + EvalPrediction, + Trainer, + TrainingArguments, + ) + + _TRANSFORMERS_AVAILABLE = True +except ImportError: + warnings.warn( + "transformers library not installed. Install it with: pip install transformers", + ImportWarning, + stacklevel=2, + ) + Trainer = None # type: ignore + TrainingArguments = None # type: ignore + AutoTokenizer = None # type: ignore + EvalPrediction = None # type: ignore + +try: + from datasets import Dataset + + _DATASETS_AVAILABLE = True +except ImportError: + warnings.warn( + "datasets library not installed. Install it with: pip install datasets", + ImportWarning, + stacklevel=2, + ) + Dataset = None # type: ignore + +from src.agents.meta_controller.bert_controller import BERTMetaController # noqa: E402 +from src.training.data_generator import MetaControllerDataGenerator # noqa: E402 + + +def setup_logging(log_level: int = logging.INFO) -> logging.Logger: + """ + Configure logging for the training script. + + Args: + log_level: Logging level (default: logging.INFO). + + Returns: + Configured logger instance. + """ + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + return logging.getLogger(__name__) + + +class BERTLoRATrainer: + """ + Trainer class for BERT Meta-Controller with LoRA adapters. + + This class provides a complete training pipeline including dataset preparation, + training with HuggingFace Trainer, evaluation, and model persistence. + + Attributes: + model_name: Name of the pre-trained BERT model. + lora_r: LoRA rank parameter. + lora_alpha: LoRA alpha scaling parameter. + lora_dropout: LoRA dropout rate. + lr: Learning rate for training. + batch_size: Training batch size. + epochs: Number of training epochs. + warmup_steps: Number of warmup steps for learning rate scheduler. + seed: Random seed for reproducibility. + device: PyTorch device for training. + controller: BERTMetaController instance. + tokenizer: BERT tokenizer. + logger: Logger instance. + + Example: + >>> trainer = BERTLoRATrainer( + ... model_name="prajjwal1/bert-mini", + ... lora_r=4, + ... epochs=5 + ... ) + >>> # Prepare and train + >>> train_dataset = trainer.prepare_dataset(train_texts, train_labels) + >>> val_dataset = trainer.prepare_dataset(val_texts, val_labels) + >>> results = trainer.train(train_texts, train_labels, val_texts, val_labels, "output") + """ + + def __init__( + self, + model_name: str = "prajjwal1/bert-mini", + lora_r: int = 4, + lora_alpha: int = 16, + lora_dropout: float = 0.1, + lr: float = 1e-3, + batch_size: int = 32, + epochs: int = 10, + warmup_steps: int = 100, + seed: int = 42, + device: str | None = None, + ) -> None: + """ + Initialize the BERT LoRA trainer. + + Args: + model_name: Pre-trained model name from HuggingFace. Defaults to "prajjwal1/bert-mini". + lora_r: LoRA rank parameter (lower = more compression). Defaults to 4. + lora_alpha: LoRA alpha scaling parameter. Defaults to 16. + lora_dropout: Dropout rate for LoRA layers. Defaults to 0.1. + lr: Learning rate for training. Defaults to 1e-3. + batch_size: Training batch size. Defaults to 32. + epochs: Number of training epochs. Defaults to 10. + warmup_steps: Number of warmup steps. Defaults to 100. + seed: Random seed for reproducibility. Defaults to 42. + device: Device for training ('cpu', 'cuda', 'mps'). If None, auto-detects. + + Raises: + ImportError: If required dependencies are not installed. + """ + if not _TRANSFORMERS_AVAILABLE: + raise ImportError( + "transformers library is required for BERTLoRATrainer. Install it with: pip install transformers" + ) + + if not _DATASETS_AVAILABLE: + raise ImportError("datasets library is required for BERTLoRATrainer. Install it with: pip install datasets") + + # Setup logging + self.logger = setup_logging() + self.logger.info("Initializing BERTLoRATrainer") + + # Store training parameters + self.model_name = model_name + self.lora_r = lora_r + self.lora_alpha = lora_alpha + self.lora_dropout = lora_dropout + self.lr = lr + self.batch_size = batch_size + self.epochs = epochs + self.warmup_steps = warmup_steps + self.seed = seed + + # Set random seeds for reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Initialize BERTMetaController with LoRA enabled + self.logger.info(f"Loading model: {model_name}") + self.controller = BERTMetaController( + name="BERTLoRATrainer", + seed=seed, + model_name=model_name, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + device=device, + use_lora=True, + ) + + # Store device and tokenizer for convenience + self.device = self.controller.device + self.tokenizer = self.controller.tokenizer + + # Log trainable parameters + params_info = self.controller.get_trainable_parameters() + self.logger.info( + f"Model parameters - Total: {params_info['total_params']:,}, " + f"Trainable: {params_info['trainable_params']:,} " + f"({params_info['trainable_percentage']:.2f}%)" + ) + + # Store trainer instance (will be created during training) + self._trainer: Trainer | None = None + + def prepare_dataset( + self, + texts: list[str], + labels: list[int], + ) -> Dataset: + """ + Prepare a HuggingFace Dataset from texts and labels. + + Tokenizes all texts and creates a dataset ready for training with + the HuggingFace Trainer. + + Args: + texts: List of text inputs (feature descriptions). + labels: List of integer labels (agent indices: 0=hrm, 1=trm, 2=mcts). + + Returns: + HuggingFace Dataset with tokenized inputs and labels. + + Raises: + ValueError: If texts and labels have different lengths. + + Example: + >>> trainer = BERTLoRATrainer() + >>> texts = ["HRM confidence: 0.8, TRM confidence: 0.3..."] + >>> labels = [0] # hrm + >>> dataset = trainer.prepare_dataset(texts, labels) + >>> 'input_ids' in dataset.features + True + """ + if len(texts) != len(labels): + raise ValueError(f"texts and labels must have same length, got {len(texts)} and {len(labels)}") + + self.logger.info(f"Preparing dataset with {len(texts)} samples") + + # Create initial dataset + dataset = Dataset.from_dict({"text": texts, "labels": labels}) + + # Tokenize function + def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: + return self.tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=512, + ) + + # Tokenize the dataset + tokenized_dataset = dataset.map( + tokenize_function, + batched=True, + remove_columns=["text"], + desc="Tokenizing", + ) + + # Set format for PyTorch + tokenized_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "labels"], + ) + + self.logger.info(f"Dataset prepared with {len(tokenized_dataset)} samples") + return tokenized_dataset + + def compute_metrics(self, eval_pred: EvalPrediction) -> dict[str, float]: + """ + Compute evaluation metrics from predictions. + + Args: + eval_pred: EvalPrediction object containing predictions and labels. + + Returns: + Dictionary containing computed metrics (accuracy). + + Example: + >>> # Called automatically by Trainer during evaluation + >>> metrics = trainer.compute_metrics(eval_pred) + >>> 'accuracy' in metrics + True + """ + predictions = eval_pred.predictions + labels = eval_pred.label_ids + + # Get predicted class indices + if isinstance(predictions, tuple): + predictions = predictions[0] + + preds = predictions.argmax(axis=-1) + + # Calculate accuracy + accuracy = (preds == labels).astype(float).mean() + + return {"accuracy": float(accuracy)} + + def train( + self, + train_texts: list[str], + train_labels: list[int], + val_texts: list[str], + val_labels: list[int], + output_dir: str, + ) -> dict[str, Any]: + """ + Train the BERT LoRA model. + + Creates training and validation datasets, configures the HuggingFace Trainer, + and runs the training loop. + + Args: + train_texts: List of training text inputs. + train_labels: List of training labels (integer indices). + val_texts: List of validation text inputs. + val_labels: List of validation labels. + output_dir: Directory to save model checkpoints and outputs. + + Returns: + Dictionary containing training history and results. + + Example: + >>> trainer = BERTLoRATrainer(epochs=3) + >>> history = trainer.train( + ... train_texts, train_labels, + ... val_texts, val_labels, + ... "output/bert_lora" + ... ) + >>> 'train_loss' in history + True + """ + self.logger.info("Starting training") + + # Create output directory + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Prepare datasets + train_dataset = self.prepare_dataset(train_texts, train_labels) + val_dataset = self.prepare_dataset(val_texts, val_labels) + + self.logger.info(f"Training samples: {len(train_dataset)}") + self.logger.info(f"Validation samples: {len(val_dataset)}") + + # Setup training arguments + training_args = TrainingArguments( + output_dir=str(output_path), + num_train_epochs=self.epochs, + per_device_train_batch_size=self.batch_size, + per_device_eval_batch_size=self.batch_size, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="accuracy", + greater_is_better=True, + warmup_steps=self.warmup_steps, + learning_rate=self.lr, + weight_decay=0.01, + logging_dir=str(output_path / "logs"), + logging_steps=10, + seed=self.seed, + report_to="none", # Disable wandb/tensorboard by default + save_total_limit=3, # Keep only last 3 checkpoints + ) + + # Set model to training mode + self.controller.model.train() + + # Create Trainer + self._trainer = Trainer( + model=self.controller.model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + compute_metrics=self.compute_metrics, + ) + + # Train the model + self.logger.info("Starting training loop") + train_result = self._trainer.train() + + # Log training results + self.logger.info("Training completed") + self.logger.info(f"Final training loss: {train_result.training_loss:.4f}") + + # Get training history + history = { + "train_loss": train_result.training_loss, + "train_runtime": train_result.metrics.get("train_runtime", 0), + "train_samples_per_second": train_result.metrics.get("train_samples_per_second", 0), + "epochs": self.epochs, + "final_metrics": train_result.metrics, + } + + # Evaluate on validation set + self.logger.info("Evaluating on validation set") + eval_results = self._trainer.evaluate() + history["eval_results"] = eval_results + self.logger.info(f"Validation accuracy: {eval_results.get('eval_accuracy', 0):.4f}") + + # Set model back to evaluation mode + self.controller.model.eval() + + return history + + def evaluate( + self, + test_texts: list[str], + test_labels: list[int], + ) -> dict[str, Any]: + """ + Evaluate the model on a test set. + + Args: + test_texts: List of test text inputs. + test_labels: List of test labels (integer indices). + + Returns: + Dictionary containing: + - loss: Average cross-entropy loss. + - accuracy: Classification accuracy. + - predictions: List of predicted class indices. + - probabilities: List of probability distributions. + + Example: + >>> trainer = BERTLoRATrainer() + >>> # After training... + >>> results = trainer.evaluate(test_texts, test_labels) + >>> 0.0 <= results['accuracy'] <= 1.0 + True + """ + self.logger.info(f"Evaluating on {len(test_texts)} test samples") + + # Prepare test dataset + self.prepare_dataset(test_texts, test_labels) + + # Set model to evaluation mode + self.controller.model.eval() + + # Collect predictions + all_predictions: list[int] = [] + all_probabilities: list[list[float]] = [] + total_loss = 0.0 + + with torch.no_grad(): + for i in range(len(test_texts)): + # Tokenize single sample + inputs = self.tokenizer( + test_texts[i], + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ) + + # Move to device + inputs = {k: v.to(self.device) for k, v in inputs.items()} + label_tensor = torch.tensor([test_labels[i]], device=self.device) + + # Forward pass + outputs = self.controller.model(**inputs) + logits = outputs.logits + + # Compute loss + loss = F.cross_entropy(logits, label_tensor) + total_loss += loss.item() + + # Get predictions + probs = F.softmax(logits, dim=-1) + pred_idx = torch.argmax(probs, dim=-1).item() + + all_predictions.append(pred_idx) + all_probabilities.append(probs[0].cpu().tolist()) + + # Calculate metrics + avg_loss = total_loss / len(test_texts) + correct = sum(1 for pred, label in zip(all_predictions, test_labels, strict=False) if pred == label) + accuracy = correct / len(test_labels) + + self.logger.info(f"Test Loss: {avg_loss:.4f}") + self.logger.info(f"Test Accuracy: {accuracy:.4f}") + + return { + "loss": avg_loss, + "accuracy": accuracy, + "predictions": all_predictions, + "probabilities": all_probabilities, + } + + def save_model(self, path: str) -> None: + """ + Save the LoRA adapter weights to disk. + + Args: + path: Directory path where the adapter weights will be saved. + + Example: + >>> trainer = BERTLoRATrainer() + >>> # After training... + >>> trainer.save_model("models/bert_lora_adapter") + """ + self.logger.info(f"Saving LoRA adapter weights to {path}") + self.controller.save_model(path) + self.logger.info("Model saved successfully") + + +def main() -> None: + """ + Main function for training BERT Meta-Controller with LoRA. + + Parses command-line arguments, generates or loads dataset, trains the model, + and saves results. + """ + parser = argparse.ArgumentParser(description="Train BERT Meta-Controller with LoRA adapters") + + # Model arguments + parser.add_argument( + "--model_name", + type=str, + default="prajjwal1/bert-mini", + help="Pre-trained model name from HuggingFace", + ) + parser.add_argument( + "--lora_r", + type=int, + default=4, + help="LoRA rank parameter", + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=16, + help="LoRA alpha scaling parameter", + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.1, + help="LoRA dropout rate", + ) + + # Training arguments + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Training batch size", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + help="Number of training epochs", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=100, + help="Number of warmup steps", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility", + ) + + # Data arguments + parser.add_argument( + "--num_samples", + type=int, + default=1000, + help="Number of samples to generate (if not loading from file)", + ) + parser.add_argument( + "--data_path", + type=str, + default=None, + help="Path to load existing dataset (JSON format)", + ) + parser.add_argument( + "--balanced", + action="store_true", + help="Generate balanced dataset (equal samples per class)", + ) + + # Output arguments + parser.add_argument( + "--output_dir", + type=str, + default="output/bert_lora", + help="Directory to save model and results", + ) + + args = parser.parse_args() + + # Setup logging + logger = setup_logging() + logger.info("=" * 60) + logger.info("BERT Meta-Controller LoRA Training") + logger.info("=" * 60) + + # Log configuration + logger.info("Configuration:") + for key, value in vars(args).items(): + logger.info(f" {key}: {value}") + logger.info("=" * 60) + + # Initialize data generator + data_generator = MetaControllerDataGenerator(seed=args.seed) + + # Generate or load dataset + if args.data_path is not None: + logger.info(f"Loading dataset from {args.data_path}") + features_list, labels_list = data_generator.load_dataset(args.data_path) + logger.info(f"Loaded {len(features_list)} samples") + else: + logger.info(f"Generating synthetic dataset with {args.num_samples} samples") + if args.balanced: + samples_per_class = args.num_samples // 3 + features_list, labels_list = data_generator.generate_balanced_dataset( + num_samples_per_class=samples_per_class + ) + logger.info(f"Generated balanced dataset with {samples_per_class} samples per class") + else: + features_list, labels_list = data_generator.generate_dataset(num_samples=args.num_samples) + + # Save generated dataset + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + dataset_path = output_path / "generated_dataset.json" + data_generator.save_dataset(features_list, labels_list, str(dataset_path)) + logger.info(f"Saved generated dataset to {dataset_path}") + + # Convert to text format + logger.info("Converting features to text format") + texts, label_indices = data_generator.to_text_dataset(features_list, labels_list) + + # Log class distribution + class_counts = {0: 0, 1: 0, 2: 0} + for label in label_indices: + class_counts[label] += 1 + logger.info("Class distribution:") + logger.info(f" HRM (0): {class_counts[0]} samples") + logger.info(f" TRM (1): {class_counts[1]} samples") + logger.info(f" MCTS (2): {class_counts[2]} samples") + + # Split dataset + logger.info("Splitting dataset into train/val/test sets") + splits = data_generator.split_dataset(texts, label_indices, train_ratio=0.7, val_ratio=0.15) + + train_texts = splits["X_train"] + train_labels = splits["y_train"] + val_texts = splits["X_val"] + val_labels = splits["y_val"] + test_texts = splits["X_test"] + test_labels = splits["y_test"] + + logger.info(f"Train set: {len(train_texts)} samples") + logger.info(f"Validation set: {len(val_texts)} samples") + logger.info(f"Test set: {len(test_texts)} samples") + + # Initialize trainer + logger.info("Initializing BERTLoRATrainer") + trainer = BERTLoRATrainer( + model_name=args.model_name, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + lr=args.lr, + batch_size=args.batch_size, + epochs=args.epochs, + warmup_steps=args.warmup_steps, + seed=args.seed, + ) + + # Train model + logger.info("Starting training") + train_history = trainer.train( + train_texts=train_texts, + train_labels=train_labels, + val_texts=val_texts, + val_labels=val_labels, + output_dir=args.output_dir, + ) + + # Evaluate on test set + logger.info("Evaluating on test set") + test_results = trainer.evaluate(test_texts, test_labels) + + # Save final model + final_model_path = Path(args.output_dir) / "final_model" + trainer.save_model(str(final_model_path)) + + # Save training results + results = { + "config": vars(args), + "train_history": train_history, + "test_results": { + "loss": test_results["loss"], + "accuracy": test_results["accuracy"], + }, + "model_params": trainer.controller.get_trainable_parameters(), + } + + results_path = Path(args.output_dir) / "training_results.json" + with open(results_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, default=str) + logger.info(f"Saved training results to {results_path}") + + # Print summary + logger.info("=" * 60) + logger.info("Training Summary") + logger.info("=" * 60) + logger.info(f"Model: {args.model_name}") + logger.info(f"LoRA Parameters: r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout}") + logger.info(f"Training Epochs: {args.epochs}") + logger.info(f"Learning Rate: {args.lr}") + logger.info(f"Batch Size: {args.batch_size}") + logger.info(f"Final Training Loss: {train_history['train_loss']:.4f}") + logger.info(f"Validation Accuracy: {train_history['eval_results'].get('eval_accuracy', 0):.4f}") + logger.info(f"Test Accuracy: {test_results['accuracy']:.4f}") + logger.info(f"Test Loss: {test_results['loss']:.4f}") + logger.info(f"Model saved to: {final_model_path}") + logger.info(f"Results saved to: {results_path}") + logger.info("=" * 60) + logger.info("Training completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/src/training/train_rnn.py b/src/training/train_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..4ce68cd94d5ca3618747f61c15736703a435f31b --- /dev/null +++ b/src/training/train_rnn.py @@ -0,0 +1,916 @@ +""" +Training script for the RNN Meta-Controller. + +This module provides a complete training pipeline for the RNN-based meta-controller, +including data generation/loading, model training with early stopping, validation, +checkpointing, and comprehensive evaluation with per-class metrics. +""" + +import argparse +import json +import logging +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +from src.agents.meta_controller.rnn_controller import ( + RNNMetaControllerModel, +) +from src.training.data_generator import MetaControllerDataGenerator + +# Braintrust integration (optional) +try: + from src.observability.braintrust_tracker import BraintrustTracker, create_training_tracker + + BRAINTRUST_AVAILABLE = True +except ImportError: + BRAINTRUST_AVAILABLE = False + BraintrustTracker = None # type: ignore + + +class RNNTrainer: + """ + Trainer class for the RNN Meta-Controller model. + + Handles the complete training pipeline including data loading, training loops, + validation, early stopping, model checkpointing, and comprehensive evaluation. + + Attributes: + hidden_dim: Dimension of the GRU hidden state. + num_layers: Number of GRU layers. + dropout: Dropout probability for regularization. + lr: Learning rate for the optimizer. + batch_size: Batch size for training and evaluation. + epochs: Maximum number of training epochs. + early_stopping_patience: Number of epochs to wait for improvement before stopping. + seed: Random seed for reproducibility. + device: PyTorch device for computation. + model: The RNNMetaControllerModel instance. + optimizer: Adam optimizer for training. + criterion: CrossEntropyLoss for classification. + logger: Logger instance for progress reporting. + + Example: + >>> trainer = RNNTrainer(hidden_dim=64, epochs=10) + >>> generator = MetaControllerDataGenerator(seed=42) + >>> features, labels = generator.generate_balanced_dataset(100) + >>> X, y = generator.to_tensor_dataset(features, labels) + >>> splits = generator.split_dataset(X, y) + >>> history = trainer.train( + ... train_data=(splits['X_train'], splits['y_train']), + ... val_data=(splits['X_val'], splits['y_val']) + ... ) + """ + + AGENT_NAMES = ["hrm", "trm", "mcts"] + LABEL_TO_INDEX = {"hrm": 0, "trm": 1, "mcts": 2} + INDEX_TO_LABEL = {0: "hrm", 1: "trm", 2: "mcts"} + + def __init__( + self, + hidden_dim: int = 64, + num_layers: int = 1, + dropout: float = 0.1, + lr: float = 1e-3, + batch_size: int = 32, + epochs: int = 10, + early_stopping_patience: int = 3, + seed: int = 42, + device: str | None = None, + braintrust_tracker: Any | None = None, + ) -> None: + """ + Initialize the RNN trainer. + + Args: + hidden_dim: Dimension of GRU hidden state. Defaults to 64. + num_layers: Number of stacked GRU layers. Defaults to 1. + dropout: Dropout probability for regularization. Defaults to 0.1. + lr: Learning rate for Adam optimizer. Defaults to 1e-3. + batch_size: Batch size for training and evaluation. Defaults to 32. + epochs: Maximum number of training epochs. Defaults to 10. + early_stopping_patience: Epochs to wait for improvement before early stopping. + Defaults to 3. + seed: Random seed for reproducibility. Defaults to 42. + device: Device to run training on ('cpu', 'cuda', 'mps'). + If None, auto-detects best available device. + braintrust_tracker: Optional BraintrustTracker for experiment tracking. + """ + # Store hyperparameters + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.dropout = dropout + self.lr = lr + self.batch_size = batch_size + self.epochs = epochs + self.early_stopping_patience = early_stopping_patience + self.seed = seed + + # Set random seeds for reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Auto-detect device if not specified + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + + # Setup logging + self._setup_logging() + self.logger.info(f"Initializing RNNTrainer with device: {self.device}") + + # Initialize model + self.model = RNNMetaControllerModel( + input_dim=10, # Fixed based on features_to_tensor output + hidden_dim=hidden_dim, + num_layers=num_layers, + num_agents=len(self.AGENT_NAMES), + dropout=dropout, + ) + self.model = self.model.to(self.device) + self.logger.info(f"Model initialized: hidden_dim={hidden_dim}, num_layers={num_layers}, dropout={dropout}") + + # Setup optimizer + self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + self.logger.info(f"Optimizer: Adam with lr={lr}") + + # Setup loss function + self.criterion = nn.CrossEntropyLoss() + self.logger.info("Loss function: CrossEntropyLoss") + + # Braintrust experiment tracking (optional) + self.braintrust_tracker = braintrust_tracker + if self.braintrust_tracker and hasattr(self.braintrust_tracker, "is_available"): + if self.braintrust_tracker.is_available: + self.logger.info("Braintrust experiment tracking enabled") + self.braintrust_tracker.log_hyperparameters( + { + "hidden_dim": hidden_dim, + "num_layers": num_layers, + "dropout": dropout, + "learning_rate": lr, + "batch_size": batch_size, + "max_epochs": epochs, + "early_stopping_patience": early_stopping_patience, + "seed": seed, + "device": str(self.device), + } + ) + else: + self.logger.info("Braintrust tracker provided but not available") + + def _setup_logging(self) -> None: + """ + Setup logging configuration for the trainer. + + Creates a logger with console handler and appropriate formatting. + """ + self.logger = logging.getLogger("RNNTrainer") + self.logger.setLevel(logging.INFO) + + # Avoid duplicate handlers + if not self.logger.handlers: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + def create_dataloader( + self, + X: torch.Tensor, + y: torch.Tensor, + batch_size: int | None = None, + shuffle: bool = True, + ) -> DataLoader: + """ + Create a PyTorch DataLoader from feature and label tensors. + + Args: + X: Feature tensor of shape (N, 10). + y: Label tensor of shape (N,). + batch_size: Batch size for the DataLoader. If None, uses self.batch_size. + shuffle: Whether to shuffle the data. Defaults to True. + + Returns: + DataLoader instance for iterating over batches. + + Example: + >>> trainer = RNNTrainer() + >>> X = torch.randn(100, 10) + >>> y = torch.randint(0, 3, (100,)) + >>> loader = trainer.create_dataloader(X, y, batch_size=16) + >>> len(loader) + 7 + """ + if batch_size is None: + batch_size = self.batch_size + + # Ensure tensors are on CPU for DataLoader + if X.device != torch.device("cpu"): + X = X.cpu() + if y.device != torch.device("cpu"): + y = y.cpu() + + dataset = TensorDataset(X, y) + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, # Use main process for data loading + pin_memory=self.device.type == "cuda", + ) + + return loader + + def train_epoch(self, train_loader: DataLoader) -> float: + """ + Train the model for one epoch. + + Args: + train_loader: DataLoader providing training batches. + + Returns: + Average training loss for the epoch. + + Example: + >>> trainer = RNNTrainer() + >>> X = torch.randn(100, 10) + >>> y = torch.randint(0, 3, (100,)) + >>> loader = trainer.create_dataloader(X, y) + >>> loss = trainer.train_epoch(loader) + >>> isinstance(loss, float) + True + """ + self.model.train() + total_loss = 0.0 + num_batches = 0 + + for batch_X, batch_y in train_loader: + # Move data to device + batch_X = batch_X.to(self.device) + batch_y = batch_y.to(self.device) + + # Zero gradients + self.optimizer.zero_grad() + + # Forward pass + logits = self.model(batch_X) + + # Compute loss + loss = self.criterion(logits, batch_y) + + # Backward pass + loss.backward() + + # Update weights + self.optimizer.step() + + # Accumulate loss + total_loss += loss.item() + num_batches += 1 + + average_loss = total_loss / num_batches if num_batches > 0 else 0.0 + return average_loss + + def validate(self, val_loader: DataLoader) -> tuple[float, float]: + """ + Evaluate the model on the validation set. + + Args: + val_loader: DataLoader providing validation batches. + + Returns: + Tuple of (average_loss, accuracy). + - average_loss: Mean cross-entropy loss over validation set. + - accuracy: Classification accuracy as a fraction [0, 1]. + + Example: + >>> trainer = RNNTrainer() + >>> X = torch.randn(50, 10) + >>> y = torch.randint(0, 3, (50,)) + >>> loader = trainer.create_dataloader(X, y, shuffle=False) + >>> loss, acc = trainer.validate(loader) + >>> 0.0 <= acc <= 1.0 + True + """ + self.model.eval() + total_loss = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for batch_X, batch_y in val_loader: + # Move data to device + batch_X = batch_X.to(self.device) + batch_y = batch_y.to(self.device) + + # Forward pass + logits = self.model(batch_X) + + # Compute loss + loss = self.criterion(logits, batch_y) + total_loss += loss.item() + + # Compute accuracy + predictions = torch.argmax(logits, dim=1) + correct += (predictions == batch_y).sum().item() + total += batch_y.size(0) + + num_batches = len(val_loader) + average_loss = total_loss / num_batches if num_batches > 0 else 0.0 + accuracy = correct / total if total > 0 else 0.0 + + return average_loss, accuracy + + def train( + self, + train_data: tuple[torch.Tensor, torch.Tensor], + val_data: tuple[torch.Tensor, torch.Tensor], + save_path: str | None = None, + ) -> dict[str, Any]: + """ + Main training loop with early stopping and model checkpointing. + + Trains the model for the specified number of epochs, monitoring validation + loss for early stopping. If save_path is provided, saves the best model + checkpoint based on validation loss. + + Args: + train_data: Tuple of (X_train, y_train) tensors. + val_data: Tuple of (X_val, y_val) tensors. + save_path: Optional path to save the best model checkpoint. + + Returns: + Dictionary containing training history: + - 'train_losses': List of training losses per epoch. + - 'val_losses': List of validation losses per epoch. + - 'val_accuracies': List of validation accuracies per epoch. + - 'best_epoch': Epoch with best validation loss. + - 'best_val_loss': Best validation loss achieved. + - 'best_val_accuracy': Validation accuracy at best epoch. + - 'stopped_early': Whether training stopped early. + - 'total_epochs': Total number of epochs trained. + + Example: + >>> trainer = RNNTrainer(epochs=5) + >>> X_train = torch.randn(100, 10) + >>> y_train = torch.randint(0, 3, (100,)) + >>> X_val = torch.randn(20, 10) + >>> y_val = torch.randint(0, 3, (20,)) + >>> history = trainer.train((X_train, y_train), (X_val, y_val)) + >>> 'train_losses' in history + True + >>> len(history['train_losses']) <= 5 + True + """ + self.logger.info("Starting training...") + self.logger.info(f"Training samples: {train_data[0].shape[0]}") + self.logger.info(f"Validation samples: {val_data[0].shape[0]}") + self.logger.info(f"Batch size: {self.batch_size}") + self.logger.info(f"Max epochs: {self.epochs}") + self.logger.info(f"Early stopping patience: {self.early_stopping_patience}") + + # Create data loaders + train_loader = self.create_dataloader(train_data[0], train_data[1], shuffle=True) + val_loader = self.create_dataloader(val_data[0], val_data[1], shuffle=False) + + # Initialize tracking variables + train_losses: list[float] = [] + val_losses: list[float] = [] + val_accuracies: list[float] = [] + + best_val_loss = float("inf") + best_val_accuracy = 0.0 + best_epoch = 0 + best_model_state = None + patience_counter = 0 + stopped_early = False + + # Training loop + for epoch in range(1, self.epochs + 1): + # Train for one epoch + train_loss = self.train_epoch(train_loader) + train_losses.append(train_loss) + + # Validate + val_loss, val_accuracy = self.validate(val_loader) + val_losses.append(val_loss) + val_accuracies.append(val_accuracy) + + # Log progress + self.logger.info( + f"Epoch {epoch}/{self.epochs} - " + f"Train Loss: {train_loss:.4f}, " + f"Val Loss: {val_loss:.4f}, " + f"Val Accuracy: {val_accuracy:.4f}" + ) + + # Log to Braintrust if available + if self.braintrust_tracker and hasattr(self.braintrust_tracker, "log_epoch_summary"): + self.braintrust_tracker.log_epoch_summary( + epoch=epoch, + train_loss=train_loss, + val_loss=val_loss, + val_accuracy=val_accuracy, + ) + + # Check for improvement + if val_loss < best_val_loss: + best_val_loss = val_loss + best_val_accuracy = val_accuracy + best_epoch = epoch + best_model_state = self.model.state_dict().copy() + patience_counter = 0 + self.logger.info(f" -> New best validation loss: {val_loss:.4f}") + + # Save checkpoint if path provided + if save_path: + torch.save(best_model_state, save_path) + self.logger.info(f" -> Model checkpoint saved to {save_path}") + else: + patience_counter += 1 + self.logger.info(f" -> No improvement for {patience_counter} epoch(s)") + + # Check for early stopping + if patience_counter >= self.early_stopping_patience: + self.logger.info(f"Early stopping triggered at epoch {epoch}. Best epoch was {best_epoch}.") + stopped_early = True + break + + # Restore best model state + if best_model_state is not None: + self.model.load_state_dict(best_model_state) + self.logger.info( + f"Restored best model from epoch {best_epoch} " + f"with val_loss={best_val_loss:.4f}, val_accuracy={best_val_accuracy:.4f}" + ) + + # Final save if path provided and not already saved + if save_path and best_model_state is not None: + torch.save(best_model_state, save_path) + self.logger.info(f"Final model saved to {save_path}") + + # Compile history + history = { + "train_losses": train_losses, + "val_losses": val_losses, + "val_accuracies": val_accuracies, + "best_epoch": best_epoch, + "best_val_loss": best_val_loss, + "best_val_accuracy": best_val_accuracy, + "stopped_early": stopped_early, + "total_epochs": len(train_losses), + } + + self.logger.info("Training completed!") + self.logger.info(f"Best epoch: {best_epoch}") + self.logger.info(f"Best validation loss: {best_val_loss:.4f}") + self.logger.info(f"Best validation accuracy: {best_val_accuracy:.4f}") + + # Log final model artifact to Braintrust + if self.braintrust_tracker and hasattr(self.braintrust_tracker, "log_model_artifact"): + self.braintrust_tracker.log_model_artifact( + model_path=str(save_path) if save_path else "in_memory", + model_type="rnn", + metrics={ + "best_val_loss": best_val_loss, + "best_val_accuracy": best_val_accuracy, + "best_epoch": float(best_epoch), + "total_epochs": float(len(train_losses)), + }, + ) + + return history + + def evaluate(self, test_loader: DataLoader) -> dict[str, Any]: + """ + Comprehensive evaluation on the test set. + + Computes overall metrics and per-class precision, recall, and F1-score. + + Args: + test_loader: DataLoader providing test batches. + + Returns: + Dictionary containing: + - 'loss': Average cross-entropy loss. + - 'accuracy': Overall classification accuracy. + - 'per_class_metrics': Dictionary with per-class metrics: + - For each agent ('hrm', 'trm', 'mcts'): + - 'precision': Precision score. + - 'recall': Recall score. + - 'f1_score': F1 score. + - 'support': Number of samples in this class. + - 'confusion_matrix': 3x3 confusion matrix as nested list. + - 'total_samples': Total number of test samples. + + Example: + >>> trainer = RNNTrainer() + >>> X = torch.randn(50, 10) + >>> y = torch.randint(0, 3, (50,)) + >>> loader = trainer.create_dataloader(X, y, shuffle=False) + >>> results = trainer.evaluate(loader) + >>> 'accuracy' in results + True + >>> 'per_class_metrics' in results + True + """ + self.model.eval() + total_loss = 0.0 + all_predictions: list[int] = [] + all_labels: list[int] = [] + + with torch.no_grad(): + for batch_X, batch_y in test_loader: + # Move data to device + batch_X = batch_X.to(self.device) + batch_y = batch_y.to(self.device) + + # Forward pass + logits = self.model(batch_X) + + # Compute loss + loss = self.criterion(logits, batch_y) + total_loss += loss.item() + + # Get predictions + predictions = torch.argmax(logits, dim=1) + all_predictions.extend(predictions.cpu().tolist()) + all_labels.extend(batch_y.cpu().tolist()) + + # Calculate overall metrics + num_batches = len(test_loader) + average_loss = total_loss / num_batches if num_batches > 0 else 0.0 + + correct = sum(p == label for p, label in zip(all_predictions, all_labels, strict=False)) + total = len(all_labels) + accuracy = correct / total if total > 0 else 0.0 + + # Calculate confusion matrix + num_classes = len(self.AGENT_NAMES) + confusion_matrix = [[0] * num_classes for _ in range(num_classes)] + for pred, label in zip(all_predictions, all_labels, strict=False): + confusion_matrix[label][pred] += 1 + + # Calculate per-class metrics + per_class_metrics: dict[str, dict[str, float]] = {} + + for class_idx, agent_name in enumerate(self.AGENT_NAMES): + # True positives: predicted as this class and actually this class + tp = confusion_matrix[class_idx][class_idx] + + # False positives: predicted as this class but actually other class + fp = sum(confusion_matrix[i][class_idx] for i in range(num_classes) if i != class_idx) + + # False negatives: actually this class but predicted as other class + fn = sum(confusion_matrix[class_idx][j] for j in range(num_classes) if j != class_idx) + + # Support: total number of samples in this class + support = sum(confusion_matrix[class_idx]) + + # Precision: TP / (TP + FP) + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + + # Recall: TP / (TP + FN) + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + + # F1 Score: 2 * (Precision * Recall) / (Precision + Recall) + f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + + per_class_metrics[agent_name] = { + "precision": precision, + "recall": recall, + "f1_score": f1_score, + "support": support, + } + + results = { + "loss": average_loss, + "accuracy": accuracy, + "per_class_metrics": per_class_metrics, + "confusion_matrix": confusion_matrix, + "total_samples": total, + } + + self.logger.info("Evaluation Results:") + self.logger.info(f" Test Loss: {average_loss:.4f}") + self.logger.info(f" Test Accuracy: {accuracy:.4f}") + self.logger.info(f" Total Samples: {total}") + self.logger.info(" Per-Class Metrics:") + for agent_name, metrics in per_class_metrics.items(): + self.logger.info( + f" {agent_name}: " + f"Precision={metrics['precision']:.4f}, " + f"Recall={metrics['recall']:.4f}, " + f"F1={metrics['f1_score']:.4f}, " + f"Support={metrics['support']}" + ) + + return results + + +def main() -> None: + """ + Main entry point for training the RNN Meta-Controller. + + Parses command-line arguments, generates or loads dataset, trains the model, + evaluates on test set, and saves results. + """ + parser = argparse.ArgumentParser( + description="Train the RNN Meta-Controller for agent selection.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Model hyperparameters + parser.add_argument( + "--hidden_dim", + type=int, + default=64, + help="Dimension of GRU hidden state", + ) + parser.add_argument( + "--num_layers", + type=int, + default=1, + help="Number of GRU layers", + ) + parser.add_argument( + "--dropout", + type=float, + default=0.1, + help="Dropout probability", + ) + + # Training hyperparameters + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for Adam optimizer", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size for training and evaluation", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + help="Maximum number of training epochs", + ) + parser.add_argument( + "--patience", + type=int, + default=3, + help="Early stopping patience (epochs without improvement)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility", + ) + + # Data parameters + parser.add_argument( + "--num_samples", + type=int, + default=3000, + help="Number of samples to generate (per class for balanced dataset)", + ) + parser.add_argument( + "--data_path", + type=str, + default=None, + help="Path to load existing dataset (JSON format). If not provided, generates new data.", + ) + + # Output parameters + parser.add_argument( + "--save_path", + type=str, + default="rnn_meta_controller.pt", + help="Path to save the trained model", + ) + + # Experiment tracking + parser.add_argument( + "--use_braintrust", + action="store_true", + help="Enable Braintrust experiment tracking", + ) + parser.add_argument( + "--experiment_name", + type=str, + default=None, + help="Custom experiment name for Braintrust (auto-generated if not provided)", + ) + + args = parser.parse_args() + + # Setup logging for main + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + logger = logging.getLogger("train_rnn") + + logger.info("=" * 60) + logger.info("RNN Meta-Controller Training") + logger.info("=" * 60) + + # Print configuration + logger.info("Configuration:") + for arg_name, arg_value in vars(args).items(): + logger.info(f" {arg_name}: {arg_value}") + logger.info("") + + try: + # Initialize data generator + generator = MetaControllerDataGenerator(seed=args.seed) + + # Load or generate dataset + if args.data_path and Path(args.data_path).exists(): + logger.info(f"Loading dataset from {args.data_path}...") + features_list, labels_list = generator.load_dataset(args.data_path) + logger.info(f"Loaded {len(features_list)} samples") + else: + logger.info(f"Generating balanced dataset with {args.num_samples} samples per class...") + features_list, labels_list = generator.generate_balanced_dataset(num_samples_per_class=args.num_samples) + total_samples = len(features_list) + logger.info(f"Generated {total_samples} total samples") + + # Optionally save generated dataset + if args.data_path: + logger.info(f"Saving generated dataset to {args.data_path}...") + generator.save_dataset(features_list, labels_list, args.data_path) + + # Convert to tensors + logger.info("Converting dataset to tensors...") + X, y = generator.to_tensor_dataset(features_list, labels_list) + logger.info(f"Feature tensor shape: {X.shape}") + logger.info(f"Label tensor shape: {y.shape}") + + # Split dataset + logger.info("Splitting dataset into train/val/test (70%/15%/15%)...") + splits = generator.split_dataset(X, y, train_ratio=0.7, val_ratio=0.15) + + logger.info(f"Training set size: {splits['X_train'].shape[0]}") + logger.info(f"Validation set size: {splits['X_val'].shape[0]}") + logger.info(f"Test set size: {splits['X_test'].shape[0]}") + logger.info("") + + # Initialize Braintrust tracker if enabled + braintrust_tracker = None + if args.use_braintrust and BRAINTRUST_AVAILABLE: + logger.info("Initializing Braintrust experiment tracker...") + braintrust_tracker = create_training_tracker( + model_type="rnn", + config={ + "hidden_dim": args.hidden_dim, + "num_layers": args.num_layers, + "dropout": args.dropout, + "lr": args.lr, + "batch_size": args.batch_size, + "epochs": args.epochs, + "patience": args.patience, + "seed": args.seed, + "num_samples": args.num_samples, + }, + ) + if braintrust_tracker.is_available: + logger.info("Braintrust experiment tracking enabled") + else: + logger.info("Braintrust not available (check API key)") + elif args.use_braintrust and not BRAINTRUST_AVAILABLE: + logger.warning("Braintrust requested but not installed. Install with: pip install braintrust") + + # Initialize trainer + logger.info("Initializing trainer...") + trainer = RNNTrainer( + hidden_dim=args.hidden_dim, + num_layers=args.num_layers, + dropout=args.dropout, + lr=args.lr, + batch_size=args.batch_size, + epochs=args.epochs, + early_stopping_patience=args.patience, + seed=args.seed, + braintrust_tracker=braintrust_tracker, + ) + logger.info("") + + # Train model + logger.info("Starting training...") + logger.info("-" * 60) + history = trainer.train( + train_data=(splits["X_train"], splits["y_train"]), + val_data=(splits["X_val"], splits["y_val"]), + save_path=args.save_path, + ) + logger.info("-" * 60) + logger.info("") + + # Evaluate on test set + logger.info("Evaluating on test set...") + logger.info("-" * 60) + test_loader = trainer.create_dataloader(splits["X_test"], splits["y_test"], shuffle=False) + test_results = trainer.evaluate(test_loader) + logger.info("-" * 60) + logger.info("") + + # Save training history + history_path = Path(args.save_path).with_suffix(".history.json") + logger.info(f"Saving training history to {history_path}...") + + # Combine history and test results + full_results = { + "config": { + "hidden_dim": args.hidden_dim, + "num_layers": args.num_layers, + "dropout": args.dropout, + "lr": args.lr, + "batch_size": args.batch_size, + "epochs": args.epochs, + "patience": args.patience, + "seed": args.seed, + "num_samples": args.num_samples, + }, + "training_history": history, + "test_results": test_results, + } + + with open(history_path, "w", encoding="utf-8") as f: + json.dump(full_results, f, indent=2) + + logger.info(f"Training history saved to {history_path}") + logger.info("") + + # Print final summary + logger.info("=" * 60) + logger.info("Training Summary") + logger.info("=" * 60) + logger.info(f"Model saved to: {args.save_path}") + logger.info(f"History saved to: {history_path}") + logger.info(f"Best validation accuracy: {history['best_val_accuracy']:.4f}") + logger.info(f"Test accuracy: {test_results['accuracy']:.4f}") + logger.info(f"Test loss: {test_results['loss']:.4f}") + + if history["stopped_early"]: + logger.info(f"Training stopped early at epoch {history['total_epochs']}") + else: + logger.info(f"Training completed all {history['total_epochs']} epochs") + + logger.info("") + logger.info("Per-class test performance:") + for agent_name, metrics in test_results["per_class_metrics"].items(): + logger.info( + f" {agent_name}: F1={metrics['f1_score']:.4f}, " + f"Precision={metrics['precision']:.4f}, " + f"Recall={metrics['recall']:.4f}" + ) + + # End Braintrust experiment + if braintrust_tracker and hasattr(braintrust_tracker, "end_experiment"): + experiment_url = braintrust_tracker.end_experiment() + if experiment_url: + logger.info(f"Braintrust experiment URL: {experiment_url}") + + logger.info("=" * 60) + logger.info("Training completed successfully!") + logger.info("=" * 60) + + except FileNotFoundError as e: + logger.error(f"File not found: {e}") + raise + except ValueError as e: + logger.error(f"Invalid value: {e}") + raise + except RuntimeError as e: + logger.error(f"Runtime error: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/src/training/unified_orchestrator.py b/src/training/unified_orchestrator.py new file mode 100644 index 0000000000000000000000000000000000000000..edb5cab42a164643de0ad69e27673f2da3a1f761 --- /dev/null +++ b/src/training/unified_orchestrator.py @@ -0,0 +1,495 @@ +""" +Unified Training Orchestrator for LangGraph Multi-Agent MCTS with DeepMind-Style Learning. + +Coordinates: +- HRM Agent +- TRM Agent +- Neural MCTS +- Policy-Value Network +- Self-play data generation +- Training loops +- Evaluation +- Checkpointing +""" + +import time +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +from torch.cuda.amp import GradScaler, autocast + +from ..agents.hrm_agent import HRMLoss, create_hrm_agent +from ..agents.trm_agent import TRMLoss, create_trm_agent +from ..framework.mcts.neural_mcts import GameState, NeuralMCTS, SelfPlayCollector +from ..models.policy_value_net import ( + AlphaZeroLoss, + create_policy_value_network, +) +from .performance_monitor import PerformanceMonitor, TimingContext +from .replay_buffer import Experience, PrioritizedReplayBuffer, collate_experiences +from .system_config import SystemConfig + + +class UnifiedTrainingOrchestrator: + """ + Complete training pipeline integrating all framework components. + + This orchestrator manages: + 1. Self-play data generation using MCTS + 2. Neural network training (policy-value) + 3. HRM agent training + 4. TRM agent training + 5. Evaluation and checkpointing + 6. Performance monitoring + """ + + def __init__( + self, + config: SystemConfig, + initial_state_fn: Callable[[], GameState], + board_size: int = 19, + ): + """ + Initialize training orchestrator. + + Args: + config: System configuration + initial_state_fn: Function that returns initial game state + board_size: Board/grid size for spatial games + """ + self.config = config + self.initial_state_fn = initial_state_fn + self.board_size = board_size + + # Setup device + self.device = config.device + torch.manual_seed(config.seed) + + # Initialize performance monitor + self.monitor = PerformanceMonitor( + window_size=100, + enable_gpu_monitoring=(self.device != "cpu"), + ) + + # Initialize components + self._initialize_components() + + # Training state + self.current_iteration = 0 + self.best_win_rate = 0.0 + self.best_model_path = None + + # Setup paths + self._setup_paths() + + # Setup experiment tracking + if config.use_wandb: + self._setup_wandb() + + def _initialize_components(self): + """Initialize all framework components.""" + print("Initializing components...") + + # Policy-Value Network + self.policy_value_net = create_policy_value_network( + config=self.config.neural_net, + board_size=self.board_size, + device=self.device, + ) + + print(f" ✓ Policy-Value Network: {self.policy_value_net.get_parameter_count():,} parameters") + + # HRM Agent + self.hrm_agent = create_hrm_agent(self.config.hrm, self.device) + print(f" ✓ HRM Agent: {self.hrm_agent.get_parameter_count():,} parameters") + + # TRM Agent + self.trm_agent = create_trm_agent( + self.config.trm, output_dim=self.config.neural_net.action_size, device=self.device + ) + print(f" ✓ TRM Agent: {self.trm_agent.get_parameter_count():,} parameters") + + # Neural MCTS + self.mcts = NeuralMCTS( + policy_value_network=self.policy_value_net, + config=self.config.mcts, + device=self.device, + ) + print(" ✓ Neural MCTS initialized") + + # Self-play collector + self.self_play_collector = SelfPlayCollector(mcts=self.mcts, config=self.config.mcts) + + # Optimizers + self._setup_optimizers() + + # Loss functions + self.pv_loss_fn = AlphaZeroLoss(value_loss_weight=1.0) + self.hrm_loss_fn = HRMLoss(ponder_weight=0.01) + self.trm_loss_fn = TRMLoss( + task_loss_fn=nn.MSELoss(), + supervision_weight_decay=self.config.trm.supervision_weight_decay, + ) + + # Replay buffer + self.replay_buffer = PrioritizedReplayBuffer( + capacity=self.config.training.buffer_size, + alpha=0.6, + beta_start=0.4, + beta_frames=self.config.training.games_per_iteration * 10, + ) + + # Mixed precision scaler + self.scaler = GradScaler() if self.config.use_mixed_precision else None + + def _setup_optimizers(self): + """Setup optimizers and learning rate schedulers.""" + # Policy-Value optimizer + self.pv_optimizer = torch.optim.SGD( + self.policy_value_net.parameters(), + lr=self.config.training.learning_rate, + momentum=self.config.training.momentum, + weight_decay=self.config.training.weight_decay, + ) + + # HRM optimizer + self.hrm_optimizer = torch.optim.Adam(self.hrm_agent.parameters(), lr=1e-3) + + # TRM optimizer + self.trm_optimizer = torch.optim.Adam(self.trm_agent.parameters(), lr=1e-3) + + # Learning rate scheduler for policy-value network + if self.config.training.lr_schedule == "cosine": + self.pv_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.pv_optimizer, T_max=100) + elif self.config.training.lr_schedule == "step": + self.pv_scheduler = torch.optim.lr_scheduler.StepLR( + self.pv_optimizer, + step_size=self.config.training.lr_decay_steps, + gamma=self.config.training.lr_decay_gamma, + ) + else: + self.pv_scheduler = None + + def _setup_paths(self): + """Setup directory paths.""" + self.checkpoint_dir = Path(self.config.checkpoint_dir) + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + self.data_dir = Path(self.config.data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + + self.log_dir = Path(self.config.log_dir) + self.log_dir.mkdir(parents=True, exist_ok=True) + + def _setup_wandb(self): + """Setup Weights & Biases experiment tracking.""" + try: + import wandb + + wandb.init( + project=self.config.wandb_project, + entity=self.config.wandb_entity, + config=self.config.to_dict(), + name=f"run_{time.strftime('%Y%m%d_%H%M%S')}", + ) + print(" ✓ Weights & Biases initialized") + except ImportError: + print(" ⚠️ wandb not installed, skipping") + self.config.use_wandb = False + + async def train_iteration(self, iteration: int) -> dict[str, Any]: + """ + Execute single training iteration. + + Args: + iteration: Current iteration number + + Returns: + Dictionary of metrics + """ + print(f"\n{'=' * 80}") + print(f"Training Iteration {iteration}") + print(f"{'=' * 80}") + + metrics = {} + + # Phase 1: Self-play data generation + print("\n[1/5] Generating self-play data...") + with TimingContext(self.monitor, "self_play_generation"): + game_data = await self._generate_self_play_data() + metrics["games_generated"] = len(game_data) + print(f" Generated {len(game_data)} training examples") + + # Phase 2: Policy-Value network training + print("\n[2/5] Training Policy-Value Network...") + with TimingContext(self.monitor, "pv_training"): + pv_metrics = await self._train_policy_value_network() + metrics.update(pv_metrics) + + # Phase 3: HRM agent training (optional, if using HRM) + if hasattr(self, "hrm_agent"): + print("\n[3/5] Training HRM Agent...") + with TimingContext(self.monitor, "hrm_training"): + hrm_metrics = await self._train_hrm_agent() + metrics.update(hrm_metrics) + + # Phase 4: TRM agent training (optional, if using TRM) + if hasattr(self, "trm_agent"): + print("\n[4/5] Training TRM Agent...") + with TimingContext(self.monitor, "trm_training"): + trm_metrics = await self._train_trm_agent() + metrics.update(trm_metrics) + + # Phase 5: Evaluation + print("\n[5/5] Evaluation...") + if iteration % self.config.training.checkpoint_interval == 0: + eval_metrics = await self._evaluate() + metrics.update(eval_metrics) + + # Save checkpoint if improved + if eval_metrics.get("win_rate", 0) > self.best_win_rate: + self.best_win_rate = eval_metrics["win_rate"] + self._save_checkpoint(iteration, metrics, is_best=True) + print(f" ✓ New best model! Win rate: {self.best_win_rate:.2%}") + + # Log metrics + self._log_metrics(iteration, metrics) + + # Performance check + self.monitor.alert_if_slow() + + return metrics + + async def _generate_self_play_data(self) -> list[Experience]: + """Generate training data from self-play games.""" + num_games = self.config.training.games_per_iteration + + # In production, this would use parallel actors + # For simplicity, we'll do sequential self-play + all_examples = [] + + for game_idx in range(num_games): + examples = await self.self_play_collector.play_game( + initial_state=self.initial_state_fn(), + temperature_threshold=self.config.mcts.temperature_threshold, + ) + + # Convert to Experience objects + for ex in examples: + all_examples.append(Experience(state=ex.state, policy=ex.policy_target, value=ex.value_target)) + + if (game_idx + 1) % 5 == 0: + print(f" Generated {game_idx + 1}/{num_games} games...") + + # Add to replay buffer + self.replay_buffer.add_batch(all_examples) + + return all_examples + + async def _train_policy_value_network(self) -> dict[str, float]: + """Train policy-value network on replay buffer data.""" + if not self.replay_buffer.is_ready(self.config.training.batch_size): + print(" Replay buffer not ready, skipping...") + return {"policy_loss": 0.0, "value_loss": 0.0} + + self.policy_value_net.train() + + total_policy_loss = 0.0 + total_value_loss = 0.0 + num_batches = 10 # Train for 10 batches per iteration + + for _ in range(num_batches): + # Sample batch + experiences, indices, weights = self.replay_buffer.sample(self.config.training.batch_size) + states, policies, values = collate_experiences(experiences) + + states = states.to(self.device) + policies = policies.to(self.device) + values = values.to(self.device) + weights = torch.from_numpy(weights).to(self.device) + + # Forward pass + if self.config.use_mixed_precision and self.scaler: + with autocast(): + policy_logits, value_pred = self.policy_value_net(states) + loss, loss_dict = self.pv_loss_fn(policy_logits, value_pred, policies, values) + # Apply importance sampling weights + loss = (loss * weights).mean() + + # Backward pass with mixed precision + self.pv_optimizer.zero_grad() + self.scaler.scale(loss).backward() + self.scaler.step(self.pv_optimizer) + self.scaler.update() + else: + policy_logits, value_pred = self.policy_value_net(states) + loss, loss_dict = self.pv_loss_fn(policy_logits, value_pred, policies, values) + loss = (loss * weights).mean() + + self.pv_optimizer.zero_grad() + loss.backward() + self.pv_optimizer.step() + + # Update priorities in replay buffer + with torch.no_grad(): + td_errors = torch.abs(value_pred.squeeze() - values) + self.replay_buffer.update_priorities(indices, td_errors.cpu().numpy()) + + total_policy_loss += loss_dict["policy"] + total_value_loss += loss_dict["value"] + + # Log losses + self.monitor.log_loss(loss_dict["policy"], loss_dict["value"], loss_dict["total"]) + + # Step learning rate scheduler + if self.pv_scheduler: + self.pv_scheduler.step() + + avg_policy_loss = total_policy_loss / num_batches + avg_value_loss = total_value_loss / num_batches + + print(f" Policy Loss: {avg_policy_loss:.4f}, Value Loss: {avg_value_loss:.4f}") + + return {"policy_loss": avg_policy_loss, "value_loss": avg_value_loss} + + async def _train_hrm_agent(self) -> dict[str, float]: + """Train HRM agent (placeholder for domain-specific implementation).""" + # This would require domain-specific data and tasks + # For now, return dummy metrics + return {"hrm_halt_step": 5.0, "hrm_ponder_cost": 0.1} + + async def _train_trm_agent(self) -> dict[str, float]: + """Train TRM agent (placeholder for domain-specific implementation).""" + # This would require domain-specific data and tasks + # For now, return dummy metrics + return {"trm_convergence_step": 8.0, "trm_final_residual": 0.01} + + async def _evaluate(self) -> dict[str, float]: + """Evaluate current model against baseline.""" + # Simplified evaluation: play games against previous best + # In production, this would be more sophisticated + win_rate = 0.55 # Placeholder + + return { + "win_rate": win_rate, + "eval_games": self.config.training.evaluation_games, + } + + def _save_checkpoint(self, iteration: int, metrics: dict, is_best: bool = False): + """Save model checkpoint.""" + checkpoint = { + "iteration": iteration, + "policy_value_net": self.policy_value_net.state_dict(), + "hrm_agent": self.hrm_agent.state_dict(), + "trm_agent": self.trm_agent.state_dict(), + "pv_optimizer": self.pv_optimizer.state_dict(), + "hrm_optimizer": self.hrm_optimizer.state_dict(), + "trm_optimizer": self.trm_optimizer.state_dict(), + "config": self.config.to_dict(), + "metrics": metrics, + "best_win_rate": self.best_win_rate, + } + + # Save regular checkpoint + path = self.checkpoint_dir / f"checkpoint_iter_{iteration}.pt" + torch.save(checkpoint, path) + print(f" ✓ Checkpoint saved: {path}") + + # Save best model + if is_best: + best_path = self.checkpoint_dir / "best_model.pt" + torch.save(checkpoint, best_path) + self.best_model_path = best_path + print(f" ✓ Best model saved: {best_path}") + + def _log_metrics(self, iteration: int, metrics: dict): + """Log metrics to console and tracking systems.""" + print(f"\n[Metrics Summary - Iteration {iteration}]") + for key, value in metrics.items(): + if isinstance(value, float): + print(f" {key}: {value:.4f}") + else: + print(f" {key}: {value}") + + # Log to wandb + if self.config.use_wandb: + try: + import wandb + + wandb_metrics = self.monitor.export_to_wandb(iteration) + wandb_metrics.update(metrics) + wandb.log(wandb_metrics, step=iteration) + except Exception as e: + print(f" ⚠️ Failed to log to wandb: {e}") + + async def train(self, num_iterations: int): + """ + Run complete training loop. + + Args: + num_iterations: Number of training iterations + """ + print("\n" + "=" * 80) + print("Starting Training") + print("=" * 80) + print(f"Total iterations: {num_iterations}") + print(f"Device: {self.device}") + print(f"Mixed precision: {self.config.use_mixed_precision}") + + start_time = time.time() + + for iteration in range(1, num_iterations + 1): + self.current_iteration = iteration + + try: + _ = await self.train_iteration(iteration) + + # Check early stopping + if self._should_early_stop(iteration): + print("\n⚠️ Early stopping triggered") + break + + except KeyboardInterrupt: + print("\n⚠️ Training interrupted by user") + break + except Exception as e: + print(f"\n❌ Error in iteration {iteration}: {e}") + import traceback + + traceback.print_exc() + break + + elapsed = time.time() - start_time + print(f"\n{'=' * 80}") + print(f"Training completed in {elapsed / 3600:.2f} hours") + print(f"Best win rate: {self.best_win_rate:.2%}") + print(f"{'=' * 80}\n") + + # Print final performance summary + self.monitor.print_summary() + + def _should_early_stop(self, iteration: int) -> bool: + """Check early stopping criteria.""" + # Placeholder: implement actual early stopping logic + _ = iteration # noqa: F841 + return False + + def load_checkpoint(self, path: str): + """Load checkpoint from file.""" + checkpoint = torch.load(path, map_location=self.device, weights_only=True) + + self.policy_value_net.load_state_dict(checkpoint["policy_value_net"]) + self.hrm_agent.load_state_dict(checkpoint["hrm_agent"]) + self.trm_agent.load_state_dict(checkpoint["trm_agent"]) + + self.pv_optimizer.load_state_dict(checkpoint["pv_optimizer"]) + self.hrm_optimizer.load_state_dict(checkpoint["hrm_optimizer"]) + self.trm_optimizer.load_state_dict(checkpoint["trm_optimizer"]) + + self.current_iteration = checkpoint["iteration"] + self.best_win_rate = checkpoint.get("best_win_rate", 0.0) + + print(f"✓ Loaded checkpoint from iteration {self.current_iteration}") diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..602c0b61027c777e8abe4f9778d6df8161f52c67 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility modules for the LangGraph Multi-Agent MCTS Framework.""" + +from src.utils.personality_response import PersonalityResponseGenerator + +__all__ = ["PersonalityResponseGenerator"] diff --git a/src/utils/personality_response.py b/src/utils/personality_response.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc0a16e9befdf55c6ae8c9d9c12fe77e3ac27a4 --- /dev/null +++ b/src/utils/personality_response.py @@ -0,0 +1,326 @@ +""" +Personality Response Generator for LangGraph Multi-Agent MCTS Framework. + +This module provides a conversational personality layer that transforms +technical agent responses into friendly, balanced advisor outputs while +maintaining transparency and ethical considerations. + +Following 2025 best practices: +- Type hints throughout +- Comprehensive docstrings (Google style) +- Dataclasses for configuration +- Property-based encapsulation +- Exception handling +- Logging for observability +""" + +import logging +import re +from dataclasses import dataclass, field +from typing import ClassVar + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PersonalityTraits: + """ + Immutable configuration for personality traits. + + Attributes: + loyalty: Commitment to user's goals (0.0-1.0) + curiosity: Tendency to explore alternatives (0.0-1.0) + aspiration: Drive toward optimal solutions (0.0-1.0) + ethical_weight: Consideration of ethical implications (0.0-1.0) + transparency: Openness about reasoning and limitations (0.0-1.0) + """ + + loyalty: float = 0.95 + curiosity: float = 0.85 + aspiration: float = 0.90 + ethical_weight: float = 0.92 + transparency: float = 0.88 + + def __post_init__(self) -> None: + """Validate trait values are in range [0.0, 1.0].""" + for trait_name, value in self.__dict__.items(): + if not 0.0 <= value <= 1.0: + raise ValueError(f"Trait '{trait_name}' must be in range [0.0, 1.0], got {value}") + + +@dataclass +class PersonalityResponseGenerator: + """ + Generates personality-infused responses based on configurable traits. + + This class transforms technical agent outputs into conversational, + balanced advisor responses that maintain transparency while being + approachable and user-friendly. + + Attributes: + traits: PersonalityTraits configuration + + Example: + >>> generator = PersonalityResponseGenerator() + >>> response = generator.generate_response( + ... agent_response="Technical analysis complete.", + ... query="How do I optimize my code?" + ... ) + >>> print(response) + Let me be transparent about my approach... + """ + + traits: PersonalityTraits = field(default_factory=PersonalityTraits) + + # Class-level constants for phrase templates + TRANSPARENCY_PHRASES: ClassVar[list[str]] = [ + "Let me be transparent about", + "I want to be clear that", + "To be honest", + "Let me share openly", + ] + + CURIOSITY_PHRASES: ClassVar[list[str]] = [ + "I'm curious about exploring", + "There are interesting alternatives worth considering", + "It might be valuable to also look at", + "I wonder if we could also approach this by", + ] + + ASPIRATION_PHRASES: ClassVar[list[str]] = [ + "I'm committed to helping you find the best solution", + "Let's aim for the optimal approach", + "I believe we can achieve even better results by", + "Striving for excellence", + ] + + LOYALTY_PHRASES: ClassVar[list[str]] = [ + "I'm here to support your goals", + "Your success is my priority", + "I'm committed to helping you succeed", + "Working together toward your objectives", + ] + + ETHICAL_PHRASES: ClassVar[list[str]] = [ + "It's important to consider the ethical implications", + "Let's ensure this aligns with best practices", + "We should be mindful of", + "From an ethical standpoint", + ] + + def generate_response( + self, + agent_response: str, + query: str, + include_preamble: bool = True, + max_length: int = 1000, + ) -> str: + """ + Generate a personality-infused response from technical agent output. + + Args: + agent_response: The original technical response from the agent + query: The original user query for context + include_preamble: Whether to include personality preamble + max_length: Maximum length of the generated response + + Returns: + A conversational, personality-infused version of the response + + Raises: + ValueError: If agent_response or query is empty + + Example: + >>> gen = PersonalityResponseGenerator() + >>> response = gen.generate_response( + ... "[HRM Analysis] Breaking down hierarchically...", + ... "How do I solve this problem?" + ... ) + >>> "transparent" in response.lower() + True + """ + # Input validation + if not agent_response or not agent_response.strip(): + raise ValueError("agent_response cannot be empty") + if not query or not query.strip(): + raise ValueError("query cannot be empty") + + try: + # Build the personality-infused response + parts = [] + + # Add preamble based on traits + if include_preamble: + preamble = self._generate_preamble(query) + parts.append(preamble) + + # Transform the technical response + transformed_response = self._transform_response(agent_response, query) + parts.append(transformed_response) + + # Add trait-based closing + closing = self._generate_closing(agent_response) + if closing: + parts.append(closing) + + # Combine and truncate if needed + full_response = "\n\n".join(parts) + + if len(full_response) > max_length: + full_response = full_response[:max_length - 3] + "..." + logger.warning(f"Response truncated to {max_length} characters") + + return full_response + + except Exception as e: + logger.error(f"Error generating personality response: {e}", exc_info=True) + # Fallback to original response with simple wrapper + return f"Here's what I found:\n\n{agent_response}" + + def _generate_preamble(self, query: str) -> str: + """ + Generate an opening preamble based on personality traits. + + Args: + query: The user's query + + Returns: + A personalized preamble + """ + preamble_parts = [] + + # Transparency (highest weight) + if self.traits.transparency >= 0.8: + preamble_parts.append( + f"{self.TRANSPARENCY_PHRASES[0]} my approach to your query. " + ) + + # Loyalty + if self.traits.loyalty >= 0.9: + preamble_parts.append( + f"{self.LOYALTY_PHRASES[0]}, and I've carefully analyzed your question. " + ) + + return "".join(preamble_parts).strip() + + def _transform_response(self, agent_response: str, query: str) -> str: + """ + Transform technical agent response into conversational tone. + + Args: + agent_response: Original technical response + query: User query for context + + Returns: + Conversational version of the response + """ + # Extract agent name from response if present + agent_match = re.search(r"\[(.*?)\]", agent_response) + agent_name = agent_match.group(1) if agent_match else "the agent" + + # Remove technical markers like [HRM Analysis], [TRM Analysis], etc. + cleaned_response = re.sub(r"\[.*?\]\s*", "", agent_response) + + # Create conversational wrapper + conversational = ( + f"Based on my analysis using {agent_name.lower()}, " + f"I've identified the following approach:\n\n{cleaned_response}" + ) + + return conversational + + def _generate_closing(self, agent_response: str) -> str: + """ + Generate a closing statement based on traits. + + Args: + agent_response: The agent response (to check for certain keywords) + + Returns: + A closing statement or empty string + """ + closing_parts = [] + + # Aspiration - offer to go further + if self.traits.aspiration >= 0.85: + closing_parts.append( + "I'm committed to helping you achieve the best possible outcome. " + ) + + # Curiosity - suggest alternatives + if self.traits.curiosity >= 0.8 and any( + keyword in agent_response.lower() + for keyword in ["optimize", "improve", "compare"] + ): + closing_parts.append( + "I'm curious if you'd like to explore alternative approaches as well. " + ) + + # Ethical considerations for certain technical queries + if self.traits.ethical_weight >= 0.9 and any( + keyword in agent_response.lower() + for keyword in ["system", "design", "architecture", "security"] + ): + closing_parts.append( + "As we proceed, let's ensure our approach aligns with best practices and ethical considerations. " + ) + + return "".join(closing_parts).strip() + + @property + def trait_summary(self) -> dict[str, float]: + """ + Get a summary of current personality traits. + + Returns: + Dictionary mapping trait names to their values + """ + return { + "loyalty": self.traits.loyalty, + "curiosity": self.traits.curiosity, + "aspiration": self.traits.aspiration, + "ethical_weight": self.traits.ethical_weight, + "transparency": self.traits.transparency, + } + + def __repr__(self) -> str: + """String representation for debugging.""" + return ( + f"PersonalityResponseGenerator(" + f"loyalty={self.traits.loyalty:.2f}, " + f"curiosity={self.traits.curiosity:.2f}, " + f"aspiration={self.traits.aspiration:.2f}, " + f"ethical_weight={self.traits.ethical_weight:.2f}, " + f"transparency={self.traits.transparency:.2f})" + ) + + +# Example usage +if __name__ == "__main__": + # Configure logging for standalone execution + logging.basicConfig(level=logging.INFO) + + # Create generator with default traits + generator = PersonalityResponseGenerator() + + # Example technical response + agent_response = ( + "[HRM Analysis] Breaking down the problem hierarchically: " + "What are the key factors to consider when choosing between " + "microservices and monolithic architecture?..." + ) + query = "What are the key factors to consider when choosing between microservices and monolithic architecture?" + + # Generate personality response + personality_response = generator.generate_response(agent_response, query) + + print("=" * 80) + print("ORIGINAL RESPONSE:") + print("=" * 80) + print(agent_response) + print("\n" + "=" * 80) + print("PERSONALITY-INFUSED RESPONSE:") + print("=" * 80) + print(personality_response) + print("\n" + "=" * 80) + print(f"Trait Summary: {generator.trait_summary}")