| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch.nn as nn |
| |
|
| | from repcodec.layers.vq_module import ResidualVQ |
| |
|
| |
|
| | class Quantizer(nn.Module): |
| | def __init__( |
| | self, |
| | code_dim: int, |
| | codebook_num: int, |
| | codebook_size: int, |
| | ): |
| | super().__init__() |
| | self.codebook = ResidualVQ( |
| | dim=code_dim, |
| | num_quantizers=codebook_num, |
| | codebook_size=codebook_size |
| | ) |
| |
|
| | def initial(self): |
| | self.codebook.initial() |
| |
|
| | def forward(self, z): |
| | zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) |
| | zq = zq.transpose(2, 1) |
| | return zq, vqloss, perplexity |
| |
|
| | def inference(self, z): |
| | zq, indices = self.codebook.forward_index(z.transpose(2, 1)) |
| | zq = zq.transpose(2, 1) |
| | return zq, indices |
| |
|
| | def encode(self, z): |
| | zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) |
| | return zq, indices |
| |
|
| | def decode(self, indices): |
| | z = self.codebook.lookup(indices) |
| | return z |
| |
|