Files
pre-commit-ci[bot] 862a98cb7e [pre-commit.ci] pre-commit autoupdate (#1883)
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.2.2 → v0.3.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.2.2...v0.3.2)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-03-16 14:42:22 +09:00

265 lines
8.4 KiB
Python

from __future__ import annotations
from enum import Enum
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from warnings import warn
import black
import isort
from datamodel_code_generator.util import cached_property, load_toml
try:
import black.mode
except ImportError: # pragma: no cover
black.mode = None
class PythonVersion(Enum):
PY_36 = '3.6'
PY_37 = '3.7'
PY_38 = '3.8'
PY_39 = '3.9'
PY_310 = '3.10'
PY_311 = '3.11'
PY_312 = '3.12'
@cached_property
def _is_py_38_or_later(self) -> bool: # pragma: no cover
return self.value not in {self.PY_36.value, self.PY_37.value} # type: ignore
@cached_property
def _is_py_39_or_later(self) -> bool: # pragma: no cover
return self.value not in {self.PY_36.value, self.PY_37.value, self.PY_38.value} # type: ignore
@cached_property
def _is_py_310_or_later(self) -> bool: # pragma: no cover
return self.value not in {
self.PY_36.value,
self.PY_37.value,
self.PY_38.value,
self.PY_39.value,
} # type: ignore
@cached_property
def _is_py_311_or_later(self) -> bool: # pragma: no cover
return self.value not in {
self.PY_36.value,
self.PY_37.value,
self.PY_38.value,
self.PY_39.value,
self.PY_310.value,
} # type: ignore
@property
def has_literal_type(self) -> bool:
return self._is_py_38_or_later
@property
def has_union_operator(self) -> bool: # pragma: no cover
return self._is_py_310_or_later
@property
def has_annotated_type(self) -> bool:
return self._is_py_39_or_later
@property
def has_typed_dict(self) -> bool:
return self._is_py_38_or_later
@property
def has_typed_dict_non_required(self) -> bool:
return self._is_py_311_or_later
if TYPE_CHECKING:
class _TargetVersion(Enum): ...
BLACK_PYTHON_VERSION: Dict[PythonVersion, _TargetVersion]
else:
BLACK_PYTHON_VERSION: Dict[PythonVersion, black.TargetVersion] = {
v: getattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}')
for v in PythonVersion
if hasattr(black.TargetVersion, f'PY{v.name.split("_")[-1]}')
}
def is_supported_in_black(python_version: PythonVersion) -> bool: # pragma: no cover
return python_version in BLACK_PYTHON_VERSION
def black_find_project_root(sources: Sequence[Path]) -> Path:
if TYPE_CHECKING:
from typing import Iterable, Tuple, Union
def _find_project_root(
srcs: Union[Sequence[str], Iterable[str]],
) -> Union[Tuple[Path, str], Path]: ...
else:
from black import find_project_root as _find_project_root
project_root = _find_project_root(tuple(str(s) for s in sources))
if isinstance(project_root, tuple):
return project_root[0]
else: # pragma: no cover
return project_root
class CodeFormatter:
def __init__(
self,
python_version: PythonVersion,
settings_path: Optional[Path] = None,
wrap_string_literal: Optional[bool] = None,
skip_string_normalization: bool = True,
known_third_party: Optional[List[str]] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if not settings_path:
settings_path = Path().resolve()
root = black_find_project_root((settings_path,))
path = root / 'pyproject.toml'
if path.is_file():
pyproject_toml = load_toml(path)
config = pyproject_toml.get('tool', {}).get('black', {})
else:
config = {}
black_kwargs: Dict[str, Any] = {}
if wrap_string_literal is not None:
experimental_string_processing = wrap_string_literal
else:
if black.__version__ < '24.1.0': # type: ignore
experimental_string_processing = config.get(
'experimental-string-processing'
)
else:
experimental_string_processing = config.get('preview', False) and (
config.get('unstable', False)
or 'string_processing' in config.get('enable-unstable-feature', [])
)
if experimental_string_processing is not None: # pragma: no cover
if black.__version__.startswith('19.'): # type: ignore
warn(
f"black doesn't support `experimental-string-processing` option" # type: ignore
f' for wrapping string literal in {black.__version__}'
)
elif black.__version__ < '24.1.0': # type: ignore
black_kwargs['experimental_string_processing'] = (
experimental_string_processing
)
elif experimental_string_processing:
black_kwargs['preview'] = True
black_kwargs['unstable'] = config.get('unstable', False)
black_kwargs['enabled_features'] = {
black.mode.Preview.string_processing
}
if TYPE_CHECKING:
self.black_mode: black.FileMode
else:
self.black_mode = black.FileMode(
target_versions={BLACK_PYTHON_VERSION[python_version]},
line_length=config.get('line-length', black.DEFAULT_LINE_LENGTH),
string_normalization=not skip_string_normalization
or not config.get('skip-string-normalization', True),
**black_kwargs,
)
self.settings_path: str = str(settings_path)
self.isort_config_kwargs: Dict[str, Any] = {}
if known_third_party:
self.isort_config_kwargs['known_third_party'] = known_third_party
if isort.__version__.startswith('4.'):
self.isort_config = None
else:
self.isort_config = isort.Config(
settings_path=self.settings_path, **self.isort_config_kwargs
)
self.custom_formatters_kwargs = custom_formatters_kwargs or {}
self.custom_formatters = self._check_custom_formatters(custom_formatters)
def _load_custom_formatter(
self, custom_formatter_import: str
) -> CustomCodeFormatter:
import_ = import_module(custom_formatter_import)
if not hasattr(import_, 'CodeFormatter'):
raise NameError(
f'Custom formatter module `{import_.__name__}` must contains object with name Formatter'
)
formatter_class = import_.__getattribute__('CodeFormatter')
if not issubclass(formatter_class, CustomCodeFormatter):
raise TypeError(
f'The custom module {custom_formatter_import} must inherit from `datamodel-code-generator`'
)
return formatter_class(formatter_kwargs=self.custom_formatters_kwargs)
def _check_custom_formatters(
self, custom_formatters: Optional[List[str]]
) -> List[CustomCodeFormatter]:
if custom_formatters is None:
return []
return [
self._load_custom_formatter(custom_formatter_import)
for custom_formatter_import in custom_formatters
]
def format_code(
self,
code: str,
) -> str:
code = self.apply_isort(code)
code = self.apply_black(code)
for formatter in self.custom_formatters:
code = formatter.apply(code)
return code
def apply_black(self, code: str) -> str:
return black.format_str(
code,
mode=self.black_mode,
)
if TYPE_CHECKING:
def apply_isort(self, code: str) -> str: ...
else:
if isort.__version__.startswith('4.'):
def apply_isort(self, code: str) -> str:
return isort.SortImports(
file_contents=code,
settings_path=self.settings_path,
**self.isort_config_kwargs,
).output
else:
def apply_isort(self, code: str) -> str:
return isort.code(code, config=self.isort_config)
class CustomCodeFormatter:
def __init__(self, formatter_kwargs: Dict[str, Any]) -> None:
self.formatter_kwargs = formatter_kwargs
def apply(self, code: str) -> str:
raise NotImplementedError