Improve union types (#1241)

* Improve union types

* Add unittest

* Add unittest

* Add unittest

* fix edge case

* fix coverage
This commit is contained in:
Koudai Aono
2023-04-12 21:43:55 +09:00
committed by GitHub
parent f21c4a1ae7
commit aef89d3161
8 changed files with 328 additions and 46 deletions

View File

@@ -23,14 +23,25 @@ from warnings import warn
from jinja2 import Environment, FileSystemLoader, Template
from datamodel_code_generator import cached_property
from datamodel_code_generator.imports import IMPORT_ANNOTATED, IMPORT_OPTIONAL, Import
from datamodel_code_generator.imports import (
IMPORT_ANNOTATED,
IMPORT_OPTIONAL,
IMPORT_UNION,
Import,
)
from datamodel_code_generator.reference import Reference, _BaseModel
from datamodel_code_generator.types import DataType, Nullable, chain_as_tuple
from datamodel_code_generator.types import (
ANY,
NONE,
UNION_PREFIX,
DataType,
Nullable,
chain_as_tuple,
get_optional_type,
)
TEMPLATE_DIR: Path = Path(__file__).parents[0] / 'template'
OPTIONAL: str = 'Optional'
ALL_MODEL: str = '#all#'
@@ -68,7 +79,7 @@ class DataModelFieldBase(_BaseModel):
if not TYPE_CHECKING:
def __init__(self, **data: Any):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
if self.data_type.reference or self.data_type.data_types:
self.data_type.parent = self
@@ -83,28 +94,29 @@ class DataModelFieldBase(_BaseModel):
type_hint = self.data_type.type_hint
if not type_hint:
return OPTIONAL
elif self.data_type.is_optional and self.data_type.type != 'Any':
return NONE
elif self.data_type.is_optional and self.data_type.type != ANY:
return type_hint
elif self.nullable is not None:
if self.nullable:
if self.data_type.use_union_operator:
return f'{type_hint} | None'
else:
return f'{OPTIONAL}[{type_hint}]'
return get_optional_type(type_hint, self.data_type.use_union_operator)
return type_hint
elif self.required:
return type_hint
if self.data_type.use_union_operator:
return f'{type_hint} | None'
else:
return f'{OPTIONAL}[{type_hint}]'
return get_optional_type(type_hint, self.data_type.use_union_operator)
@property
def imports(self) -> Tuple[Import, ...]:
type_hint = self.type_hint
has_union = not self.data_type.use_union_operator and UNION_PREFIX in type_hint
imports: List[Union[Tuple[Import], Iterator[Import]]] = [
self.data_type.all_imports
(
i
for i in self.data_type.all_imports
if not (not has_union and i == IMPORT_UNION)
)
]
if (
self.nullable or (self.nullable is None and not self.required)
) and not self.data_type.use_union_operator:

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from enum import Enum, auto
from functools import lru_cache
from itertools import chain
from typing import (
TYPE_CHECKING,
@@ -14,6 +16,7 @@ from typing import (
Iterator,
List,
Optional,
Pattern,
Sequence,
Set,
Tuple,
@@ -45,6 +48,26 @@ from datamodel_code_generator.reference import Reference, _BaseModel
T = TypeVar('T')
OPTIONAL = 'Optional'
OPTIONAL_PREFIX = f'{OPTIONAL}['
UNION = 'Union'
UNION_PREFIX = f'{UNION}['
UNION_DELIMITER = ', '
UNION_PATTERN: Pattern[str] = re.compile(r'\s*,\s*')
UNION_OPERATOR_DELIMITER = ' | '
UNION_OPERATOR_PATTERN: Pattern[str] = re.compile(r'\s*\|\s*')
NONE = 'None'
ANY = 'Any'
LITERAL = 'Literal'
SEQUENCE = 'Sequence'
MAPPING = 'Mapping'
DICT = 'Dict'
LIST = 'List'
STANDARD_DICT = 'dict'
STANDARD_LIST = 'list'
STR = 'str'
class StrictTypes(Enum):
str = 'str'
@@ -79,6 +102,63 @@ def chain_as_tuple(*iterables: Iterable[T]) -> Tuple[T, ...]:
return tuple(chain(*iterables))
@lru_cache()
def _remove_none_from_type(
type_: str, split_pattern: Pattern[str], delimiter: str
) -> List[str]:
types: List[str] = []
split_type: str = ''
inner_count: int = 0
for part in re.split(split_pattern, type_):
if part == NONE:
continue
inner_count += part.count('[') - part.count(']')
if split_type:
split_type += delimiter
if inner_count == 0:
if split_type:
types.append(f'{split_type}{part}')
else:
types.append(part)
split_type = ''
continue
else:
split_type += part
return types
def _remove_none_from_union(type_: str, use_union_operator: bool) -> str:
if use_union_operator:
if not re.match(r'^\w+ | ', type_):
return type_
return UNION_OPERATOR_DELIMITER.join(
_remove_none_from_type(
type_, UNION_OPERATOR_PATTERN, UNION_OPERATOR_DELIMITER
)
)
if not type_.startswith(UNION_PREFIX):
return type_
inner_types = _remove_none_from_type(
type_[len(UNION_PREFIX) :][:-1], UNION_PATTERN, UNION_DELIMITER
)
if len(inner_types) == 1:
return inner_types[0]
return f'{UNION_PREFIX}{UNION_DELIMITER.join(inner_types)}]'
@lru_cache()
def get_optional_type(type_: str, use_union_operator: bool) -> str:
type_ = _remove_none_from_union(type_, use_union_operator)
if not type_ or type_ == NONE:
return NONE
if use_union_operator:
return f'{type_} | {NONE}'
return f'{OPTIONAL_PREFIX}{type_}]'
@runtime_checkable
class Modular(Protocol):
@property
@@ -254,15 +334,13 @@ class DataType(_BaseModel):
super().__init__(**values)
for type_ in self.data_types:
if type_.type == 'Any' and type_.is_optional:
if any(
t for t in self.data_types if t.type != 'Any'
): # pragma: no cover
if type_.type == ANY and type_.is_optional:
if any(t for t in self.data_types if t.type != ANY): # pragma: no cover
self.is_optional = True
self.data_types = [
t
for t in self.data_types
if not (t.type == 'Any' and t.is_optional)
if not (t.type == ANY and t.is_optional)
]
break
@@ -278,18 +356,20 @@ class DataType(_BaseModel):
type_: Optional[str] = self.alias or self.type
if not type_:
if self.is_union:
data_types: List[str] = []
for data_type in self.data_types:
data_type_type = data_type.type_hint
if data_type_type in data_types: # pragma: no cover
continue
data_types.append(data_type_type)
if self.use_union_operator:
type_ = ' | '.join(
data_type.type_hint for data_type in self.data_types
)
type_ = UNION_OPERATOR_DELIMITER.join(data_types)
else:
type_ = f"Union[{', '.join(data_type.type_hint for data_type in self.data_types)}]"
type_ = f'{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]'
elif len(self.data_types) == 1:
type_ = self.data_types[0].type_hint
elif self.literals:
type_ = (
f"Literal[{', '.join(repr(literal) for literal in self.literals)}]"
)
type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]"
else:
if self.reference:
type_ = self.reference.short_name
@@ -305,29 +385,26 @@ class DataType(_BaseModel):
type_ = f"'{type_}'"
if self.is_list:
if self.use_generic_container:
list_ = 'Sequence'
list_ = SEQUENCE
elif self.use_standard_collections:
list_ = 'list'
list_ = STANDARD_LIST
else:
list_ = 'List'
list_ = LIST
type_ = f'{list_}[{type_}]' if type_ else list_
elif self.is_dict:
if self.use_generic_container:
dict_ = 'Mapping'
dict_ = MAPPING
elif self.use_standard_collections:
dict_ = 'dict'
dict_ = STANDARD_DICT
else:
dict_ = 'Dict'
dict_ = DICT
if self.dict_key or type_:
key = self.dict_key.type_hint if self.dict_key else 'str'
type_ = f'{dict_}[{key}, {type_ or "Any"}]'
key = self.dict_key.type_hint if self.dict_key else STR
type_ = f'{dict_}[{key}, {type_ or ANY}]'
else: # pragma: no cover
type_ = dict_
if self.is_optional and type_ != 'Any':
if self.use_union_operator: # pragma: no cover
type_ = f'{type_} | None'
else:
type_ = f'Optional[{type_}]'
if self.is_optional and type_ != ANY:
return get_optional_type(type_, self.use_union_operator)
elif self.is_func:
if self.kwargs:
kwargs: str = ', '.join(f'{k}={v}' for k, v in self.kwargs.items())

View File

@@ -0,0 +1,31 @@
# generated by datamodel-codegen:
# filename: nullable_any_of.json
# timestamp: 2019-07-26T00:00:00+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Extra, Field
class ConfigItem(BaseModel):
__root__: str = Field(..., description='d2', min_length=1, title='t2')
class In(BaseModel):
class Config:
extra = Extra.forbid
input_dataset_path: Optional[str] = Field(
None, description='d1', min_length=1, title='Path to the input dataset'
)
config: Optional[ConfigItem] = None
class ValidatingSchemaId1(BaseModel):
class Config:
extra = Extra.forbid
in_: Optional[In] = Field(None, alias='in')
n1: Optional[int] = None

View File

@@ -0,0 +1,29 @@
# generated by datamodel-codegen:
# filename: nullable_any_of.json
# timestamp: 2019-07-26T00:00:00+00:00
from __future__ import annotations
from pydantic import BaseModel, Extra, Field
class ConfigItem(BaseModel):
__root__: str = Field(..., description='d2', min_length=1, title='t2')
class In(BaseModel):
class Config:
extra = Extra.forbid
input_dataset_path: str | None = Field(
None, description='d1', min_length=1, title='Path to the input dataset'
)
config: ConfigItem | None = None
class ValidatingSchemaId1(BaseModel):
class Config:
extra = Extra.forbid
in_: In | None = Field(None, alias='in')
n1: int | None = None

View File

@@ -0,0 +1,41 @@
{
"type": "object",
"additionalProperties": false,
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "id1",
"title": "Validating Schema ID1",
"properties": {
"in": {
"type": "object",
"additionalProperties": false,
"properties": {
"input_dataset_path": {
"type": "string",
"minLength": 1,
"title": "Path to the input dataset",
"description": "d1"
},
"config": {
"anyOf": [
{
"type": "string",
"minLength": 1,
"title": "t2",
"description": "d2"
},
{
"type": [
"null"
],
"title": "t3",
"description": "d3"
}
]
}
}
},
"n1": {
"type": "integer"
}
}
}

View File

@@ -14,7 +14,7 @@ from datamodel_code_generator.types import DataType, Types
class A(TemplateBase):
def __init__(self, path: Path):
def __init__(self, path: Path) -> None:
self._path = path
@property
@@ -30,7 +30,7 @@ class B(DataModel):
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
pass
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
TEMPLATE_FILE_PATH = ''
@@ -135,7 +135,7 @@ def test_data_field():
)
assert field.type_hint == 'List'
field = DataModelFieldBase(name='a', data_type=DataType(), required=False)
assert field.type_hint == 'Optional'
assert field.type_hint == 'None'
field = DataModelFieldBase(
name='a',
data_type=DataType(is_list=True),
@@ -147,11 +147,11 @@ def test_data_field():
field = DataModelFieldBase(
name='a', data_type=DataType(), required=False, is_list=False, is_union=True
)
assert field.type_hint == 'Optional'
assert field.type_hint == 'None'
field = DataModelFieldBase(
name='a', data_type=DataType(), required=False, is_list=False, is_union=False
)
assert field.type_hint == 'Optional'
assert field.type_hint == 'None'
field = DataModelFieldBase(
name='a',
data_type=DataType(is_list=True),

View File

@@ -5054,3 +5054,54 @@ def test_main_multiple_required_any_of():
with pytest.raises(SystemExit):
main()
@freeze_time('2019-07-26')
def test_main_nullable_any_of():
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(JSON_SCHEMA_DATA_PATH / 'nullable_any_of.json'),
'--output',
str(output_file),
'--field-constraints',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (EXPECTED_MAIN_PATH / 'main_nullable_any_of' / 'output.py').read_text()
)
with pytest.raises(SystemExit):
main()
@freeze_time('2019-07-26')
def test_main_nullable_any_of_use_union_operator():
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(JSON_SCHEMA_DATA_PATH / 'nullable_any_of.json'),
'--output',
str(output_file),
'--field-constraints',
'--use-union-operator',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_MAIN_PATH
/ 'main_nullable_any_of_use_union_operator'
/ 'output.py'
).read_text()
)
with pytest.raises(SystemExit):
main()

41
tests/test_types.py Normal file
View File

@@ -0,0 +1,41 @@
import pytest
from datamodel_code_generator.types import get_optional_type
@pytest.mark.parametrize(
'input_,use_union_operator,expected',
[
('List[str]', False, 'Optional[List[str]]'),
('List[str, int, float]', False, 'Optional[List[str, int, float]]'),
('List[str, int, None]', False, 'Optional[List[str, int, None]]'),
('Union[str]', False, 'Optional[str]'),
('Union[str, int, float]', False, 'Optional[Union[str, int, float]]'),
('Union[str, int, None]', False, 'Optional[Union[str, int]]'),
('Union[str, int, None, None]', False, 'Optional[Union[str, int]]'),
(
'Union[str, int, List[str, int, None], None]',
False,
'Optional[Union[str, int, List[str, int, None]]]',
),
(
'Union[str, int, List[str, Dict[int, str | None]], None]',
False,
'Optional[Union[str, int, List[str, Dict[int, str | None]]]]',
),
('List[str]', True, 'List[str] | None'),
('List[str | int | float]', True, 'List[str | int | float] | None'),
('List[str | int | None]', True, 'List[str | int | None] | None'),
('str', True, 'str | None'),
('str | int | float', True, 'str | int | float | None'),
('str | int | None', True, 'str | int | None'),
('str | int | None | None', True, 'str | int | None'),
(
'str | int | List[str | Dict[int | Union[str | None]]] | None',
True,
'str | int | List[str | Dict[int | Union[str | None]]] | None',
),
],
)
def test_get_optional_type(input_: str, use_union_operator: bool, expected: str):
assert get_optional_type(input_, use_union_operator) == expected