mirror of
https://github.com/NVIDIA/nv-ingest.git
synced 2025-01-05 18:58:13 +03:00
add flake8 to pre-commit (#276)
Co-authored-by: Jeremy Dyer <jdye64@gmail.com>
This commit is contained in:
@@ -14,3 +14,9 @@ repos:
|
||||
hooks:
|
||||
- id: black
|
||||
args: ["--line-length=120"]
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.1.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ["--max-line-length=120", "--extend-ignore=E203,E266,F403,F405"]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import pandas as pd
|
||||
import math
|
||||
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
@@ -116,11 +116,11 @@ def click_validate_task(ctx, param, value):
|
||||
new_task_id = f"{task_id}_{task_options.document_type}"
|
||||
new_task = [(new_task_id, ExtractTask(**task_options.dict()))]
|
||||
|
||||
if task_options.extract_tables == True:
|
||||
if task_options.extract_tables is True:
|
||||
subtask_options = check_schema(TableExtractionSchema, {}, "table_data_extract", "{}")
|
||||
new_task.append(("table_data_extract", TableExtractionTask(**subtask_options.dict())))
|
||||
|
||||
if task_options.extract_charts == True:
|
||||
if task_options.extract_charts is True:
|
||||
subtask_options = check_schema(ChartExtractionSchema, {}, "chart_data_extract", "{}")
|
||||
new_task.append(("chart_data_extract", ChartExtractionTask(**subtask_options.dict())))
|
||||
|
||||
|
||||
@@ -390,7 +390,7 @@ class NvIngestClient:
|
||||
# Attempt to fetch the job result
|
||||
result = self._fetch_job_result(job_id, timeout, data_only=False)
|
||||
return result, job_id
|
||||
except TimeoutError as err:
|
||||
except TimeoutError:
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"Job {job_id} is not ready. "
|
||||
@@ -401,7 +401,7 @@ class NvIngestClient:
|
||||
time.sleep(retry_delay) # Wait before retrying
|
||||
except (RuntimeError, Exception) as err:
|
||||
# For any other error, log and break out of the retry loop
|
||||
logger.error(f"Error while fetching result for job ID {job_id}: {e}")
|
||||
logger.error(f"Error while fetching result for job ID {job_id}: {err}")
|
||||
return None, job_id
|
||||
logger.error(f"Max retries exceeded for job {job_id}.")
|
||||
return None, job_id
|
||||
@@ -419,19 +419,23 @@ class NvIngestClient:
|
||||
del self._job_index_to_job_spec[job_id]
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.error(
|
||||
f"Timeout while fetching result for job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}"
|
||||
f"Timeout while fetching result for job ID {job_id}: "
|
||||
f"{self._job_index_to_job_spec[job_id].source_id}"
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Decoding while processing job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}"
|
||||
f"Decoding while processing job ID {job_id}: "
|
||||
f"{self._job_index_to_job_spec[job_id].source_id}\n{e}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f"Error while processing job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}"
|
||||
f"Error while processing job ID {job_id}: "
|
||||
f"{self._job_index_to_job_spec[job_id].source_id}\n{e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while fetching result for job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}"
|
||||
f"Error while fetching result for job ID {job_id}: "
|
||||
f"{self._job_index_to_job_spec[job_id].source_id}\n{e}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -17,8 +17,6 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
import re
|
||||
|
||||
from nv_ingest_client.message_clients import MessageBrokerClientBase
|
||||
from nv_ingest_client.message_clients.simple.simple_client import ResponseSchema
|
||||
|
||||
@@ -212,7 +210,9 @@ class RestClient(MessageBrokerClientBase):
|
||||
# Terminal response code; return error ResponseSchema
|
||||
return ResponseSchema(
|
||||
response_code=1,
|
||||
response_reason=f"Terminal response code {response_code} received when fetching JobSpec: {job_id}",
|
||||
response_reason=(
|
||||
f"Terminal response code {response_code} received when fetching JobSpec: {job_id}"
|
||||
),
|
||||
response=result.text,
|
||||
)
|
||||
else:
|
||||
@@ -341,7 +341,8 @@ class RestClient(MessageBrokerClientBase):
|
||||
"""
|
||||
backoff_delay = min(2**existing_retries, self._max_backoff)
|
||||
logger.debug(
|
||||
f"Retry #: {existing_retries} of max_retries: {self.max_retries} | current backoff_delay: {backoff_delay}s of max_backoff: {self._max_backoff}s"
|
||||
f"Retry #: {existing_retries} of max_retries: {self.max_retries} | "
|
||||
f"current backoff_delay: {backoff_delay}s of max_backoff: {self._max_backoff}s"
|
||||
)
|
||||
|
||||
if self.max_retries > 0 and existing_retries < self.max_retries:
|
||||
|
||||
@@ -234,9 +234,9 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
|
||||
return ResponseSchema(**final_response)
|
||||
|
||||
except (ConnectionError, socket.error, BrokenPipeError) as e:
|
||||
except (ConnectionError, socket.error, BrokenPipeError):
|
||||
pass
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
|
||||
except Exception as e:
|
||||
return ResponseSchema(response_code=1, response_reason=str(e))
|
||||
@@ -309,9 +309,9 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
else:
|
||||
return ResponseSchema(**final_response)
|
||||
|
||||
except (ConnectionError, socket.error, BrokenPipeError) as e:
|
||||
except (ConnectionError, socket.error, BrokenPipeError):
|
||||
pass
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
|
||||
except Exception as e:
|
||||
return ResponseSchema(response_code=1, response_reason=str(e))
|
||||
@@ -346,7 +346,7 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
return ResponseSchema(**response)
|
||||
except (ConnectionError, socket.error, BrokenPipeError) as e:
|
||||
return ResponseSchema(response_code=1, response_reason=f"Connection error: {e}")
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
|
||||
except Exception as e:
|
||||
return ResponseSchema(response_code=1, response_reason=str(e))
|
||||
@@ -375,7 +375,7 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
try:
|
||||
sock.sendall(total_length.to_bytes(8, "big"))
|
||||
sock.sendall(data)
|
||||
except (socket.error, BrokenPipeError) as e:
|
||||
except (socket.error, BrokenPipeError):
|
||||
raise ConnectionError("Failed to send data.")
|
||||
|
||||
def _recv(self, sock: socket.socket) -> str:
|
||||
@@ -407,7 +407,7 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
if not data_bytes:
|
||||
raise ConnectionError("Incomplete message received.")
|
||||
return data_bytes.decode("utf-8")
|
||||
except (socket.error, BrokenPipeError, ConnectionError) as e:
|
||||
except (socket.error, BrokenPipeError, ConnectionError):
|
||||
raise ConnectionError("Failed to receive data.")
|
||||
|
||||
def _recv_exact(self, sock: socket.socket, num_bytes: int) -> Optional[bytes]:
|
||||
@@ -434,8 +434,8 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
if not packet:
|
||||
return None
|
||||
data.extend(packet)
|
||||
except socket.timeout as e:
|
||||
except socket.timeout:
|
||||
return None
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
return bytes(data)
|
||||
|
||||
@@ -47,7 +47,7 @@ class CaptionTask(Task):
|
||||
info += "Image Caption Task:\n"
|
||||
|
||||
if self._api_key:
|
||||
info += f" api_key: [redacted]\n"
|
||||
info += " api_key: [redacted]\n"
|
||||
if self._endpoint_url:
|
||||
info += f" endpoint_url: {self._endpoint_url}\n"
|
||||
if self._prompt:
|
||||
|
||||
@@ -103,4 +103,4 @@ def main(json_files, output_file):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
process_json_files()
|
||||
main()
|
||||
|
||||
@@ -262,7 +262,7 @@ def check_ingest_result(json_payload: Dict) -> typing.Tuple[bool, str]:
|
||||
source_id = (
|
||||
json_payload.get("data", [])[0].get("metadata", {}).get("source_metadata", {}).get("source_name", "")
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
source_id = ""
|
||||
|
||||
description = f"[{source_id}]: {json_payload.get('status', '')}\n"
|
||||
@@ -349,7 +349,7 @@ def create_job_specs_for_batch(files_batch: List[str]) -> List[JobSpec]:
|
||||
>>> client = NvIngestClient()
|
||||
>>> job_specs = create_job_specs_for_batch(files_batch)
|
||||
>>> print(job_specs)
|
||||
[nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb468bb0>, <nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb469270>]
|
||||
[nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb468bb0>, <nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb469270>] # noqa: E501,W505
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from typing import List
|
||||
import httpx
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +50,8 @@ class AsyncZipkinClient:
|
||||
if response.status_code == 404:
|
||||
attempt += 1
|
||||
logger.info(
|
||||
f"Attempt {attempt}/{self._max_retries} for trace_id: {trace_id} failed with 404. Retrying in {self._retry_delay} seconds..."
|
||||
f"Attempt {attempt}/{self._max_retries} for trace_id: {trace_id} failed with 404. "
|
||||
f"Retrying in {self._retry_delay} seconds..."
|
||||
)
|
||||
await asyncio.sleep(self._retry_delay)
|
||||
else:
|
||||
|
||||
@@ -41,7 +41,7 @@ def run_ingestor():
|
||||
).embed()
|
||||
|
||||
try:
|
||||
results = ingestor.ingest()
|
||||
_ = ingestor.ingest()
|
||||
logger.info("Ingestion completed successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Ingestion failed: {e}")
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import click
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from morpheus.config import Config
|
||||
from morpheus.config import CppConfig
|
||||
@@ -20,6 +22,7 @@ from nv_ingest.util.schema.schema_validator import validate_schema
|
||||
from nv_ingest.util.pipeline.stage_builders import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
local_log_level = os.getenv("INGEST_LOG_LEVEL", "INFO")
|
||||
if local_log_level in ("DEFAULT",):
|
||||
local_log_level = "INFO"
|
||||
|
||||
@@ -2,6 +2,20 @@
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import traceback
|
||||
@@ -25,47 +39,6 @@ from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
|
||||
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_base64
|
||||
from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata
|
||||
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from wand.image import Image as WandImage
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
|
||||
from nv_ingest.extraction_workflows.pdf.doughnut_utils import crop_image
|
||||
import nv_ingest.util.nim.yolox as yolox_utils
|
||||
from nv_ingest.schemas.image_extractor_schema import ImageExtractorSchema
|
||||
from nv_ingest.schemas.metadata_schema import AccessLevelEnum
|
||||
from nv_ingest.util.image_processing.transforms import numpy_to_base64
|
||||
from nv_ingest.util.nim.helpers import create_inference_client
|
||||
from nv_ingest.util.pdf.metadata_aggregators import (
|
||||
CroppedImageWithContent,
|
||||
construct_image_metadata_from_pdf_image,
|
||||
construct_image_metadata_from_base64,
|
||||
)
|
||||
from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
YOLOX_MAX_BATCH_SIZE = 8
|
||||
@@ -457,7 +430,7 @@ def image_data_extractor(
|
||||
config=kwargs.get("image_extraction_config"),
|
||||
trace_info=trace_info,
|
||||
)
|
||||
logger.debug(f"Extracted table/chart data from image")
|
||||
logger.debug("Extracted table/chart data from image")
|
||||
for _, table_chart_data in tables_and_charts:
|
||||
extracted_data.append(
|
||||
construct_table_and_chart_metadata(
|
||||
|
||||
@@ -27,7 +27,6 @@ import numpy as np
|
||||
import pypdfium2 as libpdfium
|
||||
import nv_ingest.util.nim.yolox as yolox_utils
|
||||
|
||||
import nv_ingest.util.nim.yolox as yolox_utils
|
||||
from nv_ingest.schemas.metadata_schema import AccessLevelEnum
|
||||
from nv_ingest.schemas.metadata_schema import TableFormatEnum
|
||||
from nv_ingest.schemas.metadata_schema import TextTypeEnum
|
||||
|
||||
@@ -170,7 +170,8 @@ def push_to_broker(
|
||||
broker_client: MessageBrokerClientBase, response_channel: str, json_payloads: List[str], retry_count: int = 2
|
||||
) -> None:
|
||||
"""
|
||||
Attempts to push a JSON payload to a message broker channel, retrying on failure up to a specified number of attempts.
|
||||
Attempts to push a JSON payload to a message broker channel, retrying on failure up to a specified number of
|
||||
attempts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -18,12 +18,10 @@ from morpheus.utils.module_ids import WRITE_TO_VECTOR_DB
|
||||
from morpheus.utils.module_utils import ModuleLoaderFactory
|
||||
from morpheus.utils.module_utils import register_module
|
||||
from morpheus_llm.service.vdb.milvus_client import DATA_TYPE_MAP
|
||||
from morpheus_llm.service.vdb.milvus_vector_db_service import MilvusVectorDBService
|
||||
from morpheus_llm.service.vdb.utils import VectorDBServiceFactory
|
||||
from morpheus_llm.service.vdb.vector_db_service import VectorDBService
|
||||
from mrc.core import operators as ops
|
||||
from pymilvus import BulkInsertState
|
||||
from pymilvus import Collection
|
||||
from pymilvus import connections
|
||||
from pymilvus import utility
|
||||
|
||||
@@ -74,8 +72,7 @@ def _bulk_ingest(
|
||||
]
|
||||
|
||||
uri_parsed = urlparse(milvus_uri)
|
||||
conn = connections.connect(host=uri_parsed.hostname, port=uri_parsed.port)
|
||||
collection = Collection(name=collection_name)
|
||||
_ = connections.connect(host=uri_parsed.hostname, port=uri_parsed.port)
|
||||
|
||||
task_ids = []
|
||||
for file in batch_files:
|
||||
|
||||
@@ -65,14 +65,14 @@ def fetch_and_process_messages(client, validated_config: MessageBrokerTaskSource
|
||||
if job.response_code != 0:
|
||||
continue
|
||||
|
||||
logger.debug(f"Received ResponseSchema, converting to dict")
|
||||
logger.debug("Received ResponseSchema, converting to dict")
|
||||
job = json.loads(job.response)
|
||||
else:
|
||||
logger.debug(f"Received something not a ResponseSchema")
|
||||
logger.debug("Received something not a ResponseSchema")
|
||||
|
||||
ts_fetched = datetime.now()
|
||||
yield process_message(job, ts_fetched)
|
||||
except TimeoutError as err:
|
||||
except TimeoutError:
|
||||
continue
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
|
||||
@@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
|
||||
MODULE_NAME = "image_storage"
|
||||
MODULE_NAMESPACE = "nv_ingest"
|
||||
|
||||
# TODO: Move these into microservice_entrypoint.py to populate the stage and validate them using the pydantic schema on startup.
|
||||
# TODO: Move these into microservice_entrypoint.py to populate the stage and validate them using the pydantic schema
|
||||
# on startup.
|
||||
_DEFAULT_ENDPOINT = os.environ.get("MINIO_INTERNAL_ADDRESS", "minio:9000")
|
||||
_DEFAULT_READ_ADDRESS = os.environ.get("MINIO_PUBLIC_ADDRESS", "http://minio:9000")
|
||||
_DEFAULT_BUCKET_NAME = os.environ.get("MINIO_BUCKET", "nv-ingest")
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Union, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Extra
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -45,7 +45,8 @@ def decode_and_extract(
|
||||
validated_config : Any
|
||||
Configuration object that contains `image_config`. Used if the `image` method is selected.
|
||||
default : str, optional
|
||||
The default extraction method to use if the specified method in `task_props` is not available (default is "image").
|
||||
The default extraction method to use if the specified method in `task_props` is not available
|
||||
(default is "image").
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -138,7 +139,8 @@ def process_image(
|
||||
-------
|
||||
Tuple[pd.DataFrame, Dict[str, Any]]
|
||||
A tuple containing:
|
||||
- A pandas DataFrame with the processed image content, including columns 'document_type', 'metadata', and 'uuid'.
|
||||
- A pandas DataFrame with the processed image content, including columns 'document_type', 'metadata',
|
||||
and 'uuid'.
|
||||
- A dictionary with trace information collected during processing.
|
||||
|
||||
Raises
|
||||
|
||||
@@ -139,15 +139,16 @@ class MultiProcessingBaseStage(SinglePortStage):
|
||||
forwards the task to a global multi-process worker pool where the heavy-lifting occurs.
|
||||
|
||||
3. **Global Worker Pool**: The work is executed in parallel across multiple process engines via the worker pool.
|
||||
Each process engine applies the `process_fn` to the task data, which includes a pandas DataFrame and task-specific arguments.
|
||||
Each process engine applies the `process_fn` to the task data, which includes a pandas DataFrame and
|
||||
task-specific arguments.
|
||||
|
||||
4. **Response Queue**: After the work is completed by the worker pool, the results are pushed into a response queue.
|
||||
|
||||
5. **Post-Processing and Emission**: The results from the response queue are post-processed, reconstructed into their
|
||||
original format, and emitted from an observable source for further downstream processing or final output.
|
||||
5. **Post-Processing and Emission**: The results from the response queue are post-processed, reconstructed into
|
||||
their original format, and emitted from an observable source for further downstream processing or final output.
|
||||
|
||||
This design enhances parallelism and resource utilization across multiple processes, especially for tasks that involve
|
||||
heavy computations, such as large DataFrame operations.
|
||||
This design enhances parallelism and resource utilization across multiple processes, especially for tasks that
|
||||
involve heavy computations, such as large DataFrame operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -186,7 +186,7 @@ def _extract_chart_data(
|
||||
|
||||
return df, {"trace_info": trace_info}
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error("Error occurred while extracting chart data.", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
|
||||
@@ -139,8 +139,6 @@ def _extract_table_data(
|
||||
|
||||
stage_config = validated_config.stage_config
|
||||
|
||||
paddle_infer_protocol = stage_config.paddle_infer_protocol.lower()
|
||||
|
||||
# Obtain paddle_version
|
||||
# Assuming that the grpc endpoint is at index 0
|
||||
paddle_endpoint = stage_config.paddle_endpoints[1]
|
||||
@@ -169,7 +167,7 @@ def _extract_table_data(
|
||||
|
||||
return df, {"trace_info": trace_info}
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error("Error occurred while extracting table data.", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
|
||||
@@ -44,7 +44,8 @@ def decode_and_extract(
|
||||
validated_config : Any
|
||||
Configuration object that contains `pdfium_config`. Used if the `pdfium` method is selected.
|
||||
default : str, optional
|
||||
The default extraction method to use if the specified method in `task_props` is not available (default is "pdfium").
|
||||
The default extraction method to use if the specified method in `task_props` is not available
|
||||
(default is "pdfium").
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -6,12 +6,9 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import mrc
|
||||
import pandas as pd
|
||||
from minio import Minio
|
||||
from morpheus.config import Config
|
||||
@@ -35,11 +32,9 @@ def upload_embeddings(df: pd.DataFrame, params: Dict[str, Any]) -> pd.DataFrame:
|
||||
Identify contents (e.g., images) within a dataframe and uploads the data to MinIO.
|
||||
The image metadata in the metadata column is updated with the URL of the uploaded data.
|
||||
"""
|
||||
dimension = params.get("dim", 1024)
|
||||
access_key = params.get("access_key", None)
|
||||
secret_key = params.get("secret_key", None)
|
||||
|
||||
content_types = params.get("content_types")
|
||||
endpoint = params.get("endpoint", _DEFAULT_ENDPOINT)
|
||||
bucket_name = params.get("bucket_name", _DEFAULT_BUCKET_NAME)
|
||||
bucket_path = params.get("bucket_path", "embeddings")
|
||||
@@ -73,7 +68,6 @@ def upload_embeddings(df: pd.DataFrame, params: Dict[str, Any]) -> pd.DataFrame:
|
||||
)
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
uu_id = row["uuid"]
|
||||
metadata = row["metadata"].copy()
|
||||
metadata["embedding_metadata"] = {}
|
||||
metadata["embedding_metadata"]["uploaded_embedding_url"] = bucket_path
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@@ -14,13 +12,11 @@ from typing import Tuple
|
||||
import pandas as pd
|
||||
import requests
|
||||
from morpheus.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from nv_ingest.schemas.image_caption_extraction_schema import ImageCaptionExtractionSchema
|
||||
from nv_ingest.schemas.metadata_schema import ContentTypeEnum
|
||||
from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
|
||||
from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size
|
||||
from nv_ingest.util.tracing.tagging import traceable_func
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -145,9 +145,12 @@ class RedisClient(MessageBrokerClientBase):
|
||||
-------
|
||||
Tuple[Optional[Dict[str, Any]], Optional[int], Optional[int]]
|
||||
A tuple containing:
|
||||
- message: A dictionary containing the decoded message if successful, or None if no message was retrieved.
|
||||
- fragment: An integer representing the fragment number of the message, or None if no fragment was found.
|
||||
- fragment_count: An integer representing the total number of message fragments, or None if no fragment count was found.
|
||||
- message: A dictionary containing the decoded message if successful,
|
||||
or None if no message was retrieved.
|
||||
- fragment: An integer representing the fragment number of the message,
|
||||
or None if no fragment was found.
|
||||
- fragment_count: An integer representing the total number of message fragments,
|
||||
or None if no fragment count was found.
|
||||
|
||||
Raises
|
||||
------
|
||||
|
||||
@@ -69,8 +69,6 @@ class SimpleMessageBrokerHandler(socketserver.BaseRequestHandler):
|
||||
|
||||
# Validate and extract common fields
|
||||
queue_name = request_data.get("queue_name")
|
||||
message = request_data.get("message")
|
||||
timeout = request_data.get("timeout", 100)
|
||||
|
||||
# Initialize the queue and its lock if necessary
|
||||
if queue_name:
|
||||
|
||||
@@ -11,9 +11,7 @@ import socket
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from nv_ingest.schemas.message_brokers.response_schema import ResponseSchema
|
||||
from nv_ingest_client.message_clients.client_base import MessageBrokerClientBase
|
||||
@@ -230,9 +228,9 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
|
||||
return ResponseSchema(**final_response)
|
||||
|
||||
except (ConnectionError, socket.error, BrokenPipeError) as e:
|
||||
except (ConnectionError, socket.error, BrokenPipeError):
|
||||
pass
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
|
||||
except Exception as e:
|
||||
return ResponseSchema(response_code=1, response_reason=str(e))
|
||||
@@ -305,9 +303,9 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
else:
|
||||
return ResponseSchema(**final_response)
|
||||
|
||||
except (ConnectionError, socket.error, BrokenPipeError) as e:
|
||||
except (ConnectionError, socket.error, BrokenPipeError):
|
||||
pass
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
|
||||
except Exception as e:
|
||||
return ResponseSchema(response_code=1, response_reason=str(e))
|
||||
@@ -342,7 +340,7 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
return ResponseSchema(**response)
|
||||
except (ConnectionError, socket.error, BrokenPipeError) as e:
|
||||
return ResponseSchema(response_code=1, response_reason=f"Connection error: {e}")
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
|
||||
except Exception as e:
|
||||
return ResponseSchema(response_code=1, response_reason=str(e))
|
||||
@@ -371,7 +369,7 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
try:
|
||||
sock.sendall(total_length.to_bytes(8, "big"))
|
||||
sock.sendall(data)
|
||||
except (socket.error, BrokenPipeError) as e:
|
||||
except (socket.error, BrokenPipeError):
|
||||
raise ConnectionError("Failed to send data.")
|
||||
|
||||
def _recv(self, sock: socket.socket) -> str:
|
||||
@@ -403,7 +401,7 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
if not data_bytes:
|
||||
raise ConnectionError("Incomplete message received.")
|
||||
return data_bytes.decode("utf-8")
|
||||
except (socket.error, BrokenPipeError, ConnectionError) as e:
|
||||
except (socket.error, BrokenPipeError, ConnectionError):
|
||||
raise ConnectionError("Failed to receive data.")
|
||||
|
||||
def _recv_exact(self, sock: socket.socket, num_bytes: int) -> Optional[bytes]:
|
||||
@@ -430,8 +428,8 @@ class SimpleClient(MessageBrokerClientBase):
|
||||
if not packet:
|
||||
return None
|
||||
data.extend(packet)
|
||||
except socket.timeout as e:
|
||||
except socket.timeout:
|
||||
return None
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
return bytes(data)
|
||||
|
||||
@@ -4,16 +4,9 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from math import ceil
|
||||
from math import floor
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from nv_ingest.util.image_processing.transforms import numpy_to_base64
|
||||
|
||||
ACCEPTED_TEXT_CLASSES = set(
|
||||
[
|
||||
"Text",
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import backoff
|
||||
@@ -271,7 +272,8 @@ class NimClient:
|
||||
if status_code in [429, 503]:
|
||||
# Warn and attempt to retry
|
||||
logger.warning(
|
||||
f"Received HTTP {status_code} ({response.reason}) from {self.model_interface.name()}. Retrying..."
|
||||
f"Received HTTP {status_code} ({response.reason}) from "
|
||||
f"{self.model_interface.name()}. Retrying..."
|
||||
)
|
||||
if attempt == max_retries:
|
||||
# No more retries left
|
||||
@@ -290,7 +292,10 @@ class NimClient:
|
||||
return response.json()
|
||||
|
||||
except requests.Timeout:
|
||||
err_msg = f"HTTP request timed out during {self.model_interface.name()} inference after {self.timeout} seconds"
|
||||
err_msg = (
|
||||
f"HTTP request timed out during {self.model_interface.name()} "
|
||||
f"inference after {self.timeout} seconds"
|
||||
)
|
||||
logger.error(err_msg)
|
||||
raise TimeoutError(err_msg)
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ def test_apply_dedup(should_filter, expected0, expected1, expected2):
|
||||
|
||||
payload_list = []
|
||||
for _ in range(3):
|
||||
payload_list.append(valid_image_dedup_payload(f"test", 1, 1))
|
||||
payload_list.append(valid_image_dedup_payload("test", 1, 1))
|
||||
|
||||
extracted_df = pd.DataFrame(payload_list, columns=["document_type", "metadata"])
|
||||
extracted_gdf = cudf.from_pandas(extracted_df)
|
||||
|
||||
@@ -98,11 +98,6 @@ def test_image_metadata_schema_invalid_type():
|
||||
ImageMetadataSchema(image_type=3.14) # Using a float value
|
||||
|
||||
|
||||
def test_image_metadata_schema_invalid_type():
|
||||
with pytest.raises(ValidationError):
|
||||
ImageMetadataSchema(image_type=3.14)
|
||||
|
||||
|
||||
# Test cases for TableMetadataSchema
|
||||
@pytest.mark.parametrize("table_format", ["html", "markdown", "latex", "image"])
|
||||
def test_table_metadata_schema_defaults(table_format):
|
||||
|
||||
@@ -23,6 +23,6 @@ def test_otel_meter_schema_custom_values():
|
||||
assert schema.broker_client.host == "custom_host", "Custom host value for redis_client should be respected."
|
||||
assert schema.broker_client.port == 12345, "Custom port value for redis_client should be respected."
|
||||
assert (
|
||||
schema.broker_client.broker_params["use_ssl"] == True
|
||||
schema.broker_client.broker_params["use_ssl"] is True
|
||||
), "Custom use_ssl value for broker_client should be True."
|
||||
assert schema.raise_on_failure is True, "Custom value for raise_on_failure should be respected."
|
||||
|
||||
@@ -31,7 +31,7 @@ def test_redis_task_sink_schema_custom_values():
|
||||
assert schema.broker_client.host == "custom_host", "Custom host value for broker_client should be respected."
|
||||
assert schema.broker_client.port == 12345, "Custom port value for broker_client should be respected."
|
||||
assert (
|
||||
schema.broker_client.broker_params["use_ssl"] == True
|
||||
schema.broker_client.broker_params["use_ssl"] is True
|
||||
), "Custom use_ssl value for redis_client should be True."
|
||||
assert schema.raise_on_failure is True, "Custom value for raise_on_failure should be respected."
|
||||
assert schema.progress_engines == 10, "Custom value for progress_engines should be respected."
|
||||
|
||||
@@ -33,7 +33,7 @@ def test_redis_task_source_schema_custom_values():
|
||||
assert schema.broker_client.host == "custom_host", "Custom host value for redis_client should be respected."
|
||||
assert schema.broker_client.port == 12345, "Custom port value for redis_client should be respected."
|
||||
assert (
|
||||
schema.broker_client.broker_params["use_ssl"] == True
|
||||
schema.broker_client.broker_params["use_ssl"] is True
|
||||
), "Custom use_ssl value for redis_client should be True."
|
||||
assert schema.task_queue == "custom_queue", "Custom value for task_queue should be respected."
|
||||
assert schema.progress_engines == 10, "Custom value for progress_engines should be respected."
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from nv_ingest.stages.nim.chart_extraction import _update_metadata
|
||||
from nv_ingest.stages.nim.chart_extraction import _extract_chart_data
|
||||
import requests
|
||||
|
||||
MODULE_UNDER_TEST = "nv_ingest.stages.nim.chart_extraction"
|
||||
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
import pytest
|
||||
import base64
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from io import BytesIO
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from nv_ingest.stages.nim.table_extraction import _update_metadata, _extract_table_data
|
||||
from nv_ingest.stages.nim.table_extraction import _extract_table_data
|
||||
from nv_ingest.stages.nim.table_extraction import _update_metadata
|
||||
from nv_ingest.util.nim.helpers import NimClient
|
||||
from nv_ingest.util.nim.paddle import PaddleOCRModelInterface
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ import pytest
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
MODULE_UNDER_TEST = "nv_ingest.stages.transforms.image_caption_extraction"
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from nv_ingest.schemas.metadata_schema import ContentTypeEnum
|
||||
@@ -19,6 +17,8 @@ from nv_ingest.stages.transforms.image_caption_extraction import _generate_capti
|
||||
from nv_ingest.stages.transforms.image_caption_extraction import _prepare_dataframes_mod
|
||||
from nv_ingest.stages.transforms.image_caption_extraction import caption_extract_stage
|
||||
|
||||
MODULE_UNDER_TEST = "nv_ingest.stages.transforms.image_caption_extraction"
|
||||
|
||||
|
||||
def generate_base64_png_image() -> str:
|
||||
"""Helper function to generate a base64-encoded PNG image."""
|
||||
|
||||
@@ -488,18 +488,6 @@ def test_create_inference_client_http_endpoint_whitespace_no_infer_protocol(mock
|
||||
|
||||
|
||||
# Preprocess image for paddle
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image():
|
||||
"""
|
||||
Returns a sample image array of shape (height, width, channels) with random pixel values.
|
||||
"""
|
||||
height, width = 800, 600 # Example dimensions
|
||||
image = np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
|
||||
return image
|
||||
|
||||
|
||||
def test_preprocess_image_paddle_version_none(sample_image):
|
||||
"""
|
||||
Test that when paddle_version is None, the function returns the input image unchanged.
|
||||
|
||||
@@ -143,6 +143,8 @@ def test_parse_output_http_valid(model_interface):
|
||||
{
|
||||
"type": "table",
|
||||
"bboxes": [{"xmin": 0.1, "ymin": 0.1, "xmax": 0.2, "ymax": 0.2, "confidence": 0.9}],
|
||||
},
|
||||
{
|
||||
"type": "chart",
|
||||
"bboxes": [{"xmin": 0.3, "ymin": 0.3, "xmax": 0.4, "ymax": 0.4, "confidence": 0.8}],
|
||||
},
|
||||
@@ -152,6 +154,8 @@ def test_parse_output_http_valid(model_interface):
|
||||
{
|
||||
"type": "table",
|
||||
"bboxes": [{"xmin": 0.15, "ymin": 0.15, "xmax": 0.25, "ymax": 0.25, "confidence": 0.85}],
|
||||
},
|
||||
{
|
||||
"type": "chart",
|
||||
"bboxes": [{"xmin": 0.35, "ymin": 0.35, "xmax": 0.45, "ymax": 0.45, "confidence": 0.75}],
|
||||
},
|
||||
@@ -166,7 +170,7 @@ def test_parse_output_http_valid(model_interface):
|
||||
data = {"scaling_factors": scaling_factors}
|
||||
parsed_output = model_interface.parse_output(response, "http", data)
|
||||
assert isinstance(parsed_output, np.ndarray)
|
||||
assert parsed_output.shape == (2, 2, 85)
|
||||
assert parsed_output.shape == (2, 3, 85)
|
||||
assert parsed_output.dtype == np.float32
|
||||
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ def test_store_embed_task_no_args(ingestor):
|
||||
assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[0], StoreEmbedTask)
|
||||
|
||||
|
||||
def test_store_task_some_args(ingestor):
|
||||
def test_store_task_some_args_extra_param(ingestor):
|
||||
ingestor.store_embed(params={"extra_param": "extra"})
|
||||
|
||||
task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0]
|
||||
|
||||
@@ -146,7 +146,7 @@ def test_init_with_files(mocker, job_spec_fixture):
|
||||
assert len(batch_job_spec._file_type_to_job_spec["pdf"]) > 0
|
||||
|
||||
|
||||
def test_add_task_to_specific_document_type(batch_job_spec_fixture):
|
||||
def test_add_task_to_specific_document_type_batch_job_spec(batch_job_spec_fixture):
|
||||
task = MockTask()
|
||||
|
||||
# Add task to jobs with document_type 'pdf'
|
||||
@@ -248,7 +248,7 @@ def test_add_task_to_all_documents():
|
||||
assert dedup_task in job_specs[0]._tasks
|
||||
|
||||
|
||||
def test_add_task_to_specific_document_type():
|
||||
def test_add_task_to_specific_document_type_job_spec():
|
||||
batch_job_spec = BatchJobSpec([JobSpec(document_type="pdf"), JobSpec(document_type="txt")])
|
||||
|
||||
embed_task = EmbedTask()
|
||||
|
||||
@@ -32,7 +32,8 @@ def test_extract_task_str_representation(document_type, extract_method, extract_
|
||||
f"extract text: {extract_text}",
|
||||
f"extract images: {extract_images}",
|
||||
f"extract tables: {extract_tables}",
|
||||
f"extract charts: {extract_tables}", # If extract_charts is not specified, it defaults to the same value as extract_tables.
|
||||
f"extract charts: {extract_tables}", # If extract_charts is not specified,
|
||||
# it defaults to the same value as extract_tables.
|
||||
"text depth: document", # Assuming this is a fixed value for all instances
|
||||
]
|
||||
|
||||
@@ -107,7 +108,8 @@ def test_extract_task_initialization(extract_method, extract_text, extract_image
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method, paddle_output_format",
|
||||
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method,"
|
||||
"paddle_output_format",
|
||||
[
|
||||
("pdf", "tika", True, False, False, "yolox", "pseudo_markdown"),
|
||||
("docx", "haystack", False, True, True, "python_docx", "simple"),
|
||||
@@ -142,7 +144,8 @@ def test_extract_task_to_dict_basic(
|
||||
"extract_images": extract_images,
|
||||
"extract_tables": extract_tables,
|
||||
"extract_tables_method": extract_tables_method,
|
||||
"extract_charts": extract_tables, # If extract_charts is not specified, it defaults to the same value as extract_tables.
|
||||
"extract_charts": extract_tables, # If extract_charts is not specified,
|
||||
# it defaults to the same value as extract_tables.
|
||||
"text_depth": "document",
|
||||
"paddle_output_format": paddle_output_format,
|
||||
},
|
||||
@@ -153,7 +156,10 @@ def test_extract_task_to_dict_basic(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method, extract_charts, paddle_output_format",
|
||||
(
|
||||
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method,"
|
||||
"extract_charts, paddle_output_format"
|
||||
),
|
||||
[
|
||||
("pdf", "tika", True, False, False, "yolox", False, "pseudo_markdown"),
|
||||
("docx", "haystack", False, True, True, "python_docx", False, "simple"),
|
||||
|
||||
Reference in New Issue
Block a user