| """ |
| ColVec1 - ColVec1 retrieval wrapper for late interaction. |
| """ |
|
|
| import glob |
| import json |
| import os |
| from typing import ClassVar, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModelForImageTextToText, PreTrainedModel |
|
|
| from .configuration_colvec1 import ColVec1Config |
|
|
|
|
| class ColVec1PreTrainedModel(PreTrainedModel): |
| """Base class for ColVec1 models.""" |
|
|
| config_class = ColVec1Config |
| base_model_prefix = "colvec1" |
| supports_gradient_checkpointing = True |
| _tied_weights_keys: ClassVar[List[str]] = [] |
|
|
|
|
| class ColVec1(ColVec1PreTrainedModel): |
| """ |
| Retrieval model wrapper for ColVec1 checkpoints. |
| |
| It loads the upstream model with `AutoModelForImageTextToText`, then adds |
| a projection head to produce L2-normalized retrieval embeddings. |
| """ |
|
|
| main_input_name: ClassVar[str] = "input_ids" |
|
|
| def __init__(self, config: ColVec1Config): |
| super().__init__(config) |
| self.config = config |
| self.vlm = None |
| self.embedding_proj_layer = nn.Linear(config.text_hidden_size, config.embed_dim) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| kwargs.pop("output_hidden_states", None) |
| kwargs.pop("return_dict", None) |
|
|
| vlm_outputs = self.vlm( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pixel_values=pixel_values, |
| output_hidden_states=True, |
| return_dict=True, |
| **kwargs, |
| ) |
|
|
| if hasattr(vlm_outputs, "hidden_states") and vlm_outputs.hidden_states is not None: |
| last_hidden_states = vlm_outputs.hidden_states[-1] |
| elif hasattr(vlm_outputs, "last_hidden_state"): |
| last_hidden_states = vlm_outputs.last_hidden_state |
| else: |
| last_hidden_states = vlm_outputs[0] |
|
|
| embeddings = self.embedding_proj_layer( |
| last_hidden_states.to(self.embedding_proj_layer.weight.dtype) |
| ) |
| embeddings = nn.functional.normalize(embeddings, p=2, dim=-1) |
|
|
| if attention_mask is not None: |
| embeddings = embeddings * attention_mask.unsqueeze(-1) |
|
|
| return embeddings |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| embed_dim: int = 128, |
| torch_dtype: torch.dtype = None, |
| device_map: str = None, |
| attn_impl: str = None, |
| **kwargs, |
| ): |
| |
| if torch_dtype is None: |
| torch_dtype = kwargs.pop("dtype", None) |
|
|
| |
| |
| |
| config = kwargs.pop("config", None) |
| if config is not None and hasattr(config, "embed_dim"): |
| embed_dim = config.embed_dim |
|
|
| |
| |
| |
| |
| _is_merged = ( |
| config is not None |
| and getattr(config, "model_type", None) == "colvec1" |
| ) |
|
|
| if not _is_merged: |
| config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
| if os.path.exists(config_path): |
| with open(config_path) as f: |
| raw = json.load(f) |
| _is_merged = raw.get("model_type") == "colvec1" |
| else: |
| |
| from transformers import AutoConfig |
| try: |
| hub_config = AutoConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| trust_remote_code=kwargs.get("trust_remote_code", True), |
| ) |
| _is_merged = getattr(hub_config, "model_type", None) == "colvec1" |
| except Exception: |
| pass |
|
|
| if _is_merged: |
| return cls._load_merged( |
| pretrained_model_name_or_path, |
| torch_dtype=torch_dtype, |
| device_map=device_map, |
| attn_impl=attn_impl, |
| **kwargs, |
| ) |
|
|
| |
| |
| vlm_kwargs = {"trust_remote_code": kwargs.pop("trust_remote_code", True)} |
| if torch_dtype is not None: |
| vlm_kwargs["torch_dtype"] = torch_dtype |
| if device_map is not None: |
| vlm_kwargs["device_map"] = device_map |
| if attn_impl is not None: |
| vlm_kwargs["attn_implementation"] = attn_impl |
| if "quantization_config" in kwargs: |
| vlm_kwargs["quantization_config"] = kwargs.pop("quantization_config") |
|
|
| vlm = AutoModelForImageTextToText.from_pretrained(pretrained_model_name_or_path, **vlm_kwargs) |
|
|
| if hasattr(vlm.config, "text_config") and hasattr(vlm.config.text_config, "hidden_size"): |
| text_hidden_size = vlm.config.text_config.hidden_size |
| else: |
| text_hidden_size = getattr(vlm.config, "hidden_size", 2560) |
|
|
| model_config = ColVec1Config( |
| embed_dim=embed_dim, |
| text_hidden_size=text_hidden_size, |
| padding_side="left", |
| ) |
| model = cls(model_config) |
| model.vlm = vlm |
| model.embedding_proj_layer = nn.Linear(model_config.text_hidden_size, model_config.embed_dim) |
|
|
| if torch_dtype is not None: |
| model.embedding_proj_layer = model.embedding_proj_layer.to(torch_dtype) |
|
|
| if hasattr(vlm, "device"): |
| model.embedding_proj_layer = model.embedding_proj_layer.to(vlm.device) |
|
|
| tied = getattr(vlm, "_tied_weights_keys", None) |
| if isinstance(tied, dict): |
| model._tied_weights_keys = {f"vlm.{k}": f"vlm.{v}" for k, v in tied.items()} |
| elif isinstance(tied, (list, tuple, set)): |
| model._tied_weights_keys = [f"vlm.{k}" for k in tied] |
| else: |
| model._tied_weights_keys = [] |
|
|
| return model |
|
|
| @classmethod |
| def _load_merged( |
| cls, |
| path: str, |
| torch_dtype: torch.dtype = None, |
| device_map: str = None, |
| attn_impl: str = None, |
| **kwargs, |
| ): |
| """Load a merged ColVec1 checkpoint (dense VLM weights + embedding_proj_layer).""" |
| from safetensors.torch import load_file |
| |
| |
| |
| if not os.path.isdir(path): |
| from huggingface_hub import snapshot_download |
| path = snapshot_download(path) |
|
|
| config = ColVec1Config.from_pretrained(path) |
| base_name = config.base_model_name_or_path |
| if base_name is None: |
| raise ValueError( |
| f"Merged ColVec1 config at {path} is missing 'base_model_name_or_path'. " |
| "This field is required to know which VLM architecture to instantiate." |
| ) |
|
|
| vlm_kwargs = {"trust_remote_code": True} |
| if torch_dtype is not None: |
| vlm_kwargs["torch_dtype"] = torch_dtype |
| if device_map is not None: |
| vlm_kwargs["device_map"] = device_map |
| if attn_impl is not None: |
| vlm_kwargs["attn_implementation"] = attn_impl |
|
|
| vlm = AutoModelForImageTextToText.from_pretrained(base_name, **vlm_kwargs) |
|
|
| model = cls(config) |
| model.vlm = vlm |
|
|
| safetensor_files = sorted(glob.glob(os.path.join(path, "model*.safetensors"))) |
| if not safetensor_files: |
| raise FileNotFoundError(f"No model*.safetensors files found in {path}") |
|
|
| state_dict = {} |
| for sf in safetensor_files: |
| state_dict.update(load_file(sf)) |
|
|
| model.load_state_dict(state_dict, strict=False) |
|
|
| if torch_dtype is not None: |
| model.embedding_proj_layer = model.embedding_proj_layer.to(torch_dtype) |
| if hasattr(vlm, "device"): |
| model.embedding_proj_layer = model.embedding_proj_layer.to(vlm.device) |
|
|
| tied = getattr(vlm, "_tied_weights_keys", None) |
| if isinstance(tied, dict): |
| model._tied_weights_keys = {f"vlm.{k}": f"vlm.{v}" for k, v in tied.items()} |
| elif isinstance(tied, (list, tuple, set)): |
| model._tied_weights_keys = [f"vlm.{k}" for k in tied] |
| else: |
| model._tied_weights_keys = [] |
|
|
| return model |
|
|
| def tie_weights(self, *args, **kwargs): |
| if self.vlm is None: |
| |
| return None |
| try: |
| return self.vlm.tie_weights(*args, **kwargs) |
| except TypeError: |
| return self.vlm.tie_weights() |
|
|
| def get_input_embeddings(self): |
| return self.vlm.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.vlm.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.vlm.get_output_embeddings() |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.vlm.set_output_embeddings(new_embeddings) |
|
|
| def resize_token_embeddings( |
| self, |
| new_num_tokens: Optional[int] = None, |
| pad_to_multiple_of: Optional[int] = None, |
| mean_resizing: bool = True, |
| ) -> nn.Embedding: |
| model_embeds = self.vlm.resize_token_embeddings( |
| new_num_tokens=new_num_tokens, |
| pad_to_multiple_of=pad_to_multiple_of, |
| mean_resizing=mean_resizing, |
| ) |
|
|
| if hasattr(self.vlm.config, "text_config"): |
| self.vlm.config.text_config.vocab_size = model_embeds.num_embeddings |
| if hasattr(self.vlm.config, "vocab_size"): |
| self.vlm.config.vocab_size = model_embeds.num_embeddings |
| return model_embeds |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| if self.vlm is not None and hasattr(self.vlm, "gradient_checkpointing_enable"): |
| self.vlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) |
|
|
| def gradient_checkpointing_disable(self): |
| if self.vlm is not None and hasattr(self.vlm, "gradient_checkpointing_disable"): |
| self.vlm.gradient_checkpointing_disable() |
|
|
|
|
| __all__ = ["ColVec1", "ColVec1PreTrainedModel"] |
|
|