mirror of
https://github.com/microsoft/graphrag.git
synced 2025-03-11 01:26:14 +03:00
Support JSON input files (#1777)
* Add csv loader tests * Add test loader tests * Add json input support * Remove temp path constraint * Reuse loader cose * Semver * Set file pattern automatically based on type, if empty * Remove pattern from smoke test config * Spelling --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Add support for JSON inuput files."
|
||||
}
|
||||
@@ -188,6 +188,8 @@ upvote
|
||||
# Misc
|
||||
Arxiv
|
||||
kwds
|
||||
jsons
|
||||
txts
|
||||
|
||||
# Dulce
|
||||
astrotechnician
|
||||
|
||||
@@ -257,7 +257,7 @@ class InputDefaults:
|
||||
storage_account_blob_url: None = None
|
||||
container_name: None = None
|
||||
encoding: str = "utf-8"
|
||||
file_pattern: str = ".*\\.txt$"
|
||||
file_pattern: str = ""
|
||||
file_filter: None = None
|
||||
text_column: str = "text"
|
||||
title_column: None = None
|
||||
|
||||
@@ -34,6 +34,8 @@ class InputFileType(str, Enum):
|
||||
"""The CSV input type."""
|
||||
text = "text"
|
||||
"""The text input type."""
|
||||
json = "json"
|
||||
"""The JSON input type."""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
|
||||
@@ -70,10 +70,8 @@ embed_text:
|
||||
|
||||
input:
|
||||
type: {graphrag_config_defaults.input.type.value} # or blob
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # or csv
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
base_dir: "{graphrag_config_defaults.input.base_dir}"
|
||||
file_encoding: {graphrag_config_defaults.input.encoding}
|
||||
file_pattern: ".*\\\\.txt$$"
|
||||
|
||||
chunks:
|
||||
size: {graphrag_config_defaults.chunks.size}
|
||||
|
||||
@@ -166,6 +166,14 @@ class GraphRagConfig(BaseModel):
|
||||
)
|
||||
"""The input configuration."""
|
||||
|
||||
def _validate_input_pattern(self) -> None:
|
||||
"""Validate the input file pattern based on the specified type."""
|
||||
if len(self.input.file_pattern) == 0:
|
||||
if self.input.file_type == defs.InputFileType.text:
|
||||
self.input.file_pattern = ".*\\.txt$"
|
||||
else:
|
||||
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
|
||||
|
||||
embed_graph: EmbedGraphConfig = Field(
|
||||
description="Graph embedding configuration.",
|
||||
default=EmbedGraphConfig(),
|
||||
@@ -336,6 +344,7 @@ class GraphRagConfig(BaseModel):
|
||||
"""Validate the model configuration."""
|
||||
self._validate_root_dir()
|
||||
self._validate_models()
|
||||
self._validate_input_pattern()
|
||||
self._validate_reporting_base_dir()
|
||||
self._validate_output_base_dir()
|
||||
self._validate_multi_output_base_dirs()
|
||||
|
||||
@@ -4,24 +4,19 @@
|
||||
"""A module containing load method definition."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.index.input.util import load_files, process_data_columns
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FILE_PATTERN = re.compile(r"(?P<filename>[^\\/]).csv$")
|
||||
|
||||
input_type = "csv"
|
||||
|
||||
|
||||
async def load(
|
||||
async def load_csv(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
@@ -39,61 +34,12 @@ async def load(
|
||||
data[[*additional_keys]] = data.apply(
|
||||
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
|
||||
)
|
||||
if "id" not in data.columns:
|
||||
data["id"] = data.apply(lambda x: gen_sha512_hash(x, x.keys()), axis=1)
|
||||
if config.text_column is not None and "text" not in data.columns:
|
||||
if config.text_column not in data.columns:
|
||||
log.warning(
|
||||
"text_column %s not found in csv file %s",
|
||||
config.text_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
data["text"] = data.apply(lambda x: x[config.text_column], axis=1)
|
||||
if config.title_column is not None:
|
||||
if config.title_column not in data.columns:
|
||||
log.warning(
|
||||
"title_column %s not found in csv file %s",
|
||||
config.title_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
data["title"] = data.apply(lambda x: x[config.title_column], axis=1)
|
||||
else:
|
||||
data["title"] = data.apply(lambda _: path, axis=1)
|
||||
|
||||
data = process_data_columns(data, config, path)
|
||||
|
||||
creation_date = await storage.get_creation_date(path)
|
||||
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
|
||||
|
||||
return data
|
||||
|
||||
file_pattern = (
|
||||
re.compile(config.file_pattern)
|
||||
if config.file_pattern is not None
|
||||
else DEFAULT_FILE_PATTERN
|
||||
)
|
||||
files = list(
|
||||
storage.find(
|
||||
file_pattern,
|
||||
progress=progress,
|
||||
file_filter=config.file_filter,
|
||||
)
|
||||
)
|
||||
|
||||
if len(files) == 0:
|
||||
msg = f"No CSV files found in {config.base_dir}"
|
||||
raise ValueError(msg)
|
||||
|
||||
files_loaded = []
|
||||
|
||||
for file, group in files:
|
||||
try:
|
||||
files_loaded.append(await load_file(file, group))
|
||||
except Exception: # noqa: BLE001 (catching Exception is fine here)
|
||||
log.warning("Warning! Error loading csv file %s. Skipping...", file)
|
||||
|
||||
log.info("Found %d csv files, loading %d", len(files), len(files_loaded))
|
||||
result = pd.concat(files_loaded)
|
||||
total_files_log = f"Total number of unfiltered csv rows: {len(result)}"
|
||||
log.info(total_files_log)
|
||||
return result
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
|
||||
@@ -10,12 +10,11 @@ from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.enums import InputType
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.csv import input_type as csv
|
||||
from graphrag.index.input.csv import load as load_csv
|
||||
from graphrag.index.input.text import input_type as text
|
||||
from graphrag.index.input.text import load as load_text
|
||||
from graphrag.index.input.csv import load_csv
|
||||
from graphrag.index.input.json import load_json
|
||||
from graphrag.index.input.text import load_text
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
@@ -23,8 +22,9 @@ from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
text: load_text,
|
||||
csv: load_csv,
|
||||
InputFileType.text: load_text,
|
||||
InputFileType.csv: load_csv,
|
||||
InputFileType.json: load_json,
|
||||
}
|
||||
|
||||
|
||||
|
||||
49
graphrag/index/input/json.py
Normal file
49
graphrag/index/input/json.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing load method definition."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files, process_data_columns
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_json(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load json inputs from a directory."""
|
||||
log.info("Loading json files from %s", config.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
group = {}
|
||||
text = await storage.get(path, encoding=config.encoding)
|
||||
as_json = json.loads(text)
|
||||
# json file could just be a single object, or an array of objects
|
||||
rows = as_json if isinstance(as_json, list) else [as_json]
|
||||
data = pd.DataFrame(rows)
|
||||
|
||||
additional_keys = group.keys()
|
||||
if len(additional_keys) > 0:
|
||||
data[[*additional_keys]] = data.apply(
|
||||
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
|
||||
)
|
||||
|
||||
data = process_data_columns(data, config, path)
|
||||
|
||||
creation_date = await storage.get_creation_date(path)
|
||||
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
|
||||
|
||||
return data
|
||||
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
@@ -4,64 +4,34 @@
|
||||
"""A module containing load method definition."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
DEFAULT_FILE_PATTERN = re.compile(
|
||||
r".*[\\/](?P<source>[^\\/]+)[\\/](?P<year>\d{4})-(?P<month>\d{2})-(?P<day>\d{2})_(?P<author>[^_]+)_\d+\.txt"
|
||||
)
|
||||
input_type = "text"
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load(
|
||||
async def load_text(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load text inputs from a directory."""
|
||||
|
||||
async def load_file(
|
||||
path: str, group: dict | None = None, _encoding: str = "utf-8"
|
||||
) -> dict[str, Any]:
|
||||
async def load_file(path: str, group: dict | None = None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
group = {}
|
||||
text = await storage.get(path, encoding="utf-8")
|
||||
text = await storage.get(path, encoding=config.encoding)
|
||||
new_item = {**group, "text": text}
|
||||
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
|
||||
new_item["title"] = str(Path(path).name)
|
||||
new_item["creation_date"] = await storage.get_creation_date(path)
|
||||
return new_item
|
||||
return pd.DataFrame([new_item])
|
||||
|
||||
files = list(
|
||||
storage.find(
|
||||
re.compile(config.file_pattern),
|
||||
progress=progress,
|
||||
file_filter=config.file_filter,
|
||||
)
|
||||
)
|
||||
if len(files) == 0:
|
||||
msg = f"No text files found in {config.base_dir}"
|
||||
raise ValueError(msg)
|
||||
found_files = f"found text files from {config.base_dir}, found {files}"
|
||||
log.info(found_files)
|
||||
|
||||
files_loaded = []
|
||||
|
||||
for file, group in files:
|
||||
try:
|
||||
files_loaded.append(await load_file(file, group))
|
||||
except Exception: # noqa: BLE001 (catching Exception is fine here)
|
||||
log.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
|
||||
log.info("Found %d files, loading %d", len(files), len(files_loaded))
|
||||
|
||||
return pd.DataFrame(files_loaded)
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
|
||||
89
graphrag/index/input/util.py
Normal file
89
graphrag/index/input/util.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Shared column processing for structured input files."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_files(
|
||||
loader: Any,
|
||||
config: InputConfig,
|
||||
storage: PipelineStorage,
|
||||
progress: ProgressLogger | None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load files from storage and apply a loader function."""
|
||||
files = list(
|
||||
storage.find(
|
||||
re.compile(config.file_pattern),
|
||||
progress=progress,
|
||||
file_filter=config.file_filter,
|
||||
)
|
||||
)
|
||||
|
||||
if len(files) == 0:
|
||||
msg = f"No {config.file_type} files found in {config.base_dir}"
|
||||
raise ValueError(msg)
|
||||
|
||||
files_loaded = []
|
||||
|
||||
for file, group in files:
|
||||
try:
|
||||
files_loaded.append(await loader(file, group))
|
||||
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
|
||||
log.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
log.warning("Error: %s", e)
|
||||
|
||||
log.info(
|
||||
"Found %d %s files, loading %d", len(files), config.file_type, len(files_loaded)
|
||||
)
|
||||
result = pd.concat(files_loaded)
|
||||
total_files_log = (
|
||||
f"Total number of unfiltered {config.file_type} rows: {len(result)}"
|
||||
)
|
||||
log.info(total_files_log)
|
||||
return result
|
||||
|
||||
|
||||
def process_data_columns(
|
||||
documents: pd.DataFrame, config: InputConfig, path: str
|
||||
) -> pd.DataFrame:
|
||||
"""Process configured data columns of a DataFrame."""
|
||||
if "id" not in documents.columns:
|
||||
documents["id"] = documents.apply(
|
||||
lambda x: gen_sha512_hash(x, x.keys()), axis=1
|
||||
)
|
||||
if config.text_column is not None and "text" not in documents.columns:
|
||||
if config.text_column not in documents.columns:
|
||||
log.warning(
|
||||
"text_column %s not found in csv file %s",
|
||||
config.text_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
documents["text"] = documents.apply(lambda x: x[config.text_column], axis=1)
|
||||
if config.title_column is not None:
|
||||
if config.title_column not in documents.columns:
|
||||
log.warning(
|
||||
"title_column %s not found in csv file %s",
|
||||
config.title_column,
|
||||
path,
|
||||
)
|
||||
else:
|
||||
documents["title"] = documents.apply(
|
||||
lambda x: x[config.title_column], axis=1
|
||||
)
|
||||
else:
|
||||
documents["title"] = documents.apply(lambda _: path, axis=1)
|
||||
return documents
|
||||
1
tests/fixtures/min-csv/settings.yml
vendored
1
tests/fixtures/min-csv/settings.yml
vendored
@@ -34,7 +34,6 @@ vector_store:
|
||||
|
||||
input:
|
||||
file_type: csv
|
||||
file_pattern: ".*\\.csv$$"
|
||||
|
||||
snapshots:
|
||||
embeddings: True
|
||||
|
||||
2
tests/unit/indexing/input/__init__.py
Normal file
2
tests/unit/indexing/input/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
3
tests/unit/indexing/input/data/multiple-csvs/input1.csv
Normal file
3
tests/unit/indexing/input/data/multiple-csvs/input1.csv
Normal file
@@ -0,0 +1,3 @@
|
||||
title,text
|
||||
Hello,Hi how are you today?
|
||||
Goodbye,I'm outta here
|
||||
|
2
tests/unit/indexing/input/data/multiple-csvs/input2.csv
Normal file
2
tests/unit/indexing/input/data/multiple-csvs/input2.csv
Normal file
@@ -0,0 +1,2 @@
|
||||
title,text
|
||||
Adios,See you later
|
||||
|
2
tests/unit/indexing/input/data/multiple-csvs/input3.csv
Normal file
2
tests/unit/indexing/input/data/multiple-csvs/input3.csv
Normal file
@@ -0,0 +1,2 @@
|
||||
title,text
|
||||
Hi,I'm here
|
||||
|
10
tests/unit/indexing/input/data/multiple-jsons/input1.json
Normal file
10
tests/unit/indexing/input/data/multiple-jsons/input1.json
Normal file
@@ -0,0 +1,10 @@
|
||||
[{
|
||||
"title": "Hello",
|
||||
"text": "Hi how are you today?"
|
||||
}, {
|
||||
"title": "Goodbye",
|
||||
"text": "I'm outta here"
|
||||
}, {
|
||||
"title": "Adios",
|
||||
"text": "See you later"
|
||||
}]
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"title": "Hi",
|
||||
"text": "I'm here"
|
||||
}
|
||||
1
tests/unit/indexing/input/data/multiple-txts/input1.txt
Normal file
1
tests/unit/indexing/input/data/multiple-txts/input1.txt
Normal file
@@ -0,0 +1 @@
|
||||
Hi how are you today?
|
||||
1
tests/unit/indexing/input/data/multiple-txts/input2.txt
Normal file
1
tests/unit/indexing/input/data/multiple-txts/input2.txt
Normal file
@@ -0,0 +1 @@
|
||||
I'm outta here
|
||||
3
tests/unit/indexing/input/data/one-csv/input.csv
Normal file
3
tests/unit/indexing/input/data/one-csv/input.csv
Normal file
@@ -0,0 +1,3 @@
|
||||
title,text
|
||||
Hello,Hi how are you today?
|
||||
Goodbye,I'm outta here
|
||||
|
@@ -0,0 +1,10 @@
|
||||
[{
|
||||
"title": "Hello",
|
||||
"text": "Hi how are you today?"
|
||||
}, {
|
||||
"title": "Goodbye",
|
||||
"text": "I'm outta here"
|
||||
}, {
|
||||
"title": "Adios",
|
||||
"text": "See you later"
|
||||
}]
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"title": "Hello",
|
||||
"text": "Hi how are you today?"
|
||||
}
|
||||
1
tests/unit/indexing/input/data/one-txt/input.txt
Normal file
1
tests/unit/indexing/input/data/one-txt/input.txt
Normal file
@@ -0,0 +1 @@
|
||||
Hi how are you today?
|
||||
56
tests/unit/indexing/input/test_csv_loader.py
Normal file
56
tests/unit/indexing/input/test_csv_loader.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
|
||||
|
||||
async def test_csv_loader_one_file():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (2, 4)
|
||||
assert documents["title"].iloc[0] == "input.csv"
|
||||
|
||||
|
||||
async def test_csv_loader_one_file_with_title():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
title_column="title",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (2, 4)
|
||||
assert documents["title"].iloc[0] == "Hello"
|
||||
|
||||
|
||||
async def test_csv_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (2, 5)
|
||||
assert documents["metadata"][0] == {"title": "Hello"}
|
||||
|
||||
|
||||
async def test_csv_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/multiple-csvs",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (4, 4)
|
||||
69
tests/unit/indexing/input/test_json_loader.py
Normal file
69
tests/unit/indexing/input/test_json_loader.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
|
||||
|
||||
async def test_json_loader_one_file_one_object():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "input.json"
|
||||
|
||||
|
||||
async def test_json_loader_one_file_multiple_objects():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-multiple-objects",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
print(documents)
|
||||
assert documents.shape == (3, 4)
|
||||
assert documents["title"].iloc[0] == "input.json"
|
||||
|
||||
|
||||
async def test_json_loader_one_file_with_title():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
title_column="title",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "Hello"
|
||||
|
||||
|
||||
async def test_json_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (1, 5)
|
||||
assert documents["metadata"][0] == {"title": "Hello"}
|
||||
|
||||
|
||||
async def test_json_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/multiple-jsons",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (4, 4)
|
||||
43
tests/unit/indexing/input/test_txt_loader.py
Normal file
43
tests/unit/indexing/input/test_txt_loader.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
|
||||
|
||||
async def test_txt_loader_one_file():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "input.txt"
|
||||
|
||||
|
||||
async def test_txt_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
metadata=["title"],
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (1, 5)
|
||||
# unlike csv, we cannot set the title to anything other than the filename
|
||||
assert documents["metadata"][0] == {"title": "input.txt"}
|
||||
|
||||
|
||||
async def test_txt_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
base_dir="tests/unit/indexing/input/data/multiple-txts",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
assert documents.shape == (2, 4)
|
||||
Reference in New Issue
Block a user