"""
matmmextract.inference.captioner_azure
==========================================
Generate per-panel sub-captions using Azure-hosted models via the
OpenAI-compatible API (Mistral, Llama, GPT, etc.).
Differences from :mod:`~matmmextract.inference.captioner_gemini` (Gemini):
- Uses ``openai.OpenAI`` with a custom ``base_url`` pointing at Azure
- Response format is a JSON schema dict (OpenAI structured outputs)
- Column names differ: ``image_name`` and ``reference`` instead of
``downloaded_image_name`` and ``reference_sentences``
- Retries on ``"error"`` keys in already-written JSON files
Input CSV columns required
--------------------------
``image_name``, ``caption``, ``reference``
Output
------
One JSON file per row in *output_dir*, named ``<image_name>.json``::
{
"panels": [
{
"panel": "a",
"visualization_category": "Microscopy",
"visualization_subtype": "SEM",
"subcaption": "...",
"summary": "..."
},
...
]
}
"""
from __future__ import annotations
import argparse
import csv
import json
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
# ---------------------------------------------------------------------------
# Taxonomy (shared with captioner_gemini.py — kept here to avoid circular imports)
# ---------------------------------------------------------------------------
from matmmextract.inference.captioner_gemini import (
VISUALIZATION_CATEGORIES,
SUBTYPES_BY_CATEGORY,
_TAXONOMY_BLOCK,
)
# ---------------------------------------------------------------------------
# Azure defaults
# ---------------------------------------------------------------------------
DEFAULT_AZURE_ENDPOINT = ""
DEFAULT_MODEL = "Mistral-Large-3"
DEFAULT_MAX_TOKENS = 4096
DEFAULT_MAX_RETRIES = 4
# ---------------------------------------------------------------------------
# Prompt (slightly different wording from the Gemini version)
# ---------------------------------------------------------------------------
_PROMPT_TEMPLATE = """\
You are a scientific figure analysis expert specializing in materials science.
Using only the figure caption and the reference sentences from the paper, \
generate detailed sub-captions for each panel of the figure.
Figure caption:
{caption}
Reference sentences from the paper:
{reference_sentences}
INSTRUCTIONS:
- Generate a sub-caption for each panel (a, b, c, …); use panel id "main" for single-panel figures.
- Classify each panel with TWO fields, chosen ONLY from the allowed taxonomy below:
* visualization_category: the broad category.
* visualization_subtype: the specific technique/plot type, which MUST belong to the chosen category.
- Both fields are strict enums. Use "other" only if absolutely nothing matches.
- For plots/graphs: state axes, specimen/condition, test method, and any key trend or value from the caption or reference.
- For images/maps: state material, technique, orientation/view, processing condition, and key observation or structural feature.
- The caption is the ground truth — never contradict it.
- Use a reference sentence only if it clearly supports the caption content.
- Group-level references apply to all panels; references naming a specific panel letter apply only to that panel.
- Do not start every sub-caption identically — vary sentence openings.
- Expand the caption into a precise scientific description — do not copy it verbatim.
- For each panel, write a summary of 40-60 words using only that panel's subcaption and directly related references. Do not include "panel a" or "the figure" in the summary.
ALLOWED TAXONOMY (category -> valid subtypes):
{taxonomy}
"""
# ---------------------------------------------------------------------------
# JSON schema for structured output (OpenAI / Azure format)
# ---------------------------------------------------------------------------
_RESPONSE_FORMAT = {
"type": "json_schema",
"json_schema": {
"name": "figure_panels",
"strict": True,
"schema": {
"type": "object",
"required": ["panels"],
"additionalProperties": False,
"properties": {
"panels": {
"type": "array",
"items": {
"type": "object",
"required": [
"panel", "visualization_category",
"visualization_subtype", "subcaption", "summary",
],
"additionalProperties": False,
"properties": {
"panel": {"type": "string"},
"visualization_category": {
"type": "string",
"enum": VISUALIZATION_CATEGORIES,
},
"visualization_subtype": {"type": "string"},
"subcaption": {"type": "string"},
"summary": {"type": "string"},
},
},
},
},
},
},
}
# ---------------------------------------------------------------------------
# Result container
# ---------------------------------------------------------------------------
[docs]
@dataclass
class CaptionResult:
n_total: int = 0
n_success: int = 0
n_error: int = 0
n_skipped: int = 0
output_dir: str = ""
# ---------------------------------------------------------------------------
# Core API
# ---------------------------------------------------------------------------
[docs]
def captioner(
csv_path: str | Path,
output_dir: str | Path,
api_key: str | None = None,
azure_endpoint: str = DEFAULT_AZURE_ENDPOINT,
model_name: str = DEFAULT_MODEL,
max_tokens: int = DEFAULT_MAX_TOKENS,
max_retries: int = DEFAULT_MAX_RETRIES,
overwrite: bool = False,
image_name_col: str = "downloaded_image_name",
caption_col: str = "caption",
reference_col: str = "reference_sentences",
requests_per_minute: int | None = None,
verbose: bool = True,
) -> CaptionResult:
"""Generate sub-captions for every row in *csv_path* using an Azure model.
Parameters
----------
csv_path:
CSV with columns ``image_name``, ``caption``, ``reference``
(column names overridable via the ``*_col`` parameters).
output_dir:
Directory where one JSON per figure is written.
api_key:
Azure API key. Falls back to ``AZURE_API_KEY`` env var.
azure_endpoint:
Azure OpenAI-compatible endpoint URL.
model_name:
Model deployment name (e.g. ``"Mistral-Large-3"``, ``"gpt-4o"``).
max_tokens:
Max output tokens. Uses ``max_completion_tokens`` for OpenAI models,
``max_tokens`` for Mistral/Llama.
max_retries:
Retry attempts on API error or existing ``"error"`` JSON.
overwrite:
Re-generate even if output JSON already exists and has no error.
image_name_col / caption_col / reference_col:
Column name overrides for non-standard CSVs.
requests_per_minute:
If set, throttle API calls to this rate by sleeping between requests.
If ``None`` (default), no extra throttling beyond retry back-off.
verbose:
Print progress.
Returns
-------
CaptionResult
"""
try:
from openai import OpenAI
except ImportError:
raise ImportError("openai is required: pip install openai")
api_key = api_key or os.environ.get("AZURE_API_KEY")
if not api_key:
raise EnvironmentError(
"Azure API key required. Set AZURE_API_KEY or pass api_key=."
)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
_min_interval = 60.0 / requests_per_minute if requests_per_minute else None
_last_request_time = [0.0]
client = OpenAI(base_url=azure_endpoint, api_key=api_key)
# Mistral/Llama use max_tokens; OpenAI models use max_completion_tokens
_is_openai_model = model_name.lower().startswith(("gpt", "o1", "o3"))
token_kwarg = (
{"max_completion_tokens": max_tokens}
if _is_openai_model
else {"max_tokens": max_tokens}
)
with open(csv_path, newline="", encoding="utf-8") as fh:
rows = list(csv.DictReader(fh))
# NEW: Deduplicate by figure base name (extract before underscore)
def get_figure_base(image_name: str) -> str:
"""Extract figure ID: 'img8_A' -> 'img8', 'img10_single' -> 'img10'"""
if "_" in image_name:
return image_name.split("_")[0]
return image_name
# Group by figure base name + caption + reference (defensive dedup)
figure_groups: dict[tuple, dict] = {}
for row in rows:
image_name = row.get(image_name_col, "").strip()
if not image_name:
continue
fig_base = get_figure_base(image_name)
caption = row.get(caption_col, "")
ref = row.get(reference_col, "")
# Create a composite key (figure base + caption + reference)
# This ensures we don't accidentally merge figures with different content
key = (fig_base, caption, ref)
if key not in figure_groups:
figure_groups[key] = {
"row": row,
"figure_id": fig_base,
"crop_names": []
}
# Store all crop names that belong to this figure
figure_groups[key]["crop_names"].append(image_name)
# Use deduplicated rows for generation
unique_rows = [g["row"] for g in figure_groups.values()]
if verbose:
print(f"[captioner_azure] {len(rows)} crops → {len(unique_rows)} unique figures")
# Print which figures were deduplicated
for g in figure_groups.values():
if len(g["crop_names"]) > 1:
print(f" → {g['figure_id']}: {len(g['crop_names'])} crops ({', '.join(g['crop_names'])})")
def _needs_run(row: dict) -> bool:
image_name = row.get(image_name_col, "").strip()
fig_base = get_figure_base(image_name)
if not fig_base:
return False
status = row.get("download_status", "success")
if status not in ("success", ""):
return False
if overwrite:
return True
out = output_dir / f"{fig_base}.json"
if not out.exists():
return True
try:
return "error" in json.loads(out.read_text(encoding="utf-8"))
except Exception:
return True
pending = [r for r in unique_rows if _needs_run(r)]
result = CaptionResult(n_total=len(pending), output_dir=str(output_dir))
if not pending:
if verbose:
print("[captioner_azure] Nothing to do.")
result.n_skipped = len(rows)
return result
for i, row in enumerate(pending, start=1):
image_name = row[image_name_col]
fig_base = get_figure_base(image_name)
if verbose:
print(f" [{i}/{len(pending)}] {fig_base} ...", end=" ", flush=True)
prompt = _PROMPT_TEMPLATE.format(
caption=row.get(caption_col, ""),
reference_sentences=row.get(reference_col, ""),
taxonomy=_TAXONOMY_BLOCK,
)
last_error = None
parsed = None
for attempt in range(max_retries):
try:
if _min_interval is not None:
import time as _time
elapsed = _time.monotonic() - _last_request_time[0]
if elapsed < _min_interval:
_time.sleep(_min_interval - elapsed)
_last_request_time[0] = _time.monotonic() if _min_interval else 0.0
response = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
response_format=_RESPONSE_FORMAT,
**token_kwarg,
)
parsed = json.loads(response.choices[0].message.content)
break
except Exception as exc:
last_error = exc
if verbose:
print(f"\n attempt {attempt + 1}/{max_retries} failed: {exc}", flush=True)
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
if parsed is None:
parsed = {"error": str(last_error)}
result.n_error += 1
if verbose:
print(f"ERROR: {last_error}", flush=True)
else:
result.n_success += 1
if verbose:
print(f"ok (panels={len(parsed.get('panels', []))})", flush=True)
out_path = output_dir / f"{fig_base}.json"
out_path.write_text(json.dumps(parsed, indent=2, ensure_ascii=False))
if verbose:
print(
f"\n[captioner_azure] done — "
f"success={result.n_success} errors={result.n_error}"
)
return result
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser("Generate sub-captions via Azure OpenAI")
p.add_argument("--csv", required=True)
p.add_argument("--output-dir", required=True)
p.add_argument("--api-key", default=os.environ.get("AZURE_API_KEY"))
p.add_argument("--endpoint", default=DEFAULT_AZURE_ENDPOINT)
p.add_argument("--model", default=DEFAULT_MODEL)
p.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_TOKENS)
p.add_argument("--overwrite", action="store_true")
p.add_argument("--image-name-col", default="image_name")
p.add_argument("--caption-col", default="caption")
p.add_argument("--reference-col", default="reference")
return p.parse_args()
def main() -> None:
args = _parse_args()
captioner(
csv_path=args.csv,
output_dir=args.output_dir,
api_key=args.api_key,
azure_endpoint=args.endpoint,
model_name=args.model,
max_tokens=args.max_tokens,
overwrite=args.overwrite,
image_name_col=args.image_name_col,
caption_col=args.caption_col,
reference_col=args.reference_col,
)
if __name__ == "__main__":
main()