From 8fd5698bbdb375c091d64a5e20dd67af6914c3f8 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Fri, 6 Sep 2019 05:36:22 +0900 Subject: [PATCH] support enum (#49) * support enum * fix styles --- MANIFEST.in | 2 +- datamodel_code_generator/model/enum.py | 22 +++++ .../model/template/Enum.jinja2 | 7 ++ datamodel_code_generator/parser/base.py | 2 +- datamodel_code_generator/parser/openapi.py | 37 +++++++- tests/data/enum_models.yaml | 87 +++++++++++++++++++ tests/parser/test_openapi.py | 55 +++++++++++- 7 files changed, 205 insertions(+), 7 deletions(-) create mode 100644 datamodel_code_generator/model/enum.py create mode 100644 datamodel_code_generator/model/template/Enum.jinja2 create mode 100644 tests/data/enum_models.yaml diff --git a/MANIFEST.in b/MANIFEST.in index 97ae125d..343b3526 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include datamodel_code_generator/model/template/pydantic/ *.jinja2 +recursive-include datamodel_code_generator/model/template/ *.jinja2 diff --git a/datamodel_code_generator/model/enum.py b/datamodel_code_generator/model/enum.py new file mode 100644 index 00000000..b072532f --- /dev/null +++ b/datamodel_code_generator/model/enum.py @@ -0,0 +1,22 @@ +from typing import Any, List, Optional + +from datamodel_code_generator.model import DataModel, DataModelField +from datamodel_code_generator.model.pydantic.types import get_data_type +from datamodel_code_generator.types import DataType, Types + + +class Enum(DataModel): + TEMPLATE_FILE_PATH = 'Enum.jinja2' + BASE_CLASS = 'enum.Enum' + + def __init__( + self, + name: str, + fields: List[DataModelField], + decorators: Optional[List[str]] = None, + ): + super().__init__(name=name, fields=fields, decorators=decorators) + + @classmethod + def get_data_type(cls, types: Types, **kwargs: Any) -> DataType: + raise NotImplementedError diff --git a/datamodel_code_generator/model/template/Enum.jinja2 b/datamodel_code_generator/model/template/Enum.jinja2 new file mode 100644 index 00000000..b080e778 --- /dev/null +++ b/datamodel_code_generator/model/template/Enum.jinja2 @@ -0,0 +1,7 @@ +{% for decorator in decorators -%} +{{ decorator }} +{% endfor -%} +class {{ class_name }}(Enum): +{%- for field in fields %} + {{ field.name }} = {{ field.default }} +{%- endfor -%} diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index abeb6f27..5377af39 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -9,7 +9,7 @@ from ..model.base import DataModel, DataModelField, Types def snake_to_upper_camel(word: str) -> str: - return ''.join(x.capitalize() for x in word.split('_')) + return ''.join(x[0].upper() + x[1:] for x in word.split('_')) json_schema_data_formats: Dict[str, Dict[str, Types]] = { diff --git a/datamodel_code_generator/parser/openapi.py b/datamodel_code_generator/parser/openapi.py index 57abe4a8..03f99a0e 100644 --- a/datamodel_code_generator/parser/openapi.py +++ b/datamodel_code_generator/parser/openapi.py @@ -1,7 +1,8 @@ -from typing import Callable, Dict, Iterator, List, Optional, Set, Type, Union +from typing import Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union import black import inflect +from datamodel_code_generator.model.enum import Enum from datamodel_code_generator.parser.base import ( JsonSchemaObject, Parser, @@ -29,8 +30,6 @@ def dump_templates(templates: Union[TemplateBase, List[TemplateBase]]) -> str: def create_class_name(field_name: str) -> str: upper_camel_name = snake_to_upper_camel(field_name) - if upper_camel_name == field_name: - upper_camel_name += '_' return upper_camel_name @@ -93,6 +92,10 @@ class OpenAPIParser(Parser): yield from self.parse_object(class_name, filed) field_class_names.add(class_name) field_type_hint = self.get_type_name(class_name) + elif filed.enum: + enum_name = self.get_type_name(field_name) + field_type_hint, enum = self.parse_enum(enum_name, filed) + yield enum else: data_type = get_data_type(filed, self.data_model_type) self.imports.append(data_type.import_) @@ -134,6 +137,7 @@ class OpenAPIParser(Parser): singular_name = f'{name}Item' yield from self.parse_object(singular_name, item) items_obj_name.append(self.get_type_name(singular_name)) + print(singular_name) else: data_type = get_data_type(item, self.data_model_type) items_obj_name.append(data_type.type_hint) @@ -178,6 +182,31 @@ class OpenAPIParser(Parser): self.created_model_names.add(name) yield data_model_root_type + def parse_enum(self, name: str, obj: JsonSchemaObject) -> Tuple[str, TemplateBase]: + enum_fields = [] + + for enum_part in obj.enum: # type: ignore + if obj.type == 'string': + default = f"'{enum_part}'" + else: + default = enum_part + if obj.type == 'string': + field_name = enum_part + else: + field_name = f'{obj.type}_{enum_part}' + enum_fields.append( + self.data_model_field_type(name=field_name, default=default) + ) + enum_name = name + count = 1 + while enum_name in self.created_model_names: + enum_name = f'{name}_{count}' + count += 1 + enum_name = create_class_name(enum_name) + self.imports.append(Import(import_='Enum', from_='enum')) + self.created_model_names.add(enum_name) + return enum_name, Enum(enum_name, fields=enum_fields) + def parse( self, with_import: Optional[bool] = True, format_: Optional[bool] = True ) -> str: @@ -190,6 +219,8 @@ class OpenAPIParser(Parser): templates.extend(self.parse_object(obj_name, obj)) elif obj.is_array: templates.extend(self.parse_array(obj_name, obj)) + elif obj.enum: + templates.append(self.parse_enum(obj_name, obj)[1]) else: templates.extend(self.parse_root_type(obj_name, obj)) diff --git a/tests/data/enum_models.yaml b/tests/data/enum_models.yaml new file mode 100644 index 00000000..8b334af6 --- /dev/null +++ b/tests/data/enum_models.yaml @@ -0,0 +1,87 @@ +openapi: "3.0.0" +info: + version: 1.0.0 + title: Swagger Petstore + license: + name: MIT +servers: + - url: http://petstore.swagger.io/v1 +paths: + /pets: + get: + summary: List all pets + operationId: listPets + tags: + - pets + parameters: + - name: limit + in: query + description: How many items to return at one time (max 100) + required: false + schema: + type: integer + format: int32 + responses: + '200': + description: A paged array of pets + headers: + x-next: + description: A link to the next page of responses + schema: + type: string + content: + application/json: + schema: + $ref: "#/components/schemas/Pets" + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + x-amazon-apigateway-integration: + uri: + Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${PythonVersionFunction.Arn}/invocations + passthroughBehavior: when_no_templates + httpMethod: POST + type: aws_proxy +components: + schemas: + Pet: + required: + - id + - name + properties: + id: + type: integer + format: int64 + name: + type: string + tag: + type: string + Pets: + type: array + items: + $ref: "#/components/schemas/Pet" + Error: + required: + - code + - message + properties: + code: + type: integer + format: int32 + message: + type: string + EnumObject: + type: object + properties: + type: + enum: ['a', 'b'] + type: string + EnumRoot: + enum: ['a', 'b'] + type: string + IntEnum: + enum: [1,2] + type: number \ No newline at end of file diff --git a/tests/parser/test_openapi.py b/tests/parser/test_openapi.py index e09c78c6..17d347d0 100644 --- a/tests/parser/test_openapi.py +++ b/tests/parser/test_openapi.py @@ -111,12 +111,12 @@ class Pets(BaseModel): } } }, - '''class Kind_(BaseModel): + '''class Kind(BaseModel): name: Optional[str] = None class Pets(BaseModel): - Kind: Optional[Kind_] = None''', + Kind: Optional[Kind] = None''', ), ( { @@ -655,3 +655,54 @@ class Event(BaseModel): name: Optional[str] = None """ ) + + +def test_openapi_parser_parse_enum_models(): + parser = OpenAPIParser( + BaseModel, CustomRootType, filename=str(DATA_PATH / 'enum_models.yaml') + ) + print(parser.parse()) + assert ( + parser.parse() + == """from __future__ import annotations + +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel + + +class Pet(BaseModel): + id: int + name: str + tag: Optional[str] = None + + +class Pets(BaseModel): + __root__: List[Pet] + + +class Error(BaseModel): + code: int + message: str + + +class Type(Enum): + a = 'a' + b = 'b' + + +class EnumObject(BaseModel): + type: Optional[Type] = None + + +class EnumRoot1(Enum): + a = 'a' + b = 'b' + + +class IntEnum1(Enum): + number_1 = 1 + number_2 = 2 +""" + )