support enum (#49)

* support enum

* fix styles
This commit is contained in:
Koudai Aono
2019-09-06 05:36:22 +09:00
committed by GitHub
parent ccc757a3d3
commit 8fd5698bbd
7 changed files with 205 additions and 7 deletions

View File

@@ -1 +1 @@
recursive-include datamodel_code_generator/model/template/pydantic/ *.jinja2
recursive-include datamodel_code_generator/model/template/ *.jinja2

View File

@@ -0,0 +1,22 @@
from typing import Any, List, Optional
from datamodel_code_generator.model import DataModel, DataModelField
from datamodel_code_generator.model.pydantic.types import get_data_type
from datamodel_code_generator.types import DataType, Types
class Enum(DataModel):
TEMPLATE_FILE_PATH = 'Enum.jinja2'
BASE_CLASS = 'enum.Enum'
def __init__(
self,
name: str,
fields: List[DataModelField],
decorators: Optional[List[str]] = None,
):
super().__init__(name=name, fields=fields, decorators=decorators)
@classmethod
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
raise NotImplementedError

View File

@@ -0,0 +1,7 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}(Enum):
{%- for field in fields %}
{{ field.name }} = {{ field.default }}
{%- endfor -%}

View File

@@ -9,7 +9,7 @@ from ..model.base import DataModel, DataModelField, Types
def snake_to_upper_camel(word: str) -> str:
return ''.join(x.capitalize() for x in word.split('_'))
return ''.join(x[0].upper() + x[1:] for x in word.split('_'))
json_schema_data_formats: Dict[str, Dict[str, Types]] = {

View File

@@ -1,7 +1,8 @@
from typing import Callable, Dict, Iterator, List, Optional, Set, Type, Union
from typing import Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
import black
import inflect
from datamodel_code_generator.model.enum import Enum
from datamodel_code_generator.parser.base import (
JsonSchemaObject,
Parser,
@@ -29,8 +30,6 @@ def dump_templates(templates: Union[TemplateBase, List[TemplateBase]]) -> str:
def create_class_name(field_name: str) -> str:
upper_camel_name = snake_to_upper_camel(field_name)
if upper_camel_name == field_name:
upper_camel_name += '_'
return upper_camel_name
@@ -93,6 +92,10 @@ class OpenAPIParser(Parser):
yield from self.parse_object(class_name, filed)
field_class_names.add(class_name)
field_type_hint = self.get_type_name(class_name)
elif filed.enum:
enum_name = self.get_type_name(field_name)
field_type_hint, enum = self.parse_enum(enum_name, filed)
yield enum
else:
data_type = get_data_type(filed, self.data_model_type)
self.imports.append(data_type.import_)
@@ -134,6 +137,7 @@ class OpenAPIParser(Parser):
singular_name = f'{name}Item'
yield from self.parse_object(singular_name, item)
items_obj_name.append(self.get_type_name(singular_name))
print(singular_name)
else:
data_type = get_data_type(item, self.data_model_type)
items_obj_name.append(data_type.type_hint)
@@ -178,6 +182,31 @@ class OpenAPIParser(Parser):
self.created_model_names.add(name)
yield data_model_root_type
def parse_enum(self, name: str, obj: JsonSchemaObject) -> Tuple[str, TemplateBase]:
enum_fields = []
for enum_part in obj.enum: # type: ignore
if obj.type == 'string':
default = f"'{enum_part}'"
else:
default = enum_part
if obj.type == 'string':
field_name = enum_part
else:
field_name = f'{obj.type}_{enum_part}'
enum_fields.append(
self.data_model_field_type(name=field_name, default=default)
)
enum_name = name
count = 1
while enum_name in self.created_model_names:
enum_name = f'{name}_{count}'
count += 1
enum_name = create_class_name(enum_name)
self.imports.append(Import(import_='Enum', from_='enum'))
self.created_model_names.add(enum_name)
return enum_name, Enum(enum_name, fields=enum_fields)
def parse(
self, with_import: Optional[bool] = True, format_: Optional[bool] = True
) -> str:
@@ -190,6 +219,8 @@ class OpenAPIParser(Parser):
templates.extend(self.parse_object(obj_name, obj))
elif obj.is_array:
templates.extend(self.parse_array(obj_name, obj))
elif obj.enum:
templates.append(self.parse_enum(obj_name, obj)[1])
else:
templates.extend(self.parse_root_type(obj_name, obj))

View File

@@ -0,0 +1,87 @@
openapi: "3.0.0"
info:
version: 1.0.0
title: Swagger Petstore
license:
name: MIT
servers:
- url: http://petstore.swagger.io/v1
paths:
/pets:
get:
summary: List all pets
operationId: listPets
tags:
- pets
parameters:
- name: limit
in: query
description: How many items to return at one time (max 100)
required: false
schema:
type: integer
format: int32
responses:
'200':
description: A paged array of pets
headers:
x-next:
description: A link to the next page of responses
schema:
type: string
content:
application/json:
schema:
$ref: "#/components/schemas/Pets"
default:
description: unexpected error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
x-amazon-apigateway-integration:
uri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${PythonVersionFunction.Arn}/invocations
passthroughBehavior: when_no_templates
httpMethod: POST
type: aws_proxy
components:
schemas:
Pet:
required:
- id
- name
properties:
id:
type: integer
format: int64
name:
type: string
tag:
type: string
Pets:
type: array
items:
$ref: "#/components/schemas/Pet"
Error:
required:
- code
- message
properties:
code:
type: integer
format: int32
message:
type: string
EnumObject:
type: object
properties:
type:
enum: ['a', 'b']
type: string
EnumRoot:
enum: ['a', 'b']
type: string
IntEnum:
enum: [1,2]
type: number

View File

@@ -111,12 +111,12 @@ class Pets(BaseModel):
}
}
},
'''class Kind_(BaseModel):
'''class Kind(BaseModel):
name: Optional[str] = None
class Pets(BaseModel):
Kind: Optional[Kind_] = None''',
Kind: Optional[Kind] = None''',
),
(
{
@@ -655,3 +655,54 @@ class Event(BaseModel):
name: Optional[str] = None
"""
)
def test_openapi_parser_parse_enum_models():
parser = OpenAPIParser(
BaseModel, CustomRootType, filename=str(DATA_PATH / 'enum_models.yaml')
)
print(parser.parse())
assert (
parser.parse()
== """from __future__ import annotations
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class Pet(BaseModel):
id: int
name: str
tag: Optional[str] = None
class Pets(BaseModel):
__root__: List[Pet]
class Error(BaseModel):
code: int
message: str
class Type(Enum):
a = 'a'
b = 'b'
class EnumObject(BaseModel):
type: Optional[Type] = None
class EnumRoot1(Enum):
a = 'a'
b = 'b'
class IntEnum1(Enum):
number_1 = 1
number_2 = 2
"""
)