Source code for matmmextract.inference.detector

"""
matmmextract.inference.detector
===================================
Batch panel detection using a YOLO fine-tuned checkpoint (Ultralytics).

Ported from ``infer.py`` (YOLO version) — globals → parameters, callable as API.

Flow
----
images → ``run()`` → one JSON file per image + ``_summary.json`` in *output_dir*

Each per-image JSON::

    {
        "meta":         { weights, conf, iou, imgsz, timestamp, id2label },
        "file":         "img123.jpg",
        "path":         "/abs/path/img123.jpg",
        "width":        W,
        "height":       H,
        "n_detections": N,
        "detections":   [
            { "bbox": [x1,y1,x2,y2], "score": 0.91,
              "label_id": 0, "label_name": "A" },
            ...
        ]
    }
"""

from __future__ import annotations

import argparse
import datetime
import json
import os
from dataclasses import dataclass, field
from pathlib import Path

IMAGE_EXTS: frozenset[str] = frozenset(
    {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
)

ID2LABEL: dict[int, str] = {i: chr(ord("A") + i) for i in range(20)}
ID2LABEL[20] = "single"
ID2LABEL[21] = "common"


# ---------------------------------------------------------------------------
# Result container
# ---------------------------------------------------------------------------

[docs] @dataclass class DetectionResult: n_images: int = 0 n_saved: int = 0 output_dir: str = "" failed: list[str] = field(default_factory=list)
# --------------------------------------------------------------------------- # Helpers # ---------------------------------------------------------------------------
[docs] def find_images(folder: str | Path) -> list[str]: folder = Path(folder) return [ str(folder / f) for f in sorted(os.listdir(folder)) if Path(f).suffix.lower() in IMAGE_EXTS ]
def _is_gdrive_url(s: str) -> bool: return "drive.google.com" in s or "docs.google.com" in s def _is_hf_url(s: str) -> bool: return "huggingface.co" in s def _is_hf_repo_id(s: str) -> bool: """ Matches bare repo IDs like "username/model-name" or "username/model-name:filename.pt" (no scheme, no slashes beyond one). """ if _is_url(s) or _is_gdrive_url(s): return False # repo_id pattern: exactly one "/" between two non-empty segments, # optionally followed by ":filename" core = s.split(":", 1)[0] parts = core.split("/") return len(parts) == 2 and all(parts) and not s.startswith((".", "/")) def _is_url(s: str) -> bool: return str(s).startswith(("http://", "https://")) def _download_gdrive(url: str, cache_dir: str | Path = ".weights_cache") -> str: """Download a Google Drive file and return the local path. Accepts any of: - https://drive.google.com/file/d/<ID>/view?usp=sharing - https://drive.google.com/open?id=<ID> - https://drive.google.com/uc?id=<ID> """ try: import gdown except ImportError: raise ImportError("gdown is required for Google Drive downloads: pip install gdown") import re # Extract file ID from any Drive URL format patterns = [ r"/file/d/([a-zA-Z0-9_-]+)", r"[?&]id=([a-zA-Z0-9_-]+)", r"open\?id=([a-zA-Z0-9_-]+)", ] file_id = None for pat in patterns: m = re.search(pat, url) if m: file_id = m.group(1) break if not file_id: raise ValueError(f"Could not extract Google Drive file ID from URL: {url}") cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) # Check if already downloaded existing = list(cache_dir.glob(f"{file_id}*.pt")) if existing: print(f"[detector] using cached weights: {existing[0]}") return str(existing[0]) dest = cache_dir / f"{file_id}.pt" print(f"[detector] downloading from Google Drive (id={file_id}) → {dest}") gdown.download(id=file_id, output=str(dest), quiet=False) if not dest.exists(): raise RuntimeError(f"Download failed — file not found at {dest}") return str(dest) def _download_hf(source: str, cache_dir: str | Path = ".weights_cache") -> str: """Download a .pt file from a Hugging Face Hub repo and return the local path. Accepts: - A full URL: https://huggingface.co/<user>/<repo>/resolve/main/best.pt - A bare repo ID: "<user>/<repo>" (auto-detects the .pt filename in the repo) - A repo ID with filename: "<user>/<repo>:best.pt" """ try: from huggingface_hub import hf_hub_download, list_repo_files except ImportError: raise ImportError( "huggingface_hub is required: pip install huggingface_hub" ) import re repo_id = source filename = None if _is_hf_url(source): # https://huggingface.co/<user>/<repo>/resolve/main/<filename> # https://huggingface.co/<user>/<repo>/blob/main/<filename> m = re.search( r"huggingface\.co/([^/]+/[^/]+)/(?:resolve|blob)/[^/]+/(.+?)(?:\?.*)?$", source, ) if m: repo_id, filename = m.group(1), m.group(2) else: # Fallback: just the repo, e.g. https://huggingface.co/<user>/<repo> m2 = re.search(r"huggingface\.co/([^/]+/[^/]+)", source) if not m2: raise ValueError(f"Could not parse Hugging Face URL: {source}") repo_id = m2.group(1) elif ":" in source: repo_id, filename = source.split(":", 1) cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) if filename is None: # Auto-detect: find the first .pt file in the repo try: files = list_repo_files(repo_id) except Exception as exc: raise RuntimeError(f"Could not list files in HF repo '{repo_id}': {exc}") pt_files = [f for f in files if f.endswith(".pt")] if not pt_files: raise FileNotFoundError(f"No .pt files found in Hugging Face repo: {repo_id}") # Prefer best.pt / last.pt if present preferred = [f for f in pt_files if Path(f).name in ("best.pt", "last.pt")] filename = preferred[0] if preferred else sorted(pt_files)[0] print(f"[detector] auto-selected '{filename}' from repo '{repo_id}'") print(f"[detector] downloading from Hugging Face: {repo_id} / {filename}") local_path = hf_hub_download( repo_id=repo_id, filename=filename, cache_dir=str(cache_dir), ) return local_path
[docs] def resolve_weights(checkpoint: str | Path, cache_dir: str | Path = ".weights_cache") -> str: """Resolve weights to a local .pt path. Accepts: - A direct local .pt file path - A local directory containing best.pt / last.pt - A Google Drive share URL (downloaded and cached in *cache_dir*) """ s = str(checkpoint) # Google Drive URL if _is_gdrive_url(s): return _download_gdrive(s, cache_dir=cache_dir) # Hugging Face Hub URL or bare repo ID ("user/repo" or "user/repo:file.pt") if _is_hf_url(s) or _is_hf_repo_id(s): return _download_hf(s, cache_dir=cache_dir) # Other URL — not supported yet if _is_url(s): raise ValueError( f"URL '{s}' is not a Google Drive or Hugging Face link. " f"Only Google Drive and Hugging Face URLs are currently supported." ) # Local path p = Path(checkpoint) if p.is_file() and p.suffix == ".pt": return str(p) for name in ("best.pt", "last.pt"): candidate = p / name if candidate.exists(): return str(candidate) pts = list(p.glob("*.pt")) if pts: return str(sorted(pts)[0]) raise FileNotFoundError( f"No .pt weights found in '{checkpoint}'. " f"Expected best.pt / last.pt inside the folder, or a direct .pt path." )
# --------------------------------------------------------------------------- # Core API # ---------------------------------------------------------------------------
[docs] def detect( image_dir: str | Path, output_dir: str | Path, checkpoint: str | Path = "best.pt", conf: float = 0.6, iou: float = 0.4, imgsz: int = 1024, device: str = "", id2label: dict[int, str] | None = None, write_summary: bool = True, weights_cache_dir: str | Path = ".weights_cache", verbose: bool = True, ) -> DetectionResult: """Run YOLO panel detection on every image in *image_dir*. Parameters ---------- image_dir: Directory of input images. output_dir: Directory where one ``.json`` per image is written, plus ``_summary.json`` containing all records combined. checkpoint: Path to a YOLO ``.pt`` file, a directory containing ``best.pt`` / ``last.pt``, a Google Drive share URL (e.g. ``https://drive.google.com/file/d/<ID>/view``), or a Hugging Face Hub reference — a full URL (``https://huggingface.co/<user>/<repo>/resolve/main/best.pt``) or a bare repo ID (``"<user>/<repo>"`` or ``"<user>/<repo>:best.pt"``). weights_cache_dir: Local directory to cache downloaded weights (default ``.weights_cache``). conf: Confidence threshold (default 0.6). iou: NMS IoU threshold (default 0.4). imgsz: Inference image size (default 1024). device: Device string: ``""`` for auto, ``"cpu"``, ``"0"``, ``"0,1"``, etc. id2label: Override the default label map. If ``None``, uses the built-in map. write_summary: Write ``_summary.json`` combining all records (default True). verbose: Print progress. Returns ------- DetectionResult """ try: from ultralytics import YOLO except ImportError: raise ImportError("ultralytics is required: pip install ultralytics") from tqdm import tqdm output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) weights = resolve_weights(checkpoint, cache_dir=weights_cache_dir) label_map = id2label if id2label is not None else ID2LABEL if verbose: print(f"[detector] weights={weights}") print(f"[detector] conf={conf} iou={iou} imgsz={imgsz} device={device!r}") model = YOLO(weights) image_paths = find_images(image_dir) if not image_paths: raise RuntimeError(f"No images found in: {image_dir}") if verbose: print(f"[detector] {len(image_paths)} images\n") result = DetectionResult(n_images=len(image_paths), output_dir=str(output_dir)) all_records: list[dict] = [] for img_path in tqdm(image_paths, desc="Detecting", disable=not verbose): try: from PIL import Image pil = Image.open(img_path).convert("RGB") W, H = pil.size results = model.predict( source=img_path, conf=conf, iou=iou, imgsz=imgsz, device=device, verbose=False, ) r = results[0] detections = [] if r.boxes is not None and len(r.boxes): for box, score, cls in zip( r.boxes.xyxy.tolist(), r.boxes.conf.tolist(), r.boxes.cls.int().tolist(), ): detections.append({ "bbox": [round(v, 2) for v in box], "score": round(score, 6), "label_id": cls, "label_name": label_map.get(cls, str(cls)), }) stem = Path(img_path).stem record = { "meta": { "weights": weights, "conf": conf, "iou": iou, "imgsz": imgsz, "timestamp": datetime.datetime.now().isoformat(timespec="seconds"), "id2label": {str(k): v for k, v in label_map.items()}, }, "file": os.path.basename(img_path), "path": img_path, "width": W, "height": H, "n_detections": len(detections), "detections": detections, } out_path = output_dir / f"{stem}.json" out_path.write_text(json.dumps(record, indent=2)) all_records.append(record) result.n_saved += 1 except Exception as exc: result.failed.append(img_path) if verbose: print(f"\n[detector] ERROR {img_path}: {exc}") if write_summary: summary_path = output_dir / "_summary.json" summary_path.write_text(json.dumps(all_records, indent=2)) if verbose: print(f"[detector] summary → {summary_path}") if verbose: print(f"[detector] saved {result.n_saved} JSON files → {output_dir}") if result.failed: print(f"[detector] {len(result.failed)} failed: {result.failed}") return result
# --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser("MatDetect — YOLO inference → JSON") p.add_argument("--image-dir", required=True) p.add_argument("--output-dir", required=True) p.add_argument("--weights", default="best.pt", help="Path to YOLO .pt weights or folder containing best.pt") p.add_argument("--conf", type=float, default=0.6) p.add_argument("--iou", type=float, default=0.4) p.add_argument("--imgsz", type=int, default=1024) p.add_argument("--device", default="", help="Device: '' auto, 'cpu', '0', '0,1', etc.") return p.parse_args() def main() -> None: args = _parse_args() detect( image_dir=args.image_dir, output_dir=args.output_dir, checkpoint=args.weights, conf=args.conf, iou=args.iou, imgsz=args.imgsz, device=args.device, ) if __name__ == "__main__": main()