mirror of
https://github.com/koxudaxi/datamodel-code-generator.git
synced 2024-03-18 14:54:37 +03:00
Add custom formatters (#1733)
* Support custom formatters for CodeFormatter * Add custom formatters argument * Add graphql to docs/supported-data-types.md * Add test custom formatter for custom-scalar-types.graphql; * Run poetry run scripts/format.sh * Add simple doc
This commit is contained in:
@@ -438,7 +438,10 @@ Template customization:
|
|||||||
Wrap string literal by using black `experimental-
|
Wrap string literal by using black `experimental-
|
||||||
string-processing` option (require black 20.8b0 or
|
string-processing` option (require black 20.8b0 or
|
||||||
later)
|
later)
|
||||||
--additional-imports Custom imports for output (delimited list input)
|
--additional-imports Custom imports for output (delimited list input).
|
||||||
|
For example "datetime.date,datetime.datetime"
|
||||||
|
--custom-formatters List of modules with custom formatter (delimited list input).
|
||||||
|
--custom-formatters-kwargs A file with kwargs for custom formatters.
|
||||||
|
|
||||||
OpenAPI-only options:
|
OpenAPI-only options:
|
||||||
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
|
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
|
||||||
|
|||||||
@@ -298,6 +298,8 @@ def generate(
|
|||||||
keep_model_order: bool = False,
|
keep_model_order: bool = False,
|
||||||
custom_file_header: Optional[str] = None,
|
custom_file_header: Optional[str] = None,
|
||||||
custom_file_header_path: Optional[Path] = None,
|
custom_file_header_path: Optional[Path] = None,
|
||||||
|
custom_formatters: Optional[List[str]] = None,
|
||||||
|
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
|
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
|
||||||
if isinstance(input_, str):
|
if isinstance(input_, str):
|
||||||
@@ -452,6 +454,8 @@ def generate(
|
|||||||
capitalise_enum_members=capitalise_enum_members,
|
capitalise_enum_members=capitalise_enum_members,
|
||||||
keep_model_order=keep_model_order,
|
keep_model_order=keep_model_order,
|
||||||
known_third_party=data_model_types.known_third_party,
|
known_third_party=data_model_types.known_third_party,
|
||||||
|
custom_formatters=custom_formatters,
|
||||||
|
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,9 @@ class Config(BaseModel):
|
|||||||
def get_fields(cls) -> Dict[str, Any]:
|
def get_fields(cls) -> Dict[str, Any]:
|
||||||
return cls.__fields__
|
return cls.__fields__
|
||||||
|
|
||||||
@field_validator('aliases', 'extra_template_data', mode='before')
|
@field_validator(
|
||||||
|
'aliases', 'extra_template_data', 'custom_formatters_kwargs', mode='before'
|
||||||
|
)
|
||||||
def validate_file(cls, value: Any) -> Optional[TextIOBase]:
|
def validate_file(cls, value: Any) -> Optional[TextIOBase]:
|
||||||
if value is None or isinstance(value, TextIOBase):
|
if value is None or isinstance(value, TextIOBase):
|
||||||
return value
|
return value
|
||||||
@@ -204,6 +206,14 @@ class Config(BaseModel):
|
|||||||
values['additional_imports'] = []
|
values['additional_imports'] = []
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode='before')
|
||||||
|
def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
if values.get('custom_formatters') is not None:
|
||||||
|
values['custom_formatters'] = values.get('custom_formatters').split(',')
|
||||||
|
else:
|
||||||
|
values['custom_formatters'] = []
|
||||||
|
return values
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
|
|
||||||
@model_validator(mode='after') # type: ignore
|
@model_validator(mode='after') # type: ignore
|
||||||
@@ -282,6 +292,8 @@ class Config(BaseModel):
|
|||||||
keep_model_order: bool = False
|
keep_model_order: bool = False
|
||||||
custom_file_header: Optional[str] = None
|
custom_file_header: Optional[str] = None
|
||||||
custom_file_header_path: Optional[Path] = None
|
custom_file_header_path: Optional[Path] = None
|
||||||
|
custom_formatters: Optional[List[str]] = None
|
||||||
|
custom_formatters_kwargs: Optional[TextIOBase] = None
|
||||||
|
|
||||||
def merge_args(self, args: Namespace) -> None:
|
def merge_args(self, args: Namespace) -> None:
|
||||||
set_args = {
|
set_args = {
|
||||||
@@ -391,6 +403,28 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
|
|||||||
)
|
)
|
||||||
return Exit.ERROR
|
return Exit.ERROR
|
||||||
|
|
||||||
|
if config.custom_formatters_kwargs is None:
|
||||||
|
custom_formatters_kwargs = None
|
||||||
|
else:
|
||||||
|
with config.custom_formatters_kwargs as data:
|
||||||
|
try:
|
||||||
|
custom_formatters_kwargs = json.load(data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(
|
||||||
|
f'Unable to load custom_formatters_kwargs mapping: {e}',
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return Exit.ERROR
|
||||||
|
if not isinstance(custom_formatters_kwargs, dict) or not all(
|
||||||
|
isinstance(k, str) and isinstance(v, str)
|
||||||
|
for k, v in custom_formatters_kwargs.items()
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return Exit.ERROR
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generate(
|
generate(
|
||||||
input_=config.url or config.input or sys.stdin.read(),
|
input_=config.url or config.input or sys.stdin.read(),
|
||||||
@@ -452,6 +486,8 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
|
|||||||
keep_model_order=config.keep_model_order,
|
keep_model_order=config.keep_model_order,
|
||||||
custom_file_header=config.custom_file_header,
|
custom_file_header=config.custom_file_header,
|
||||||
custom_file_header_path=config.custom_file_header_path,
|
custom_file_header_path=config.custom_file_header_path,
|
||||||
|
custom_formatters=config.custom_formatters,
|
||||||
|
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||||
)
|
)
|
||||||
return Exit.OK
|
return Exit.OK
|
||||||
except InvalidClassNameError as e:
|
except InvalidClassNameError as e:
|
||||||
|
|||||||
@@ -387,10 +387,21 @@ template_options.add_argument(
|
|||||||
)
|
)
|
||||||
base_options.add_argument(
|
base_options.add_argument(
|
||||||
'--additional-imports',
|
'--additional-imports',
|
||||||
help='Custom imports for output (delimited list input)',
|
help='Custom imports for output (delimited list input). For example "datetime.date,datetime.datetime"',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
base_options.add_argument(
|
||||||
|
'--custom-formatters',
|
||||||
|
help='List of modules with custom formatter (delimited list input).',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
template_options.add_argument(
|
||||||
|
'--custom-formatters-kwargs',
|
||||||
|
help='A file with kwargs for custom formatters.',
|
||||||
|
type=FileType('rt'),
|
||||||
|
)
|
||||||
|
|
||||||
# ======================================================================================
|
# ======================================================================================
|
||||||
# Options specific to OpenAPI input schemas
|
# Options specific to OpenAPI input schemas
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
@@ -112,6 +113,8 @@ class CodeFormatter:
|
|||||||
wrap_string_literal: Optional[bool] = None,
|
wrap_string_literal: Optional[bool] = None,
|
||||||
skip_string_normalization: bool = True,
|
skip_string_normalization: bool = True,
|
||||||
known_third_party: Optional[List[str]] = None,
|
known_third_party: Optional[List[str]] = None,
|
||||||
|
custom_formatters: Optional[List[str]] = None,
|
||||||
|
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not settings_path:
|
if not settings_path:
|
||||||
settings_path = Path().resolve()
|
settings_path = Path().resolve()
|
||||||
@@ -167,12 +170,49 @@ class CodeFormatter:
|
|||||||
settings_path=self.settings_path, **self.isort_config_kwargs
|
settings_path=self.settings_path, **self.isort_config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.custom_formatters_kwargs = custom_formatters_kwargs or {}
|
||||||
|
self.custom_formatters = self._check_custom_formatters(custom_formatters)
|
||||||
|
|
||||||
|
def _load_custom_formatter(
|
||||||
|
self, custom_formatter_import: str
|
||||||
|
) -> CustomCodeFormatter:
|
||||||
|
import_ = import_module(custom_formatter_import)
|
||||||
|
|
||||||
|
if not hasattr(import_, 'CodeFormatter'):
|
||||||
|
raise NameError(
|
||||||
|
f'Custom formatter module `{import_.__name__}` must contains object with name Formatter'
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter_class = import_.__getattribute__('CodeFormatter')
|
||||||
|
|
||||||
|
if not issubclass(formatter_class, CustomCodeFormatter):
|
||||||
|
raise TypeError(
|
||||||
|
f'The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`'
|
||||||
|
)
|
||||||
|
|
||||||
|
return formatter_class(formatter_kwargs=self.custom_formatters_kwargs)
|
||||||
|
|
||||||
|
def _check_custom_formatters(
|
||||||
|
self, custom_formatters: Optional[List[str]]
|
||||||
|
) -> List[CustomCodeFormatter]:
|
||||||
|
if custom_formatters is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
self._load_custom_formatter(custom_formatter_import)
|
||||||
|
for custom_formatter_import in custom_formatters
|
||||||
|
]
|
||||||
|
|
||||||
def format_code(
|
def format_code(
|
||||||
self,
|
self,
|
||||||
code: str,
|
code: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
code = self.apply_isort(code)
|
code = self.apply_isort(code)
|
||||||
code = self.apply_black(code)
|
code = self.apply_black(code)
|
||||||
|
|
||||||
|
for formatter in self.custom_formatters:
|
||||||
|
code = formatter.apply(code)
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
def apply_black(self, code: str) -> str:
|
def apply_black(self, code: str) -> str:
|
||||||
@@ -200,3 +240,11 @@ class CodeFormatter:
|
|||||||
|
|
||||||
def apply_isort(self, code: str) -> str:
|
def apply_isort(self, code: str) -> str:
|
||||||
return isort.code(code, config=self.isort_config)
|
return isort.code(code, config=self.isort_config)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomCodeFormatter:
|
||||||
|
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
|
||||||
|
self.formatter_kwargs = formatter_kwargs
|
||||||
|
|
||||||
|
def apply(self, code: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -386,6 +386,8 @@ class Parser(ABC):
|
|||||||
keep_model_order: bool = False,
|
keep_model_order: bool = False,
|
||||||
use_one_literal_as_default: bool = False,
|
use_one_literal_as_default: bool = False,
|
||||||
known_third_party: Optional[List[str]] = None,
|
known_third_party: Optional[List[str]] = None,
|
||||||
|
custom_formatters: Optional[List[str]] = None,
|
||||||
|
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.data_type_manager: DataTypeManager = data_type_manager_type(
|
self.data_type_manager: DataTypeManager = data_type_manager_type(
|
||||||
python_version=target_python_version,
|
python_version=target_python_version,
|
||||||
@@ -502,6 +504,8 @@ class Parser(ABC):
|
|||||||
self.keep_model_order = keep_model_order
|
self.keep_model_order = keep_model_order
|
||||||
self.use_one_literal_as_default = use_one_literal_as_default
|
self.use_one_literal_as_default = use_one_literal_as_default
|
||||||
self.known_third_party = known_third_party
|
self.known_third_party = known_third_party
|
||||||
|
self.custom_formatter = custom_formatters
|
||||||
|
self.custom_formatters_kwargs = custom_formatters_kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def iter_source(self) -> Iterator[Source]:
|
def iter_source(self) -> Iterator[Source]:
|
||||||
@@ -1143,6 +1147,8 @@ class Parser(ABC):
|
|||||||
self.wrap_string_literal,
|
self.wrap_string_literal,
|
||||||
skip_string_normalization=not self.use_double_quotes,
|
skip_string_normalization=not self.use_double_quotes,
|
||||||
known_third_party=self.known_third_party,
|
known_third_party=self.known_third_party,
|
||||||
|
custom_formatters=self.custom_formatter,
|
||||||
|
custom_formatters_kwargs=self.custom_formatters_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
code_formatter = None
|
code_formatter = None
|
||||||
|
|||||||
@@ -154,6 +154,8 @@ class GraphQLParser(Parser):
|
|||||||
keep_model_order: bool = False,
|
keep_model_order: bool = False,
|
||||||
use_one_literal_as_default: bool = False,
|
use_one_literal_as_default: bool = False,
|
||||||
known_third_party: Optional[List[str]] = None,
|
known_third_party: Optional[List[str]] = None,
|
||||||
|
custom_formatters: Optional[List[str]] = None,
|
||||||
|
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
source=source,
|
source=source,
|
||||||
@@ -217,6 +219,8 @@ class GraphQLParser(Parser):
|
|||||||
capitalise_enum_members=capitalise_enum_members,
|
capitalise_enum_members=capitalise_enum_members,
|
||||||
keep_model_order=keep_model_order,
|
keep_model_order=keep_model_order,
|
||||||
known_third_party=known_third_party,
|
known_third_party=known_third_party,
|
||||||
|
custom_formatters=custom_formatters,
|
||||||
|
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.data_model_scalar_type = data_model_scalar_type
|
self.data_model_scalar_type = data_model_scalar_type
|
||||||
|
|||||||
@@ -422,6 +422,8 @@ class JsonSchemaParser(Parser):
|
|||||||
capitalise_enum_members: bool = False,
|
capitalise_enum_members: bool = False,
|
||||||
keep_model_order: bool = False,
|
keep_model_order: bool = False,
|
||||||
known_third_party: Optional[List[str]] = None,
|
known_third_party: Optional[List[str]] = None,
|
||||||
|
custom_formatters: Optional[List[str]] = None,
|
||||||
|
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
source=source,
|
source=source,
|
||||||
@@ -485,6 +487,8 @@ class JsonSchemaParser(Parser):
|
|||||||
capitalise_enum_members=capitalise_enum_members,
|
capitalise_enum_members=capitalise_enum_members,
|
||||||
keep_model_order=keep_model_order,
|
keep_model_order=keep_model_order,
|
||||||
known_third_party=known_third_party,
|
known_third_party=known_third_party,
|
||||||
|
custom_formatters=custom_formatters,
|
||||||
|
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()
|
self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()
|
||||||
|
|||||||
@@ -218,6 +218,8 @@ class OpenAPIParser(JsonSchemaParser):
|
|||||||
capitalise_enum_members: bool = False,
|
capitalise_enum_members: bool = False,
|
||||||
keep_model_order: bool = False,
|
keep_model_order: bool = False,
|
||||||
known_third_party: Optional[List[str]] = None,
|
known_third_party: Optional[List[str]] = None,
|
||||||
|
custom_formatters: Optional[List[str]] = None,
|
||||||
|
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
source=source,
|
source=source,
|
||||||
@@ -281,6 +283,8 @@ class OpenAPIParser(JsonSchemaParser):
|
|||||||
capitalise_enum_members=capitalise_enum_members,
|
capitalise_enum_members=capitalise_enum_members,
|
||||||
keep_model_order=keep_model_order,
|
keep_model_order=keep_model_order,
|
||||||
known_third_party=known_third_party,
|
known_third_party=known_third_party,
|
||||||
|
custom_formatters=custom_formatters,
|
||||||
|
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||||
)
|
)
|
||||||
self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
|
self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
|
||||||
OpenAPIScope.Schemas
|
OpenAPIScope.Schemas
|
||||||
|
|||||||
23
docs/custom-formatters.md
Normal file
23
docs/custom-formatters.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Custom Code Formatters
|
||||||
|
|
||||||
|
New features of the `datamodel-code-generator` it is custom code formatters.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
To use the `--custom-formatters` option, you'll need to pass the module with your formatter. For example
|
||||||
|
|
||||||
|
**your_module.py**
|
||||||
|
```python
|
||||||
|
from datamodel_code_generator.format import CustomCodeFormatter
|
||||||
|
|
||||||
|
class CodeFormatter(CustomCodeFormatter):
|
||||||
|
def apply(self, code: str) -> str:
|
||||||
|
# processed code
|
||||||
|
return ...
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
and run the following command
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ datamodel-codegen --input {your_input_file} --output {your_output_file} --custom-formatters "{path_to_your_module}.your_module"
|
||||||
|
```
|
||||||
@@ -435,7 +435,10 @@ Template customization:
|
|||||||
Wrap string literal by using black `experimental-
|
Wrap string literal by using black `experimental-
|
||||||
string-processing` option (require black 20.8b0 or
|
string-processing` option (require black 20.8b0 or
|
||||||
later)
|
later)
|
||||||
--additional-imports Custom imports for output (delimited list input)
|
--additional-imports Custom imports for output (delimited list input).
|
||||||
|
For example "datetime.date,datetime.datetime"
|
||||||
|
--custom-formatters List of modules with custom formatter (delimited list input).
|
||||||
|
--custom-formatters-kwargs A file with kwargs for custom formatters.
|
||||||
|
|
||||||
OpenAPI-only options:
|
OpenAPI-only options:
|
||||||
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
|
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ This code generator supports the following input formats:
|
|||||||
- JSON Schema ([JSON Schema Core](http://json-schema.org/draft/2019-09/json-schema-validation.html) /[JSON Schema Validation](http://json-schema.org/draft/2019-09/json-schema-validation.html))
|
- JSON Schema ([JSON Schema Core](http://json-schema.org/draft/2019-09/json-schema-validation.html) /[JSON Schema Validation](http://json-schema.org/draft/2019-09/json-schema-validation.html))
|
||||||
- JSON/YAML Data (it will be converted to JSON Schema)
|
- JSON/YAML Data (it will be converted to JSON Schema)
|
||||||
- Python dictionary (it will be converted to JSON Schema)
|
- Python dictionary (it will be converted to JSON Schema)
|
||||||
|
- GraphQL schema ([GraphQL Schemas and Types](https://graphql.org/learn/schema/))
|
||||||
|
|
||||||
## Implemented data types and features
|
## Implemented data types and features
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ nav:
|
|||||||
- Generate from JSON Data: jsondata.md
|
- Generate from JSON Data: jsondata.md
|
||||||
- Generate from GraphQL Schema: graphql.md
|
- Generate from GraphQL Schema: graphql.md
|
||||||
- Custom template: custom_template.md
|
- Custom template: custom_template.md
|
||||||
|
- Custom formatters: custom-formatters.md
|
||||||
- Using as module: using_as_module.md
|
- Using as module: using_as_module.md
|
||||||
- Formatting: formatting.md
|
- Formatting: formatting.md
|
||||||
- Field Constraints: field-constraints.md
|
- Field Constraints: field-constraints.md
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
# generated by datamodel-codegen:
|
||||||
|
# filename: custom-scalar-types.graphql
|
||||||
|
# timestamp: 2019-07-26T00:00:00+00:00
|
||||||
|
|
||||||
|
# a comment
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, TypeAlias
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
Boolean: TypeAlias = bool
|
||||||
|
"""
|
||||||
|
The `Boolean` scalar type represents `true` or `false`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
ID: TypeAlias = str
|
||||||
|
"""
|
||||||
|
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Long: TypeAlias = str
|
||||||
|
|
||||||
|
|
||||||
|
String: TypeAlias = str
|
||||||
|
"""
|
||||||
|
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class A(BaseModel):
|
||||||
|
duration: Long
|
||||||
|
id: ID
|
||||||
|
typename__: Optional[Literal['A']] = Field('A', alias='__typename')
|
||||||
7
tests/data/python/custom_formatters/add_comment.py
Normal file
7
tests/data/python/custom_formatters/add_comment.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from datamodel_code_generator.format import CustomCodeFormatter
|
||||||
|
|
||||||
|
|
||||||
|
class CodeFormatter(CustomCodeFormatter):
|
||||||
|
"""Simple correct formatter. Adding a comment to top of code."""
|
||||||
|
def apply(self, code: str) -> str:
|
||||||
|
return f'# a comment\n{code}'
|
||||||
24
tests/data/python/custom_formatters/add_license.py
Normal file
24
tests/data/python/custom_formatters/add_license.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datamodel_code_generator.format import CustomCodeFormatter
|
||||||
|
|
||||||
|
|
||||||
|
class CodeFormatter(CustomCodeFormatter):
|
||||||
|
"""Add a license to file from license file path."""
|
||||||
|
|
||||||
|
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
|
||||||
|
super().__init__(formatter_kwargs)
|
||||||
|
|
||||||
|
if 'license_file' not in formatter_kwargs:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
license_file_path = Path(formatter_kwargs['license_file']).resolve()
|
||||||
|
|
||||||
|
with license_file_path.open("r") as f:
|
||||||
|
license_file = f.read()
|
||||||
|
|
||||||
|
self.license_header = '\n'.join([f'# {line}' for line in license_file.split('\n')])
|
||||||
|
|
||||||
|
def apply(self, code: str) -> str:
|
||||||
|
return f'{self.license_header}\n{code}'
|
||||||
3
tests/data/python/custom_formatters/license_example.txt
Normal file
3
tests/data/python/custom_formatters/license_example.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Blah-blah
|
||||||
3
tests/data/python/custom_formatters/not_subclass.py
Normal file
3
tests/data/python/custom_formatters/not_subclass.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
class CodeFormatter:
|
||||||
|
"""Invalid formatter: is not subclass of `datamodel_code_generator.format.CustomCodeFormatter`."""
|
||||||
|
pass
|
||||||
7
tests/data/python/custom_formatters/wrong.py
Normal file
7
tests/data/python/custom_formatters/wrong.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from datamodel_code_generator.format import CustomCodeFormatter
|
||||||
|
|
||||||
|
|
||||||
|
class WrongFormatterName(CustomCodeFormatter):
|
||||||
|
"""Invalid formatter: correct name is CodeFormatter."""
|
||||||
|
def apply(self, code: str) -> str:
|
||||||
|
return f'# a comment\n{code}'
|
||||||
@@ -1,9 +1,20 @@
|
|||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from datamodel_code_generator.format import CodeFormatter, PythonVersion
|
from datamodel_code_generator.format import CodeFormatter, PythonVersion
|
||||||
|
|
||||||
|
EXAMPLE_LICENSE_FILE = str(
|
||||||
|
Path(__file__).parent / 'data/python/custom_formatters/license_example.txt'
|
||||||
|
)
|
||||||
|
|
||||||
|
UN_EXIST_FORMATTER = 'tests.data.python.custom_formatters.un_exist'
|
||||||
|
WRONG_FORMATTER = 'tests.data.python.custom_formatters.wrong'
|
||||||
|
NOT_SUBCLASS_FORMATTER = 'tests.data.python.custom_formatters.not_subclass'
|
||||||
|
ADD_COMMENT_FORMATTER = 'tests.data.python.custom_formatters.add_comment'
|
||||||
|
ADD_LICENSE_FORMATTER = 'tests.data.python.custom_formatters.add_license'
|
||||||
|
|
||||||
|
|
||||||
def test_python_version():
|
def test_python_version():
|
||||||
"""Ensure that the python version used for the tests is properly listed"""
|
"""Ensure that the python version used for the tests is properly listed"""
|
||||||
@@ -28,3 +39,84 @@ def test_format_code_with_skip_string_normalization(
|
|||||||
formatted_code = formatter.format_code("a = 'b'")
|
formatted_code = formatter.format_code("a = 'b'")
|
||||||
|
|
||||||
assert formatted_code == expected_output + '\n'
|
assert formatted_code == expected_output + '\n'
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_code_un_exist_custom_formatter():
|
||||||
|
with pytest.raises(ModuleNotFoundError):
|
||||||
|
_ = CodeFormatter(
|
||||||
|
PythonVersion.PY_37,
|
||||||
|
custom_formatters=[UN_EXIST_FORMATTER],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_code_invalid_formatter_name():
|
||||||
|
with pytest.raises(NameError):
|
||||||
|
_ = CodeFormatter(
|
||||||
|
PythonVersion.PY_37,
|
||||||
|
custom_formatters=[WRONG_FORMATTER],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_code_is_not_subclass():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
_ = CodeFormatter(
|
||||||
|
PythonVersion.PY_37,
|
||||||
|
custom_formatters=[NOT_SUBCLASS_FORMATTER],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_code_with_custom_formatter_without_kwargs():
|
||||||
|
formatter = CodeFormatter(
|
||||||
|
PythonVersion.PY_37,
|
||||||
|
custom_formatters=[ADD_COMMENT_FORMATTER],
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_code = formatter.format_code('x = 1\ny = 2')
|
||||||
|
|
||||||
|
assert formatted_code == '# a comment\nx = 1\ny = 2' + '\n'
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_code_with_custom_formatter_with_kwargs():
|
||||||
|
formatter = CodeFormatter(
|
||||||
|
PythonVersion.PY_37,
|
||||||
|
custom_formatters=[ADD_LICENSE_FORMATTER],
|
||||||
|
custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE},
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_code = formatter.format_code('x = 1\ny = 2')
|
||||||
|
|
||||||
|
assert (
|
||||||
|
formatted_code
|
||||||
|
== """# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 Blah-blah
|
||||||
|
#
|
||||||
|
x = 1
|
||||||
|
y = 2
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_code_with_two_custom_formatters():
|
||||||
|
formatter = CodeFormatter(
|
||||||
|
PythonVersion.PY_37,
|
||||||
|
custom_formatters=[
|
||||||
|
ADD_COMMENT_FORMATTER,
|
||||||
|
ADD_LICENSE_FORMATTER,
|
||||||
|
],
|
||||||
|
custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE},
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_code = formatter.format_code('x = 1\ny = 2')
|
||||||
|
|
||||||
|
assert (
|
||||||
|
formatted_code
|
||||||
|
== """# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 Blah-blah
|
||||||
|
#
|
||||||
|
# a comment
|
||||||
|
x = 1
|
||||||
|
y = 2
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|||||||
@@ -6335,3 +6335,28 @@ def test_main_graphql_additional_imports_isort_5():
|
|||||||
/ 'output_isort5.py'
|
/ 'output_isort5.py'
|
||||||
).read_text()
|
).read_text()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time('2019-07-26')
|
||||||
|
def test_main_graphql_custom_formatters():
|
||||||
|
with TemporaryDirectory() as output_dir:
|
||||||
|
output_file: Path = Path(output_dir) / 'output.py'
|
||||||
|
return_code: Exit = main(
|
||||||
|
[
|
||||||
|
'--input',
|
||||||
|
str(GRAPHQL_DATA_PATH / 'custom-scalar-types.graphql'),
|
||||||
|
'--output',
|
||||||
|
str(output_file),
|
||||||
|
'--input-file-type',
|
||||||
|
'graphql',
|
||||||
|
'--custom-formatters',
|
||||||
|
'tests.data.python.custom_formatters.add_comment',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert return_code == Exit.OK
|
||||||
|
assert (
|
||||||
|
output_file.read_text()
|
||||||
|
== (
|
||||||
|
EXPECTED_MAIN_PATH / 'main_graphql_custom_formatters' / 'output.py'
|
||||||
|
).read_text()
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user