add flake8 to pre-commit (#276)

Co-authored-by: Jeremy Dyer <jdye64@gmail.com>
This commit is contained in:
Edward Kim
2024-12-13 09:11:31 -08:00
committed by GitHub
parent 1269ff5008
commit 362a63f57c
45 changed files with 145 additions and 180 deletions

View File

@@ -14,3 +14,9 @@ repos:
hooks:
- id: black
args: ["--line-length=120"]
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
args: ["--max-line-length=120", "--extend-ignore=E203,E266,F403,F405"]

View File

@@ -1,6 +1,5 @@
import json
import pandas as pd
import math
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

View File

@@ -116,11 +116,11 @@ def click_validate_task(ctx, param, value):
new_task_id = f"{task_id}_{task_options.document_type}"
new_task = [(new_task_id, ExtractTask(**task_options.dict()))]
if task_options.extract_tables == True:
if task_options.extract_tables is True:
subtask_options = check_schema(TableExtractionSchema, {}, "table_data_extract", "{}")
new_task.append(("table_data_extract", TableExtractionTask(**subtask_options.dict())))
if task_options.extract_charts == True:
if task_options.extract_charts is True:
subtask_options = check_schema(ChartExtractionSchema, {}, "chart_data_extract", "{}")
new_task.append(("chart_data_extract", ChartExtractionTask(**subtask_options.dict())))

View File

@@ -390,7 +390,7 @@ class NvIngestClient:
# Attempt to fetch the job result
result = self._fetch_job_result(job_id, timeout, data_only=False)
return result, job_id
except TimeoutError as err:
except TimeoutError:
if verbose:
logger.info(
f"Job {job_id} is not ready. "
@@ -401,7 +401,7 @@ class NvIngestClient:
time.sleep(retry_delay) # Wait before retrying
except (RuntimeError, Exception) as err:
# For any other error, log and break out of the retry loop
logger.error(f"Error while fetching result for job ID {job_id}: {e}")
logger.error(f"Error while fetching result for job ID {job_id}: {err}")
return None, job_id
logger.error(f"Max retries exceeded for job {job_id}.")
return None, job_id
@@ -419,19 +419,23 @@ class NvIngestClient:
del self._job_index_to_job_spec[job_id]
except concurrent.futures.TimeoutError:
logger.error(
f"Timeout while fetching result for job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}"
f"Timeout while fetching result for job ID {job_id}: "
f"{self._job_index_to_job_spec[job_id].source_id}"
)
except json.JSONDecodeError as e:
logger.error(
f"Decoding while processing job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}"
f"Decoding while processing job ID {job_id}: "
f"{self._job_index_to_job_spec[job_id].source_id}\n{e}"
)
except RuntimeError as e:
logger.error(
f"Error while processing job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}"
f"Error while processing job ID {job_id}: "
f"{self._job_index_to_job_spec[job_id].source_id}\n{e}"
)
except Exception as e:
logger.error(
f"Error while fetching result for job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}"
f"Error while fetching result for job ID {job_id}: "
f"{self._job_index_to_job_spec[job_id].source_id}\n{e}"
)
return results

View File

@@ -17,8 +17,6 @@ from typing import Any
import httpx
import requests
import re
from nv_ingest_client.message_clients import MessageBrokerClientBase
from nv_ingest_client.message_clients.simple.simple_client import ResponseSchema
@@ -212,7 +210,9 @@ class RestClient(MessageBrokerClientBase):
# Terminal response code; return error ResponseSchema
return ResponseSchema(
response_code=1,
response_reason=f"Terminal response code {response_code} received when fetching JobSpec: {job_id}",
response_reason=(
f"Terminal response code {response_code} received when fetching JobSpec: {job_id}"
),
response=result.text,
)
else:
@@ -341,7 +341,8 @@ class RestClient(MessageBrokerClientBase):
"""
backoff_delay = min(2**existing_retries, self._max_backoff)
logger.debug(
f"Retry #: {existing_retries} of max_retries: {self.max_retries} | current backoff_delay: {backoff_delay}s of max_backoff: {self._max_backoff}s"
f"Retry #: {existing_retries} of max_retries: {self.max_retries} | "
f"current backoff_delay: {backoff_delay}s of max_backoff: {self._max_backoff}s"
)
if self.max_retries > 0 and existing_retries < self.max_retries:

View File

@@ -234,9 +234,9 @@ class SimpleClient(MessageBrokerClientBase):
return ResponseSchema(**final_response)
except (ConnectionError, socket.error, BrokenPipeError) as e:
except (ConnectionError, socket.error, BrokenPipeError):
pass
except json.JSONDecodeError as e:
except json.JSONDecodeError:
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
except Exception as e:
return ResponseSchema(response_code=1, response_reason=str(e))
@@ -309,9 +309,9 @@ class SimpleClient(MessageBrokerClientBase):
else:
return ResponseSchema(**final_response)
except (ConnectionError, socket.error, BrokenPipeError) as e:
except (ConnectionError, socket.error, BrokenPipeError):
pass
except json.JSONDecodeError as e:
except json.JSONDecodeError:
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
except Exception as e:
return ResponseSchema(response_code=1, response_reason=str(e))
@@ -346,7 +346,7 @@ class SimpleClient(MessageBrokerClientBase):
return ResponseSchema(**response)
except (ConnectionError, socket.error, BrokenPipeError) as e:
return ResponseSchema(response_code=1, response_reason=f"Connection error: {e}")
except json.JSONDecodeError as e:
except json.JSONDecodeError:
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
except Exception as e:
return ResponseSchema(response_code=1, response_reason=str(e))
@@ -375,7 +375,7 @@ class SimpleClient(MessageBrokerClientBase):
try:
sock.sendall(total_length.to_bytes(8, "big"))
sock.sendall(data)
except (socket.error, BrokenPipeError) as e:
except (socket.error, BrokenPipeError):
raise ConnectionError("Failed to send data.")
def _recv(self, sock: socket.socket) -> str:
@@ -407,7 +407,7 @@ class SimpleClient(MessageBrokerClientBase):
if not data_bytes:
raise ConnectionError("Incomplete message received.")
return data_bytes.decode("utf-8")
except (socket.error, BrokenPipeError, ConnectionError) as e:
except (socket.error, BrokenPipeError, ConnectionError):
raise ConnectionError("Failed to receive data.")
def _recv_exact(self, sock: socket.socket, num_bytes: int) -> Optional[bytes]:
@@ -434,8 +434,8 @@ class SimpleClient(MessageBrokerClientBase):
if not packet:
return None
data.extend(packet)
except socket.timeout as e:
except socket.timeout:
return None
except Exception as e:
except Exception:
return None
return bytes(data)

View File

@@ -47,7 +47,7 @@ class CaptionTask(Task):
info += "Image Caption Task:\n"
if self._api_key:
info += f" api_key: [redacted]\n"
info += " api_key: [redacted]\n"
if self._endpoint_url:
info += f" endpoint_url: {self._endpoint_url}\n"
if self._prompt:

View File

@@ -103,4 +103,4 @@ def main(json_files, output_file):
if __name__ == "__main__":
process_json_files()
main()

View File

@@ -262,7 +262,7 @@ def check_ingest_result(json_payload: Dict) -> typing.Tuple[bool, str]:
source_id = (
json_payload.get("data", [])[0].get("metadata", {}).get("source_metadata", {}).get("source_name", "")
)
except Exception as e:
except Exception:
source_id = ""
description = f"[{source_id}]: {json_payload.get('status', '')}\n"
@@ -349,7 +349,7 @@ def create_job_specs_for_batch(files_batch: List[str]) -> List[JobSpec]:
>>> client = NvIngestClient()
>>> job_specs = create_job_specs_for_batch(files_batch)
>>> print(job_specs)
[nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb468bb0>, <nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb469270>]
[nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb468bb0>, <nv_ingest_client.primitives.jobs.job_spec.JobSpec object at 0x743acb469270>] # noqa: E501,W505
See Also
--------

View File

@@ -2,15 +2,15 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Dict, Optional
from typing import List
import httpx
import json
import os
import asyncio
import json
import logging
import os
from typing import Dict
from typing import List
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
@@ -50,7 +50,8 @@ class AsyncZipkinClient:
if response.status_code == 404:
attempt += 1
logger.info(
f"Attempt {attempt}/{self._max_retries} for trace_id: {trace_id} failed with 404. Retrying in {self._retry_delay} seconds..."
f"Attempt {attempt}/{self._max_retries} for trace_id: {trace_id} failed with 404. "
f"Retrying in {self._retry_delay} seconds..."
)
await asyncio.sleep(self._retry_delay)
else:

View File

@@ -41,7 +41,7 @@ def run_ingestor():
).embed()
try:
results = ingestor.ingest()
_ = ingestor.ingest()
logger.info("Ingestion completed successfully.")
except Exception as e:
logger.error(f"Ingestion failed: {e}")

View File

@@ -2,8 +2,10 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import click
import json
import logging
import os
from morpheus.config import Config
from morpheus.config import CppConfig
@@ -20,6 +22,7 @@ from nv_ingest.util.schema.schema_validator import validate_schema
from nv_ingest.util.pipeline.stage_builders import *
logger = logging.getLogger(__name__)
local_log_level = os.getenv("INGEST_LOG_LEVEL", "INFO")
if local_log_level in ("DEFAULT",):
local_log_level = "INFO"

View File

@@ -2,6 +2,20 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import traceback
@@ -25,47 +39,6 @@ from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_base64
from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import traceback
from datetime import datetime
from typing import List, Dict
from typing import Optional
from typing import Tuple
from wand.image import Image as WandImage
from PIL import Image
import io
import numpy as np
from nv_ingest.extraction_workflows.pdf.doughnut_utils import crop_image
import nv_ingest.util.nim.yolox as yolox_utils
from nv_ingest.schemas.image_extractor_schema import ImageExtractorSchema
from nv_ingest.schemas.metadata_schema import AccessLevelEnum
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim.helpers import create_inference_client
from nv_ingest.util.pdf.metadata_aggregators import (
CroppedImageWithContent,
construct_image_metadata_from_pdf_image,
construct_image_metadata_from_base64,
)
from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata
logger = logging.getLogger(__name__)
YOLOX_MAX_BATCH_SIZE = 8
@@ -457,7 +430,7 @@ def image_data_extractor(
config=kwargs.get("image_extraction_config"),
trace_info=trace_info,
)
logger.debug(f"Extracted table/chart data from image")
logger.debug("Extracted table/chart data from image")
for _, table_chart_data in tables_and_charts:
extracted_data.append(
construct_table_and_chart_metadata(

View File

@@ -27,7 +27,6 @@ import numpy as np
import pypdfium2 as libpdfium
import nv_ingest.util.nim.yolox as yolox_utils
import nv_ingest.util.nim.yolox as yolox_utils
from nv_ingest.schemas.metadata_schema import AccessLevelEnum
from nv_ingest.schemas.metadata_schema import TableFormatEnum
from nv_ingest.schemas.metadata_schema import TextTypeEnum

View File

@@ -170,7 +170,8 @@ def push_to_broker(
broker_client: MessageBrokerClientBase, response_channel: str, json_payloads: List[str], retry_count: int = 2
) -> None:
"""
Attempts to push a JSON payload to a message broker channel, retrying on failure up to a specified number of attempts.
Attempts to push a JSON payload to a message broker channel, retrying on failure up to a specified number of
attempts.
Parameters
----------

View File

@@ -18,12 +18,10 @@ from morpheus.utils.module_ids import WRITE_TO_VECTOR_DB
from morpheus.utils.module_utils import ModuleLoaderFactory
from morpheus.utils.module_utils import register_module
from morpheus_llm.service.vdb.milvus_client import DATA_TYPE_MAP
from morpheus_llm.service.vdb.milvus_vector_db_service import MilvusVectorDBService
from morpheus_llm.service.vdb.utils import VectorDBServiceFactory
from morpheus_llm.service.vdb.vector_db_service import VectorDBService
from mrc.core import operators as ops
from pymilvus import BulkInsertState
from pymilvus import Collection
from pymilvus import connections
from pymilvus import utility
@@ -74,8 +72,7 @@ def _bulk_ingest(
]
uri_parsed = urlparse(milvus_uri)
conn = connections.connect(host=uri_parsed.hostname, port=uri_parsed.port)
collection = Collection(name=collection_name)
_ = connections.connect(host=uri_parsed.hostname, port=uri_parsed.port)
task_ids = []
for file in batch_files:

View File

@@ -65,14 +65,14 @@ def fetch_and_process_messages(client, validated_config: MessageBrokerTaskSource
if job.response_code != 0:
continue
logger.debug(f"Received ResponseSchema, converting to dict")
logger.debug("Received ResponseSchema, converting to dict")
job = json.loads(job.response)
else:
logger.debug(f"Received something not a ResponseSchema")
logger.debug("Received something not a ResponseSchema")
ts_fetched = datetime.now()
yield process_message(job, ts_fetched)
except TimeoutError as err:
except TimeoutError:
continue
except Exception as err:
logger.error(

View File

@@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
MODULE_NAME = "image_storage"
MODULE_NAMESPACE = "nv_ingest"
# TODO: Move these into microservice_entrypoint.py to populate the stage and validate them using the pydantic schema on startup.
# TODO: Move these into microservice_entrypoint.py to populate the stage and validate them using the pydantic schema
# on startup.
_DEFAULT_ENDPOINT = os.environ.get("MINIO_INTERNAL_ADDRESS", "minio:9000")
_DEFAULT_READ_ADDRESS = os.environ.get("MINIO_PUBLIC_ADDRESS", "http://minio:9000")
_DEFAULT_BUCKET_NAME = os.environ.get("MINIO_BUCKET", "nv-ingest")

View File

@@ -4,9 +4,11 @@
import logging
from typing import Union, Optional
from typing import Optional
from pydantic import BaseModel, Field, Extra
from pydantic import BaseModel
from pydantic import Extra
from pydantic import Field
logger = logging.getLogger(__name__)

View File

@@ -11,7 +11,6 @@
import json
import logging
import os
import uuid
from json import JSONDecodeError
from typing import Any

View File

@@ -45,7 +45,8 @@ def decode_and_extract(
validated_config : Any
Configuration object that contains `image_config`. Used if the `image` method is selected.
default : str, optional
The default extraction method to use if the specified method in `task_props` is not available (default is "image").
The default extraction method to use if the specified method in `task_props` is not available
(default is "image").
Returns
-------
@@ -138,7 +139,8 @@ def process_image(
-------
Tuple[pd.DataFrame, Dict[str, Any]]
A tuple containing:
- A pandas DataFrame with the processed image content, including columns 'document_type', 'metadata', and 'uuid'.
- A pandas DataFrame with the processed image content, including columns 'document_type', 'metadata',
and 'uuid'.
- A dictionary with trace information collected during processing.
Raises

View File

@@ -139,15 +139,16 @@ class MultiProcessingBaseStage(SinglePortStage):
forwards the task to a global multi-process worker pool where the heavy-lifting occurs.
3. **Global Worker Pool**: The work is executed in parallel across multiple process engines via the worker pool.
Each process engine applies the `process_fn` to the task data, which includes a pandas DataFrame and task-specific arguments.
Each process engine applies the `process_fn` to the task data, which includes a pandas DataFrame and
task-specific arguments.
4. **Response Queue**: After the work is completed by the worker pool, the results are pushed into a response queue.
5. **Post-Processing and Emission**: The results from the response queue are post-processed, reconstructed into their
original format, and emitted from an observable source for further downstream processing or final output.
5. **Post-Processing and Emission**: The results from the response queue are post-processed, reconstructed into
their original format, and emitted from an observable source for further downstream processing or final output.
This design enhances parallelism and resource utilization across multiple processes, especially for tasks that involve
heavy computations, such as large DataFrame operations.
This design enhances parallelism and resource utilization across multiple processes, especially for tasks that
involve heavy computations, such as large DataFrame operations.
"""
def __init__(

View File

@@ -186,7 +186,7 @@ def _extract_chart_data(
return df, {"trace_info": trace_info}
except Exception as e:
except Exception:
logger.error("Error occurred while extracting chart data.", exc_info=True)
raise
finally:

View File

@@ -139,8 +139,6 @@ def _extract_table_data(
stage_config = validated_config.stage_config
paddle_infer_protocol = stage_config.paddle_infer_protocol.lower()
# Obtain paddle_version
# Assuming that the grpc endpoint is at index 0
paddle_endpoint = stage_config.paddle_endpoints[1]
@@ -169,7 +167,7 @@ def _extract_table_data(
return df, {"trace_info": trace_info}
except Exception as e:
except Exception:
logger.error("Error occurred while extracting table data.", exc_info=True)
raise
finally:

View File

@@ -44,7 +44,8 @@ def decode_and_extract(
validated_config : Any
Configuration object that contains `pdfium_config`. Used if the `pdfium` method is selected.
default : str, optional
The default extraction method to use if the specified method in `task_props` is not available (default is "pdfium").
The default extraction method to use if the specified method in `task_props` is not available
(default is "pdfium").
Returns
-------

View File

@@ -6,12 +6,9 @@ import functools
import logging
import os
import traceback
import typing
import uuid
from typing import Any
from typing import Dict
import mrc
import pandas as pd
from minio import Minio
from morpheus.config import Config
@@ -35,11 +32,9 @@ def upload_embeddings(df: pd.DataFrame, params: Dict[str, Any]) -> pd.DataFrame:
Identify contents (e.g., images) within a dataframe and uploads the data to MinIO.
The image metadata in the metadata column is updated with the URL of the uploaded data.
"""
dimension = params.get("dim", 1024)
access_key = params.get("access_key", None)
secret_key = params.get("secret_key", None)
content_types = params.get("content_types")
endpoint = params.get("endpoint", _DEFAULT_ENDPOINT)
bucket_name = params.get("bucket_name", _DEFAULT_BUCKET_NAME)
bucket_path = params.get("bucket_path", "embeddings")
@@ -73,7 +68,6 @@ def upload_embeddings(df: pd.DataFrame, params: Dict[str, Any]) -> pd.DataFrame:
)
for idx, row in df.iterrows():
uu_id = row["uuid"]
metadata = row["metadata"].copy()
metadata["embedding_metadata"] = {}
metadata["embedding_metadata"]["uploaded_embedding_url"] = bucket_path

View File

@@ -2,8 +2,6 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import base64
import io
import logging
from functools import partial
from typing import Any
@@ -14,13 +12,11 @@ from typing import Tuple
import pandas as pd
import requests
from morpheus.config import Config
from PIL import Image
from nv_ingest.schemas.image_caption_extraction_schema import ImageCaptionExtractionSchema
from nv_ingest.schemas.metadata_schema import ContentTypeEnum
from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size
from nv_ingest.util.tracing.tagging import traceable_func
logger = logging.getLogger(__name__)

View File

@@ -145,9 +145,12 @@ class RedisClient(MessageBrokerClientBase):
-------
Tuple[Optional[Dict[str, Any]], Optional[int], Optional[int]]
A tuple containing:
- message: A dictionary containing the decoded message if successful, or None if no message was retrieved.
- fragment: An integer representing the fragment number of the message, or None if no fragment was found.
- fragment_count: An integer representing the total number of message fragments, or None if no fragment count was found.
- message: A dictionary containing the decoded message if successful,
or None if no message was retrieved.
- fragment: An integer representing the fragment number of the message,
or None if no fragment was found.
- fragment_count: An integer representing the total number of message fragments,
or None if no fragment count was found.
Raises
------

View File

@@ -69,8 +69,6 @@ class SimpleMessageBrokerHandler(socketserver.BaseRequestHandler):
# Validate and extract common fields
queue_name = request_data.get("queue_name")
message = request_data.get("message")
timeout = request_data.get("timeout", 100)
# Initialize the queue and its lock if necessary
if queue_name:

View File

@@ -11,9 +11,7 @@ import socket
import json
import time
import logging
from typing import Optional, Union
from pydantic import BaseModel
from typing import Optional
from nv_ingest.schemas.message_brokers.response_schema import ResponseSchema
from nv_ingest_client.message_clients.client_base import MessageBrokerClientBase
@@ -230,9 +228,9 @@ class SimpleClient(MessageBrokerClientBase):
return ResponseSchema(**final_response)
except (ConnectionError, socket.error, BrokenPipeError) as e:
except (ConnectionError, socket.error, BrokenPipeError):
pass
except json.JSONDecodeError as e:
except json.JSONDecodeError:
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
except Exception as e:
return ResponseSchema(response_code=1, response_reason=str(e))
@@ -305,9 +303,9 @@ class SimpleClient(MessageBrokerClientBase):
else:
return ResponseSchema(**final_response)
except (ConnectionError, socket.error, BrokenPipeError) as e:
except (ConnectionError, socket.error, BrokenPipeError):
pass
except json.JSONDecodeError as e:
except json.JSONDecodeError:
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
except Exception as e:
return ResponseSchema(response_code=1, response_reason=str(e))
@@ -342,7 +340,7 @@ class SimpleClient(MessageBrokerClientBase):
return ResponseSchema(**response)
except (ConnectionError, socket.error, BrokenPipeError) as e:
return ResponseSchema(response_code=1, response_reason=f"Connection error: {e}")
except json.JSONDecodeError as e:
except json.JSONDecodeError:
return ResponseSchema(response_code=1, response_reason="Invalid JSON response from server.")
except Exception as e:
return ResponseSchema(response_code=1, response_reason=str(e))
@@ -371,7 +369,7 @@ class SimpleClient(MessageBrokerClientBase):
try:
sock.sendall(total_length.to_bytes(8, "big"))
sock.sendall(data)
except (socket.error, BrokenPipeError) as e:
except (socket.error, BrokenPipeError):
raise ConnectionError("Failed to send data.")
def _recv(self, sock: socket.socket) -> str:
@@ -403,7 +401,7 @@ class SimpleClient(MessageBrokerClientBase):
if not data_bytes:
raise ConnectionError("Incomplete message received.")
return data_bytes.decode("utf-8")
except (socket.error, BrokenPipeError, ConnectionError) as e:
except (socket.error, BrokenPipeError, ConnectionError):
raise ConnectionError("Failed to receive data.")
def _recv_exact(self, sock: socket.socket, num_bytes: int) -> Optional[bytes]:
@@ -430,8 +428,8 @@ class SimpleClient(MessageBrokerClientBase):
if not packet:
return None
data.extend(packet)
except socket.timeout as e:
except socket.timeout:
return None
except Exception as e:
except Exception:
return None
return bytes(data)

View File

@@ -4,16 +4,9 @@
import logging
import re
from math import ceil
from math import floor
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
from nv_ingest.util.image_processing.transforms import numpy_to_base64
ACCEPTED_TEXT_CLASSES = set(
[
"Text",

View File

@@ -5,7 +5,8 @@
import logging
import re
import time
from typing import Optional, Any
from typing import Any
from typing import Optional
from typing import Tuple
import backoff
@@ -271,7 +272,8 @@ class NimClient:
if status_code in [429, 503]:
# Warn and attempt to retry
logger.warning(
f"Received HTTP {status_code} ({response.reason}) from {self.model_interface.name()}. Retrying..."
f"Received HTTP {status_code} ({response.reason}) from "
f"{self.model_interface.name()}. Retrying..."
)
if attempt == max_retries:
# No more retries left
@@ -290,7 +292,10 @@ class NimClient:
return response.json()
except requests.Timeout:
err_msg = f"HTTP request timed out during {self.model_interface.name()} inference after {self.timeout} seconds"
err_msg = (
f"HTTP request timed out during {self.model_interface.name()} "
f"inference after {self.timeout} seconds"
)
logger.error(err_msg)
raise TimeoutError(err_msg)

View File

@@ -91,7 +91,7 @@ def test_apply_dedup(should_filter, expected0, expected1, expected2):
payload_list = []
for _ in range(3):
payload_list.append(valid_image_dedup_payload(f"test", 1, 1))
payload_list.append(valid_image_dedup_payload("test", 1, 1))
extracted_df = pd.DataFrame(payload_list, columns=["document_type", "metadata"])
extracted_gdf = cudf.from_pandas(extracted_df)

View File

@@ -98,11 +98,6 @@ def test_image_metadata_schema_invalid_type():
ImageMetadataSchema(image_type=3.14) # Using a float value
def test_image_metadata_schema_invalid_type():
with pytest.raises(ValidationError):
ImageMetadataSchema(image_type=3.14)
# Test cases for TableMetadataSchema
@pytest.mark.parametrize("table_format", ["html", "markdown", "latex", "image"])
def test_table_metadata_schema_defaults(table_format):

View File

@@ -23,6 +23,6 @@ def test_otel_meter_schema_custom_values():
assert schema.broker_client.host == "custom_host", "Custom host value for redis_client should be respected."
assert schema.broker_client.port == 12345, "Custom port value for redis_client should be respected."
assert (
schema.broker_client.broker_params["use_ssl"] == True
schema.broker_client.broker_params["use_ssl"] is True
), "Custom use_ssl value for broker_client should be True."
assert schema.raise_on_failure is True, "Custom value for raise_on_failure should be respected."

View File

@@ -31,7 +31,7 @@ def test_redis_task_sink_schema_custom_values():
assert schema.broker_client.host == "custom_host", "Custom host value for broker_client should be respected."
assert schema.broker_client.port == 12345, "Custom port value for broker_client should be respected."
assert (
schema.broker_client.broker_params["use_ssl"] == True
schema.broker_client.broker_params["use_ssl"] is True
), "Custom use_ssl value for redis_client should be True."
assert schema.raise_on_failure is True, "Custom value for raise_on_failure should be respected."
assert schema.progress_engines == 10, "Custom value for progress_engines should be respected."

View File

@@ -33,7 +33,7 @@ def test_redis_task_source_schema_custom_values():
assert schema.broker_client.host == "custom_host", "Custom host value for redis_client should be respected."
assert schema.broker_client.port == 12345, "Custom port value for redis_client should be respected."
assert (
schema.broker_client.broker_params["use_ssl"] == True
schema.broker_client.broker_params["use_ssl"] is True
), "Custom use_ssl value for redis_client should be True."
assert schema.task_queue == "custom_queue", "Custom value for task_queue should be respected."
assert schema.progress_engines == 10, "Custom value for progress_engines should be respected."

View File

@@ -1,12 +1,12 @@
import pytest
from unittest.mock import Mock
from unittest.mock import patch
import pytest
import requests
import pandas as pd
from unittest.mock import Mock, patch
from nv_ingest.stages.nim.chart_extraction import _update_metadata
from nv_ingest.stages.nim.chart_extraction import _extract_chart_data
import requests
MODULE_UNDER_TEST = "nv_ingest.stages.nim.chart_extraction"

View File

@@ -1,20 +1,17 @@
import pytest
import base64
import cv2
import numpy as np
import pandas as pd
from unittest.mock import Mock, patch
from io import BytesIO
from unittest.mock import Mock
from unittest.mock import patch
import cv2
import numpy as np
import pandas as pd
import pytest
import requests
from PIL import Image
from nv_ingest.stages.nim.table_extraction import _update_metadata, _extract_table_data
from nv_ingest.stages.nim.table_extraction import _extract_table_data
from nv_ingest.stages.nim.table_extraction import _update_metadata
from nv_ingest.util.nim.helpers import NimClient
from nv_ingest.util.nim.paddle import PaddleOCRModelInterface

View File

@@ -10,8 +10,6 @@ import pytest
import requests
from PIL import Image
MODULE_UNDER_TEST = "nv_ingest.stages.transforms.image_caption_extraction"
import pandas as pd
from nv_ingest.schemas.metadata_schema import ContentTypeEnum
@@ -19,6 +17,8 @@ from nv_ingest.stages.transforms.image_caption_extraction import _generate_capti
from nv_ingest.stages.transforms.image_caption_extraction import _prepare_dataframes_mod
from nv_ingest.stages.transforms.image_caption_extraction import caption_extract_stage
MODULE_UNDER_TEST = "nv_ingest.stages.transforms.image_caption_extraction"
def generate_base64_png_image() -> str:
"""Helper function to generate a base64-encoded PNG image."""

View File

@@ -488,18 +488,6 @@ def test_create_inference_client_http_endpoint_whitespace_no_infer_protocol(mock
# Preprocess image for paddle
@pytest.fixture
def sample_image():
"""
Returns a sample image array of shape (height, width, channels) with random pixel values.
"""
height, width = 800, 600 # Example dimensions
image = np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
return image
def test_preprocess_image_paddle_version_none(sample_image):
"""
Test that when paddle_version is None, the function returns the input image unchanged.

View File

@@ -143,6 +143,8 @@ def test_parse_output_http_valid(model_interface):
{
"type": "table",
"bboxes": [{"xmin": 0.1, "ymin": 0.1, "xmax": 0.2, "ymax": 0.2, "confidence": 0.9}],
},
{
"type": "chart",
"bboxes": [{"xmin": 0.3, "ymin": 0.3, "xmax": 0.4, "ymax": 0.4, "confidence": 0.8}],
},
@@ -152,6 +154,8 @@ def test_parse_output_http_valid(model_interface):
{
"type": "table",
"bboxes": [{"xmin": 0.15, "ymin": 0.15, "xmax": 0.25, "ymax": 0.25, "confidence": 0.85}],
},
{
"type": "chart",
"bboxes": [{"xmin": 0.35, "ymin": 0.35, "xmax": 0.45, "ymax": 0.45, "confidence": 0.75}],
},
@@ -166,7 +170,7 @@ def test_parse_output_http_valid(model_interface):
data = {"scaling_factors": scaling_factors}
parsed_output = model_interface.parse_output(response, "http", data)
assert isinstance(parsed_output, np.ndarray)
assert parsed_output.shape == (2, 2, 85)
assert parsed_output.shape == (2, 3, 85)
assert parsed_output.dtype == np.float32

View File

@@ -181,7 +181,7 @@ def test_store_embed_task_no_args(ingestor):
assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[0], StoreEmbedTask)
def test_store_task_some_args(ingestor):
def test_store_task_some_args_extra_param(ingestor):
ingestor.store_embed(params={"extra_param": "extra"})
task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0]

View File

@@ -146,7 +146,7 @@ def test_init_with_files(mocker, job_spec_fixture):
assert len(batch_job_spec._file_type_to_job_spec["pdf"]) > 0
def test_add_task_to_specific_document_type(batch_job_spec_fixture):
def test_add_task_to_specific_document_type_batch_job_spec(batch_job_spec_fixture):
task = MockTask()
# Add task to jobs with document_type 'pdf'
@@ -248,7 +248,7 @@ def test_add_task_to_all_documents():
assert dedup_task in job_specs[0]._tasks
def test_add_task_to_specific_document_type():
def test_add_task_to_specific_document_type_job_spec():
batch_job_spec = BatchJobSpec([JobSpec(document_type="pdf"), JobSpec(document_type="txt")])
embed_task = EmbedTask()

View File

@@ -32,7 +32,8 @@ def test_extract_task_str_representation(document_type, extract_method, extract_
f"extract text: {extract_text}",
f"extract images: {extract_images}",
f"extract tables: {extract_tables}",
f"extract charts: {extract_tables}", # If extract_charts is not specified, it defaults to the same value as extract_tables.
f"extract charts: {extract_tables}", # If extract_charts is not specified,
# it defaults to the same value as extract_tables.
"text depth: document", # Assuming this is a fixed value for all instances
]
@@ -107,7 +108,8 @@ def test_extract_task_initialization(extract_method, extract_text, extract_image
@pytest.mark.parametrize(
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method, paddle_output_format",
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method,"
"paddle_output_format",
[
("pdf", "tika", True, False, False, "yolox", "pseudo_markdown"),
("docx", "haystack", False, True, True, "python_docx", "simple"),
@@ -142,7 +144,8 @@ def test_extract_task_to_dict_basic(
"extract_images": extract_images,
"extract_tables": extract_tables,
"extract_tables_method": extract_tables_method,
"extract_charts": extract_tables, # If extract_charts is not specified, it defaults to the same value as extract_tables.
"extract_charts": extract_tables, # If extract_charts is not specified,
# it defaults to the same value as extract_tables.
"text_depth": "document",
"paddle_output_format": paddle_output_format,
},
@@ -153,7 +156,10 @@ def test_extract_task_to_dict_basic(
@pytest.mark.parametrize(
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method, extract_charts, paddle_output_format",
(
"document_type, extract_method, extract_text, extract_images, extract_tables, extract_tables_method,"
"extract_charts, paddle_output_format"
),
[
("pdf", "tika", True, False, False, "yolox", False, "pseudo_markdown"),
("docx", "haystack", False, True, True, "python_docx", False, "simple"),