langgraph-mcts-demo / app_mock.py
ianshank
feat: add personality output and bug fixes
40ee6b4
"""
LangGraph Multi-Agent MCTS Framework - Hugging Face Spaces Demo
A proof-of-concept demonstration of multi-agent reasoning with Monte Carlo Tree Search.
"""
import asyncio
import time
from dataclasses import dataclass
import gradio as gr
import numpy as np
# Demo-specific simplified implementations
from demo_src.agents_demo import HRMAgent, TRMAgent
from demo_src.llm_mock import HuggingFaceClient, MockLLMClient
from demo_src.mcts_demo import MCTSDemo
from demo_src.wandb_tracker import WandBTracker, is_wandb_available
@dataclass
class AgentResult:
"""Result from a single agent."""
agent_name: str
response: str
confidence: float
reasoning_steps: list[str]
execution_time_ms: float
@dataclass
class FrameworkResult:
"""Combined result from all agents."""
query: str
hrm_result: AgentResult | None
trm_result: AgentResult | None
mcts_result: dict | None
consensus_score: float
final_response: str
total_time_ms: float
metadata: dict
class MultiAgentFrameworkDemo:
"""Simplified multi-agent framework for Hugging Face Spaces demo."""
def __init__(self, use_hf_inference: bool = False, hf_model: str = ""):
"""Initialize the demo framework.
Args:
use_hf_inference: Use Hugging Face Inference API instead of mock
hf_model: Hugging Face model ID for inference
"""
self.use_hf_inference = use_hf_inference
self.hf_model = hf_model
# Initialize components
if use_hf_inference and hf_model:
self.llm_client = HuggingFaceClient(model_id=hf_model)
else:
self.llm_client = MockLLMClient()
self.hrm_agent = HRMAgent(self.llm_client)
self.trm_agent = TRMAgent(self.llm_client)
self.mcts = MCTSDemo()
async def process_query(
self,
query: str,
use_hrm: bool = True,
use_trm: bool = True,
use_mcts: bool = False,
mcts_iterations: int = 25,
exploration_weight: float = 1.414,
seed: int | None = None,
) -> FrameworkResult:
"""Process a query through the multi-agent framework.
Args:
query: The input query to process
use_hrm: Enable Hierarchical Reasoning Module
use_trm: Enable Tree Reasoning Module
use_mcts: Enable Monte Carlo Tree Search
mcts_iterations: Number of MCTS iterations
exploration_weight: UCB1 exploration parameter
seed: Random seed for reproducibility
Returns:
FrameworkResult with all agent outputs and consensus
"""
start_time = time.perf_counter()
hrm_result = None
trm_result = None
mcts_result = None
# Run enabled agents
tasks = []
agent_names = []
if use_hrm:
tasks.append(self._run_hrm(query))
agent_names.append("hrm")
if use_trm:
tasks.append(self._run_trm(query))
agent_names.append("trm")
if use_mcts:
tasks.append(self._run_mcts(query, mcts_iterations, exploration_weight, seed))
agent_names.append("mcts")
# Execute agents concurrently
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
for name, result in zip(agent_names, results, strict=False):
if isinstance(result, Exception):
continue
if name == "hrm":
hrm_result = result
elif name == "trm":
trm_result = result
elif name == "mcts":
mcts_result = result
# Calculate consensus score
consensus_score = self._calculate_consensus(hrm_result, trm_result, mcts_result)
# Generate final synthesized response
final_response = self._synthesize_response(query, hrm_result, trm_result, mcts_result, consensus_score)
total_time = (time.perf_counter() - start_time) * 1000
return FrameworkResult(
query=query,
hrm_result=hrm_result,
trm_result=trm_result,
mcts_result=mcts_result,
consensus_score=consensus_score,
final_response=final_response,
total_time_ms=round(total_time, 2),
metadata={
"agents_used": agent_names,
"mcts_config": (
{"iterations": mcts_iterations, "exploration_weight": exploration_weight, "seed": seed}
if use_mcts
else None
),
},
)
async def _run_hrm(self, query: str) -> AgentResult:
"""Run Hierarchical Reasoning Module."""
start = time.perf_counter()
result = await self.hrm_agent.process(query)
elapsed = (time.perf_counter() - start) * 1000
return AgentResult(
agent_name="HRM (Hierarchical Reasoning)",
response=result["response"],
confidence=result["confidence"],
reasoning_steps=result["steps"],
execution_time_ms=round(elapsed, 2),
)
async def _run_trm(self, query: str) -> AgentResult:
"""Run Tree Reasoning Module."""
start = time.perf_counter()
result = await self.trm_agent.process(query)
elapsed = (time.perf_counter() - start) * 1000
return AgentResult(
agent_name="TRM (Iterative Refinement)",
response=result["response"],
confidence=result["confidence"],
reasoning_steps=result["steps"],
execution_time_ms=round(elapsed, 2),
)
async def _run_mcts(self, query: str, iterations: int, exploration_weight: float, seed: int | None) -> dict:
"""Run Monte Carlo Tree Search."""
start = time.perf_counter()
# MCTSDemo.search is now async and uses the production framework
result = await self.mcts.search(query=query, iterations=iterations, exploration_weight=exploration_weight, seed=seed)
elapsed = (time.perf_counter() - start) * 1000
result["execution_time_ms"] = round(elapsed, 2)
return result
def _calculate_consensus(
self, hrm_result: AgentResult | None, trm_result: AgentResult | None, mcts_result: dict | None
) -> float:
"""Calculate agreement score between agents."""
confidences = []
if hrm_result:
confidences.append(hrm_result.confidence)
if trm_result:
confidences.append(trm_result.confidence)
if mcts_result:
confidences.append(mcts_result.get("best_value", 0.5))
if not confidences:
return 0.0
# Consensus is based on confidence alignment and average
if len(confidences) == 1:
return confidences[0]
avg_confidence = np.mean(confidences)
std_confidence = np.std(confidences)
# Higher consensus when agents agree (low std) and are confident (high avg)
agreement_factor = max(0, 1 - std_confidence * 2)
consensus = avg_confidence * agreement_factor
return round(min(1.0, consensus), 3)
def _synthesize_response(
self,
query: str,
hrm_result: AgentResult | None,
trm_result: AgentResult | None,
mcts_result: dict | None,
consensus_score: float,
) -> str:
"""Synthesize final response from all agent outputs."""
parts = []
if hrm_result and hrm_result.confidence > 0.5:
parts.append(f"[HRM] {hrm_result.response}")
if trm_result and trm_result.confidence > 0.5:
parts.append(f"[TRM] {trm_result.response}")
if mcts_result and mcts_result.get("best_value", 0) > 0.5:
parts.append(f"[MCTS] Best path: {mcts_result.get('best_action', 'N/A')}")
if not parts:
truncated_query = f"{query[:80]}..." if len(query) > 80 else query
return f"Insufficient confidence to answer query: '{truncated_query}'."
synthesis = " | ".join(parts)
if consensus_score > 0.7:
return f"HIGH CONSENSUS ({consensus_score:.1%}): {synthesis}"
elif consensus_score > 0.4:
return f"MODERATE CONSENSUS ({consensus_score:.1%}): {synthesis}"
else:
return f"LOW CONSENSUS ({consensus_score:.1%}): {synthesis}"
# Global framework instance
framework = None
wandb_tracker = None
def initialize_framework(use_hf: bool, model_id: str):
"""Initialize or reinitialize the framework."""
global framework
framework = MultiAgentFrameworkDemo(use_hf_inference=use_hf, hf_model=model_id)
return "Framework initialized successfully!"
def process_query_sync(
query: str,
use_hrm: bool,
use_trm: bool,
use_mcts: bool,
mcts_iterations: int,
exploration_weight: float,
seed: int,
enable_wandb: bool = False,
wandb_project: str = "langgraph-mcts-demo",
wandb_run_name: str = "",
):
"""Synchronous wrapper for async processing."""
global framework, wandb_tracker
if framework is None:
framework = MultiAgentFrameworkDemo()
if not query.strip():
return "Please enter a query.", {}, "", {}, ""
# Handle seed
seed_value = seed if seed > 0 else None
# Initialize W&B tracking if enabled
wandb_url = ""
if enable_wandb and is_wandb_available():
if wandb_tracker is None:
wandb_tracker = WandBTracker(project_name=wandb_project, enabled=True)
# Start a new run
run_name = wandb_run_name if wandb_run_name.strip() else None
config = {
"query": query[:200], # Truncate for config
"use_hrm": use_hrm,
"use_trm": use_trm,
"use_mcts": use_mcts,
"mcts_iterations": mcts_iterations,
"exploration_weight": exploration_weight,
"seed": seed_value,
}
wandb_tracker.init_run(run_name=run_name, config=config)
# Run async function
result = asyncio.run(
framework.process_query(
query=query,
use_hrm=use_hrm,
use_trm=use_trm,
use_mcts=use_mcts,
mcts_iterations=int(mcts_iterations),
exploration_weight=exploration_weight,
seed=seed_value,
)
)
# Format outputs
final_response = result.final_response
# Agent details
agent_details = {}
if result.hrm_result:
agent_details["HRM"] = {
"response": result.hrm_result.response,
"confidence": f"{result.hrm_result.confidence:.1%}",
"reasoning_steps": result.hrm_result.reasoning_steps,
"time_ms": result.hrm_result.execution_time_ms,
}
# Log to W&B
if enable_wandb and wandb_tracker:
wandb_tracker.log_agent_result(
"HRM",
result.hrm_result.response,
result.hrm_result.confidence,
result.hrm_result.execution_time_ms,
result.hrm_result.reasoning_steps,
)
if result.trm_result:
agent_details["TRM"] = {
"response": result.trm_result.response,
"confidence": f"{result.trm_result.confidence:.1%}",
"reasoning_steps": result.trm_result.reasoning_steps,
"time_ms": result.trm_result.execution_time_ms,
}
# Log to W&B
if enable_wandb and wandb_tracker:
wandb_tracker.log_agent_result(
"TRM",
result.trm_result.response,
result.trm_result.confidence,
result.trm_result.execution_time_ms,
result.trm_result.reasoning_steps,
)
if result.mcts_result:
agent_details["MCTS"] = result.mcts_result
# Log to W&B
if enable_wandb and wandb_tracker:
wandb_tracker.log_mcts_result(result.mcts_result)
# Log consensus and performance to W&B
if enable_wandb and wandb_tracker:
wandb_tracker.log_consensus(result.consensus_score, result.metadata["agents_used"], result.final_response)
wandb_tracker.log_performance(result.total_time_ms)
wandb_tracker.log_query_summary(query, use_hrm, use_trm, use_mcts, result.consensus_score, result.total_time_ms)
# Get run URL
wandb_url = wandb_tracker.get_run_url() or ""
# Finish the run
wandb_tracker.finish_run()
# Metrics
metrics = f"""
**Consensus Score:** {result.consensus_score:.1%}
**Total Processing Time:** {result.total_time_ms:.2f} ms
**Agents Used:** {", ".join(result.metadata["agents_used"])}
"""
if wandb_url:
metrics += f"\n**W&B Run:** [{wandb_url}]({wandb_url})"
# Full JSON result
full_result = {
"query": result.query,
"final_response": result.final_response,
"consensus_score": result.consensus_score,
"total_time_ms": result.total_time_ms,
"metadata": result.metadata,
"agent_details": agent_details,
"wandb_url": wandb_url,
}
return final_response, agent_details, metrics, full_result, wandb_url
def visualize_mcts_tree(mcts_result: dict) -> str:
"""Create ASCII visualization of MCTS tree."""
if not mcts_result or "tree_visualization" not in mcts_result:
return "No MCTS tree data available"
return mcts_result["tree_visualization"]
# Example queries for demonstration
EXAMPLE_QUERIES = [
"What are the key factors to consider when choosing between microservices and monolithic architecture?",
"How can we optimize a Python application that processes 10GB of log files daily?",
"What is the best approach to implement rate limiting in a distributed system?",
"Should we use SQL or NoSQL database for a social media application with 1M users?",
"How to design a fault-tolerant message queue system?",
]
# Gradio Interface
with gr.Blocks(
title="LangGraph Multi-Agent MCTS Demo",
theme=gr.themes.Soft(),
css="""
.agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; }
.consensus-high { color: #28a745; font-weight: bold; }
.consensus-medium { color: #ffc107; font-weight: bold; }
.consensus-low { color: #dc3545; font-weight: bold; }
""",
) as demo:
gr.Markdown(
"""
# LangGraph Multi-Agent MCTS Framework
**Proof-of-Concept Demo** - Multi-agent reasoning with Monte Carlo Tree Search
This demo showcases:
- **HRM**: Hierarchical Reasoning Module - breaks down complex queries
- **TRM**: Tree Reasoning Module - iterative refinement of responses
- **MCTS**: Monte Carlo Tree Search - strategic exploration of solution space
- **Consensus**: Agreement scoring between agents
---
"""
)
with gr.Row():
with gr.Column(scale=2):
query_input = gr.Textbox(
label="Query", placeholder="Enter your reasoning task or question...", lines=3, max_lines=10
)
gr.Markdown("**Example Queries:**")
example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True)
def load_example(example):
return example
example_dropdown.change(load_example, example_dropdown, query_input)
with gr.Column(scale=1):
gr.Markdown("**Agent Configuration**")
use_hrm = gr.Checkbox(label="Enable HRM (Hierarchical)", value=True)
use_trm = gr.Checkbox(label="Enable TRM (Iterative)", value=True)
use_mcts = gr.Checkbox(label="Enable MCTS", value=False)
gr.Markdown("**MCTS Parameters**")
mcts_iterations = gr.Slider(
minimum=10,
maximum=100,
value=25,
step=5,
label="Iterations",
info="More iterations = better search, but slower",
)
exploration_weight = gr.Slider(
minimum=0.1,
maximum=3.0,
value=1.414,
step=0.1,
label="Exploration Weight (C)",
info="Higher = more exploration, Lower = more exploitation",
)
seed_input = gr.Number(label="Random Seed (0 for random)", value=0, precision=0)
with gr.Accordion("Weights & Biases Tracking", open=False):
gr.Markdown(
"""
**Experiment Tracking with W&B**
Track your experiments, visualize metrics, and compare runs.
Requires W&B API key set in Space secrets as `WANDB_API_KEY`.
"""
)
with gr.Row():
enable_wandb = gr.Checkbox(
label="Enable W&B Tracking", value=False, info="Log metrics and results to Weights & Biases"
)
wandb_project = gr.Textbox(
label="Project Name", value="langgraph-mcts-demo", placeholder="Your W&B project name"
)
wandb_run_name = gr.Textbox(label="Run Name (optional)", value="", placeholder="Auto-generated if empty")
wandb_status = gr.Markdown(f"**W&B Status:** {'Available' if is_wandb_available() else 'Not installed'}")
process_btn = gr.Button("Process Query", variant="primary", size="lg")
gr.Markdown("---")
with gr.Row():
with gr.Column():
gr.Markdown("### Final Response")
final_response_output = gr.Textbox(label="Synthesized Response", lines=4, interactive=False)
gr.Markdown("### Performance Metrics")
metrics_output = gr.Markdown()
with gr.Column():
gr.Markdown("### Agent Details")
agent_details_output = gr.JSON(label="Individual Agent Results")
with gr.Accordion("Full JSON Result", open=False):
full_result_output = gr.JSON(label="Complete Framework Output")
with gr.Accordion("W&B Run Details", open=False, visible=True):
wandb_url_output = gr.Textbox(
label="W&B Run URL", interactive=False, placeholder="Enable W&B tracking to see run URL here"
)
# Wire up the processing
process_btn.click(
fn=process_query_sync,
inputs=[
query_input,
use_hrm,
use_trm,
use_mcts,
mcts_iterations,
exploration_weight,
seed_input,
enable_wandb,
wandb_project,
wandb_run_name,
],
outputs=[final_response_output, agent_details_output, metrics_output, full_result_output, wandb_url_output],
)
gr.Markdown(
"""
---
### About This Demo
This is a **proof-of-concept** demonstration of the LangGraph Multi-Agent MCTS Framework.
**Features:**
- Multi-agent orchestration with consensus scoring
- Monte Carlo Tree Search for strategic reasoning
- Configurable exploration vs exploitation trade-offs
- Deterministic results with seeded randomness
- **Weights & Biases integration** for experiment tracking
**Limitations (POC):**
- Uses mock/simplified LLM responses (not production LLM)
- Limited to demonstration scenarios
- No persistent storage or RAG
- Simplified MCTS implementation
**Full Framework:** [GitHub Repository](https://github.com/ianshank/langgraph_multi_agent_mcts)
---
*Built with LangGraph, Gradio, Weights & Biases, and Python*
"""
)
if __name__ == "__main__":
# Initialize with mock client for demo
framework = MultiAgentFrameworkDemo(use_hf_inference=False)
# Launch the demo
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)