| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | import sys |
| |
|
| | from feature_utils import get_path_iterator, dump_feature |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| | stream=sys.stdout, |
| | ) |
| | logger = logging.getLogger("dump_feature") |
| |
|
| |
|
| | def main( |
| | model_type: str, |
| | tsv_path: str, |
| | ckpt_path: str, |
| | whisper_root: str, |
| | whisper_name: str, |
| | layer: int, |
| | nshard: int, |
| | rank: int, |
| | feat_dir: str, |
| | max_chunk: int, |
| | use_cpu: bool = False |
| | ): |
| | device = "cpu" if use_cpu else "cuda" |
| |
|
| | |
| | if model_type in ["hubert", "data2vec"]: |
| | assert ckpt_path and os.path.exists(ckpt_path) |
| | elif model_type in ["whisper"]: |
| | assert whisper_name and whisper_root |
| | else: |
| | raise ValueError(f"Unsupported model type {model_type}") |
| |
|
| | reader = None |
| | if model_type == "hubert": |
| | from hubert_feature_reader import HubertFeatureReader |
| | reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk) |
| | elif model_type == "data2vec": |
| | from data2vec_feature_reader import Data2vecFeatureReader |
| | reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk) |
| | elif model_type == "whisper": |
| | from whisper_feature_reader import WhisperFeatureReader |
| | reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device) |
| |
|
| | assert reader is not None |
| |
|
| | generator, num = get_path_iterator(tsv_path, nshard, rank) |
| | dump_feature(reader, generator, num, nshard, rank, feat_dir) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_type", |
| | required=True, |
| | type=str, |
| | choices=["data2vec", "hubert", "whisper"], |
| | help="the type of the speech encoder." |
| | ) |
| | parser.add_argument( |
| | "--tsv_path", |
| | required=True, |
| | type=str, |
| | help="the path to the tsv file." |
| | ) |
| | parser.add_argument( |
| | "--ckpt_path", |
| | required=False, |
| | type=str, |
| | default=None, |
| | help="path to the speech model. must provide for HuBERT and data2vec" |
| | ) |
| | parser.add_argument( |
| | "--whisper_root", |
| | required=False, |
| | type=str, |
| | default=None, |
| | help="root dir to download/store whisper model. must provide for whisper model." |
| | ) |
| | parser.add_argument( |
| | "--whisper_name", |
| | required=False, |
| | type=str, |
| | default=None, |
| | help="name of whisper model. e.g., large-v2. must provide for whisper model." |
| | ) |
| | parser.add_argument( |
| | "--layer", |
| | required=True, |
| | type=int, |
| | help="which layer of the model. this is 1-based." |
| | ) |
| | parser.add_argument( |
| | "--feat_dir", |
| | required=True, |
| | type=str, |
| | help="the output dir to save the representations." |
| | ) |
| | parser.add_argument( |
| | "--nshard", |
| | required=False, |
| | type=int, |
| | default=1, |
| | help="total number of shards." |
| | ) |
| | parser.add_argument( |
| | "--rank", |
| | required=False, |
| | type=int, |
| | default=0, |
| | help="shard id of this process." |
| | ) |
| | parser.add_argument( |
| | "--max_chunk", |
| | type=int, |
| | default=1600000, |
| | help="max number of frames of each batch." |
| | ) |
| | parser.add_argument( |
| | "--use_cpu", |
| | default=False, |
| | action="store_true", |
| | help="whether use cpu instead of gpu." |
| | ) |
| | args = parser.parse_args() |
| | logger.info(args) |
| |
|
| | main(**vars(args)) |
| |
|