Support options on pyproject.toml (#190)

* support for pyproject.toml

* add unittest

* remove comments

* fix format

* update documents
This commit is contained in:
Koudai Aono
2020-08-03 01:27:15 +09:00
committed by GitHub
parent 4d8ed032dd
commit 3e0522fdd7
8 changed files with 287 additions and 40 deletions

View File

@@ -10,10 +10,14 @@ import sys
from argparse import ArgumentParser, FileType, Namespace
from collections import defaultdict
from enum import IntEnum
from io import TextIOBase
from pathlib import Path
from typing import Any, DefaultDict, Dict, Optional, Sequence
from typing import Any, DefaultDict, Dict, Mapping, Optional, Sequence, TextIO
import argcomplete
import black
import toml
from pydantic import BaseModel
from datamodel_code_generator import (
DEFAULT_BASE_CLASS,
@@ -43,23 +47,16 @@ signal.signal(signal.SIGINT, sig_int_handler)
arg_parser = ArgumentParser()
arg_parser.add_argument(
'--input',
help='Input file (default: stdin)',
type=FileType('rt'),
default=sys.stdin,
'--input', help='Input file (default: stdin)', type=FileType('rt'),
)
arg_parser.add_argument(
'--input-file-type',
help='Input file type (default: auto)',
choices=[i.value for i in InputFileType],
default='auto',
)
arg_parser.add_argument('--output', help='Output file (default: stdout)')
arg_parser.add_argument(
'--base-class',
help='Base Class (default: pydantic.BaseModel)',
type=str,
default=DEFAULT_BASE_CLASS,
'--base-class', help='Base Class (default: pydantic.BaseModel)', type=str,
)
arg_parser.add_argument(
'--field-constraints',
@@ -85,7 +82,6 @@ arg_parser.add_argument(
'--target-python-version',
help='target python version (default: 3.7)',
choices=['3.6', '3.7'],
default='3.7',
)
arg_parser.add_argument(
'--validation', help='Enable validation (Only OpenAPI)', action='store_true'
@@ -94,6 +90,32 @@ arg_parser.add_argument('--debug', help='show debug message', action='store_true
arg_parser.add_argument('--version', help='show version', action='store_true')
class Config(BaseModel):
class Config:
validate_assignment = True
arbitrary_types_allowed = (TextIOBase,)
input_file_type: InputFileType = InputFileType.Auto
output: Optional[Path]
debug: bool = False
target_python_version: PythonVersion = PythonVersion.PY_37
base_class: str = DEFAULT_BASE_CLASS
custom_template_dir: Optional[str]
extra_template_data: Optional[TextIOBase]
validation: bool = False
field_constraints: bool = False
snake_case_field: bool = False
strip_default_none: bool = False
aliases: Optional[TextIOBase]
def merge_args(self, args: Namespace) -> None:
for field_name in self.__fields__:
arg = getattr(args, field_name)
if arg is None:
continue
setattr(self, field_name, arg)
def main(args: Optional[Sequence[str]] = None) -> Exit:
"""Main function."""
@@ -111,12 +133,39 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
print(version)
exit(0)
if namespace.debug: # pragma: no cover
root = black.find_project_root((Path().resolve(),))
pyproject_toml_path = root / "pyproject.toml"
if pyproject_toml_path.is_file():
pyproject_toml: Dict[str, Any] = {
k.replace('-', '_'): v
for k, v in toml.load(str(pyproject_toml_path))
.get('tool', {})
.get('datamodel-codegen', {})
.items()
}
else:
pyproject_toml = {}
config = Config.parse_obj(pyproject_toml)
config.merge_args(namespace)
if namespace.input:
input_name: str = namespace.input.name
input_text: str = namespace.input.read()
elif 'input' in pyproject_toml:
input_path = Path(pyproject_toml['input'])
input_name = input_path.name
input_text = input_path.read_text()
else:
input_name = '<stdin>'
input_text = sys.stdin.read()
if config.debug: # pragma: no cover
enable_debug_message()
extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]]
if namespace.extra_template_data is not None:
with namespace.extra_template_data as data:
if config.extra_template_data is not None:
with config.extra_template_data as data:
try:
extra_template_data = json.load(
data, object_hook=lambda d: defaultdict(dict, **d)
@@ -127,8 +176,8 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
else:
extra_template_data = None
if namespace.aliases is not None:
with namespace.aliases as data:
if config.aliases is not None:
with config.aliases as data:
try:
aliases = json.load(data)
except json.JSONDecodeError as e:
@@ -147,18 +196,18 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
try:
generate(
input_name=namespace.input.name,
input_text=namespace.input.read(),
input_file_type=InputFileType(namespace.input_file_type),
output=Path(namespace.output) if namespace.output is not None else None,
target_python_version=PythonVersion(namespace.target_python_version),
base_class=namespace.base_class,
custom_template_dir=namespace.custom_template_dir,
input_name=input_name,
input_text=input_text,
input_file_type=config.input_file_type,
output=Path(config.output) if config.output is not None else None,
target_python_version=config.target_python_version,
base_class=config.base_class,
custom_template_dir=config.custom_template_dir,
extra_template_data=extra_template_data,
validation=namespace.validation,
field_constraints=namespace.field_constraints,
snake_case_field=namespace.snake_case_field,
strip_default_none=namespace.strip_default_none,
validation=config.validation,
field_constraints=config.field_constraints,
snake_case_field=config.snake_case_field,
strip_default_none=config.strip_default_none,
aliases=aliases,
)
return Exit.OK

12
docs/pyproject_toml.md Normal file
View File

@@ -0,0 +1,12 @@
datamodel-code-generator has a lot of command-line options.
The options are supported on `pyproject.toml`.
Example `pyproject.toml`:
```toml
[tool.datamodel-codegen]
field-constraints = true
snake-case-field = true
strip-default-none = false
target-python-version = "3.7"
```

View File

@@ -27,6 +27,7 @@ nav:
- Generate from JSON Data: jsondata.md
- Formatting: formatting.md
- Field Constraints: field-constraints.md
- pyproject.toml: pyproject_toml.md
- Development-Contributing: development-contributing.md
plugins:

View File

@@ -2,10 +2,6 @@
# filename: api.yaml
# timestamp: 2019-07-26T00:00:00+00:00
from __future__ import (
annotations,
)
from typing import (
List,
Optional,
@@ -25,7 +21,7 @@ class Pet(BaseModel):
class Pets(BaseModel):
__root__: List[Pet]
__root__: List["Pet"]
class User(BaseModel):
@@ -35,7 +31,7 @@ class User(BaseModel):
class Users(BaseModel):
__root__: List[User]
__root__: List["User"]
class Id(BaseModel):
@@ -79,7 +75,7 @@ class Api(BaseModel):
class Apis(BaseModel):
__root__: List[Api]
__root__: List["Api"]
class Event(BaseModel):
@@ -88,5 +84,5 @@ class Event(BaseModel):
class Result(BaseModel):
event: Optional[
Event
"Event"
] = None

View File

@@ -0,0 +1,69 @@
# generated by datamodel-codegen:
# filename: api.yaml
# timestamp: 2019-07-26T00:00:00+00:00
from __future__ import annotations
from typing import List, Optional
from pydantic import AnyUrl, BaseModel, Field
class Pet(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Pets(BaseModel):
__root__: List[Pet]
class User(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Users(BaseModel):
__root__: List[User]
class Id(BaseModel):
__root__: str
class Rules(BaseModel):
__root__: List[str]
class Error(BaseModel):
code: int
message: str
class Api(BaseModel):
apiKey: Optional[str] = Field(
None, description='To be used as a dataset parameter value'
)
apiVersionNumber: Optional[str] = Field(
None, description='To be used as a version parameter value'
)
apiUrl: Optional[AnyUrl] = Field(
None, description="The URL describing the dataset's fields"
)
apiDocumentationUrl: Optional[AnyUrl] = Field(
None, description='A URL to the API console for each API'
)
class Apis(BaseModel):
__root__: List[Api]
class Event(BaseModel):
name: Optional[str] = None
class Result(BaseModel):
event: Optional[Event] = None

View File

@@ -0,0 +1,69 @@
# generated by datamodel-codegen:
# filename: <stdin>
# timestamp: 2019-07-26T00:00:00+00:00
from __future__ import annotations
from typing import List, Optional
from pydantic import AnyUrl, BaseModel, Field
class Pet(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Pets(BaseModel):
__root__: List[Pet]
class User(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Users(BaseModel):
__root__: List[User]
class Id(BaseModel):
__root__: str
class Rules(BaseModel):
__root__: List[str]
class Error(BaseModel):
code: int
message: str
class Api(BaseModel):
apiKey: Optional[str] = Field(
None, description='To be used as a dataset parameter value'
)
apiVersionNumber: Optional[str] = Field(
None, description='To be used as a version parameter value'
)
apiUrl: Optional[AnyUrl] = Field(
None, description="The URL describing the dataset's fields"
)
apiDocumentationUrl: Optional[AnyUrl] = Field(
None, description='A URL to the API console for each API'
)
class Apis(BaseModel):
__root__: List[Api]
class Event(BaseModel):
name: Optional[str] = None
class Result(BaseModel):
event: Optional[Event] = None

View File

@@ -1,3 +1,13 @@
[tool.black]
skip-string-normalization = false
line-length = 30
[tool.datamodel-codegen]
input = "INPUT_PATH"
output = "OUTPUT_PATH"
input_file_type = 'openapi'
validation = true
field-constraints = true
snake-case-field = true
strip-default-none = true
target-python-version = "3.6"

View File

@@ -1,3 +1,4 @@
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
@@ -385,11 +386,37 @@ def test_main_custom_template_dir(capsys: CaptureFixture) -> None:
@freeze_time('2019-07-26')
def test_pyproject():
current_dir = os.getcwd()
with TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
pyproject_toml = Path(DATA_PATH) / "project" / "pyproject.toml"
shutil.copy(pyproject_toml, output_dir)
output_file: Path = output_dir / 'output.py'
pyproject_toml_path = Path(DATA_PATH) / "project" / "pyproject.toml"
pyproject_toml = (
pyproject_toml_path.read_text()
.replace('INPUT_PATH', str(OPEN_API_DATA_PATH / 'api.yaml'))
.replace('OUTPUT_PATH', str(output_file))
)
(output_dir / 'pyproject.toml').write_text(pyproject_toml)
os.chdir(output_dir)
return_code: Exit = main([])
assert return_code == Exit.OK
assert (
output_file.read_text()
== (EXPECTED_MAIN_PATH / 'pyproject' / 'output.py').read_text()
)
os.chdir(current_dir)
with pytest.raises(SystemExit):
main()
@freeze_time('2019-07-26')
def test_pyproject_not_found():
current_dir = os.getcwd()
with TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
output_file: Path = output_dir / 'output.py'
os.chdir(output_dir)
return_code: Exit = main(
[
'--input',
@@ -401,11 +428,25 @@ def test_pyproject():
assert return_code == Exit.OK
assert (
output_file.read_text()
== (EXPECTED_MAIN_PATH / 'pyproject' / 'output.py').read_text()
== (EXPECTED_MAIN_PATH / 'pyproject_not_found' / 'output.py').read_text()
)
os.chdir(current_dir)
with pytest.raises(SystemExit):
main()
@freeze_time('2019-07-26')
def test_stdin(monkeypatch):
with TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
output_file: Path = output_dir / 'output.py'
monkeypatch.setattr('sys.stdin', (OPEN_API_DATA_PATH / 'api.yaml').open())
return_code: Exit = main(
['--output', str(output_file),]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (EXPECTED_MAIN_PATH / 'stdin' / 'output.py').read_text()
)
@freeze_time('2019-07-26')