| import os |
| from PIL import Image |
| import torchvision.transforms as T |
| from torch.utils.data import Dataset |
|
|
| class MultiStyleDataset(Dataset): |
| def __init__(self, styles_root_dir, img_size=256): |
| self.styles = [] |
| self.style_ids = {} |
| |
| |
| for i, style_dir in enumerate(os.listdir(styles_root_dir)): |
| if not style_dir.startswith("."): |
| continue |
| self.style_ids[style_dir] = i |
| style_path = os.path.join(styles_root_dir, style_dir) |
| for img_name in os.listdir(style_path): |
| self.styles.append((os.path.join(style_path, img_name), i)) |
| |
| self.transform = T.Compose([ |
| T.Resize(img_size), |
| T.ToTensor(), |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
| def __len__(self): |
| return len(self.styles) |
|
|
| def __getitem__(self, idx): |
| img_path, style_id = self.styles[idx] |
| img = Image.open(img_path).convert("RGB") |
| return self.transform(img), style_id |