mirror of
https://github.com/koxudaxi/datamodel-code-generator.git
synced 2024-03-18 14:54:37 +03:00
Improve union types (#1241)
* Improve union types * Add unittest * Add unittest * Add unittest * fix edge case * fix coverage
This commit is contained in:
@@ -23,14 +23,25 @@ from warnings import warn
|
||||
from jinja2 import Environment, FileSystemLoader, Template
|
||||
|
||||
from datamodel_code_generator import cached_property
|
||||
from datamodel_code_generator.imports import IMPORT_ANNOTATED, IMPORT_OPTIONAL, Import
|
||||
from datamodel_code_generator.imports import (
|
||||
IMPORT_ANNOTATED,
|
||||
IMPORT_OPTIONAL,
|
||||
IMPORT_UNION,
|
||||
Import,
|
||||
)
|
||||
from datamodel_code_generator.reference import Reference, _BaseModel
|
||||
from datamodel_code_generator.types import DataType, Nullable, chain_as_tuple
|
||||
from datamodel_code_generator.types import (
|
||||
ANY,
|
||||
NONE,
|
||||
UNION_PREFIX,
|
||||
DataType,
|
||||
Nullable,
|
||||
chain_as_tuple,
|
||||
get_optional_type,
|
||||
)
|
||||
|
||||
TEMPLATE_DIR: Path = Path(__file__).parents[0] / 'template'
|
||||
|
||||
OPTIONAL: str = 'Optional'
|
||||
|
||||
ALL_MODEL: str = '#all#'
|
||||
|
||||
|
||||
@@ -68,7 +79,7 @@ class DataModelFieldBase(_BaseModel):
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
if self.data_type.reference or self.data_type.data_types:
|
||||
self.data_type.parent = self
|
||||
@@ -83,28 +94,29 @@ class DataModelFieldBase(_BaseModel):
|
||||
type_hint = self.data_type.type_hint
|
||||
|
||||
if not type_hint:
|
||||
return OPTIONAL
|
||||
elif self.data_type.is_optional and self.data_type.type != 'Any':
|
||||
return NONE
|
||||
elif self.data_type.is_optional and self.data_type.type != ANY:
|
||||
return type_hint
|
||||
elif self.nullable is not None:
|
||||
if self.nullable:
|
||||
if self.data_type.use_union_operator:
|
||||
return f'{type_hint} | None'
|
||||
else:
|
||||
return f'{OPTIONAL}[{type_hint}]'
|
||||
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
||||
return type_hint
|
||||
elif self.required:
|
||||
return type_hint
|
||||
if self.data_type.use_union_operator:
|
||||
return f'{type_hint} | None'
|
||||
else:
|
||||
return f'{OPTIONAL}[{type_hint}]'
|
||||
return get_optional_type(type_hint, self.data_type.use_union_operator)
|
||||
|
||||
@property
|
||||
def imports(self) -> Tuple[Import, ...]:
|
||||
type_hint = self.type_hint
|
||||
has_union = not self.data_type.use_union_operator and UNION_PREFIX in type_hint
|
||||
imports: List[Union[Tuple[Import], Iterator[Import]]] = [
|
||||
self.data_type.all_imports
|
||||
(
|
||||
i
|
||||
for i in self.data_type.all_imports
|
||||
if not (not has_union and i == IMPORT_UNION)
|
||||
)
|
||||
]
|
||||
|
||||
if (
|
||||
self.nullable or (self.nullable is None and not self.required)
|
||||
) and not self.data_type.use_union_operator:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -14,6 +16,7 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
@@ -45,6 +48,26 @@ from datamodel_code_generator.reference import Reference, _BaseModel
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
OPTIONAL = 'Optional'
|
||||
OPTIONAL_PREFIX = f'{OPTIONAL}['
|
||||
|
||||
UNION = 'Union'
|
||||
UNION_PREFIX = f'{UNION}['
|
||||
UNION_DELIMITER = ', '
|
||||
UNION_PATTERN: Pattern[str] = re.compile(r'\s*,\s*')
|
||||
UNION_OPERATOR_DELIMITER = ' | '
|
||||
UNION_OPERATOR_PATTERN: Pattern[str] = re.compile(r'\s*\|\s*')
|
||||
NONE = 'None'
|
||||
ANY = 'Any'
|
||||
LITERAL = 'Literal'
|
||||
SEQUENCE = 'Sequence'
|
||||
MAPPING = 'Mapping'
|
||||
DICT = 'Dict'
|
||||
LIST = 'List'
|
||||
STANDARD_DICT = 'dict'
|
||||
STANDARD_LIST = 'list'
|
||||
STR = 'str'
|
||||
|
||||
|
||||
class StrictTypes(Enum):
|
||||
str = 'str'
|
||||
@@ -79,6 +102,63 @@ def chain_as_tuple(*iterables: Iterable[T]) -> Tuple[T, ...]:
|
||||
return tuple(chain(*iterables))
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _remove_none_from_type(
|
||||
type_: str, split_pattern: Pattern[str], delimiter: str
|
||||
) -> List[str]:
|
||||
types: List[str] = []
|
||||
split_type: str = ''
|
||||
inner_count: int = 0
|
||||
for part in re.split(split_pattern, type_):
|
||||
if part == NONE:
|
||||
continue
|
||||
inner_count += part.count('[') - part.count(']')
|
||||
if split_type:
|
||||
split_type += delimiter
|
||||
if inner_count == 0:
|
||||
if split_type:
|
||||
types.append(f'{split_type}{part}')
|
||||
else:
|
||||
types.append(part)
|
||||
split_type = ''
|
||||
continue
|
||||
else:
|
||||
split_type += part
|
||||
return types
|
||||
|
||||
|
||||
def _remove_none_from_union(type_: str, use_union_operator: bool) -> str:
|
||||
if use_union_operator:
|
||||
if not re.match(r'^\w+ | ', type_):
|
||||
return type_
|
||||
return UNION_OPERATOR_DELIMITER.join(
|
||||
_remove_none_from_type(
|
||||
type_, UNION_OPERATOR_PATTERN, UNION_OPERATOR_DELIMITER
|
||||
)
|
||||
)
|
||||
|
||||
if not type_.startswith(UNION_PREFIX):
|
||||
return type_
|
||||
inner_types = _remove_none_from_type(
|
||||
type_[len(UNION_PREFIX) :][:-1], UNION_PATTERN, UNION_DELIMITER
|
||||
)
|
||||
|
||||
if len(inner_types) == 1:
|
||||
return inner_types[0]
|
||||
return f'{UNION_PREFIX}{UNION_DELIMITER.join(inner_types)}]'
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_optional_type(type_: str, use_union_operator: bool) -> str:
|
||||
type_ = _remove_none_from_union(type_, use_union_operator)
|
||||
|
||||
if not type_ or type_ == NONE:
|
||||
return NONE
|
||||
if use_union_operator:
|
||||
return f'{type_} | {NONE}'
|
||||
return f'{OPTIONAL_PREFIX}{type_}]'
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Modular(Protocol):
|
||||
@property
|
||||
@@ -254,15 +334,13 @@ class DataType(_BaseModel):
|
||||
super().__init__(**values)
|
||||
|
||||
for type_ in self.data_types:
|
||||
if type_.type == 'Any' and type_.is_optional:
|
||||
if any(
|
||||
t for t in self.data_types if t.type != 'Any'
|
||||
): # pragma: no cover
|
||||
if type_.type == ANY and type_.is_optional:
|
||||
if any(t for t in self.data_types if t.type != ANY): # pragma: no cover
|
||||
self.is_optional = True
|
||||
self.data_types = [
|
||||
t
|
||||
for t in self.data_types
|
||||
if not (t.type == 'Any' and t.is_optional)
|
||||
if not (t.type == ANY and t.is_optional)
|
||||
]
|
||||
break
|
||||
|
||||
@@ -278,18 +356,20 @@ class DataType(_BaseModel):
|
||||
type_: Optional[str] = self.alias or self.type
|
||||
if not type_:
|
||||
if self.is_union:
|
||||
data_types: List[str] = []
|
||||
for data_type in self.data_types:
|
||||
data_type_type = data_type.type_hint
|
||||
if data_type_type in data_types: # pragma: no cover
|
||||
continue
|
||||
data_types.append(data_type_type)
|
||||
if self.use_union_operator:
|
||||
type_ = ' | '.join(
|
||||
data_type.type_hint for data_type in self.data_types
|
||||
)
|
||||
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
|
||||
else:
|
||||
type_ = f"Union[{', '.join(data_type.type_hint for data_type in self.data_types)}]"
|
||||
type_ = f'{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]'
|
||||
elif len(self.data_types) == 1:
|
||||
type_ = self.data_types[0].type_hint
|
||||
elif self.literals:
|
||||
type_ = (
|
||||
f"Literal[{', '.join(repr(literal) for literal in self.literals)}]"
|
||||
)
|
||||
type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
|
||||
else:
|
||||
if self.reference:
|
||||
type_ = self.reference.short_name
|
||||
@@ -305,29 +385,26 @@ class DataType(_BaseModel):
|
||||
type_ = f"'{type_}'"
|
||||
if self.is_list:
|
||||
if self.use_generic_container:
|
||||
list_ = 'Sequence'
|
||||
list_ = SEQUENCE
|
||||
elif self.use_standard_collections:
|
||||
list_ = 'list'
|
||||
list_ = STANDARD_LIST
|
||||
else:
|
||||
list_ = 'List'
|
||||
list_ = LIST
|
||||
type_ = f'{list_}[{type_}]' if type_ else list_
|
||||
elif self.is_dict:
|
||||
if self.use_generic_container:
|
||||
dict_ = 'Mapping'
|
||||
dict_ = MAPPING
|
||||
elif self.use_standard_collections:
|
||||
dict_ = 'dict'
|
||||
dict_ = STANDARD_DICT
|
||||
else:
|
||||
dict_ = 'Dict'
|
||||
dict_ = DICT
|
||||
if self.dict_key or type_:
|
||||
key = self.dict_key.type_hint if self.dict_key else 'str'
|
||||
type_ = f'{dict_}[{key}, {type_ or "Any"}]'
|
||||
key = self.dict_key.type_hint if self.dict_key else STR
|
||||
type_ = f'{dict_}[{key}, {type_ or ANY}]'
|
||||
else: # pragma: no cover
|
||||
type_ = dict_
|
||||
if self.is_optional and type_ != 'Any':
|
||||
if self.use_union_operator: # pragma: no cover
|
||||
type_ = f'{type_} | None'
|
||||
else:
|
||||
type_ = f'Optional[{type_}]'
|
||||
if self.is_optional and type_ != ANY:
|
||||
return get_optional_type(type_, self.use_union_operator)
|
||||
elif self.is_func:
|
||||
if self.kwargs:
|
||||
kwargs: str = ', '.join(f'{k}={v}' for k, v in self.kwargs.items())
|
||||
|
||||
31
tests/data/expected/main/main_nullable_any_of/output.py
Normal file
31
tests/data/expected/main/main_nullable_any_of/output.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: nullable_any_of.json
|
||||
# timestamp: 2019-07-26T00:00:00+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
|
||||
class ConfigItem(BaseModel):
|
||||
__root__: str = Field(..., description='d2', min_length=1, title='t2')
|
||||
|
||||
|
||||
class In(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
input_dataset_path: Optional[str] = Field(
|
||||
None, description='d1', min_length=1, title='Path to the input dataset'
|
||||
)
|
||||
config: Optional[ConfigItem] = None
|
||||
|
||||
|
||||
class ValidatingSchemaId1(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
in_: Optional[In] = Field(None, alias='in')
|
||||
n1: Optional[int] = None
|
||||
@@ -0,0 +1,29 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: nullable_any_of.json
|
||||
# timestamp: 2019-07-26T00:00:00+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
|
||||
class ConfigItem(BaseModel):
|
||||
__root__: str = Field(..., description='d2', min_length=1, title='t2')
|
||||
|
||||
|
||||
class In(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
input_dataset_path: str | None = Field(
|
||||
None, description='d1', min_length=1, title='Path to the input dataset'
|
||||
)
|
||||
config: ConfigItem | None = None
|
||||
|
||||
|
||||
class ValidatingSchemaId1(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
in_: In | None = Field(None, alias='in')
|
||||
n1: int | None = None
|
||||
41
tests/data/jsonschema/nullable_any_of.json
Normal file
41
tests/data/jsonschema/nullable_any_of.json
Normal file
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"$id": "id1",
|
||||
"title": "Validating Schema ID1",
|
||||
"properties": {
|
||||
"in": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"input_dataset_path": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Path to the input dataset",
|
||||
"description": "d1"
|
||||
},
|
||||
"config": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "t2",
|
||||
"description": "d2"
|
||||
},
|
||||
{
|
||||
"type": [
|
||||
"null"
|
||||
],
|
||||
"title": "t3",
|
||||
"description": "d3"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"n1": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ from datamodel_code_generator.types import DataType, Types
|
||||
|
||||
|
||||
class A(TemplateBase):
|
||||
def __init__(self, path: Path):
|
||||
def __init__(self, path: Path) -> None:
|
||||
self._path = path
|
||||
|
||||
@property
|
||||
@@ -30,7 +30,7 @@ class B(DataModel):
|
||||
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
TEMPLATE_FILE_PATH = ''
|
||||
@@ -135,7 +135,7 @@ def test_data_field():
|
||||
)
|
||||
assert field.type_hint == 'List'
|
||||
field = DataModelFieldBase(name='a', data_type=DataType(), required=False)
|
||||
assert field.type_hint == 'Optional'
|
||||
assert field.type_hint == 'None'
|
||||
field = DataModelFieldBase(
|
||||
name='a',
|
||||
data_type=DataType(is_list=True),
|
||||
@@ -147,11 +147,11 @@ def test_data_field():
|
||||
field = DataModelFieldBase(
|
||||
name='a', data_type=DataType(), required=False, is_list=False, is_union=True
|
||||
)
|
||||
assert field.type_hint == 'Optional'
|
||||
assert field.type_hint == 'None'
|
||||
field = DataModelFieldBase(
|
||||
name='a', data_type=DataType(), required=False, is_list=False, is_union=False
|
||||
)
|
||||
assert field.type_hint == 'Optional'
|
||||
assert field.type_hint == 'None'
|
||||
field = DataModelFieldBase(
|
||||
name='a',
|
||||
data_type=DataType(is_list=True),
|
||||
|
||||
@@ -5054,3 +5054,54 @@ def test_main_multiple_required_any_of():
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main()
|
||||
|
||||
|
||||
@freeze_time('2019-07-26')
|
||||
def test_main_nullable_any_of():
|
||||
with TemporaryDirectory() as output_dir:
|
||||
output_file: Path = Path(output_dir) / 'output.py'
|
||||
return_code: Exit = main(
|
||||
[
|
||||
'--input',
|
||||
str(JSON_SCHEMA_DATA_PATH / 'nullable_any_of.json'),
|
||||
'--output',
|
||||
str(output_file),
|
||||
'--field-constraints',
|
||||
]
|
||||
)
|
||||
assert return_code == Exit.OK
|
||||
assert (
|
||||
output_file.read_text()
|
||||
== (EXPECTED_MAIN_PATH / 'main_nullable_any_of' / 'output.py').read_text()
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main()
|
||||
|
||||
|
||||
@freeze_time('2019-07-26')
|
||||
def test_main_nullable_any_of_use_union_operator():
|
||||
with TemporaryDirectory() as output_dir:
|
||||
output_file: Path = Path(output_dir) / 'output.py'
|
||||
return_code: Exit = main(
|
||||
[
|
||||
'--input',
|
||||
str(JSON_SCHEMA_DATA_PATH / 'nullable_any_of.json'),
|
||||
'--output',
|
||||
str(output_file),
|
||||
'--field-constraints',
|
||||
'--use-union-operator',
|
||||
]
|
||||
)
|
||||
assert return_code == Exit.OK
|
||||
assert (
|
||||
output_file.read_text()
|
||||
== (
|
||||
EXPECTED_MAIN_PATH
|
||||
/ 'main_nullable_any_of_use_union_operator'
|
||||
/ 'output.py'
|
||||
).read_text()
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
main()
|
||||
|
||||
41
tests/test_types.py
Normal file
41
tests/test_types.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
|
||||
from datamodel_code_generator.types import get_optional_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'input_,use_union_operator,expected',
|
||||
[
|
||||
('List[str]', False, 'Optional[List[str]]'),
|
||||
('List[str, int, float]', False, 'Optional[List[str, int, float]]'),
|
||||
('List[str, int, None]', False, 'Optional[List[str, int, None]]'),
|
||||
('Union[str]', False, 'Optional[str]'),
|
||||
('Union[str, int, float]', False, 'Optional[Union[str, int, float]]'),
|
||||
('Union[str, int, None]', False, 'Optional[Union[str, int]]'),
|
||||
('Union[str, int, None, None]', False, 'Optional[Union[str, int]]'),
|
||||
(
|
||||
'Union[str, int, List[str, int, None], None]',
|
||||
False,
|
||||
'Optional[Union[str, int, List[str, int, None]]]',
|
||||
),
|
||||
(
|
||||
'Union[str, int, List[str, Dict[int, str | None]], None]',
|
||||
False,
|
||||
'Optional[Union[str, int, List[str, Dict[int, str | None]]]]',
|
||||
),
|
||||
('List[str]', True, 'List[str] | None'),
|
||||
('List[str | int | float]', True, 'List[str | int | float] | None'),
|
||||
('List[str | int | None]', True, 'List[str | int | None] | None'),
|
||||
('str', True, 'str | None'),
|
||||
('str | int | float', True, 'str | int | float | None'),
|
||||
('str | int | None', True, 'str | int | None'),
|
||||
('str | int | None | None', True, 'str | int | None'),
|
||||
(
|
||||
'str | int | List[str | Dict[int | Union[str | None]]] | None',
|
||||
True,
|
||||
'str | int | List[str | Dict[int | Union[str | None]]] | None',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_optional_type(input_: str, use_union_operator: bool, expected: str):
|
||||
assert get_optional_type(input_, use_union_operator) == expected
|
||||
Reference in New Issue
Block a user