RODY/app/services/RpcService.py

155 lines
3.6 KiB
Python
Raw Permalink Normal View History

2022-11-04 17:37:08 +08:00
"""
@Time 2022/9/30 11:28
@Auth
@File RpcService.py
@IDE PyCharm
@MottoABC(Always Be Coding)
@DescRPC服务端
"""
import asyncio
import json
import socket
from functools import wraps
from app.schemas.TrainResult import ProcessValueList, Report
from app.utils.RedisMQTool import Task
from app.utils.StandardizedOutput import output_wrapped
from app.utils.redis_config import redis_client
funcs = {}
def register_function(func):
"""
server端方法注册client端只能调用注册的方法
"""
name = func.__name__
funcs[name] = func
def mq_send(func, *args, **kwargs):
data = func(*args, **kwargs)
print(data)
class TCPServer(object):
def __init__(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.client_socket = None
def bind_listen(self, port=4999):
self.sock.bind(('0.0.0.0', port))
self.sock.listen(5)
def accept_receive_close(self):
"""
接收client的消息
"""
if self.client_socket is None:
(self.client_socket, address) = self.sock.accept()
if self.client_socket:
msg = self.client_socket.recv(1024)
data = self.on_msg(msg)
self.client_socket.send(data)
def on_msg(self, msg):
pass
class RPCStub(object):
def __init__(self):
self.data = None
def call_method(self, data):
"""
解析函数调用对应的方法便将该方法的执行结果返回
"""
if len(data) == 0:
return json.dumps("something wrong").encode('utf-8')
self.data = json.loads(data.decode('utf-8'))
method_name = self.data['method_name']
method_args = self.data['method_args']
method_kwargs = self.data['method_kwargs']
res = funcs[method_name](*method_args, **method_kwargs)
return json.dumps(res).encode('utf-8')
class RPCServer(TCPServer, RPCStub):
def __init__(self):
TCPServer.__init__(self)
RPCStub.__init__(self)
def loop(self, port):
"""
循环监听 4999端口
"""
self.bind_listen(port)
while True:
try:
self.accept_receive_close()
except Exception:
self.client_socket.close()
self.client_socket = None
print(Exception)
continue
def on_msg(self, data):
return self.call_method(data)
def redisMQSend():
def wrapTheFunction(func):
@wraps(func)
def wrapped_function(*args, **kwargs):
data = func(*args, **kwargs)
print(data)
Task(redis_conn=redis_client.get_redis(), channel="ceshi").publish_task(data=output_wrapped(0, 'success', data))
return wrapped_function
return wrapTheFunction
@register_function
def add(a, b, c=10):
sum = a + b + c
print(sum)
return sum
@register_function
def start(param: str):
"""
例子
"""
print(param)
process_value_list = ProcessValueList(name='1', value=[])
report = Report(rate_of_progess=0, process_value=[process_value_list])
@mq_send
def process(v: int):
print(v)
report.rate_of_progess = ((v + 1) / 10) * 100
report.process_value[0].value.append(v)
for i in range(10):
process(i)
print(report.dict())
return report.dict()
@register_function
def setData(data):
print(data)
return data
if __name__ == '__main__':
# 开启redis连接
redis_client.init_redis_connect()
s = RPCServer()
s.loop(5003) # 传入要监听的端口