GravityLLM / evaluate.py
lzanardos9's picture
Upload 20 files
b7720f0 verified
import argparse
import json
import re
from pathlib import Path
from typing import Dict, Tuple
import torch
from datasets import load_dataset
from jsonschema import Draft7Validator
from peft import AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PREFIX = (
"You are GravityLLM, a Spatial9 scene generation model. "
"Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. "
"Do not return markdown. Do not explain your answer.\n\n"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate GravityLLM outputs on a JSONL validation set.")
parser.add_argument("--model_dir", type=str, required=True)
parser.add_argument("--data_file", type=str, default="data/valid.jsonl")
parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json"))
parser.add_argument("--max_new_tokens", type=int, default=900)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--limit", type=int, default=0, help="0 means evaluate all rows.")
parser.add_argument("--report_path", type=Path, default=Path("reports/eval_report.json"))
return parser.parse_args()
def load_model_and_tokenizer(model_dir: str):
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
except Exception:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
model.eval()
return model, tokenizer
def format_prompt(raw_prompt: str) -> str:
raw_prompt = raw_prompt.strip()
if raw_prompt.lower().startswith("gravityllm:"):
raw_prompt = raw_prompt.split(":", 1)[1].strip()
return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n"
def extract_first_json(text: str) -> str:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
return match.group(0).strip() if match else text.strip()
def validate_schema(schema, output_text: str) -> Tuple[bool, Dict]:
data = json.loads(output_text)
validator = Draft7Validator(schema)
errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path))
return len(errors) == 0, data
def check_budget(input_payload: Dict, scene_payload: Dict) -> bool:
max_objects = input_payload.get("max_objects")
if max_objects is None:
return True
return len(scene_payload.get("objects", [])) <= max_objects
def check_anchor_rules(input_payload: Dict, scene_payload: Dict) -> bool:
objects = {obj["class"]: obj for obj in scene_payload.get("objects", [])}
for rule in input_payload.get("rules", []):
if rule.get("type") != "anchor":
continue
klass = rule.get("track_class")
obj = objects.get(klass)
if obj is None:
return False
for field in ["az_deg", "el_deg", "dist_m"]:
if float(obj[field]) != float(rule[field]):
return False
return True
def generate_scene(model, tokenizer, prompt_text: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
inputs = tokenizer(prompt_text, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
raw_completion = decoded[len(prompt_prefix):].strip()
return extract_first_json(raw_completion)
def main() -> None:
args = parse_args()
schema = json.loads(args.schema_path.read_text(encoding="utf-8"))
ds = load_dataset("json", data_files=args.data_file, split="train")
if args.limit > 0:
ds = ds.select(range(min(args.limit, len(ds))))
model, tokenizer = load_model_and_tokenizer(args.model_dir)
total = len(ds)
parse_ok = 0
schema_ok = 0
budget_ok = 0
anchor_ok = 0
samples = []
for row in ds:
prompt_text = format_prompt(row["prompt"])
generated = generate_scene(model, tokenizer, prompt_text, args.max_new_tokens, args.temperature, args.top_p)
sample_report = {"prompt": row["prompt"], "generated": generated}
try:
gen_data = json.loads(generated)
parse_ok += 1
valid, gen_scene = validate_schema(schema, generated)
if valid:
schema_ok += 1
# Reconstruct input payload from prompt for simple rule checks.
prompt_payload_text = row["prompt"].split("INPUT:\n", 1)[1]
input_payload = json.loads(prompt_payload_text)
if check_budget(input_payload, gen_scene):
budget_ok += 1
if check_anchor_rules(input_payload, gen_scene):
anchor_ok += 1
sample_report["schema_valid"] = True
sample_report["budget_pass"] = check_budget(input_payload, gen_scene)
sample_report["anchor_pass"] = check_anchor_rules(input_payload, gen_scene)
else:
sample_report["schema_valid"] = False
except Exception as exc:
sample_report["error"] = str(exc)
samples.append(sample_report)
report = {
"examples": total,
"json_parse_rate": round(parse_ok / total, 4) if total else 0.0,
"schema_valid_rate": round(schema_ok / total, 4) if total else 0.0,
"budget_pass_rate": round(budget_ok / total, 4) if total else 0.0,
"anchor_pass_rate": round(anchor_ok / total, 4) if total else 0.0,
"samples": samples[:10],
}
args.report_path.parent.mkdir(parents=True, exist_ok=True)
args.report_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()