| """ |
| Convert the plain-text summary (like the one you pasted) into a LaTeX table. |
| |
| Features |
| - Parses blocks like: --- Dataset: QM9 --- |
| - Reads Val metrics as "mean ± std" |
| - Reads Test metrics as either "mean [low, high]" (CI) or "mean ± std" |
| - Option --test-intervals {ci, pm}: |
| ci -> keep "mean [low, high]" strings (uses text columns for the 2 Test cols) |
| pm -> convert CI to ± half-width to match siunitx S columns |
| - Multirow per dataset; booktabs rules; siunitx S columns; optional renaming & bolding |
| |
| Usage |
| python latex_table_from_txt.py \ |
| --input results.txt --output table.tex \ |
| --test-intervals pm \ |
| --rename "polyatomic=PACTNet (ECC)" \ |
| --bold-contains "PACTNet" \ |
| --val-dec 3 --test-dec 4 |
| """ |
|
|
| import argparse |
| import re |
| from pathlib import Path |
| import pandas as pd |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="Convert TXT performance summary to LaTeX table (booktabs + siunitx + multirow)." |
| ) |
| p.add_argument("--input", "-i", type=Path, required=True, help="Input TXT file") |
| p.add_argument("--output", "-o", type=Path, required=True, help="Output .tex file") |
| p.add_argument( |
| "--caption", |
| default="Comprehensive performance comparison across all datasets and models.", |
| help="LaTeX caption", |
| ) |
| p.add_argument("--label", default="tab:full_results", help="LaTeX label") |
| p.add_argument( |
| "--val-dec", |
| type=int, |
| default=3, |
| help="Decimal places for Val metrics (mean & std)", |
| ) |
| p.add_argument( |
| "--test-dec", |
| type=int, |
| default=4, |
| help="Decimal places for Test metrics (mean & std)", |
| ) |
| p.add_argument( |
| "--no-fixed-decimals", |
| action="store_true", |
| help="Use raw decimals as provided (don't round to fixed places)", |
| ) |
| p.add_argument( |
| "--table-formats", |
| nargs=4, |
| default=["2.3(4)", "2.3(4)", "2.4(4)", "2.4(4)"], |
| help="siunitx table-format for Val RMSE, Val MAE, Test RMSE, Test MAE", |
| ) |
| p.add_argument( |
| "--font-size", default="\\small", help="LaTeX font size inside the table" |
| ) |
| p.add_argument("--width", default="\\textwidth", help="Width for \\resizebox") |
| p.add_argument( |
| "--no-resize", action="store_true", help="Disable \\resizebox wrapper" |
| ) |
| p.add_argument( |
| "--booktabs", action="store_true", default=True, help="Use booktabs rules" |
| ) |
| p.add_argument( |
| "--no-booktabs", |
| dest="booktabs", |
| action="store_false", |
| help="Disable booktabs rules", |
| ) |
| p.add_argument( |
| "--test-intervals", |
| choices=["ci", "pm"], |
| default="pm", |
| help="For Test metrics with CIs: keep CIs (ci) or convert to ± half-width (pm)", |
| ) |
| p.add_argument( |
| "--bold-contains", |
| default=None, |
| help="Regex to bold any row where model/rep cell matches", |
| ) |
| p.add_argument( |
| "--rename", |
| nargs="*", |
| default=[], |
| help='Rename patterns like old=new (regex on the "Model (Rep.)" cell)', |
| ) |
| p.add_argument("--dataset-order", nargs="*", help="Optional explicit dataset order") |
| p.add_argument( |
| "--sort-by", |
| nargs="*", |
| default=None, |
| help="Sort keys within each dataset, e.g., --sort-by model representation", |
| ) |
| p.add_argument( |
| "--ascending", |
| nargs="*", |
| type=int, |
| help="Ascending flags matching --sort-by, e.g. 1 0", |
| ) |
| return p.parse_args() |
|
|
|
|
| def fmt_unc(mean, std, fixed_decimals: bool, dec_places: int) -> str: |
| if pd.isna(mean) or pd.isna(std): |
| return r"\textemdash" |
| if fixed_decimals: |
| return f"{float(mean):.{dec_places}f} \\pm {float(std):.{dec_places}f}" |
|
|
| |
| def tidy(x): |
| s = f"{x}" |
| if "e" in s or "E" in s: |
| return s |
| if "." in s: |
| s = s.rstrip("0").rstrip(".") |
| return s |
|
|
| return f"{tidy(mean)} \\pm {tidy(std)}" |
|
|
|
|
| def build_model_rep(name: str) -> tuple[str, str, str]: |
| """ |
| Split 'gat_ecfp' -> ('GAT', 'ECFP', 'GAT (ECFP)') |
| If no underscore, rep is ''. |
| """ |
| name = name.strip() |
| model, rep = name, "" |
| if "_" in name: |
| model, rep = name.split("_", 1) |
| |
| model_fmt = model.upper() if model.isalpha() else model |
| rep_fmt = rep.upper() if rep else "" |
| label = f"{model_fmt} ({rep_fmt})" if rep_fmt else model_fmt |
| return model_fmt, rep_fmt, label |
|
|
|
|
| def apply_renames(s: str, mapping: dict) -> str: |
| for k, v in mapping.items(): |
| s = re.sub(k, v, s) |
| return s |
|
|
|
|
| def parse_metric(cell: str): |
| """ |
| Returns dict with possible keys: mean, std, ci_low, ci_high |
| Accepts: |
| - '1.234 ± 0.056' |
| - '1.234 [1.111, 1.345]' |
| - '1.234' |
| """ |
| s = cell.strip() |
| m = re.match(r"([+-]?\d+(?:\.\d+)?)\s*±\s*([+-]?\d+(?:\.\d+)?)", s) |
| if m: |
| return { |
| "mean": float(m.group(1)), |
| "std": float(m.group(2)), |
| "ci_low": None, |
| "ci_high": None, |
| } |
| m = re.match( |
| r"([+-]?\d+(?:\.\d+)?)\s*\[\s*([+-]?\d+(?:\.\d+)?)\s*,\s*([+-]?\d+(?:\.\d+)?)\s*\]", |
| s, |
| ) |
| if m: |
| return { |
| "mean": float(m.group(1)), |
| "std": None, |
| "ci_low": float(m.group(2)), |
| "ci_high": float(m.group(3)), |
| } |
| m = re.match(r"([+-]?\d+(?:\.\d+)?)$", s) |
| if m: |
| return {"mean": float(m.group(1)), "std": None, "ci_low": None, "ci_high": None} |
| return {"mean": None, "std": None, "ci_low": None, "ci_high": None} |
|
|
|
|
| def parse_txt(path: Path) -> pd.DataFrame: |
| """ |
| Parse the text file structure you showed into a tidy DataFrame. |
| """ |
| text = path.read_text(encoding="utf-8", errors="ignore") |
| |
| blocks = [] |
| for m in re.finditer(r"---\s*Dataset:\s*(.+?)\s*---", text): |
| blocks.append((m.start(), m.group(1).strip())) |
| rows = [] |
| for i, (pos, dataset) in enumerate(blocks): |
| start = pos |
| end = blocks[i + 1][0] if i + 1 < len(blocks) else len(text) |
| body = text[start:end] |
|
|
| |
| table_lines = [] |
| after_header = False |
| for line in body.splitlines(): |
| if re.search(r"\|\s*Val RMSE", line): |
| after_header = True |
| continue |
| if after_header: |
| if line.strip().startswith("--- Statistical"): |
| break |
| if re.match(r"\s*$", line): |
| break |
| |
| if re.match(r"[-\s]{5,}$", line.replace("|", "")): |
| continue |
| if "|" in line: |
| table_lines.append(line) |
|
|
| for line in table_lines: |
| parts = [p.strip() for p in line.split("|")] |
| if len(parts) < 5: |
| continue |
| name = parts[0] |
| val_rmse = parse_metric(parts[1]) |
| val_mae = parse_metric(parts[2]) |
| test_rmse = parse_metric(parts[3]) |
| test_mae = parse_metric(parts[4]) |
| model, rep, label = build_model_rep(name) |
| rows.append( |
| { |
| "dataset": dataset, |
| "model": model, |
| "representation": rep, |
| "label": label, |
| "val_rmse_mean": val_rmse["mean"], |
| "val_rmse_std": val_rmse["std"], |
| "val_mae_mean": val_mae["mean"], |
| "val_mae_std": val_mae["std"], |
| "test_rmse_mean": test_rmse["mean"], |
| "test_rmse_std": test_rmse["std"], |
| "test_rmse_ci_low": test_rmse["ci_low"], |
| "test_rmse_ci_high": test_rmse["ci_high"], |
| "test_mae_mean": test_mae["mean"], |
| "test_mae_std": test_mae["std"], |
| "test_mae_ci_low": test_mae["ci_low"], |
| "test_mae_ci_high": test_mae["ci_high"], |
| } |
| ) |
| return pd.DataFrame(rows) |
|
|
|
|
| def main(): |
| args = parse_args() |
| df = parse_txt(args.input) |
|
|
| |
| if args.sort_by: |
| asc = ( |
| [True] * len(args.sort_by) |
| if args.ascending is None |
| else [bool(int(a)) for a in args.ascending] |
| ) |
| df = df.sort_values(by=args.sort_by, ascending=asc) |
| if args.dataset_order: |
| cat = pd.Categorical(df["dataset"], categories=args.dataset_order, ordered=True) |
| df = df.assign(_dataset=cat).sort_values("_dataset").drop(columns="_dataset") |
|
|
| |
| rename_map = dict(kv.split("=", 1) for kv in args.rename) if args.rename else {} |
| bold_re = ( |
| re.compile(args.bold_contains) if args.bold_contains else None |
| ) |
| |
| bold_re = re.compile(args.bold_contains) if args.bold_contains else None |
|
|
| |
| if args.test_intervals == "ci": |
| colspec = ( |
| "@{}ll " |
| + " ".join( |
| [ |
| f"S[table-format={args.table_formats[0]}]", |
| f"S[table-format={args.table_formats[1]}]", |
| "l", |
| "l", |
| ] |
| ) |
| + "@{}" |
| ) |
| else: |
| colspec = ( |
| "@{}ll " |
| + " ".join([f"S[table-format={tf}]" for tf in args.table_formats]) |
| + "@{}" |
| ) |
|
|
| |
| lines = [] |
| lines.append(r"\begin{table}[h]") |
| lines.append(r"\centering") |
| if args.font_size: |
| lines.append(f"{args.font_size} % Font size") |
| lines.append(r"\caption{" + args.caption + r"}") |
| lines.append(r"\label{" + args.label + r"}") |
| lines.append(r"% siunitx settings") |
| lines.append(r"\sisetup{separate-uncertainty, table-align-text-post=false}") |
|
|
| inner_begin = r"\begin{tabular}{" + colspec + r"}" |
| inner_end = r"\end{tabular}" |
|
|
| if args.no_resize: |
| lines.append(inner_begin) |
| else: |
| lines.append(r"\resizebox{" + args.width + r"}{!}{" + inner_begin) |
|
|
| if args.booktabs: |
| lines.append(r"\toprule") |
| lines.append( |
| r"\textbf{Dataset} & \textbf{Model (Rep.)} & {Val RMSE} & {Val MAE} & {Test RMSE} & {Test MAE} \\" |
| ) |
| if args.booktabs: |
| lines.append(r"\midrule") |
|
|
| |
| for dataset, g in df.groupby("dataset", sort=False): |
| n = len(g) |
| first = True |
| for _, row in g.iterrows(): |
| |
| cell_model = apply_renames(row["label"], rename_map) |
| do_bold = bool(bold_re and bold_re.search(cell_model)) if bold_re else False |
|
|
| |
| val_rmse = fmt_unc( |
| row["val_rmse_mean"], |
| row["val_rmse_std"], |
| fixed_decimals=not args.no_fixed_decimals, |
| dec_places=args.val_dec, |
| ) |
| val_mae = fmt_unc( |
| row["val_mae_mean"], |
| row["val_mae_std"], |
| fixed_decimals=not args.no_fixed_decimals, |
| dec_places=args.val_dec, |
| ) |
|
|
| |
| def ci_or_pm(mean, std, lo, hi): |
| if args.test_intervals == "ci" and (lo is not None and hi is not None): |
| return f"{mean} [{lo}, {hi}]" |
| if std is None and (lo is not None and hi is not None): |
| std = (hi - lo) / 2.0 |
| return fmt_unc( |
| mean, |
| std, |
| fixed_decimals=not args.no_fixed_decimals, |
| dec_places=args.test_dec, |
| ) |
|
|
| test_rmse = ci_or_pm( |
| row["test_rmse_mean"], |
| row["test_rmse_std"], |
| row["test_rmse_ci_low"], |
| row["test_rmse_ci_high"], |
| ) |
| test_mae = ci_or_pm( |
| row["test_mae_mean"], |
| row["test_mae_std"], |
| row["test_mae_ci_low"], |
| row["test_mae_ci_high"], |
| ) |
|
|
| parts = [] |
| if first: |
| parts.append(rf"\multirow{{{n}}}{{*}}{{{dataset}}}") |
| first = False |
| else: |
| parts.append("") |
|
|
| if do_bold: |
| parts.append(rf"\bfseries {cell_model}") |
| parts.append(rf"\bfseries {val_rmse}") |
| parts.append(rf"\bfseries {val_mae}") |
| parts.append(rf"\bfseries {test_rmse}") |
| parts.append(rf"\bfseries {test_mae}") |
| else: |
| parts.append(cell_model) |
| parts.append(val_rmse) |
| parts.append(val_mae) |
| parts.append(test_rmse) |
| parts.append(test_mae) |
|
|
| lines.append(" & ".join(parts) + r" \\") |
|
|
| if args.booktabs: |
| lines.append(r"\bottomrule") |
| lines.append(inner_end) |
| if not args.no_resize: |
| lines.append("}") |
| lines.append(r"\end{table}") |
|
|
| args.output.write_text("\n".join(lines), encoding="utf-8") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|