Source code for matmmextract.preprocess.pipeline

"""
matmmextract.preprocess.pipeline
====================================
Load, deduplicate, and filter Scopus export CSVs down to the set of
papers that will be fetched from publisher APIs.

This is the notebook's first half — everything before the XML download
steps — expressed as a callable API.

Notebook cells covered
-----------------------
1. Load multiple CSVs and concatenate
2. Inspect publisher counts (bar chart)
3. Identify and drop duplicate DOIs
4. Filter by publisher regex (default: Elsevier)
5. Filter by Open Access type
6. Save filtered CSVs
7. Move qualifying XMLs into a working directory (green open access)
8. Copy CC BY XMLs into a separate directory (post-license-check)
9. Flatten downloaded image sub-folders into a single flat folder
10. Prune images that have no matching .txt caption file

Cells NOT included (publisher-specific, live in elsevier/ subpackage):
- The figure extractor cell (→ elsevier/extractor.py)
- The CC license diagnostics cell (→ preprocess/cc_license.py)
"""

from __future__ import annotations

import shutil
from pathlib import Path
from typing import TypeAlias
from collections.abc import Iterable

import pandas as pd

FileArg: TypeAlias = str | Path | Iterable[str | Path]

# ---------------------------------------------------------------------------
# Constants  (match notebook defaults)
# ---------------------------------------------------------------------------

ELSEVIER_PATTERN: str = r"\bElsevier\b"

OPEN_ACCESS_ALLOWED: frozenset[str] = frozenset({
    "All Open Access",
    "All Open Access; Green Open Access",
    "All Open Access; Bronze Open Access",
    "All Open Access; Bronze Open Access; Green Open Access",
})

IMAGE_EXTS: frozenset[str] = frozenset(
    {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp", ".gif", ".webp"}
)


# ---------------------------------------------------------------------------
# 1. Load
# ---------------------------------------------------------------------------

def _clean_doi(doi: str) -> str:
    """Strip https://doi.org/ prefix if present."""
    import re
    return re.sub(r"^https?://doi\.org/", "", str(doi).strip())


[docs] def load_csvs(filenames: FileArg, **read_csv_kwargs) -> pd.DataFrame: """Load one or more Scopus export CSVs and concatenate them. Parameters ---------- filenames: A single path or any iterable of paths. **read_csv_kwargs: Passed to :func:`pandas.read_csv`. Returns ------- pd.DataFrame Concatenated data, index reset to 0 … N-1. Examples -------- >>> df = load_csvs(["Al_alloy_10k.csv", "Ni_alloy_10k.csv"]) >>> df = load_csvs(Path("data").glob("*.csv")) """ if isinstance(filenames, (str, Path)): filenames = [filenames] frames = [pd.read_csv(f, **read_csv_kwargs) for f in filenames] if not frames: raise ValueError("No files provided to load_csvs().") return pd.concat(frames, ignore_index=True)
# --------------------------------------------------------------------------- # 2. Inspect # ---------------------------------------------------------------------------
[docs] def publisher_counts(df: pd.DataFrame, top_n: int = 10) -> pd.Series: """Return a value-counts Series for the top-N publishers. Mirrors the ``merged_df["Publisher"].value_counts().head(10)`` cell. """ return df["Publisher"].value_counts().head(top_n)
# --------------------------------------------------------------------------- # 3. Deduplication # ---------------------------------------------------------------------------
[docs] def find_duplicate_dois(df: pd.DataFrame, doi_col: str = "DOI") -> pd.DataFrame: """Return all rows whose DOI appears more than once. Mirrors: ``merged_df[merged_df["DOI"].notna() & merged_df["DOI"].duplicated(keep=False)]`` """ mask = df[doi_col].notna() & df[doi_col].duplicated(keep=False) return df[mask].sort_values(doi_col)
[docs] def drop_duplicate_dois( df: pd.DataFrame, doi_col: str = "DOI", keep: str = "first", verbose: bool = True, ) -> pd.DataFrame: """Remove duplicated DOIs, keeping one row per DOI. Mirrors: ``merged_df.drop_duplicates(subset=["DOI"], keep="first").reset_index(drop=True)`` """ # Normalize DOIs: strip https://doi.org/ prefix if present if doi_col in df.columns: df = df.copy() df[doi_col] = df[doi_col].apply(lambda x: _clean_doi(x) if pd.notna(x) else x) dupes = find_duplicate_dois(df, doi_col) clean = df.drop_duplicates(subset=[doi_col], keep=keep).reset_index(drop=True) if verbose: print(f"Duplicate DOI rows: {len(dupes)} | {len(df)}{len(clean)} rows after dedup") return clean
# --------------------------------------------------------------------------- # 4. Publisher filter # ---------------------------------------------------------------------------
[docs] def filter_by_publisher( df: pd.DataFrame, pattern: str = ELSEVIER_PATTERN, publisher_col: str = "Publisher", case: bool = False, verbose: bool = True, ) -> pd.DataFrame: """Keep rows whose publisher matches *pattern*. Mirrors: ``merged_df[merged_df["Publisher"].fillna("").str.contains(elsevier_pattern, ...)]`` """ mask = df[publisher_col].fillna("").str.contains(pattern, case=case, regex=True) result = df[mask].reset_index(drop=True) if verbose: print(f"Publisher filter '{pattern}': {len(df)}{len(result)} rows") return result
# --------------------------------------------------------------------------- # 5. Open-access filter # ---------------------------------------------------------------------------
[docs] def filter_open_access( df: pd.DataFrame, allowed: frozenset[str] | set[str] | list[str] = OPEN_ACCESS_ALLOWED, oa_col: str = "Open Access", verbose: bool = True, ) -> pd.DataFrame: """Keep rows whose Open Access type is in *allowed*. Mirrors: ``green_df = df[df["Open Access"].isin(allowed_open_access)]`` """ result = df[df[oa_col].isin(allowed)].reset_index(drop=True) if verbose: print(f"Open-access filter: {len(df)}{len(result)} rows") return result
# --------------------------------------------------------------------------- # 6. Save # ---------------------------------------------------------------------------
[docs] def save_csv(df: pd.DataFrame, path: str | Path, verbose: bool = True) -> Path: """Save *df* to CSV (no index). Creates parent directories if needed.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) df.to_csv(path, index=False) if verbose: print(f"Saved {len(df)} rows → {path}") return path.resolve()
# --------------------------------------------------------------------------- # 7. Move qualifying XMLs (green open-access subset) # ---------------------------------------------------------------------------
[docs] def move_xmls_by_doi( dois: pd.Series | list[str], source_dir: str | Path, target_dir: str | Path, verbose: bool = True, ) -> tuple[list[str], list[str]]: """Move XML files from *source_dir* to *target_dir* for the given DOIs. Mirrors the notebook cell: ``for doi in dois: shutil.move(source_dir / doi.replace("/","_")+".xml", ...)`` Parameters ---------- dois: Iterable of DOI strings (from ``green_df["DOI"]``). source_dir: Directory currently holding all XML files. target_dir: Destination directory (created if it doesn't exist). Returns ------- moved : list[str] — filenames successfully moved missing : list[str] — DOIs whose XML was not found """ source_dir = Path(source_dir) target_dir = Path(target_dir) target_dir.mkdir(parents=True, exist_ok=True) clean_dois = ( pd.Series(dois).dropna().astype(str).str.strip() .pipe(lambda s: s[s != ""]) .drop_duplicates().tolist() ) moved, missing = [], [] for doi in clean_dois: xml_name = doi.replace("/", "_") + ".xml" src = source_dir / xml_name if src.exists(): shutil.move(str(src), target_dir / xml_name) moved.append(xml_name) else: missing.append(doi) if verbose: print(f"move_xmls_by_doi: {len(clean_dois)} DOIs | {len(moved)} moved | {len(missing)} missing") return moved, missing
# --------------------------------------------------------------------------- # 8. Copy CC BY XMLs # ---------------------------------------------------------------------------
[docs] def copy_xmls_by_filename( filenames: Iterable[str], source_dir: str | Path, target_dir: str | Path, verbose: bool = True, ) -> tuple[int, list[str]]: """Copy specific XML files (by filename) to *target_dir*. Mirrors: ``for fname in cc_by_files: shutil.copy2(source_dir/fname, cc_by_xml_dir/fname)`` Returns ------- copied : int missing : list[str] — filenames not found in source_dir """ source_dir = Path(source_dir) target_dir = Path(target_dir) target_dir.mkdir(parents=True, exist_ok=True) copied, missing = 0, [] for fname in filenames: src = source_dir / fname if src.exists(): shutil.copy2(src, target_dir / fname) copied += 1 else: missing.append(fname) if verbose: print(f"copy_xmls_by_filename: {copied} copied | {len(missing)} missing") return copied, missing
# --------------------------------------------------------------------------- # 9. Flatten image sub-folders # ---------------------------------------------------------------------------
[docs] def flatten_images( source_root: str | Path, target_dir: str | Path, limit: int | None = None, image_exts: frozenset[str] = IMAGE_EXTS, verbose: bool = True, ) -> dict[str, int]: """Copy images from nested per-paper folders into one flat directory. Mirrors the notebook cell that scanned ``alloy_elsevier_contents`` sub-folders and copied everything into ``alloy_images_flat``. Parameters ---------- source_root: Root whose immediate children are per-paper folders. target_dir: Flat output directory. limit: If set, process only the first *limit* sub-folders (sorted). Returns ------- dict with keys ``"copied"`` and ``"skipped"``. """ source_root = Path(source_root) target_dir = Path(target_dir) target_dir.mkdir(parents=True, exist_ok=True) subfolders = sorted(p for p in source_root.iterdir() if p.is_dir()) if limit is not None: subfolders = subfolders[:limit] copied = skipped = 0 for folder in subfolders: image_dir = folder / "images" if (folder / "images").exists() else folder for img_path in image_dir.rglob("*"): if img_path.is_file() and img_path.suffix.lower() in image_exts: dst = target_dir / f"{folder.name}_{img_path.name}" if dst.exists(): skipped += 1 else: shutil.copy2(img_path, dst) copied += 1 if verbose: print(f"flatten_images: {len(subfolders)} folders | {copied} copied | {skipped} skipped") return {"copied": copied, "skipped": skipped}
# --------------------------------------------------------------------------- # 10. Prune images without matching captions # ---------------------------------------------------------------------------
[docs] def prune_images_without_captions( image_dir: str | Path, caption_ext: str = ".txt", image_exts: frozenset[str] = IMAGE_EXTS, dry_run: bool = False, verbose: bool = True, ) -> dict[str, int | list[str]]: """Delete image files that have no matching caption file. Mirrors: ``for img in images: if not img.with_suffix(".txt").exists(): img.unlink()`` Parameters ---------- dry_run: Report what *would* be deleted without actually deleting. """ image_dir = Path(image_dir) images = [p for p in image_dir.iterdir() if p.is_file() and p.suffix.lower() in image_exts] kept = removed = 0 removed_files: list[str] = [] for img in images: if img.with_suffix(caption_ext).exists(): kept += 1 else: removed_files.append(img.name) removed += 1 if not dry_run: img.unlink() if verbose: action = "would remove" if dry_run else "removed" print(f"prune_images: {len(images)} total | {kept} kept | {removed} {action}") return {"kept": kept, "removed": removed, "removed_files": removed_files}
# --------------------------------------------------------------------------- # Convenience: full CSV preprocessing in one call # ---------------------------------------------------------------------------
[docs] def run( csv_files: FileArg, publisher_pattern: str = ELSEVIER_PATTERN, open_access_allowed: frozenset[str] = OPEN_ACCESS_ALLOWED, output_dir: str | Path = Path("output"), verbose: bool = True, ) -> dict[str, pd.DataFrame]: """Run the full CSV preprocessing pipeline. Load → dedup → filter publisher → filter open-access → save. Parameters ---------- csv_files: Input CSV path(s). publisher_pattern: Regex for publisher filter. open_access_allowed: Allowed Open Access tag values. output_dir: Directory for output CSVs. verbose: Print progress. Returns ------- dict with keys: ``"raw"`` — concatenated raw data ``"deduped"`` — after DOI dedup ``"publisher"`` — after publisher filter (saved as publisher_filtered.csv) ``"open_access"`` — after OA filter (saved as open_access_filtered.csv) Examples -------- >>> result = run(["Al_alloy_10k.csv", "Ni_alloy_10k.csv"]) >>> green_df = result["open_access"] """ raw = load_csvs(csv_files) if verbose: print(f"Loaded {len(raw):,} rows from {len(list(csv_files) if not isinstance(csv_files, (str, Path)) else [csv_files])} file(s)") print(publisher_counts(raw).to_string()) deduped = drop_duplicate_dois(raw, verbose=verbose) pub_filtered = filter_by_publisher(deduped, pattern=publisher_pattern, verbose=verbose) oa_filtered = filter_open_access(pub_filtered, allowed=open_access_allowed, verbose=verbose) output_dir = Path(output_dir) save_csv(pub_filtered, output_dir / "publisher_filtered.csv", verbose=verbose) save_csv(oa_filtered, output_dir / "open_access_filtered.csv", verbose=verbose) return { "raw": raw, "deduped": deduped, "publisher": pub_filtered, "open_access": oa_filtered, }