155 lines
3.6 KiB
Python
155 lines
3.6 KiB
Python
|
"""
|
|||
|
@Time : 2022/9/30 11:28
|
|||
|
@Auth : 东
|
|||
|
@File :RpcService.py
|
|||
|
@IDE :PyCharm
|
|||
|
@Motto:ABC(Always Be Coding)
|
|||
|
@Desc:RPC服务端
|
|||
|
|
|||
|
"""
|
|||
|
import asyncio
|
|||
|
import json
|
|||
|
import socket
|
|||
|
from functools import wraps
|
|||
|
|
|||
|
from app.schemas.TrainResult import ProcessValueList, Report
|
|||
|
from app.utils.RedisMQTool import Task
|
|||
|
from app.utils.StandardizedOutput import output_wrapped
|
|||
|
from app.utils.redis_config import redis_client
|
|||
|
|
|||
|
funcs = {}
|
|||
|
|
|||
|
|
|||
|
def register_function(func):
|
|||
|
"""
|
|||
|
server端方法注册,client端只能调用注册的方法
|
|||
|
"""
|
|||
|
name = func.__name__
|
|||
|
funcs[name] = func
|
|||
|
|
|||
|
|
|||
|
def mq_send(func, *args, **kwargs):
|
|||
|
data = func(*args, **kwargs)
|
|||
|
print(data)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
class TCPServer(object):
|
|||
|
|
|||
|
def __init__(self):
|
|||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|||
|
self.client_socket = None
|
|||
|
|
|||
|
def bind_listen(self, port=4999):
|
|||
|
self.sock.bind(('0.0.0.0', port))
|
|||
|
self.sock.listen(5)
|
|||
|
|
|||
|
def accept_receive_close(self):
|
|||
|
"""
|
|||
|
接收client的消息
|
|||
|
"""
|
|||
|
if self.client_socket is None:
|
|||
|
(self.client_socket, address) = self.sock.accept()
|
|||
|
if self.client_socket:
|
|||
|
msg = self.client_socket.recv(1024)
|
|||
|
data = self.on_msg(msg)
|
|||
|
self.client_socket.send(data)
|
|||
|
|
|||
|
def on_msg(self, msg):
|
|||
|
pass
|
|||
|
|
|||
|
|
|||
|
class RPCStub(object):
|
|||
|
def __init__(self):
|
|||
|
self.data = None
|
|||
|
|
|||
|
def call_method(self, data):
|
|||
|
"""
|
|||
|
解析函数,调用对应的方法便将该方法的执行结果返回
|
|||
|
"""
|
|||
|
if len(data) == 0:
|
|||
|
return json.dumps("something wrong").encode('utf-8')
|
|||
|
self.data = json.loads(data.decode('utf-8'))
|
|||
|
method_name = self.data['method_name']
|
|||
|
method_args = self.data['method_args']
|
|||
|
method_kwargs = self.data['method_kwargs']
|
|||
|
res = funcs[method_name](*method_args, **method_kwargs)
|
|||
|
return json.dumps(res).encode('utf-8')
|
|||
|
|
|||
|
|
|||
|
class RPCServer(TCPServer, RPCStub):
|
|||
|
def __init__(self):
|
|||
|
TCPServer.__init__(self)
|
|||
|
RPCStub.__init__(self)
|
|||
|
|
|||
|
def loop(self, port):
|
|||
|
"""
|
|||
|
循环监听 4999端口
|
|||
|
"""
|
|||
|
self.bind_listen(port)
|
|||
|
while True:
|
|||
|
try:
|
|||
|
self.accept_receive_close()
|
|||
|
except Exception:
|
|||
|
self.client_socket.close()
|
|||
|
self.client_socket = None
|
|||
|
print(Exception)
|
|||
|
continue
|
|||
|
|
|||
|
def on_msg(self, data):
|
|||
|
return self.call_method(data)
|
|||
|
|
|||
|
|
|||
|
def redisMQSend():
|
|||
|
def wrapTheFunction(func):
|
|||
|
@wraps(func)
|
|||
|
def wrapped_function(*args, **kwargs):
|
|||
|
data = func(*args, **kwargs)
|
|||
|
print(data)
|
|||
|
Task(redis_conn=redis_client.get_redis(), channel="ceshi").publish_task(data=output_wrapped(0, 'success', data))
|
|||
|
|
|||
|
return wrapped_function
|
|||
|
|
|||
|
return wrapTheFunction
|
|||
|
|
|||
|
|
|||
|
@register_function
|
|||
|
def add(a, b, c=10):
|
|||
|
sum = a + b + c
|
|||
|
print(sum)
|
|||
|
return sum
|
|||
|
|
|||
|
|
|||
|
@register_function
|
|||
|
def start(param: str):
|
|||
|
"""
|
|||
|
例子
|
|||
|
"""
|
|||
|
print(param)
|
|||
|
process_value_list = ProcessValueList(name='1', value=[])
|
|||
|
report = Report(rate_of_progess=0, process_value=[process_value_list])
|
|||
|
|
|||
|
@mq_send
|
|||
|
def process(v: int):
|
|||
|
print(v)
|
|||
|
report.rate_of_progess = ((v + 1) / 10) * 100
|
|||
|
report.process_value[0].value.append(v)
|
|||
|
|
|||
|
for i in range(10):
|
|||
|
process(i)
|
|||
|
print(report.dict())
|
|||
|
return report.dict()
|
|||
|
|
|||
|
|
|||
|
@register_function
|
|||
|
def setData(data):
|
|||
|
print(data)
|
|||
|
return data
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
# 开启redis连接
|
|||
|
redis_client.init_redis_connect()
|
|||
|
s = RPCServer()
|
|||
|
s.loop(5003) # 传入要监听的端口
|