320 lines
10 KiB
Python
320 lines
10 KiB
Python
import orjson
|
||
|
||
try:
|
||
from collections import Mapping
|
||
except: # noqa E722
|
||
from collections.abc import Mapping
|
||
import inspect
|
||
import importlib
|
||
import os
|
||
import re
|
||
import sys
|
||
from typing import Any, Dict, List, Optional
|
||
from unicodedata import normalize
|
||
from distutils.version import LooseVersion
|
||
import logging
|
||
import datetime
|
||
|
||
from flask import request
|
||
from flask_sqlalchemy import model
|
||
from sqlalchemy import UniqueConstraint
|
||
import marshmallow
|
||
from marshmallow import Schema
|
||
|
||
from .webargs import use_kwargs as base_use_kwargs, parser
|
||
|
||
from flask.json import JSONEncoder
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ParamsDict(dict):
|
||
"""Just available update func.
|
||
|
||
Example::
|
||
|
||
@use_kwargs(PageParams.update({...}))
|
||
def list_users(page, page_size, order_by):
|
||
pass
|
||
|
||
"""
|
||
|
||
def update(self, other=None):
|
||
"""Update self by other Mapping and return self.
|
||
"""
|
||
ret = ParamsDict(self.copy())
|
||
if other is not None:
|
||
for k, v in other.items() if isinstance(other, Mapping) else other:
|
||
ret[k] = v
|
||
return ret
|
||
|
||
|
||
# Function version
|
||
def row2dict(row):
|
||
return {
|
||
c.name: str(getattr(row, c.name))
|
||
for c in row.__table__.columns
|
||
}
|
||
|
||
|
||
class dict2object(dict):
|
||
"""
|
||
Dict to fake object that can use getattr.
|
||
"""
|
||
|
||
def __getattr__(self, name: str) -> Any:
|
||
if name in self.keys():
|
||
return self[name]
|
||
raise AttributeError('object has no attribute {}'.format(name))
|
||
|
||
def __setattr__(self, name: str, value: Any) -> None:
|
||
if not isinstance(name, str):
|
||
raise TypeError('key must be string type.')
|
||
self[name] = value
|
||
|
||
|
||
def secure_filename(filename: str) -> str:
|
||
"""Borrowed from werkzeug.utils.secure_filename.
|
||
|
||
Pass it a filename and it will return a secure version of it. This
|
||
filename can then safely be stored on a regular file system and passed
|
||
to :func:`os.path.join`.
|
||
|
||
On windows systems the function also makes sure that the file is not
|
||
named after one of the special device files.
|
||
|
||
>>> secure_filename(u'哈哈.zip')
|
||
'哈哈.zip'
|
||
>>> secure_filename('My cool movie.mov')
|
||
'My_cool_movie.mov'
|
||
>>> secure_filename('../../../etc/passwd')
|
||
'etc_passwd'
|
||
>>> secure_filename(u'i contain cool \xfcml\xe4uts.txt')
|
||
'i_contain_cool_umlauts.txt'
|
||
|
||
"""
|
||
|
||
for sep in os.path.sep, os.path.altsep:
|
||
if sep:
|
||
filename = filename.replace(sep, ' ')
|
||
|
||
filename = normalize('NFKD', '_'.join(filename.split()))
|
||
|
||
filename_strip_re = re.compile(u'[^A-Za-z0-9\u4e00-\u9fa5_.-]')
|
||
filename = filename_strip_re.sub('', filename).strip('._')
|
||
|
||
# on nt a couple of special files are present in each folder. We
|
||
# have to ensure that the target file is not such a filename. In
|
||
# this case we prepend an underline
|
||
windows_device_files = (
|
||
'CON', 'AUX', 'COM1', 'COM2', 'COM3', 'COM4', 'LPT1',
|
||
'LPT2', 'LPT3', 'PRN', 'NUL',
|
||
)
|
||
if os.name == 'nt' and filename and \
|
||
filename.split('.')[0].upper() in windows_device_files:
|
||
filename = '_' + filename
|
||
|
||
return filename
|
||
|
||
|
||
def _get_init_args(instance, base_class):
|
||
"""Get instance's __init__ args and it's value when __call__.
|
||
"""
|
||
getargspec = inspect.getfullargspec
|
||
argspec = getargspec(base_class.__init__)
|
||
|
||
defaults = argspec.defaults
|
||
kwargs = {}
|
||
if defaults is not None:
|
||
no_defaults = argspec.args[:-len(defaults)]
|
||
has_defaults = argspec.args[-len(defaults):]
|
||
|
||
kwargs = {k: getattr(instance, k) for k in no_defaults
|
||
if k != 'self' and hasattr(instance, k)}
|
||
kwargs.update({k: getattr(instance, k) if hasattr(instance, k) else
|
||
getattr(instance, k, defaults[i])
|
||
for i, k in enumerate(has_defaults)})
|
||
assert len(kwargs) == len(argspec.args) - 1, 'exclude `self`'
|
||
return kwargs
|
||
|
||
|
||
def use_kwargs(argmap, schema_kwargs: Optional[Dict] = None, **kwargs: Any):
|
||
"""For fix ``Schema(partial=True)`` not work when used with
|
||
``@webargs.flaskparser.use_kwargs``. More details ``see webargs.core``.
|
||
|
||
Args:
|
||
|
||
argmap (marshmallow.Schema,dict,callable): Either a
|
||
`marshmallow.Schema`, `dict` of argname ->
|
||
`marshmallow.fields.Field` pairs, or a callable that returns a
|
||
`marshmallow.Schema` instance.
|
||
schema_kwargs (dict): kwargs for argmap.
|
||
|
||
Returns:
|
||
dict: A dictionary of parsed arguments.
|
||
|
||
"""
|
||
schema_kwargs = schema_kwargs or {}
|
||
|
||
argmap = parser._get_schema(argmap, request)
|
||
|
||
if not (argmap.partial or schema_kwargs.get('partial')):
|
||
return base_use_kwargs(argmap, **kwargs)
|
||
|
||
def factory(request):
|
||
argmap_kwargs = _get_init_args(argmap, Schema)
|
||
argmap_kwargs.update(schema_kwargs)
|
||
|
||
# force set force_all=False
|
||
only = parser.parse(argmap, request).keys()
|
||
|
||
argmap_kwargs.update({
|
||
'partial': False, # fix missing=None not work
|
||
'only': only or None,
|
||
'context': {"request": request},
|
||
})
|
||
if tuple(LooseVersion(marshmallow.__version__).version)[0] < 3:
|
||
argmap_kwargs['strict'] = True
|
||
|
||
return argmap.__class__(**argmap_kwargs)
|
||
|
||
return base_use_kwargs(factory, **kwargs)
|
||
|
||
|
||
def import_subs(locals_, modules_only: bool = False) -> List[str]:
|
||
""" Auto import submodules, used in __init__.py.
|
||
|
||
Args:
|
||
|
||
locals_: `locals()`.
|
||
modules_only: Only collect modules to __all__.
|
||
|
||
Examples::
|
||
|
||
# app/models/__init__.py
|
||
from hobbit_core.utils import import_subs
|
||
|
||
__all__ = import_subs(locals())
|
||
|
||
Auto collect Model's subclass, Schema's subclass and instance.
|
||
Others objects must defined in submodule.__all__.
|
||
"""
|
||
package = locals_['__package__']
|
||
path = locals_['__path__']
|
||
top_mudule = sys.modules[package]
|
||
|
||
all_ = []
|
||
for name in os.listdir(path[0]):
|
||
if not name.endswith(('.py', '.pyc')) or name.startswith('__init__.'):
|
||
continue
|
||
|
||
module_name = name.split('.')[0]
|
||
submodule = importlib.import_module(f".{module_name}", package)
|
||
all_.append(module_name)
|
||
|
||
if modules_only:
|
||
continue
|
||
|
||
if hasattr(submodule, '__all__'):
|
||
for name in getattr(submodule, '__all__'):
|
||
if not isinstance(name, str):
|
||
raise Exception(f'Invalid object {name} in __all__, '
|
||
f'must contain only strings.')
|
||
setattr(top_mudule, name, getattr(submodule, name))
|
||
all_.append(name)
|
||
else:
|
||
for name, obj in submodule.__dict__.items():
|
||
if isinstance(obj, (model.DefaultMeta, Schema)) or \
|
||
(inspect.isclass(obj) and
|
||
(issubclass(obj, Schema) or
|
||
obj.__name__.endswith('Service'))):
|
||
setattr(top_mudule, name, obj)
|
||
all_.append(name)
|
||
return all_
|
||
|
||
|
||
def bulk_create_or_update_on_duplicate(
|
||
db, model_cls, items, updated_at='updated_at', batch_size=500):
|
||
""" Support MySQL and postgreSQL.
|
||
https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html
|
||
|
||
Args:
|
||
|
||
db: Instance of `SQLAlchemy`.
|
||
model_cls: Model object.
|
||
items: List of data,[ example: `[{key: value}, {key: value}, ...]`.
|
||
updated_at: Field which recording row update time.
|
||
batch_size: Batch size is max rows per execute.
|
||
|
||
Returns:
|
||
dict: A dictionary contains rowcount and items_count.
|
||
"""
|
||
if not items:
|
||
logger.warning("bulk_create_or_update_on_duplicate save to "
|
||
f"{model_cls} failed, empty items")
|
||
return {'rowcount': 0, 'items_count': 0}
|
||
|
||
items_count = len(items)
|
||
table_name = model_cls.__tablename__
|
||
fields = list(items[0].keys())
|
||
unique_keys = [c.name for i in model_cls.__table_args__ if isinstance(
|
||
i, UniqueConstraint) for c in i]
|
||
columns = [c.name for c in model_cls.__table__.columns if c.name not in (
|
||
'id', 'created_at')]
|
||
|
||
if updated_at in columns and updated_at not in fields:
|
||
fields.append(updated_at)
|
||
updated_at_time = datetime.datetime.now()
|
||
for item in items:
|
||
item[updated_at] = updated_at_time
|
||
|
||
assert set(fields) == set(columns), \
|
||
'item fields not equal to columns in models:new: ' + \
|
||
f'{set(fields) - set(columns)}, delete: {set(columns) - set(fields)}'
|
||
|
||
for item in items:
|
||
for column in unique_keys:
|
||
if column in item and item[column] is None:
|
||
item[column] = ''
|
||
|
||
engine = db.get_engine(bind=getattr(model_cls, '__bind_key__', None))
|
||
if engine.name == 'postgresql':
|
||
sql_on_update = ', '.join([
|
||
f' {field} = excluded.{field}'
|
||
for field in fields if field not in unique_keys])
|
||
sql = f"""INSERT INTO {table_name} ({", ".join(fields)}) VALUES
|
||
({", ".join([f':{key}' for key in fields])})
|
||
ON CONFLICT ({", ".join(unique_keys)}) DO UPDATE SET
|
||
{sql_on_update}"""
|
||
elif engine.name == 'mysql':
|
||
sql_on_update = '`, `'.join([
|
||
f' `{field}` = new.{field}' for field in fields
|
||
if field not in unique_keys])
|
||
sql = f"""INSERT INTO {table_name} (`{"`, `".join(fields)}`) VALUES
|
||
({", ".join([f':{key}' for key in fields])}) AS new
|
||
ON DUPLICATE KEY UPDATE
|
||
{sql_on_update}"""
|
||
else:
|
||
raise Exception(f'not support db: {engine.name}')
|
||
|
||
rowcounts = 0
|
||
while len(items) > 0:
|
||
batch, items = items[:batch_size], items[batch_size:]
|
||
try:
|
||
result = db.session.execute(sql, batch, bind=engine)
|
||
except Exception as e:
|
||
logger.error(e, exc_info=True)
|
||
logger.info(sql)
|
||
raise e
|
||
rowcounts += result.rowcount
|
||
logger.info(f'{model_cls} save_data: rowcount={rowcounts}, '
|
||
f'items_count: {items_count}')
|
||
return {'rowcount': rowcounts, 'items_count': items_count}
|
||
|
||
|
||
def orjson_serializer(obj):
|
||
"""
|
||
Note that `orjson.dumps()` return byte array, while sqlalchemy expects string, thus `decode()` call.
|
||
"""
|
||
return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NAIVE_UTC).decode()
|