support enum (#49)

* support enum

* fix styles
This commit is contained in:
Koudai Aono
2019-09-06 05:36:22 +09:00
committed by GitHub
parent ccc757a3d3
commit 8fd5698bbd
7 changed files with 205 additions and 7 deletions

View File

@@ -1 +1 @@
recursive-include datamodel_code_generator/model/template/pydantic/ *.jinja2 recursive-include datamodel_code_generator/model/template/ *.jinja2

View File

@@ -0,0 +1,22 @@
from typing import Any, List, Optional
from datamodel_code_generator.model import DataModel, DataModelField
from datamodel_code_generator.model.pydantic.types import get_data_type
from datamodel_code_generator.types import DataType, Types
class Enum(DataModel):
TEMPLATE_FILE_PATH = 'Enum.jinja2'
BASE_CLASS = 'enum.Enum'
def __init__(
self,
name: str,
fields: List[DataModelField],
decorators: Optional[List[str]] = None,
):
super().__init__(name=name, fields=fields, decorators=decorators)
@classmethod
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
raise NotImplementedError

View File

@@ -0,0 +1,7 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}(Enum):
{%- for field in fields %}
{{ field.name }} = {{ field.default }}
{%- endfor -%}

View File

@@ -9,7 +9,7 @@ from ..model.base import DataModel, DataModelField, Types
def snake_to_upper_camel(word: str) -> str: def snake_to_upper_camel(word: str) -> str:
return ''.join(x.capitalize() for x in word.split('_')) return ''.join(x[0].upper() + x[1:] for x in word.split('_'))
json_schema_data_formats: Dict[str, Dict[str, Types]] = { json_schema_data_formats: Dict[str, Dict[str, Types]] = {

View File

@@ -1,7 +1,8 @@
from typing import Callable, Dict, Iterator, List, Optional, Set, Type, Union from typing import Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
import black import black
import inflect import inflect
from datamodel_code_generator.model.enum import Enum
from datamodel_code_generator.parser.base import ( from datamodel_code_generator.parser.base import (
JsonSchemaObject, JsonSchemaObject,
Parser, Parser,
@@ -29,8 +30,6 @@ def dump_templates(templates: Union[TemplateBase, List[TemplateBase]]) -> str:
def create_class_name(field_name: str) -> str: def create_class_name(field_name: str) -> str:
upper_camel_name = snake_to_upper_camel(field_name) upper_camel_name = snake_to_upper_camel(field_name)
if upper_camel_name == field_name:
upper_camel_name += '_'
return upper_camel_name return upper_camel_name
@@ -93,6 +92,10 @@ class OpenAPIParser(Parser):
yield from self.parse_object(class_name, filed) yield from self.parse_object(class_name, filed)
field_class_names.add(class_name) field_class_names.add(class_name)
field_type_hint = self.get_type_name(class_name) field_type_hint = self.get_type_name(class_name)
elif filed.enum:
enum_name = self.get_type_name(field_name)
field_type_hint, enum = self.parse_enum(enum_name, filed)
yield enum
else: else:
data_type = get_data_type(filed, self.data_model_type) data_type = get_data_type(filed, self.data_model_type)
self.imports.append(data_type.import_) self.imports.append(data_type.import_)
@@ -134,6 +137,7 @@ class OpenAPIParser(Parser):
singular_name = f'{name}Item' singular_name = f'{name}Item'
yield from self.parse_object(singular_name, item) yield from self.parse_object(singular_name, item)
items_obj_name.append(self.get_type_name(singular_name)) items_obj_name.append(self.get_type_name(singular_name))
print(singular_name)
else: else:
data_type = get_data_type(item, self.data_model_type) data_type = get_data_type(item, self.data_model_type)
items_obj_name.append(data_type.type_hint) items_obj_name.append(data_type.type_hint)
@@ -178,6 +182,31 @@ class OpenAPIParser(Parser):
self.created_model_names.add(name) self.created_model_names.add(name)
yield data_model_root_type yield data_model_root_type
def parse_enum(self, name: str, obj: JsonSchemaObject) -> Tuple[str, TemplateBase]:
enum_fields = []
for enum_part in obj.enum: # type: ignore
if obj.type == 'string':
default = f"'{enum_part}'"
else:
default = enum_part
if obj.type == 'string':
field_name = enum_part
else:
field_name = f'{obj.type}_{enum_part}'
enum_fields.append(
self.data_model_field_type(name=field_name, default=default)
)
enum_name = name
count = 1
while enum_name in self.created_model_names:
enum_name = f'{name}_{count}'
count += 1
enum_name = create_class_name(enum_name)
self.imports.append(Import(import_='Enum', from_='enum'))
self.created_model_names.add(enum_name)
return enum_name, Enum(enum_name, fields=enum_fields)
def parse( def parse(
self, with_import: Optional[bool] = True, format_: Optional[bool] = True self, with_import: Optional[bool] = True, format_: Optional[bool] = True
) -> str: ) -> str:
@@ -190,6 +219,8 @@ class OpenAPIParser(Parser):
templates.extend(self.parse_object(obj_name, obj)) templates.extend(self.parse_object(obj_name, obj))
elif obj.is_array: elif obj.is_array:
templates.extend(self.parse_array(obj_name, obj)) templates.extend(self.parse_array(obj_name, obj))
elif obj.enum:
templates.append(self.parse_enum(obj_name, obj)[1])
else: else:
templates.extend(self.parse_root_type(obj_name, obj)) templates.extend(self.parse_root_type(obj_name, obj))

View File

@@ -0,0 +1,87 @@
openapi: "3.0.0"
info:
version: 1.0.0
title: Swagger Petstore
license:
name: MIT
servers:
- url: http://petstore.swagger.io/v1
paths:
/pets:
get:
summary: List all pets
operationId: listPets
tags:
- pets
parameters:
- name: limit
in: query
description: How many items to return at one time (max 100)
required: false
schema:
type: integer
format: int32
responses:
'200':
description: A paged array of pets
headers:
x-next:
description: A link to the next page of responses
schema:
type: string
content:
application/json:
schema:
$ref: "#/components/schemas/Pets"
default:
description: unexpected error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
x-amazon-apigateway-integration:
uri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${PythonVersionFunction.Arn}/invocations
passthroughBehavior: when_no_templates
httpMethod: POST
type: aws_proxy
components:
schemas:
Pet:
required:
- id
- name
properties:
id:
type: integer
format: int64
name:
type: string
tag:
type: string
Pets:
type: array
items:
$ref: "#/components/schemas/Pet"
Error:
required:
- code
- message
properties:
code:
type: integer
format: int32
message:
type: string
EnumObject:
type: object
properties:
type:
enum: ['a', 'b']
type: string
EnumRoot:
enum: ['a', 'b']
type: string
IntEnum:
enum: [1,2]
type: number

View File

@@ -111,12 +111,12 @@ class Pets(BaseModel):
} }
} }
}, },
'''class Kind_(BaseModel): '''class Kind(BaseModel):
name: Optional[str] = None name: Optional[str] = None
class Pets(BaseModel): class Pets(BaseModel):
Kind: Optional[Kind_] = None''', Kind: Optional[Kind] = None''',
), ),
( (
{ {
@@ -655,3 +655,54 @@ class Event(BaseModel):
name: Optional[str] = None name: Optional[str] = None
""" """
) )
def test_openapi_parser_parse_enum_models():
parser = OpenAPIParser(
BaseModel, CustomRootType, filename=str(DATA_PATH / 'enum_models.yaml')
)
print(parser.parse())
assert (
parser.parse()
== """from __future__ import annotations
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class Pet(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Pets(BaseModel):
__root__: List[Pet]
class Error(BaseModel):
code: int
message: str
class Type(Enum):
a = 'a'
b = 'b'
class EnumObject(BaseModel):
type: Optional[Type] = None
class EnumRoot1(Enum):
a = 'a'
b = 'b'
class IntEnum1(Enum):
number_1 = 1
number_2 = 2
"""
)