Spaces:
Running
Running
| """ | |
| Hierarchical Reasoning Model (HRM) Agent. | |
| Implements the HRM architecture with: | |
| - H-Module: High-level planning and decomposition | |
| - L-Module: Low-level execution and refinement | |
| - Adaptive Computation Time (ACT) for dynamic depth | |
| - Halting mechanism based on confidence thresholds | |
| Based on: "Hierarchical Reasoning for Compositional Generalization" | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..training.system_config import HRMConfig | |
| class SubProblem: | |
| """Represents a decomposed subproblem in the hierarchy.""" | |
| level: int # Hierarchy level (0 = root, higher = more abstract) | |
| description: str # Natural language description | |
| state: torch.Tensor # Latent state representation | |
| parent_id: int | None = None # Parent subproblem ID | |
| confidence: float = 0.0 # Confidence in this decomposition | |
| class HRMOutput: | |
| """Output from HRM processing.""" | |
| final_state: torch.Tensor # Final processed state | |
| subproblems: list[SubProblem] # Hierarchical decomposition | |
| halt_step: int # Step at which halting occurred | |
| total_ponder_cost: float # Total computation cost (for training) | |
| convergence_path: list[float] # Confidence at each step | |
| class AdaptiveComputationTime(nn.Module): | |
| """ | |
| Adaptive Computation Time (ACT) mechanism for dynamic depth. | |
| Allows the model to "ponder" longer on difficult problems by | |
| dynamically adjusting the number of processing steps. | |
| """ | |
| def __init__(self, hidden_dim: int, epsilon: float = 0.01): | |
| super().__init__() | |
| self.epsilon = epsilon | |
| # Halting unit: predicts probability of halting | |
| self.halt_fc = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim // 2, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, float]: | |
| """ | |
| Compute halting probabilities. | |
| Args: | |
| hidden_states: [batch, seq, hidden_dim] | |
| Returns: | |
| halt_probs: [batch, seq] probability of halting | |
| ponder_cost: Scalar cost for training | |
| """ | |
| # Compute halting probabilities | |
| halt_logits = self.halt_fc(hidden_states) # [batch, seq, 1] | |
| halt_probs = halt_logits.squeeze(-1) # [batch, seq] | |
| # Ponder cost is the expected number of steps | |
| ponder_cost = halt_probs.sum(dim=-1).mean() | |
| return halt_probs, ponder_cost | |
| class HModule(nn.Module): | |
| """ | |
| H-Module: High-level planning and abstract reasoning. | |
| Responsible for: | |
| - Decomposing problems into subproblems | |
| - Abstract planning and strategy | |
| - Coordinating L-module executions | |
| """ | |
| def __init__(self, config: HRMConfig): | |
| super().__init__() | |
| self.config = config | |
| # Multi-head self-attention for relational reasoning | |
| self.attention = nn.MultiheadAttention( | |
| embed_dim=config.h_dim, | |
| num_heads=8, | |
| dropout=config.dropout, | |
| batch_first=True, | |
| ) | |
| # Feed-forward network | |
| self.ffn = nn.Sequential( | |
| nn.Linear(config.h_dim, config.h_dim * 4), | |
| nn.GELU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.h_dim * 4, config.h_dim), | |
| nn.Dropout(config.dropout), | |
| ) | |
| # Layer normalization | |
| self.norm1 = nn.LayerNorm(config.h_dim) | |
| self.norm2 = nn.LayerNorm(config.h_dim) | |
| # Decomposition head: outputs subproblem structure | |
| self.decompose_head = nn.Sequential( | |
| nn.Linear(config.h_dim, config.h_dim), | |
| nn.ReLU(), | |
| nn.Linear(config.h_dim, config.h_dim), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Process input through high-level reasoning. | |
| Args: | |
| x: [batch, seq, h_dim] input tensor | |
| Returns: | |
| [batch, seq, h_dim] processed tensor | |
| """ | |
| # Self-attention for relational reasoning | |
| attn_out, _ = self.attention(x, x, x) | |
| x = self.norm1(x + attn_out) | |
| # Feed-forward processing | |
| ffn_out = self.ffn(x) | |
| x = self.norm2(x + ffn_out) | |
| return x | |
| def decompose(self, x: torch.Tensor) -> torch.Tensor: | |
| """Generate subproblem representations.""" | |
| return self.decompose_head(x) | |
| class LModule(nn.Module): | |
| """ | |
| L-Module: Low-level execution and concrete operations. | |
| Responsible for: | |
| - Executing concrete operations | |
| - Processing individual subproblems | |
| - Generating intermediate results | |
| """ | |
| def __init__(self, config: HRMConfig): | |
| super().__init__() | |
| self.config = config | |
| # Projection from H-module to L-module dimension | |
| self.h_to_l = nn.Linear(config.h_dim, config.l_dim) | |
| # GRU for sequential processing | |
| self.gru = nn.GRU( | |
| input_size=config.l_dim, | |
| hidden_size=config.l_dim, | |
| num_layers=config.num_l_layers, | |
| dropout=config.dropout if config.num_l_layers > 1 else 0, | |
| batch_first=True, | |
| ) | |
| # Output projection | |
| self.output_proj = nn.Sequential( | |
| nn.Linear(config.l_dim, config.l_dim * 2), | |
| nn.ReLU(), | |
| nn.Dropout(config.dropout), | |
| nn.Linear(config.l_dim * 2, config.l_dim), | |
| ) | |
| # Back-projection to H-module dimension | |
| self.l_to_h = nn.Linear(config.l_dim, config.h_dim) | |
| def forward(self, x: torch.Tensor, h_context: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Execute low-level processing. | |
| Args: | |
| x: [batch, seq, h_dim] input from H-module | |
| h_context: Optional hidden state | |
| Returns: | |
| output: [batch, seq, l_dim] processed output | |
| l_to_h: [batch, seq, h_dim] back-projection to H-module | |
| """ | |
| # Project to L-module dimension | |
| x_l = self.h_to_l(x) | |
| # Sequential processing | |
| gru_out, _ = self.gru(x_l, h_context) | |
| # Output processing | |
| output = self.output_proj(gru_out) | |
| # Back-project to H-module dimension for feedback | |
| feedback = self.l_to_h(output) | |
| return output, feedback | |
| class HRMAgent(nn.Module): | |
| """ | |
| Complete Hierarchical Reasoning Model agent. | |
| Combines H-module and L-module with ACT for adaptive computation. | |
| """ | |
| def __init__(self, config: HRMConfig, device: str = "cpu"): | |
| super().__init__() | |
| self.config = config | |
| self.device = device | |
| # Input embedding | |
| self.input_proj = nn.Linear(config.h_dim, config.h_dim) | |
| # Core modules | |
| self.h_module = nn.ModuleList([HModule(config) for _ in range(config.num_h_layers)]) | |
| self.l_module = LModule(config) | |
| # Adaptive computation time | |
| self.act = AdaptiveComputationTime(config.h_dim, config.ponder_epsilon) | |
| # State integration | |
| self.integrate = nn.Sequential( | |
| nn.Linear(config.h_dim * 2, config.h_dim), | |
| nn.LayerNorm(config.h_dim), | |
| nn.GELU(), | |
| ) | |
| self.to(device) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| max_steps: int | None = None, | |
| return_decomposition: bool = False, | |
| ) -> HRMOutput: | |
| """ | |
| Process input through hierarchical reasoning. | |
| Args: | |
| x: [batch, seq, h_dim] input tensor | |
| max_steps: Maximum outer loop steps (defaults to config) | |
| return_decomposition: Whether to return subproblem decomposition | |
| Returns: | |
| HRMOutput containing final state and optional decomposition | |
| """ | |
| batch_size, seq_len, _ = x.shape | |
| max_steps = max_steps or self.config.max_outer_steps | |
| # Initial projection | |
| h_state = self.input_proj(x) | |
| # Tracking | |
| subproblems = [] | |
| convergence_path = [] | |
| total_ponder_cost = 0.0 | |
| # Outer loop: iterative refinement | |
| for step in range(max_steps): | |
| # H-module: high-level planning | |
| for h_layer in self.h_module: | |
| h_state = h_layer(h_state) | |
| # Check halting condition | |
| halt_probs, ponder_cost = self.act(h_state) | |
| total_ponder_cost += ponder_cost | |
| # Average halting probability across sequence | |
| avg_halt_prob = halt_probs.mean().item() | |
| convergence_path.append(avg_halt_prob) | |
| # Generate subproblem decomposition if requested | |
| if return_decomposition: | |
| subproblem_repr = self.h_module[0].decompose(h_state) | |
| # Create subproblem entries (simplified) | |
| for i in range(min(3, seq_len)): # Top 3 subproblems | |
| subproblems.append( | |
| SubProblem( | |
| level=step, | |
| description=f"Subproblem at step {step}, position {i}", | |
| state=subproblem_repr[:, i, :].detach(), | |
| confidence=halt_probs[:, i].mean().item(), | |
| ) | |
| ) | |
| # Halt if confident enough | |
| if avg_halt_prob >= self.config.halt_threshold: | |
| break | |
| # L-module: low-level execution | |
| l_output, l_feedback = self.l_module(h_state) | |
| # Integrate L-module feedback | |
| h_state = self.integrate(torch.cat([h_state, l_feedback], dim=-1)) | |
| return HRMOutput( | |
| final_state=h_state, | |
| subproblems=subproblems, | |
| halt_step=step + 1, | |
| total_ponder_cost=total_ponder_cost, | |
| convergence_path=convergence_path, | |
| ) | |
| async def decompose_problem(self, query: str, state: torch.Tensor) -> list[SubProblem]: | |
| """ | |
| Decompose a problem into hierarchical subproblems. | |
| Args: | |
| query: Natural language problem description | |
| state: Initial state representation | |
| Returns: | |
| List of subproblems in hierarchical order | |
| """ | |
| # Ensure state is batched | |
| if state.dim() == 2: | |
| state = state.unsqueeze(0) # [1, seq, dim] | |
| # Forward pass with decomposition | |
| output = self.forward(state, return_decomposition=True) | |
| # Add query context to subproblems | |
| for i, sp in enumerate(output.subproblems): | |
| sp.description = f"{query} -> Level {sp.level} Subproblem {i}" | |
| return output.subproblems | |
| def get_parameter_count(self) -> int: | |
| """Return total number of trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| # Training utilities | |
| class HRMLoss(nn.Module): | |
| """ | |
| Combined loss for HRM training. | |
| Includes: | |
| - Task loss (e.g., cross-entropy for classification) | |
| - Ponder cost regularization (encourages efficiency) | |
| - Consistency loss (encourages stable convergence) | |
| """ | |
| def __init__( | |
| self, | |
| task_weight: float = 1.0, | |
| ponder_weight: float = 0.01, | |
| consistency_weight: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.task_weight = task_weight | |
| self.ponder_weight = ponder_weight | |
| self.consistency_weight = consistency_weight | |
| def forward( | |
| self, | |
| hrm_output: HRMOutput, | |
| predictions: torch.Tensor, | |
| targets: torch.Tensor, | |
| task_loss_fn: nn.Module, | |
| ) -> tuple[torch.Tensor, dict]: | |
| """ | |
| Compute combined loss. | |
| Args: | |
| hrm_output: Output from HRM forward pass | |
| predictions: Model predictions | |
| targets: Ground truth targets | |
| task_loss_fn: Loss function for the task | |
| Returns: | |
| total_loss: Combined loss | |
| loss_dict: Dictionary of individual loss components | |
| """ | |
| # Task loss | |
| task_loss = task_loss_fn(predictions, targets) | |
| # Ponder cost (encourages efficiency) | |
| ponder_loss = hrm_output.total_ponder_cost | |
| # Consistency loss (encourages monotonic convergence) | |
| if len(hrm_output.convergence_path) > 1: | |
| conv_tensor = torch.tensor(hrm_output.convergence_path) | |
| # Penalize non-monotonic increases | |
| diffs = conv_tensor[1:] - conv_tensor[:-1] | |
| consistency_loss = F.relu(-diffs).mean() # Penalize decreases | |
| else: | |
| consistency_loss = torch.tensor(0.0) | |
| # Combine losses | |
| total_loss = ( | |
| self.task_weight * task_loss + self.ponder_weight * ponder_loss + self.consistency_weight * consistency_loss | |
| ) | |
| loss_dict = { | |
| "total": total_loss.item(), | |
| "task": task_loss.item(), | |
| "ponder": ponder_loss, | |
| "consistency": consistency_loss.item(), | |
| "halt_step": hrm_output.halt_step, | |
| } | |
| return total_loss, loss_dict | |
| def create_hrm_agent(config: HRMConfig, device: str = "cpu") -> HRMAgent: | |
| """ | |
| Factory function to create and initialize HRM agent. | |
| Args: | |
| config: HRM configuration | |
| device: Device to place model on | |
| Returns: | |
| Initialized HRMAgent | |
| """ | |
| agent = HRMAgent(config, device) | |
| # Initialize weights | |
| def init_weights(m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.GRU): | |
| for name, param in m.named_parameters(): | |
| if "weight" in name: | |
| nn.init.orthogonal_(param) | |
| elif "bias" in name: | |
| nn.init.zeros_(param) | |
| agent.apply(init_weights) | |
| return agent | |