mirror of
https://github.com/koxudaxi/datamodel-code-generator.git
synced 2024-03-18 14:54:37 +03:00
767 lines
26 KiB
Python
767 lines
26 KiB
Python
import re
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from enum import Enum, auto
|
|
from functools import lru_cache
|
|
from itertools import zip_longest
|
|
from keyword import iskeyword
|
|
from pathlib import Path, PurePath
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
AbstractSet,
|
|
Any,
|
|
Callable,
|
|
ClassVar,
|
|
DefaultDict,
|
|
Dict,
|
|
Generator,
|
|
List,
|
|
Mapping,
|
|
NamedTuple,
|
|
Optional,
|
|
Pattern,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
from urllib.parse import ParseResult, urlparse
|
|
|
|
import inflect
|
|
import pydantic
|
|
from packaging import version
|
|
from pydantic import BaseModel
|
|
|
|
from datamodel_code_generator.util import (
|
|
PYDANTIC_V2,
|
|
ConfigDict,
|
|
cached_property,
|
|
model_validator,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from pydantic.typing import DictStrAny
|
|
|
|
|
|
class _BaseModel(BaseModel):
|
|
_exclude_fields: ClassVar[Set[str]] = set()
|
|
_pass_fields: ClassVar[Set[str]] = set()
|
|
|
|
if not TYPE_CHECKING:
|
|
|
|
def __init__(self, **values: Any) -> None:
|
|
super().__init__(**values)
|
|
for pass_field_name in self._pass_fields:
|
|
if pass_field_name in values:
|
|
setattr(self, pass_field_name, values[pass_field_name])
|
|
|
|
if not TYPE_CHECKING:
|
|
if PYDANTIC_V2:
|
|
|
|
def dict(
|
|
self,
|
|
*,
|
|
include: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], None
|
|
] = None,
|
|
exclude: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], None
|
|
] = None,
|
|
by_alias: bool = False,
|
|
exclude_unset: bool = False,
|
|
exclude_defaults: bool = False,
|
|
exclude_none: bool = False,
|
|
) -> 'DictStrAny':
|
|
return self.model_dump(
|
|
include=include,
|
|
exclude=set(exclude or ()) | self._exclude_fields,
|
|
by_alias=by_alias,
|
|
exclude_unset=exclude_unset,
|
|
exclude_defaults=exclude_defaults,
|
|
exclude_none=exclude_none,
|
|
)
|
|
|
|
else:
|
|
|
|
def dict(
|
|
self,
|
|
*,
|
|
include: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], None
|
|
] = None,
|
|
exclude: Union[
|
|
AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], None
|
|
] = None,
|
|
by_alias: bool = False,
|
|
skip_defaults: Optional[bool] = None,
|
|
exclude_unset: bool = False,
|
|
exclude_defaults: bool = False,
|
|
exclude_none: bool = False,
|
|
) -> 'DictStrAny':
|
|
return super().dict(
|
|
include=include,
|
|
exclude=set(exclude or ()) | self._exclude_fields,
|
|
by_alias=by_alias,
|
|
skip_defaults=skip_defaults,
|
|
exclude_unset=exclude_unset,
|
|
exclude_defaults=exclude_defaults,
|
|
exclude_none=exclude_none,
|
|
)
|
|
|
|
|
|
class Reference(_BaseModel):
|
|
path: str
|
|
original_name: str = ''
|
|
name: str
|
|
duplicate_name: Optional[str] = None
|
|
loaded: bool = True
|
|
source: Optional[Any] = None
|
|
children: List[Any] = []
|
|
_exclude_fields: ClassVar[Set[str]] = {'children'}
|
|
|
|
@model_validator(mode='before')
|
|
def validate_original_name(cls, values: Any) -> Any:
|
|
"""
|
|
If original_name is empty then, `original_name` is assigned `name`
|
|
"""
|
|
if not isinstance(values, dict): # pragma: no cover
|
|
return values
|
|
original_name = values.get('original_name')
|
|
if original_name:
|
|
return values
|
|
|
|
values['original_name'] = values.get('name', original_name)
|
|
return values
|
|
|
|
if PYDANTIC_V2:
|
|
# TODO[pydantic]: The following keys were removed: `copy_on_model_validation`.
|
|
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
|
|
model_config = ConfigDict(
|
|
arbitrary_types_allowed=True,
|
|
ignored_types=(cached_property,),
|
|
revalidate_instances='never',
|
|
)
|
|
else:
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
keep_untouched = (cached_property,)
|
|
copy_on_model_validation = (
|
|
False
|
|
if version.parse(pydantic.VERSION) < version.parse('1.9.2')
|
|
else 'none'
|
|
)
|
|
|
|
@property
|
|
def short_name(self) -> str:
|
|
return self.name.rsplit('.', 1)[-1]
|
|
|
|
|
|
SINGULAR_NAME_SUFFIX: str = 'Item'
|
|
|
|
ID_PATTERN: Pattern[str] = re.compile(r'^#[^/].*')
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
@contextmanager
|
|
def context_variable(
|
|
setter: Callable[[T], None], current_value: T, new_value: T
|
|
) -> Generator[None, None, None]:
|
|
previous_value: T = current_value
|
|
setter(new_value)
|
|
try:
|
|
yield
|
|
finally:
|
|
setter(previous_value)
|
|
|
|
|
|
_UNDER_SCORE_1: Pattern[str] = re.compile(r'([^_])([A-Z][a-z]+)')
|
|
_UNDER_SCORE_2: Pattern[str] = re.compile('([a-z0-9])([A-Z])')
|
|
|
|
|
|
@lru_cache()
|
|
def camel_to_snake(string: str) -> str:
|
|
subbed = _UNDER_SCORE_1.sub(r'\1_\2', string)
|
|
return _UNDER_SCORE_2.sub(r'\1_\2', subbed).lower()
|
|
|
|
|
|
class FieldNameResolver:
|
|
def __init__(
|
|
self,
|
|
aliases: Optional[Mapping[str, str]] = None,
|
|
snake_case_field: bool = False,
|
|
empty_field_name: Optional[str] = None,
|
|
original_delimiter: Optional[str] = None,
|
|
special_field_name_prefix: Optional[str] = None,
|
|
remove_special_field_name_prefix: bool = False,
|
|
capitalise_enum_members: bool = False,
|
|
):
|
|
self.aliases: Mapping[str, str] = {} if aliases is None else {**aliases}
|
|
self.empty_field_name: str = empty_field_name or '_'
|
|
self.snake_case_field = snake_case_field
|
|
self.original_delimiter: Optional[str] = original_delimiter
|
|
self.special_field_name_prefix: Optional[str] = (
|
|
'field' if special_field_name_prefix is None else special_field_name_prefix
|
|
)
|
|
self.remove_special_field_name_prefix: bool = remove_special_field_name_prefix
|
|
self.capitalise_enum_members: bool = capitalise_enum_members
|
|
|
|
@classmethod
|
|
def _validate_field_name(cls, field_name: str) -> bool:
|
|
return True
|
|
|
|
def get_valid_name(
|
|
self,
|
|
name: str,
|
|
excludes: Optional[Set[str]] = None,
|
|
ignore_snake_case_field: bool = False,
|
|
upper_camel: bool = False,
|
|
) -> str:
|
|
if not name:
|
|
name = self.empty_field_name
|
|
if name[0] == '#':
|
|
name = name[1:] or self.empty_field_name
|
|
|
|
if (
|
|
self.snake_case_field
|
|
and not ignore_snake_case_field
|
|
and self.original_delimiter is not None
|
|
):
|
|
name = snake_to_upper_camel(name, delimiter=self.original_delimiter)
|
|
|
|
name = re.sub(r'[¹²³⁴⁵⁶⁷⁸⁹]|\W', '_', name)
|
|
if name[0].isnumeric():
|
|
name = f'{self.special_field_name_prefix}_{name}'
|
|
|
|
# We should avoid having a field begin with an underscore, as it
|
|
# causes pydantic to consider it as private
|
|
while name.startswith('_'):
|
|
if self.remove_special_field_name_prefix:
|
|
name = name[1:]
|
|
else:
|
|
name = f'{self.special_field_name_prefix}{name}'
|
|
break
|
|
if (
|
|
self.capitalise_enum_members
|
|
or self.snake_case_field
|
|
and not ignore_snake_case_field
|
|
):
|
|
name = camel_to_snake(name)
|
|
count = 1
|
|
if iskeyword(name) or not self._validate_field_name(name):
|
|
name += '_'
|
|
if upper_camel:
|
|
new_name = snake_to_upper_camel(name)
|
|
elif self.capitalise_enum_members:
|
|
new_name = name.upper()
|
|
else:
|
|
new_name = name
|
|
while (
|
|
not (new_name.isidentifier() or not self._validate_field_name(new_name))
|
|
or iskeyword(new_name)
|
|
or (excludes and new_name in excludes)
|
|
):
|
|
new_name = f'{name}{count}' if upper_camel else f'{name}_{count}'
|
|
count += 1
|
|
return new_name
|
|
|
|
def get_valid_field_name_and_alias(
|
|
self, field_name: str, excludes: Optional[Set[str]] = None
|
|
) -> Tuple[str, Optional[str]]:
|
|
if field_name in self.aliases:
|
|
return self.aliases[field_name], field_name
|
|
valid_name = self.get_valid_name(field_name, excludes=excludes)
|
|
return valid_name, None if field_name == valid_name else field_name
|
|
|
|
|
|
class PydanticFieldNameResolver(FieldNameResolver):
|
|
@classmethod
|
|
def _validate_field_name(cls, field_name: str) -> bool:
|
|
# TODO: Support Pydantic V2
|
|
return not hasattr(BaseModel, field_name)
|
|
|
|
|
|
class EnumFieldNameResolver(FieldNameResolver):
|
|
def get_valid_name(
|
|
self,
|
|
name: str,
|
|
excludes: Optional[Set[str]] = None,
|
|
ignore_snake_case_field: bool = False,
|
|
upper_camel: bool = False,
|
|
) -> str:
|
|
return super().get_valid_name(
|
|
name='mro_' if name == 'mro' else name,
|
|
excludes={'mro'} | (excludes or set()),
|
|
ignore_snake_case_field=ignore_snake_case_field,
|
|
upper_camel=upper_camel,
|
|
)
|
|
|
|
|
|
class ModelType(Enum):
|
|
PYDANTIC = auto()
|
|
ENUM = auto()
|
|
CLASS = auto()
|
|
|
|
|
|
DEFAULT_FIELD_NAME_RESOLVERS: Dict[ModelType, Type[FieldNameResolver]] = {
|
|
ModelType.ENUM: EnumFieldNameResolver,
|
|
ModelType.PYDANTIC: PydanticFieldNameResolver,
|
|
ModelType.CLASS: FieldNameResolver,
|
|
}
|
|
|
|
|
|
class ClassName(NamedTuple):
|
|
name: str
|
|
duplicate_name: Optional[str]
|
|
|
|
|
|
def get_relative_path(base_path: PurePath, target_path: PurePath) -> PurePath:
|
|
if base_path == target_path:
|
|
return Path('.')
|
|
if not target_path.is_absolute():
|
|
return target_path
|
|
parent_count: int = 0
|
|
children: List[str] = []
|
|
for base_part, target_part in zip_longest(base_path.parts, target_path.parts):
|
|
if base_part == target_part and not parent_count:
|
|
continue
|
|
if base_part or not target_part:
|
|
parent_count += 1
|
|
if target_part:
|
|
children.append(target_part)
|
|
return Path(*['..' for _ in range(parent_count)], *children)
|
|
|
|
|
|
class ModelResolver:
|
|
def __init__(
|
|
self,
|
|
exclude_names: Optional[Set[str]] = None,
|
|
duplicate_name_suffix: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
singular_name_suffix: Optional[str] = None,
|
|
aliases: Optional[Mapping[str, str]] = None,
|
|
snake_case_field: bool = False,
|
|
empty_field_name: Optional[str] = None,
|
|
custom_class_name_generator: Optional[Callable[[str], str]] = None,
|
|
base_path: Optional[Path] = None,
|
|
field_name_resolver_classes: Optional[
|
|
Dict[ModelType, Type[FieldNameResolver]]
|
|
] = None,
|
|
original_field_name_delimiter: Optional[str] = None,
|
|
special_field_name_prefix: Optional[str] = None,
|
|
remove_special_field_name_prefix: bool = False,
|
|
capitalise_enum_members: bool = False,
|
|
) -> None:
|
|
self.references: Dict[str, Reference] = {}
|
|
self._current_root: Sequence[str] = []
|
|
self._root_id: Optional[str] = None
|
|
self._root_id_base_path: Optional[str] = None
|
|
self.ids: DefaultDict[str, Dict[str, str]] = defaultdict(dict)
|
|
self.after_load_files: Set[str] = set()
|
|
self.exclude_names: Set[str] = exclude_names or set()
|
|
self.duplicate_name_suffix: Optional[str] = duplicate_name_suffix
|
|
self._base_url: Optional[str] = base_url
|
|
self.singular_name_suffix: str = (
|
|
singular_name_suffix
|
|
if isinstance(singular_name_suffix, str)
|
|
else SINGULAR_NAME_SUFFIX
|
|
)
|
|
merged_field_name_resolver_classes = DEFAULT_FIELD_NAME_RESOLVERS.copy()
|
|
if field_name_resolver_classes: # pragma: no cover
|
|
merged_field_name_resolver_classes.update(field_name_resolver_classes)
|
|
self.field_name_resolvers: Dict[ModelType, FieldNameResolver] = {
|
|
k: v(
|
|
aliases=aliases,
|
|
snake_case_field=snake_case_field,
|
|
empty_field_name=empty_field_name,
|
|
original_delimiter=original_field_name_delimiter,
|
|
special_field_name_prefix=special_field_name_prefix,
|
|
remove_special_field_name_prefix=remove_special_field_name_prefix,
|
|
capitalise_enum_members=capitalise_enum_members
|
|
if k == ModelType.ENUM
|
|
else False,
|
|
)
|
|
for k, v in merged_field_name_resolver_classes.items()
|
|
}
|
|
self.class_name_generator = (
|
|
custom_class_name_generator or self.default_class_name_generator
|
|
)
|
|
self._base_path: Path = base_path or Path.cwd()
|
|
self._current_base_path: Optional[Path] = self._base_path
|
|
|
|
@property
|
|
def current_base_path(self) -> Optional[Path]:
|
|
return self._current_base_path
|
|
|
|
def set_current_base_path(self, base_path: Optional[Path]) -> None:
|
|
self._current_base_path = base_path
|
|
|
|
@property
|
|
def base_url(self) -> Optional[str]:
|
|
return self._base_url
|
|
|
|
def set_base_url(self, base_url: Optional[str]) -> None:
|
|
self._base_url = base_url
|
|
|
|
@contextmanager
|
|
def current_base_path_context(
|
|
self, base_path: Optional[Path]
|
|
) -> Generator[None, None, None]:
|
|
if base_path:
|
|
base_path = (self._base_path / base_path).resolve()
|
|
with context_variable(
|
|
self.set_current_base_path, self.current_base_path, base_path
|
|
):
|
|
yield
|
|
|
|
@contextmanager
|
|
def base_url_context(self, base_url: str) -> Generator[None, None, None]:
|
|
if self._base_url:
|
|
with context_variable(self.set_base_url, self.base_url, base_url):
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
@property
|
|
def current_root(self) -> Sequence[str]:
|
|
if len(self._current_root) > 1:
|
|
return self._current_root
|
|
return self._current_root
|
|
|
|
def set_current_root(self, current_root: Sequence[str]) -> None:
|
|
self._current_root = current_root
|
|
|
|
@contextmanager
|
|
def current_root_context(
|
|
self, current_root: Sequence[str]
|
|
) -> Generator[None, None, None]:
|
|
with context_variable(self.set_current_root, self.current_root, current_root):
|
|
yield
|
|
|
|
@property
|
|
def root_id(self) -> Optional[str]:
|
|
return self._root_id
|
|
|
|
@property
|
|
def root_id_base_path(self) -> Optional[str]:
|
|
return self._root_id_base_path
|
|
|
|
def set_root_id(self, root_id: Optional[str]) -> None:
|
|
if root_id and '/' in root_id:
|
|
self._root_id_base_path = root_id.rsplit('/', 1)[0]
|
|
else:
|
|
self._root_id_base_path = None
|
|
|
|
self._root_id = root_id
|
|
|
|
def add_id(self, id_: str, path: Sequence[str]) -> None:
|
|
self.ids['/'.join(self.current_root)][id_] = self.resolve_ref(path)
|
|
|
|
def resolve_ref(self, path: Union[Sequence[str], str]) -> str:
|
|
if isinstance(path, str):
|
|
joined_path = path
|
|
else:
|
|
joined_path = self.join_path(path)
|
|
if joined_path == '#':
|
|
return f"{'/'.join(self.current_root)}#"
|
|
if (
|
|
self.current_base_path
|
|
and not self.base_url
|
|
and joined_path[0] != '#'
|
|
and not is_url(joined_path)
|
|
):
|
|
# resolve local file path
|
|
file_path, *object_part = joined_path.split('#', 1)
|
|
resolved_file_path = Path(self.current_base_path, file_path).resolve()
|
|
joined_path = get_relative_path(
|
|
self._base_path, resolved_file_path
|
|
).as_posix()
|
|
if object_part:
|
|
joined_path += f'#{object_part[0]}'
|
|
if ID_PATTERN.match(joined_path):
|
|
ref: str = self.ids['/'.join(self.current_root)][joined_path]
|
|
else:
|
|
if '#' not in joined_path:
|
|
joined_path += '#'
|
|
elif joined_path[0] == '#':
|
|
joined_path = f'{"/".join(self.current_root)}{joined_path}'
|
|
|
|
delimiter = joined_path.index('#')
|
|
file_path = ''.join(joined_path[:delimiter])
|
|
ref = f"{''.join(joined_path[:delimiter])}#{''.join(joined_path[delimiter + 1:])}"
|
|
if self.root_id_base_path and not (
|
|
is_url(joined_path) or Path(self._base_path, file_path).is_file()
|
|
):
|
|
ref = f'{self.root_id_base_path}/{ref}'
|
|
|
|
if self.base_url:
|
|
from .http import join_url
|
|
|
|
joined_url = join_url(self.base_url, ref)
|
|
if '#' in joined_url:
|
|
return joined_url
|
|
return f'{joined_url}#'
|
|
|
|
if is_url(ref):
|
|
file_part, path_part = ref.split('#', 1)
|
|
if file_part == self.root_id:
|
|
return f'{"/".join(self.current_root)}#{path_part}'
|
|
target_url: ParseResult = urlparse(file_part)
|
|
if not (self.root_id and self.current_base_path):
|
|
return ref
|
|
root_id_url: ParseResult = urlparse(self.root_id)
|
|
if (target_url.scheme, target_url.netloc) == (
|
|
root_id_url.scheme,
|
|
root_id_url.netloc,
|
|
): # pragma: no cover
|
|
target_url_path = Path(target_url.path)
|
|
relative_target_base = get_relative_path(
|
|
Path(root_id_url.path).parent, target_url_path.parent
|
|
)
|
|
target_path = (
|
|
self.current_base_path / relative_target_base / target_url_path.name
|
|
)
|
|
if target_path.exists():
|
|
return f'{target_path.resolve().relative_to(self._base_path)}#{path_part}'
|
|
|
|
return ref
|
|
|
|
def is_after_load(self, ref: str) -> bool:
|
|
if is_url(ref) or not self.current_base_path:
|
|
return False
|
|
file_part, *_ = ref.split('#', 1)
|
|
absolute_path = Path(self._base_path, file_part).resolve().as_posix()
|
|
if self.is_external_root_ref(ref):
|
|
return absolute_path in self.after_load_files
|
|
elif self.is_external_ref(ref):
|
|
return absolute_path in self.after_load_files
|
|
return False # pragma: no cover
|
|
|
|
@staticmethod
|
|
def is_external_ref(ref: str) -> bool:
|
|
return '#' in ref and ref[0] != '#'
|
|
|
|
@staticmethod
|
|
def is_external_root_ref(ref: str) -> bool:
|
|
return ref[-1] == '#'
|
|
|
|
@staticmethod
|
|
def join_path(path: Sequence[str]) -> str:
|
|
joined_path = '/'.join(p for p in path if p).replace('/#', '#')
|
|
if '#' not in joined_path:
|
|
joined_path += '#'
|
|
return joined_path
|
|
|
|
def add_ref(self, ref: str, resolved: bool = False) -> Reference:
|
|
if not resolved:
|
|
path = self.resolve_ref(ref)
|
|
else:
|
|
path = ref
|
|
reference = self.references.get(path)
|
|
if reference:
|
|
return reference
|
|
split_ref = ref.rsplit('/', 1)
|
|
if len(split_ref) == 1:
|
|
original_name = Path(
|
|
split_ref[0][:-1] if self.is_external_root_ref(path) else split_ref[0]
|
|
).stem
|
|
else:
|
|
original_name = (
|
|
Path(split_ref[1][:-1]).stem
|
|
if self.is_external_root_ref(path)
|
|
else split_ref[1]
|
|
)
|
|
name = self.get_class_name(original_name, unique=False).name
|
|
reference = Reference(
|
|
path=path,
|
|
original_name=original_name,
|
|
name=name,
|
|
loaded=False,
|
|
)
|
|
|
|
self.references[path] = reference
|
|
return reference
|
|
|
|
def add(
|
|
self,
|
|
path: Sequence[str],
|
|
original_name: str,
|
|
*,
|
|
class_name: bool = False,
|
|
singular_name: bool = False,
|
|
unique: bool = True,
|
|
singular_name_suffix: Optional[str] = None,
|
|
loaded: bool = False,
|
|
) -> Reference:
|
|
joined_path = self.join_path(path)
|
|
reference: Optional[Reference] = self.references.get(joined_path)
|
|
if reference:
|
|
if loaded and not reference.loaded:
|
|
reference.loaded = True
|
|
if (
|
|
not original_name
|
|
or original_name == reference.original_name
|
|
or original_name == reference.name
|
|
):
|
|
return reference
|
|
name = original_name
|
|
duplicate_name: Optional[str] = None
|
|
if class_name:
|
|
name, duplicate_name = self.get_class_name(
|
|
name=name,
|
|
unique=unique,
|
|
reserved_name=reference.name if reference else None,
|
|
singular_name=singular_name,
|
|
singular_name_suffix=singular_name_suffix,
|
|
)
|
|
else:
|
|
# TODO: create a validate for module name
|
|
name = self.get_valid_field_name(name, model_type=ModelType.CLASS)
|
|
if singular_name: # pragma: no cover
|
|
name = get_singular_name(
|
|
name, singular_name_suffix or self.singular_name_suffix
|
|
)
|
|
elif unique: # pragma: no cover
|
|
unique_name = self._get_unique_name(name)
|
|
if unique_name == name:
|
|
duplicate_name = name
|
|
name = unique_name
|
|
if reference:
|
|
reference.original_name = original_name
|
|
reference.name = name
|
|
reference.loaded = loaded
|
|
reference.duplicate_name = duplicate_name
|
|
else:
|
|
reference = Reference(
|
|
path=joined_path,
|
|
original_name=original_name,
|
|
name=name,
|
|
loaded=loaded,
|
|
duplicate_name=duplicate_name,
|
|
)
|
|
self.references[joined_path] = reference
|
|
return reference
|
|
|
|
def get(self, path: Union[Sequence[str], str]) -> Optional[Reference]:
|
|
return self.references.get(self.resolve_ref(path))
|
|
|
|
def delete(self, path: Union[Sequence[str], str]) -> None:
|
|
if self.resolve_ref(path) in self.references:
|
|
del self.references[self.resolve_ref(path)]
|
|
|
|
def default_class_name_generator(self, name: str) -> str:
|
|
# TODO: create a validate for class name
|
|
return self.field_name_resolvers[ModelType.CLASS].get_valid_name(
|
|
name, ignore_snake_case_field=True, upper_camel=True
|
|
)
|
|
|
|
def get_class_name(
|
|
self,
|
|
name: str,
|
|
unique: bool = True,
|
|
reserved_name: Optional[str] = None,
|
|
singular_name: bool = False,
|
|
singular_name_suffix: Optional[str] = None,
|
|
) -> ClassName:
|
|
if '.' in name:
|
|
split_name = name.split('.')
|
|
prefix = '.'.join(
|
|
# TODO: create a validate for class name
|
|
self.field_name_resolvers[ModelType.CLASS].get_valid_name(
|
|
n, ignore_snake_case_field=True
|
|
)
|
|
for n in split_name[:-1]
|
|
)
|
|
prefix += '.'
|
|
class_name = split_name[-1]
|
|
else:
|
|
prefix = ''
|
|
class_name = name
|
|
|
|
class_name = self.class_name_generator(class_name)
|
|
|
|
if singular_name:
|
|
class_name = get_singular_name(
|
|
class_name, singular_name_suffix or self.singular_name_suffix
|
|
)
|
|
duplicate_name: Optional[str] = None
|
|
if unique:
|
|
if reserved_name == class_name:
|
|
return ClassName(name=class_name, duplicate_name=duplicate_name)
|
|
|
|
unique_name = self._get_unique_name(class_name, camel=True)
|
|
if unique_name != class_name:
|
|
duplicate_name = class_name
|
|
class_name = unique_name
|
|
return ClassName(name=f'{prefix}{class_name}', duplicate_name=duplicate_name)
|
|
|
|
def _get_unique_name(self, name: str, camel: bool = False) -> str:
|
|
unique_name: str = name
|
|
count: int = 1
|
|
reference_names = {
|
|
r.name for r in self.references.values()
|
|
} | self.exclude_names
|
|
while unique_name in reference_names:
|
|
if self.duplicate_name_suffix:
|
|
name_parts: List[Union[str, int]] = [
|
|
name,
|
|
self.duplicate_name_suffix,
|
|
count - 1,
|
|
]
|
|
else:
|
|
name_parts = [name, count]
|
|
delimiter = '' if camel else '_'
|
|
unique_name = delimiter.join(str(p) for p in name_parts if p)
|
|
count += 1
|
|
return unique_name
|
|
|
|
@classmethod
|
|
def validate_name(cls, name: str) -> bool:
|
|
return name.isidentifier() and not iskeyword(name)
|
|
|
|
def get_valid_field_name(
|
|
self,
|
|
name: str,
|
|
excludes: Optional[Set[str]] = None,
|
|
model_type: ModelType = ModelType.PYDANTIC,
|
|
) -> str:
|
|
return self.field_name_resolvers[model_type].get_valid_name(name, excludes)
|
|
|
|
def get_valid_field_name_and_alias(
|
|
self,
|
|
field_name: str,
|
|
excludes: Optional[Set[str]] = None,
|
|
model_type: ModelType = ModelType.PYDANTIC,
|
|
) -> Tuple[str, Optional[str]]:
|
|
return self.field_name_resolvers[model_type].get_valid_field_name_and_alias(
|
|
field_name, excludes
|
|
)
|
|
|
|
|
|
@lru_cache()
|
|
def get_singular_name(name: str, suffix: str = SINGULAR_NAME_SUFFIX) -> str:
|
|
singular_name = inflect_engine.singular_noun(name)
|
|
if singular_name is False:
|
|
singular_name = f'{name}{suffix}'
|
|
return singular_name
|
|
|
|
|
|
@lru_cache()
|
|
def snake_to_upper_camel(word: str, delimiter: str = '_') -> str:
|
|
prefix = ''
|
|
if word.startswith(delimiter):
|
|
prefix = '_'
|
|
word = word[1:]
|
|
|
|
return prefix + ''.join(x[0].upper() + x[1:] for x in word.split(delimiter) if x)
|
|
|
|
|
|
def is_url(ref: str) -> bool:
|
|
return ref.startswith(('https://', 'http://'))
|
|
|
|
|
|
inflect_engine = inflect.engine()
|