import argparse import json import re from pathlib import Path from typing import Dict, Tuple import torch from datasets import load_dataset from jsonschema import Draft7Validator from peft import AutoPeftModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM_PREFIX = ( "You are GravityLLM, a Spatial9 scene generation model. " "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " "Do not return markdown. Do not explain your answer.\n\n" ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate GravityLLM outputs on a JSONL validation set.") parser.add_argument("--model_dir", type=str, required=True) parser.add_argument("--data_file", type=str, default="data/valid.jsonl") parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json")) parser.add_argument("--max_new_tokens", type=int, default=900) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--limit", type=int, default=0, help="0 means evaluate all rows.") parser.add_argument("--report_path", type=Path, default=Path("reports/eval_report.json")) return parser.parse_args() def load_model_and_tokenizer(model_dir: str): tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token try: model = AutoPeftModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) except Exception: model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) model.eval() return model, tokenizer def format_prompt(raw_prompt: str) -> str: raw_prompt = raw_prompt.strip() if raw_prompt.lower().startswith("gravityllm:"): raw_prompt = raw_prompt.split(":", 1)[1].strip() return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n" def extract_first_json(text: str) -> str: match = re.search(r"\{.*\}", text, flags=re.DOTALL) return match.group(0).strip() if match else text.strip() def validate_schema(schema, output_text: str) -> Tuple[bool, Dict]: data = json.loads(output_text) validator = Draft7Validator(schema) errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path)) return len(errors) == 0, data def check_budget(input_payload: Dict, scene_payload: Dict) -> bool: max_objects = input_payload.get("max_objects") if max_objects is None: return True return len(scene_payload.get("objects", [])) <= max_objects def check_anchor_rules(input_payload: Dict, scene_payload: Dict) -> bool: objects = {obj["class"]: obj for obj in scene_payload.get("objects", [])} for rule in input_payload.get("rules", []): if rule.get("type") != "anchor": continue klass = rule.get("track_class") obj = objects.get(klass) if obj is None: return False for field in ["az_deg", "el_deg", "dist_m"]: if float(obj[field]) != float(rule[field]): return False return True def generate_scene(model, tokenizer, prompt_text: str, max_new_tokens: int, temperature: float, top_p: float) -> str: inputs = tokenizer(prompt_text, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) raw_completion = decoded[len(prompt_prefix):].strip() return extract_first_json(raw_completion) def main() -> None: args = parse_args() schema = json.loads(args.schema_path.read_text(encoding="utf-8")) ds = load_dataset("json", data_files=args.data_file, split="train") if args.limit > 0: ds = ds.select(range(min(args.limit, len(ds)))) model, tokenizer = load_model_and_tokenizer(args.model_dir) total = len(ds) parse_ok = 0 schema_ok = 0 budget_ok = 0 anchor_ok = 0 samples = [] for row in ds: prompt_text = format_prompt(row["prompt"]) generated = generate_scene(model, tokenizer, prompt_text, args.max_new_tokens, args.temperature, args.top_p) sample_report = {"prompt": row["prompt"], "generated": generated} try: gen_data = json.loads(generated) parse_ok += 1 valid, gen_scene = validate_schema(schema, generated) if valid: schema_ok += 1 # Reconstruct input payload from prompt for simple rule checks. prompt_payload_text = row["prompt"].split("INPUT:\n", 1)[1] input_payload = json.loads(prompt_payload_text) if check_budget(input_payload, gen_scene): budget_ok += 1 if check_anchor_rules(input_payload, gen_scene): anchor_ok += 1 sample_report["schema_valid"] = True sample_report["budget_pass"] = check_budget(input_payload, gen_scene) sample_report["anchor_pass"] = check_anchor_rules(input_payload, gen_scene) else: sample_report["schema_valid"] = False except Exception as exc: sample_report["error"] = str(exc) samples.append(sample_report) report = { "examples": total, "json_parse_rate": round(parse_ok / total, 4) if total else 0.0, "schema_valid_rate": round(schema_ok / total, 4) if total else 0.0, "budget_pass_rate": round(budget_ok / total, 4) if total else 0.0, "anchor_pass_rate": round(anchor_ok / total, 4) if total else 0.0, "samples": samples[:10], } args.report_path.parent.mkdir(parents=True, exist_ok=True) args.report_path.write_text(json.dumps(report, indent=2), encoding="utf-8") print(json.dumps(report, indent=2)) if __name__ == "__main__": main()