mirror of
				https://github.com/dantetemplar/pdf-extraction-agenda.git
				synced 2025-03-17 21:12:24 +03:00 
			
		
		
		
	Add pdf-extraction-agenda package
This commit is contained in:
		
							
								
								
									
										6
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					.venv
 | 
				
			||||||
 | 
					.python-version
 | 
				
			||||||
 | 
					data/
 | 
				
			||||||
 | 
					dist/
 | 
				
			||||||
 | 
					.idea/
 | 
				
			||||||
 | 
					*.egg-info
 | 
				
			||||||
							
								
								
									
										31
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
				
			|||||||
 | 
					[project]
 | 
				
			||||||
 | 
					name = "pdf-extraction-agenda"
 | 
				
			||||||
 | 
					version = "0.1.0"
 | 
				
			||||||
 | 
					description = "Overview of pipelines related to PDF document processing"
 | 
				
			||||||
 | 
					readme = "README.md"
 | 
				
			||||||
 | 
					requires-python = ">=3.12"
 | 
				
			||||||
 | 
					dependencies = [
 | 
				
			||||||
 | 
					    "colorlog>=6.9.0",
 | 
				
			||||||
 | 
					    "datasets>=3.3.2",
 | 
				
			||||||
 | 
					    "huggingface-hub[hf-transfer]>=0.29.2",
 | 
				
			||||||
 | 
					    "pandas>=2.2.3",
 | 
				
			||||||
 | 
					    "pydantic>=2.10.6",
 | 
				
			||||||
 | 
					    "rapidfuzz>=3.12.2",
 | 
				
			||||||
 | 
					    "tabulate>=0.9.0",
 | 
				
			||||||
 | 
					    "tqdm>=4.67.1",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.ruff]
 | 
				
			||||||
 | 
					line-length = 120
 | 
				
			||||||
 | 
					lint.ignore = ["PLR"]
 | 
				
			||||||
 | 
					lint.extend-select = ["I", "UP", "PL"]
 | 
				
			||||||
 | 
					target-version = "py312"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[dependency-groups]
 | 
				
			||||||
 | 
					docling = [
 | 
				
			||||||
 | 
					    "docling>=2.25.2",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[build-system]
 | 
				
			||||||
 | 
					requires = ["hatchling"]
 | 
				
			||||||
 | 
					build-backend = "hatchling.build"
 | 
				
			||||||
							
								
								
									
										0
									
								
								src/pdf_extraction_agenda/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/pdf_extraction_agenda/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										85
									
								
								src/pdf_extraction_agenda/datasets_.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								src/pdf_extraction_agenda/datasets_.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,85 @@
 | 
				
			|||||||
 | 
					import glob
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import tarfile
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from typing import Protocol
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from datasets import Dataset, load_dataset
 | 
				
			||||||
 | 
					from huggingface_hub import snapshot_download
 | 
				
			||||||
 | 
					from logging_ import logger
 | 
				
			||||||
 | 
					from pydantic import BaseModel, ValidationError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OlmoOCRResponse(BaseModel):
 | 
				
			||||||
 | 
					    """OCRed Page Information"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    primary_language: str
 | 
				
			||||||
 | 
					    is_rotation_valid: bool
 | 
				
			||||||
 | 
					    rotation_correction: int
 | 
				
			||||||
 | 
					    is_table: bool
 | 
				
			||||||
 | 
					    is_diagram: bool
 | 
				
			||||||
 | 
					    natural_text: str  # Extracted text from PDF
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_response(example: dict, warn: bool = True) -> tuple[bool, OlmoOCRResponse | None]:
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        return False, OlmoOCRResponse.model_validate_json(example["response"])
 | 
				
			||||||
 | 
					    except ValidationError as e:
 | 
				
			||||||
 | 
					        if warn:
 | 
				
			||||||
 | 
					            logger.warning(f"Malformed response for {example.get('id')}\n{e}")
 | 
				
			||||||
 | 
					        return True, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def extract_tarballs(source_dir: str | os.PathLike, destination_dir: str | os.PathLike) -> None:
 | 
				
			||||||
 | 
					    """Extracts all tarball files from the source directory into the destination directory."""
 | 
				
			||||||
 | 
					    os.makedirs(destination_dir, exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tarballs = glob.glob(os.path.join(source_dir, "*.tar*"))  # Matches .tar, .tar.gz, .tar.bz2, etc.
 | 
				
			||||||
 | 
					    for tarball in tarballs:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            with tarfile.open(tarball, "r:*") as tar:
 | 
				
			||||||
 | 
					                tar.extractall(path=destination_dir, filter="fully_trusted")
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            logger.info(f"Failed to extract {tarball}: {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IdToPathProto(Protocol):
 | 
				
			||||||
 | 
					    def __call__(self, id: str, warn: bool = False) -> Path | None:
 | 
				
			||||||
 | 
					        """Converts an ID to a file path."""
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def prepare_olmocr_dataset() -> tuple[Dataset, IdToPathProto]:
 | 
				
			||||||
 | 
					    dataset = load_dataset("allenai/olmOCR-mix-0225", "00_documents", split="eval_s2pdf")
 | 
				
			||||||
 | 
					    path_to_snaphot = snapshot_download(
 | 
				
			||||||
 | 
					        repo_id="dantetemplar/pdf-extraction-agenda", repo_type="dataset", allow_patterns=["*.tar.gz"]
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    source_tarball_dir = os.path.join(path_to_snaphot, "data", "olmOCR-mix-0225")
 | 
				
			||||||
 | 
					    destination_dir = Path("data/olmOCR-mix-0225-extracted")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    extract_tarballs(source_tarball_dir, destination_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def id_to_path(id: str, warn: bool = False) -> Path | None:
 | 
				
			||||||
 | 
					        path = destination_dir / f"{id}.pdf"
 | 
				
			||||||
 | 
					        if path.exists():
 | 
				
			||||||
 | 
					            return path
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if warn:
 | 
				
			||||||
 | 
					                logger.warning(f"File {path} not found")
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return dataset, id_to_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main():
 | 
				
			||||||
 | 
					    dataset, id_to_path = prepare_olmocr_dataset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for s in dataset:
 | 
				
			||||||
 | 
					        path = id_to_path(s["id"], warn=True)
 | 
				
			||||||
 | 
					        malformed, response = parse_response(s, warn=True)
 | 
				
			||||||
 | 
					        if malformed:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
							
								
								
									
										24
									
								
								src/pdf_extraction_agenda/logging.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								src/pdf_extraction_agenda/logging.yaml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					version: 1
 | 
				
			||||||
 | 
					disable_existing_loggers: False
 | 
				
			||||||
 | 
					formatters:
 | 
				
			||||||
 | 
					  src:
 | 
				
			||||||
 | 
					    "()": colorlog.ColoredFormatter
 | 
				
			||||||
 | 
					    format: '[%(asctime)s] [%(log_color)s%(levelname)s%(reset)s] [%(cyan)sFile "%(relativePath)s", line %(lineno)d%(reset)s] %(message)s'
 | 
				
			||||||
 | 
					  default:
 | 
				
			||||||
 | 
					    "()": colorlog.ColoredFormatter
 | 
				
			||||||
 | 
					    format: '[%(asctime)s] [%(log_color)s%(levelname)s%(reset)s] [%(name)s] %(message)s'
 | 
				
			||||||
 | 
					handlers:
 | 
				
			||||||
 | 
					  src:
 | 
				
			||||||
 | 
					    formatter: src
 | 
				
			||||||
 | 
					    class: logging.StreamHandler
 | 
				
			||||||
 | 
					    stream: ext://sys.stdout
 | 
				
			||||||
 | 
					  default:
 | 
				
			||||||
 | 
					    formatter: default
 | 
				
			||||||
 | 
					    class: logging.StreamHandler
 | 
				
			||||||
 | 
					    stream: ext://sys.stdout
 | 
				
			||||||
 | 
					loggers:
 | 
				
			||||||
 | 
					  src:
 | 
				
			||||||
 | 
					    level: INFO
 | 
				
			||||||
 | 
					    handlers:
 | 
				
			||||||
 | 
					      - src
 | 
				
			||||||
 | 
					    propagate: no
 | 
				
			||||||
							
								
								
									
										22
									
								
								src/pdf_extraction_agenda/logging_.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								src/pdf_extraction_agenda/logging_.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					__all__ = ["logger"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging.config
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import yaml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RelativePathFilter(logging.Filter):
 | 
				
			||||||
 | 
					    def filter(self, record: logging.LogRecord) -> bool:
 | 
				
			||||||
 | 
					        record.relativePath = os.path.relpath(record.pathname)
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logging_yaml = os.path.join(os.path.dirname(__file__), "logging.yaml")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					with open(logging_yaml) as f:
 | 
				
			||||||
 | 
					    config = yaml.safe_load(f)
 | 
				
			||||||
 | 
					    logging.config.dictConfig(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.getLogger("src")
 | 
				
			||||||
 | 
					logger.addFilter(RelativePathFilter())
 | 
				
			||||||
							
								
								
									
										52
									
								
								src/pdf_extraction_agenda/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								src/pdf_extraction_agenda/main.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					from os import PathLike
 | 
				
			||||||
 | 
					from typing import Literal, NewType, Protocol, assert_never
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pandas as pd
 | 
				
			||||||
 | 
					from datasets_ import parse_response, prepare_olmocr_dataset
 | 
				
			||||||
 | 
					from metrics import calc_nid
 | 
				
			||||||
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PipelineProto(Protocol):
 | 
				
			||||||
 | 
					    def __call__(self, path: str | PathLike) -> str:
 | 
				
			||||||
 | 
					        """Runs the pipeline on the given path and returns the md result."""
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EvaluationResult = NewType("EvaluationResult", pd.DataFrame)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def evaluate_pipeline(run_pipeline: PipelineProto) -> EvaluationResult:
 | 
				
			||||||
 | 
					    dataset, id_to_path = prepare_olmocr_dataset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metrics_raw = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for s in tqdm(dataset):
 | 
				
			||||||
 | 
					        path = id_to_path(s["id"], warn=True)
 | 
				
			||||||
 | 
					        malformed, response = parse_response(s, warn=True)
 | 
				
			||||||
 | 
					        if malformed:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        md_result = run_pipeline(path)
 | 
				
			||||||
 | 
					        nid = calc_nid(response.natural_text, md_result)
 | 
				
			||||||
 | 
					        metrics_raw.append({"nid": nid})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metrics_df = pd.DataFrame(metrics_raw)
 | 
				
			||||||
 | 
					    return EvaluationResult(metrics_df)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def main(pipeline: Literal["docling"]):
 | 
				
			||||||
 | 
					    if pipeline == "docling":
 | 
				
			||||||
 | 
					        from pipeline_docling import run_docling_pipeline
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        run_pipeline = run_docling_pipeline
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert_never(pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    metrics_df = evaluate_pipeline(run_pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(metrics_df)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    main("docling")
 | 
				
			||||||
							
								
								
									
										31
									
								
								src/pdf_extraction_agenda/metrics.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								src/pdf_extraction_agenda/metrics.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
				
			|||||||
 | 
					from rapidfuzz import fuzz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _normalize_text(text: str) -> str:
 | 
				
			||||||
 | 
					    """Normalize text for comparison."""
 | 
				
			||||||
 | 
					    if text is None:
 | 
				
			||||||
 | 
					        return ""
 | 
				
			||||||
 | 
					    return text.strip().lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def calc_nid(gt_text: str, pred_text: str) -> float:
 | 
				
			||||||
 | 
					    """Calculate the Normalized Indel score between the gt and pred text.
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        gt_text (str): The string of gt text to compare.
 | 
				
			||||||
 | 
					        pred_text (str): The string of pred text to compare.
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        float: The nid score between gt and pred text. [0., 1.]
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    gt_text = _normalize_text(gt_text)
 | 
				
			||||||
 | 
					    pred_text = _normalize_text(pred_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if gt and pred is empty, return 1
 | 
				
			||||||
 | 
					    if len(gt_text) == 0 and len(pred_text) == 0:
 | 
				
			||||||
 | 
					        score = 1
 | 
				
			||||||
 | 
					    # if pred is empty while gt is not, return 0
 | 
				
			||||||
 | 
					    elif len(gt_text) > 0 and len(pred_text) == 0:
 | 
				
			||||||
 | 
					        score = 0
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        score = fuzz.ratio(gt_text, pred_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return score
 | 
				
			||||||
							
								
								
									
										10
									
								
								src/pdf_extraction_agenda/pipeline_docling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								src/pdf_extraction_agenda/pipeline_docling.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
				
			|||||||
 | 
					from os import PathLike
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from docling.document_converter import DocumentConverter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					document_converter = DocumentConverter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_docling_pipeline(path: str | PathLike) -> str:
 | 
				
			||||||
 | 
					    result = document_converter.convert(path)
 | 
				
			||||||
 | 
					    return result.document.export_to_markdown()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user