Source code for matmmextract.inference.captioner_gemini

"""
matmmextract.inference.captioner_gemini
=======================================
Generate per-panel sub-captions using Gemini, from figure captions and
reference sentences extracted during the XML extraction step.


Input
-----
A CSV with columns: ``downloaded_image_name``, ``caption``,
``reference_sentences``, ``download_status``.

Output
------
One JSON file per row in *output_dir*, named
``<downloaded_image_name>.json``, with schema::

    {
        "panels": [
            {
                "panel":                    "a",
                "visualization_category":  "Microscopy",
                "visualization_subtype":   "SEM",
                "subcaption":              "...",
                "summary":                 "..."
            },
            ...
        ]
    }
"""

from __future__ import annotations

import argparse
import csv
import json
import os
import time
from dataclasses import dataclass, field
from pathlib import Path

# ---------------------------------------------------------------------------
# Taxonomy
# ---------------------------------------------------------------------------

VISUALIZATION_CATEGORIES: list[str] = [
    "Microscopy", "Diffraction", "Spectroscopy", "Thermal Analysis",
    "Phase/Equilibrium Diagram", "Mechanical Test", "Electrochemistry",
    "Magnetic/Electronic", "Optical/Photonic", "Tomography/3D",
    "Compositional Map", "Simulation", "Machine Learning",
    "Crystal Structure", "Generic Plot", "Schematic/Diagram",
    "Photograph", "Table", "other",
]

SUBTYPES_BY_CATEGORY: dict[str, list[str]] = {
    "Microscopy": [
        "SEM", "TEM", "STEM", "HAADF-STEM", "BF-TEM", "DF-TEM",
        "Optical Micrograph", "Confocal Microscopy", "AFM",
        "Fluorescence Microscopy", "Live/Dead Staining",
    ],
    "Diffraction": [
        "XRD Pattern", "SAED", "EBSD Map", "Pole Figure",
        "Inverse Pole Figure", "Neutron Diffraction", "Synchrotron Diffraction",
    ],
    "Spectroscopy": [
        "XPS Spectrum", "Raman Spectrum", "FTIR Spectrum", "EDX Spectrum",
        "EELS Spectrum", "NMR Spectrum", "Mass Spectrum", "UV-Vis Spectrum",
        "Photoluminescence Spectrum", "XANES Spectrum", "EXAFS Spectrum",
        "Mössbauer Spectrum",
    ],
    "Thermal Analysis": ["DSC Curve", "TGA Curve", "DMA Curve", "TMA Curve"],
    "Phase/Equilibrium Diagram": [
        "Binary Phase Diagram", "Ternary Phase Diagram", "TTT Diagram",
        "CCT Diagram", "CALPHAD Diagram", "Pourbaix Diagram",
    ],
    "Mechanical Test": [
        "Stress-Strain Curve", "Load-Displacement Curve", "Nanoindentation Curve",
        "Hardness Map", "Fatigue/S-N Curve", "Creep Curve",
        "Fracture Toughness Plot", "DIC Strain Map", "Wear/Tribology Plot",
    ],
    "Electrochemistry": [
        "Cyclic Voltammogram", "Charge-Discharge Curve", "Capacity Retention Plot",
        "Coulombic Efficiency Plot", "Nyquist Plot", "Bode Plot", "Tafel Plot",
        "Polarization Curve", "GITT/PITT Curve", "Rate Capability Plot",
    ],
    "Magnetic/Electronic": [
        "M-H Hysteresis Loop", "M-T Curve", "ZFC/FC Curve", "I-V Curve",
        "C-V Curve", "Band Structure", "Density of States", "Hall Effect Plot",
    ],
    "Optical/Photonic": [
        "Absorbance Spectrum", "Transmittance Spectrum", "Reflectance Spectrum",
        "EQE/IQE Plot", "J-V Curve", "Ellipsometry Plot", "Refractive Index Plot",
    ],
    "Tomography/3D": [
        "APT Reconstruction", "Micro-CT", "FIB-SEM Tomography", "3D Reconstruction",
    ],
    "Compositional Map": [
        "EDS Map", "WDS Map", "EBSD IPF Map", "Elemental Distribution Map",
    ],
    "Simulation": [
        "DFT Result", "MD Snapshot", "MD Trajectory", "Phase-Field Simulation",
        "FEA/FEM Result", "Monte Carlo Result",
    ],
    "Machine Learning": [
        "Parity Plot", "Confusion Matrix", "ROC Curve", "Learning Curve",
        "Feature Importance Plot", "SHAP Plot", "t-SNE/UMAP/PCA Plot",
    ],
    "Crystal Structure": ["Unit Cell", "Atomic Model", "Supercell"],
    "Generic Plot": [
        "Bar Chart", "Scatter Plot", "Line Graph", "Box Plot", "Contour Plot",
        "Heatmap", "Radar Chart", "Ashby Plot", "Arrhenius Plot", "Histogram",
    ],
    "Schematic/Diagram": [
        "Process Schematic", "Flowchart", "Mechanism Diagram", "Experimental Setup",
    ],
    "Photograph": ["Sample Photo", "Equipment Photo", "In-situ Photo"],
    "Table": ["Data Table"],
    "other": ["other"],
}

VISUALIZATION_SUBTYPES: list[str] = list(dict.fromkeys(
    sub for subs in SUBTYPES_BY_CATEGORY.values() for sub in subs
))

_TAXONOMY_BLOCK: str = "\n".join(
    f"- {cat}: {', '.join(subs)}"
    for cat, subs in SUBTYPES_BY_CATEGORY.items()
)

_PROMPT_TEMPLATE = """\
You are a scientific figure analysis expert covering all of materials science \
(metals/alloys, ceramics, polymers, composites, semiconductors, batteries, \
catalysts, biomaterials, 2D materials, etc.).

Using only the figure caption and the reference sentences from the paper, \
generate detailed sub-captions for each panel of the figure.

Figure caption:
{caption}

Reference sentences from the paper:
{reference_sentences}

INSTRUCTIONS:
- Generate a sub-caption for each panel (a, b, c, …); use panel id "main" for single-panel figures.
- Classify each panel with TWO fields, chosen ONLY from the allowed taxonomy below:
    * visualization_category: the broad category.
    * visualization_subtype: the specific technique/plot type, which MUST belong to the chosen category.
- Both fields are strict enums. Use "other" only if absolutely nothing matches.
- For plots/graphs: state axes, specimen/condition, test method, and any key trend or value from the caption or reference.
- For images/maps: state material, technique, orientation/view, processing condition, and key observation or structural feature.
- The caption is the ground truth — never contradict it.
- Use a reference sentence only if it clearly supports the caption content.
- Group-level references apply to all panels; references naming a specific panel letter apply only to that panel.
- Do not start every sub-caption identically — vary sentence openings.
- Expand the caption into a precise scientific description — do not copy it verbatim.
- For each panel, write a summary of 30-50 words using only that panel's subcaption and directly related references.

ALLOWED TAXONOMY (category → valid subtypes):
{taxonomy}
"""


# ---------------------------------------------------------------------------
# Result container
# ---------------------------------------------------------------------------

[docs] @dataclass class CaptionResult: n_total: int = 0 n_success: int = 0 n_error: int = 0 n_skipped: int = 0 output_dir: str = ""
# --------------------------------------------------------------------------- # Core API # ---------------------------------------------------------------------------
[docs] def captioner( csv_path: str | Path, output_dir: str | Path, api_key: str | None = None, model_name: str = "gemini-3.1-flash-lite", max_tokens: int = 4096, max_retries: int = 4, overwrite: bool = False, requests_per_minute: int | None = None, verbose: bool = True, ) -> CaptionResult: """Generate sub-captions for every successfully downloaded figure. Parameters ---------- csv_path: Figure CSV with columns ``downloaded_image_name``, ``caption``, ``reference_sentences``, ``download_status``. output_dir: Directory where one JSON per figure is written. api_key: Google Gemini API key. Falls back to ``GOOGLE_API_KEY`` env var. model_name: Gemini model string. max_tokens: Max output tokens per request. max_retries: Retry attempts on API error. overwrite: Re-generate even if the output JSON already exists. requests_per_minute: If set, throttle API calls to this rate by sleeping between requests. If ``None`` (default), no extra throttling beyond retry back-off. verbose: Print progress. Returns ------- CaptionResult """ try: from google import genai from google.genai import types except ImportError: raise ImportError( "google-genai is required: pip install google-genai" ) api_key = api_key or os.environ.get("GOOGLE_API_KEY") if not api_key: raise EnvironmentError( "Gemini API key required. Set GOOGLE_API_KEY or pass api_key=." ) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) _min_interval = 60.0 / requests_per_minute if requests_per_minute else None _last_request_time = [0.0] client = genai.Client(api_key=api_key) response_schema = types.Schema( type=types.Type.OBJECT, required=["panels"], properties={ "panels": types.Schema( type=types.Type.ARRAY, items=types.Schema( type=types.Type.OBJECT, required=["panel", "visualization_category", "visualization_subtype", "subcaption", "summary"], properties={ "panel": types.Schema(type=types.Type.STRING), "visualization_category": types.Schema( type=types.Type.STRING, enum=VISUALIZATION_CATEGORIES, ), "visualization_subtype": types.Schema( type=types.Type.STRING, ), "subcaption": types.Schema(type=types.Type.STRING), "summary": types.Schema(type=types.Type.STRING), }, ), ), }, ) generate_config = types.GenerateContentConfig( temperature=0.1, max_output_tokens=max_tokens, response_mime_type="application/json", response_schema=response_schema, ) with open(csv_path, newline="", encoding="utf-8") as fh: rows = list(csv.DictReader(fh)) # Deduplicate by figure base: img8_A, img8_B, img8_C → one call for img8 def get_fig_base(name: str) -> str: parts = name.split("_") for i, p in enumerate(parts): if p.lower().startswith("img") and p[3:].isdigit(): return p return name.split("_")[0] seen_figures: dict[tuple, dict] = {} for r in rows: name = str(r.get("downloaded_image_name", "")).strip() if not name: continue if r.get("download_status", "success") not in ("success", ""): continue fig_base = get_fig_base(name) key = (fig_base, r.get("caption", ""), r.get("reference_sentences", "")) if key not in seen_figures: seen_figures[key] = {"row": r, "fig_base": fig_base} unique_rows = [v["row"] for v in seen_figures.values()] pending = [ r for r in unique_rows if overwrite or not (output_dir / f"{get_fig_base(r['downloaded_image_name'])}.json").exists() ] result = CaptionResult( n_total=len(pending), output_dir=str(output_dir), ) if verbose: print(f"[captioner] {len(rows)} crops → {len(unique_rows)} unique figures | {len(pending)} pending | model={model_name}") if not pending: if verbose: print("[captioner] Nothing to do.") result.n_skipped = len(rows) return result for i, row in enumerate(pending, start=1): image_name = row["downloaded_image_name"] fig_base = get_fig_base(image_name) prompt = _PROMPT_TEMPLATE.format( caption=row.get("caption", ""), reference_sentences=row.get("reference_sentences", ""), taxonomy=_TAXONOMY_BLOCK, ) last_error = None parsed = None for attempt in range(max_retries): try: if _min_interval is not None: import time as _time elapsed = _time.monotonic() - _last_request_time[0] if elapsed < _min_interval: _time.sleep(_min_interval - elapsed) _last_request_time[0] = _time.monotonic() if _min_interval else 0.0 response = client.models.generate_content( model=model_name, contents=prompt, config=generate_config, ) parsed = json.loads(response.text) break except Exception as exc: last_error = exc if attempt < max_retries - 1: time.sleep(2 ** attempt) if parsed is None: parsed = {"error": str(last_error)} result.n_error += 1 if verbose: print(f" [{i}/{len(pending)}] ERROR {image_name}: {last_error}") else: result.n_success += 1 if verbose and (i % 10 == 0 or i == len(pending)): pct = i / len(pending) * 100 print( f" [{pct:5.1f}%] {i}/{len(pending)} " f"ok={result.n_success} err={result.n_error}" ) out_path = output_dir / f"{fig_base}.json" out_path.write_text(json.dumps(parsed, indent=2, ensure_ascii=False)) if verbose: print(f"\n[captioner] done — success={result.n_success} errors={result.n_error}") return result
# --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser("Generate sub-captions with Gemini") p.add_argument("--csv", required=True, help="Figure CSV path") p.add_argument("--output-dir", required=True) p.add_argument("--api-key", default=os.environ.get("GOOGLE_API_KEY")) p.add_argument("--model", default="gemini-3.1-flash-lite") p.add_argument("--max-tokens", type=int, default=4096) p.add_argument("--overwrite", action="store_true") return p.parse_args() def main() -> None: args = _parse_args() captioner( csv_path=args.csv, output_dir=args.output_dir, api_key=args.api_key, model_name=args.model, max_tokens=args.max_tokens, overwrite=args.overwrite, requests_per_minute=args.rpm, ) if __name__ == "__main__": main()