From aef89d31616f787e428de8ab0511df39ecfa1af6 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Wed, 12 Apr 2023 21:43:55 +0900 Subject: [PATCH] Improve union types (#1241) * Improve union types * Add unittest * Add unittest * Add unittest * fix edge case * fix coverage --- datamodel_code_generator/model/base.py | 44 +++--- datamodel_code_generator/types.py | 127 ++++++++++++++---- .../main/main_nullable_any_of/output.py | 31 +++++ .../output.py | 29 ++++ tests/data/jsonschema/nullable_any_of.json | 41 ++++++ tests/model/test_base.py | 10 +- tests/test_main.py | 51 +++++++ tests/test_types.py | 41 ++++++ 8 files changed, 328 insertions(+), 46 deletions(-) create mode 100644 tests/data/expected/main/main_nullable_any_of/output.py create mode 100644 tests/data/expected/main/main_nullable_any_of_use_union_operator/output.py create mode 100644 tests/data/jsonschema/nullable_any_of.json create mode 100644 tests/test_types.py diff --git a/datamodel_code_generator/model/base.py b/datamodel_code_generator/model/base.py index 2b13b8a1..816ef8ac 100644 --- a/datamodel_code_generator/model/base.py +++ b/datamodel_code_generator/model/base.py @@ -23,14 +23,25 @@ from warnings import warn from jinja2 import Environment, FileSystemLoader, Template from datamodel_code_generator import cached_property -from datamodel_code_generator.imports import IMPORT_ANNOTATED, IMPORT_OPTIONAL, Import +from datamodel_code_generator.imports import ( + IMPORT_ANNOTATED, + IMPORT_OPTIONAL, + IMPORT_UNION, + Import, +) from datamodel_code_generator.reference import Reference, _BaseModel -from datamodel_code_generator.types import DataType, Nullable, chain_as_tuple +from datamodel_code_generator.types import ( + ANY, + NONE, + UNION_PREFIX, + DataType, + Nullable, + chain_as_tuple, + get_optional_type, +) TEMPLATE_DIR: Path = Path(__file__).parents[0] / 'template' -OPTIONAL: str = 'Optional' - ALL_MODEL: str = '#all#' @@ -68,7 +79,7 @@ class DataModelFieldBase(_BaseModel): if not TYPE_CHECKING: - def __init__(self, **data: Any): + def __init__(self, **data: Any) -> None: super().__init__(**data) if self.data_type.reference or self.data_type.data_types: self.data_type.parent = self @@ -83,28 +94,29 @@ class DataModelFieldBase(_BaseModel): type_hint = self.data_type.type_hint if not type_hint: - return OPTIONAL - elif self.data_type.is_optional and self.data_type.type != 'Any': + return NONE + elif self.data_type.is_optional and self.data_type.type != ANY: return type_hint elif self.nullable is not None: if self.nullable: - if self.data_type.use_union_operator: - return f'{type_hint} | None' - else: - return f'{OPTIONAL}[{type_hint}]' + return get_optional_type(type_hint, self.data_type.use_union_operator) return type_hint elif self.required: return type_hint - if self.data_type.use_union_operator: - return f'{type_hint} | None' - else: - return f'{OPTIONAL}[{type_hint}]' + return get_optional_type(type_hint, self.data_type.use_union_operator) @property def imports(self) -> Tuple[Import, ...]: + type_hint = self.type_hint + has_union = not self.data_type.use_union_operator and UNION_PREFIX in type_hint imports: List[Union[Tuple[Import], Iterator[Import]]] = [ - self.data_type.all_imports + ( + i + for i in self.data_type.all_imports + if not (not has_union and i == IMPORT_UNION) + ) ] + if ( self.nullable or (self.nullable is None and not self.required) ) and not self.data_type.use_union_operator: diff --git a/datamodel_code_generator/types.py b/datamodel_code_generator/types.py index ef49db11..69459736 100644 --- a/datamodel_code_generator/types.py +++ b/datamodel_code_generator/types.py @@ -1,7 +1,9 @@ from __future__ import annotations +import re from abc import ABC, abstractmethod from enum import Enum, auto +from functools import lru_cache from itertools import chain from typing import ( TYPE_CHECKING, @@ -14,6 +16,7 @@ from typing import ( Iterator, List, Optional, + Pattern, Sequence, Set, Tuple, @@ -45,6 +48,26 @@ from datamodel_code_generator.reference import Reference, _BaseModel T = TypeVar('T') +OPTIONAL = 'Optional' +OPTIONAL_PREFIX = f'{OPTIONAL}[' + +UNION = 'Union' +UNION_PREFIX = f'{UNION}[' +UNION_DELIMITER = ', ' +UNION_PATTERN: Pattern[str] = re.compile(r'\s*,\s*') +UNION_OPERATOR_DELIMITER = ' | ' +UNION_OPERATOR_PATTERN: Pattern[str] = re.compile(r'\s*\|\s*') +NONE = 'None' +ANY = 'Any' +LITERAL = 'Literal' +SEQUENCE = 'Sequence' +MAPPING = 'Mapping' +DICT = 'Dict' +LIST = 'List' +STANDARD_DICT = 'dict' +STANDARD_LIST = 'list' +STR = 'str' + class StrictTypes(Enum): str = 'str' @@ -79,6 +102,63 @@ def chain_as_tuple(*iterables: Iterable[T]) -> Tuple[T, ...]: return tuple(chain(*iterables)) +@lru_cache() +def _remove_none_from_type( + type_: str, split_pattern: Pattern[str], delimiter: str +) -> List[str]: + types: List[str] = [] + split_type: str = '' + inner_count: int = 0 + for part in re.split(split_pattern, type_): + if part == NONE: + continue + inner_count += part.count('[') - part.count(']') + if split_type: + split_type += delimiter + if inner_count == 0: + if split_type: + types.append(f'{split_type}{part}') + else: + types.append(part) + split_type = '' + continue + else: + split_type += part + return types + + +def _remove_none_from_union(type_: str, use_union_operator: bool) -> str: + if use_union_operator: + if not re.match(r'^\w+ | ', type_): + return type_ + return UNION_OPERATOR_DELIMITER.join( + _remove_none_from_type( + type_, UNION_OPERATOR_PATTERN, UNION_OPERATOR_DELIMITER + ) + ) + + if not type_.startswith(UNION_PREFIX): + return type_ + inner_types = _remove_none_from_type( + type_[len(UNION_PREFIX) :][:-1], UNION_PATTERN, UNION_DELIMITER + ) + + if len(inner_types) == 1: + return inner_types[0] + return f'{UNION_PREFIX}{UNION_DELIMITER.join(inner_types)}]' + + +@lru_cache() +def get_optional_type(type_: str, use_union_operator: bool) -> str: + type_ = _remove_none_from_union(type_, use_union_operator) + + if not type_ or type_ == NONE: + return NONE + if use_union_operator: + return f'{type_} | {NONE}' + return f'{OPTIONAL_PREFIX}{type_}]' + + @runtime_checkable class Modular(Protocol): @property @@ -254,15 +334,13 @@ class DataType(_BaseModel): super().__init__(**values) for type_ in self.data_types: - if type_.type == 'Any' and type_.is_optional: - if any( - t for t in self.data_types if t.type != 'Any' - ): # pragma: no cover + if type_.type == ANY and type_.is_optional: + if any(t for t in self.data_types if t.type != ANY): # pragma: no cover self.is_optional = True self.data_types = [ t for t in self.data_types - if not (t.type == 'Any' and t.is_optional) + if not (t.type == ANY and t.is_optional) ] break @@ -278,18 +356,20 @@ class DataType(_BaseModel): type_: Optional[str] = self.alias or self.type if not type_: if self.is_union: + data_types: List[str] = [] + for data_type in self.data_types: + data_type_type = data_type.type_hint + if data_type_type in data_types: # pragma: no cover + continue + data_types.append(data_type_type) if self.use_union_operator: - type_ = ' | '.join( - data_type.type_hint for data_type in self.data_types - ) + type_ = UNION_OPERATOR_DELIMITER.join(data_types) else: - type_ = f"Union[{', '.join(data_type.type_hint for data_type in self.data_types)}]" + type_ = f'{UNION_PREFIX}{UNION_DELIMITER.join(data_types)}]' elif len(self.data_types) == 1: type_ = self.data_types[0].type_hint elif self.literals: - type_ = ( - f"Literal[{', '.join(repr(literal) for literal in self.literals)}]" - ) + type_ = f"{LITERAL}[{', '.join(repr(literal) for literal in self.literals)}]" else: if self.reference: type_ = self.reference.short_name @@ -305,29 +385,26 @@ class DataType(_BaseModel): type_ = f"'{type_}'" if self.is_list: if self.use_generic_container: - list_ = 'Sequence' + list_ = SEQUENCE elif self.use_standard_collections: - list_ = 'list' + list_ = STANDARD_LIST else: - list_ = 'List' + list_ = LIST type_ = f'{list_}[{type_}]' if type_ else list_ elif self.is_dict: if self.use_generic_container: - dict_ = 'Mapping' + dict_ = MAPPING elif self.use_standard_collections: - dict_ = 'dict' + dict_ = STANDARD_DICT else: - dict_ = 'Dict' + dict_ = DICT if self.dict_key or type_: - key = self.dict_key.type_hint if self.dict_key else 'str' - type_ = f'{dict_}[{key}, {type_ or "Any"}]' + key = self.dict_key.type_hint if self.dict_key else STR + type_ = f'{dict_}[{key}, {type_ or ANY}]' else: # pragma: no cover type_ = dict_ - if self.is_optional and type_ != 'Any': - if self.use_union_operator: # pragma: no cover - type_ = f'{type_} | None' - else: - type_ = f'Optional[{type_}]' + if self.is_optional and type_ != ANY: + return get_optional_type(type_, self.use_union_operator) elif self.is_func: if self.kwargs: kwargs: str = ', '.join(f'{k}={v}' for k, v in self.kwargs.items()) diff --git a/tests/data/expected/main/main_nullable_any_of/output.py b/tests/data/expected/main/main_nullable_any_of/output.py new file mode 100644 index 00000000..bad52098 --- /dev/null +++ b/tests/data/expected/main/main_nullable_any_of/output.py @@ -0,0 +1,31 @@ +# generated by datamodel-codegen: +# filename: nullable_any_of.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, Extra, Field + + +class ConfigItem(BaseModel): + __root__: str = Field(..., description='d2', min_length=1, title='t2') + + +class In(BaseModel): + class Config: + extra = Extra.forbid + + input_dataset_path: Optional[str] = Field( + None, description='d1', min_length=1, title='Path to the input dataset' + ) + config: Optional[ConfigItem] = None + + +class ValidatingSchemaId1(BaseModel): + class Config: + extra = Extra.forbid + + in_: Optional[In] = Field(None, alias='in') + n1: Optional[int] = None diff --git a/tests/data/expected/main/main_nullable_any_of_use_union_operator/output.py b/tests/data/expected/main/main_nullable_any_of_use_union_operator/output.py new file mode 100644 index 00000000..fb785636 --- /dev/null +++ b/tests/data/expected/main/main_nullable_any_of_use_union_operator/output.py @@ -0,0 +1,29 @@ +# generated by datamodel-codegen: +# filename: nullable_any_of.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from pydantic import BaseModel, Extra, Field + + +class ConfigItem(BaseModel): + __root__: str = Field(..., description='d2', min_length=1, title='t2') + + +class In(BaseModel): + class Config: + extra = Extra.forbid + + input_dataset_path: str | None = Field( + None, description='d1', min_length=1, title='Path to the input dataset' + ) + config: ConfigItem | None = None + + +class ValidatingSchemaId1(BaseModel): + class Config: + extra = Extra.forbid + + in_: In | None = Field(None, alias='in') + n1: int | None = None diff --git a/tests/data/jsonschema/nullable_any_of.json b/tests/data/jsonschema/nullable_any_of.json new file mode 100644 index 00000000..dc21a7f3 --- /dev/null +++ b/tests/data/jsonschema/nullable_any_of.json @@ -0,0 +1,41 @@ +{ + "type": "object", + "additionalProperties": false, + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "id1", + "title": "Validating Schema ID1", + "properties": { + "in": { + "type": "object", + "additionalProperties": false, + "properties": { + "input_dataset_path": { + "type": "string", + "minLength": 1, + "title": "Path to the input dataset", + "description": "d1" + }, + "config": { + "anyOf": [ + { + "type": "string", + "minLength": 1, + "title": "t2", + "description": "d2" + }, + { + "type": [ + "null" + ], + "title": "t3", + "description": "d3" + } + ] + } + } + }, + "n1": { + "type": "integer" + } + } +} diff --git a/tests/model/test_base.py b/tests/model/test_base.py index d0661926..f2503725 100644 --- a/tests/model/test_base.py +++ b/tests/model/test_base.py @@ -14,7 +14,7 @@ from datamodel_code_generator.types import DataType, Types class A(TemplateBase): - def __init__(self, path: Path): + def __init__(self, path: Path) -> None: self._path = path @property @@ -30,7 +30,7 @@ class B(DataModel): def get_data_type(cls, types: Types, **kwargs: Any) -> DataType: pass - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) TEMPLATE_FILE_PATH = '' @@ -135,7 +135,7 @@ def test_data_field(): ) assert field.type_hint == 'List' field = DataModelFieldBase(name='a', data_type=DataType(), required=False) - assert field.type_hint == 'Optional' + assert field.type_hint == 'None' field = DataModelFieldBase( name='a', data_type=DataType(is_list=True), @@ -147,11 +147,11 @@ def test_data_field(): field = DataModelFieldBase( name='a', data_type=DataType(), required=False, is_list=False, is_union=True ) - assert field.type_hint == 'Optional' + assert field.type_hint == 'None' field = DataModelFieldBase( name='a', data_type=DataType(), required=False, is_list=False, is_union=False ) - assert field.type_hint == 'Optional' + assert field.type_hint == 'None' field = DataModelFieldBase( name='a', data_type=DataType(is_list=True), diff --git a/tests/test_main.py b/tests/test_main.py index 7b1f84ef..8552a660 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5054,3 +5054,54 @@ def test_main_multiple_required_any_of(): with pytest.raises(SystemExit): main() + + +@freeze_time('2019-07-26') +def test_main_nullable_any_of(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(JSON_SCHEMA_DATA_PATH / 'nullable_any_of.json'), + '--output', + str(output_file), + '--field-constraints', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == (EXPECTED_MAIN_PATH / 'main_nullable_any_of' / 'output.py').read_text() + ) + + with pytest.raises(SystemExit): + main() + + +@freeze_time('2019-07-26') +def test_main_nullable_any_of_use_union_operator(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(JSON_SCHEMA_DATA_PATH / 'nullable_any_of.json'), + '--output', + str(output_file), + '--field-constraints', + '--use-union-operator', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_MAIN_PATH + / 'main_nullable_any_of_use_union_operator' + / 'output.py' + ).read_text() + ) + + with pytest.raises(SystemExit): + main() diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 00000000..1b38e693 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,41 @@ +import pytest + +from datamodel_code_generator.types import get_optional_type + + +@pytest.mark.parametrize( + 'input_,use_union_operator,expected', + [ + ('List[str]', False, 'Optional[List[str]]'), + ('List[str, int, float]', False, 'Optional[List[str, int, float]]'), + ('List[str, int, None]', False, 'Optional[List[str, int, None]]'), + ('Union[str]', False, 'Optional[str]'), + ('Union[str, int, float]', False, 'Optional[Union[str, int, float]]'), + ('Union[str, int, None]', False, 'Optional[Union[str, int]]'), + ('Union[str, int, None, None]', False, 'Optional[Union[str, int]]'), + ( + 'Union[str, int, List[str, int, None], None]', + False, + 'Optional[Union[str, int, List[str, int, None]]]', + ), + ( + 'Union[str, int, List[str, Dict[int, str | None]], None]', + False, + 'Optional[Union[str, int, List[str, Dict[int, str | None]]]]', + ), + ('List[str]', True, 'List[str] | None'), + ('List[str | int | float]', True, 'List[str | int | float] | None'), + ('List[str | int | None]', True, 'List[str | int | None] | None'), + ('str', True, 'str | None'), + ('str | int | float', True, 'str | int | float | None'), + ('str | int | None', True, 'str | int | None'), + ('str | int | None | None', True, 'str | int | None'), + ( + 'str | int | List[str | Dict[int | Union[str | None]]] | None', + True, + 'str | int | List[str | Dict[int | Union[str | None]]] | None', + ), + ], +) +def test_get_optional_type(input_: str, use_union_operator: bool, expected: str): + assert get_optional_type(input_, use_union_operator) == expected