|
|
| """ |
| Addressed State Attention (ASA) - Analysis Harness |
| |
| Research implementation with mechanistic intervention capabilities. |
| For efficient training without interventions, use asm_training.py instead. |
| |
| Features: |
| - Slot-mask causal interventions (slot_mask, slot_mask_where, slot_mask_scope) |
| - Refinement decomposition (orthogonal/parallel gating) |
| - Per-head geometry logging |
| - Configurable information storage (info_level, info_cfg) |
| |
| Checkpoint Compatibility: |
| All parameter/buffer names match asm_training.py for weight sharing. |
| Do NOT rename: slot_keys, Wk_write, Wv_write, Wq_read, out_proj, |
| _alibi_slopes, _alibi_strength_param, _content_read_gamma_raw, |
| slot_in/slot_q/slot_k/slot_v/slot_out, _slotspace_gate_raw, |
| rope/rope_slotspace buffers. |
| |
| Repository: https://github.com/DigitalDaimyo/AddressedStateAttention |
| Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/tree/main/paper_drafts |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional, Dict, Tuple, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| __all__ = [ |
| 'AddressedStateAttention', |
| 'ASMBlock', |
| 'ASMLanguageModel', |
| 'ASMTrainConfig', |
| 'build_model_from_cfg', |
| ] |
|
|
|
|
| |
|
|
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
| return torch.stack((-x2, x1), dim=-1).flatten(-2) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, base: float = 10000.0): |
| super().__init__() |
| assert dim % 2 == 0, "RoPE requires even dim" |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._cos_cached = None |
| self._sin_cached = None |
| self._t_cached = None |
| self._device_cached = None |
|
|
| def get_cos_sin(self, T: int, device, dtype): |
| if ( |
| self._t_cached == T |
| and self._cos_cached is not None |
| and self._device_cached == device |
| ): |
| return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) |
| t = torch.arange(T, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum("t,f->tf", t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = emb.cos()[None, None, :, :] |
| sin = emb.sin()[None, None, :, :] |
| self._t_cached = T |
| self._device_cached = device |
| self._cos_cached = cos |
| self._sin_cached = sin |
| return cos.to(dtype=dtype), sin.to(dtype=dtype) |
|
|
|
|
| def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| return (x * cos) + (_rotate_half(x) * sin) |
|
|
|
|
| def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor: |
| def get_slopes(n): |
| def power_of_2_slopes(n): |
| start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3))) |
| ratio = start |
| return [start * (ratio ** i) for i in range(n)] |
| if math.log2(n).is_integer(): |
| return power_of_2_slopes(n) |
| closest = 2 ** math.floor(math.log2(n)) |
| return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest] |
| return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype) |
|
|
|
|
| def _inv_softplus(y: torch.Tensor) -> torch.Tensor: |
| return torch.log(torch.expm1(y)) |
|
|
|
|
| def phi(x: torch.Tensor) -> torch.Tensor: |
| """Performer-style feature map (elu + 1).""" |
| return F.elu(x) + 1.0 |
|
|
|
|
| |
|
|
| class AddressedStateAttention(nn.Module): |
| """ |
| Addressed State Attention (ASA) — unified research harness. |
| |
| Core mechanism |
| -------------- |
| * Prefix-softmax WRITE into K learned slots (streaming, O(T)) |
| * READ routing from tokens → slots (softmax / top-k / external) |
| * Content-conditioned READ term (gamma-weighted) |
| * RoPE on write keys (geometry) |
| * ALiBi bias on write logits (prefix-friendly) |
| |
| Slot-space refinement |
| --------------------- |
| * Causal linear attention in a low-dim slot-address coordinate space |
| * Produces per-token signed weights over slots |
| * Decoded through the same streaming slot-state basis |
| * Gated by learnable ``slotspace_gate`` (softplus) |
| |
| Causal intervention (slot mask) |
| ------------------------------- |
| * ``slot_mask`` [K] float/bool, 1=keep 0=mask |
| * ``slot_mask_where`` "read" | "content_read_only" | "slotspace_only" |
| * ``slot_mask_scope`` "all" | "last_pos_only" |
| |
| Refine-delta intervention (instance attrs, NO-OP by default) |
| ---------------------------------------------------------------- |
| * ``_intv_mode`` "off" | "delta_par" | "delta_orth" | "orth_gate" | … |
| * Decomposes refine delta into parallel / orthogonal vs base output |
| * See User Guide for configuration details. |
| |
| Refine-geometry logging (NO output change) |
| ------------------------------------------------ |
| * ``_log_refine_geom = True`` enables per-head geometry vectors in info dict. |
| |
| Info storage |
| ------------ |
| * ``info_level`` "basic" | "logits" | "full" |
| * ``info_cfg`` dict controlling which tensors to store, downsampling, CPU offload. |
| """ |
|
|
| |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int = 8, |
| num_slots: int = 8, |
| dropout: float = 0.1, |
| |
| read_temperature: float = 1.0, |
| write_temperature: float = 1.0, |
| state_fp32: bool = True, |
| slot_dropout: float = 0.0, |
| normalize_k: bool = False, |
| |
| use_rope_keys: bool = True, |
| rope_base: float = 10000.0, |
| |
| use_alibi_write: bool = True, |
| alibi_strength_init: float = 0.1, |
| learn_alibi_strength: bool = True, |
| min_strength: float = 0.0, |
| |
| use_content_read: bool = True, |
| content_read_init: float = -4.0, |
| content_read_max_gamma: float = 3.0, |
| |
| use_slotspace_refine: bool = True, |
| slotspace_dim: int = 32, |
| slotspace_gate_init: float = -4.0, |
| slotspace_dropout: float = 0.05, |
| slotspace_signed_weights: bool = True, |
| |
| use_rope_slotspace: bool = True, |
| rope_base_slotspace: float = 100000.0, |
| |
| write_chunk_size: int = 128, |
| slotspace_chunk_size: int = 128, |
| ): |
| super().__init__() |
| assert embed_dim % num_heads == 0 |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.num_slots = num_slots |
| self.head_dim = embed_dim // num_heads |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| self.read_temperature = float(read_temperature) |
| self.write_temperature = float(write_temperature) |
| self.state_fp32 = bool(state_fp32) |
| self.slot_dropout = float(slot_dropout) |
| self.normalize_k = bool(normalize_k) |
| self.routing_override = None |
|
|
| self.use_rope_keys = bool(use_rope_keys) |
| self.use_alibi_write = bool(use_alibi_write) |
| self.learn_alibi_strength = bool(learn_alibi_strength) |
| self.min_strength = float(min_strength) |
|
|
| self.use_content_read = bool(use_content_read) |
| self.content_read_max_gamma = float(content_read_max_gamma) |
|
|
| self.use_slotspace_refine = bool(use_slotspace_refine) |
| self.slotspace_dim = int(slotspace_dim) |
| self.slotspace_dropout = nn.Dropout(float(slotspace_dropout)) |
| self.slotspace_signed_weights = bool(slotspace_signed_weights) |
|
|
| self.write_chunk_size = int(write_chunk_size) |
| self.slotspace_chunk_size = int(slotspace_chunk_size) |
|
|
| |
| self.slot_keys = nn.Parameter( |
| torch.randn(num_heads, num_slots, self.head_dim) / math.sqrt(self.head_dim) |
| ) |
|
|
| |
| self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.Wq_read = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
| |
| self.rope = RotaryEmbedding(self.head_dim, base=rope_base) if self.use_rope_keys else None |
|
|
| |
| if self.use_alibi_write: |
| self.register_buffer("_alibi_slopes", alibi_slopes(num_heads), persistent=False) |
| else: |
| self.register_buffer("_alibi_slopes", torch.zeros(num_heads), persistent=False) |
|
|
| if self.use_alibi_write and self.learn_alibi_strength: |
| init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8) |
| self._alibi_strength_param = nn.Parameter(_inv_softplus(init)) |
| else: |
| self._alibi_strength_param = None |
| self.alibi_strength = float(alibi_strength_init) |
|
|
| |
| if self.use_content_read: |
| self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init))) |
| else: |
| self._content_read_gamma_raw = None |
|
|
| |
| self.use_rope_slotspace = bool(use_rope_slotspace) and bool(self.use_slotspace_refine) |
| if self.use_slotspace_refine: |
| self.slot_in = nn.Linear(num_slots, self.slotspace_dim, bias=False) |
| self.slot_q = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False) |
| self.slot_k = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False) |
| self.slot_v = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False) |
| self.slot_out = nn.Linear(self.slotspace_dim, num_slots, bias=False) |
| self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init))) |
| if self.use_rope_slotspace: |
| assert (self.slotspace_dim % 2) == 0, "use_rope_slotspace requires even slotspace_dim" |
| self.rope_slotspace = RotaryEmbedding(self.slotspace_dim, base=float(rope_base_slotspace)) |
| else: |
| self.rope_slotspace = None |
| else: |
| self.slot_in = None |
| self.slot_q = self.slot_k = self.slot_v = None |
| self.slot_out = None |
| self._slotspace_gate_raw = None |
| self.rope_slotspace = None |
|
|
| |
| self._intv_mode: str = "off" |
| self._intv_beta: float = 1.0 |
| self._intv_score_kind: str = "orth_frac" |
| self._intv_tau_kind: str = "pctl" |
| self._intv_tau: float = 0.15 |
| self._intv_tau_pctl: float = 75.0 |
| self._intv_mask_mode: str = "soft" |
| self._intv_soft_temp: float = 0.05 |
| self._intv_par_beta: float = 1.0 |
| self._intv_head_mask: Optional[torch.Tensor] = None |
| self._intv_score_clip_pctl: float = 99.0 |
|
|
| |
| self._log_refine_geom: bool = False |
|
|
| |
|
|
| def _alibi_strength(self, dtype, device) -> torch.Tensor: |
| if not (self.use_alibi_write and self.learn_alibi_strength): |
| return torch.tensor(self.alibi_strength, dtype=dtype, device=device) |
| return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device) |
|
|
| def _content_read_gamma(self, dtype, device) -> torch.Tensor: |
| if not self.use_content_read: |
| return torch.tensor(0.0, dtype=dtype, device=device) |
| g = F.softplus(self._content_read_gamma_raw) |
| if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0: |
| g = g.clamp(max=self.content_read_max_gamma) |
| return g.to(dtype=dtype, device=device) |
|
|
| def _slotspace_gate(self, dtype, device) -> torch.Tensor: |
| if not self.use_slotspace_refine: |
| return torch.tensor(0.0, dtype=dtype, device=device) |
| return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device) |
|
|
| |
|
|
| @staticmethod |
| def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor: |
| diff = s - m |
| diff = diff.masked_fill(~torch.isfinite(m), float("-inf")) |
| return torch.exp(diff) |
|
|
| |
|
|
| def _resolve_slot_mask( |
| self, |
| slot_mask: Optional[torch.Tensor], |
| *, |
| B: int, H: int, L: int, K: int, |
| device, dtype, scope: str, |
| ) -> Optional[torch.Tensor]: |
| """Expand [K] mask → [B,H,L,K]. Falls back to self.slot_mask attr.""" |
| if slot_mask is None: |
| slot_mask = getattr(self, "slot_mask", None) |
| if slot_mask is None: |
| return None |
| sm = slot_mask.to(device=device, dtype=dtype) |
| if sm.ndim != 1 or sm.numel() != K: |
| raise ValueError(f"slot_mask must be shape [K]={K}, got {tuple(sm.shape)}") |
| sm = sm.view(1, 1, 1, K) |
| if scope == "all": |
| return sm.expand(B, H, L, K) |
| if scope == "last_pos_only": |
| out = torch.ones((B, H, L, K), device=device, dtype=dtype) |
| out[:, :, -1:, :] = sm.expand(B, H, 1, K) |
| return out |
| raise ValueError(f"Unknown slot_mask_scope={scope!r}") |
|
|
| @staticmethod |
| def _apply_hard_mask_and_renorm(w: torch.Tensor, keep: torch.Tensor) -> torch.Tensor: |
| w = w * keep.to(w.dtype) |
| return w / w.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
|
|
| |
|
|
| @staticmethod |
| def default_info_cfg() -> Dict: |
| """Return default info_cfg dict. Copy and modify before passing to forward().""" |
| return dict( |
| store_read_weights=True, |
| store_read_logits=True, |
| store_write_logits=True, |
| store_slot_state_norm=True, |
| store_out1=False, |
| store_delta=False, |
| store_slot_w=False, |
| detach_to_cpu=False, |
| time_stride=1, |
| batch_stride=1, |
| ) |
|
|
| @staticmethod |
| def _store_tensor( |
| t: Optional[torch.Tensor], *, cfg: Dict, kind: str, |
| ) -> Optional[torch.Tensor]: |
| """Downsample + detach (+ optional CPU offload).""" |
| if t is None: |
| return None |
| bstride = int(cfg.get("batch_stride", 1)) |
| tstride = int(cfg.get("time_stride", 1)) |
| to_cpu = bool(cfg.get("detach_to_cpu", False)) |
| x = t |
| if x.dim() >= 1 and bstride > 1: |
| x = x[::bstride] |
| if x.dim() == 4 and tstride > 1: |
| if kind == "bhtk": |
| x = x[:, :, ::tstride, :] |
| elif kind == "bhkt": |
| x = x[:, :, :, ::tstride] |
| x = x.detach() |
| if to_cpu: |
| x = x.to("cpu", non_blocking=True) |
| return x |
|
|
| |
|
|
| def _compute_read_weights( |
| self, |
| *, |
| read_logits: torch.Tensor, |
| read_logits_key: torch.Tensor, |
| read_logits_content: Optional[torch.Tensor], |
| routing_mode: str, |
| routing_topk: int, |
| read_weights_override: Optional[torch.Tensor], |
| routing_noise: Optional[str], |
| routing_noise_scale: float, |
| rtemp: float, |
| sm: Optional[torch.Tensor], |
| slot_mask_where: str, |
| B: int, H: int, L: int, K: int, |
| T_total: int, |
| t0: int, t1: int, |
| q_read_c: torch.Tensor, |
| slot_keys: torch.Tensor, |
| slot_state_t: torch.Tensor, |
| valid: Optional[torch.Tensor], |
| state_dtype, |
| ) -> torch.Tensor: |
| """Compute read weights for one write-chunk. Handles noise, overrides, masks.""" |
| |
| if routing_noise is not None: |
| if routing_noise == "gumbel": |
| u = torch.rand_like(read_logits) |
| g = -torch.log(-torch.log(u.clamp_min(1e-8)).clamp_min(1e-8)) |
| read_logits = read_logits + routing_noise_scale * g |
| elif routing_noise == "gaussian": |
| read_logits = read_logits + routing_noise_scale * torch.randn_like(read_logits) |
| else: |
| raise ValueError(f"Unknown routing_noise={routing_noise}") |
|
|
| |
| if self.routing_override is not None: |
| if callable(self.routing_override): |
| ctx = dict( |
| t0=t0, t1=t1, B=B, H=H, T=T_total, K=K, d=self.head_dim, |
| rtemp=rtemp, state_dtype=state_dtype, |
| q_read_c=q_read_c, slot_keys=slot_keys, |
| slot_state_t=slot_state_t, valid=valid, |
| ) |
| read_w = self.routing_override( |
| t0, t1, read_logits, read_logits_key, read_logits_content, ctx, |
| ) |
| else: |
| read_w = self.routing_override[:, :, t0:t1, :].to(read_logits.dtype) |
| read_w = torch.nan_to_num(read_w, nan=0.0, posinf=0.0, neginf=0.0) |
| read_w = read_w.clamp_min(0.0) |
| read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
|
|
| else: |
| if routing_mode == "softmax": |
| read_w = torch.softmax(read_logits / rtemp, dim=-1) |
| elif routing_mode == "top1": |
| top = read_logits.argmax(dim=-1) |
| read_w = F.one_hot(top, num_classes=K).to(read_logits.dtype) |
| elif routing_mode == "topk": |
| kk = max(1, min(K, int(routing_topk))) |
| vals, idx = torch.topk(read_logits, k=kk, dim=-1) |
| masked = torch.full_like(read_logits, float("-inf")) |
| masked.scatter_(-1, idx, vals) |
| read_w = torch.softmax(masked / rtemp, dim=-1) |
| elif routing_mode == "external": |
| if read_weights_override is None: |
| raise ValueError("routing_mode='external' requires read_weights_override") |
| if read_weights_override.shape[-2] == T_total: |
| read_w = read_weights_override[:, :, t0:t1, :] |
| else: |
| read_w = read_weights_override |
| read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
| else: |
| raise ValueError(f"Unknown routing_mode={routing_mode}") |
|
|
| |
| if slot_mask_where == "read" and sm is not None: |
| read_w = self._apply_hard_mask_and_renorm(read_w, (sm > 0.0)) |
|
|
| return read_w |
|
|
| |
|
|
| def _apply_refine_intervention( |
| self, |
| out1: torch.Tensor, |
| delta: torch.Tensor, |
| slot_w: Optional[torch.Tensor], |
| ): |
| """Decompose refine delta into par/orth vs base output, optionally gate.""" |
| eps = 1e-8 |
| B, H, L, d = out1.shape |
|
|
| |
| hm = getattr(self, "_intv_head_mask", None) |
| if hm is not None: |
| hm = hm.to(device=out1.device).view(1, H, 1, 1).to(dtype=out1.dtype) |
|
|
| out1_norm2 = (out1 * out1).sum(dim=-1, keepdim=True).clamp_min(eps) |
| alpha = (delta * out1).sum(dim=-1, keepdim=True) / out1_norm2 |
| delta_par = alpha * out1 |
| delta_orth = delta - delta_par |
|
|
| logs = None |
|
|
| |
| if getattr(self, "_log_refine_geom", False): |
| out1n = out1.norm(dim=-1).clamp_min(eps) |
| dn = delta.norm(dim=-1).clamp_min(eps) |
| dparn = delta_par.norm(dim=-1) |
| dorthn = delta_orth.norm(dim=-1) |
| a = alpha.squeeze(-1) |
| logs = dict( |
| geom_alpha_mean=a.mean(dim=(0, 2)), |
| geom_alpha_abs=a.abs().mean(dim=(0, 2)), |
| geom_sign_pos=(a > 0).float().mean(dim=(0, 2)), |
| geom_orth_frac=(dorthn / dn).mean(dim=(0, 2)), |
| geom_d_ratio=(dn / out1n).mean(dim=(0, 2)), |
| geom_dpar_ratio=(dparn / dn).mean(dim=(0, 2)), |
| ) |
|
|
| mode = getattr(self, "_intv_mode", "off") |
| if mode is None or mode == "off": |
| return delta, logs |
|
|
| |
| if mode == "delta_par": |
| delta_mod = delta_par |
| logs = logs or {} |
| logs["alpha"] = alpha.squeeze(-1) |
|
|
| elif mode == "delta_orth": |
| delta_mod = delta_orth |
| logs = logs or {} |
| logs["alpha"] = alpha.squeeze(-1) |
|
|
| elif mode == "delta_par_plus_orth": |
| delta_mod = delta_par + delta_orth |
| logs = logs or {} |
| logs["alpha"] = alpha.squeeze(-1) |
|
|
| elif mode == "orth_gate": |
| beta = float(getattr(self, "_intv_beta", 1.0)) |
| sk = getattr(self, "_intv_score_kind", "orth_frac") |
| out1n = out1.norm(dim=-1).clamp_min(eps) |
| dorthn = delta_orth.norm(dim=-1) |
| dn = delta.norm(dim=-1).clamp_min(eps) |
|
|
| if sk == "orth_ratio": |
| score = dorthn / out1n |
| elif sk == "orth_frac": |
| score = dorthn / dn |
| elif sk == "alpha_abs": |
| score = alpha.abs().squeeze(-1) |
| elif sk == "slot_peaked": |
| if slot_w is None: |
| raise ValueError("score_kind='slot_peaked' requires slot_w") |
| p = torch.softmax(slot_w.float(), dim=-1).clamp_min(1e-8) |
| Hrw = -(p * p.log()).sum(dim=-1) |
| K = p.shape[-1] |
| score = (1.0 - Hrw / max(1e-8, math.log(K))).to(dtype=out1.dtype) |
| else: |
| raise ValueError(f"Unknown _intv_score_kind={sk}") |
|
|
| |
| clip_p = getattr(self, "_intv_score_clip_pctl", None) |
| if clip_p is not None: |
| clip_p = float(clip_p) |
| if 0.0 < clip_p < 100.0: |
| smax = torch.quantile(score.detach().flatten(), clip_p / 100.0).to(score.dtype) |
| score = torch.clamp(score, max=smax) |
|
|
| |
| tk = getattr(self, "_intv_tau_kind", "pctl") |
| if tk == "abs": |
| tau = torch.tensor(float(getattr(self, "_intv_tau", 0.15)), |
| device=score.device, dtype=score.dtype) |
| elif tk == "pctl": |
| tau = torch.quantile( |
| score.detach().flatten(), |
| float(getattr(self, "_intv_tau_pctl", 75.0)) / 100.0, |
| ).to(score.dtype) |
| else: |
| raise ValueError(f"Unknown _intv_tau_kind={tk}") |
|
|
| |
| mm = getattr(self, "_intv_mask_mode", "soft") |
| if mm == "hard": |
| mask = (score > tau).to(out1.dtype) |
| elif mm == "soft": |
| temp = max(1e-6, float(getattr(self, "_intv_soft_temp", 0.05))) |
| mask = torch.sigmoid((score - tau) / temp).to(out1.dtype) |
| else: |
| raise ValueError(f"Unknown _intv_mask_mode={mm}") |
|
|
| par_beta = float(getattr(self, "_intv_par_beta", 1.0)) |
| delta_mod = par_beta * delta_par + beta * mask.unsqueeze(-1) * delta_orth |
|
|
| logs = logs or {} |
| logs.update(dict( |
| score=score, tau=tau, mask=mask, |
| alpha=alpha.squeeze(-1), |
| out1_norm=out1n, |
| dpar_norm=delta_par.norm(dim=-1), |
| dorth_norm=dorthn, |
| )) |
| else: |
| raise ValueError(f"Unknown _intv_mode={mode}") |
|
|
| |
| if hm is not None: |
| delta_mod = hm * delta_mod + (1.0 - hm) * delta |
| logs = logs or {} |
| logs["head_mask"] = hm.squeeze(0).squeeze(-1).squeeze(-1).detach() |
|
|
| return delta_mod, logs |
|
|
| |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_info: bool = False, |
| |
| |
| routing_mode: str = "softmax", |
| routing_topk: int = 2, |
| read_weights_override: Optional[torch.Tensor] = None, |
| routing_noise: Optional[str] = None, |
| routing_noise_scale: float = 1.0, |
| |
| |
| slot_mask: Optional[torch.Tensor] = None, |
| slot_mask_where: str = "read", |
| slot_mask_scope: str = "all", |
| |
| |
| info_level: str = "full", |
| info_cfg: Optional[Dict] = None, |
| ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: |
| """ |
| Parameters |
| ---------- |
| x : [B, T, C] |
| attention_mask : [B, T] optional padding mask (1=valid, 0=pad) |
| return_info : if True, return diagnostics dict as second element |
| routing_mode : "softmax" | "top1" | "topk" | "external" |
| routing_topk : k for topk mode |
| read_weights_override : [B,H,T,K] or [B,H,L,K] for external routing |
| routing_noise : None | "gumbel" | "gaussian" |
| routing_noise_scale : scale for routing noise |
| slot_mask : [K] where 1=keep, 0=mask |
| slot_mask_where : "read" | "content_read_only" | "slotspace_only" |
| slot_mask_scope : "all" | "last_pos_only" |
| info_level : "basic" | "logits" | "full" |
| info_cfg : dict (see default_info_cfg()) |
| |
| Returns |
| ------- |
| (output, info) where info is None if return_info=False. |
| """ |
|
|
| B, T, C = x.shape |
| H, K, d = self.num_heads, self.num_slots, self.head_dim |
|
|
| |
| if info_cfg is None: |
| info_cfg = self.default_info_cfg() |
| store_read_weights = bool(info_cfg.get("store_read_weights", True)) |
| store_read_logits = bool(info_cfg.get("store_read_logits", True)) and info_level in ("logits", "full") |
| store_write_logits = bool(info_cfg.get("store_write_logits", True)) and info_level == "full" |
| store_slot_norm = bool(info_cfg.get("store_slot_state_norm", True)) and info_level == "full" |
| store_out1 = bool(info_cfg.get("store_out1", False)) and return_info |
| store_delta = bool(info_cfg.get("store_delta", False)) and return_info |
| store_slot_w = bool(info_cfg.get("store_slot_w", False)) and return_info |
|
|
| |
| k_write = self.Wk_write(x).view(B, T, H, d).transpose(1, 2) |
| v_write = self.Wv_write(x).view(B, T, H, d).transpose(1, 2) |
| q_read = self.Wq_read(x).view(B, T, H, d).transpose(1, 2) |
|
|
| if self.normalize_k: |
| k_write = F.normalize(k_write, dim=-1, eps=1e-8) |
|
|
| if self.use_rope_keys: |
| cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype) |
| k_write = apply_rope(k_write, cos, sin) |
|
|
| |
| slot_keys = self.slot_keys |
| if self.training and self.slot_dropout > 0.0: |
| drop = (torch.rand((H, K), device=x.device) < self.slot_dropout) |
| slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1) |
|
|
| |
| write_logits_raw = torch.einsum("hkd,bhtd->bhkt", slot_keys, k_write) / math.sqrt(d) |
| state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype |
| write_logits = write_logits_raw.to(state_dtype) / max(1e-6, self.write_temperature) |
|
|
| |
| alibi_bias_applied = None |
| if self.use_alibi_write: |
| strength = self._alibi_strength(dtype=state_dtype, device=x.device) |
| slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength |
| pos_i = torch.arange(T, device=x.device, dtype=state_dtype) |
| alibi_bias = slopes.view(1, H, 1, 1) * pos_i.view(1, 1, 1, T) |
| write_logits = write_logits + alibi_bias |
| alibi_bias_applied = alibi_bias |
|
|
| |
| if attention_mask is not None: |
| valid = attention_mask.to(dtype=torch.bool) |
| write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf")) |
| else: |
| valid = None |
|
|
| |
| |
| |
| content_read_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device) |
| rtemp = max(1e-6, self.read_temperature) |
|
|
| out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) |
|
|
| out1_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_out1 else None |
| delta_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_delta else None |
| slot_w_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if store_slot_w else None |
|
|
| need_rw = bool(self.use_slotspace_refine) or (return_info and store_read_weights) |
| read_weights = torch.empty((B, H, T, K), device=x.device, dtype=q_read.dtype) if need_rw else None |
|
|
| slot_state_norm_t = ( |
| torch.empty((B, H, T, K), device=x.device, dtype=torch.float32) |
| if (return_info and store_slot_norm) else None |
| ) |
|
|
| if return_info and store_read_logits: |
| read_logits_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) |
| read_logits_key_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) |
| read_logits_content_full = ( |
| torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if self.use_content_read else None |
| ) |
| else: |
| read_logits_full = read_logits_key_full = read_logits_content_full = None |
|
|
| |
| denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype) |
| numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype) |
| m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype) |
|
|
| WRITE_CHUNK = self.write_chunk_size |
|
|
| for t0 in range(0, T, WRITE_CHUNK): |
| t1 = min(T, t0 + WRITE_CHUNK) |
| L = t1 - t0 |
|
|
| wlog_c = write_logits[:, :, :, t0:t1] |
| m_c, _ = torch.cummax(wlog_c, dim=-1) |
| m_new = torch.maximum(m_state.unsqueeze(-1), m_c) |
|
|
| scale = torch.exp(m_state.unsqueeze(-1) - m_new) |
| denom_c = denom_state.unsqueeze(-1) * scale |
| numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1) |
|
|
| w_new = self._safe_exp_sub_max(wlog_c, m_new) |
| denom_c = denom_c + torch.cumsum(w_new, dim=-1) |
|
|
| v_c = v_write[:, :, t0:t1, :].to(state_dtype) |
| add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2) |
| numer_c = numer_c + add |
|
|
| slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1) |
| slot_state_t = slot_state_c.permute(0, 1, 3, 2, 4).contiguous() |
|
|
| |
| q_read_c = q_read[:, :, t0:t1, :] |
| read_logits_key = torch.einsum("bhld,hkd->bhlk", q_read_c, slot_keys) / math.sqrt(d) |
|
|
| read_logits_content = None |
| if self.use_content_read: |
| read_logits_content = torch.einsum( |
| "bhld,bhlkd->bhlk", q_read_c, slot_state_t.to(q_read_c.dtype), |
| ) / math.sqrt(d) |
|
|
| |
| sm = self._resolve_slot_mask( |
| slot_mask, B=B, H=H, L=L, K=K, |
| device=x.device, dtype=read_logits_key.dtype, scope=slot_mask_scope, |
| ) |
|
|
| |
| if slot_mask_where == "read": |
| if sm is not None: |
| read_logits_key = read_logits_key.masked_fill(sm <= 0.0, float("-inf")) |
| if self.use_content_read and read_logits_content is not None: |
| read_logits_content = read_logits_content.masked_fill(sm <= 0.0, float("-inf")) |
| elif slot_mask_where == "content_read_only": |
| if sm is not None and self.use_content_read and read_logits_content is not None: |
| read_logits_content = read_logits_content.masked_fill(sm <= 0.0, 0.0) |
| elif slot_mask_where == "slotspace_only": |
| pass |
| else: |
| raise ValueError(f"Unknown slot_mask_where={slot_mask_where!r}") |
|
|
| |
| rl = read_logits_key |
| if self.use_content_read and read_logits_content is not None: |
| rl = rl + content_read_gamma.to(rl.dtype) * read_logits_content |
|
|
| if return_info and store_read_logits: |
| read_logits_full[:, :, t0:t1, :] = rl.to(state_dtype) |
| read_logits_key_full[:, :, t0:t1, :] = read_logits_key.to(state_dtype) |
| if self.use_content_read and read_logits_content_full is not None: |
| read_logits_content_full[:, :, t0:t1, :] = read_logits_content.to(state_dtype) |
|
|
| |
| read_w_c = self._compute_read_weights( |
| read_logits=rl, read_logits_key=read_logits_key, |
| read_logits_content=read_logits_content, |
| routing_mode=routing_mode, routing_topk=routing_topk, |
| read_weights_override=read_weights_override, |
| routing_noise=routing_noise, routing_noise_scale=routing_noise_scale, |
| rtemp=rtemp, sm=sm, slot_mask_where=slot_mask_where, |
| B=B, H=H, L=L, K=K, T_total=T, t0=t0, t1=t1, |
| q_read_c=q_read_c, slot_keys=slot_keys, |
| slot_state_t=slot_state_t, valid=valid, |
| state_dtype=state_dtype, |
| ) |
|
|
| if read_weights is not None: |
| read_weights[:, :, t0:t1, :] = read_w_c |
|
|
| |
| out_h[:, :, t0:t1, :] = torch.einsum( |
| "bhlk,bhlkd->bhld", read_w_c.to(state_dtype), slot_state_t.to(state_dtype), |
| ) |
|
|
| if out1_full is not None: |
| out1_full[:, :, t0:t1, :] = out_h[:, :, t0:t1, :] |
|
|
| if slot_state_norm_t is not None: |
| slot_state_norm_t[:, :, t0:t1, :] = slot_state_t.to(torch.float32).norm(dim=-1) |
|
|
| m_state = m_new[:, :, :, -1] |
| denom_state = denom_c[:, :, :, -1] |
| numer_state = numer_c[:, :, :, -1, :] |
|
|
| |
| |
| |
| slotspace_delta_norm_mean = None |
| intv_logs_acc: Optional[Dict] = None |
| intv_logs_count = 0 |
|
|
| if self.use_slotspace_refine: |
| slotspace_dtype = state_dtype |
| M = self.slotspace_dim |
| assert read_weights is not None |
|
|
| u = self.slot_in(read_weights.to(slotspace_dtype)) |
| q_s = self.slot_q(u) |
| k_s = self.slot_k(u) |
| v_s = self.slot_v(u) |
|
|
| if self.use_rope_slotspace: |
| cos_s, sin_s = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=q_s.dtype) |
| q_s = apply_rope(q_s, cos_s, sin_s) |
| k_s = apply_rope(k_s, cos_s, sin_s) |
|
|
| qf = phi(q_s) |
| kf = phi(k_s) |
|
|
| if valid is not None: |
| vmask = valid.view(B, 1, T, 1).to(slotspace_dtype) |
| qf = qf * vmask |
| kf = kf * vmask |
| v_s = v_s * vmask |
|
|
| u2 = torch.empty((B, H, T, M), device=x.device, dtype=slotspace_dtype) |
| S_state = torch.zeros((B, H, M, M), device=x.device, dtype=slotspace_dtype) |
| Z_state = torch.zeros((B, H, M), device=x.device, dtype=slotspace_dtype) |
|
|
| SS_CHUNK = self.slotspace_chunk_size |
| for t0 in range(0, T, SS_CHUNK): |
| t1 = min(T, t0 + SS_CHUNK) |
| qf_c = qf[:, :, t0:t1, :] |
| kf_c = kf[:, :, t0:t1, :] |
| v_c = v_s[:, :, t0:t1, :] |
|
|
| kv = torch.einsum("bhlm,bhln->bhlmn", kf_c, v_c) |
| S_c = torch.cumsum(kv, dim=2) + S_state.unsqueeze(2) |
| Z_c = (torch.cumsum(kf_c, dim=2) + Z_state.unsqueeze(2)).clamp_min(1e-8) |
|
|
| num = torch.einsum("bhlm,bhlmn->bhln", qf_c, S_c) |
| den = torch.einsum("bhlm,bhlm->bhl", qf_c, Z_c).unsqueeze(-1).clamp_min(1e-8) |
| u2[:, :, t0:t1, :] = num / den |
|
|
| S_state = S_c[:, :, -1, :, :] |
| Z_state = Z_c[:, :, -1, :] |
|
|
| u2 = self.slotspace_dropout(u2) |
| slot_w = self.slot_out(u2) |
|
|
| if slot_w_full is not None: |
| slot_w_full[:] = slot_w.to(state_dtype) |
|
|
| if self.slotspace_signed_weights: |
| slot_w_eff = torch.tanh(slot_w) |
| else: |
| slot_w_eff = torch.softmax(slot_w, dim=-1) |
|
|
| |
| if slot_mask_where == "slotspace_only": |
| sm_full = self._resolve_slot_mask( |
| slot_mask, B=B, H=H, L=T, K=K, |
| device=x.device, dtype=slot_w_eff.dtype, scope=slot_mask_scope, |
| ) |
| if sm_full is not None: |
| slot_w_eff = slot_w_eff * (sm_full > 0.0).to(slot_w_eff.dtype) |
| if not self.slotspace_signed_weights: |
| slot_w_eff = slot_w_eff / slot_w_eff.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
|
|
| gate = self._slotspace_gate(dtype=state_dtype, device=x.device).to(state_dtype) |
|
|
| |
| denom_state2 = torch.zeros((B, H, K), device=x.device, dtype=state_dtype) |
| numer_state2 = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype) |
| m_state2 = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype) |
|
|
| delta_norm_sum = torch.zeros((), device=x.device, dtype=torch.float32) |
| delta_norm_count = 0 |
|
|
| for t0 in range(0, T, WRITE_CHUNK): |
| t1 = min(T, t0 + WRITE_CHUNK) |
| Lc = t1 - t0 |
|
|
| wlog_c = write_logits[:, :, :, t0:t1] |
| m_c, _ = torch.cummax(wlog_c, dim=-1) |
| m_new = torch.maximum(m_state2.unsqueeze(-1), m_c) |
|
|
| scale = torch.exp(m_state2.unsqueeze(-1) - m_new) |
| denom_c = denom_state2.unsqueeze(-1) * scale |
| numer_c = numer_state2.unsqueeze(-2) * scale.unsqueeze(-1) |
|
|
| w_new = self._safe_exp_sub_max(wlog_c, m_new) |
| denom_c = denom_c + torch.cumsum(w_new, dim=-1) |
|
|
| v_c = v_write[:, :, t0:t1, :].to(state_dtype) |
| add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2) |
| numer_c = numer_c + add |
|
|
| slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1) |
| slot_state_t2 = slot_state_c.permute(0, 1, 3, 2, 4).contiguous() |
|
|
| slot_w_c = slot_w_eff[:, :, t0:t1, :].to(state_dtype) |
| delta_c = torch.einsum("bhlk,bhlkd->bhld", slot_w_c, slot_state_t2.to(state_dtype)) |
|
|
| delta = gate * delta_c |
|
|
| if delta_full is not None: |
| delta_full[:, :, t0:t1, :] = delta |
|
|
| |
| slot_w_for_score = slot_w[:, :, t0:t1, :] if store_slot_w else None |
| delta_mod, logs = self._apply_refine_intervention( |
| out1=out_h[:, :, t0:t1, :], delta=delta, slot_w=slot_w_for_score, |
| ) |
|
|
| out_h[:, :, t0:t1, :] = out_h[:, :, t0:t1, :] + delta_mod |
|
|
| |
| if logs is not None and return_info: |
| if intv_logs_acc is None: |
| intv_logs_acc = {} |
| for klog, v in logs.items(): |
| if torch.is_tensor(v): |
| vv = v.detach().to(torch.float32) |
| intv_logs_acc[klog] = vv if vv.ndim == 1 else vv.mean() |
| intv_logs_count = 1 |
| else: |
| for klog, v in logs.items(): |
| if torch.is_tensor(v) and klog in intv_logs_acc: |
| vv = v.detach().to(torch.float32) |
| intv_logs_acc[klog] = intv_logs_acc[klog] + (vv if vv.ndim == 1 else vv.mean()) |
| intv_logs_count += 1 |
|
|
| delta_norm_sum = delta_norm_sum + delta.detach().to(torch.float32).norm(dim=-1).sum() |
| delta_norm_count += B * H * Lc |
|
|
| m_state2 = m_new[:, :, :, -1] |
| denom_state2 = denom_c[:, :, :, -1] |
| numer_state2 = numer_c[:, :, :, -1, :] |
|
|
| slotspace_delta_norm_mean = (delta_norm_sum / max(1, delta_norm_count)).detach().cpu() |
|
|
| |
| |
| |
| out = out_h.transpose(1, 2).contiguous().view(B, T, C) |
| out = self.out_proj(out) |
| out = self.dropout(out) |
|
|
| |
| info = None |
| if return_info: |
| info = { |
| "content_read_gamma": content_read_gamma.detach().to(torch.float32).cpu(), |
| "routing_mode": routing_mode, |
| "slot_mask_where": slot_mask_where, |
| "slot_mask_scope": slot_mask_scope, |
| "intv_mode": getattr(self, "_intv_mode", "off"), |
| } |
|
|
| if alibi_bias_applied is not None and info_level == "full": |
| info["alibi_bias_applied"] = self._store_tensor(alibi_bias_applied.to(torch.float32), cfg=info_cfg, kind="other") |
|
|
| if self.use_alibi_write and self.learn_alibi_strength: |
| info["alibi_strength"] = self._alibi_strength(dtype=torch.float32, device=x.device).detach().cpu() |
|
|
| if self.use_slotspace_refine: |
| info["slotspace_gate"] = self._slotspace_gate(dtype=torch.float32, device=x.device).detach().cpu() |
| info["use_rope_slotspace"] = torch.tensor(bool(self.use_rope_slotspace)) |
| if slotspace_delta_norm_mean is not None: |
| info["slotspace_delta_norm"] = slotspace_delta_norm_mean |
|
|
| |
| if store_read_weights and read_weights is not None: |
| info["read_weights"] = self._store_tensor(read_weights, cfg=info_cfg, kind="bhtk") |
| else: |
| info["read_weights"] = None |
|
|
| |
| if store_slot_norm and slot_state_norm_t is not None: |
| s = slot_state_norm_t.permute(0, 1, 3, 2).contiguous() |
| info["slot_state_norm"] = self._store_tensor(s, cfg=info_cfg, kind="bhkt") |
| else: |
| info["slot_state_norm"] = None |
|
|
| |
| if store_read_logits and read_logits_full is not None: |
| info["read_logits"] = self._store_tensor(read_logits_full.to(torch.float32), cfg=info_cfg, kind="bhtk") |
| info["read_logits_key"] = self._store_tensor(read_logits_key_full.to(torch.float32), cfg=info_cfg, kind="bhtk") |
| info["read_logits_content"] = ( |
| self._store_tensor(read_logits_content_full.to(torch.float32), cfg=info_cfg, kind="bhtk") |
| if read_logits_content_full is not None else None |
| ) |
| else: |
| info["read_logits"] = info["read_logits_key"] = info["read_logits_content"] = None |
|
|
| |
| if store_write_logits and info_level == "full": |
| info["write_logits_raw"] = self._store_tensor(write_logits_raw, cfg=info_cfg, kind="bhkt") |
| info["write_logits"] = self._store_tensor(write_logits.to(torch.float32), cfg=info_cfg, kind="bhkt") |
| else: |
| info["write_logits_raw"] = info["write_logits"] = None |
|
|
| |
| info["out1"] = self._store_tensor(out1_full.to(torch.float32), cfg=info_cfg, kind="other") if out1_full is not None else None |
| info["delta"] = self._store_tensor(delta_full.to(torch.float32), cfg=info_cfg, kind="other") if delta_full is not None else None |
| info["slot_w"] = self._store_tensor(slot_w_full.to(torch.float32), cfg=info_cfg, kind="bhtk") if slot_w_full is not None else None |
|
|
| |
| if intv_logs_acc is not None and intv_logs_count > 0: |
| for klog, v in intv_logs_acc.items(): |
| info[klog] = (v / float(intv_logs_count)).detach().cpu() |
|
|
| |
| for alias_from, alias_to in [ |
| ("score", "intv_score_mean"), ("mask", "intv_mask_mean"), |
| ("tau", "intv_tau"), ("alpha", "intv_alpha_mean"), |
| ("out1_norm", "intv_out1_norm_mean"), |
| ("dpar_norm", "intv_dpar_norm_mean"), |
| ("dorth_norm", "intv_dorth_norm_mean"), |
| ]: |
| if alias_from in intv_logs_acc: |
| val = info.get(alias_from) |
| if torch.is_tensor(val) and val.ndim != 1: |
| info[alias_to] = val |
|
|
| return out, info |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @dataclass |
| class ASMTrainConfig: |
| |
| dataset_name: str = "wikitext" |
| dataset_config: str = "wikitext-103-raw-v1" |
| tokenizer_name: str = "gpt2" |
|
|
| max_seq_len: int = 256 |
| stride_frac_val: float = 0.50 |
| seed: int = 1337 |
|
|
| micro_batch_size: int = 2 |
| grad_accum_steps: int = 8 |
| train_samples_target: int = 100_000_000 |
| val_samples_target: int = 25_000 |
|
|
| |
| batch_size: int = 64 |
| learning_rate: float = 3e-4 |
| weight_decay: float = 0.01 |
| betas: Tuple[float, float] = (0.9, 0.95) |
| grad_clip: float = 1.0 |
| warmup_steps: int = 1_000 |
| total_steps: int = 75_000 |
| eval_interval: int = 1_000 |
| log_interval: int = 100 |
|
|
| |
| vocab_size: int = 50257 |
| embed_dim: int = 384 |
| num_layers: int = 23 |
| num_heads: int = 8 |
| num_slots: int = 32 |
| mlp_ratio: float = 4.0 |
| dropout: float = 0.1 |
| tie_weights: bool = True |
|
|
| |
| read_temperature: float = 1.0 |
| write_temperature: float = 1.0 |
| slot_dropout: float = 0.05 |
| state_fp32: bool = True |
| normalize_k: bool = False |
|
|
| |
| use_abs_pos: bool = False |
| use_rope_keys: bool = True |
| rope_base: float = 10000.0 |
| use_alibi_write: bool = True |
| alibi_strength_init: float = 0.1 |
| learn_alibi_strength: bool = True |
| min_strength: float = 0.0 |
|
|
| |
| use_content_read: bool = True |
| content_read_init: float = -4.0 |
| content_read_max_gamma: float = 3.0 |
|
|
| |
| use_slotspace_refine: bool = True |
| slotspace_dim: int = 64 |
| slotspace_gate_init: float = -4.0 |
| slotspace_dropout: float = 0.05 |
| slotspace_signed_weights: bool = True |
|
|
| |
| use_rope_slotspace: bool = True |
| rope_base_slotspace: float = 100000.0 |
|
|
| |
| write_chunk_size: int = 128 |
| slotspace_chunk_size: int = 128 |
| enable_compiled: bool = False |
|
|
| |
| eval_max_batches: int = 150 |
| analytics_last_k: int = 32 |
|
|
| |
| output_dir: str = "./drive/MyDrive/asm_outputs" |
| tag: str = "asm_wikitext" |
| cache_dir: str = "./drive/MyDrive/asm_caches" |
| val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl" |
|
|
|
|
| |
| |
| |
| class ASMBlock(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| num_slots: int, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.1, |
| |
| read_temperature: float = 1.0, |
| write_temperature: float = 1.0, |
| state_fp32: bool = True, |
| slot_dropout: float = 0.0, |
| normalize_k: bool = False, |
| |
| use_rope_keys: bool = True, |
| rope_base: float = 10000.0, |
| use_alibi_write: bool = True, |
| |
| alibi_strength_init: float = 0.1, |
| learn_alibi_strength: bool = True, |
| min_strength: float = 0.0, |
| |
| use_content_read: bool = True, |
| content_read_init: float = -4.0, |
| content_read_max_gamma: float = 3.0, |
| |
| use_slotspace_refine: bool = True, |
| slotspace_dim: int = 32, |
| slotspace_gate_init: float = -10.0, |
| slotspace_dropout: float = 0.0, |
| slotspace_signed_weights: bool = True, |
| |
| use_rope_slotspace: bool = True, |
| rope_base_slotspace: float = 100000.0, |
| |
| write_chunk_size: int = 128, |
| slotspace_chunk_size: int = 128, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(embed_dim) |
|
|
| self.asa = AddressedStateAttention( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| num_slots=num_slots, |
| dropout=dropout, |
| read_temperature=read_temperature, |
| write_temperature=write_temperature, |
| state_fp32=state_fp32, |
| slot_dropout=slot_dropout, |
| normalize_k=normalize_k, |
| use_rope_keys=use_rope_keys, |
| rope_base=rope_base, |
| use_alibi_write=use_alibi_write, |
| alibi_strength_init=alibi_strength_init, |
| learn_alibi_strength=learn_alibi_strength, |
| min_strength=min_strength, |
| use_content_read=use_content_read, |
| content_read_init=content_read_init, |
| content_read_max_gamma=content_read_max_gamma, |
| use_slotspace_refine=use_slotspace_refine, |
| slotspace_dim=slotspace_dim, |
| slotspace_gate_init=slotspace_gate_init, |
| slotspace_dropout=slotspace_dropout, |
| slotspace_signed_weights=slotspace_signed_weights, |
| use_rope_slotspace=use_rope_slotspace, |
| rope_base_slotspace=rope_base_slotspace, |
| write_chunk_size=write_chunk_size, |
| slotspace_chunk_size=slotspace_chunk_size, |
| ) |
|
|
| self.norm2 = nn.LayerNorm(embed_dim) |
| hidden = int(embed_dim * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(embed_dim, hidden, bias=False), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden, embed_dim, bias=False), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_info: bool = False, |
| |
| routing_mode: str = "softmax", |
| routing_topk: int = 2, |
| read_weights_override: Optional[torch.Tensor] = None, |
| routing_noise: Optional[str] = None, |
| routing_noise_scale: float = 1.0, |
| |
| slot_mask: Optional[torch.Tensor] = None, |
| slot_mask_where: str = "read", |
| slot_mask_scope: str = "all", |
| |
| info_level: str = "full", |
| info_cfg: Optional[Dict] = None, |
| ): |
| a, info = self.asa( |
| self.norm1(x), |
| attention_mask=attention_mask, |
| return_info=return_info, |
| routing_mode=routing_mode, |
| routing_topk=routing_topk, |
| read_weights_override=read_weights_override, |
| routing_noise=routing_noise, |
| routing_noise_scale=routing_noise_scale, |
| slot_mask=slot_mask, |
| slot_mask_where=slot_mask_where, |
| slot_mask_scope=slot_mask_scope, |
| info_level=info_level, |
| info_cfg=info_cfg, |
| ) |
| x = x + a |
| x = x + self.mlp(self.norm2(x)) |
| return x, info |
|
|
|
|
| |
| |
| |
| class ASMLanguageModel(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int, |
| embed_dim: int = 384, |
| num_layers: int = 6, |
| num_heads: int = 8, |
| num_slots: int = 8, |
| max_seq_len: int = 1024, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.1, |
| |
| read_temperature: float = 1.0, |
| write_temperature: float = 1.0, |
| state_fp32: bool = True, |
| slot_dropout: float = 0.05, |
| normalize_k: bool = False, |
| tie_weights: bool = True, |
| |
| use_abs_pos: bool = False, |
| |
| use_rope_keys: bool = True, |
| rope_base: float = 10000.0, |
| use_alibi_write: bool = True, |
| |
| alibi_strength_init: float = 0.1, |
| learn_alibi_strength: bool = True, |
| min_strength: float = 0.0, |
| |
| use_content_read: bool = True, |
| content_read_init: float = -4.0, |
| content_read_max_gamma: float = 3.0, |
| |
| use_slotspace_refine: bool = True, |
| slotspace_dim: int = 32, |
| slotspace_gate_init: float = -10.0, |
| slotspace_dropout: float = 0.0, |
| slotspace_signed_weights: bool = True, |
| |
| use_rope_slotspace: bool = True, |
| rope_base_slotspace: float = 100000.0, |
| |
| write_chunk_size: int = 128, |
| slotspace_chunk_size: int = 128, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.embed_dim = embed_dim |
| self.max_seq_len = max_seq_len |
| self.use_abs_pos = bool(use_abs_pos) |
|
|
| self.tok = nn.Embedding(vocab_size, embed_dim) |
| self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None |
| self.drop = nn.Dropout(dropout) |
|
|
| self.blocks = nn.ModuleList([ |
| ASMBlock( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| num_slots=num_slots, |
| mlp_ratio=mlp_ratio, |
| dropout=dropout, |
| read_temperature=read_temperature, |
| write_temperature=write_temperature, |
| state_fp32=state_fp32, |
| slot_dropout=slot_dropout, |
| normalize_k=normalize_k, |
| use_rope_keys=use_rope_keys, |
| rope_base=rope_base, |
| use_alibi_write=use_alibi_write, |
| alibi_strength_init=alibi_strength_init, |
| learn_alibi_strength=learn_alibi_strength, |
| min_strength=min_strength, |
| use_content_read=use_content_read, |
| content_read_init=content_read_init, |
| content_read_max_gamma=content_read_max_gamma, |
| use_slotspace_refine=use_slotspace_refine, |
| slotspace_dim=slotspace_dim, |
| slotspace_gate_init=slotspace_gate_init, |
| slotspace_dropout=slotspace_dropout, |
| slotspace_signed_weights=slotspace_signed_weights, |
| use_rope_slotspace=use_rope_slotspace, |
| rope_base_slotspace=rope_base_slotspace, |
| write_chunk_size=write_chunk_size, |
| slotspace_chunk_size=slotspace_chunk_size, |
| ) |
| for _ in range(num_layers) |
| ]) |
|
|
| self.norm = nn.LayerNorm(embed_dim) |
| self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False) |
| if tie_weights: |
| self.lm_head.weight = self.tok.weight |
|
|
| self.apply(self._init) |
|
|
| def _init(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, std=0.02) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, std=0.02) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_info: bool = False, |
| |
| routing_mode: str = "softmax", |
| routing_topk: int = 2, |
| read_weights_override: Optional[torch.Tensor] = None, |
| routing_noise: Optional[str] = None, |
| routing_noise_scale: float = 1.0, |
| |
| slot_mask: Optional[torch.Tensor] = None, |
| slot_mask_where: str = "read", |
| slot_mask_scope: str = "all", |
| |
| info_level: str = "full", |
| info_cfg: Optional[Dict] = None, |
| ): |
| B, T = input_ids.shape |
|
|
| x = self.tok(input_ids) |
| if self.use_abs_pos: |
| pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1) |
| x = x + self.pos(pos) |
| x = self.drop(x) |
|
|
| infos: List[Optional[Dict[str, torch.Tensor]]] = [] |
| for blk in self.blocks: |
| x, info = blk( |
| x, |
| attention_mask=attention_mask, |
| return_info=return_info, |
| routing_mode=routing_mode, |
| routing_topk=routing_topk, |
| read_weights_override=read_weights_override, |
| routing_noise=routing_noise, |
| routing_noise_scale=routing_noise_scale, |
| slot_mask=slot_mask, |
| slot_mask_where=slot_mask_where, |
| slot_mask_scope=slot_mask_scope, |
| info_level=info_level, |
| info_cfg=info_cfg, |
| ) |
| if return_info: |
| infos.append(info) |
|
|
| x = self.norm(x) |
| logits = self.lm_head(x) |
| return (logits, infos) if return_info else logits |
|
|
|
|
| |
| |
| |
| def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel: |
| return ASMLanguageModel( |
| vocab_size=cfg.vocab_size, |
| embed_dim=cfg.embed_dim, |
| num_layers=cfg.num_layers, |
| num_heads=cfg.num_heads, |
| num_slots=cfg.num_slots, |
| max_seq_len=cfg.max_seq_len, |
| mlp_ratio=cfg.mlp_ratio, |
| dropout=cfg.dropout, |
| read_temperature=cfg.read_temperature, |
| write_temperature=cfg.write_temperature, |
| state_fp32=cfg.state_fp32, |
| slot_dropout=cfg.slot_dropout, |
| normalize_k=cfg.normalize_k, |
| tie_weights=cfg.tie_weights, |
| use_abs_pos=cfg.use_abs_pos, |
| use_rope_keys=cfg.use_rope_keys, |
| rope_base=cfg.rope_base, |
| use_alibi_write=cfg.use_alibi_write, |
| alibi_strength_init=cfg.alibi_strength_init, |
| learn_alibi_strength=cfg.learn_alibi_strength, |
| min_strength=cfg.min_strength, |
| use_content_read=cfg.use_content_read, |
| content_read_init=cfg.content_read_init, |
| content_read_max_gamma=cfg.content_read_max_gamma, |
| use_slotspace_refine=cfg.use_slotspace_refine, |
| slotspace_dim=cfg.slotspace_dim, |
| slotspace_gate_init=cfg.slotspace_gate_init, |
| slotspace_dropout=cfg.slotspace_dropout, |
| slotspace_signed_weights=cfg.slotspace_signed_weights, |
| use_rope_slotspace=cfg.use_rope_slotspace, |
| rope_base_slotspace=cfg.rope_base_slotspace, |
| write_chunk_size=cfg.write_chunk_size, |
| slotspace_chunk_size=cfg.slotspace_chunk_size, |
| ) |
|
|