import math import json import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from einops import rearrange, repeat, einsum from typing import Union @dataclass class ModelArgs: d_model: int n_layer: int vocab_size: int d_state: int = 16 expand: int = 2 dt_rank: Union[int, str] = 'auto' d_conv: int = 4 pad_vocab_size_multiple: int = 8 conv_bias: bool = True bias: bool = False def __post_init__(self): self.d_inner = int(self.expand * self.d_model) if self.dt_rank == 'auto': self.dt_rank = math.ceil(self.d_model / 16) if self.vocab_size % self.pad_vocab_size_multiple != 0: self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple) class MambaBlock_CD(nn.Module): def __init__(self, d_model,d_conv, d_state, bias = True, conv_bias = True): """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" super().__init__() # self.args = args self.norm = RMSNorm(d_model=d_model) self.d_inner = 2 * d_model self.dt_rank = math.ceil(d_model / 16) self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias) self.mlp_1 = nn.Linear(self.d_inner, d_model) self.mlp_2 = nn.Linear(self.d_inner, d_model) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, ) # x_proj takes in `x` and outputs the input-specific Δ, B, C self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False) # dt_proj projects Δ from dt_rank to d_in self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) A = repeat(torch.arange(1, d_state + 1), 'n -> d n', d=self.d_inner) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) self.D_p = nn.Parameter(torch.ones(self.d_inner)) self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias) def forward(self, t1,t2): ee1 = t1 ee2 = t2 (b, l, d) = t1.shape t1 = self.norm(t1) t1_and_res = self.in_proj(t1) # shape (b, l, 2 * d_in) (t1, res1) = t1_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) t1 = rearrange(t1, 'b l d_in -> b d_in l') t1 = self.conv1d(t1)[:, :, :l] t1 = rearrange(t1, 'b d_in l -> b l d_in') t1 = F.silu(t1) (b, l, d) = t2.shape t2 = self.norm(t2) t2_and_res = self.in_proj(t2) # shape (b, l, 2 * d_in) (t2, res2) = t2_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) t2 = rearrange(t2, 'b l d_in -> b d_in l') t2 = self.conv1d(t2)[:, :, :l] t2 = rearrange(t2, 'b d_in l -> b l d_in') t2 = F.silu(t2) y1,y2 = self.cssm(t1,t2) y1 = y1 * F.silu(res1) y2 = y2 * F.silu(res2) output1 = self.out_proj(y1) output2 = self.out_proj(y2) return output1 + ee1, output2 + ee2 def cssm(self, t1, t2): (d_in, n) = self.A_log.shape A = -torch.exp(self.A_log.float()) # shape (d_in, n) D = self.D.float() t1_dbl = self.x_proj(t1) # (b, l, dt_rank + 2*n) (delta, B, C) = t1_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) A_prim = -torch.exp(self.A_log.float()) # shape (d_in, n) D_prim = self.D_p.float() t2_dbl = self.x_proj(t2) # (b, l, dt_rank + 2*n) (delta, B_prim, C_prim) = t2_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) y = self.selective_scan(t1,t2, delta, A, B, C, D, A_prim, B_prim, C_prim, D_prim) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] return y def selective_scan(self, t1,t2, delta, A, B, C, D, A_prim, B_prim, C_prim, D_prim): (b, l, d_in) = t1.shape n = A.shape[1] deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) deltaB_u = einsum(delta, B, t1, 'b l d_in, b l n, b l d_in -> b l d_in n') deltaB_u_prim = einsum(delta, B_prim, t2, 'b l d_in, b l n, b l d_in -> b l d_in n') x = torch.zeros((b, d_in, n), device=deltaA.device) ys = [] for i in range(l): x = deltaA[:, i] * x + torch.abs(deltaB_u[:, i] - deltaB_u_prim[:,i]) y1 = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') ys.append(y1) y1 = torch.stack(ys, dim=1) # shape (b, l, d_in) y1 = y1 + t1 * D (b, l, d_in) = t2.shape n = A_prim.shape[1] deltaA_prim = torch.exp(einsum(delta, A_prim, 'b l d_in, d_in n -> b l d_in n')) # deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') x = torch.zeros((b, d_in, n), device=deltaA.device) ys = [] for i in range(l): x = deltaA_prim[:, i] * x + torch.abs(deltaB_u[:, i] - deltaB_u_prim[:,i]) y2 = einsum(x, C_prim[:, i, :], 'b d_in n, b n -> b d_in') ys.append(y2) y2 = torch.stack(ys, dim=1) # shape (b, l, d_in) y2 = y2 + t2 * D_prim return y1 ,y2 class MambaCSSM(nn.Module): def __init__(self, num_layers, d_model,d_conv, d_state, bias = True, conv_bias = True ): super().__init__() self.layers = nn.ModuleList([MambaBlock_CD(d_model,d_conv, d_state, bias = True, conv_bias = True) for _ in range(num_layers)]) def forward(self, t1,t2): for layer in self.layers: t1,t2 = layer(t1,t2) return t1,t2 class MambaBlock(nn.Module): def __init__(self, d_model,d_conv, d_state, bias = True, conv_bias = True): """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" super().__init__() # self.args = args self.d_inner = 2 * d_model self.dt_rank = math.ceil(d_model / 16) self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, ) # x_proj takes in `x` and outputs the input-specific Δ, B, C self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False) # dt_proj projects Δ from dt_rank to d_in self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) A = repeat(torch.arange(1, d_state + 1), 'n -> d n', d=self.d_inner) self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.d_inner)) self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias) def forward(self, x): """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. Args: x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) Returns: output: shape (b, l, d) Official Implementation: class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 """ (b, l, d) = x.shape x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) x = rearrange(x, 'b l d_in -> b d_in l') x = self.conv1d(x)[:, :, :l] x = rearrange(x, 'b d_in l -> b l d_in') x = F.silu(x) y = self.ssm(x) y = y * F.silu(res) output = self.out_proj(y) return output def ssm(self, x): """Runs the SSM. See: - Algorithm 2 in Section 3.2 in the Mamba paper [1] - run_SSM(A, B, C, u) in The Annotated S4 [2] Args: x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) Returns: output: shape (b, l, d_in) Official Implementation: mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 """ (d_in, n) = self.A_log.shape # Compute ∆ A B C D, the state space parameters. # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, # and is why Mamba is called **selective** state spaces) A = -torch.exp(self.A_log.float()) # shape (d_in, n) D = self.D.float() x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] return y def selective_scan(self, u, delta, A, B, C, D): """Does selective scan algorithm. See: - Section 2 State Space Models in the Mamba paper [1] - Algorithm 2 in Section 3.2 in the Mamba paper [1] - run_SSM(A, B, C, u) in The Annotated S4 [2] This is the classic discrete state space formula: x(t + 1) = Ax(t) + Bu(t) y(t) = Cx(t) + Du(t) except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). Args: u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) delta: shape (b, l, d_in) A: shape (d_in, n) B: shape (b, l, n) C: shape (b, l, n) D: shape (d_in,) Returns: output: shape (b, l, d_in) Official Implementation: selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. """ (b, l, d_in) = u.shape n = A.shape[1] # Discretize continuous parameters (A, B) # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: # "A is the more important term and the performance doesn't change much with the simplification on B" deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) # Note that the below is sequential, while the official implementation does a much faster parallel scan that # is additionally hardware-aware (like FlashAttention). x = torch.zeros((b, d_in, n), device=deltaA.device) ys = [] for i in range(l): x = deltaA[:, i] * x + deltaB_u[:, i] y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') ys.append(y) y = torch.stack(ys, dim=1) # shape (b, l, d_in) y = y + u * D return y class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) def forward(self, x): output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight return output