Add custom formatters (#1733)

* Support custom formatters for CodeFormatter

* Add custom formatters argument

* Add graphql to docs/supported-data-types.md

* Add test

custom formatter for custom-scalar-types.graphql;

* Run poetry run scripts/format.sh

* Add simple doc
This commit is contained in:
Denis Artyushin
2023-11-24 18:59:32 +03:00
committed by GitHub
parent 3e0f0aaea5
commit a36ce94f2d
21 changed files with 351 additions and 5 deletions

View File

@@ -438,7 +438,10 @@ Template customization:
Wrap string literal by using black `experimental-
string-processing` option (require black 20.8b0 or
later)
--additional-imports Custom imports for output (delimited list input)
--additional-imports Custom imports for output (delimited list input).
For example "datetime.date,datetime.datetime"
--custom-formatters List of modules with custom formatter (delimited list input).
--custom-formatters-kwargs A file with kwargs for custom formatters.
OpenAPI-only options:
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]

View File

@@ -298,6 +298,8 @@ def generate(
keep_model_order: bool = False,
custom_file_header: Optional[str] = None,
custom_file_header_path: Optional[Path] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
@@ -452,6 +454,8 @@ def generate(
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=data_model_types.known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
**kwargs,
)

View File

@@ -120,7 +120,9 @@ class Config(BaseModel):
def get_fields(cls) -> Dict[str, Any]:
return cls.__fields__
@field_validator('aliases', 'extra_template_data', mode='before')
@field_validator(
'aliases', 'extra_template_data', 'custom_formatters_kwargs', mode='before'
)
def validate_file(cls, value: Any) -> Optional[TextIOBase]:
if value is None or isinstance(value, TextIOBase):
return value
@@ -204,6 +206,14 @@ class Config(BaseModel):
values['additional_imports'] = []
return values
@model_validator(mode='before')
def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get('custom_formatters') is not None:
values['custom_formatters'] = values.get('custom_formatters').split(',')
else:
values['custom_formatters'] = []
return values
if PYDANTIC_V2:
@model_validator(mode='after') # type: ignore
@@ -282,6 +292,8 @@ class Config(BaseModel):
keep_model_order: bool = False
custom_file_header: Optional[str] = None
custom_file_header_path: Optional[Path] = None
custom_formatters: Optional[List[str]] = None
custom_formatters_kwargs: Optional[TextIOBase] = None
def merge_args(self, args: Namespace) -> None:
set_args = {
@@ -391,6 +403,28 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
)
return Exit.ERROR
if config.custom_formatters_kwargs is None:
custom_formatters_kwargs = None
else:
with config.custom_formatters_kwargs as data:
try:
custom_formatters_kwargs = json.load(data)
except json.JSONDecodeError as e:
print(
f'Unable to load custom_formatters_kwargs mapping: {e}',
file=sys.stderr,
)
return Exit.ERROR
if not isinstance(custom_formatters_kwargs, dict) or not all(
isinstance(k, str) and isinstance(v, str)
for k, v in custom_formatters_kwargs.items()
):
print(
'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
file=sys.stderr,
)
return Exit.ERROR
try:
generate(
input_=config.url or config.input or sys.stdin.read(),
@@ -452,6 +486,8 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
keep_model_order=config.keep_model_order,
custom_file_header=config.custom_file_header,
custom_file_header_path=config.custom_file_header_path,
custom_formatters=config.custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)
return Exit.OK
except InvalidClassNameError as e:

View File

@@ -387,10 +387,21 @@ template_options.add_argument(
)
base_options.add_argument(
'--additional-imports',
help='Custom imports for output (delimited list input)',
help='Custom imports for output (delimited list input). For example "datetime.date,datetime.datetime"',
type=str,
default=None,
)
base_options.add_argument(
'--custom-formatters',
help='List of modules with custom formatter (delimited list input).',
type=str,
default=None,
)
template_options.add_argument(
'--custom-formatters-kwargs',
help='A file with kwargs for custom formatters.',
type=FileType('rt'),
)
# ======================================================================================
# Options specific to OpenAPI input schemas

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from enum import Enum
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from warnings import warn
@@ -112,6 +113,8 @@ class CodeFormatter:
wrap_string_literal: Optional[bool] = None,
skip_string_normalization: bool = True,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if not settings_path:
settings_path = Path().resolve()
@@ -167,12 +170,49 @@ class CodeFormatter:
settings_path=self.settings_path, **self.isort_config_kwargs
)
self.custom_formatters_kwargs = custom_formatters_kwargs or {}
self.custom_formatters = self._check_custom_formatters(custom_formatters)
def _load_custom_formatter(
self, custom_formatter_import: str
) -> CustomCodeFormatter:
import_ = import_module(custom_formatter_import)
if not hasattr(import_, 'CodeFormatter'):
raise NameError(
f'Custom formatter module `{import_.__name__}` must contains object with name Formatter'
)
formatter_class = import_.__getattribute__('CodeFormatter')
if not issubclass(formatter_class, CustomCodeFormatter):
raise TypeError(
f'The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`'
)
return formatter_class(formatter_kwargs=self.custom_formatters_kwargs)
def _check_custom_formatters(
self, custom_formatters: Optional[List[str]]
) -> List[CustomCodeFormatter]:
if custom_formatters is None:
return []
return [
self._load_custom_formatter(custom_formatter_import)
for custom_formatter_import in custom_formatters
]
def format_code(
self,
code: str,
) -> str:
code = self.apply_isort(code)
code = self.apply_black(code)
for formatter in self.custom_formatters:
code = formatter.apply(code)
return code
def apply_black(self, code: str) -> str:
@@ -200,3 +240,11 @@ class CodeFormatter:
def apply_isort(self, code: str) -> str:
return isort.code(code, config=self.isort_config)
class CustomCodeFormatter:
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
self.formatter_kwargs = formatter_kwargs
def apply(self, code: str) -> str:
raise NotImplementedError

View File

@@ -386,6 +386,8 @@ class Parser(ABC):
keep_model_order: bool = False,
use_one_literal_as_default: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.data_type_manager: DataTypeManager = data_type_manager_type(
python_version=target_python_version,
@@ -502,6 +504,8 @@ class Parser(ABC):
self.keep_model_order = keep_model_order
self.use_one_literal_as_default = use_one_literal_as_default
self.known_third_party = known_third_party
self.custom_formatter = custom_formatters
self.custom_formatters_kwargs = custom_formatters_kwargs
@property
def iter_source(self) -> Iterator[Source]:
@@ -1143,6 +1147,8 @@ class Parser(ABC):
self.wrap_string_literal,
skip_string_normalization=not self.use_double_quotes,
known_third_party=self.known_third_party,
custom_formatters=self.custom_formatter,
custom_formatters_kwargs=self.custom_formatters_kwargs,
)
else:
code_formatter = None

View File

@@ -154,6 +154,8 @@ class GraphQLParser(Parser):
keep_model_order: bool = False,
use_one_literal_as_default: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
source=source,
@@ -217,6 +219,8 @@ class GraphQLParser(Parser):
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)
self.data_model_scalar_type = data_model_scalar_type

View File

@@ -422,6 +422,8 @@ class JsonSchemaParser(Parser):
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
source=source,
@@ -485,6 +487,8 @@ class JsonSchemaParser(Parser):
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)
self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()

View File

@@ -218,6 +218,8 @@ class OpenAPIParser(JsonSchemaParser):
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
source=source,
@@ -281,6 +283,8 @@ class OpenAPIParser(JsonSchemaParser):
capitalise_enum_members=capitalise_enum_members,
keep_model_order=keep_model_order,
known_third_party=known_third_party,
custom_formatters=custom_formatters,
custom_formatters_kwargs=custom_formatters_kwargs,
)
self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
OpenAPIScope.Schemas

23
docs/custom-formatters.md Normal file
View File

@@ -0,0 +1,23 @@
# Custom Code Formatters
New features of the `datamodel-code-generator` it is custom code formatters.
## Usage
To use the `--custom-formatters` option, you'll need to pass the module with your formatter. For example
**your_module.py**
```python
from datamodel_code_generator.format import CustomCodeFormatter
class CodeFormatter(CustomCodeFormatter):
def apply(self, code: str) -> str:
# processed code
return ...
```
and run the following command
```sh
$ datamodel-codegen --input {your_input_file} --output {your_output_file} --custom-formatters "{path_to_your_module}.your_module"
```

View File

@@ -435,8 +435,11 @@ Template customization:
Wrap string literal by using black `experimental-
string-processing` option (require black 20.8b0 or
later)
--additional-imports Custom imports for output (delimited list input)
--additional-imports Custom imports for output (delimited list input).
For example "datetime.date,datetime.datetime"
--custom-formatters List of modules with custom formatter (delimited list input).
--custom-formatters-kwargs A file with kwargs for custom formatters.
OpenAPI-only options:
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
Scopes of OpenAPI model generation (default: schemas)

View File

@@ -6,6 +6,7 @@ This code generator supports the following input formats:
- JSON Schema ([JSON Schema Core](http://json-schema.org/draft/2019-09/json-schema-validation.html) /[JSON Schema Validation](http://json-schema.org/draft/2019-09/json-schema-validation.html))
- JSON/YAML Data (it will be converted to JSON Schema)
- Python dictionary (it will be converted to JSON Schema)
- GraphQL schema ([GraphQL Schemas and Types](https://graphql.org/learn/schema/))
## Implemented data types and features

View File

@@ -26,6 +26,7 @@ nav:
- Generate from JSON Data: jsondata.md
- Generate from GraphQL Schema: graphql.md
- Custom template: custom_template.md
- Custom formatters: custom-formatters.md
- Using as module: using_as_module.md
- Formatting: formatting.md
- Field Constraints: field-constraints.md

View File

@@ -0,0 +1,37 @@
# generated by datamodel-codegen:
# filename: custom-scalar-types.graphql
# timestamp: 2019-07-26T00:00:00+00:00
# a comment
from __future__ import annotations
from typing import Optional, TypeAlias
from pydantic import BaseModel, Field
from typing_extensions import Literal
Boolean: TypeAlias = bool
"""
The `Boolean` scalar type represents `true` or `false`.
"""
ID: TypeAlias = str
"""
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
"""
Long: TypeAlias = str
String: TypeAlias = str
"""
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
"""
class A(BaseModel):
duration: Long
id: ID
typename__: Optional[Literal['A']] = Field('A', alias='__typename')

View File

@@ -0,0 +1,7 @@
from datamodel_code_generator.format import CustomCodeFormatter
class CodeFormatter(CustomCodeFormatter):
"""Simple correct formatter. Adding a comment to top of code."""
def apply(self, code: str) -> str:
return f'# a comment\n{code}'

View File

@@ -0,0 +1,24 @@
from typing import Any, Dict
from pathlib import Path
from datamodel_code_generator.format import CustomCodeFormatter
class CodeFormatter(CustomCodeFormatter):
"""Add a license to file from license file path."""
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
super().__init__(formatter_kwargs)
if 'license_file' not in formatter_kwargs:
raise ValueError()
license_file_path = Path(formatter_kwargs['license_file']).resolve()
with license_file_path.open("r") as f:
license_file = f.read()
self.license_header = '\n'.join([f'# {line}' for line in license_file.split('\n')])
def apply(self, code: str) -> str:
return f'{self.license_header}\n{code}'

View File

@@ -0,0 +1,3 @@
MIT License
Copyright (c) 2023 Blah-blah

View File

@@ -0,0 +1,3 @@
class CodeFormatter:
"""Invalid formatter: is not subclass of `datamodel_code_generator.format.CustomCodeFormatter`."""
pass

View File

@@ -0,0 +1,7 @@
from datamodel_code_generator.format import CustomCodeFormatter
class WrongFormatterName(CustomCodeFormatter):
"""Invalid formatter: correct name is CodeFormatter."""
def apply(self, code: str) -> str:
return f'# a comment\n{code}'

View File

@@ -1,9 +1,20 @@
import sys
from pathlib import Path
import pytest
from datamodel_code_generator.format import CodeFormatter, PythonVersion
EXAMPLE_LICENSE_FILE = str(
Path(__file__).parent / 'data/python/custom_formatters/license_example.txt'
)
UN_EXIST_FORMATTER = 'tests.data.python.custom_formatters.un_exist'
WRONG_FORMATTER = 'tests.data.python.custom_formatters.wrong'
NOT_SUBCLASS_FORMATTER = 'tests.data.python.custom_formatters.not_subclass'
ADD_COMMENT_FORMATTER = 'tests.data.python.custom_formatters.add_comment'
ADD_LICENSE_FORMATTER = 'tests.data.python.custom_formatters.add_license'
def test_python_version():
"""Ensure that the python version used for the tests is properly listed"""
@@ -28,3 +39,84 @@ def test_format_code_with_skip_string_normalization(
formatted_code = formatter.format_code("a = 'b'")
assert formatted_code == expected_output + '\n'
def test_format_code_un_exist_custom_formatter():
with pytest.raises(ModuleNotFoundError):
_ = CodeFormatter(
PythonVersion.PY_37,
custom_formatters=[UN_EXIST_FORMATTER],
)
def test_format_code_invalid_formatter_name():
with pytest.raises(NameError):
_ = CodeFormatter(
PythonVersion.PY_37,
custom_formatters=[WRONG_FORMATTER],
)
def test_format_code_is_not_subclass():
with pytest.raises(TypeError):
_ = CodeFormatter(
PythonVersion.PY_37,
custom_formatters=[NOT_SUBCLASS_FORMATTER],
)
def test_format_code_with_custom_formatter_without_kwargs():
formatter = CodeFormatter(
PythonVersion.PY_37,
custom_formatters=[ADD_COMMENT_FORMATTER],
)
formatted_code = formatter.format_code('x = 1\ny = 2')
assert formatted_code == '# a comment\nx = 1\ny = 2' + '\n'
def test_format_code_with_custom_formatter_with_kwargs():
formatter = CodeFormatter(
PythonVersion.PY_37,
custom_formatters=[ADD_LICENSE_FORMATTER],
custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE},
)
formatted_code = formatter.format_code('x = 1\ny = 2')
assert (
formatted_code
== """# MIT License
#
# Copyright (c) 2023 Blah-blah
#
x = 1
y = 2
"""
)
def test_format_code_with_two_custom_formatters():
formatter = CodeFormatter(
PythonVersion.PY_37,
custom_formatters=[
ADD_COMMENT_FORMATTER,
ADD_LICENSE_FORMATTER,
],
custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE},
)
formatted_code = formatter.format_code('x = 1\ny = 2')
assert (
formatted_code
== """# MIT License
#
# Copyright (c) 2023 Blah-blah
#
# a comment
x = 1
y = 2
"""
)

View File

@@ -6335,3 +6335,28 @@ def test_main_graphql_additional_imports_isort_5():
/ 'output_isort5.py'
).read_text()
)
@freeze_time('2019-07-26')
def test_main_graphql_custom_formatters():
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(GRAPHQL_DATA_PATH / 'custom-scalar-types.graphql'),
'--output',
str(output_file),
'--input-file-type',
'graphql',
'--custom-formatters',
'tests.data.python.custom_formatters.add_comment',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_MAIN_PATH / 'main_graphql_custom_formatters' / 'output.py'
).read_text()
)