320 lines
10 KiB
Python
320 lines
10 KiB
Python
|
import orjson
|
|||
|
|
|||
|
try:
|
|||
|
from collections import Mapping
|
|||
|
except: # noqa E722
|
|||
|
from collections.abc import Mapping
|
|||
|
import inspect
|
|||
|
import importlib
|
|||
|
import os
|
|||
|
import re
|
|||
|
import sys
|
|||
|
from typing import Any, Dict, List, Optional
|
|||
|
from unicodedata import normalize
|
|||
|
from distutils.version import LooseVersion
|
|||
|
import logging
|
|||
|
import datetime
|
|||
|
|
|||
|
from flask import request
|
|||
|
from flask_sqlalchemy import model
|
|||
|
from sqlalchemy import UniqueConstraint
|
|||
|
import marshmallow
|
|||
|
from marshmallow import Schema
|
|||
|
|
|||
|
from .webargs import use_kwargs as base_use_kwargs, parser
|
|||
|
|
|||
|
from flask.json import JSONEncoder
|
|||
|
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
|
|||
|
class ParamsDict(dict):
|
|||
|
"""Just available update func.
|
|||
|
|
|||
|
Example::
|
|||
|
|
|||
|
@use_kwargs(PageParams.update({...}))
|
|||
|
def list_users(page, page_size, order_by):
|
|||
|
pass
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
def update(self, other=None):
|
|||
|
"""Update self by other Mapping and return self.
|
|||
|
"""
|
|||
|
ret = ParamsDict(self.copy())
|
|||
|
if other is not None:
|
|||
|
for k, v in other.items() if isinstance(other, Mapping) else other:
|
|||
|
ret[k] = v
|
|||
|
return ret
|
|||
|
|
|||
|
|
|||
|
# Function version
|
|||
|
def row2dict(row):
|
|||
|
return {
|
|||
|
c.name: str(getattr(row, c.name))
|
|||
|
for c in row.__table__.columns
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
class dict2object(dict):
|
|||
|
"""
|
|||
|
Dict to fake object that can use getattr.
|
|||
|
"""
|
|||
|
|
|||
|
def __getattr__(self, name: str) -> Any:
|
|||
|
if name in self.keys():
|
|||
|
return self[name]
|
|||
|
raise AttributeError('object has no attribute {}'.format(name))
|
|||
|
|
|||
|
def __setattr__(self, name: str, value: Any) -> None:
|
|||
|
if not isinstance(name, str):
|
|||
|
raise TypeError('key must be string type.')
|
|||
|
self[name] = value
|
|||
|
|
|||
|
|
|||
|
def secure_filename(filename: str) -> str:
|
|||
|
"""Borrowed from werkzeug.utils.secure_filename.
|
|||
|
|
|||
|
Pass it a filename and it will return a secure version of it. This
|
|||
|
filename can then safely be stored on a regular file system and passed
|
|||
|
to :func:`os.path.join`.
|
|||
|
|
|||
|
On windows systems the function also makes sure that the file is not
|
|||
|
named after one of the special device files.
|
|||
|
|
|||
|
>>> secure_filename(u'哈哈.zip')
|
|||
|
'哈哈.zip'
|
|||
|
>>> secure_filename('My cool movie.mov')
|
|||
|
'My_cool_movie.mov'
|
|||
|
>>> secure_filename('../../../etc/passwd')
|
|||
|
'etc_passwd'
|
|||
|
>>> secure_filename(u'i contain cool \xfcml\xe4uts.txt')
|
|||
|
'i_contain_cool_umlauts.txt'
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
for sep in os.path.sep, os.path.altsep:
|
|||
|
if sep:
|
|||
|
filename = filename.replace(sep, ' ')
|
|||
|
|
|||
|
filename = normalize('NFKD', '_'.join(filename.split()))
|
|||
|
|
|||
|
filename_strip_re = re.compile(u'[^A-Za-z0-9\u4e00-\u9fa5_.-]')
|
|||
|
filename = filename_strip_re.sub('', filename).strip('._')
|
|||
|
|
|||
|
# on nt a couple of special files are present in each folder. We
|
|||
|
# have to ensure that the target file is not such a filename. In
|
|||
|
# this case we prepend an underline
|
|||
|
windows_device_files = (
|
|||
|
'CON', 'AUX', 'COM1', 'COM2', 'COM3', 'COM4', 'LPT1',
|
|||
|
'LPT2', 'LPT3', 'PRN', 'NUL',
|
|||
|
)
|
|||
|
if os.name == 'nt' and filename and \
|
|||
|
filename.split('.')[0].upper() in windows_device_files:
|
|||
|
filename = '_' + filename
|
|||
|
|
|||
|
return filename
|
|||
|
|
|||
|
|
|||
|
def _get_init_args(instance, base_class):
|
|||
|
"""Get instance's __init__ args and it's value when __call__.
|
|||
|
"""
|
|||
|
getargspec = inspect.getfullargspec
|
|||
|
argspec = getargspec(base_class.__init__)
|
|||
|
|
|||
|
defaults = argspec.defaults
|
|||
|
kwargs = {}
|
|||
|
if defaults is not None:
|
|||
|
no_defaults = argspec.args[:-len(defaults)]
|
|||
|
has_defaults = argspec.args[-len(defaults):]
|
|||
|
|
|||
|
kwargs = {k: getattr(instance, k) for k in no_defaults
|
|||
|
if k != 'self' and hasattr(instance, k)}
|
|||
|
kwargs.update({k: getattr(instance, k) if hasattr(instance, k) else
|
|||
|
getattr(instance, k, defaults[i])
|
|||
|
for i, k in enumerate(has_defaults)})
|
|||
|
assert len(kwargs) == len(argspec.args) - 1, 'exclude `self`'
|
|||
|
return kwargs
|
|||
|
|
|||
|
|
|||
|
def use_kwargs(argmap, schema_kwargs: Optional[Dict] = None, **kwargs: Any):
|
|||
|
"""For fix ``Schema(partial=True)`` not work when used with
|
|||
|
``@webargs.flaskparser.use_kwargs``. More details ``see webargs.core``.
|
|||
|
|
|||
|
Args:
|
|||
|
|
|||
|
argmap (marshmallow.Schema,dict,callable): Either a
|
|||
|
`marshmallow.Schema`, `dict` of argname ->
|
|||
|
`marshmallow.fields.Field` pairs, or a callable that returns a
|
|||
|
`marshmallow.Schema` instance.
|
|||
|
schema_kwargs (dict): kwargs for argmap.
|
|||
|
|
|||
|
Returns:
|
|||
|
dict: A dictionary of parsed arguments.
|
|||
|
|
|||
|
"""
|
|||
|
schema_kwargs = schema_kwargs or {}
|
|||
|
|
|||
|
argmap = parser._get_schema(argmap, request)
|
|||
|
|
|||
|
if not (argmap.partial or schema_kwargs.get('partial')):
|
|||
|
return base_use_kwargs(argmap, **kwargs)
|
|||
|
|
|||
|
def factory(request):
|
|||
|
argmap_kwargs = _get_init_args(argmap, Schema)
|
|||
|
argmap_kwargs.update(schema_kwargs)
|
|||
|
|
|||
|
# force set force_all=False
|
|||
|
only = parser.parse(argmap, request).keys()
|
|||
|
|
|||
|
argmap_kwargs.update({
|
|||
|
'partial': False, # fix missing=None not work
|
|||
|
'only': only or None,
|
|||
|
'context': {"request": request},
|
|||
|
})
|
|||
|
if tuple(LooseVersion(marshmallow.__version__).version)[0] < 3:
|
|||
|
argmap_kwargs['strict'] = True
|
|||
|
|
|||
|
return argmap.__class__(**argmap_kwargs)
|
|||
|
|
|||
|
return base_use_kwargs(factory, **kwargs)
|
|||
|
|
|||
|
|
|||
|
def import_subs(locals_, modules_only: bool = False) -> List[str]:
|
|||
|
""" Auto import submodules, used in __init__.py.
|
|||
|
|
|||
|
Args:
|
|||
|
|
|||
|
locals_: `locals()`.
|
|||
|
modules_only: Only collect modules to __all__.
|
|||
|
|
|||
|
Examples::
|
|||
|
|
|||
|
# app/models/__init__.py
|
|||
|
from hobbit_core.utils import import_subs
|
|||
|
|
|||
|
__all__ = import_subs(locals())
|
|||
|
|
|||
|
Auto collect Model's subclass, Schema's subclass and instance.
|
|||
|
Others objects must defined in submodule.__all__.
|
|||
|
"""
|
|||
|
package = locals_['__package__']
|
|||
|
path = locals_['__path__']
|
|||
|
top_mudule = sys.modules[package]
|
|||
|
|
|||
|
all_ = []
|
|||
|
for name in os.listdir(path[0]):
|
|||
|
if not name.endswith(('.py', '.pyc')) or name.startswith('__init__.'):
|
|||
|
continue
|
|||
|
|
|||
|
module_name = name.split('.')[0]
|
|||
|
submodule = importlib.import_module(f".{module_name}", package)
|
|||
|
all_.append(module_name)
|
|||
|
|
|||
|
if modules_only:
|
|||
|
continue
|
|||
|
|
|||
|
if hasattr(submodule, '__all__'):
|
|||
|
for name in getattr(submodule, '__all__'):
|
|||
|
if not isinstance(name, str):
|
|||
|
raise Exception(f'Invalid object {name} in __all__, '
|
|||
|
f'must contain only strings.')
|
|||
|
setattr(top_mudule, name, getattr(submodule, name))
|
|||
|
all_.append(name)
|
|||
|
else:
|
|||
|
for name, obj in submodule.__dict__.items():
|
|||
|
if isinstance(obj, (model.DefaultMeta, Schema)) or \
|
|||
|
(inspect.isclass(obj) and
|
|||
|
(issubclass(obj, Schema) or
|
|||
|
obj.__name__.endswith('Service'))):
|
|||
|
setattr(top_mudule, name, obj)
|
|||
|
all_.append(name)
|
|||
|
return all_
|
|||
|
|
|||
|
|
|||
|
def bulk_create_or_update_on_duplicate(
|
|||
|
db, model_cls, items, updated_at='updated_at', batch_size=500):
|
|||
|
""" Support MySQL and postgreSQL.
|
|||
|
https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html
|
|||
|
|
|||
|
Args:
|
|||
|
|
|||
|
db: Instance of `SQLAlchemy`.
|
|||
|
model_cls: Model object.
|
|||
|
items: List of data,[ example: `[{key: value}, {key: value}, ...]`.
|
|||
|
updated_at: Field which recording row update time.
|
|||
|
batch_size: Batch size is max rows per execute.
|
|||
|
|
|||
|
Returns:
|
|||
|
dict: A dictionary contains rowcount and items_count.
|
|||
|
"""
|
|||
|
if not items:
|
|||
|
logger.warning("bulk_create_or_update_on_duplicate save to "
|
|||
|
f"{model_cls} failed, empty items")
|
|||
|
return {'rowcount': 0, 'items_count': 0}
|
|||
|
|
|||
|
items_count = len(items)
|
|||
|
table_name = model_cls.__tablename__
|
|||
|
fields = list(items[0].keys())
|
|||
|
unique_keys = [c.name for i in model_cls.__table_args__ if isinstance(
|
|||
|
i, UniqueConstraint) for c in i]
|
|||
|
columns = [c.name for c in model_cls.__table__.columns if c.name not in (
|
|||
|
'id', 'created_at')]
|
|||
|
|
|||
|
if updated_at in columns and updated_at not in fields:
|
|||
|
fields.append(updated_at)
|
|||
|
updated_at_time = datetime.datetime.now()
|
|||
|
for item in items:
|
|||
|
item[updated_at] = updated_at_time
|
|||
|
|
|||
|
assert set(fields) == set(columns), \
|
|||
|
'item fields not equal to columns in models:new: ' + \
|
|||
|
f'{set(fields) - set(columns)}, delete: {set(columns) - set(fields)}'
|
|||
|
|
|||
|
for item in items:
|
|||
|
for column in unique_keys:
|
|||
|
if column in item and item[column] is None:
|
|||
|
item[column] = ''
|
|||
|
|
|||
|
engine = db.get_engine(bind=getattr(model_cls, '__bind_key__', None))
|
|||
|
if engine.name == 'postgresql':
|
|||
|
sql_on_update = ', '.join([
|
|||
|
f' {field} = excluded.{field}'
|
|||
|
for field in fields if field not in unique_keys])
|
|||
|
sql = f"""INSERT INTO {table_name} ({", ".join(fields)}) VALUES
|
|||
|
({", ".join([f':{key}' for key in fields])})
|
|||
|
ON CONFLICT ({", ".join(unique_keys)}) DO UPDATE SET
|
|||
|
{sql_on_update}"""
|
|||
|
elif engine.name == 'mysql':
|
|||
|
sql_on_update = '`, `'.join([
|
|||
|
f' `{field}` = new.{field}' for field in fields
|
|||
|
if field not in unique_keys])
|
|||
|
sql = f"""INSERT INTO {table_name} (`{"`, `".join(fields)}`) VALUES
|
|||
|
({", ".join([f':{key}' for key in fields])}) AS new
|
|||
|
ON DUPLICATE KEY UPDATE
|
|||
|
{sql_on_update}"""
|
|||
|
else:
|
|||
|
raise Exception(f'not support db: {engine.name}')
|
|||
|
|
|||
|
rowcounts = 0
|
|||
|
while len(items) > 0:
|
|||
|
batch, items = items[:batch_size], items[batch_size:]
|
|||
|
try:
|
|||
|
result = db.session.execute(sql, batch, bind=engine)
|
|||
|
except Exception as e:
|
|||
|
logger.error(e, exc_info=True)
|
|||
|
logger.info(sql)
|
|||
|
raise e
|
|||
|
rowcounts += result.rowcount
|
|||
|
logger.info(f'{model_cls} save_data: rowcount={rowcounts}, '
|
|||
|
f'items_count: {items_count}')
|
|||
|
return {'rowcount': rowcounts, 'items_count': items_count}
|
|||
|
|
|||
|
|
|||
|
def orjson_serializer(obj):
|
|||
|
"""
|
|||
|
Note that `orjson.dumps()` return byte array, while sqlalchemy expects string, thus `decode()` call.
|
|||
|
"""
|
|||
|
return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NAIVE_UTC).decode()
|