Batched evaluation

This commit is contained in:
Ruslan Bel'kov
2025-03-12 00:21:22 +03:00
parent f6a26d4bac
commit 16bd03ad11
2 changed files with 35 additions and 10 deletions

View File

@@ -1,36 +1,60 @@
from os import PathLike
from typing import Literal, NewType, Protocol, assert_never
from typing import Literal, NewType, Protocol, assert_never, overload
import pandas as pd
from tqdm.auto import tqdm
from .datasets_ import parse_response, prepare_olmocr_dataset
from .metrics import calc_nid
class PipelineProto(Protocol):
def __call__(self, path: str | PathLike) -> str:
def __call__(self, path: str | PathLike) -> str | None:
"""Runs the pipeline on the given path and returns the md result."""
pass
class PipelineBatchedProto(Protocol):
def __call__(self, paths: list[str | PathLike], tqdm_pbar) -> list[str | None]:
"""Runs the pipeline on the given paths and returns the md results."""
pass
EvaluationResult = NewType("EvaluationResult", pd.DataFrame)
def evaluate_pipeline(run_pipeline: PipelineProto) -> EvaluationResult:
@overload
def evaluate_pipeline(run_pipeline: PipelineProto, mode: Literal["single"]) -> EvaluationResult: ...
@overload
def evaluate_pipeline(run_pipeline: PipelineBatchedProto, mode: Literal["all"]) -> EvaluationResult: ...
def evaluate_pipeline(
run_pipeline: PipelineProto | PipelineBatchedProto, mode: Literal["singe", "all"]
) -> EvaluationResult:
from tqdm.auto import tqdm
dataset, id_to_path = prepare_olmocr_dataset()
metrics_raw = []
if mode == "all":
paths = [id_to_path(s["id"], warn=True) for s in dataset]
with tqdm(total=len(paths), desc="Processing files") as pbar:
run_pipeline: PipelineBatchedProto
md_results = run_pipeline(paths, pbar)
elif mode == "single":
paths = [id_to_path(s["id"], warn=True) for s in dataset]
md_results = [run_pipeline(path) for path in paths]
else:
assert_never(mode)
for s in tqdm(dataset):
path = id_to_path(s["id"], warn=True)
for s, md_result in zip(dataset, md_results):
malformed, response = parse_response(s, warn=True)
if malformed:
continue
md_result = run_pipeline(path)
nid = calc_nid(response.natural_text, md_result)
metrics_raw.append({"nid": nid})
metrics_raw.append({"id": s["id"], "nid": nid})
metrics_df = pd.DataFrame(metrics_raw)
return EvaluationResult(metrics_df)
@@ -44,7 +68,7 @@ def main(pipeline: Literal["docling"]):
else:
assert_never(pipeline)
metrics_df = evaluate_pipeline(run_pipeline)
metrics_df = evaluate_pipeline(run_pipeline, mode="single")
print(metrics_df)

1
uv.lock generated
View File

@@ -1218,6 +1218,7 @@ requires-dist = [
]
[package.metadata.requires-dev]
dev = []
docling = [{ name = "docling", specifier = ">=2.25.2" }]
[[package]]