This commit is contained in:
bluerelay
2019-02-10 16:17:52 -08:00
commit 07e1e731b7
25 changed files with 2374 additions and 0 deletions

252
.gitignore vendored Normal file
View File

@@ -0,0 +1,252 @@
# Build results
[Dd]ebug/
[Dd]ebugPublic/
[Rr]elease/
[Rr]eleases/
x64/
x86/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
# Visual Studio 2015 cache/options directory
.vs/
# Uncomment if you have tasks that create the project's static files in wwwroot
#wwwroot/
# MSTest test Results
[Tt]est[Rr]esult*/
[Bb]uild[Ll]og.*
# NUNIT
*.VisualState.xml
TestResult.xml
# Build Results of an ATL Project
[Dd]ebugPS/
[Rr]eleasePS/
dlldata.c
# DNX
project.lock.json
project.fragment.lock.json
artifacts/
*_i.c
*_p.c
*_i.h
*.ilk
*.meta
*.obj
*.pch
*.pdb
*.pgc
*.pgd
*.rsp
*.sbr
*.tlb
*.tli
*.tlh
*.tmp
*.tmp_proj
*.log
*.vspscc
*.vssscc
.builds
*.pidb
*.svclog
*.scc
# Chutzpah Test files
_Chutzpah*
# Visual C++ cache files
ipch/
*.aps
*.ncb
*.opendb
*.opensdf
*.sdf
*.cachefile
*.VC.db
*.VC.VC.opendb
# Visual Studio profiler
*.psess
*.vsp
*.vspx
*.sap
# TFS 2012 Local Workspace
$tf/
# Guidance Automation Toolkit
*.gpState
# ReSharper is a .NET coding add-in
_ReSharper*/
*.[Rr]e[Ss]harper
*.DotSettings.user
# JustCode is a .NET coding add-in
.JustCode
# TeamCity is a build add-in
_TeamCity*
# DotCover is a Code Coverage Tool
*.dotCover
# NCrunch
_NCrunch_*
.*crunch*.local.xml
nCrunchTemp_*
# MightyMoose
*.mm.*
AutoTest.Net/
# Web workbench (sass)
.sass-cache/
# Installshield output folder
[Ee]xpress/
# DocProject is a documentation generator add-in
DocProject/buildhelp/
DocProject/Help/*.HxT
DocProject/Help/*.HxC
DocProject/Help/*.hhc
DocProject/Help/*.hhk
DocProject/Help/*.hhp
DocProject/Help/Html2
DocProject/Help/html
# Click-Once directory
publish/
# Publish Web Output
*.[Pp]ublish.xml
*.azurePubxml
# TODO: Comment the next line if you want to checkin your web deploy settings
# but database connection strings (with potential passwords) will be unencrypted
#*.pubxml
*.publishproj
# Microsoft Azure Web App publish settings. Comment the next line if you want to
# checkin your Azure Web App publish settings, but sensitive information contained
# in these scripts will be unencrypted
PublishScripts/
# NuGet Packages
*.nupkg
# The packages folder can be ignored because of Package Restore
**/packages/*
# except build/, which is used as an MSBuild target.
!**/packages/build/
# Uncomment if necessary however generally it will be regenerated when needed
#!**/packages/repositories.config
# NuGet v3's project.json files produces more ignoreable files
*.nuget.props
*.nuget.targets
# Microsoft Azure Build Output
csx/
*.build.csdef
# Microsoft Azure Emulator
ecf/
rcf/
# Windows Store app package directories and files
AppPackages/
BundleArtifacts/
Package.StoreAssociation.xml
_pkginfo.txt
# Visual Studio cache files
# files ending in .cache can be ignored
*.[Cc]ache
# but keep track of directories ending in .cache
!*.[Cc]ache/
# Others
ClientBin/
~$*
*~
*.dbmdl
*.dbproj.schemaview
*.jfm
*.pfx
*.publishsettings
node_modules/
orleans.codegen.cs
# Since there are multiple workflows, uncomment next line to ignore bower_components
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
#bower_components/
# RIA/Silverlight projects
Generated_Code/
# Backup & report files from converting an old project file
# to a newer Visual Studio version. Backup files are not needed,
# because we have git ;-)
_UpgradeReport_Files/
Backup*/
UpgradeLog*.XML
UpgradeLog*.htm
# SQL Server files
*.mdf
*.ldf
# Business Intelligence projects
*.rdl.data
*.bim.layout
*.bim_*.settings
# Microsoft Fakes
FakesAssemblies/
# GhostDoc plugin setting file
*.GhostDoc.xml
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
# Visual Studio 6 build log
*.plg
# Visual Studio 6 workspace options file
*.opt
# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
**/*.DesktopClient/ModelManifest.xml
**/*.Server/GeneratedArtifacts
**/*.Server/ModelManifest.xml
_Pvt_Extensions
# Paket dependency manager
.paket/paket.exe
paket-files/
# FAKE - F# Make
.fake/
# JetBrains Rider
.idea/
*.sln.iml
# CodeRush
.cr/
# Python Tools for Visual Studio (PTVS)
__pycache__/
*.pyc
.pytest_cache/
.env
.vscode/settings.json

19
LICENSE Normal file
View File

@@ -0,0 +1,19 @@
Copyright (c) 2018 The Python Packaging Authority
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

3
README.md Normal file
View File

@@ -0,0 +1,3 @@
# windyquery - A non-blocking Python PostgreSQL query builder
Windyquery is a non-blocking PostgreSQL query builder with Asyncio.

21
setup.py Normal file
View File

@@ -0,0 +1,21 @@
import setuptools
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name="windyquery",
version="0.0.1",
author="windymile.it",
author_email="windymile.it@gmail.com",
description="A non-blocking PostgreSQL query builder using Asyncio",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/bluerelay/windyquery",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3.6",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
)

4
windyquery/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .index import Connection
from .index import DB
from .index import Schema
from .model import Model

882
windyquery/builder.py Normal file
View File

@@ -0,0 +1,882 @@
import re
import json
from asyncpg import utils
from .parser import CRUD
class Builder:
"""collect building blocks - items; choose cmposer; exec final sql from composer"""
def __init__(self, pool):
self.composer = None
self.items = {}
self.pool = pool
self._first = False
def reset(self):
self.composer = None
self.items = {}
self._first = False
def _stripItems(self, items):
return [val.strip() if isinstance(val, str) else val for val in items]
async def toSql(self):
sql, args = self.composer.compose(self.items)
async with self.pool.acquire() as connection:
sqlStr = await utils._mogrify(connection, sql, args)
return sqlStr
async def exec(self):
sql, args = self.composer.compose(self.items)
async with self.pool.acquire() as connection:
if self._first:
return await connection.fetchrow(sql, *args)
else:
return await connection.fetch(sql, *args)
def __await__(self):
return self.exec().__await__()
def raw(self, query, args):
if not isinstance(self.composer, Raw):
self.composer = Raw()
self.items['sql'] = query
self.items['args'] = args
return self
def table(self, table):
self.items['table'] = table.strip()
return self
def first(self):
self._first = True
return self
def select(self, *items):
if not isinstance(self.composer, Select):
self.composer = Select()
if 'select' not in self.items:
self.items['select'] = []
self.items['select'] += self._stripItems(items)
return self
def select_raw(self, *items):
if not isinstance(self.composer, SelectRaw):
self.composer = SelectRaw()
if 'select' not in self.items:
self.items['select'] = []
self.items['select'] += self._stripItems(items)
return self
def join(self, *clause):
if 'join' not in self.items:
self.items['join'] = []
self.items['join'].append(self._stripItems(clause))
return self
def where(self, *clause):
if 'where' not in self.items:
self.items['where'] = []
self.items['where'].append(self._stripItems(clause))
return self
def where_raw(self, query):
if 'where_raw' not in self.items:
self.items['where_raw'] = []
self.items['where_raw'].append(query)
return self
def order_by(self, *items):
if 'order_by' not in self.items:
self.items['order_by'] = []
self.items['order_by'] += self._stripItems(items)
return self
def order_by_raw(self, query):
if 'order_by_raw' not in self.items:
self.items['order_by_raw'] = []
self.items['order_by_raw'].append(query)
return self
def group_by(self, *items):
if 'group_by' not in self.items:
self.items['group_by'] = []
self.items['group_by'] += self._stripItems(items)
return self
def limit(self, val):
self.items['limit'] = int(val)
return self
def update(self, fields):
if not isinstance(self.composer, Update):
self.composer = Update()
if 'update' not in self.items:
self.items['update'] = {}
self.items['update'].update(fields)
return self
def update_from(self, tableName):
self.items['update_from'] = tableName.strip()
return self
def insert(self, *rows):
if not isinstance(self.composer, Insert):
self.composer = Insert()
if 'insert_keys' not in self.items:
self.items['insert_keys'] = self._stripItems(rows[0].keys())
if 'insert_values' not in self.items:
self.items['insert_values'] = []
for row in rows:
self.items['insert_values'].append(list(row.values()))
return self
def returning(self, *items):
if 'returning' not in self.items:
self.items['returning'] = []
self.items['returning'] += self._stripItems(items)
return self
def insertRaw(self, sql, args):
if not isinstance(self.composer, InsertRaw):
self.composer = InsertRaw()
self.items['sql'] = sql
self.items['args'] = args
return self
def delete(self):
if not isinstance(self.composer, Delete):
self.composer = Delete()
return self
async def create(self, items):
self.composer = Create()
self.items = items
return await self.exec()
async def create_index(self, items):
self.composer = CreateIndex()
self.items = items
return await self.exec()
async def rename_table(self, table_name_old, table_name_new):
self.composer = RenameTable()
self.items = {'table': table_name_old, 'rename': table_name_new}
return await self.exec()
async def drop_table(self, table_name, ifExists=False):
self.composer = DropTable()
self.items = {'table': table_name, 'if_exists': ifExists}
return await self.exec()
async def alter(self, items):
self.composer = Alter()
self.items = items
return await self.exec()
async def alter_index(self, items):
self.composer = AlterIndex()
self.items = items
return await self.exec()
async def drop_index(self, index_name):
self.composer = DropIndex()
self.items = {'index': index_name}
return await self.exec()
async def drop_primary_key(self, table_name, pkey):
self.composer = DropPrimaryKey()
self.items = {'table': table_name, 'pkey': pkey}
return await self.exec()
async def drop_constraint(self, table_name, constraint_name):
self.composer = DropConstraint()
self.items = {'table': table_name, 'constraint_name': constraint_name}
return await self.exec()
class Statement:
"""prepare sql scaffold, each step compose a piece in sql, add args and add build parser tree"""
def __init__(self):
self.sql = ''
self.parser = None
self.args = []
self.steps = []
def compose(self, items):
self.sql = ''
self.parser = CRUD()
self.args = []
for step in self.steps:
step = getattr(self, step)
step(items)
params = self.parser.compile()
return self.sql.format(**params), self.args
def get_field_sql(self, item):
# check table.field
pair = item.split('.')
if len(pair) > 2:
raise UserWarning("Identifier has more than two dots")
if len(pair) == 2:
full_field = self.parser('full_field')
full_field.add(self.parser('identifier', pair[0]))
full_field.add(self.get_field_sql(pair[1]))
return full_field
# field
if '->>' in item:
field, *attrs = item.split('->')
item_sql = self.parser('jsonb_text')
item_sql.add(self.parser('identifier', field))
for attr in attrs:
attr = attr.lstrip('>')
item_sql.add(self.parser('literal', attr))
elif '->' in item:
field, *attrs = item.split('->')
item_sql = self.parser('jsonb')
item_sql.add(self.parser('identifier', field))
for attr in attrs:
item_sql.add(self.parser('literal', attr))
elif item == '*':
item_sql = self.parser('const', '*')
else:
item_sql = self.parser('identifier', item)
return item_sql
class Raw(Statement):
def __init__(self):
super().__init__()
self.steps = ['raw']
def raw(self, items):
self.sql = items['sql']
self.args = items['args']
return self
class Select(Statement):
def __init__(self):
super().__init__()
self.steps = ['select', 'join', 'where', 'order_by', 'group_by', 'limit']
def select(self, items):
self.sql = 'SELECT {select} FROM {table}'
self.parser('identifier', items['table'], data_key='table')
select = self.parser('select', data_key='select')
fields = items['select']
if len(fields) == 0:
fields = ['*']
for field in fields:
if isinstance(field, str):
select.add(self.select_item(field))
else:
raise UserWarning("Not implemented")
return self
def select_item(self, item):
# check field AS alias
pair = re.split('\s+as\s+', item, flags=re.IGNORECASE)
if len(pair) > 2:
raise UserWarning("Too many 'AS' in select() argument")
if len(pair) == 2:
select_as = self.parser('select_as')
select_as.add(self.get_field_sql(pair[0]))
select_as.add(self.parser('identifier', pair[1]))
return select_as
return self.get_field_sql(item)
def join(self, items):
if 'join' in items:
self.sql += ' {join}'
join = self.parser('join', data_key='join')
for clause in items['join']:
if len(clause) != 4:
raise UserWarning("Invalid arguments for join()")
join.add(self.join_item(*clause))
return self
def join_item(self, table_name, left, op, right):
join_item = self.parser('join_item')
join_item.add(self.parser('identifier', table_name))
join_item.add(self.get_field_sql(left))
join_item.add(self.parser('const', op))
join_item.add(self.get_field_sql(right))
return join_item
def where(self, items):
where = None
if 'where' in items:
self.sql += ' WHERE {where}'
where = self.parser('where', data_key='where')
for clause in items['where']:
if len(clause) == 2:
if isinstance(clause[1], list):
where.add(self.where_item(clause[0], 'IN', clause[1]))
else:
where.add(self.where_item(clause[0], '=', clause[1]))
elif len(clause) == 3:
where.add(self.where_item(*clause))
else:
raise UserWarning("Invalid arguments for DB.where")
if 'where_raw' in items:
if where is None:
self.sql += ' WHERE {where}'
where = self.parser('where', data_key='where')
for query in items['where_raw']:
where.add(self.parser('raw', query))
return self
def where_item(self, field, op, val):
if op.upper() == 'IN':
where_item = self.parser('where_in_item')
where_item.add(self.get_field_sql(field))
idx = []
for v in val:
self.args.append(v)
idx.append(len(self.args))
where_item.add(idx)
elif op.upper() == 'NOT IN':
where_item = self.parser('where_not_in_item')
where_item.add(self.get_field_sql(field))
idx = []
for v in val:
self.args.append(v)
idx.append(len(self.args))
where_item.add(idx)
else:
where_item = self.parser('where_item')
where_item.add(op)
where_item.add(self.get_field_sql(field))
if '->' in field and '->>' not in field:
val = json.dumps(val)
if val is None:
where_item.add(None)
else:
self.args.append(val)
where_item.add(len(self.args))
return where_item
def order_by(self, items):
order_by = None
if 'order_by' in items:
self.sql += ' ORDER BY {order_by}'
order_by = self.parser('order_by', data_key='order_by')
for item in items['order_by']:
if isinstance(item, str):
order_by.add(self.order_by_item(item))
else:
raise UserWarning("order_by(): invalid arguments")
if 'order_by_raw' in items:
if order_by is None:
self.sql += ' ORDER BY {order_by}'
order_by = self.parser('order_by', data_key='order_by')
for query in items['order_by_raw']:
order_by.add(self.parser('raw', query))
return self
def order_by_item(self, item):
if item.upper().endswith(' ASC'):
order_by_item = self.parser('order_by_item')
order_by_item.add(self.get_field_sql(item[:-4].strip()))
order_by_item.add('ASC')
elif item.upper().endswith(' DESC'):
order_by_item = self.parser('order_by_item')
order_by_item.add(self.get_field_sql(item[:-5].strip()))
order_by_item.add('DESC')
else:
order_by_item = self.parser('order_by_item', self.get_field_sql(item))
return order_by_item
def group_by(self, items):
if 'group_by' in items:
self.sql += ' GROUP BY {group_by}'
group_by = self.parser('group_by', data_key='group_by')
for item in items['group_by']:
if isinstance(item, str):
group_by.add(self.get_field_sql(item))
else:
raise UserWarning("group_by(): invalid arguments")
return self
def limit(self, items):
if 'limit' in items:
self.sql += ' LIMIT {limit}'
self.args.append(items['limit'])
self.parser('limit', len(self.args), data_key='limit')
return self
class SelectRaw(Select):
def select(self, items):
self.sql = 'SELECT {select} FROM {table}'
self.parser('identifier', items['table'], data_key='table')
select = self.parser('select', data_key='select')
fields = items['select']
if len(fields) == 0:
fields = ['*']
for field in fields:
if isinstance(field, str):
select.add(self.parser('raw', field))
else:
raise UserWarning("Not implemented")
return self
class Update(Select):
def __init__(self):
super().__init__()
self.steps = ['update', 'where']
def update(self, items):
self.sql = 'UPDATE {table} SET {update}'
self.parser('identifier', items['table'], data_key='table')
from_table = None
if 'update_from' in items:
self.sql += ' FROM {update_from}'
self.parser('identifier', items['update_from'], data_key='update_from')
from_table = items['update_from']+'.'
update = self.parser('update', data_key='update')
for item, value in items['update'].items():
if isinstance(item, str):
item = item.strip()
if from_table is not None and isinstance(value, str) and from_table in value:
update.add(self.update_from_item(item, value.strip()))
else:
update.add(self.update_item(item, value))
else:
raise UserWarning("Not implemented")
return self
def update_from_item(self, item, val):
if '->>' in item:
raise UserWarning("Not implemented")
elif '->' in item:
raise UserWarning("Not implemented")
else:
update_from_item = self.parser('update_from_item')
update_from_item.add(self.get_field_sql(item))
update_from_item.add(self.get_field_sql(val))
return update_from_item
def update_item(self, item, val):
if '->>' in item:
raise UserWarning("->> is not valid for update()")
elif '->' in item:
return self.update_jsonb(item, val)
else:
update_item = self.parser('update_item')
update_item.add(self.get_field_sql(item))
if isinstance(val, dict):
val = json.dumps(val)
if val is None:
update_item.add(None)
else:
self.args.append(val)
update_item.add(len(self.args))
return update_item
def update_jsonb(self, item, val):
update_jsonb = self.parser('update_jsonb')
field, *attrs = item.split('->')
update_jsonb.add(self.parser('identifier', field))
for attr in attrs:
update_jsonb.add(attr)
update_jsonb.add(val)
return update_jsonb
def where(self, items):
if 'where' in items:
self.sql += ' WHERE {where}'
from_table = None
if 'update_from' in items:
from_table = items['update_from']+'.'
where = self.parser('where', data_key='where')
for clause in items['where']:
if len(clause) == 2:
field, value = clause
op = '='
elif len(clause) == 3:
field, op, value = clause
else:
raise UserWarning("Invalid arguments for DB.where")
if from_table is not None and (from_table in field or isinstance(value, str) and from_table in value):
where.add(self.where_from_item(field, op, value))
else:
where.add(self.where_item(field, op, value))
return self
def where_from_item(self, field, op, value):
where_from_item = self.parser('where_from_item')
where_from_item.add(self.get_field_sql(field))
where_from_item.add(op)
where_from_item.add(self.get_field_sql(value))
return where_from_item
class Insert(Select):
def __init__(self):
super().__init__()
self.steps = ['insert', 'returning']
def insert(self, items):
self.sql = 'INSERT INTO {table} ({keys}) VALUES {values}'
self.parser('identifier', items['table'], data_key='table')
insert_keys = self.parser('insert_keys', data_key='keys')
for key in items['insert_keys']:
insert_keys.add(self.parser('identifier', key))
insert_values = self.parser('insert_values', data_key='values')
for values in items['insert_values']:
insert_value = self.parser('insert_value')
for value in values:
if isinstance(value, dict):
value = json.dumps(value)
self.args.append(value)
insert_value.add(len(self.args))
insert_values.add(insert_value)
return self
def returning(self, items):
if 'returning' in items:
self.sql += ' RETURNING {returning}'
returning = self.parser('returning', data_key='returning')
fields = items['returning']
if len(fields) == 0:
fields = ['*']
for field in fields:
if isinstance(field, str):
returning.add(self.get_field_sql(field))
else:
raise UserWarning("Invalid field in RETURNING - {}".format(field))
return self
class InsertRaw(Select):
def __init__(self):
super().__init__()
self.steps = ['insert']
def insert(self, items):
self.sql = 'INSERT INTO {table} ' + items['sql']
self.args = items['args']
self.parser('identifier', items['table'], data_key='table')
return self
class Create(Statement):
def __init__(self):
super().__init__()
self.steps = ['create']
def create(self, items):
self.sql = 'CREATE TABLE {table} ({columns})'
self.parser('identifier', items['table'], data_key='table')
create_columns = self.parser('create_columns', data_key='columns')
for column in items['columns']:
create_column = self.parser('create_column')
create_column.add(self.parser('identifier', column['name']))
create_column.add(self.parser('const', column['type']))
create_column.add(bool(column['nullable']))
create_column.add(self.create_column_default(column['default']))
create_column.add(column['primary_key'])
create_columns.add(create_column)
for unique_cols in items['uniques']:
unique_columns = self.parser('unique_columns')
for unique_col in unique_cols:
unique_columns.add(self.parser('identifier', unique_col))
create_columns.add(unique_columns)
if items['primary']:
primary_columns = self.parser('primary_columns')
for primary_col in items['primary']:
primary_columns.add(self.parser('identifier', primary_col))
create_columns.add(primary_columns)
return self
def create_column_default(self, default_val):
if default_val is None:
return None
if isinstance(default_val, bool):
val = 'TRUE' if default_val else 'FALSE'
return self.parser('const', val)
elif isinstance(default_val, int):
return default_val
elif default_val == "NULL" or default_val == 'NOW()':
return self.parser('const', default_val)
else:
if isinstance(default_val, dict):
default_val = json.dumps(default_val)
return self.parser('literal', default_val)
class CreateIndex(Statement):
def __init__(self):
super().__init__()
self.steps = ['create_index']
def create_index(self, items):
self.parser('identifier', items['table'], data_key='table')
index_col = 'create_index_col'
idx_cols = items['index_columns']
self.sql = 'CREATE INDEX ON {table}({'+index_col+'})'
index_columns = self.parser('index_columns', data_key=index_col)
for idx_col in idx_cols:
index_columns.add(self.parser('identifier', idx_col))
return self
class RenameTable(Statement):
def __init__(self):
super().__init__()
self.steps = ['rename']
def rename(self, items):
self.sql = 'ALTER TABLE {table} RENAME TO {rename}'
self.parser('identifier', items['table'], data_key='table')
self.parser('identifier', items['rename'], data_key='rename')
return self
class DropTable(Statement):
def __init__(self):
super().__init__()
self.steps = ['drop']
def drop(self, items):
if items['if_exists']:
self.sql = 'DROP TABLE IF EXISTS {table}'
else:
self.sql = 'DROP TABLE {table}'
self.parser('identifier', items['table'], data_key='table')
return self
class DropIndex(Statement):
def __init__(self):
super().__init__()
self.steps = ['drop']
def drop(self, items):
self.sql = 'DROP INDEX {index}'
self.parser('identifier', items['index'], data_key='index')
return self
class DropPrimaryKey(Statement):
def __init__(self):
super().__init__()
self.steps = ['drop']
def drop(self, items):
self.sql = 'ALTER TABLE {table} DROP CONSTRAINT {pkey}'
self.parser('identifier', items['table'], data_key='table')
self.parser('identifier', items['pkey'], data_key='pkey')
return self
class DropConstraint(Statement):
def __init__(self):
super().__init__()
self.steps = ['drop']
def drop(self, items):
self.sql = 'ALTER TABLE {table} DROP CONSTRAINT {constraint_name}'
self.parser('identifier', items['table'], data_key='table')
self.parser('identifier', items['constraint_name'], data_key='constraint_name')
return self
class Alter(Create):
def __init__(self):
super().__init__()
self.steps = ['alter']
def alter(self, items):
self.sql = 'ALTER TABLE {table} {actions}'
self.parser('identifier', items['table'], data_key='table')
alter_actions = self.parser('alter_actions', data_key='actions')
for column in items['columns']:
alter_column = self.parser('alter_column')
alter_column.add(column['action'])
alter_column.add(self.parser('identifier', column['name']))
if column['type'] is not None:
alter_column.add(self.parser('const', column['type']))
else:
alter_column.add(None)
alter_column.add(bool(column['nullable']))
alter_column.add(self.create_column_default(column['default']))
alter_column.add(column['primary_key'])
alter_actions.add(alter_column)
if items['primary']:
add_primary_columns = self.parser('add_primary_columns')
for primary_col in items['primary']:
add_primary_columns.add(self.parser('identifier', primary_col))
alter_actions.add(add_primary_columns)
return self
class AlterIndex(Statement):
def __init__(self):
super().__init__()
self.steps = ['alter_index']
def alter_index(self, items):
self.parser('identifier', items['table'], data_key='table')
idx = 1
# index
for idx_cols in items['indexes']:
index_col = 'create_index_col_'+str(idx)
idx += 1
if self.sql:
self.sql += '; '
self.sql += 'CREATE INDEX ON {table}({'+index_col+'})'
index_columns = self.parser('index_columns', data_key=index_col)
for idx_col in idx_cols:
index_columns.add(self.parser('identifier', idx_col))
# unique index
for idx_cols in items['uniques']:
index_col = 'create_index_col_'+str(idx)
idx += 1
if self.sql:
self.sql += '; '
self.sql += 'CREATE UNIQUE INDEX ON {table}({'+index_col+'})'
index_columns = self.parser('index_columns', data_key=index_col)
for idx_col in idx_cols:
index_columns.add(self.parser('identifier', idx_col))
return self
class Delete(Select):
def __init__(self):
super().__init__()
self.steps = ['delete', 'where']
def delete(self, items):
self.sql = 'DELETE FROM {table}'
self.parser('identifier', items['table'], data_key='table')
return self
class ColumnBuilder:
def __init__(self, name, type=None, nullable=True, default=None, unsigned=None, primary_key=False):
self.column = {}
self.column['name'] = name.strip()
self.column['type'] = type.strip().upper() if type is not None else type
self.column['nullable'] = nullable
self.column['default'] = default
self.column['unsigned'] = unsigned
self.column['primary_key'] = primary_key
self.column['action'] = 'add'
self.uniques = None
self.indexes = None
def __call__(self, items):
items['columns'].append(self.column)
if self.indexes:
items['indexes'].append(self.indexes)
if self.uniques:
items['uniques'].append(self.uniques)
def serial(self):
self.column["type"] = "SERIAL"
return self
def bigserial(self):
self.column["type"] = "BIGSERIAL"
return self
def text(self):
self.column["type"] = "VARCHAR"
return self
def string(self, length=255):
self.column["type"] = "VARCHAR (%d)" % (int(length))
return self
def integer(self):
self.column["type"] = "INTEGER"
return self
def bigint(self):
self.column["type"] = "BIGINT"
return self
def numeric(self, precision, scale):
self.column["type"] = "NUMERIC(%d, %d)" % (int(precision), int(scale))
return self
def timestamp(self):
self.column["type"] = "TIMESTAMP"
return self
def timestamptz(self):
self.column["type"] = "TIMESTAMPTZ"
return self
def boolean(self):
self.column["type"] = "BOOLEAN"
return self
def jsonb(self):
self.column["type"] = "JSONB"
return self
def nullable(self, isNullable=True):
self.column["nullable"] = isNullable
return self
def primary_key(self):
self.column["primary_key"] = True
return self
def unique(self):
self.uniques = [self.column['name']]
return self
def index(self):
self.indexes = [self.column['name']]
return self
def default(self, val):
if val is None:
val = 'NULL'
self.column["default"] = val
return self
# ALTER TABLE
def drop(self):
self.column['action'] = 'drop'
return self
class IndexBuilder:
def __init__(self, columns, type='index'):
self.type = type
self.columns = [c.strip() for c in columns]
def __call__(self, items):
if self.type == 'unique':
items['uniques'].append(self.columns)
elif self.type == 'primary':
items['primary'] = self.columns
elif self.type == 'index':
items['indexes'].append(self.columns)
else:
raise UserWarning("Unknown index type: {}".format(self.type))

139
windyquery/index.py Normal file
View File

@@ -0,0 +1,139 @@
import asyncpg
from .builder import Builder
from .builder import ColumnBuilder
from .builder import IndexBuilder
class Connection:
"""The base class for DB connection"""
def __init__(self):
self.default = ''
self.builder = None
self.conn_pools = {}
async def connect(self, connection_name, config, default=False, min_size=10, max_size=10, max_queries=50000, max_inactive_connection_lifetime=300.0, setup=None, init=None, loop=None, connection_class=asyncpg.connection.Connection, **connect_kwargs):
if connection_name in self.conn_pools:
raise UserWarning("connection: {} already exists".format(connection_name))
dsn= "postgresql://{}:{}@{}:{}/{}".format(
config['username'],
config['password'],
config['host'],
config['port'],
config['database']
)
self.conn_pools[connection_name] = await asyncpg.create_pool(
dsn=dsn,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
setup=setup,
init=init,
loop=loop,
connection_class=connection_class,
**connect_kwargs)
if default:
self.default = connection_name
async def stop(self):
for name, pool in self.conn_pools.items():
await pool.close()
pool.terminate()
self.default = ''
self.conn_pools = {}
def get_builder(self):
if self.builder is None:
pool = self.conn_pools[self.default]
builder = Builder(pool)
else:
builder = self.builder
self.builder = None
return builder
def connection(self, connection_name):
if connection_name not in self.conn_pools:
raise UserWarning("connection: {} does not exists".format(connection_name))
pool = self.conn_pools[connection_name]
self.builder = Builder(pool)
return self
def raw(self, query, args=None):
if args is None:
args = []
return self.get_builder().raw(query, args)
class DB(Connection):
"""DB class for CRUD"""
def table(self, table_name):
return self.get_builder().table(table_name)
class Schema(Connection):
"""DB class for migrations"""
def get_column_items(self, table_name, columns):
items = {}
items['table'] = table_name.strip()
items["columns"] = []
items["primary"] = None
items["uniques"] = []
items["indexes"] = []
for col in columns:
col(items)
return items
async def create(self, table_name, *columns):
items = self.get_column_items(table_name, columns)
builder = self.get_builder()
await builder.create(items)
for idx_cols in items['indexes']:
builder.reset()
items['index_columns'] = idx_cols
await builder.create_index(items)
async def rename(self, table_name_old, table_name_new):
return await self.get_builder().rename_table(table_name_old, table_name_new)
async def drop(self, table_name):
return await self.get_builder().drop_table(table_name)
async def drop_if_exists(self, table_name):
return await self.get_builder().drop_table(table_name, True)
async def table(self, table_name, *columns):
items = self.get_column_items(table_name, columns)
builder = self.get_builder()
if len(items['columns']) > 0 or items['primary'] is not None:
await builder.alter(items)
if len(items['indexes']) > 0 or len(items['uniques']) > 0:
builder.reset()
await builder.alter_index(items)
async def dropIndex(self, index_name):
return await self.get_builder().drop_index(index_name)
async def dropPrimaryKey(self, table_name, pkey=None):
if pkey is None:
pkey = table_name+'_pkey'
return await self.get_builder().drop_primary_key(table_name, pkey)
async def dropConstraint(self, table_name, constraint_name):
return await self.get_builder().drop_constraint(table_name, constraint_name)
def column(cls, name, **attr):
return ColumnBuilder(name, **attr)
def primary_key(cls, *columns):
return IndexBuilder(columns, type='primary')
def unique(cls, *columns):
return IndexBuilder(columns, type='unique')
def index(cls, *columns):
return IndexBuilder(columns, type='index')

193
windyquery/model.py Normal file
View File

@@ -0,0 +1,193 @@
import asyncio
import json
from json.decoder import JSONDecodeError
import re
from rx.subjects import BehaviorSubject
from .index import DB, Builder
class Event:
"""hold events"""
db = BehaviorSubject(None)
class Column:
"""represents a DB column"""
def __init__(self, name, type, ordinal):
self.name = name
self.type = type
self.ordinal = ordinal
class ModelMeta(type):
"""add table name and columns to Model"""
def __init__(cls, name, bases, attr_dict):
super().__init__(name, bases, attr_dict)
if cls.__name__ == 'Model':
return
initialized = False
def setup(db):
nonlocal cls
nonlocal initialized
cls.db = db
if initialized:
return
if not hasattr(cls, 'table'):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__)
s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
cls.table = s2 if s2.endswith('s') else s2+'s'
if hasattr(cls, 'columns'):
cols = cls.columns
else:
cols = asyncio.get_event_loop().run_until_complete(cls.db.raw("SELECT column_name, ordinal_position, data_type FROM information_schema.columns WHERE table_name = '{}'".format(cls.table)))
cls.columns = tuple(Column(col['column_name'], col['data_type'], col['ordinal_position']) for col in cols)
if not hasattr(cls, 'id'):
row = asyncio.get_event_loop().run_until_complete(cls.db.raw("SELECT column_name FROM information_schema.key_column_usage AS c LEFT JOIN information_schema.table_constraints AS t ON t.constraint_name = c.constraint_name WHERE t.table_name = '{}' AND t.constraint_type = 'PRIMARY KEY'".format(cls.table)).first())
if row and 'column_name' in row:
cls.id = row['column_name']
else:
cls.id = None
initialized = True
Event.db.filter(lambda db: db is not None).subscribe(setup)
class Model(metaclass=ModelMeta):
"""represents a DB record with querry functions"""
connection = None
def __getitem__(self, k):
return getattr(self, k)
def __setitem__(self, k, v):
setattr(self, k, v)
def __delitem__(self, k):
delattr(self, k)
def __len__(self):
return len(self.__dict__)
def __iter__(self):
return iter(self.__dict__.items())
def __contains__(self, item):
return item in self.__dict__
@classmethod
def set(cls, instance, record):
record = dict(record)
for col in cls.columns:
if col.name in record:
val = record[col.name]
if col.type == 'jsonb' and isinstance(val, str):
try:
val = json.loads(val)
except JSONDecodeError:
pass
setattr(instance, col.name, val)
return instance
def __init__(self, *records, **record):
cls = type(self)
if cls.id is not None:
setattr(self, cls.id, None)
for r in records:
cls.set(self, r)
cls.set(self, record)
@classmethod
def builder(cls):
if cls.connection is None:
pool = cls.db.conn_pools[cls.db.default]
else:
pool = cls.db.conn_pools[cls.connection]
return ModelBuilder(pool, cls).table(cls.table).select()
@classmethod
def all(cls):
"""Retrieve all instances"""
return cls.builder()
@classmethod
def find(cls, id):
"""Retrieve a model instance by id"""
builder = cls.builder().where(cls.id, id)
if not isinstance(id, list):
builder.first()
return builder
@classmethod
def where(cls, *args):
"""Retrieve a model instance by where"""
return cls.builder().where(*args)
async def save(self):
"""Save to DB"""
record = self.__dict__.copy()
for col in type(self).columns:
if col.type == 'jsonb' and col.name in record:
record[col.name] = json.dumps(record[col.name])
builder = type(self).builder()
id_name = type(self).id
id = getattr(self, id_name) if id_name else None
if id:
await builder.where(id_name, id).update(record)
else:
if id_name:
del record[id_name]
model = await builder.insert(record).returning().first()
if model:
type(self).set(self, model)
return self
async def delete(self):
"""Delete from DB"""
builder = type(self).builder()
id_name = type(self).id
id = getattr(self, id_name) if id_name else None
if id:
await builder.where(id_name, id).delete()
class ModelBuilder(Builder):
"""wrap the Builder class to return a Model instance after exec"""
def __init__(self, pool, model_cls):
super().__init__(pool)
self.model_cls = model_cls
self._new_if_not_found = False
def reset(self):
super().reset()
self._new_if_not_found = False
async def exec(self):
if self.composer is None:
raise UserWarning("SQL Builder is not complete")
rows = await super().exec()
if isinstance(rows, list):
result = [self.model_cls.set(self.model_cls(), row) for row in rows]
else:
if rows:
result = self.model_cls.set(self.model_cls(), rows)
else:
result = self.model_cls() if self._new_if_not_found else None
self.reset()
return result
def first_or_new(self):
self._first = True
self._new_if_not_found = True
return self

211
windyquery/parser.py Normal file
View File

@@ -0,0 +1,211 @@
import json
import re
from asyncpg import utils
class InvalidSqlError(RuntimeError):
"""raise when can not build a valid sql"""
class CRUD:
"""Base class for synthesize CRUD query"""
def __init__(self):
self.data = {}
def compile(self):
sqlParams = {}
for name, sql in self.data.items():
sqlParams[name] = sql()
return sqlParams
def __call__(self, sql, *args, data_key=None):
if data_key is not None:
if data_key not in self.data:
self.data[data_key] = SQL(sql, *args)
return self.data[data_key]
else:
return SQL(sql, *args)
class SQL:
"""class used to represent SQL fragment"""
def __init__(self, action, sql=None):
self.action = getattr(self, action)
self.sql = sql
def add(self, sql):
if self.sql is None:
self.sql = []
self.sql.append(sql)
def __call__(self):
if isinstance(self.sql, list):
return self.action(*(s() if isinstance(s, SQL) else s for s in self.sql))
elif isinstance(self.sql, SQL):
return self.action(self.sql())
else:
return self.action(self.sql)
def identifier(self, var):
return utils._quote_ident(var)
def literal(self, var):
return utils._quote_literal(var)
def const(self, var):
allowed = ['*', '=', 'SERIAL', 'BIGSERIAL', 'VARCHAR', 'INTEGER', 'BIGINT', 'NUMERIC', 'TIMESTAMP', 'TIMESTAMPTZ', 'BOOLEAN', 'JSONB', 'TRUE', 'FALSE', 'NULL', 'NOW()']
if var not in allowed and not re.match(r'VARCHAR \(\d+\)', var) and not re.match(r'NUMERIC\(\d+, \d+\)', var):
raise InvalidSqlError("not allowed to use raw string: {}".format(var))
return var
def raw(self, var):
return var
def jsonb_text(self, field, *attrs):
attrs = list(attrs)
prefix = '->'.join([field] + attrs[:-1])
return '{}->>{}'.format(prefix, attrs[-1])
def jsonb(self, field, *attrs):
attrs = list(attrs)
return '->'.join([field]+attrs)
def select(self, *items):
return ', '.join(items)
def select_as(self, field, alias):
return '{} AS {}'.format(field, alias)
def full_field(self, table_name, field):
return '{}.{}'.format(table_name, field)
def where(self, *items):
return ' AND '.join(items)
def where_item(self, op, field, idx):
allowed = ['=', '<', '<=', '>', '>=', 'IN', 'NOT IN', 'IS', 'IS NOT', 'LIKE']
op = op.upper()
if op not in allowed:
raise InvalidSqlError('invalid operator in where clause: {}'.format(op))
if idx is None:
return '{} {} NULL'.format(field, op)
else:
return '{} {} ${}'.format(field, op, idx)
def where_in_item(self, field, idx):
padded = []
for id in idx:
padded.append('${}'.format(id))
place_hodlers = ', '.join(padded)
return '{} IN ({})'.format(field, place_hodlers)
def where_not_in_item(self, field, idx):
padded = []
for id in idx:
padded.append('${}'.format(id))
place_hodlers = ', '.join(padded)
return '{} NOT IN ({})'.format(field, place_hodlers)
def join(self, *items):
return ' '.join(items)
def join_item(self, table, left_expr, join_op, right_expr):
return "JOIN {} ON {} {} {}".format(table, left_expr, join_op, right_expr)
def order_by(self, *items):
return ', '.join(items)
def order_by_item(self, field, dir=None):
if dir is None:
return field
else:
if dir not in ['ASC', 'DESC']:
raise InvalidSqlError('invalid order by dir {}'.format(dir))
return "{} {}".format(field, dir)
def group_by(self, *items):
return ', '.join(items)
def limit(self, idx):
return "${}".format(idx)
def update(self, *items):
return ', '.join(items)
def update_item(self, field, idx):
if idx is None:
return '{} = NULL'.format(field)
else:
return "{} = ${}".format(field, idx)
def update_from_item(self, field, from_field):
return "{} = {}".format(field, from_field)
def where_from_item(self, field, op, value):
allowed = ['=', '<', '<=', '>', '>=', 'IN', 'IS']
op = op.upper()
if op not in allowed:
raise InvalidSqlError('invalid operator in where clause: {}'.format(op))
return "{} {} {}".format(field, op, value)
def update_jsonb(self, field, *attrs):
update = attrs[-1]
for attr in reversed(attrs[:-1]):
update = {attr: update}
return "{} = COALESCE({}, '{}') || '{}'".format(field, field, json.dumps({}), json.dumps(update))
def insert_keys(self, *keys):
return ', '.join(keys)
def insert_value(self, *index):
insertStr = ', '.join('${}'.format(idx) for idx in index)
return '({})'.format(insertStr)
def insert_values(self, *values):
return ', '.join(values)
def returning(self, *values):
return ', '.join(values)
def create_columns(self, *columns):
return ', '.join(columns)
def create_column(self, name, type, nullable, default, primary_key):
s = "{} {}".format(name, type)
if not nullable:
s += " NOT NULL"
if default is not None:
s += ' DEFAULT {}'.format(default)
if primary_key:
s += " PRIMARY KEY"
return s
def unique_columns(self, *columns):
return 'UNIQUE ({})'.format(', '.join(columns))
def primary_columns(self, *columns):
return 'PRIMARY KEY ({})'.format(', '.join(columns))
def index_columns(self, *columns):
return ', '.join(columns)
def alter_actions(self, *columns):
return ', '.join(columns)
def alter_column(self, action, name, type, nullable, default, primary_key):
if action == 'drop':
s = 'DROP COLUMN IF EXISTS {}'.format(name)
else: # default to ADD
s = "ADD COLUMN {} {}".format(name, type)
if not nullable:
s += " NOT NULL"
if default is not None:
s += ' DEFAULT {}'.format(default)
if primary_key:
s += " PRIMARY KEY"
return s
def add_primary_columns(self, *columns):
return 'ADD PRIMARY KEY ({})'.format(', '.join(columns))

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,60 @@
import asyncio
import pytest
from windyquery import DB, Schema
@pytest.fixture(scope="module")
def config():
class Config:
DB_HOST = "localhost",
DB_PORT = "5432",
DB_TEST = "db_test",
DB_USER = "tester name",
DB_PASS = "tester password"
yield Config
@pytest.fixture(scope="module")
def db(config):
app_db = DB()
async def init_db():
return await app_db.connect('db_test', {
'host': config.DB_HOST,
'port': config.DB_PORT,
'database': config.DB_TEST,
'username': config.DB_USER,
'password': config.DB_PASS
}, default=True)
asyncio.get_event_loop().run_until_complete(init_db())
yield app_db
asyncio.get_event_loop().run_until_complete(app_db.stop())
@pytest.fixture(scope="module")
def schema(config):
app_schema = Schema()
async def init_schema():
return await app_schema.connect('db_test', {
'host': config.DB_HOST,
'port': config.DB_PORT,
'database': config.DB_TEST,
'username': config.DB_USER,
'password': config.DB_PASS
}, default=True)
asyncio.get_event_loop().run_until_complete(init_schema())
yield app_schema
asyncio.get_event_loop().run_until_complete(app_schema.stop())
@pytest.fixture(scope="module")
def model(config):
from windyquery.model import Event
model_db = DB()
conn_coro = model_db.connect('db_test', {
'host': config.DB_HOST,
'port': config.DB_PORT,
'database': config.DB_TEST,
'username': config.DB_USER,
'password': config.DB_PASS
}, default=True)
asyncio.get_event_loop().run_until_complete(conn_coro)
Event.db.on_next(model_db)
yield model_db
asyncio.get_event_loop().run_until_complete(model_db.stop())

View File

@@ -0,0 +1,50 @@
import asyncio
loop = asyncio.get_event_loop()
def test_add_column(schema):
async def add_col():
return await schema.table('users',
schema.column('test_col').string(50).nullable(False).default('test default'),
)
async def drop_col():
return await schema.table('users',
schema.column('test_col').drop(),
)
loop.run_until_complete(add_col())
loop.run_until_complete(drop_col())
assert 1 == 1
def test_alter_unique_index(schema):
async def add_unique():
return await schema.table('boards',
schema.unique('user_id', 'location'),
)
async def drop_unique():
return await schema.dropIndex('boards_user_id_location_idx')
loop.run_until_complete(add_unique())
loop.run_until_complete(drop_unique())
assert 1 == 1
def test_alter_index(schema):
async def add_index():
return await schema.table('boards',
schema.index('user_id', 'location'),
)
async def drop_index():
return await schema.dropIndex('boards_user_id_location_idx')
loop.run_until_complete(add_index())
loop.run_until_complete(drop_index())
assert 1 == 1
def test_alter_primary_key(schema):
async def add_primary():
return await schema.table('cards_copy',
schema.primary_key('id', 'board_id'),
)
async def drop_primary():
return await schema.dropPrimaryKey('cards_copy')
loop.run_until_complete(add_primary())
loop.run_until_complete(drop_primary())
assert 1 == 1

View File

@@ -0,0 +1,87 @@
import asyncio
loop = asyncio.get_event_loop()
def test_create_user_tmp(schema):
async def create_user_tmp():
return await schema.create('users_tmp',
schema.column('id').serial().primary_key(),
schema.column('email').string().nullable(False).unique(),
schema.column('password').string().nullable(False),
schema.column('registered_on').timestamp().nullable(False).default("NOW()"),
schema.column('admin').boolean().nullable(False).default(False)
)
async def drop_table():
return await schema.drop('users_tmp')
loop.run_until_complete(create_user_tmp())
loop.run_until_complete(drop_table())
assert 1 == 1
def test_create_unique_index(schema):
async def create_unique_index():
return await schema.create('users_tmp',
schema.column('id').serial().primary_key(),
schema.column('name').string().nullable(False),
schema.column('user_id').integer().nullable(False),
schema.unique('user_id', 'name'),
schema.column('created_at').timestamp().nullable(False).default("NOW()"),
schema.column('updated_at').timestamp().nullable(False).default("NOW()"),
schema.column('deleted_at').timestamp().nullable(True)
)
async def drop_table():
return await schema.drop('users_tmp')
loop.run_until_complete(create_unique_index())
loop.run_until_complete(drop_table())
assert 1 == 1
def test_create_primary_key(schema):
async def create_user_tmp():
return await schema.create('users_tmp2',
schema.column('name').string().nullable(False),
schema.column('email').string().nullable(False),
schema.primary_key('name', 'email'),
schema.column('password').string().nullable(False),
schema.column('registered_on').timestamp().nullable(False).default("NOW()"),
schema.column('admin').boolean().nullable(False).default(False)
)
async def drop_table():
return await schema.drop('users_tmp2')
loop.run_until_complete(create_user_tmp())
loop.run_until_complete(drop_table())
assert 1 == 1
def test_create_index_key(schema):
async def create_user_tmp():
return await schema.create('users_tmp3',
schema.column('name').string().nullable(False),
schema.column('email').string().nullable(False),
schema.index('name', 'email'),
schema.column('password').string().nullable(False),
schema.column('registered_on').timestamp().nullable(False).default("NOW()"),
schema.column('admin').boolean().nullable(False).default(False)
)
async def drop_table():
return await schema.drop('users_tmp3')
loop.run_until_complete(create_user_tmp())
loop.run_until_complete(drop_table())
assert 1 == 1
def test_drop_nonexists(schema):
async def drop_table():
return await schema.drop_if_exists('not_exist_table')
loop.run_until_complete(drop_table())
assert 1 == 1
def test_create_jsonb(schema):
async def create_jsonb():
return await schema.create('cards_tmp',
schema.column('id').integer().nullable(False),
schema.column('board_id').integer().nullable(False),
schema.column('data').jsonb()
)
async def drop_table():
return await schema.drop('cards_tmp')
loop.run_until_complete(create_jsonb())
loop.run_until_complete(drop_table())
assert 1 == 1

View File

@@ -0,0 +1,15 @@
import asyncio
import string
import random
loop = asyncio.get_event_loop()
def test_delete(db):
test_id = 99999
test_name = 'delete'
loop.run_until_complete(db.table('test').insert({'id': test_id, 'name': test_name}))
row = loop.run_until_complete(db.table('test').select().where('id', test_id).first())
assert row['name'] == test_name
loop.run_until_complete(db.table('test').where('id', test_id).delete())
row = loop.run_until_complete(db.table('test').select().where('id', test_id).first())
assert row is None

View File

@@ -0,0 +1,11 @@
import asyncio
loop = asyncio.get_event_loop()
def test_group_by(db):
async def group_by():
return await db.table('boards').select('user_id').group_by('user_id')
rows = loop.run_until_complete(group_by())
assert len(rows) == 2
assert rows[0]['user_id'] == 2

View File

@@ -0,0 +1,31 @@
import asyncio
import string
import random
loop = asyncio.get_event_loop()
def test_insert_user(db):
async def insert_user(email1, email2):
return await db.table('users').insert({'email': email1, 'password': 'my precious'}, {'email': email2, 'password': 'my precious'})
async def get_user(email):
return await db.table('users').select().where('email', email).first()
email1 = ''.join(random.choice(string.ascii_letters) for i in range(6))
email2 = ''.join(random.choice(string.ascii_letters) for i in range(6))
loop.run_until_complete(insert_user(email1, email2))
row = loop.run_until_complete(get_user(email1))
assert row['email'] == email1
row = loop.run_until_complete(get_user(email2))
assert row['email'] == email2
def test_insert_jsonb(db):
async def insert_jsonb(test_id):
return await db.table('cards').insert({'id': test_id, 'board_id': random.randint(1, 100), 'data': {'name': 'hi{}'.format(test_id), 'address': {'city': 'Chicago', 'state': 'IL'}}})
async def get_jsonb(test_id):
return await db.table('cards').select('data->>name AS name').where('id', test_id).first()
test_id = random.randint(1, 10000)
loop.run_until_complete(insert_jsonb(test_id))
row = loop.run_until_complete(get_jsonb(test_id))
assert row['name'] == 'hi'+str(test_id)

View File

@@ -0,0 +1,22 @@
import asyncio
import string
import random
loop = asyncio.get_event_loop()
def test_insert(db):
test_id = 9998
name = ''.join(random.choice(string.ascii_letters) for i in range(6))
loop.run_until_complete(db.table('test').insertRaw(
'("id", "name") SELECT $1, $2 WHERE NOT EXISTS (SELECT "id" FROM test WHERE "id" = $1)', [test_id, name]
))
# insert it again but should fail
loop.run_until_complete(db.table('test').insertRaw(
'("id", "name") SELECT $1, $2 WHERE NOT EXISTS (SELECT "id" FROM test WHERE "id" = $1)', [test_id, name]
))
rows = loop.run_until_complete(db.table('test').select().where('id', test_id))
assert len(rows) == 1
row = rows[0]
assert row['name'] == name
loop.run_until_complete(db.table('test').where('id', test_id).delete())

View File

@@ -0,0 +1,19 @@
import asyncio
loop = asyncio.get_event_loop()
def test_simple_join(db):
async def simple_join():
return await db.table('cards').join(
'boards', 'cards.board_id', '=', 'boards.id'
).join(
'users', 'boards.user_id', '=', 'users.id'
).select(
'users.email', 'boards.*'
).where("users.id", 2).where('users.admin', '=', True).first()
row = loop.run_until_complete(simple_join())
assert row['email'] == 'test@example.com'
assert row['location'] == 'south door'

View File

@@ -0,0 +1,117 @@
import asyncio
from windyquery import Model
loop = asyncio.get_event_loop()
def test_table_name(model):
class User(Model):
pass
assert User.table == 'users'
class AdminUser(Model):
pass
assert SearchTemplateSchedule.table == 'admin_users'
class Custom(Model):
table = 'my_custom'
assert Custom.table == 'my_custom'
def test_empty_model(model):
class User(Model):
pass
user = User()
user.id = 8
assert user.id == 8
user = User(id=9, email='test@test.com')
assert user.email == 'test@test.com'
user = User({'id': 10, 'email': 'test@example.com'})
assert user.email == 'test@example.com'
user = User({'id': 10, 'email': 'test@example.com'}, email='testoveride@example.com')
assert user.email == 'testoveride@example.com'
def test_find(model):
class User(Model):
pass
user = loop.run_until_complete(User.find(2))
assert user.email == 'test@example.com'
users = loop.run_until_complete(User.find([1, 2]))
assert len(users) == 2
assert users[1].email == 'test@example.com' or users[1].email == 'test2@example.com'
def test_selected_colums(model):
class User(Model):
pass
user = loop.run_until_complete(User.find(2).select('email'))
assert user.email == 'test@example.com'
assert not hasattr(user, 'admin')
def test_where(model):
class User(Model):
pass
user = loop.run_until_complete(User.where("email", 'test@example.com').first())
assert user.id == 2
users = loop.run_until_complete(User.where("email", 'test@example.com'))
assert len(users) == 1
assert users[0].id == 2
user = loop.run_until_complete(User.where("email", 'no_such_email').first())
assert user is None
users = loop.run_until_complete(User.where("email", 'no_such_email'))
assert users == []
def test_where_none(model):
class Card(Model):
pass
card = loop.run_until_complete(Card.where("board_id", None).first())
assert card is None
card = loop.run_until_complete(Card.where("board_id", None))
assert card == []
def test_cls_id(model):
class User(Model):
pass
assert User.id == 'id'
class Country(Model):
table = 'country'
assert Country.id == 'numeric_code'
def test_save(model):
class User(Model):
pass
user = loop.run_until_complete(User.where("email", 'test@example.com').first())
new_val = 'north door' if user.password == 'south door' else 'south door'
user.password = new_val
user = loop.run_until_complete(user.save())
assert user.password == new_val
def test_save_new(model):
class User(Model):
pass
user = User(email='tmp@example.com', password='tmp_password')
user = loop.run_until_complete(user.save())
assert user.id > 0
loop.run_until_complete(user.delete())
user = loop.run_until_complete(User.find(user.id))
assert user is None
def test_multi_results(model):
class User(Model):
pass
users = loop.run_until_complete(User.where("password", 'secret'))
assert len(users) == 2
assert users[0].email == 'insert_multi_1'
assert users[1].email == 'insert_multi_2'
def test_all(model):
class Board(Model):
pass
results = loop.run_until_complete(Board.all())
assert len(results) == 3
assert results[1].location == 'bedroom'
def test_first_or_new(model):
class AdminUser(Model):
pass
admin = loop.run_until_complete(AdminUser.where('id', 78901).where('name', 'not_exist_testing_name').first_or_new())
assert admin is not None
assert isinstance(admin, AdminUser)
assert admin.id is None

View File

@@ -0,0 +1,35 @@
import asyncio
from windyquery import Model
loop = asyncio.get_event_loop()
def test_read_jsonb(model):
class Card(Model):
pass
card = loop.run_until_complete(Card.find(5))
assert card.data == {
"finished": False,
"name": "Hang paintings",
"tags": [
"Improvements",
"Office"
]
}
def test_write_jsonb(model):
class Card(Model):
pass
card = Card(board_id=1, data=[1,'hi',2])
card = loop.run_until_complete(card.save())
assert card.id > 0
card = loop.run_until_complete(Card.find(card.id))
assert card.data == [1,'hi',2]
loop.run_until_complete(Card.where('id', card.id).delete())
card = Card(board_id=1, data='plain string as json')
card = loop.run_until_complete(card.save())
assert card.id > 0
card = loop.run_until_complete(Card.find(card.id))
assert card.data == 'plain string as json'
loop.run_until_complete(Card.where('id', card.id).delete())

View File

@@ -0,0 +1,17 @@
import asyncio
loop = asyncio.get_event_loop()
def test_raw_select(db):
async def raw_select():
return await db.raw('SELECT * FROM cards WHERE board_id = $1', [5]).first()
row = loop.run_until_complete(raw_select())
assert row['id'] == 9247
def test_select_raw(db):
async def select_raw():
return await db.table('cards').select_raw('ROUND(AVG(board_id),1) AS avg_id, COUNT(1) AS copies').where('id', [4,5,6]).first()
row = loop.run_until_complete(select_raw())
assert row['avg_id'] == 2.0
assert row['copies'] == 3

View File

@@ -0,0 +1,48 @@
import asyncio
loop = asyncio.get_event_loop()
def test_db_test_connected(db):
assert 'db_test' in db.conn_pools
def test_raw_connection(db):
async def select_cards():
async with db.conn_pools['db_test'].acquire() as connection:
return await connection.fetchrow('SELECT * FROM test')
row = loop.run_until_complete(select_cards())
assert row['name'] == 'test'
def test_select_by_builder_toSql(db):
async def select_by_builder():
return await db.table('test').select().toSql()
sql = loop.run_until_complete(select_by_builder())
assert sql == 'SELECT * FROM "test"'
def test_select_by_builder(db):
async def select_by_builder():
return await db.table('test').select().first()
row = loop.run_until_complete(select_by_builder())
assert row['name'] == 'test'
def test_select_with_alias(db):
async def select_with_alias():
return await db.table('test').select('test.id AS name1', 'test.name').first()
row = loop.run_until_complete(select_with_alias())
assert row['name1'] == 1
assert row['name'] == 'test'
def test_select_with_jsonb(db):
async def select_with_jsonb():
return await db.table('cards').select('data->name AS name', 'data->>name AS name_text', 'data->tags AS tags', 'data->finished').where('id', 2).first()
row = loop.run_until_complete(select_with_jsonb())
assert row['name'] == '"Wash dishes"'
assert row['name_text'] == 'Wash dishes'
assert row['tags'] == '["Clean", "Kitchen"]'
assert row['?column?'] == 'false'
def test_select_nested_jsonb(db):
async def select_nested_jsonb():
return await db.table('cards').select('data->address->>city AS city').where('id', 8).first()
row = loop.run_until_complete(select_nested_jsonb())
assert row['city'] == 'Chicago'

View File

@@ -0,0 +1,36 @@
import asyncio
loop = asyncio.get_event_loop()
def test_order_by(db):
async def order_by():
return await db.table('users').select().order_by('id ASC', 'email ASC', 'password DESC').first()
row = loop.run_until_complete(order_by())
assert row['email'] == 'test@example.com'
assert row['id'] == 1
def test_order_by_with_table(db):
async def order_by_with_table():
return await db.table('users').select().order_by('users.id ASC', 'users.email ASC', 'password DESC').first()
row = loop.run_until_complete(order_by_with_table())
assert row['email'] == 'test@example.com'
assert row['id'] == 1
def test_order_by_with_jsonb(db):
async def order_by_with_jsonb():
return await db.table('cards').select('data->>name AS name').order_by('cards.data->name', 'id DESC').first()
row = loop.run_until_complete(order_by_with_jsonb())
assert row['name'] == 'Cook lunch'
def test_limit(db):
async def limit():
return await db.table('cards').select().limit(3)
rows = loop.run_until_complete(limit())
assert len(rows) == 3
def test_limit_str(db):
async def limit_str():
return await db.table('cards').select().limit('5')
rows = loop.run_until_complete(limit_str())
assert len(rows) == 5

View File

@@ -0,0 +1,58 @@
import asyncio
loop = asyncio.get_event_loop()
def test_simple_update(db):
async def get_board_id():
return await db.table('cards').select('board_id').where('id', 2).first()
async def update_board_id(test_id):
return await db.table('cards').where('id', 2).update({'board_id': test_id})
loop.run_until_complete(update_board_id(3))
row = loop.run_until_complete(get_board_id())
assert row['board_id'] == 3
loop.run_until_complete(update_board_id(2))
row = loop.run_until_complete(get_board_id())
assert row['board_id'] == 2
def test_simple_update_jsonb(db):
async def get_cards_city():
return await db.table('cards').select('data->address->>city AS city').where('id', 8).first()
async def update_cards_city(city):
return await db.table('cards').where('id', 8).update({'data->address->city': city})
loop.run_until_complete(update_cards_city('New York'))
row = loop.run_until_complete(get_cards_city())
assert row['city'] == 'New York'
loop.run_until_complete(update_cards_city('Chicago'))
row = loop.run_until_complete(get_cards_city())
assert row['city'] == 'Chicago'
def test_path_jsonb(db):
async def get_cards_skill():
return await db.table('cards').select('data->skill->>java AS java').where('id', 8).first()
async def update_cards_skill(level):
return await db.table('cards').where('id', 8).update({'data->skill->java': level})
loop.run_until_complete(update_cards_skill('Good'))
row = loop.run_until_complete(get_cards_skill())
assert row['java'] == 'Good'
def test_set_whole_jsonb(db):
async def get_cards_address():
return await db.table('cards').select('data->address->>city AS city').where('id', 7).first()
async def update_cards_address():
return await db.table('cards').where('id', 7).update({'data': {'address': {'city': 'New York'}}})
loop.run_until_complete(update_cards_address())
row = loop.run_until_complete(get_cards_address())
assert row['city'] == 'New York'
def test_update_from(db):
async def get_pass(user_id):
return await db.table('users').select('password').where('id', user_id).first()
async def update_pass(user_id):
return await db.table('users').update({'password': 'boards.location'}).update_from('boards').where('boards.user_id', 'users.id').where('users.id', user_id)
loop.run_until_complete(update_pass(1))
row = loop.run_until_complete(get_pass(1))
assert row['password'] == 'dining room'
loop.run_until_complete(update_pass(2))
row = loop.run_until_complete(get_pass(2))
assert row['password'] == 'south door'

View File

@@ -0,0 +1,43 @@
import asyncio
loop = asyncio.get_event_loop()
def test_single_where(db):
async def single_where():
return await db.table('test').select().where("name", 'test').first()
row = loop.run_until_complete(single_where())
assert row['name'] == 'test'
assert row['id'] == 1
def test_josnb_where(db):
async def jsonb_where():
return await db.table('cards').select('id', 'data->tags').where("data->name", 'Cook lunch').first()
row = loop.run_until_complete(jsonb_where())
assert row['id'] == 3
assert row['?column?'] == '["Cook", "Kitchen", "Tacos"]'
def test_josnb_text_where(db):
async def jsonb_where():
return await db.table('cards').select('id', 'data->>tags').where("data->name", 'Cook lunch').first()
row = loop.run_until_complete(jsonb_where())
assert row['id'] == 3
assert row['?column?'] == '["Cook", "Kitchen", "Tacos"]'
def test_multi_where(db):
async def jsonb_where():
return await db.table('cards').select('id', 'data->>tags').where("data->>name", 'Cook lunch').where('board_id', '=', 3).first()
row = loop.run_until_complete(jsonb_where())
assert row['?column?'] == '["Cook", "Kitchen", "Tacos"]'
def test_where_in(db):
async def where_in():
return await db.table('boards').select().where("id", 'IN', [1, 3])
rows = loop.run_until_complete(where_in())
assert len(rows) == 2
def test_where_list(db):
async def where_list():
return await db.table('boards').select().where("id", [1, 3])
rows = loop.run_until_complete(where_list())
assert len(rows) == 2