Source code for matmmextract.springer.extractor

from __future__ import annotations

import argparse
import os
import re
import warnings
from pathlib import Path

import pandas as pd
from bs4 import BeautifulSoup, NavigableString, Tag, XMLParsedAsHTMLWarning

from ..shared.sentence_utils import split_sentences

warnings.filterwarnings("ignore", category=XMLParsedAsHTMLWarning)



PARA_TAGS: frozenset[str] = frozenset({"p", "para", "simple-para"})
FIG_TAGS: frozenset[str] = frozenset({"fig"})
XREF_TAGS: frozenset[str] = frozenset({"xref"})

SKIP_TEXT_TAGS: frozenset[str] = frozenset({
    "math", "mml:math", "mml:semantics",
    "disp-formula", "inline-formula",
    "tex-math", "alternatives",
})

GA_RE = re.compile(r"graphical\s+abstract", re.IGNORECASE)


def _tag_name(node) -> str:
    if not isinstance(node, Tag) or not node.name:
        return ""
    return node.name.lower()


def _local_name(node) -> str:
    return _tag_name(node).split(":", 1)[-1]


def _attr_value(tag, wanted: str) -> str:
    for key, value in tag.attrs.items():
        if key == wanted or key.split(":", 1)[-1] == wanted:
            if isinstance(value, list):
                return " ".join(str(v) for v in value)
            return str(value)
    return ""


def _clean_markers(text: str) -> str:
    text = re.sub(r"__XREF_.*?__\s*", "", text).strip()
    return re.sub(r"\s+([,.;:!?])", r"\1", text)


def _normalize_caption(text: str) -> str:
    text = re.sub(r"\s+", " ", text).strip()
    return re.sub(r"^(Fig(?:ure)?\.?\s*\d+[A-Za-z]?\s*){2,}", r"\1", text)


def _figure_number_from_id(fig_id: str | None) -> int | None:
    m = re.search(r"(\d+)", fig_id or "")
    return int(m.group(1)) if m else None


def _is_figure_xref(xref) -> bool:
    ref_type = _attr_value(xref, "ref-type").lower()
    rid = _attr_value(xref, "rid")
    if ref_type in {"fig", "figure"}:
        return bool(rid)
    return bool(re.search(r"\b(fig|f)\d+[a-z]?\b", rid, re.IGNORECASE))


def _extract_fallback_figure_numbers(sentence: str) -> set[int]:
    """Extract figure numbers from untagged mentions (e.g. 'Figs. 2-5')."""
    numbers: set[int] = set()
    for match in re.finditer(
        r"\bFig(?:ure)?s?\.?\s+([0-9A-Za-z\(\)\[\],\s\-–—]+)",
        sentence, re.IGNORECASE,
    ):
        phrase = match.group(1)
        phrase = re.split(
            r"\b(?:show|shows|present|presents|illustrate|illustrates|"
            r"display|displays|compare|compares|summarize|summarizes)\b",
            phrase, 1, flags=re.IGNORECASE,
        )[0]
        for start, end in re.findall(r"(\d+)\s*[-–—]\s*(\d+)", phrase):
            lo, hi = int(start), int(end)
            if lo <= hi and hi - lo <= 50:
                numbers.update(range(lo, hi + 1))
        for number in re.findall(r"\d+", phrase):
            numbers.add(int(number))
    return numbers


[docs] def resolve_image_url(fig_tag) -> str: """Resolve image location from JATS ``<graphic xlink:href="...">``, etc.""" for child in fig_tag.find_all(True): if _local_name(child) not in {"graphic", "inline-graphic", "media"}: continue href = _attr_value(child, "href") if href: return href.strip() return ""
def _is_graphical_abstract(fig_tag) -> bool: if GA_RE.search(_attr_value(fig_tag, "id")): return True label = fig_tag.find(lambda t: _local_name(t) == "label") if label and GA_RE.search(label.get_text(" ", strip=True)): return True caption = fig_tag.find(lambda t: _local_name(t) == "caption") if caption and GA_RE.search(caption.get_text(" ", strip=True)): return True return False
[docs] def extract_figures(root) -> dict[str, dict]: """Return ``{fig_id: {caption, image_url, fig_num, is_graphical_abstract}}``.""" figures: dict[str, dict] = {} for fig in root.find_all(lambda t: _local_name(t) in FIG_TAGS): fig_id = _attr_value(fig, "id").strip() if not fig_id: label_tag = fig.find(lambda t: _local_name(t) == "label") if label_tag: fig_id = re.sub(r"\W+", "", label_tag.get_text(" ", strip=True)) if not fig_id: continue label_text = "" label = fig.find(lambda t: _local_name(t) == "label", recursive=False) if label: label_text = label.get_text(" ", strip=True) caption_text = "" caption = fig.find(lambda t: _local_name(t) == "caption", recursive=False) if caption: caption_text = caption.get_text(" ", strip=True) caption_full = _normalize_caption( " ".join(p for p in [label_text, caption_text] if p) ) figures[fig_id] = { "caption": caption_full, "image_url": resolve_image_url(fig), "fig_num": _figure_number_from_id(fig_id), "is_graphical_abstract": _is_graphical_abstract(fig), } return figures
def _get_marked_text(para_tag) -> tuple[str, dict[str, str]]: """Reconstruct paragraph text with ``__XREF_<rid>__`` sentinel markers.""" parts: list[str] = [] marker_map: dict[str, str] = {} def walk(node): if isinstance(node, NavigableString): parts.append(str(node)) return if not isinstance(node, Tag): return name = _local_name(node) full = _tag_name(node) if full in SKIP_TEXT_TAGS or name in SKIP_TEXT_TAGS: return if name in XREF_TAGS: rid = _attr_value(node, "rid") if rid and _is_figure_xref(node): marker = f"__XREF_{rid}__" marker_map[marker] = rid parts.append(marker) parts.append(node.get_text(" ", strip=True)) return for child in node.children: walk(child) for child in para_tag.children: walk(child) text = re.sub(r"\s+", " ", " ".join(parts)).strip() return text, marker_map
[docs] def extract_reference_sentences(root, figures: dict[str, dict]) -> dict[str, list[str]]: """Find body sentences that cite each figure.""" ref_map: dict[str, list[str]] = {fid: [] for fid in figures} num_to_fids: dict[int, list[str]] = {} for fid, data in figures.items(): n = data["fig_num"] if n is not None: num_to_fids.setdefault(n, []).append(fid) body = root.find(lambda t: _local_name(t) == "body") search_root = body or root for para in search_root.find_all(lambda t: _local_name(t) in PARA_TAGS): text_marked, marker_map = _get_marked_text(para) if not text_marked: continue has_any_fig_marker = bool(marker_map) for sentence in split_sentences(text_marked): found = re.findall(r"__XREF_(.*?)__", sentence) clean = _clean_markers(sentence) if found: for marker_content in found: for rid in marker_content.split(): if rid in ref_map and clean not in ref_map[rid]: ref_map[rid].append(clean) elif not has_any_fig_marker: for number in _extract_fallback_figure_numbers(clean): for fid in num_to_fids.get(number, []): if clean not in ref_map[fid]: ref_map[fid].append(clean) return ref_map
[docs] def process_file(xml_path: str | Path) -> list[dict]: """Extract all figure rows from a single Springer JATS XML file.""" with open(xml_path, "rb") as fh: soup = BeautifulSoup(fh, "lxml") figures = extract_figures(soup) ref_map = extract_reference_sentences(soup, figures) fname = os.path.basename(xml_path) rows: list[dict] = [] for fid, data in figures.items(): sentences = ref_map.get(fid, []) rows.append({ "xml_file": fname, "figure_id": fid, "caption": data["caption"], "image_url": data["image_url"], "is_graphical_abstract": data["is_graphical_abstract"], "num_references": len(sentences), "reference_sentences": " || ".join(sentences[:5]), }) return rows
[docs] def extract_all( xml_dir: str | Path, output_csv: str | Path | None = None, verbose: bool = True, ) -> tuple[pd.DataFrame, list[tuple[str, str]]]: """Process every XML file in *xml_dir* and return a figures DataFrame. Parameters ---------- xml_dir: Directory containing Springer ``.xml`` files. output_csv: If provided, write the DataFrame here. verbose: Print per-file progress and summary. Returns ------- df : pd.DataFrame One row per figure. errors : list of (filename, error_message) """ xml_dir = Path(xml_dir) xml_files = sorted(f for f in xml_dir.iterdir() if f.suffix == ".xml") if verbose: print(f"Found {len(xml_files)} XML files in '{xml_dir}'") all_rows: list[dict] = [] errors: list[tuple[str, str]] = [] for path in xml_files: try: rows = process_file(path) all_rows.extend(rows) if verbose: n_refs = sum(1 for r in rows if r["num_references"] > 0) n_ga = sum(1 for r in rows if r["is_graphical_abstract"]) print(f" OK {path.name} figs={len(rows)} with_refs={n_refs} ga={n_ga}") except Exception as exc: errors.append((path.name, str(exc))) if verbose: print(f" ERR {path.name} ERROR: {exc}") df = pd.DataFrame(all_rows) if output_csv is not None: df.to_csv(output_csv, index=False) if verbose: print(f"\n{len(all_rows)} rows → {output_csv}") return df, errors
# --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Extract figures from Springer JATS XML files.") p.add_argument("--xml-dir", default="alloys_springer") p.add_argument("--output-csv", default="springer_figure_details.csv") return p.parse_args() def main() -> None: args = _parse_args() extract_all(xml_dir=args.xml_dir, output_csv=args.output_csv) if __name__ == "__main__": main()