ianshank
feat: add personality output and bug fixes
40ee6b4
"""
Hierarchical Reasoning Model (HRM) Agent.
Implements the HRM architecture with:
- H-Module: High-level planning and decomposition
- L-Module: Low-level execution and refinement
- Adaptive Computation Time (ACT) for dynamic depth
- Halting mechanism based on confidence thresholds
Based on: "Hierarchical Reasoning for Compositional Generalization"
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..training.system_config import HRMConfig
@dataclass
class SubProblem:
"""Represents a decomposed subproblem in the hierarchy."""
level: int # Hierarchy level (0 = root, higher = more abstract)
description: str # Natural language description
state: torch.Tensor # Latent state representation
parent_id: int | None = None # Parent subproblem ID
confidence: float = 0.0 # Confidence in this decomposition
@dataclass
class HRMOutput:
"""Output from HRM processing."""
final_state: torch.Tensor # Final processed state
subproblems: list[SubProblem] # Hierarchical decomposition
halt_step: int # Step at which halting occurred
total_ponder_cost: float # Total computation cost (for training)
convergence_path: list[float] # Confidence at each step
class AdaptiveComputationTime(nn.Module):
"""
Adaptive Computation Time (ACT) mechanism for dynamic depth.
Allows the model to "ponder" longer on difficult problems by
dynamically adjusting the number of processing steps.
"""
def __init__(self, hidden_dim: int, epsilon: float = 0.01):
super().__init__()
self.epsilon = epsilon
# Halting unit: predicts probability of halting
self.halt_fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid(),
)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, float]:
"""
Compute halting probabilities.
Args:
hidden_states: [batch, seq, hidden_dim]
Returns:
halt_probs: [batch, seq] probability of halting
ponder_cost: Scalar cost for training
"""
# Compute halting probabilities
halt_logits = self.halt_fc(hidden_states) # [batch, seq, 1]
halt_probs = halt_logits.squeeze(-1) # [batch, seq]
# Ponder cost is the expected number of steps
ponder_cost = halt_probs.sum(dim=-1).mean()
return halt_probs, ponder_cost
class HModule(nn.Module):
"""
H-Module: High-level planning and abstract reasoning.
Responsible for:
- Decomposing problems into subproblems
- Abstract planning and strategy
- Coordinating L-module executions
"""
def __init__(self, config: HRMConfig):
super().__init__()
self.config = config
# Multi-head self-attention for relational reasoning
self.attention = nn.MultiheadAttention(
embed_dim=config.h_dim,
num_heads=8,
dropout=config.dropout,
batch_first=True,
)
# Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(config.h_dim, config.h_dim * 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.h_dim * 4, config.h_dim),
nn.Dropout(config.dropout),
)
# Layer normalization
self.norm1 = nn.LayerNorm(config.h_dim)
self.norm2 = nn.LayerNorm(config.h_dim)
# Decomposition head: outputs subproblem structure
self.decompose_head = nn.Sequential(
nn.Linear(config.h_dim, config.h_dim),
nn.ReLU(),
nn.Linear(config.h_dim, config.h_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Process input through high-level reasoning.
Args:
x: [batch, seq, h_dim] input tensor
Returns:
[batch, seq, h_dim] processed tensor
"""
# Self-attention for relational reasoning
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# Feed-forward processing
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
def decompose(self, x: torch.Tensor) -> torch.Tensor:
"""Generate subproblem representations."""
return self.decompose_head(x)
class LModule(nn.Module):
"""
L-Module: Low-level execution and concrete operations.
Responsible for:
- Executing concrete operations
- Processing individual subproblems
- Generating intermediate results
"""
def __init__(self, config: HRMConfig):
super().__init__()
self.config = config
# Projection from H-module to L-module dimension
self.h_to_l = nn.Linear(config.h_dim, config.l_dim)
# GRU for sequential processing
self.gru = nn.GRU(
input_size=config.l_dim,
hidden_size=config.l_dim,
num_layers=config.num_l_layers,
dropout=config.dropout if config.num_l_layers > 1 else 0,
batch_first=True,
)
# Output projection
self.output_proj = nn.Sequential(
nn.Linear(config.l_dim, config.l_dim * 2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.l_dim * 2, config.l_dim),
)
# Back-projection to H-module dimension
self.l_to_h = nn.Linear(config.l_dim, config.h_dim)
def forward(self, x: torch.Tensor, h_context: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
"""
Execute low-level processing.
Args:
x: [batch, seq, h_dim] input from H-module
h_context: Optional hidden state
Returns:
output: [batch, seq, l_dim] processed output
l_to_h: [batch, seq, h_dim] back-projection to H-module
"""
# Project to L-module dimension
x_l = self.h_to_l(x)
# Sequential processing
gru_out, _ = self.gru(x_l, h_context)
# Output processing
output = self.output_proj(gru_out)
# Back-project to H-module dimension for feedback
feedback = self.l_to_h(output)
return output, feedback
class HRMAgent(nn.Module):
"""
Complete Hierarchical Reasoning Model agent.
Combines H-module and L-module with ACT for adaptive computation.
"""
def __init__(self, config: HRMConfig, device: str = "cpu"):
super().__init__()
self.config = config
self.device = device
# Input embedding
self.input_proj = nn.Linear(config.h_dim, config.h_dim)
# Core modules
self.h_module = nn.ModuleList([HModule(config) for _ in range(config.num_h_layers)])
self.l_module = LModule(config)
# Adaptive computation time
self.act = AdaptiveComputationTime(config.h_dim, config.ponder_epsilon)
# State integration
self.integrate = nn.Sequential(
nn.Linear(config.h_dim * 2, config.h_dim),
nn.LayerNorm(config.h_dim),
nn.GELU(),
)
self.to(device)
def forward(
self,
x: torch.Tensor,
max_steps: int | None = None,
return_decomposition: bool = False,
) -> HRMOutput:
"""
Process input through hierarchical reasoning.
Args:
x: [batch, seq, h_dim] input tensor
max_steps: Maximum outer loop steps (defaults to config)
return_decomposition: Whether to return subproblem decomposition
Returns:
HRMOutput containing final state and optional decomposition
"""
batch_size, seq_len, _ = x.shape
max_steps = max_steps or self.config.max_outer_steps
# Initial projection
h_state = self.input_proj(x)
# Tracking
subproblems = []
convergence_path = []
total_ponder_cost = 0.0
# Outer loop: iterative refinement
for step in range(max_steps):
# H-module: high-level planning
for h_layer in self.h_module:
h_state = h_layer(h_state)
# Check halting condition
halt_probs, ponder_cost = self.act(h_state)
total_ponder_cost += ponder_cost
# Average halting probability across sequence
avg_halt_prob = halt_probs.mean().item()
convergence_path.append(avg_halt_prob)
# Generate subproblem decomposition if requested
if return_decomposition:
subproblem_repr = self.h_module[0].decompose(h_state)
# Create subproblem entries (simplified)
for i in range(min(3, seq_len)): # Top 3 subproblems
subproblems.append(
SubProblem(
level=step,
description=f"Subproblem at step {step}, position {i}",
state=subproblem_repr[:, i, :].detach(),
confidence=halt_probs[:, i].mean().item(),
)
)
# Halt if confident enough
if avg_halt_prob >= self.config.halt_threshold:
break
# L-module: low-level execution
l_output, l_feedback = self.l_module(h_state)
# Integrate L-module feedback
h_state = self.integrate(torch.cat([h_state, l_feedback], dim=-1))
return HRMOutput(
final_state=h_state,
subproblems=subproblems,
halt_step=step + 1,
total_ponder_cost=total_ponder_cost,
convergence_path=convergence_path,
)
async def decompose_problem(self, query: str, state: torch.Tensor) -> list[SubProblem]:
"""
Decompose a problem into hierarchical subproblems.
Args:
query: Natural language problem description
state: Initial state representation
Returns:
List of subproblems in hierarchical order
"""
# Ensure state is batched
if state.dim() == 2:
state = state.unsqueeze(0) # [1, seq, dim]
# Forward pass with decomposition
output = self.forward(state, return_decomposition=True)
# Add query context to subproblems
for i, sp in enumerate(output.subproblems):
sp.description = f"{query} -> Level {sp.level} Subproblem {i}"
return output.subproblems
def get_parameter_count(self) -> int:
"""Return total number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# Training utilities
class HRMLoss(nn.Module):
"""
Combined loss for HRM training.
Includes:
- Task loss (e.g., cross-entropy for classification)
- Ponder cost regularization (encourages efficiency)
- Consistency loss (encourages stable convergence)
"""
def __init__(
self,
task_weight: float = 1.0,
ponder_weight: float = 0.01,
consistency_weight: float = 0.1,
):
super().__init__()
self.task_weight = task_weight
self.ponder_weight = ponder_weight
self.consistency_weight = consistency_weight
def forward(
self,
hrm_output: HRMOutput,
predictions: torch.Tensor,
targets: torch.Tensor,
task_loss_fn: nn.Module,
) -> tuple[torch.Tensor, dict]:
"""
Compute combined loss.
Args:
hrm_output: Output from HRM forward pass
predictions: Model predictions
targets: Ground truth targets
task_loss_fn: Loss function for the task
Returns:
total_loss: Combined loss
loss_dict: Dictionary of individual loss components
"""
# Task loss
task_loss = task_loss_fn(predictions, targets)
# Ponder cost (encourages efficiency)
ponder_loss = hrm_output.total_ponder_cost
# Consistency loss (encourages monotonic convergence)
if len(hrm_output.convergence_path) > 1:
conv_tensor = torch.tensor(hrm_output.convergence_path)
# Penalize non-monotonic increases
diffs = conv_tensor[1:] - conv_tensor[:-1]
consistency_loss = F.relu(-diffs).mean() # Penalize decreases
else:
consistency_loss = torch.tensor(0.0)
# Combine losses
total_loss = (
self.task_weight * task_loss + self.ponder_weight * ponder_loss + self.consistency_weight * consistency_loss
)
loss_dict = {
"total": total_loss.item(),
"task": task_loss.item(),
"ponder": ponder_loss,
"consistency": consistency_loss.item(),
"halt_step": hrm_output.halt_step,
}
return total_loss, loss_dict
def create_hrm_agent(config: HRMConfig, device: str = "cpu") -> HRMAgent:
"""
Factory function to create and initialize HRM agent.
Args:
config: HRM configuration
device: Device to place model on
Returns:
Initialized HRMAgent
"""
agent = HRMAgent(config, device)
# Initialize weights
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.GRU):
for name, param in m.named_parameters():
if "weight" in name:
nn.init.orthogonal_(param)
elif "bias" in name:
nn.init.zeros_(param)
agent.apply(init_weights)
return agent