| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | import logging |
| | import math |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| |
|
| | from omegaconf import II |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.distributed as dist |
| |
|
| | from fairseq.modules import EMAModule, EMAModuleConfig |
| | from fairseq.data.data_utils import compute_mask_indices |
| | from fairseq.models import BaseFairseqModel, register_model |
| | from fairseq.models.wav2vec import ( |
| | ConvFeatureExtractionModel, |
| | Wav2Vec2Config, |
| | TransformerEncoder, |
| | ) |
| | from fairseq.modules import ( |
| | GradMultiply, |
| | LayerNorm, |
| | ) |
| | from fairseq.utils import index_put |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class Data2VecAudioConfig(Wav2Vec2Config): |
| |
|
| | loss_beta: float = field( |
| | default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} |
| | ) |
| | loss_scale: Optional[float] = field( |
| | default=None, |
| | metadata={ |
| | "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" |
| | }, |
| | ) |
| | average_top_k_layers: int = field( |
| | default=8, metadata={"help": "how many layers to average"} |
| | ) |
| |
|
| | layer_norm_target_layer: bool = False |
| | instance_norm_target_layer: bool = False |
| | instance_norm_targets: bool = False |
| | layer_norm_targets: bool = False |
| | batch_norm_target_layer: bool = False |
| | group_norm_target_layer: bool = False |
| |
|
| | ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) |
| | ema_end_decay: float = field( |
| | default=0.9999, metadata={"help": "final ema decay rate"} |
| | ) |
| |
|
| | |
| | ema_anneal_end_step: int = II("optimization.max_update") |
| |
|
| | ema_transformer_only: bool = field( |
| | default=True, |
| | metadata={"help": "whether to momentum update only the transformer"}, |
| | ) |
| | ema_layers_only: bool = field( |
| | default=True, |
| | metadata={"help": "whether to momentum update only the transformer layers"}, |
| | ) |
| |
|
| | max_update: int = II("optimization.max_update") |
| |
|
| | min_target_var: float = field( |
| | default=0.1, metadata={"help": "stop training if target var falls below this"} |
| | ) |
| | min_pred_var: float = field( |
| | default=0.01, |
| | metadata={"help": "stop training if prediction var falls below this"}, |
| | ) |
| |
|
| |
|
| | def get_annealed_rate(start, end, curr_step, total_steps): |
| | r = end - start |
| | pct_remaining = 1 - curr_step / total_steps |
| | return end - r * pct_remaining |
| |
|
| |
|
| | @register_model("data2vec_audio", dataclass=Data2VecAudioConfig) |
| | class Data2VecAudioModel(BaseFairseqModel): |
| | def __init__(self, cfg: Data2VecAudioConfig): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | feature_enc_layers = eval(cfg.conv_feature_layers) |
| | self.extractor_embed = feature_enc_layers[-1][0] |
| |
|
| | self.ema = None |
| | self.embed = cfg.encoder_embed_dim |
| |
|
| | self.average_top_k_layers = cfg.average_top_k_layers |
| | self.loss_beta = cfg.loss_beta |
| | self.loss_scale = cfg.loss_scale |
| |
|
| | self.feature_extractor = ConvFeatureExtractionModel( |
| | conv_layers=feature_enc_layers, |
| | dropout=0.0, |
| | mode=cfg.extractor_mode, |
| | conv_bias=cfg.conv_bias, |
| | ) |
| |
|
| | self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim) |
| |
|
| | self.mask_prob = cfg.mask_prob |
| | self.mask_selection = cfg.mask_selection |
| | self.mask_other = cfg.mask_other |
| | self.mask_length = cfg.mask_length |
| | self.no_mask_overlap = cfg.no_mask_overlap |
| | self.mask_min_space = cfg.mask_min_space |
| |
|
| | self.mask_channel_prob = cfg.mask_channel_prob |
| | self.mask_channel_before = cfg.mask_channel_before |
| | self.mask_channel_selection = cfg.mask_channel_selection |
| | self.mask_channel_other = cfg.mask_channel_other |
| | self.mask_channel_length = cfg.mask_channel_length |
| | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap |
| | self.mask_channel_min_space = cfg.mask_channel_min_space |
| |
|
| | self.dropout_input = nn.Dropout(cfg.dropout_input) |
| | self.dropout_features = nn.Dropout(cfg.dropout_features) |
| |
|
| | self.feature_grad_mult = cfg.feature_grad_mult |
| |
|
| | self.mask_emb = nn.Parameter( |
| | torch.FloatTensor(cfg.encoder_embed_dim).uniform_() |
| | ) |
| |
|
| | self.encoder = TransformerEncoder(cfg) |
| | self.layer_norm = LayerNorm(self.extractor_embed) |
| |
|
| | self.final_proj = nn.Linear(self.embed, self.embed) |
| |
|
| | self.num_updates = 0 |
| |
|
| | def make_ema_teacher(self): |
| | ema_config = EMAModuleConfig( |
| | ema_decay=self.cfg.ema_decay, |
| | ema_fp32=True, |
| | ) |
| | skip_keys = set() |
| | if self.cfg.ema_layers_only: |
| | self.cfg.ema_transformer_only = True |
| | for k, _ in self.encoder.pos_conv.named_parameters(): |
| | skip_keys.add(f"pos_conv.{k}") |
| |
|
| | self.ema = EMAModule( |
| | self.encoder if self.cfg.ema_transformer_only else self, |
| | ema_config, |
| | skip_keys=skip_keys, |
| | ) |
| |
|
| | def set_num_updates(self, num_updates): |
| | super().set_num_updates(num_updates) |
| |
|
| | if self.ema is None and self.final_proj is not None: |
| | logger.info(f"making ema teacher") |
| | self.make_ema_teacher() |
| | elif self.training and self.ema is not None: |
| | if self.cfg.ema_decay != self.cfg.ema_end_decay: |
| | if num_updates >= self.cfg.ema_anneal_end_step: |
| | decay = self.cfg.ema_end_decay |
| | else: |
| | decay = get_annealed_rate( |
| | self.cfg.ema_decay, |
| | self.cfg.ema_end_decay, |
| | num_updates, |
| | self.cfg.ema_anneal_end_step, |
| | ) |
| | self.ema.set_decay(decay) |
| | if self.ema.get_decay() < 1: |
| | self.ema.step(self.encoder if self.cfg.ema_transformer_only else self) |
| |
|
| | self.num_updates = num_updates |
| |
|
| | def state_dict(self, destination=None, prefix="", keep_vars=False): |
| | state = super().state_dict(destination, prefix, keep_vars) |
| |
|
| | if self.ema is not None: |
| | state[prefix + "_ema"] = self.ema.fp32_params |
| |
|
| | return state |
| |
|
| | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
| | if self.ema is not None: |
| | k = prefix + "_ema" |
| | assert k in state_dict |
| | self.ema.restore(state_dict[k], True) |
| | del state_dict[k] |
| | return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
| |
|
| | @classmethod |
| | def build_model(cls, cfg: Data2VecAudioConfig, task=None): |
| | """Build a new model instance.""" |
| |
|
| | return cls(cfg) |
| |
|
| | def apply_mask( |
| | self, |
| | x, |
| | padding_mask, |
| | mask_indices=None, |
| | mask_channel_indices=None, |
| | ): |
| | B, T, C = x.shape |
| |
|
| | if self.mask_channel_prob > 0 and self.mask_channel_before: |
| | mask_channel_indices = compute_mask_indices( |
| | (B, C), |
| | None, |
| | self.mask_channel_prob, |
| | self.mask_channel_length, |
| | self.mask_channel_selection, |
| | self.mask_channel_other, |
| | no_overlap=self.no_mask_channel_overlap, |
| | min_space=self.mask_channel_min_space, |
| | ) |
| | mask_channel_indices = ( |
| | torch.from_numpy(mask_channel_indices) |
| | .to(x.device) |
| | .unsqueeze(1) |
| | .expand(-1, T, -1) |
| | ) |
| | x[mask_channel_indices] = 0 |
| |
|
| | if self.mask_prob > 0: |
| | if mask_indices is None: |
| | mask_indices = compute_mask_indices( |
| | (B, T), |
| | padding_mask, |
| | self.mask_prob, |
| | self.mask_length, |
| | self.mask_selection, |
| | self.mask_other, |
| | min_masks=1, |
| | no_overlap=self.no_mask_overlap, |
| | min_space=self.mask_min_space, |
| | require_same_masks=self.cfg.require_same_masks, |
| | mask_dropout=self.cfg.mask_dropout, |
| | ) |
| | mask_indices = torch.from_numpy(mask_indices).to(x.device) |
| | x = index_put(x, mask_indices, self.mask_emb) |
| | else: |
| | mask_indices = None |
| |
|
| | if self.mask_channel_prob > 0 and not self.mask_channel_before: |
| | if mask_channel_indices is None: |
| | mask_channel_indices = compute_mask_indices( |
| | (B, C), |
| | None, |
| | self.mask_channel_prob, |
| | self.mask_channel_length, |
| | self.mask_channel_selection, |
| | self.mask_channel_other, |
| | no_overlap=self.no_mask_channel_overlap, |
| | min_space=self.mask_channel_min_space, |
| | ) |
| | mask_channel_indices = ( |
| | torch.from_numpy(mask_channel_indices) |
| | .to(x.device) |
| | .unsqueeze(1) |
| | .expand(-1, T, -1) |
| | ) |
| | x = index_put(x, mask_channel_indices, 0) |
| |
|
| | return x, mask_indices |
| |
|
| | def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
| | """ |
| | Computes the output length of the convolutional layers |
| | """ |
| |
|
| | def _conv_out_length(input_length, kernel_size, stride): |
| | return torch.floor((input_length - kernel_size) / stride + 1) |
| |
|
| | conv_cfg_list = eval(self.cfg.conv_feature_layers) |
| |
|
| | for i in range(len(conv_cfg_list)): |
| | input_lengths = _conv_out_length( |
| | input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] |
| | ) |
| |
|
| | return input_lengths.to(torch.long) |
| |
|
| | def forward( |
| | self, |
| | source, |
| | padding_mask=None, |
| | mask=True, |
| | features_only=False, |
| | layer=None, |
| | mask_indices=None, |
| | mask_channel_indices=None, |
| | padding_count=None, |
| | ): |
| | features = source |
| |
|
| | if self.feature_grad_mult > 0: |
| | features = self.feature_extractor(features) |
| | if self.feature_grad_mult != 1.0: |
| | features = GradMultiply.apply(features, self.feature_grad_mult) |
| | else: |
| | with torch.no_grad(): |
| | features = self.feature_extractor(features) |
| |
|
| | features = features.transpose(1, 2) |
| |
|
| | features = self.layer_norm(features) |
| |
|
| | orig_padding_mask = padding_mask |
| |
|
| | if padding_mask is not None and padding_mask.any(): |
| | input_lengths = (1 - padding_mask.long()).sum(-1) |
| | |
| | output_lengths = self._get_feat_extract_output_lengths(input_lengths) |
| |
|
| | padding_mask = torch.zeros( |
| | features.shape[:2], dtype=features.dtype, device=features.device |
| | ) |
| |
|
| | |
| | |
| | padding_mask[ |
| | ( |
| | torch.arange(padding_mask.shape[0], device=padding_mask.device), |
| | output_lengths - 1, |
| | ) |
| | ] = 1 |
| | padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() |
| | else: |
| | padding_mask = None |
| |
|
| | if self.post_extract_proj is not None: |
| | features = self.post_extract_proj(features) |
| |
|
| | pre_encoder_features = None |
| | if self.cfg.ema_transformer_only: |
| | pre_encoder_features = features.clone() |
| |
|
| | features = self.dropout_input(features) |
| |
|
| | if mask: |
| | x, mask_indices = self.apply_mask( |
| | features, |
| | padding_mask, |
| | mask_indices=mask_indices, |
| | mask_channel_indices=mask_channel_indices, |
| | ) |
| | else: |
| | x = features |
| | mask_indices = None |
| |
|
| | x, layer_results = self.encoder( |
| | x, |
| | padding_mask=padding_mask, |
| | layer=layer, |
| | ) |
| |
|
| | if features_only: |
| | return { |
| | "x": x, |
| | "padding_mask": padding_mask, |
| | "layer_results": layer_results, |
| | } |
| |
|
| | result = { |
| | "losses": {}, |
| | } |
| |
|
| | with torch.no_grad(): |
| | self.ema.model.eval() |
| |
|
| | if self.cfg.ema_transformer_only: |
| | y, layer_results = self.ema.model.extract_features( |
| | pre_encoder_features, |
| | padding_mask=padding_mask, |
| | min_layer=self.cfg.encoder_layers - self.average_top_k_layers, |
| | ) |
| | y = { |
| | "x": y, |
| | "padding_mask": padding_mask, |
| | "layer_results": layer_results, |
| | } |
| | else: |
| | y = self.ema.model.extract_features( |
| | source=source, |
| | padding_mask=orig_padding_mask, |
| | mask=False, |
| | ) |
| |
|
| | target_layer_results = [l[2] for l in y["layer_results"]] |
| |
|
| | permuted = False |
| | if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: |
| | target_layer_results = [ |
| | tl.permute(1, 2, 0) for tl in target_layer_results |
| | ] |
| | permuted = True |
| |
|
| | if self.cfg.batch_norm_target_layer: |
| | target_layer_results = [ |
| | F.batch_norm( |
| | tl.float(), running_mean=None, running_var=None, training=True |
| | ) |
| | for tl in target_layer_results |
| | ] |
| |
|
| | if self.cfg.instance_norm_target_layer: |
| | target_layer_results = [ |
| | F.instance_norm(tl.float()) for tl in target_layer_results |
| | ] |
| |
|
| | if permuted: |
| | target_layer_results = [ |
| | tl.transpose(1, 2) for tl in target_layer_results |
| | ] |
| |
|
| | if self.cfg.group_norm_target_layer: |
| | target_layer_results = [ |
| | F.layer_norm(tl.float(), tl.shape[-2:]) |
| | for tl in target_layer_results |
| | ] |
| |
|
| | if self.cfg.layer_norm_target_layer: |
| | target_layer_results = [ |
| | F.layer_norm(tl.float(), tl.shape[-1:]) |
| | for tl in target_layer_results |
| | ] |
| |
|
| | y = sum(target_layer_results) / len(target_layer_results) |
| |
|
| | if self.cfg.layer_norm_targets: |
| | y = F.layer_norm(y.float(), y.shape[-1:]) |
| |
|
| | if self.cfg.instance_norm_targets: |
| | y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2) |
| |
|
| | if not permuted: |
| | y = y.transpose(0, 1) |
| |
|
| | y = y[mask_indices] |
| |
|
| | x = x[mask_indices] |
| | x = self.final_proj(x) |
| |
|
| | sz = x.size(-1) |
| |
|
| | if self.loss_beta == 0: |
| | loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1) |
| | else: |
| | loss = F.smooth_l1_loss( |
| | x.float(), y.float(), reduction="none", beta=self.loss_beta |
| | ).sum(dim=-1) |
| |
|
| | if self.loss_scale is not None: |
| | scale = self.loss_scale |
| | else: |
| | scale = 1 / math.sqrt(sz) |
| |
|
| | result["losses"]["regression"] = loss.sum() * scale |
| |
|
| | if "sample_size" not in result: |
| | result["sample_size"] = loss.numel() |
| |
|
| | with torch.no_grad(): |
| | result["target_var"] = self.compute_var(y) |
| | result["pred_var"] = self.compute_var(x.float()) |
| |
|
| | if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var: |
| | logger.error( |
| | f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" |
| | ) |
| | raise Exception( |
| | f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" |
| | ) |
| | if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var: |
| | logger.error( |
| | f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" |
| | ) |
| | raise Exception( |
| | f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" |
| | ) |
| |
|
| | if self.ema is not None: |
| | result["ema_decay"] = self.ema.get_decay() * 1000 |
| |
|
| | return result |
| |
|
| | @staticmethod |
| | def compute_var(y): |
| | y = y.view(-1, y.size(-1)) |
| | if dist.is_initialized(): |
| | zc = torch.tensor(y.size(0)).cuda() |
| | zs = y.sum(dim=0) |
| | zss = (y ** 2).sum(dim=0) |
| |
|
| | dist.all_reduce(zc) |
| | dist.all_reduce(zs) |
| | dist.all_reduce(zss) |
| |
|
| | var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) |
| | return torch.sqrt(var + 1e-6).mean() |
| | else: |
| | return torch.sqrt(y.var(dim=0) + 1e-6).mean() |
| |
|
| | def extract_features( |
| | self, source, padding_mask, mask=False, layer=None |
| | ): |
| | res = self.forward( |
| | source, |
| | padding_mask, |
| | mask=mask, |
| | features_only=True, |
| | layer=layer, |
| | ) |
| | return res |
| |
|
| | def remove_pretraining_modules(self, last_layer=None): |
| | self.final_proj = None |
| | self.ema = None |
| | if last_layer is not None: |
| | self.encoder.layers = nn.ModuleList( |
| | l for i, l in enumerate(self.encoder.layers) if i <= last_layer |
| | ) |
| |
|