119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
import pycuda.autoinit
|
||
import pycuda.driver as cuda
|
||
import tensorrt as trt
|
||
import numpy as np
|
||
|
||
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||
def GiB(val):
|
||
return val * 1 << 30
|
||
|
||
def ONNX_to_TRT(onnx_model_path=None,trt_engine_path=None,fp16_mode=False):
|
||
"""
|
||
仅适用TensorRT V8版本
|
||
生成cudaEngine,并保存引擎文件(仅支持固定输入尺度)
|
||
|
||
fp16_mode: True则fp16预测
|
||
onnx_model_path: 将加载的onnx权重路径
|
||
trt_engine_path: trt引擎文件保存路径
|
||
"""
|
||
builder = trt.Builder(TRT_LOGGER)
|
||
network = builder.create_network(EXPLICIT_BATCH)
|
||
parser = trt.OnnxParser(network, TRT_LOGGER)
|
||
|
||
config = builder.create_builder_config()
|
||
config.max_workspace_size=GiB(1)
|
||
if fp16_mode:
|
||
config.set_flag(trt.BuilderFlag.FP16)
|
||
with open(onnx_model_path, 'rb') as model:
|
||
assert parser.parse(model.read())
|
||
serialized_engine=builder.build_serialized_network(network, config)
|
||
|
||
|
||
with open(trt_engine_path, 'wb') as f:
|
||
f.write(serialized_engine) # 序列化
|
||
|
||
print('TensorRT file in ' + trt_engine_path)
|
||
print('============ONNX->TensorRT SUCCESS============')
|
||
|
||
class TrtModel():
|
||
'''
|
||
TensorRT infer
|
||
'''
|
||
def __init__(self,trt_path):
|
||
self.ctx=cuda.Device(0).make_context()
|
||
stream = cuda.Stream()
|
||
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
|
||
runtime = trt.Runtime(TRT_LOGGER)
|
||
|
||
# Deserialize the engine from file
|
||
with open(trt_path, "rb") as f:
|
||
engine = runtime.deserialize_cuda_engine(f.read())
|
||
context = engine.create_execution_context()
|
||
|
||
host_inputs = []
|
||
cuda_inputs = []
|
||
host_outputs = []
|
||
cuda_outputs = []
|
||
bindings = []
|
||
|
||
for binding in engine:
|
||
print('bingding:', binding, engine.get_binding_shape(binding))
|
||
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
|
||
dtype = trt.nptype(engine.get_binding_dtype(binding))
|
||
# Allocate host and device buffers
|
||
host_mem = cuda.pagelocked_empty(size, dtype)
|
||
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
|
||
# Append the device buffer to device bindings.
|
||
bindings.append(int(cuda_mem))
|
||
# Append to the appropriate list.
|
||
if engine.binding_is_input(binding):
|
||
self.input_w = engine.get_binding_shape(binding)[-1]
|
||
self.input_h = engine.get_binding_shape(binding)[-2]
|
||
host_inputs.append(host_mem)
|
||
cuda_inputs.append(cuda_mem)
|
||
else:
|
||
host_outputs.append(host_mem)
|
||
cuda_outputs.append(cuda_mem)
|
||
|
||
# Store
|
||
self.stream = stream
|
||
self.context = context
|
||
self.engine = engine
|
||
self.host_inputs = host_inputs
|
||
self.cuda_inputs = cuda_inputs
|
||
self.host_outputs = host_outputs
|
||
self.cuda_outputs = cuda_outputs
|
||
self.bindings = bindings
|
||
self.batch_size = engine.max_batch_size
|
||
|
||
def __call__(self,img_np_nchw):
|
||
'''
|
||
TensorRT推理
|
||
:param img_np_nchw: 输入图像
|
||
'''
|
||
self.ctx.push()
|
||
|
||
# Restore
|
||
stream = self.stream
|
||
context = self.context
|
||
engine = self.engine
|
||
host_inputs = self.host_inputs
|
||
cuda_inputs = self.cuda_inputs
|
||
host_outputs = self.host_outputs
|
||
cuda_outputs = self.cuda_outputs
|
||
bindings = self.bindings
|
||
|
||
np.copyto(host_inputs[0], img_np_nchw.ravel())
|
||
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
|
||
context.execute_async(batch_size=self.batch_size, bindings=bindings, stream_handle=stream.handle)
|
||
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
|
||
stream.synchronize()
|
||
self.ctx.pop()
|
||
return host_outputs[0]
|
||
|
||
|
||
def destroy(self):
|
||
# Remove any context from the top of the context stack, deactivating it.
|
||
self.ctx.pop()
|