| import os |
| import requests |
| import tarfile |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import json |
| import math |
| from tqdm import tqdm |
| from transformers import BertTokenizer, BertModel |
| import gradio as gr |
|
|
| |
| class Config: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| image_size = 64 |
| batch_size = 32 |
| num_epochs = 50 |
| learning_rate = 1e-4 |
| timesteps = 1000 |
| text_embed_dim = 768 |
| num_images_options = [1, 4, 6] |
| |
| |
| coco_images_url = "http://images.cocodataset.org/zips/train2017.zip" |
| coco_annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip" |
| data_dir = "./coco_data" |
| images_dir = os.path.join(data_dir, "train2017") |
| annotations_path = os.path.join(data_dir, "annotations/instances_train2017.json") |
| |
| def __init__(self): |
| os.makedirs(self.data_dir, exist_ok=True) |
|
|
| config = Config() |
|
|
| |
| def download_and_extract_coco(): |
| if os.path.exists(config.images_dir) and os.path.exists(config.annotations_path): |
| print("COCO dataset already downloaded") |
| return |
| |
| print("Downloading COCO dataset...") |
| |
| |
| images_zip_path = os.path.join(config.data_dir, "train2017.zip") |
| if not os.path.exists(images_zip_path): |
| response = requests.get(config.coco_images_url, stream=True) |
| with open(images_zip_path, "wb") as f: |
| for chunk in tqdm(response.iter_content(chunk_size=1024)): |
| if chunk: |
| f.write(chunk) |
| |
| |
| annotations_zip_path = os.path.join(config.data_dir, "annotations_trainval2017.zip") |
| if not os.path.exists(annotations_zip_path): |
| response = requests.get(config.coco_annotations_url, stream=True) |
| with open(annotations_zip_path, "wb") as f: |
| for chunk in tqdm(response.iter_content(chunk_size=1024)): |
| if chunk: |
| f.write(chunk) |
| |
| |
| print("Extracting images...") |
| with tarfile.open(images_zip_path, "r:zip") as tar: |
| tar.extractall(config.data_dir) |
| |
| print("Extracting annotations...") |
| with tarfile.open(annotations_zip_path, "r:zip") as tar: |
| tar.extractall(config.data_dir) |
| |
| print("COCO dataset ready") |
|
|
| download_and_extract_coco() |
|
|
| |
| class TextEncoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| self.model = BertModel.from_pretrained('bert-base-uncased') |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| |
| def forward(self, texts): |
| inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64) |
| inputs = {k: v.to(config.device) for k, v in inputs.items()} |
| outputs = self.model(**inputs) |
| return outputs.last_hidden_state[:, 0, :] |
|
|
| text_encoder = TextEncoder().to(config.device) |
|
|
| |
| class ConditionalUNet(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
| self.down1 = DownBlock(64, 128) |
| self.down2 = DownBlock(128, 256) |
| |
| self.text_proj = nn.Linear(config.text_embed_dim, 256) |
| self.merge = nn.Linear(256 + 256, 256) |
| |
| self.up1 = UpBlock(256, 128) |
| self.up2 = UpBlock(128, 64) |
| self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1) |
| |
| def forward(self, x, t, text_emb): |
| x1 = F.relu(self.conv1(x)) |
| x2 = self.down1(x1) |
| x3 = self.down2(x2) |
| |
| text_emb = self.text_proj(text_emb) |
| text_emb = text_emb.unsqueeze(-1).unsqueeze(-1) |
| text_emb = text_emb.expand(-1, -1, x3.size(2), x3.size(3)) |
| |
| x = torch.cat([x3, text_emb], dim=1) |
| b, c, h, w = x.shape |
| x = x.permute(0, 2, 3, 1).reshape(b*h*w, c) |
| x = self.merge(x) |
| x = x.reshape(b, h, w, 256).permute(0, 3, 1, 2) |
| |
| x = self.up1(x) |
| x = self.up2(x) |
| return self.final(x) |
|
|
| class DownBlock(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(), |
| nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(), |
| nn.MaxPool2d(2) |
| ) |
| |
| def forward(self, x): |
| return self.conv(x) |
|
|
| class UpBlock(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(), |
| nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU() |
| ) |
| |
| def forward(self, x): |
| x = self.up(x) |
| return self.conv(x) |
|
|
| |
| betas = linear_beta_schedule(config.timesteps).to(config.device) |
| alphas = 1. - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
| sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) |
|
|
| def linear_beta_schedule(timesteps): |
| beta_start = 0.0001 |
| beta_end = 0.02 |
| return torch.linspace(beta_start, beta_end, timesteps) |
|
|
| def forward_diffusion_sample(x_0, t, device=config.device): |
| noise = torch.randn_like(x_0) |
| sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) |
| sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) |
| return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise |
|
|
| |
| class CocoDataset(Dataset): |
| def __init__(self, root_dir, annotations_file, transform=None): |
| self.root_dir = root_dir |
| self.transform = transform |
| |
| with open(annotations_file, 'r') as f: |
| data = json.load(f) |
| |
| self.images = [] |
| self.captions = [] |
| |
| image_id_to_captions = {} |
| for ann in data['annotations']: |
| if ann['image_id'] not in image_id_to_captions: |
| image_id_to_captions[ann['image_id']] = [] |
| image_id_to_captions[ann['image_id']].append(ann['caption']) |
| |
| for img in data['images']: |
| if img['id'] in image_id_to_captions: |
| self.images.append(img) |
| self.captions.append(image_id_to_captions[img['id']][0]) |
| |
| def __len__(self): |
| return len(self.images) |
| |
| def __getitem__(self, idx): |
| img_path = os.path.join(self.root_dir, self.images[idx]['file_name']) |
| image = Image.open(img_path).convert('RGB') |
| caption = self.captions[idx] |
| |
| if self.transform: |
| image = self.transform(image) |
| |
| return image, caption |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((config.image_size, config.image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
| |
| model = ConditionalUNet().to(config.device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) |
|
|
| |
| def train(): |
| dataset = CocoDataset(config.images_dir, config.annotations_path, transform) |
| dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) |
| |
| for epoch in range(config.num_epochs): |
| for batch_idx, (images, captions) in enumerate(tqdm(dataloader)): |
| images = images.to(config.device) |
| |
| |
| text_emb = text_encoder(captions) |
| |
| |
| t = torch.randint(0, config.timesteps, (images.size(0),), device=config.device) |
| |
| |
| x_noisy, noise = forward_diffusion_sample(images, t) |
| |
| |
| pred_noise = model(x_noisy, t, text_emb) |
| |
| |
| loss = F.mse_loss(pred_noise, noise) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| if batch_idx % 100 == 0: |
| print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}") |
| |
| |
| torch.save(model.state_dict(), f"model_epoch_{epoch}.pth") |
|
|
| |
| @torch.no_grad() |
| def generate(prompt, num_images=1): |
| model.eval() |
| num_images = int(num_images) |
| |
| text_emb = text_encoder([prompt]*num_images) |
| x = torch.randn((num_images, 3, config.image_size, config.image_size)).to(config.device) |
| |
| for t in reversed(range(config.timesteps)): |
| t_tensor = torch.full((num_images,), t, device=config.device) |
| pred_noise = model(x, t_tensor, text_emb) |
| alpha_t = alphas[t].view(1, 1, 1, 1) |
| alpha_cumprod_t = alphas_cumprod[t].view(1, 1, 1, 1) |
| beta_t = betas[t].view(1, 1, 1, 1) |
| |
| if t > 0: |
| noise = torch.randn_like(x) |
| else: |
| noise = torch.zeros_like(x) |
| |
| x = (1 / torch.sqrt(alpha_t)) * ( |
| x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise |
| ) + torch.sqrt(beta_t) * noise |
| |
| x = torch.clamp(x, -1, 1) |
| x = (x + 1) / 2 |
| |
| images = [] |
| for img in x: |
| img = transforms.ToPILImage()(img.cpu()) |
| images.append(img) |
| |
| return images |
|
|
| |
| def generate_and_display(prompt, num_images): |
| images = generate(prompt, num_images) |
| |
| fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5)) |
| if len(images) == 1: |
| axes.imshow(images[0]) |
| axes.axis('off') |
| else: |
| for ax, img in zip(axes, images): |
| ax.imshow(img) |
| ax.axis('off') |
| plt.tight_layout() |
| return fig |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("## GPUDiff-V1: diffussion powerful image generator!") |
| with gr.Row(): |
| prompt_input = gr.Textbox(label="Prompt", placeholder="Enter image description...") |
| num_select = gr.Dropdown(choices=config.num_images_options, value=1, label="Number of images") |
| generate_btn = gr.Button("Generate") |
| output = gr.Plot() |
| |
| generate_btn.click( |
| fn=generate_and_display, |
| inputs=[prompt_input, num_select], |
| outputs=output |
| ) |
|
|
| if __name__ == "__main__": |
| |
| train() |
| |
| demo.launch() |