mirror of
https://github.com/dantetemplar/pdf-extraction-agenda.git
synced 2025-03-17 21:12:24 +03:00
Add pdf-extraction-agenda package
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.venv
|
||||||
|
.python-version
|
||||||
|
data/
|
||||||
|
dist/
|
||||||
|
.idea/
|
||||||
|
*.egg-info
|
||||||
31
pyproject.toml
Normal file
31
pyproject.toml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
[project]
|
||||||
|
name = "pdf-extraction-agenda"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Overview of pipelines related to PDF document processing"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"colorlog>=6.9.0",
|
||||||
|
"datasets>=3.3.2",
|
||||||
|
"huggingface-hub[hf-transfer]>=0.29.2",
|
||||||
|
"pandas>=2.2.3",
|
||||||
|
"pydantic>=2.10.6",
|
||||||
|
"rapidfuzz>=3.12.2",
|
||||||
|
"tabulate>=0.9.0",
|
||||||
|
"tqdm>=4.67.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 120
|
||||||
|
lint.ignore = ["PLR"]
|
||||||
|
lint.extend-select = ["I", "UP", "PL"]
|
||||||
|
target-version = "py312"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
docling = [
|
||||||
|
"docling>=2.25.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
0
src/pdf_extraction_agenda/__init__.py
Normal file
0
src/pdf_extraction_agenda/__init__.py
Normal file
85
src/pdf_extraction_agenda/datasets_.py
Normal file
85
src/pdf_extraction_agenda/datasets_.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from datasets import Dataset, load_dataset
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from logging_ import logger
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class OlmoOCRResponse(BaseModel):
|
||||||
|
"""OCRed Page Information"""
|
||||||
|
|
||||||
|
primary_language: str
|
||||||
|
is_rotation_valid: bool
|
||||||
|
rotation_correction: int
|
||||||
|
is_table: bool
|
||||||
|
is_diagram: bool
|
||||||
|
natural_text: str # Extracted text from PDF
|
||||||
|
|
||||||
|
|
||||||
|
def parse_response(example: dict, warn: bool = True) -> tuple[bool, OlmoOCRResponse | None]:
|
||||||
|
try:
|
||||||
|
return False, OlmoOCRResponse.model_validate_json(example["response"])
|
||||||
|
except ValidationError as e:
|
||||||
|
if warn:
|
||||||
|
logger.warning(f"Malformed response for {example.get('id')}\n{e}")
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tarballs(source_dir: str | os.PathLike, destination_dir: str | os.PathLike) -> None:
|
||||||
|
"""Extracts all tarball files from the source directory into the destination directory."""
|
||||||
|
os.makedirs(destination_dir, exist_ok=True)
|
||||||
|
|
||||||
|
tarballs = glob.glob(os.path.join(source_dir, "*.tar*")) # Matches .tar, .tar.gz, .tar.bz2, etc.
|
||||||
|
for tarball in tarballs:
|
||||||
|
try:
|
||||||
|
with tarfile.open(tarball, "r:*") as tar:
|
||||||
|
tar.extractall(path=destination_dir, filter="fully_trusted")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Failed to extract {tarball}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class IdToPathProto(Protocol):
|
||||||
|
def __call__(self, id: str, warn: bool = False) -> Path | None:
|
||||||
|
"""Converts an ID to a file path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_olmocr_dataset() -> tuple[Dataset, IdToPathProto]:
|
||||||
|
dataset = load_dataset("allenai/olmOCR-mix-0225", "00_documents", split="eval_s2pdf")
|
||||||
|
path_to_snaphot = snapshot_download(
|
||||||
|
repo_id="dantetemplar/pdf-extraction-agenda", repo_type="dataset", allow_patterns=["*.tar.gz"]
|
||||||
|
)
|
||||||
|
source_tarball_dir = os.path.join(path_to_snaphot, "data", "olmOCR-mix-0225")
|
||||||
|
destination_dir = Path("data/olmOCR-mix-0225-extracted")
|
||||||
|
|
||||||
|
extract_tarballs(source_tarball_dir, destination_dir)
|
||||||
|
|
||||||
|
def id_to_path(id: str, warn: bool = False) -> Path | None:
|
||||||
|
path = destination_dir / f"{id}.pdf"
|
||||||
|
if path.exists():
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
if warn:
|
||||||
|
logger.warning(f"File {path} not found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return dataset, id_to_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dataset, id_to_path = prepare_olmocr_dataset()
|
||||||
|
|
||||||
|
for s in dataset:
|
||||||
|
path = id_to_path(s["id"], warn=True)
|
||||||
|
malformed, response = parse_response(s, warn=True)
|
||||||
|
if malformed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
24
src/pdf_extraction_agenda/logging.yaml
Normal file
24
src/pdf_extraction_agenda/logging.yaml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
version: 1
|
||||||
|
disable_existing_loggers: False
|
||||||
|
formatters:
|
||||||
|
src:
|
||||||
|
"()": colorlog.ColoredFormatter
|
||||||
|
format: '[%(asctime)s] [%(log_color)s%(levelname)s%(reset)s] [%(cyan)sFile "%(relativePath)s", line %(lineno)d%(reset)s] %(message)s'
|
||||||
|
default:
|
||||||
|
"()": colorlog.ColoredFormatter
|
||||||
|
format: '[%(asctime)s] [%(log_color)s%(levelname)s%(reset)s] [%(name)s] %(message)s'
|
||||||
|
handlers:
|
||||||
|
src:
|
||||||
|
formatter: src
|
||||||
|
class: logging.StreamHandler
|
||||||
|
stream: ext://sys.stdout
|
||||||
|
default:
|
||||||
|
formatter: default
|
||||||
|
class: logging.StreamHandler
|
||||||
|
stream: ext://sys.stdout
|
||||||
|
loggers:
|
||||||
|
src:
|
||||||
|
level: INFO
|
||||||
|
handlers:
|
||||||
|
- src
|
||||||
|
propagate: no
|
||||||
22
src/pdf_extraction_agenda/logging_.py
Normal file
22
src/pdf_extraction_agenda/logging_.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
__all__ = ["logger"]
|
||||||
|
|
||||||
|
import logging.config
|
||||||
|
import os
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePathFilter(logging.Filter):
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
record.relativePath = os.path.relpath(record.pathname)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
logging_yaml = os.path.join(os.path.dirname(__file__), "logging.yaml")
|
||||||
|
|
||||||
|
with open(logging_yaml) as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
logging.config.dictConfig(config)
|
||||||
|
|
||||||
|
logger = logging.getLogger("src")
|
||||||
|
logger.addFilter(RelativePathFilter())
|
||||||
52
src/pdf_extraction_agenda/main.py
Normal file
52
src/pdf_extraction_agenda/main.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from os import PathLike
|
||||||
|
from typing import Literal, NewType, Protocol, assert_never
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datasets_ import parse_response, prepare_olmocr_dataset
|
||||||
|
from metrics import calc_nid
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineProto(Protocol):
|
||||||
|
def __call__(self, path: str | PathLike) -> str:
|
||||||
|
"""Runs the pipeline on the given path and returns the md result."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
EvaluationResult = NewType("EvaluationResult", pd.DataFrame)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_pipeline(run_pipeline: PipelineProto) -> EvaluationResult:
|
||||||
|
dataset, id_to_path = prepare_olmocr_dataset()
|
||||||
|
|
||||||
|
metrics_raw = []
|
||||||
|
|
||||||
|
for s in tqdm(dataset):
|
||||||
|
path = id_to_path(s["id"], warn=True)
|
||||||
|
malformed, response = parse_response(s, warn=True)
|
||||||
|
if malformed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
md_result = run_pipeline(path)
|
||||||
|
nid = calc_nid(response.natural_text, md_result)
|
||||||
|
metrics_raw.append({"nid": nid})
|
||||||
|
|
||||||
|
metrics_df = pd.DataFrame(metrics_raw)
|
||||||
|
return EvaluationResult(metrics_df)
|
||||||
|
|
||||||
|
|
||||||
|
def main(pipeline: Literal["docling"]):
|
||||||
|
if pipeline == "docling":
|
||||||
|
from pipeline_docling import run_docling_pipeline
|
||||||
|
|
||||||
|
run_pipeline = run_docling_pipeline
|
||||||
|
else:
|
||||||
|
assert_never(pipeline)
|
||||||
|
|
||||||
|
metrics_df = evaluate_pipeline(run_pipeline)
|
||||||
|
|
||||||
|
print(metrics_df)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main("docling")
|
||||||
31
src/pdf_extraction_agenda/metrics.py
Normal file
31
src/pdf_extraction_agenda/metrics.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from rapidfuzz import fuzz
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_text(text: str) -> str:
|
||||||
|
"""Normalize text for comparison."""
|
||||||
|
if text is None:
|
||||||
|
return ""
|
||||||
|
return text.strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def calc_nid(gt_text: str, pred_text: str) -> float:
|
||||||
|
"""Calculate the Normalized Indel score between the gt and pred text.
|
||||||
|
Args:
|
||||||
|
gt_text (str): The string of gt text to compare.
|
||||||
|
pred_text (str): The string of pred text to compare.
|
||||||
|
Returns:
|
||||||
|
float: The nid score between gt and pred text. [0., 1.]
|
||||||
|
"""
|
||||||
|
gt_text = _normalize_text(gt_text)
|
||||||
|
pred_text = _normalize_text(pred_text)
|
||||||
|
|
||||||
|
# if gt and pred is empty, return 1
|
||||||
|
if len(gt_text) == 0 and len(pred_text) == 0:
|
||||||
|
score = 1
|
||||||
|
# if pred is empty while gt is not, return 0
|
||||||
|
elif len(gt_text) > 0 and len(pred_text) == 0:
|
||||||
|
score = 0
|
||||||
|
else:
|
||||||
|
score = fuzz.ratio(gt_text, pred_text)
|
||||||
|
|
||||||
|
return score
|
||||||
10
src/pdf_extraction_agenda/pipeline_docling.py
Normal file
10
src/pdf_extraction_agenda/pipeline_docling.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from os import PathLike
|
||||||
|
|
||||||
|
from docling.document_converter import DocumentConverter
|
||||||
|
|
||||||
|
document_converter = DocumentConverter()
|
||||||
|
|
||||||
|
|
||||||
|
def run_docling_pipeline(path: str | PathLike) -> str:
|
||||||
|
result = document_converter.convert(path)
|
||||||
|
return result.document.export_to_markdown()
|
||||||
Reference in New Issue
Block a user