| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,5,7" |
| |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from transformers import AutoConfig, GPT2LMHeadModel, AutoModel, AutoModelForCausalLM |
| from transformers import Trainer, TrainingArguments |
| from datasets import Dataset, DatasetDict, concatenate_datasets, Sequence, Value |
| from torch.nn import functional as F |
| from tqdm import tqdm |
| import time |
| import torch |
| import wandb |
| import random |
| import string |
| from eval_model import evaluate_model |
|
|
| def process(text): |
|
|
| |
| text = text.lower() |
|
|
| |
| punctuation_to_remove = string.punctuation.replace("'", "") |
| translation_table = str.maketrans('', '', punctuation_to_remove) |
| text = text.translate(translation_table) |
|
|
| |
| while text[0] == ' ' or text[-1] == ' ': |
| if text[0] == ' ': |
| text = text[1:] |
| if text[-1] == ' ': |
| text = text[:-1] |
| |
| return text |
|
|
| dataset_name = "entity_tokenized" |
| tokenizer_path = "./../tokenizer" |
| max_length = 2048 |
| |
| |
| |
| n_bwords = 25 |
|
|
| dataset = Dataset.load_from_disk(dataset_name) |
| dataset = dataset.remove_columns(["audio_tokens", "raw_text", "transcript", "entities", "prompt"]) |
| feat = dataset.features.copy() |
| feat["input_ids"] = Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None) |
| feat["attention_mask"] = Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None) |
| dataset = dataset.cast(feat) |
| dataset = dataset.train_test_split(test_size=0.025) |
|
|
| asr_dataset = DatasetDict.load_from_disk("/root/.cache/huggingface/hub/models--darshanmakwana--storage/snapshots/b6e4caa73046e02ad19b48b39c097ba7b9980210/ASR/tokenized_librispeech/").remove_columns(["token_type_ids"]) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| tokenizer.pad_token_id = 0 |
| tokenizer.pad_token = "<|padding|>" |
| tokenizer.padding_side = "right" |
|
|
| |
| num_new_tokens = tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"]) |
| |
| tokenizer.add_tokens(["<|entity:PER|>", "<|entity:LOC|>", "<|entity:ORG|>", "<|entity|>", "<|detectentities|>"]) |
| |
| |
| |
|
|
| with open("./../prompting/blist/all_rare_words.txt") as fin: |
| rarewords = [process(word.strip()) for word in fin] |
|
|
| def tokenize(element): |
| |
| |
| audio_tkns = element["audio_tokens"] |
| data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) + "<|startofprompt|>" |
| |
| |
| b_words = element["b_words"] |
| if n_bwords > len(b_words): |
| context = b_words + random.sample(rarewords, n_bwords - len(b_words)) |
| else: |
| context = random.sample(b_words, n_bwords) |
| random.shuffle(context) |
| |
| |
| data += "<|sepofprompt|>".join(context) |
| |
| |
| data += "<|endofprompt|><|startoftranscript|>" + element["text"] + "<|endoftranscript|>" |
| |
| outputs = tokenizer(data, truncation=True, max_length=max_length, padding="max_length") |
| return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} |
|
|
| p_dataset = DatasetDict.load_from_disk("./../libripseech_tokenized") |
| prompt_dataset = p_dataset.map( |
| tokenize, batched=False, remove_columns = p_dataset["train.clean.100"].column_names |
| ) |
|
|
| print("Total Vocab Size:", len(tokenizer)) |
|
|
| model = GPT2LMHeadModel.from_pretrained("./../models/checkpoint-prompting") |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| from transformers import DataCollatorForLanguageModeling |
|
|
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) |
|
|
| config = { |
| "output_dir": "./out", |
| "max_steps": 20000, |
| "per_device_train_batch_size": 5, |
| "per_device_eval_batch_size": 5, |
| "gradient_accumulation_steps": 1, |
| "eval_strategy": "steps", |
| "save_strategy": "steps", |
| "eval_steps": 500, |
| "logging_steps": 1, |
| "logging_first_step": True, |
| "save_total_limit": 5, |
| "load_best_model_at_end": True, |
| "save_steps": 1000, |
| "lr_scheduler_type": "cosine", |
| "learning_rate": 1e-4, |
| "warmup_steps": 10, |
| "weight_decay": 0.01, |
| "report_to": "wandb", |
| "fp16": True |
| } |
|
|
| from argparse import Namespace |
|
|
| args = Namespace(**config) |
| train_args = TrainingArguments(**config) |
|
|
| wandb.init(project="multi_modal_exps", name="entity") |
|
|
| class GPTTrainer(Trainer): |
| def compute_loss(self, model, inputs, return_outputs=False): |
| |
| labels = inputs.get("labels") |
| outputs = model(**inputs) |
| logits = outputs.get("logits") |
| |
| labels = labels[:, 1:] |
| logits = logits[:, :-1, :] |
| |
| print(logits.shape, labels.shape, torch.max(logits).item(), torch.max(labels).item(), torch.min(logits).item(), torch.min(labels).item()) |
| |
| loss = F.cross_entropy(torch.reshape(logits, (-1, logits.size(-1))), torch.reshape(labels, (-1, )), ignore_index=-100) |
| |
| return (loss, outputs) if return_outputs else loss |
| |
| @torch.no_grad() |
| def evaluation_loop(self, dataloader, description, prediction_loss_only=None, ignore_keys=None, metric_key_prefix="eval"): |
| |
| eval_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix) |
| |
| wer, cer, b_wer, u_wer = evaluate_model(model) |
| |
| wandb.log({ |
| "Word Error Rate": wer, |
| "Char Error Rate": cer, |
| "Biased Word Error Rate": b_wer, |
| "Unbiased Word Error Rate": u_wer |
| }) |
| |
| return eval_output |
|
|
| trainer = GPTTrainer( |
| model = model, |
| tokenizer = tokenizer, |
| args = train_args, |
| data_collator = data_collator, |
| train_dataset = concatenate_datasets([dataset["train"], asr_dataset["train.clean.100"], prompt_dataset["train.clean.100"]]), |
| eval_dataset = concatenate_datasets([dataset["test"], asr_dataset["validation.clean"], prompt_dataset["validation.clean"]]), |
| ) |
|
|
| trainer.train() |