| import os, sys |
| import torch |
|
|
| |
| root_path = os.path.abspath('.') |
| sys.path.append(root_path) |
| from opt import opt |
| from architecture.rrdb import RRDBNet |
| from architecture.grl import GRL |
| from architecture.dat import DAT |
| from architecture.swinir import SwinIR |
| from architecture.cunet import UNet_Full |
|
|
|
|
| def load_rrdb(generator_weight_PATH, scale, print_options=False): |
| ''' A simpler API to load RRDB model from Real-ESRGAN |
| Args: |
| generator_weight_PATH (str): The path to the weight |
| scale (int): the scaling factor |
| print_options (bool): whether to print options to show what kinds of setting is used |
| Returns: |
| generator (torch): the generator instance of the model |
| ''' |
|
|
| |
| checkpoint_g = torch.load(generator_weight_PATH) |
|
|
| |
| if 'params_ema' in checkpoint_g: |
| |
| weight = checkpoint_g['params_ema'] |
| generator = RRDBNet(3, 3, scale=scale) |
|
|
| elif 'params' in checkpoint_g: |
| |
| weight = checkpoint_g['params'] |
| generator = RRDBNet(3, 3, scale=scale) |
|
|
| elif 'model_state_dict' in checkpoint_g: |
| |
| weight = checkpoint_g['model_state_dict'] |
| generator = RRDBNet(3, 3, scale=scale) |
|
|
| else: |
| print("This weight is not supported") |
| os._exit(0) |
|
|
|
|
| |
| old_keys = [key for key in weight] |
| for old_key in old_keys: |
| if old_key[:10] == "_orig_mod.": |
| new_key = old_key[10:] |
| weight[new_key] = weight[old_key] |
| del weight[old_key] |
|
|
| generator.load_state_dict(weight) |
| generator = generator.eval().cuda() |
|
|
|
|
| |
| if print_options: |
| if 'opt' in checkpoint_g: |
| for key in checkpoint_g['opt']: |
| value = checkpoint_g['opt'][key] |
| print(f'{key} : {value}') |
|
|
| return generator |
|
|
|
|
| def load_cunet(generator_weight_PATH, scale, print_options=False): |
| ''' A simpler API to load CUNET model from Real-CUGAN |
| Args: |
| generator_weight_PATH (str): The path to the weight |
| scale (int): the scaling factor |
| print_options (bool): whether to print options to show what kinds of setting is used |
| Returns: |
| generator (torch): the generator instance of the model |
| ''' |
| |
| |
| if scale != 2: |
| raise NotImplementedError("We only support 2x in CUNET") |
|
|
| |
| checkpoint_g = torch.load(generator_weight_PATH) |
|
|
| |
| if 'model_state_dict' in checkpoint_g: |
| |
| weight = checkpoint_g['model_state_dict'] |
| loss = checkpoint_g["lowest_generator_weight"] |
| if "iteration" in checkpoint_g: |
| iteration = checkpoint_g["iteration"] |
| else: |
| iteration = "NAN" |
| generator = UNet_Full() |
| |
| print(f"the generator weight is {loss} at iteration {iteration}") |
|
|
| else: |
| print("This weight is not supported") |
| os._exit(0) |
|
|
|
|
| |
| old_keys = [key for key in weight] |
| for old_key in old_keys: |
| if old_key[:10] == "_orig_mod.": |
| new_key = old_key[10:] |
| weight[new_key] = weight[old_key] |
| del weight[old_key] |
|
|
| generator.load_state_dict(weight) |
| generator = generator.eval().cuda() |
|
|
|
|
| |
| if print_options: |
| if 'opt' in checkpoint_g: |
| for key in checkpoint_g['opt']: |
| value = checkpoint_g['opt'][key] |
| print(f'{key} : {value}') |
|
|
| return generator |
|
|
| def load_grl(generator_weight_PATH, scale=4): |
| ''' A simpler API to load GRL model |
| Args: |
| generator_weight_PATH (str): The path to the weight |
| scale (int): Scale Factor (Usually Set as 4) |
| Returns: |
| generator (torch): the generator instance of the model |
| ''' |
|
|
| |
| checkpoint_g = torch.load(generator_weight_PATH) |
|
|
| |
| if 'model_state_dict' in checkpoint_g: |
| weight = checkpoint_g['model_state_dict'] |
|
|
| |
| generator = GRL( |
| upscale = scale, |
| img_size = 64, |
| window_size = 8, |
| depths = [4, 4, 4, 4], |
| embed_dim = 64, |
| num_heads_window = [2, 2, 2, 2], |
| num_heads_stripe = [2, 2, 2, 2], |
| mlp_ratio = 2, |
| qkv_proj_type = "linear", |
| anchor_proj_type = "avgpool", |
| anchor_window_down_factor = 2, |
| out_proj_type = "linear", |
| conv_type = "1conv", |
| upsampler = "nearest+conv", |
| ).cuda() |
|
|
| else: |
| print("This weight is not supported") |
| os._exit(0) |
|
|
|
|
| generator.load_state_dict(weight) |
| generator = generator.eval().cuda() |
|
|
|
|
| num_params = 0 |
| for p in generator.parameters(): |
| if p.requires_grad: |
| num_params += p.numel() |
| print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") |
|
|
|
|
| return generator |
|
|
|
|
|
|
| def load_dat(generator_weight_PATH, scale=4): |
|
|
| |
| checkpoint_g = torch.load(generator_weight_PATH) |
|
|
| |
| if 'model_state_dict' in checkpoint_g: |
| weight = checkpoint_g['model_state_dict'] |
|
|
| |
| generator = DAT(upscale = 4, |
| in_chans = 3, |
| img_size = 64, |
| img_range = 1., |
| depth = [6, 6, 6, 6, 6, 6], |
| embed_dim = 180, |
| num_heads = [6, 6, 6, 6, 6, 6], |
| expansion_factor = 2, |
| resi_connection = '1conv', |
| split_size = [8, 16], |
| upsampler = 'pixelshuffledirect', |
| ).cuda() |
|
|
| else: |
| print("This weight is not supported") |
| os._exit(0) |
|
|
|
|
| generator.load_state_dict(weight) |
| generator = generator.eval().cuda() |
|
|
|
|
| num_params = 0 |
| for p in generator.parameters(): |
| if p.requires_grad: |
| num_params += p.numel() |
| print(f"Number of parameters {num_params / 10 ** 6: 0.2f}") |
|
|
|
|
| return generator |
|
|