| import argparse |
| import json |
| import os |
| import random |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import torch |
| from datasets import load_dataset |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| Trainer, |
| TrainingArguments, |
| set_seed, |
| ) |
|
|
| SYSTEM_PREFIX = ( |
| "You are GravityLLM, a Spatial9 scene generation model. " |
| "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " |
| "Do not return markdown. Do not explain your answer. " |
| "Respect hard constraints such as object budgets, anchor positions, and low-end centering.\n\n" |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Fine-tune GravityLLM for Spatial9 scene generation.") |
| parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct") |
| parser.add_argument("--train_file", type=str, default="data/train.jsonl") |
| parser.add_argument("--valid_file", type=str, default="data/valid.jsonl") |
| parser.add_argument("--output_dir", type=str, default="outputs/GravityLLM-Qwen2.5-1.5B-S9") |
| parser.add_argument("--max_length", type=int, default=2048) |
|
|
| parser.add_argument("--num_train_epochs", type=float, default=1.0) |
| parser.add_argument("--learning_rate", type=float, default=2e-4) |
| parser.add_argument("--train_batch_size", type=int, default=1) |
| parser.add_argument("--eval_batch_size", type=int, default=1) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=16) |
| parser.add_argument("--warmup_ratio", type=float, default=0.03) |
| parser.add_argument("--weight_decay", type=float, default=0.0) |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument("--save_steps", type=int, default=200) |
| parser.add_argument("--eval_steps", type=int, default=200) |
| parser.add_argument("--seed", type=int, default=42) |
|
|
| parser.add_argument("--lora", action="store_true", help="Enable LoRA adapters.") |
| parser.add_argument("--qlora", action="store_true", help="Enable 4-bit QLoRA training.") |
| parser.add_argument("--lora_r", type=int, default=16) |
| parser.add_argument("--lora_alpha", type=int, default=32) |
| parser.add_argument("--lora_dropout", type=float, default=0.05) |
|
|
| parser.add_argument("--bf16", action="store_true") |
| parser.add_argument("--fp16", action="store_true") |
|
|
| parser.add_argument("--push_to_hub", action="store_true") |
| parser.add_argument("--hub_model_id", type=str, default=None) |
| parser.add_argument("--hub_private_repo", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def load_jsonl(file_path: str): |
| return load_dataset("json", data_files=file_path, split="train") |
|
|
|
|
| def format_prompt(raw_prompt: str) -> str: |
| raw_prompt = raw_prompt.strip() |
| if raw_prompt.lower().startswith("gravityllm:"): |
| raw_prompt = raw_prompt.split(":", 1)[1].strip() |
| return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n" |
|
|
|
|
| def tokenize_example(example: Dict[str, str], tokenizer, max_length: int) -> Dict[str, List[int]]: |
| prompt_text = format_prompt(example["prompt"]) |
| completion_text = example["completion"].strip() |
|
|
| prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] |
| completion_ids = tokenizer(completion_text + tokenizer.eos_token, add_special_tokens=False)["input_ids"] |
|
|
| input_ids = prompt_ids + completion_ids |
| labels = [-100] * len(prompt_ids) + completion_ids |
|
|
| if len(input_ids) > max_length: |
| input_ids = input_ids[:max_length] |
| labels = labels[:max_length] |
|
|
| attention_mask = [1] * len(input_ids) |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| } |
|
|
|
|
| @dataclass |
| class CausalDataCollator: |
| pad_token_id: int |
| label_pad_token_id: int = -100 |
|
|
| def __call__(self, features): |
| max_len = max(len(f["input_ids"]) for f in features) |
|
|
| input_ids = [] |
| attention_mask = [] |
| labels = [] |
|
|
| for f in features: |
| pad_len = max_len - len(f["input_ids"]) |
| input_ids.append(f["input_ids"] + [self.pad_token_id] * pad_len) |
| attention_mask.append(f["attention_mask"] + [0] * pad_len) |
| labels.append(f["labels"] + [self.label_pad_token_id] * pad_len) |
|
|
| batch = { |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
| "labels": torch.tensor(labels, dtype=torch.long), |
| } |
| return batch |
|
|
|
|
| def prepare_model(args: argparse.Namespace): |
| model_kwargs = {} |
| if args.qlora: |
| compute_dtype = torch.bfloat16 if args.bf16 else torch.float16 |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_compute_dtype=compute_dtype, |
| ) |
| model_kwargs["device_map"] = "auto" |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| args.model, |
| torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else None), |
| trust_remote_code=True, |
| **model_kwargs, |
| ) |
| model.config.use_cache = False |
|
|
| if args.qlora: |
| model = prepare_model_for_kbit_training(model) |
|
|
| if args.lora or args.qlora: |
| lora_config = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules="all-linear", |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| return model |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| os.makedirs(args.output_dir, exist_ok=True) |
| set_seed(args.seed) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, trust_remote_code=True) |
| tokenizer.padding_side = "right" |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| train_ds = load_jsonl(args.train_file) |
| valid_ds = load_jsonl(args.valid_file) if args.valid_file and Path(args.valid_file).exists() else None |
|
|
| train_ds = train_ds.map( |
| lambda row: tokenize_example(row, tokenizer, args.max_length), |
| remove_columns=train_ds.column_names, |
| desc="Tokenizing train set", |
| ) |
| if valid_ds is not None: |
| valid_ds = valid_ds.map( |
| lambda row: tokenize_example(row, tokenizer, args.max_length), |
| remove_columns=valid_ds.column_names, |
| desc="Tokenizing valid set", |
| ) |
|
|
| model = prepare_model(args) |
|
|
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| overwrite_output_dir=True, |
| num_train_epochs=args.num_train_epochs, |
| learning_rate=args.learning_rate, |
| per_device_train_batch_size=args.train_batch_size, |
| per_device_eval_batch_size=args.eval_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| warmup_ratio=args.warmup_ratio, |
| weight_decay=args.weight_decay, |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| eval_steps=args.eval_steps, |
| evaluation_strategy="steps" if valid_ds is not None else "no", |
| save_strategy="steps", |
| bf16=args.bf16, |
| fp16=args.fp16, |
| report_to="none", |
| gradient_checkpointing=True, |
| lr_scheduler_type="cosine", |
| optim="paged_adamw_32bit" if (args.lora or args.qlora) else "adamw_torch", |
| max_grad_norm=1.0, |
| push_to_hub=args.push_to_hub, |
| hub_model_id=args.hub_model_id, |
| hub_private_repo=args.hub_private_repo, |
| hub_strategy="end" if args.push_to_hub else "every_save", |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_ds, |
| eval_dataset=valid_ds, |
| data_collator=CausalDataCollator(pad_token_id=tokenizer.pad_token_id), |
| tokenizer=tokenizer, |
| ) |
|
|
| train_result = trainer.train() |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
| metrics = train_result.metrics |
| with open(Path(args.output_dir) / "training_metrics.json", "w", encoding="utf-8") as f: |
| json.dump(metrics, f, indent=2) |
|
|
| run_meta = vars(args).copy() |
| run_meta["train_examples"] = len(train_ds) |
| run_meta["valid_examples"] = len(valid_ds) if valid_ds is not None else 0 |
| with open(Path(args.output_dir) / "run_config.json", "w", encoding="utf-8") as f: |
| json.dump(run_meta, f, indent=2) |
|
|
| if args.push_to_hub: |
| trainer.push_to_hub(commit_message="Add GravityLLM fine-tuned adapter") |
| print(f"Training complete. Artifacts saved to: {args.output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|