mirror of
https://github.com/koxudaxi/datamodel-code-generator.git
synced 2024-03-18 14:54:37 +03:00
Add custom formatters (#1733)
* Support custom formatters for CodeFormatter * Add custom formatters argument * Add graphql to docs/supported-data-types.md * Add test custom formatter for custom-scalar-types.graphql; * Run poetry run scripts/format.sh * Add simple doc
This commit is contained in:
@@ -438,7 +438,10 @@ Template customization:
|
||||
Wrap string literal by using black `experimental-
|
||||
string-processing` option (require black 20.8b0 or
|
||||
later)
|
||||
--additional-imports Custom imports for output (delimited list input)
|
||||
--additional-imports Custom imports for output (delimited list input).
|
||||
For example "datetime.date,datetime.datetime"
|
||||
--custom-formatters List of modules with custom formatter (delimited list input).
|
||||
--custom-formatters-kwargs A file with kwargs for custom formatters.
|
||||
|
||||
OpenAPI-only options:
|
||||
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
|
||||
|
||||
@@ -298,6 +298,8 @@ def generate(
|
||||
keep_model_order: bool = False,
|
||||
custom_file_header: Optional[str] = None,
|
||||
custom_file_header_path: Optional[Path] = None,
|
||||
custom_formatters: Optional[List[str]] = None,
|
||||
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
|
||||
if isinstance(input_, str):
|
||||
@@ -452,6 +454,8 @@ def generate(
|
||||
capitalise_enum_members=capitalise_enum_members,
|
||||
keep_model_order=keep_model_order,
|
||||
known_third_party=data_model_types.known_third_party,
|
||||
custom_formatters=custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,7 +120,9 @@ class Config(BaseModel):
|
||||
def get_fields(cls) -> Dict[str, Any]:
|
||||
return cls.__fields__
|
||||
|
||||
@field_validator('aliases', 'extra_template_data', mode='before')
|
||||
@field_validator(
|
||||
'aliases', 'extra_template_data', 'custom_formatters_kwargs', mode='before'
|
||||
)
|
||||
def validate_file(cls, value: Any) -> Optional[TextIOBase]:
|
||||
if value is None or isinstance(value, TextIOBase):
|
||||
return value
|
||||
@@ -204,6 +206,14 @@ class Config(BaseModel):
|
||||
values['additional_imports'] = []
|
||||
return values
|
||||
|
||||
@model_validator(mode='before')
|
||||
def validate_custom_formatters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if values.get('custom_formatters') is not None:
|
||||
values['custom_formatters'] = values.get('custom_formatters').split(',')
|
||||
else:
|
||||
values['custom_formatters'] = []
|
||||
return values
|
||||
|
||||
if PYDANTIC_V2:
|
||||
|
||||
@model_validator(mode='after') # type: ignore
|
||||
@@ -282,6 +292,8 @@ class Config(BaseModel):
|
||||
keep_model_order: bool = False
|
||||
custom_file_header: Optional[str] = None
|
||||
custom_file_header_path: Optional[Path] = None
|
||||
custom_formatters: Optional[List[str]] = None
|
||||
custom_formatters_kwargs: Optional[TextIOBase] = None
|
||||
|
||||
def merge_args(self, args: Namespace) -> None:
|
||||
set_args = {
|
||||
@@ -391,6 +403,28 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
|
||||
)
|
||||
return Exit.ERROR
|
||||
|
||||
if config.custom_formatters_kwargs is None:
|
||||
custom_formatters_kwargs = None
|
||||
else:
|
||||
with config.custom_formatters_kwargs as data:
|
||||
try:
|
||||
custom_formatters_kwargs = json.load(data)
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f'Unable to load custom_formatters_kwargs mapping: {e}',
|
||||
file=sys.stderr,
|
||||
)
|
||||
return Exit.ERROR
|
||||
if not isinstance(custom_formatters_kwargs, dict) or not all(
|
||||
isinstance(k, str) and isinstance(v, str)
|
||||
for k, v in custom_formatters_kwargs.items()
|
||||
):
|
||||
print(
|
||||
'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
|
||||
file=sys.stderr,
|
||||
)
|
||||
return Exit.ERROR
|
||||
|
||||
try:
|
||||
generate(
|
||||
input_=config.url or config.input or sys.stdin.read(),
|
||||
@@ -452,6 +486,8 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
|
||||
keep_model_order=config.keep_model_order,
|
||||
custom_file_header=config.custom_file_header,
|
||||
custom_file_header_path=config.custom_file_header_path,
|
||||
custom_formatters=config.custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
)
|
||||
return Exit.OK
|
||||
except InvalidClassNameError as e:
|
||||
|
||||
@@ -387,10 +387,21 @@ template_options.add_argument(
|
||||
)
|
||||
base_options.add_argument(
|
||||
'--additional-imports',
|
||||
help='Custom imports for output (delimited list input)',
|
||||
help='Custom imports for output (delimited list input). For example "datetime.date,datetime.datetime"',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
base_options.add_argument(
|
||||
'--custom-formatters',
|
||||
help='List of modules with custom formatter (delimited list input).',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
template_options.add_argument(
|
||||
'--custom-formatters-kwargs',
|
||||
help='A file with kwargs for custom formatters.',
|
||||
type=FileType('rt'),
|
||||
)
|
||||
|
||||
# ======================================================================================
|
||||
# Options specific to OpenAPI input schemas
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
||||
from warnings import warn
|
||||
@@ -112,6 +113,8 @@ class CodeFormatter:
|
||||
wrap_string_literal: Optional[bool] = None,
|
||||
skip_string_normalization: bool = True,
|
||||
known_third_party: Optional[List[str]] = None,
|
||||
custom_formatters: Optional[List[str]] = None,
|
||||
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if not settings_path:
|
||||
settings_path = Path().resolve()
|
||||
@@ -167,12 +170,49 @@ class CodeFormatter:
|
||||
settings_path=self.settings_path, **self.isort_config_kwargs
|
||||
)
|
||||
|
||||
self.custom_formatters_kwargs = custom_formatters_kwargs or {}
|
||||
self.custom_formatters = self._check_custom_formatters(custom_formatters)
|
||||
|
||||
def _load_custom_formatter(
|
||||
self, custom_formatter_import: str
|
||||
) -> CustomCodeFormatter:
|
||||
import_ = import_module(custom_formatter_import)
|
||||
|
||||
if not hasattr(import_, 'CodeFormatter'):
|
||||
raise NameError(
|
||||
f'Custom formatter module `{import_.__name__}` must contains object with name Formatter'
|
||||
)
|
||||
|
||||
formatter_class = import_.__getattribute__('CodeFormatter')
|
||||
|
||||
if not issubclass(formatter_class, CustomCodeFormatter):
|
||||
raise TypeError(
|
||||
f'The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`'
|
||||
)
|
||||
|
||||
return formatter_class(formatter_kwargs=self.custom_formatters_kwargs)
|
||||
|
||||
def _check_custom_formatters(
|
||||
self, custom_formatters: Optional[List[str]]
|
||||
) -> List[CustomCodeFormatter]:
|
||||
if custom_formatters is None:
|
||||
return []
|
||||
|
||||
return [
|
||||
self._load_custom_formatter(custom_formatter_import)
|
||||
for custom_formatter_import in custom_formatters
|
||||
]
|
||||
|
||||
def format_code(
|
||||
self,
|
||||
code: str,
|
||||
) -> str:
|
||||
code = self.apply_isort(code)
|
||||
code = self.apply_black(code)
|
||||
|
||||
for formatter in self.custom_formatters:
|
||||
code = formatter.apply(code)
|
||||
|
||||
return code
|
||||
|
||||
def apply_black(self, code: str) -> str:
|
||||
@@ -200,3 +240,11 @@ class CodeFormatter:
|
||||
|
||||
def apply_isort(self, code: str) -> str:
|
||||
return isort.code(code, config=self.isort_config)
|
||||
|
||||
|
||||
class CustomCodeFormatter:
|
||||
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
|
||||
self.formatter_kwargs = formatter_kwargs
|
||||
|
||||
def apply(self, code: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -386,6 +386,8 @@ class Parser(ABC):
|
||||
keep_model_order: bool = False,
|
||||
use_one_literal_as_default: bool = False,
|
||||
known_third_party: Optional[List[str]] = None,
|
||||
custom_formatters: Optional[List[str]] = None,
|
||||
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.data_type_manager: DataTypeManager = data_type_manager_type(
|
||||
python_version=target_python_version,
|
||||
@@ -502,6 +504,8 @@ class Parser(ABC):
|
||||
self.keep_model_order = keep_model_order
|
||||
self.use_one_literal_as_default = use_one_literal_as_default
|
||||
self.known_third_party = known_third_party
|
||||
self.custom_formatter = custom_formatters
|
||||
self.custom_formatters_kwargs = custom_formatters_kwargs
|
||||
|
||||
@property
|
||||
def iter_source(self) -> Iterator[Source]:
|
||||
@@ -1143,6 +1147,8 @@ class Parser(ABC):
|
||||
self.wrap_string_literal,
|
||||
skip_string_normalization=not self.use_double_quotes,
|
||||
known_third_party=self.known_third_party,
|
||||
custom_formatters=self.custom_formatter,
|
||||
custom_formatters_kwargs=self.custom_formatters_kwargs,
|
||||
)
|
||||
else:
|
||||
code_formatter = None
|
||||
|
||||
@@ -154,6 +154,8 @@ class GraphQLParser(Parser):
|
||||
keep_model_order: bool = False,
|
||||
use_one_literal_as_default: bool = False,
|
||||
known_third_party: Optional[List[str]] = None,
|
||||
custom_formatters: Optional[List[str]] = None,
|
||||
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
source=source,
|
||||
@@ -217,6 +219,8 @@ class GraphQLParser(Parser):
|
||||
capitalise_enum_members=capitalise_enum_members,
|
||||
keep_model_order=keep_model_order,
|
||||
known_third_party=known_third_party,
|
||||
custom_formatters=custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
)
|
||||
|
||||
self.data_model_scalar_type = data_model_scalar_type
|
||||
|
||||
@@ -422,6 +422,8 @@ class JsonSchemaParser(Parser):
|
||||
capitalise_enum_members: bool = False,
|
||||
keep_model_order: bool = False,
|
||||
known_third_party: Optional[List[str]] = None,
|
||||
custom_formatters: Optional[List[str]] = None,
|
||||
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
source=source,
|
||||
@@ -485,6 +487,8 @@ class JsonSchemaParser(Parser):
|
||||
capitalise_enum_members=capitalise_enum_members,
|
||||
keep_model_order=keep_model_order,
|
||||
known_third_party=known_third_party,
|
||||
custom_formatters=custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
)
|
||||
|
||||
self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()
|
||||
|
||||
@@ -218,6 +218,8 @@ class OpenAPIParser(JsonSchemaParser):
|
||||
capitalise_enum_members: bool = False,
|
||||
keep_model_order: bool = False,
|
||||
known_third_party: Optional[List[str]] = None,
|
||||
custom_formatters: Optional[List[str]] = None,
|
||||
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
source=source,
|
||||
@@ -281,6 +283,8 @@ class OpenAPIParser(JsonSchemaParser):
|
||||
capitalise_enum_members=capitalise_enum_members,
|
||||
keep_model_order=keep_model_order,
|
||||
known_third_party=known_third_party,
|
||||
custom_formatters=custom_formatters,
|
||||
custom_formatters_kwargs=custom_formatters_kwargs,
|
||||
)
|
||||
self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
|
||||
OpenAPIScope.Schemas
|
||||
|
||||
23
docs/custom-formatters.md
Normal file
23
docs/custom-formatters.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Custom Code Formatters
|
||||
|
||||
New features of the `datamodel-code-generator` it is custom code formatters.
|
||||
|
||||
## Usage
|
||||
To use the `--custom-formatters` option, you'll need to pass the module with your formatter. For example
|
||||
|
||||
**your_module.py**
|
||||
```python
|
||||
from datamodel_code_generator.format import CustomCodeFormatter
|
||||
|
||||
class CodeFormatter(CustomCodeFormatter):
|
||||
def apply(self, code: str) -> str:
|
||||
# processed code
|
||||
return ...
|
||||
|
||||
```
|
||||
|
||||
and run the following command
|
||||
|
||||
```sh
|
||||
$ datamodel-codegen --input {your_input_file} --output {your_output_file} --custom-formatters "{path_to_your_module}.your_module"
|
||||
```
|
||||
@@ -435,8 +435,11 @@ Template customization:
|
||||
Wrap string literal by using black `experimental-
|
||||
string-processing` option (require black 20.8b0 or
|
||||
later)
|
||||
--additional-imports Custom imports for output (delimited list input)
|
||||
|
||||
--additional-imports Custom imports for output (delimited list input).
|
||||
For example "datetime.date,datetime.datetime"
|
||||
--custom-formatters List of modules with custom formatter (delimited list input).
|
||||
--custom-formatters-kwargs A file with kwargs for custom formatters.
|
||||
|
||||
OpenAPI-only options:
|
||||
--openapi-scopes {schemas,paths,tags,parameters} [{schemas,paths,tags,parameters} ...]
|
||||
Scopes of OpenAPI model generation (default: schemas)
|
||||
|
||||
@@ -6,6 +6,7 @@ This code generator supports the following input formats:
|
||||
- JSON Schema ([JSON Schema Core](http://json-schema.org/draft/2019-09/json-schema-validation.html) /[JSON Schema Validation](http://json-schema.org/draft/2019-09/json-schema-validation.html))
|
||||
- JSON/YAML Data (it will be converted to JSON Schema)
|
||||
- Python dictionary (it will be converted to JSON Schema)
|
||||
- GraphQL schema ([GraphQL Schemas and Types](https://graphql.org/learn/schema/))
|
||||
|
||||
## Implemented data types and features
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ nav:
|
||||
- Generate from JSON Data: jsondata.md
|
||||
- Generate from GraphQL Schema: graphql.md
|
||||
- Custom template: custom_template.md
|
||||
- Custom formatters: custom-formatters.md
|
||||
- Using as module: using_as_module.md
|
||||
- Formatting: formatting.md
|
||||
- Field Constraints: field-constraints.md
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: custom-scalar-types.graphql
|
||||
# timestamp: 2019-07-26T00:00:00+00:00
|
||||
|
||||
# a comment
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
Boolean: TypeAlias = bool
|
||||
"""
|
||||
The `Boolean` scalar type represents `true` or `false`.
|
||||
"""
|
||||
|
||||
|
||||
ID: TypeAlias = str
|
||||
"""
|
||||
The `ID` scalar type represents a unique identifier, often used to refetch an object or as key for a cache. The ID type appears in a JSON response as a String; however, it is not intended to be human-readable. When expected as an input type, any string (such as `"4"`) or integer (such as `4`) input value will be accepted as an ID.
|
||||
"""
|
||||
|
||||
|
||||
Long: TypeAlias = str
|
||||
|
||||
|
||||
String: TypeAlias = str
|
||||
"""
|
||||
The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.
|
||||
"""
|
||||
|
||||
|
||||
class A(BaseModel):
|
||||
duration: Long
|
||||
id: ID
|
||||
typename__: Optional[Literal['A']] = Field('A', alias='__typename')
|
||||
7
tests/data/python/custom_formatters/add_comment.py
Normal file
7
tests/data/python/custom_formatters/add_comment.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from datamodel_code_generator.format import CustomCodeFormatter
|
||||
|
||||
|
||||
class CodeFormatter(CustomCodeFormatter):
|
||||
"""Simple correct formatter. Adding a comment to top of code."""
|
||||
def apply(self, code: str) -> str:
|
||||
return f'# a comment\n{code}'
|
||||
24
tests/data/python/custom_formatters/add_license.py
Normal file
24
tests/data/python/custom_formatters/add_license.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any, Dict
|
||||
from pathlib import Path
|
||||
|
||||
from datamodel_code_generator.format import CustomCodeFormatter
|
||||
|
||||
|
||||
class CodeFormatter(CustomCodeFormatter):
|
||||
"""Add a license to file from license file path."""
|
||||
|
||||
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
|
||||
super().__init__(formatter_kwargs)
|
||||
|
||||
if 'license_file' not in formatter_kwargs:
|
||||
raise ValueError()
|
||||
|
||||
license_file_path = Path(formatter_kwargs['license_file']).resolve()
|
||||
|
||||
with license_file_path.open("r") as f:
|
||||
license_file = f.read()
|
||||
|
||||
self.license_header = '\n'.join([f'# {line}' for line in license_file.split('\n')])
|
||||
|
||||
def apply(self, code: str) -> str:
|
||||
return f'{self.license_header}\n{code}'
|
||||
3
tests/data/python/custom_formatters/license_example.txt
Normal file
3
tests/data/python/custom_formatters/license_example.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Blah-blah
|
||||
3
tests/data/python/custom_formatters/not_subclass.py
Normal file
3
tests/data/python/custom_formatters/not_subclass.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class CodeFormatter:
|
||||
"""Invalid formatter: is not subclass of `datamodel_code_generator.format.CustomCodeFormatter`."""
|
||||
pass
|
||||
7
tests/data/python/custom_formatters/wrong.py
Normal file
7
tests/data/python/custom_formatters/wrong.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from datamodel_code_generator.format import CustomCodeFormatter
|
||||
|
||||
|
||||
class WrongFormatterName(CustomCodeFormatter):
|
||||
"""Invalid formatter: correct name is CodeFormatter."""
|
||||
def apply(self, code: str) -> str:
|
||||
return f'# a comment\n{code}'
|
||||
@@ -1,9 +1,20 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from datamodel_code_generator.format import CodeFormatter, PythonVersion
|
||||
|
||||
EXAMPLE_LICENSE_FILE = str(
|
||||
Path(__file__).parent / 'data/python/custom_formatters/license_example.txt'
|
||||
)
|
||||
|
||||
UN_EXIST_FORMATTER = 'tests.data.python.custom_formatters.un_exist'
|
||||
WRONG_FORMATTER = 'tests.data.python.custom_formatters.wrong'
|
||||
NOT_SUBCLASS_FORMATTER = 'tests.data.python.custom_formatters.not_subclass'
|
||||
ADD_COMMENT_FORMATTER = 'tests.data.python.custom_formatters.add_comment'
|
||||
ADD_LICENSE_FORMATTER = 'tests.data.python.custom_formatters.add_license'
|
||||
|
||||
|
||||
def test_python_version():
|
||||
"""Ensure that the python version used for the tests is properly listed"""
|
||||
@@ -28,3 +39,84 @@ def test_format_code_with_skip_string_normalization(
|
||||
formatted_code = formatter.format_code("a = 'b'")
|
||||
|
||||
assert formatted_code == expected_output + '\n'
|
||||
|
||||
|
||||
def test_format_code_un_exist_custom_formatter():
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
_ = CodeFormatter(
|
||||
PythonVersion.PY_37,
|
||||
custom_formatters=[UN_EXIST_FORMATTER],
|
||||
)
|
||||
|
||||
|
||||
def test_format_code_invalid_formatter_name():
|
||||
with pytest.raises(NameError):
|
||||
_ = CodeFormatter(
|
||||
PythonVersion.PY_37,
|
||||
custom_formatters=[WRONG_FORMATTER],
|
||||
)
|
||||
|
||||
|
||||
def test_format_code_is_not_subclass():
|
||||
with pytest.raises(TypeError):
|
||||
_ = CodeFormatter(
|
||||
PythonVersion.PY_37,
|
||||
custom_formatters=[NOT_SUBCLASS_FORMATTER],
|
||||
)
|
||||
|
||||
|
||||
def test_format_code_with_custom_formatter_without_kwargs():
|
||||
formatter = CodeFormatter(
|
||||
PythonVersion.PY_37,
|
||||
custom_formatters=[ADD_COMMENT_FORMATTER],
|
||||
)
|
||||
|
||||
formatted_code = formatter.format_code('x = 1\ny = 2')
|
||||
|
||||
assert formatted_code == '# a comment\nx = 1\ny = 2' + '\n'
|
||||
|
||||
|
||||
def test_format_code_with_custom_formatter_with_kwargs():
|
||||
formatter = CodeFormatter(
|
||||
PythonVersion.PY_37,
|
||||
custom_formatters=[ADD_LICENSE_FORMATTER],
|
||||
custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE},
|
||||
)
|
||||
|
||||
formatted_code = formatter.format_code('x = 1\ny = 2')
|
||||
|
||||
assert (
|
||||
formatted_code
|
||||
== """# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Blah-blah
|
||||
#
|
||||
x = 1
|
||||
y = 2
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_format_code_with_two_custom_formatters():
|
||||
formatter = CodeFormatter(
|
||||
PythonVersion.PY_37,
|
||||
custom_formatters=[
|
||||
ADD_COMMENT_FORMATTER,
|
||||
ADD_LICENSE_FORMATTER,
|
||||
],
|
||||
custom_formatters_kwargs={'license_file': EXAMPLE_LICENSE_FILE},
|
||||
)
|
||||
|
||||
formatted_code = formatter.format_code('x = 1\ny = 2')
|
||||
|
||||
assert (
|
||||
formatted_code
|
||||
== """# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Blah-blah
|
||||
#
|
||||
# a comment
|
||||
x = 1
|
||||
y = 2
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -6335,3 +6335,28 @@ def test_main_graphql_additional_imports_isort_5():
|
||||
/ 'output_isort5.py'
|
||||
).read_text()
|
||||
)
|
||||
|
||||
|
||||
@freeze_time('2019-07-26')
|
||||
def test_main_graphql_custom_formatters():
|
||||
with TemporaryDirectory() as output_dir:
|
||||
output_file: Path = Path(output_dir) / 'output.py'
|
||||
return_code: Exit = main(
|
||||
[
|
||||
'--input',
|
||||
str(GRAPHQL_DATA_PATH / 'custom-scalar-types.graphql'),
|
||||
'--output',
|
||||
str(output_file),
|
||||
'--input-file-type',
|
||||
'graphql',
|
||||
'--custom-formatters',
|
||||
'tests.data.python.custom_formatters.add_comment',
|
||||
]
|
||||
)
|
||||
assert return_code == Exit.OK
|
||||
assert (
|
||||
output_file.read_text()
|
||||
== (
|
||||
EXPECTED_MAIN_PATH / 'main_graphql_custom_formatters' / 'output.py'
|
||||
).read_text()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user