| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch.nn as nn |
| |
|
| | from repcodec.modules.decoder import Decoder |
| | from repcodec.modules.encoder import Encoder |
| | from repcodec.modules.projector import Projector |
| | from repcodec.modules.quantizer import Quantizer |
| |
|
| |
|
| | class RepCodec(nn.Module): |
| | def __init__( |
| | self, |
| | input_channels=768, |
| | output_channels=768, |
| | encode_channels=768, |
| | decode_channels=768, |
| | code_dim=768, |
| | codebook_num=1, |
| | codebook_size=1024, |
| | bias=True, |
| | enc_ratios=(1, 1), |
| | dec_ratios=(1, 1), |
| | enc_strides=(1, 1), |
| | dec_strides=(1, 1), |
| | enc_kernel_size=3, |
| | dec_kernel_size=3, |
| | enc_block_dilations=(1, 1), |
| | enc_block_kernel_size=3, |
| | dec_block_dilations=(1, 1), |
| | dec_block_kernel_size=3 |
| | ): |
| | super().__init__() |
| |
|
| | self.input_channels = input_channels |
| |
|
| | self.encoder = Encoder( |
| | input_channels=input_channels, |
| | encode_channels=encode_channels, |
| | channel_ratios=enc_ratios, |
| | strides=enc_strides, |
| | kernel_size=enc_kernel_size, |
| | bias=bias, |
| | block_dilations=enc_block_dilations, |
| | unit_kernel_size=enc_block_kernel_size |
| | ) |
| |
|
| | self.decoder = Decoder( |
| | code_dim=code_dim, |
| | output_channels=output_channels, |
| | decode_channels=decode_channels, |
| | channel_ratios=dec_ratios, |
| | strides=dec_strides, |
| | kernel_size=dec_kernel_size, |
| | bias=bias, |
| | block_dilations=dec_block_dilations, |
| | unit_kernel_size=dec_block_kernel_size |
| | ) |
| |
|
| | self.projector = Projector( |
| | input_channels=self.encoder.out_channels, |
| | code_dim=code_dim, |
| | kernel_size=3, |
| | stride=1, |
| | bias=False |
| | ) |
| |
|
| | self.quantizer = Quantizer( |
| | code_dim=code_dim, |
| | codebook_num=codebook_num, |
| | codebook_size=codebook_size |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.encoder(x) |
| | z = self.projector(x) |
| | zq, vqloss, perplexity = self.quantizer(z) |
| | y = self.decoder(zq) |
| | return y, zq, z, vqloss, perplexity |
| |
|