|
|
|
|
| """ |
| Universal Checkpoint Loader for ASA Models |
| |
| Loads checkpoints into either training or analysis harness. |
| |
| Repository: https://github.com/DigitalDaimyo/AddressedStateAttention |
| """ |
|
|
| import torch |
| from typing import Literal, Tuple, Dict, Any |
|
|
|
|
| __all__ = ['load_asm_checkpoint'] |
|
|
|
|
| def load_asm_checkpoint( |
| checkpoint_path: str, |
| mode: Literal["train", "analysis"] = "train", |
| device: str = None |
| ) -> Tuple[Any, Any, Dict]: |
| """ |
| Universal ASM checkpoint loader. |
| |
| Args: |
| checkpoint_path: Path to .pt checkpoint file |
| mode: "train" (efficient) or "analysis" (intervention harness) |
| device: Device to load on (defaults to cuda if available) |
| |
| Returns: |
| model: Loaded ASMLanguageModel |
| cfg: ASMTrainConfig object |
| ckpt: Full checkpoint dict (for step, loss metadata) |
| |
| Example: |
| >>> model, cfg, ckpt = load_asm_checkpoint( |
| ... "best.pt", mode="analysis", device="cuda" |
| ... ) |
| >>> print(f"Step {ckpt['step']}, Loss {ckpt['val_loss']:.3f}") |
| """ |
|
|
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| ckpt = torch.load(checkpoint_path, map_location="cpu") |
| |
| cfg_dict = ckpt.get("cfg") |
| if cfg_dict is None: |
| raise KeyError(f"Missing 'cfg' key. Available: {list(ckpt.keys())}") |
| |
| |
| if mode == "train": |
| from .training import ASMTrainConfig, build_model_from_cfg |
| else: |
| from .analysis import ASMTrainConfig, build_model_from_cfg |
| |
| |
| cfg = ASMTrainConfig(**cfg_dict) |
| model = build_model_from_cfg(cfg) |
| |
| |
| state_dict = ckpt.get("model") |
| if state_dict is None: |
| raise KeyError(f"Missing 'model' key. Available: {list(ckpt.keys())}") |
| |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| |
| if missing: |
| print(f"⚠ Missing keys: {len(missing)}") |
| if unexpected: |
| print(f"⚠ Unexpected keys: {len(unexpected)}") |
| |
| model = model.to(device).eval() |
| |
| return model, cfg, ckpt |
|
|