| import pandas as pd |
| import os |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers import GPT2TokenizerFast, GPT2LMHeadModel, AutoModelForCausalLM |
| from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling |
| from transformers import Trainer, TrainingArguments, RobertaTokenizerFast |
|
|
| import datasets |
| from datasets import disable_caching |
| disable_caching() |
| from datasets import IterableDataset |
|
|
| from conditional_gpt2_model import ConditionalGPT2LMHeadModel |
|
|
|
|
| ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" |
| TOKENIZER_MAX_LEN = 256 |
|
|
| DATA_SUBSHARDS = 10 |
|
|
| DATA_DIR = None |
| TRAINER_SAVE_DIR = None |
|
|
| assert DATA_DIR is not None, "data directory must be specified" |
| assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified" |
|
|
|
|
|
|
| def gen_dataset(): |
| |
| data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i]) |
| |
| for filename in data_filenames: |
| |
| dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}') |
| |
| keep_cols = ['input_ids', 'encoder_hidden_states'] |
| |
| dataset = dataset.remove_columns([i for i in dataset.column_names |
| if not i in keep_cols]).with_format("torch") |
| |
| |
| shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True) |
| for index in range(DATA_SUBSHARDS)] |
| |
| for i, shard in enumerate(shards): |
| for example in shard: |
| |
| example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:] |
| yield example |
|
|
| dataset = IterableDataset.from_generator(gen_dataset) |
| dataset = dataset.with_format("torch") |
|
|
| tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN) |
| collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
|
|
| |
| config = GPT2Config( |
| vocab_size=len(tokenizer), |
| n_positions=TOKENIZER_MAX_LEN, |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| n_layer=6, |
| n_head=8, |
| add_cross_attention=True, |
| ) |
|
|
| model = ConditionalGPT2LMHeadModel(config) |
|
|
| |
| |
| |
| |
|
|
| |
| args = TrainingArguments( |
| output_dir=TRAINER_SAVE_DIR, |
| per_device_train_batch_size=192, |
| logging_steps=25, |
| gradient_accumulation_steps=8, |
| num_train_epochs=1, |
| weight_decay=0.1, |
| warmup_steps=1000, |
| lr_scheduler_type="cosine", |
| learning_rate=1e-5, |
| save_steps=200, |
| save_total_limit=30, |
| fp16=True, |
| push_to_hub=False, |
| max_steps=50000, |
| ) |
|
|
|
|
| trainer = Trainer( |
| model=model, |
| tokenizer=tokenizer, |
| args=args, |
| data_collator=collator, |
| train_dataset=dataset, |
| ) |
|
|
| trainer.train() |
|
|
|
|