File size: 8,991 Bytes
b7720f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | import argparse
import json
import os
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
Trainer,
TrainingArguments,
set_seed,
)
SYSTEM_PREFIX = (
"You are GravityLLM, a Spatial9 scene generation model. "
"Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
"Do not return markdown. Do not explain your answer. "
"Respect hard constraints such as object budgets, anchor positions, and low-end centering.\n\n"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Fine-tune GravityLLM for Spatial9 scene generation.")
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
parser.add_argument("--train_file", type=str, default="data/train.jsonl")
parser.add_argument("--valid_file", type=str, default="data/valid.jsonl")
parser.add_argument("--output_dir", type=str, default="outputs/GravityLLM-Qwen2.5-1.5B-S9")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--num_train_epochs", type=float, default=1.0)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--train_batch_size", type=int, default=1)
parser.add_argument("--eval_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
parser.add_argument("--warmup_ratio", type=float, default=0.03)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument("--save_steps", type=int, default=200)
parser.add_argument("--eval_steps", type=int, default=200)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--lora", action="store_true", help="Enable LoRA adapters.")
parser.add_argument("--qlora", action="store_true", help="Enable 4-bit QLoRA training.")
parser.add_argument("--lora_r", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
return parser.parse_args()
def load_jsonl(file_path: str):
return load_dataset("json", data_files=file_path, split="train")
def format_prompt(raw_prompt: str) -> str:
raw_prompt = raw_prompt.strip()
if raw_prompt.lower().startswith("gravityllm:"):
raw_prompt = raw_prompt.split(":", 1)[1].strip()
return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n"
def tokenize_example(example: Dict[str, str], tokenizer, max_length: int) -> Dict[str, List[int]]:
prompt_text = format_prompt(example["prompt"])
completion_text = example["completion"].strip()
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
completion_ids = tokenizer(completion_text + tokenizer.eos_token, add_special_tokens=False)["input_ids"]
input_ids = prompt_ids + completion_ids
labels = [-100] * len(prompt_ids) + completion_ids
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
labels = labels[:max_length]
attention_mask = [1] * len(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
@dataclass
class CausalDataCollator:
pad_token_id: int
label_pad_token_id: int = -100
def __call__(self, features):
max_len = max(len(f["input_ids"]) for f in features)
input_ids = []
attention_mask = []
labels = []
for f in features:
pad_len = max_len - len(f["input_ids"])
input_ids.append(f["input_ids"] + [self.pad_token_id] * pad_len)
attention_mask.append(f["attention_mask"] + [0] * pad_len)
labels.append(f["labels"] + [self.label_pad_token_id] * pad_len)
batch = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
return batch
def prepare_model(args: argparse.Namespace):
model_kwargs = {}
if args.qlora:
compute_dtype = torch.bfloat16 if args.bf16 else torch.float16
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=compute_dtype,
)
model_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else None),
trust_remote_code=True,
**model_kwargs,
)
model.config.use_cache = False
if args.qlora:
model = prepare_model_for_kbit_training(model)
if args.lora or args.qlora:
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def main() -> None:
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
set_seed(args.seed)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, trust_remote_code=True)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_ds = load_jsonl(args.train_file)
valid_ds = load_jsonl(args.valid_file) if args.valid_file and Path(args.valid_file).exists() else None
train_ds = train_ds.map(
lambda row: tokenize_example(row, tokenizer, args.max_length),
remove_columns=train_ds.column_names,
desc="Tokenizing train set",
)
if valid_ds is not None:
valid_ds = valid_ds.map(
lambda row: tokenize_example(row, tokenizer, args.max_length),
remove_columns=valid_ds.column_names,
desc="Tokenizing valid set",
)
model = prepare_model(args)
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
evaluation_strategy="steps" if valid_ds is not None else "no",
save_strategy="steps",
bf16=args.bf16,
fp16=args.fp16,
report_to="none",
gradient_checkpointing=True,
lr_scheduler_type="cosine",
optim="paged_adamw_32bit" if (args.lora or args.qlora) else "adamw_torch",
max_grad_norm=1.0,
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
hub_private_repo=args.hub_private_repo,
hub_strategy="end" if args.push_to_hub else "every_save",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=valid_ds,
data_collator=CausalDataCollator(pad_token_id=tokenizer.pad_token_id),
tokenizer=tokenizer,
)
train_result = trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
metrics = train_result.metrics
with open(Path(args.output_dir) / "training_metrics.json", "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
run_meta = vars(args).copy()
run_meta["train_examples"] = len(train_ds)
run_meta["valid_examples"] = len(valid_ds) if valid_ds is not None else 0
with open(Path(args.output_dir) / "run_config.json", "w", encoding="utf-8") as f:
json.dump(run_meta, f, indent=2)
if args.push_to_hub:
trainer.push_to_hub(commit_message="Add GravityLLM fine-tuned adapter")
print(f"Training complete. Artifacts saved to: {args.output_dir}")
if __name__ == "__main__":
main()
|