| """CLIP Text Encoder for text-conditional diffusion.""" |
| import torch |
| import torch.nn as nn |
| from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
| class CLIPTextEncoder(nn.Module): |
| """Wrapper around CLIP text encoder for diffusion conditioning. |
| |
| Clip effectively maps images and text to the same latent space. |
| |
| """ |
|
|
| def __init__(self, model_name="openai/clip-vit-base-patch32", freeze=True): |
| super().__init__() |
| self.tokenizer = CLIPTokenizer.from_pretrained(model_name) |
| self.text_model = CLIPTextModel.from_pretrained(model_name) |
|
|
| if freeze: |
| for param in self.text_model.parameters(): |
| param.requires_grad = False |
|
|
| self.embedding_dim = self.text_model.config.hidden_size |
|
|
| def forward(self, text_prompts): |
| """ |
| Encode text prompts to embeddings. |
| |
| Args: |
| text_prompts: List of strings or single string |
| |
| Returns: |
| Text embeddings of shape [batch_size, embedding_dim] |
| """ |
| if isinstance(text_prompts, str): |
| text_prompts = [text_prompts] |
|
|
| tokens = self.tokenizer( |
| text_prompts, |
| padding=True, |
| truncation=True, |
| max_length=77, |
| return_tensors="pt" |
| ).to(self.text_model.device) |
|
|
| with torch.set_grad_enabled(self.text_model.training): |
| outputs = self.text_model(**tokens) |
| embeddings = outputs.pooler_output |
|
|
| return embeddings |
|
|
| def encode_batch(self, text_prompts): |
| """Convenience method for batch encoding.""" |
| return self.forward(text_prompts) |
|
|
| @property |
| def device(self): |
| return self.text_model.device |
|
|
|
|