| """Text-conditional U-Net for diffusion.""" |
| import torch |
| import torch.nn as nn |
| import math |
| import config |
|
|
|
|
| class TimeEmbedding(nn.Module): |
| """Sinusoidal time embedding.""" |
|
|
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t): |
| half_dim = self.dim // 2 |
| emb = math.log(10000) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) |
| emb = t[:, None] * emb[None, :] |
| return torch.cat([emb.sin(), emb.cos()], dim=1) |
|
|
|
|
| class ResBlock(nn.Module): |
| """Residual block with time and text conditioning.""" |
|
|
| def __init__(self, in_ch, out_ch, time_dim, text_dim=None): |
| super().__init__() |
| self.time_mlp = nn.Linear(time_dim, out_ch) |
| self.text_mlp = nn.Linear(text_dim, out_ch) if text_dim else None |
|
|
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| self.norm1 = nn.GroupNorm(min(8, in_ch), in_ch) |
| self.norm2 = nn.GroupNorm(min(8, out_ch), out_ch) |
| self.act = nn.SiLU() |
|
|
| self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
|
|
| def forward(self, x, t_emb, text_emb=None): |
| h = self.act(self.norm1(x)) |
| h = self.conv1(h) |
|
|
| |
| h = h + self.time_mlp(t_emb)[:, :, None, None] |
|
|
| |
| if self.text_mlp is not None and text_emb is not None: |
| h = h + self.text_mlp(text_emb)[:, :, None, None] |
|
|
| h = self.act(self.norm2(h)) |
| h = self.conv2(h) |
|
|
| return h + self.skip(x) |
|
|
|
|
| class TextConditionedUNet(nn.Module): |
| """U-Net with CLIP text conditioning.""" |
|
|
| def __init__(self, text_dim=512): |
| super().__init__() |
| self.text_dim = text_dim |
|
|
| self.time_emb = TimeEmbedding(config.TIME_DIM) |
| self.time_mlp = nn.Sequential( |
| nn.Linear(config.TIME_DIM, config.TIME_DIM), |
| nn.SiLU(), |
| nn.Linear(config.TIME_DIM, config.TIME_DIM) |
| ) |
|
|
| self.text_proj = nn.Sequential( |
| nn.Linear(text_dim, text_dim), |
| nn.SiLU(), |
| nn.Linear(text_dim, text_dim) |
| ) |
|
|
| |
| self.down1 = ResBlock(1, config.CHANNELS, config.TIME_DIM, text_dim) |
| self.down2 = ResBlock(config.CHANNELS, config.CHANNELS * 2, config.TIME_DIM, text_dim) |
| self.down3 = ResBlock(config.CHANNELS * 2, config.CHANNELS * 4, config.TIME_DIM, text_dim) |
|
|
| |
| self.mid = ResBlock(config.CHANNELS * 4, config.CHANNELS * 4, config.TIME_DIM, text_dim) |
|
|
| |
| self.up3 = ResBlock(config.CHANNELS * 8, config.CHANNELS * 2, config.TIME_DIM, text_dim) |
| self.up2 = ResBlock(config.CHANNELS * 4, config.CHANNELS, config.TIME_DIM, text_dim) |
| self.up1 = ResBlock(config.CHANNELS * 2, config.CHANNELS, config.TIME_DIM, text_dim) |
|
|
| |
| self.out = nn.Conv2d(config.CHANNELS, 1, 1) |
|
|
| |
| self.pool = nn.MaxPool2d(2) |
| self.upsample = nn.Upsample(scale_factor=2, mode='nearest') |
|
|
| def forward(self, x, t, text_emb): |
| """ |
| Args: |
| x: [B, 1, H, W] noisy images |
| t: [B] timesteps |
| text_emb: [B, text_dim] CLIP text embeddings |
| """ |
| |
| t_emb = self.time_mlp(self.time_emb(t)) |
| text_emb = self.text_proj(text_emb) |
|
|
| |
| h1 = self.down1(x, t_emb, text_emb) |
| h2 = self.down2(self.pool(h1), t_emb, text_emb) |
| h3 = self.down3(self.pool(h2), t_emb, text_emb) |
|
|
| |
| h = self.mid(self.pool(h3), t_emb, text_emb) |
|
|
| |
| h = self.up3(torch.cat([self.upsample(h), h3], dim=1), t_emb, text_emb) |
| h = self.up2(torch.cat([self.upsample(h), h2], dim=1), t_emb, text_emb) |
| h = self.up1(torch.cat([self.upsample(h), h1], dim=1), t_emb, text_emb) |
|
|
| return self.out(h) |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing Text-Conditioned U-Net...") |
| model = TextConditionedUNet(text_dim=512) |
|
|
| |
| batch_size = 2 |
| x = torch.randn(batch_size, 1, 64, 64) |
| t = torch.randint(0, 1000, (batch_size,)) |
| text_emb = torch.randn(batch_size, 512) |
|
|
| out = model(x, t, text_emb) |
| print(f"Input shape: {x.shape}") |
| print(f"Output shape: {out.shape}") |
| print(f"✅ Model test passed!") |