Source code for matmmextract.inference.dataset_builder

"""
matmmextract.inference.dataset_builder
==========================================
Link cropped panel images with their JSON sub-captions to produce
the final ``linked_dataset.csv``.

JSON lookup key
---------------
Each JSON file is named after the full crop stem:
    img11_A.json      ← written by captioner for crop img11_A.jpg
    img10_single.json ← written by captioner for crop img10_single.jpg

The linker looks up the JSON by full crop stem, then finds the matching
panel inside the JSON by the panel letter/key.

Filename pattern matched
------------------------
``imgXXXX_single.jpg``      →  looks for img XXXX_single.json, panel "main"
``imgXXXX_single_2.jpg``    →  looks for imgXXXX_single_2.json, panel "main"
``imgXXXX_A.jpg``           →  looks for imgXXXX_A.json, panel "a"
``imgXXXX_A_2.jpg``         →  looks for imgXXXX_A_2.json, panel "a"

Output columns
--------------
image_filename, visualization_category, visualization_subtype,
subcaption, summary
"""

from __future__ import annotations

import argparse
import csv
import json
import os
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path

IMG_RE = re.compile(
    r"^(.+)_(single|[A-Z])(?:_(\d+))?\.jpg$",
    re.IGNORECASE,
)

CSV_FIELDS = [
    "image_filename",
    "image_id",
    "panel_suffix",
    "variant",
    "json_panel",
    "visualization_category",
    "visualization_subtype",
    "subcaption",
    "summary",
    "matched",
]


[docs] @dataclass class BuildResult: n_images: int = 0 n_matched: int = 0 n_no_json: int = 0 n_no_panel: int = 0 n_skipped_pattern: int = 0 n_has_caption: int = 0 output_csv: str = "" log_path: str = ""
def _load_json_index(json_dir: Path) -> dict[str, dict]: """Load all JSON files into {stem: parsed_json} dict.""" index: dict[str, dict] = {} for fname in os.listdir(json_dir): if not fname.endswith(".json"): continue stem = fname[:-5] # e.g. "img11_A" try: with open(json_dir / fname, encoding="utf-8") as fh: index[stem] = json.load(fh) except json.JSONDecodeError: pass return index def _build_panel_lookup(json_obj: dict) -> dict[str, dict]: """ Return {panel_key_lower: panel_data} for simple keys only. Accepts: "main", single letters (a-z). Skips complex keys like "a1", "a1-a6", "d-e". """ lookup: dict[str, dict] = {} for panel in json_obj.get("panels", []): key = panel.get("panel", "").strip() if key == "main" or (len(key) == 1 and key.isalpha()): lookup[key.lower()] = panel return lookup
[docs] def build( images_dir: str | Path, json_dir: str | Path, output_csv: str | Path = "linked_dataset.csv", log_path: str | Path = "build_dataset.log", verbose: bool = True, ) -> BuildResult: """Link cropped images with sub-caption JSON files. Parameters ---------- images_dir: Directory of cropped panel images (output of cropper.run). json_dir: Directory of per-crop sub-caption JSON files (output of captioner.run or captioner_azure.run). Each JSON is named after the crop stem: img11_A.json. output_csv: Path for the final linked CSV. log_path: Path for the human-readable build log. verbose: Print stats to stdout. """ images_dir = Path(images_dir) json_dir = Path(json_dir) output_csv = Path(output_csv) log_path = Path(log_path) # index: "img11_A" → parsed JSON json_index = _load_json_index(json_dir) image_files = sorted(os.listdir(images_dir)) rows: list[dict] = [] skipped_pattern: list[str] = [] no_json: list[str] = [] panel_not_found: list[str] = [] matched: list[str] = [] for fname in image_files: m = IMG_RE.match(fname) if not m: skipped_pattern.append(fname) continue img_id = m.group(1) panel_letter = m.group(2) variant = m.group(3) or "" if panel_letter.lower() == "single": panel_suffix = "single" lookup_key = "main" else: panel_suffix = f"{panel_letter.upper()}_{variant}" if variant else panel_letter.upper() lookup_key = panel_letter.lower() row: dict = { "image_filename": fname, "image_id": img_id, "panel_suffix": panel_suffix, "variant": variant, "json_panel": "", "visualization_category": "", "visualization_subtype": "", "subcaption": "", "summary": "", "matched": False, } json_obj = json_index.get(img_id) if json_obj is None: no_json.append(fname) else: panel_data = _build_panel_lookup(json_obj).get(lookup_key) if panel_data: row["json_panel"] = panel_data.get("panel", "") row["visualization_category"] = panel_data.get("visualization_category", "") row["visualization_subtype"] = panel_data.get("visualization_subtype", "") row["subcaption"] = panel_data.get("subcaption", "") row["summary"] = panel_data.get("summary", "") row["matched"] = True matched.append(fname) else: panel_not_found.append(fname) rows.append(row) # Write CSV output_csv.parent.mkdir(parents=True, exist_ok=True) with open(output_csv, "w", newline="", encoding="utf-8") as fh: writer = csv.DictWriter(fh, fieldnames=CSV_FIELDS) writer.writeheader() writer.writerows(rows) # Stats total_rows = len(rows) n_matched = len(matched) n_no_json = len(no_json) n_no_panel = len(panel_not_found) n_skipped = len(skipped_pattern) n_has_caption = sum( 1 for r in rows if r["matched"] and r["subcaption"] and r["summary"] ) def pct(n): return n / total_rows * 100 if total_rows else 0 log_lines = [ f"build_dataset — {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", "=" * 60, f"Total image files scanned : {len(image_files)}", f"Skipped (bad filename) : {n_skipped}", f"Processed rows : {total_rows}", "", "── Image → JSON linkage ─────────────────────────────────", f" Matched (image + panel) : {n_matched:>6} ({pct(n_matched):.1f}%)", f" No JSON file : {n_no_json:>6} ({pct(n_no_json):.1f}%)", f" JSON found, panel missing: {n_no_panel:>6} ({pct(n_no_panel):.1f}%)", "", "── Caption linkage (of total processed) ─────────────────", f" Has subcaption + summary : {n_has_caption:>6} ({pct(n_has_caption):.1f}%)", "", "── Output ───────────────────────────────────────────────", f" CSV : {output_csv}", f" Log : {log_path}", ] if skipped_pattern: log_lines += ["", "── Skipped filenames (bad pattern) ──────────────────────"] log_lines += [f" {f}" for f in skipped_pattern[:20]] if len(skipped_pattern) > 20: log_lines.append(f" ... and {len(skipped_pattern) - 20} more") log_text = "\n".join(log_lines) log_path.parent.mkdir(parents=True, exist_ok=True) log_path.write_text(log_text + "\n", encoding="utf-8") if verbose: print(log_text) return BuildResult( n_images=len(image_files), n_matched=n_matched, n_no_json=n_no_json, n_no_panel=n_no_panel, n_skipped_pattern=n_skipped, n_has_caption=n_has_caption, output_csv=str(output_csv), log_path=str(log_path), )
def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser("Link crops + sub-captions → dataset CSV") p.add_argument("--images-dir", required=True) p.add_argument("--json-dir", required=True) p.add_argument("--output-csv", default="linked_dataset.csv") p.add_argument("--log", default="build_dataset.log") return p.parse_args() def main() -> None: args = _parse_args() build( images_dir=args.images_dir, json_dir=args.json_dir, output_csv=args.output_csv, log_path=args.log, ) if __name__ == "__main__": main()