mirror of
https://github.com/dantetemplar/pdf-extraction-agenda.git
synced 2025-03-17 21:12:24 +03:00
Add pdf-extraction-agenda package
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
.venv
|
||||
.python-version
|
||||
data/
|
||||
dist/
|
||||
.idea/
|
||||
*.egg-info
|
||||
31
pyproject.toml
Normal file
31
pyproject.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[project]
|
||||
name = "pdf-extraction-agenda"
|
||||
version = "0.1.0"
|
||||
description = "Overview of pipelines related to PDF document processing"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"colorlog>=6.9.0",
|
||||
"datasets>=3.3.2",
|
||||
"huggingface-hub[hf-transfer]>=0.29.2",
|
||||
"pandas>=2.2.3",
|
||||
"pydantic>=2.10.6",
|
||||
"rapidfuzz>=3.12.2",
|
||||
"tabulate>=0.9.0",
|
||||
"tqdm>=4.67.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
lint.ignore = ["PLR"]
|
||||
lint.extend-select = ["I", "UP", "PL"]
|
||||
target-version = "py312"
|
||||
|
||||
[dependency-groups]
|
||||
docling = [
|
||||
"docling>=2.25.2",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
0
src/pdf_extraction_agenda/__init__.py
Normal file
0
src/pdf_extraction_agenda/__init__.py
Normal file
85
src/pdf_extraction_agenda/datasets_.py
Normal file
85
src/pdf_extraction_agenda/datasets_.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import glob
|
||||
import os
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from logging_ import logger
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
class OlmoOCRResponse(BaseModel):
|
||||
"""OCRed Page Information"""
|
||||
|
||||
primary_language: str
|
||||
is_rotation_valid: bool
|
||||
rotation_correction: int
|
||||
is_table: bool
|
||||
is_diagram: bool
|
||||
natural_text: str # Extracted text from PDF
|
||||
|
||||
|
||||
def parse_response(example: dict, warn: bool = True) -> tuple[bool, OlmoOCRResponse | None]:
|
||||
try:
|
||||
return False, OlmoOCRResponse.model_validate_json(example["response"])
|
||||
except ValidationError as e:
|
||||
if warn:
|
||||
logger.warning(f"Malformed response for {example.get('id')}\n{e}")
|
||||
return True, None
|
||||
|
||||
|
||||
def extract_tarballs(source_dir: str | os.PathLike, destination_dir: str | os.PathLike) -> None:
|
||||
"""Extracts all tarball files from the source directory into the destination directory."""
|
||||
os.makedirs(destination_dir, exist_ok=True)
|
||||
|
||||
tarballs = glob.glob(os.path.join(source_dir, "*.tar*")) # Matches .tar, .tar.gz, .tar.bz2, etc.
|
||||
for tarball in tarballs:
|
||||
try:
|
||||
with tarfile.open(tarball, "r:*") as tar:
|
||||
tar.extractall(path=destination_dir, filter="fully_trusted")
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to extract {tarball}: {e}")
|
||||
|
||||
|
||||
class IdToPathProto(Protocol):
|
||||
def __call__(self, id: str, warn: bool = False) -> Path | None:
|
||||
"""Converts an ID to a file path."""
|
||||
pass
|
||||
|
||||
|
||||
def prepare_olmocr_dataset() -> tuple[Dataset, IdToPathProto]:
|
||||
dataset = load_dataset("allenai/olmOCR-mix-0225", "00_documents", split="eval_s2pdf")
|
||||
path_to_snaphot = snapshot_download(
|
||||
repo_id="dantetemplar/pdf-extraction-agenda", repo_type="dataset", allow_patterns=["*.tar.gz"]
|
||||
)
|
||||
source_tarball_dir = os.path.join(path_to_snaphot, "data", "olmOCR-mix-0225")
|
||||
destination_dir = Path("data/olmOCR-mix-0225-extracted")
|
||||
|
||||
extract_tarballs(source_tarball_dir, destination_dir)
|
||||
|
||||
def id_to_path(id: str, warn: bool = False) -> Path | None:
|
||||
path = destination_dir / f"{id}.pdf"
|
||||
if path.exists():
|
||||
return path
|
||||
else:
|
||||
if warn:
|
||||
logger.warning(f"File {path} not found")
|
||||
return None
|
||||
|
||||
return dataset, id_to_path
|
||||
|
||||
|
||||
def main():
|
||||
dataset, id_to_path = prepare_olmocr_dataset()
|
||||
|
||||
for s in dataset:
|
||||
path = id_to_path(s["id"], warn=True)
|
||||
malformed, response = parse_response(s, warn=True)
|
||||
if malformed:
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
24
src/pdf_extraction_agenda/logging.yaml
Normal file
24
src/pdf_extraction_agenda/logging.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
version: 1
|
||||
disable_existing_loggers: False
|
||||
formatters:
|
||||
src:
|
||||
"()": colorlog.ColoredFormatter
|
||||
format: '[%(asctime)s] [%(log_color)s%(levelname)s%(reset)s] [%(cyan)sFile "%(relativePath)s", line %(lineno)d%(reset)s] %(message)s'
|
||||
default:
|
||||
"()": colorlog.ColoredFormatter
|
||||
format: '[%(asctime)s] [%(log_color)s%(levelname)s%(reset)s] [%(name)s] %(message)s'
|
||||
handlers:
|
||||
src:
|
||||
formatter: src
|
||||
class: logging.StreamHandler
|
||||
stream: ext://sys.stdout
|
||||
default:
|
||||
formatter: default
|
||||
class: logging.StreamHandler
|
||||
stream: ext://sys.stdout
|
||||
loggers:
|
||||
src:
|
||||
level: INFO
|
||||
handlers:
|
||||
- src
|
||||
propagate: no
|
||||
22
src/pdf_extraction_agenda/logging_.py
Normal file
22
src/pdf_extraction_agenda/logging_.py
Normal file
@@ -0,0 +1,22 @@
|
||||
__all__ = ["logger"]
|
||||
|
||||
import logging.config
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class RelativePathFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.relativePath = os.path.relpath(record.pathname)
|
||||
return True
|
||||
|
||||
|
||||
logging_yaml = os.path.join(os.path.dirname(__file__), "logging.yaml")
|
||||
|
||||
with open(logging_yaml) as f:
|
||||
config = yaml.safe_load(f)
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
logger = logging.getLogger("src")
|
||||
logger.addFilter(RelativePathFilter())
|
||||
52
src/pdf_extraction_agenda/main.py
Normal file
52
src/pdf_extraction_agenda/main.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from os import PathLike
|
||||
from typing import Literal, NewType, Protocol, assert_never
|
||||
|
||||
import pandas as pd
|
||||
from datasets_ import parse_response, prepare_olmocr_dataset
|
||||
from metrics import calc_nid
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class PipelineProto(Protocol):
|
||||
def __call__(self, path: str | PathLike) -> str:
|
||||
"""Runs the pipeline on the given path and returns the md result."""
|
||||
pass
|
||||
|
||||
|
||||
EvaluationResult = NewType("EvaluationResult", pd.DataFrame)
|
||||
|
||||
|
||||
def evaluate_pipeline(run_pipeline: PipelineProto) -> EvaluationResult:
|
||||
dataset, id_to_path = prepare_olmocr_dataset()
|
||||
|
||||
metrics_raw = []
|
||||
|
||||
for s in tqdm(dataset):
|
||||
path = id_to_path(s["id"], warn=True)
|
||||
malformed, response = parse_response(s, warn=True)
|
||||
if malformed:
|
||||
continue
|
||||
|
||||
md_result = run_pipeline(path)
|
||||
nid = calc_nid(response.natural_text, md_result)
|
||||
metrics_raw.append({"nid": nid})
|
||||
|
||||
metrics_df = pd.DataFrame(metrics_raw)
|
||||
return EvaluationResult(metrics_df)
|
||||
|
||||
|
||||
def main(pipeline: Literal["docling"]):
|
||||
if pipeline == "docling":
|
||||
from pipeline_docling import run_docling_pipeline
|
||||
|
||||
run_pipeline = run_docling_pipeline
|
||||
else:
|
||||
assert_never(pipeline)
|
||||
|
||||
metrics_df = evaluate_pipeline(run_pipeline)
|
||||
|
||||
print(metrics_df)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main("docling")
|
||||
31
src/pdf_extraction_agenda/metrics.py
Normal file
31
src/pdf_extraction_agenda/metrics.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
"""Normalize text for comparison."""
|
||||
if text is None:
|
||||
return ""
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def calc_nid(gt_text: str, pred_text: str) -> float:
|
||||
"""Calculate the Normalized Indel score between the gt and pred text.
|
||||
Args:
|
||||
gt_text (str): The string of gt text to compare.
|
||||
pred_text (str): The string of pred text to compare.
|
||||
Returns:
|
||||
float: The nid score between gt and pred text. [0., 1.]
|
||||
"""
|
||||
gt_text = _normalize_text(gt_text)
|
||||
pred_text = _normalize_text(pred_text)
|
||||
|
||||
# if gt and pred is empty, return 1
|
||||
if len(gt_text) == 0 and len(pred_text) == 0:
|
||||
score = 1
|
||||
# if pred is empty while gt is not, return 0
|
||||
elif len(gt_text) > 0 and len(pred_text) == 0:
|
||||
score = 0
|
||||
else:
|
||||
score = fuzz.ratio(gt_text, pred_text)
|
||||
|
||||
return score
|
||||
10
src/pdf_extraction_agenda/pipeline_docling.py
Normal file
10
src/pdf_extraction_agenda/pipeline_docling.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from os import PathLike
|
||||
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
document_converter = DocumentConverter()
|
||||
|
||||
|
||||
def run_docling_pipeline(path: str | PathLike) -> str:
|
||||
result = document_converter.convert(path)
|
||||
return result.document.export_to_markdown()
|
||||
Reference in New Issue
Block a user