| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | from collections import defaultdict |
| |
|
| | import torch |
| | from tensorboardX import SummaryWriter |
| | from tqdm import tqdm |
| |
|
| | logger = logging.getLogger("repcodec_train") |
| |
|
| |
|
| | class Trainer: |
| | def __init__( |
| | self, |
| | steps: int, |
| | epochs: int, |
| | data_loader: dict, |
| | model: dict, |
| | criterion: dict, |
| | optimizer: dict, |
| | scheduler: dict, |
| | config: dict, |
| | device=torch.device("cpu"), |
| | ): |
| | self.steps = steps |
| | self.epochs = epochs |
| | self.data_loader = data_loader |
| | self.model = model |
| | self.criterion = criterion |
| | self.optimizer = optimizer |
| | self.scheduler = scheduler |
| | self.config = config |
| | self.device = device |
| | self.writer = SummaryWriter(config["outdir"]) |
| | self.total_train_loss = defaultdict(float) |
| | self.total_eval_loss = defaultdict(float) |
| | self.train_max_steps = config.get("train_max_steps", 0) |
| |
|
| | def _train_step(self, batch): |
| | """Single step of training.""" |
| | mode = "train" |
| | x = batch |
| | x = x.to(self.device) |
| |
|
| | codec_loss = 0.0 |
| | y_, zq, z, vqloss, perplexity = self.model["repcodec"](x) |
| | self._perplexity(perplexity, mode=mode) |
| | codec_loss += self._vq_loss(vqloss, mode=mode) |
| | codec_loss += self._metric_loss(y_, x, mode=mode) |
| |
|
| | self._record_loss("codec_loss", codec_loss, mode=mode) |
| | self._update_repcodec(codec_loss) |
| |
|
| | self.steps += 1 |
| | self.tqdm.update(1) |
| | self._check_train_finish() |
| |
|
| | @torch.no_grad() |
| | def _eval_step(self, batch): |
| | """Single step of evaluation.""" |
| | mode = "eval" |
| | x = batch |
| | x = x.to(self.device) |
| |
|
| | codec_loss = 0.0 |
| | y_, zq, z, vqloss, perplexity = self.model["repcodec"](x) |
| | self._perplexity(perplexity, mode=mode) |
| | codec_loss += self._vq_loss(vqloss, mode=mode) |
| | codec_loss += self._metric_loss(y_, x, mode=mode) |
| |
|
| | self._record_loss("codec_loss", codec_loss, mode=mode) |
| |
|
| | def run(self): |
| | """Run training.""" |
| | self.finish_train = False |
| | self.tqdm = tqdm( |
| | initial=self.steps, total=self.train_max_steps, desc="[train]" |
| | ) |
| | while True: |
| | self._train_epoch() |
| |
|
| | |
| | if self.finish_train: |
| | break |
| |
|
| | self.tqdm.close() |
| | logger.info("Finished training.") |
| |
|
| | def save_checkpoint(self, checkpoint_path: str): |
| | state_dict = { |
| | "model": { |
| | "repcodec": self.model["repcodec"].state_dict() |
| | }, |
| | "optimizer": { |
| | "repcodec": self.optimizer["repcodec"].state_dict(), |
| | }, |
| | "scheduler": { |
| | "repcodec": self.scheduler["repcodec"].state_dict(), |
| | }, |
| | "steps": self.steps, |
| | "epochs": self.epochs, |
| | } |
| |
|
| | if not os.path.exists(os.path.dirname(checkpoint_path)): |
| | os.makedirs(os.path.dirname(checkpoint_path)) |
| | torch.save(state_dict, checkpoint_path) |
| |
|
| | def load_checkpoint( |
| | self, |
| | checkpoint_path: str, |
| | strict: bool = True, |
| | load_only_params: bool = False |
| | ): |
| | state_dict = torch.load(checkpoint_path, map_location="cpu") |
| | self.model["repcodec"].load_state_dict( |
| | state_dict["model"]["repcodec"], strict=strict |
| | ) |
| |
|
| | if not load_only_params: |
| | self.steps = state_dict["steps"] |
| | self.epochs = state_dict["epochs"] |
| | self.optimizer["repcodec"].load_state_dict( |
| | state_dict["optimizer"]["repcodec"] |
| | ) |
| | self.scheduler["repcodec"].load_state_dict( |
| | state_dict["scheduler"]["repcodec"] |
| | ) |
| |
|
| | def _train_epoch(self): |
| | """One epoch of training.""" |
| | for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): |
| | |
| | self._train_step(batch) |
| |
|
| | |
| | self._check_log_interval() |
| | self._check_eval_interval() |
| | self._check_save_interval() |
| |
|
| | |
| | if self.finish_train: |
| | return |
| |
|
| | |
| | self.epochs += 1 |
| | self.train_steps_per_epoch = train_steps_per_epoch |
| | if train_steps_per_epoch > 200: |
| | logger.info( |
| | f"(Steps: {self.steps}) Finished {self.epochs} epoch training " |
| | f"({self.train_steps_per_epoch} steps per epoch)." |
| | ) |
| |
|
| | def _eval_epoch(self): |
| | """One epoch of evaluation.""" |
| | logger.info(f"(Steps: {self.steps}) Start evaluation.") |
| | |
| | for key in self.model.keys(): |
| | self.model[key].eval() |
| |
|
| | |
| | for eval_steps_per_epoch, batch in enumerate( |
| | tqdm(self.data_loader["dev"], desc="[eval]"), 1 |
| | ): |
| | |
| | self._eval_step(batch) |
| |
|
| | logger.info( |
| | f"(Steps: {self.steps}) Finished evaluation " |
| | f"({eval_steps_per_epoch} steps per epoch)." |
| | ) |
| |
|
| | |
| | for key in self.total_eval_loss.keys(): |
| | self.total_eval_loss[key] /= eval_steps_per_epoch |
| | logger.info( |
| | f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." |
| | ) |
| |
|
| | |
| | self._write_to_tensorboard(self.total_eval_loss) |
| |
|
| | |
| | self.total_eval_loss = defaultdict(float) |
| |
|
| | |
| | for key in self.model.keys(): |
| | self.model[key].train() |
| |
|
| | def _metric_loss(self, predict_y, natural_y, mode='train'): |
| | """Metric losses.""" |
| | metric_loss = 0.0 |
| |
|
| | repr_reconstruct_loss = self.criterion["repr_reconstruct_loss"](predict_y, natural_y) |
| | repr_reconstruct_loss *= self.config["lambda_repr_reconstruct_loss"] |
| | self._record_loss("reconstruct_loss", repr_reconstruct_loss, mode=mode) |
| | metric_loss += repr_reconstruct_loss |
| |
|
| | return metric_loss |
| |
|
| | def _update_repcodec(self, repr_loss): |
| | """Update generator.""" |
| | self.optimizer["repcodec"].zero_grad() |
| | repr_loss.backward() |
| | if self.config["grad_norm"] > 0: |
| | torch.nn.utils.clip_grad_norm_( |
| | self.model["repcodec"].parameters(), |
| | self.config["grad_norm"], |
| | ) |
| | self.optimizer["repcodec"].step() |
| | self.scheduler["repcodec"].step() |
| |
|
| | def _record_loss(self, name: str, loss, mode='train'): |
| | """Record loss.""" |
| | if torch.is_tensor(loss): |
| | loss = loss.item() |
| |
|
| | if mode == 'train': |
| | self.total_train_loss[f"train/{name}"] += loss |
| | elif mode == 'eval': |
| | self.total_eval_loss[f"eval/{name}"] += loss |
| | else: |
| | raise NotImplementedError(f"Mode ({mode}) is not supported!") |
| |
|
| | def _write_to_tensorboard(self, loss): |
| | """Write to tensorboard.""" |
| | for key, value in loss.items(): |
| | self.writer.add_scalar(key, value, self.steps) |
| |
|
| | def _check_save_interval(self): |
| | if self.steps and (self.steps % self.config["save_interval_steps"] == 0): |
| | self.save_checkpoint( |
| | os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl") |
| | ) |
| | logger.info(f"Successfully saved checkpoint @ {self.steps} steps.") |
| |
|
| | def _check_eval_interval(self): |
| | if self.steps % self.config["eval_interval_steps"] == 0: |
| | self._eval_epoch() |
| |
|
| | def _check_log_interval(self): |
| | if self.steps % self.config["log_interval_steps"] == 0: |
| | for key in self.total_train_loss.keys(): |
| | self.total_train_loss[key] /= self.config["log_interval_steps"] |
| | logger.info( |
| | f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." |
| | ) |
| | self._write_to_tensorboard(self.total_train_loss) |
| |
|
| | |
| | self.total_train_loss = defaultdict(float) |
| |
|
| | def _check_train_finish(self): |
| | if self.steps >= self.train_max_steps: |
| | self.finish_train = True |
| | else: |
| | self.finish_train = False |
| | return self.finish_train |
| |
|
| | def _perplexity(self, perplexity, label=None, mode='train'): |
| | if label: |
| | name = f"{mode}/ppl_{label}" |
| | else: |
| | name = f"{mode}/ppl" |
| | if torch.numel(perplexity) > 1: |
| | perplexity = perplexity.tolist() |
| | for idx, ppl in enumerate(perplexity): |
| | self._record_loss(f"{name}_{idx}", ppl, mode=mode) |
| | else: |
| | self._record_loss(name, perplexity, mode=mode) |
| |
|
| | def _vq_loss(self, vqloss, label=None, mode='train'): |
| | if label: |
| | name = f"{mode}/vqloss_{label}" |
| | else: |
| | name = f"{mode}/vqloss" |
| | vqloss = torch.sum(vqloss) |
| | vqloss *= self.config["lambda_vq_loss"] |
| | self._record_loss(name, vqloss, mode=mode) |
| |
|
| | return vqloss |
| |
|