如今已经有很多研究尝试将注意力机制的二次复杂性计算代价降到线性。在Mamba2中,作者沿用了 《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》的思路,尝试用更一般的形式来刻画注意力机制,即对于任何Y=f(QK⊤)⋅V ,而不是仅仅讨论 Softmax 自注意力。
if self.config.use_cuda: try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn self.selective_scan_cuda = selective_scan_fn except ImportError: print("Failed to import mamba_ssm. Falling back to mamba.py.") self.config.use_cuda = False def_apply_layernorms(self, dt, B, C): if self.dt_layernorm isnotNone: dt = self.dt_layernorm(dt) if self.B_layernorm isnotNone: B = self.B_layernorm(B) if self.C_layernorm isnotNone: C = self.C_layernorm(C) return dt, B, C
defforward(self, x): # x : (B, L, D) # y : (B, L, D)
# x branch # 先调整x的轴的位置(即形状 shape) 以适应 conv1d # 然后深度卷积后截取前L个输出,最后又将x的形状调整回来 x = x.transpose(1, 2) # (B, ED, L) x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter x = x.transpose(1, 2) # (B, L, ED)
x = F.silu(x) y = self.ssm(x, z)
if self.config.use_cuda: output = self.out_proj(y) # (B, L, D) return output # the rest of the operations are done in the ssm function (fused with the CUDA pscan)
# z branch z = F.silu(z)
output = y * z output = self.out_proj(output) # (B, L, D)
# 将 A_log 通过指数运算还原出来 A = -torch.exp(self.A_log.float()) # (ED, N) D = self.D.float()
# 分离出三个参数 delta, B, C deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) delta, B, C = self._apply_layernorms(delta, B, C) # 将低秩投影的 delta 投影到应该有的维度大小 delta = self.dt_proj.weight @ delta.transpose(1, 2) # (ED, dt_rank) @ (B, L, dt_rank) -> (B, ED, L) # here we just apply the matrix mul operation of delta = softplus(dt_proj(delta)) # the rest will be applied later (fused if using cuda)
if self.config.use_cuda: # these are unfortunately needed for the selective_scan_cuda function x = x.transpose(1, 2) B = B.transpose(1, 2) C = C.transpose(1, 2) z = z.transpose(1, 2)
# "softplus" + "bias" + "y * silu(z)" operations are fused # 此处借用官方实现以进行数据融合 y = self.selective_scan_cuda(x, delta, A, B, C, D, z=z, delta_softplus=True, delta_bias=self.dt_proj.bias.float()) y = y.transpose(1, 2) # (B, L, ED) else: delta = delta.transpose(1, 2) delta = F.softplus(delta + self.dt_proj.bias)
if self.config.pscan: y = self.selective_scan(x, delta, A, B, C, D) else: y = self.selective_scan_seq(x, delta, A, B, C, D)
return y
# 自实现的并行扫描 defselective_scan(self, x, delta, A, B, C, D): # x : (B, L, ED) # Δ : (B, L, ED) # A : (ED, N) # B : (B, L, N) # C : (B, L, N) # D : (ED)
A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D).
@dataclass classMamba2Config: d_model: int n_layer: int = 24 d_state: int = 128 d_conv: int = 4 expand: int = 2 headdim: int = 64# 多头SSM的 head dimension (P) chunk_size: int = 64# 矩阵分块的尺寸 matrix partition size (Q)
# lm 分词需要 vocab_size: int = 50277 pad_vocab_size_multiple: int = 16
# 将 x, B, C 调整形状和截断,一次性进行 conv1d 后再分离开 # Pad or truncate xBC seqlen to d_conv conv_state = F.pad( rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0) )
# 将x分为p个部分做多头SSD x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim) y, ssm_state = ssd( x * dt.unsqueeze(-1), A * dt, rearrange(B, "b l n -> b l 1 n"), rearrange(C, "b l n -> b l 1 n"), self.args.chunk_size, device=self.device, ) y = y + x * self.D.unsqueeze(-1) y = rearrange(y, "b l h p -> b l (h p)") y = self.norm(y, z) y = self.out_proj(y)
h = InferenceCache(conv_state, ssm_state) return y, h
defstep(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]: """Take a single inference step for the current input and hidden state Arguments u: (batch, 1, d_model) h: initial/running hidden state Return (y, h) y: (batch, 1, d_model) h: updated hidden state """ assert u.shape[1] == 1, "Only one token can be decoded per inference step"
x, B, C = torch.split( xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 ) A = -torch.exp(self.A_log) # (nheads,)
# SSM step dt = F.softplus(dt + self.dt_bias) # (batch, nheads) dA = torch.exp(dt * A) # (batch, nheads) x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim) dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x) h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C) y = y + rearrange(self.D, "h -> h 1") * x y = rearrange(y, "b h p -> b (h p)") y = self.norm(y, z) y = self.out_proj(y)
@staticmethod deffrom_pretrained(huggingface_model_id: str, device: Device = None): from transformers.utils import CONFIG_NAME, WEIGHTS_NAME from transformers.utils.hub import cached_file
config_path = cached_file(huggingface_model_id, CONFIG_NAME) assert config_path, "Failed to get huggingface config file" state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME) assert state_dict_path, "Failed to get huggingface state dict file"
# Process prompt # The input sequence to forward (non-inference path) must have length multiple that of chunk_size. # We split out excess tokens so that n_chunked tokens can be processed by one forward call and # process the rest in multiple inference steps. n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size if n_chunked > 0: _, h = self(prefix[:n_chunked].unsqueeze(0), None) else: h = [ InferenceCache.alloc(1, self.args, device=self.device) for _ inrange(self.args.n_layer) ] for i inrange(n_chunked, prefix.shape[0]): _, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate for _ inrange(max_new_length): with torch.no_grad(): out, h = self(tokens, h) logits = out[0, -1] if temperature != 1.0: logits = logits / temperature if top_k > 0: indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1] logits[indices_to_remove] = -torch.inf if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cum_probs > 0.5 sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() sorted_indices_to_remove[0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = -torch.inf probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) if next_token.item() == eos_token_id: return tokens = next_token.unsqueeze(0) yield cast(int, next_token.item()), h
defssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None): """Structed State Space Duality (SSD) - the core of Mamba-2 This is almost the exact same minimal SSD code from the blog post. Arguments x: (batch, seqlen, n_heads, d_head) A: (batch, seqlen, n_heads) B: (batch, seqlen, n_heads, d_state) C: (batch, seqlen, n_heads, d_state) Return y: (batch, seqlen, n_heads, d_head) """ assert x.shape[1] % chunk_size == 0
# Rearrange into chunks # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel) # This is not implemented and left as an exercise for the reader 😜 x, A, B, C = [ rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C) ]
A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks) L = torch.exp(segsum(A, device=device)) Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
# 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) if initial_states isNone: initial_states = torch.zeros_like(states[:, :1]) states = torch.cat([initial_states, states], dim=1) decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device)) new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")