detect_plate/detect_demo.py

223 lines
8.4 KiB
Python
Raw Normal View History

2024-08-07 09:32:38 +08:00
# -*- coding: UTF-8 -*-
import argparse
import time
import os
import cv2
import torch
from numpy import random
import copy
import numpy as np
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import check_img_size, non_max_suppression_face, scale_coords
from utils.torch_utils import time_synchronized
from utils.cv_puttext import cv2ImgAddText
from plate_recognition.plate_rec import get_plate_result,allFilePath,cv_imread
from plate_recognition.double_plate_split_merge import get_split_merge
clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
def load_model(weights, device):
model = attempt_load(weights, map_location=device) # load FP32 model
return model
def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
coords[:, [0, 2, 4, 6]] -= pad[0] # x padding
coords[:, [1, 3, 5, 7]] -= pad[1] # y padding
coords[:, :10] /= gain
#clip_coords(coords, img0_shape)
coords[:, 0].clamp_(0, img0_shape[1]) # x1
coords[:, 1].clamp_(0, img0_shape[0]) # y1
coords[:, 2].clamp_(0, img0_shape[1]) # x2
coords[:, 3].clamp_(0, img0_shape[0]) # y2
coords[:, 4].clamp_(0, img0_shape[1]) # x3
coords[:, 5].clamp_(0, img0_shape[0]) # y3
coords[:, 6].clamp_(0, img0_shape[1]) # x4
coords[:, 7].clamp_(0, img0_shape[0]) # y4
# coords[:, 8].clamp_(0, img0_shape[1]) # x5
# coords[:, 9].clamp_(0, img0_shape[0]) # y5
return coords
def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device):
h,w,c = img.shape
result_dict={}
tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
x1 = int(xyxy[0])
y1 = int(xyxy[1])
x2 = int(xyxy[2])
y2 = int(xyxy[3])
landmarks_np=np.zeros((4,2))
rect=[x1,y1,x2,y2]
for i in range(4):
point_x = int(landmarks[2 * i])
point_y = int(landmarks[2 * i + 1])
landmarks_np[i]=np.array([point_x,point_y])
class_label= int(class_num) #车牌的的类型0代表单牌1代表双层车牌
result_dict['rect']=rect
result_dict['landmarks']=landmarks_np.tolist()
result_dict['class']=class_label
return result_dict
def detect_plate(model, orgimg, device,img_size):
# Load model
# img_size = opt_img_size
conf_thres = 0.3
iou_thres = 0.5
dict_list=[]
# orgimg = cv2.imread(image_path) # BGR
img0 = copy.deepcopy(orgimg)
assert orgimg is not None, 'Image Not Found '
h0, w0 = orgimg.shape[:2] # orig hw
r = img_size / max(h0, w0) # resize image to img_size
if r != 1: # always resize down, only resize up if training with augmentation
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
imgsz = check_img_size(img_size, s=model.stride.max()) # check img_size
img = letterbox(img0, new_shape=imgsz)[0]
# img =process_data(img0)
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
# Run inference
t0 = time.time()
img = torch.from_numpy(img).to(device)
img = img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
t1 = time_synchronized()
pred = model(img)[0]
t2=time_synchronized()
# print(f"infer time is {(t2-t1)*1000} ms")
# Apply NMS
pred = non_max_suppression_face(pred, conf_thres, iou_thres)
# print('img.shape: ', img.shape)
# print('orgimg.shape: ', orgimg.shape)
# Process detections
for i, det in enumerate(pred): # detections per image
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], orgimg.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
det[:, 5:13] = scale_coords_landmarks(img.shape[2:], det[:, 5:13], orgimg.shape).round()
for j in range(det.size()[0]):
xyxy = det[j, :4].view(-1).tolist()
conf = det[j, 4].cpu().numpy()
landmarks = det[j, 5:13].view(-1).tolist()
class_num = det[j, 13].cpu().numpy()
result_dict = get_plate_rec_landmark(orgimg, xyxy, conf, landmarks, class_num,device)
dict_list.append(result_dict)
return dict_list
# cv2.imwrite('result.jpg', orgimg)
def draw_result(orgimg,dict_list):
result_str =""
for result in dict_list:
rect_area = result['rect']
x,y,w,h = rect_area[0],rect_area[1],rect_area[2]-rect_area[0],rect_area[3]-rect_area[1]
padding_w = 0.05*w
padding_h = 0.11*h
rect_area[0]=max(0,int(x-padding_w))
rect_area[1]=max(0,int(y-padding_h))
rect_area[2]=min(orgimg.shape[1],int(rect_area[2]+padding_w))
rect_area[3]=min(orgimg.shape[0],int(rect_area[3]+padding_h))
landmarks=result['landmarks']
label=result['class']
# result_str+=result+" "
for i in range(4): #关键点
cv2.circle(orgimg, (int(landmarks[i][0]), int(landmarks[i][1])), 5, clors[i], -1)
cv2.rectangle(orgimg,(rect_area[0],rect_area[1]),(rect_area[2],rect_area[3]),clors[label],2) #画框
cv2.putText(img,str(label),(rect_area[0],rect_area[1]),cv2.FONT_HERSHEY_SIMPLEX,0.5,clors[label],2)
# orgimg=cv2ImgAddText(orgimg,label,rect_area[0]-height_area,rect_area[1]-height_area-10,(0,255,0),height_area)
# print(result_str)
return orgimg
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--detect_model', nargs='+', type=str, default='weights/plate_detect.pt', help='model.pt path(s)') #检测模型
parser.add_argument('--image_path', type=str, default=r'D:\Project\ChePai\test\images\val', help='source')
parser.add_argument('--img_size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--output', type=str, default='result1', help='source')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device =torch.device("cpu")
opt = parser.parse_args()
print(opt)
save_path = opt.output
count=0
if not os.path.exists(save_path):
os.mkdir(save_path)
detect_model = load_model(opt.detect_model, device) #初始化检测模型
time_all = 0
time_begin=time.time()
if not os.path.isfile(opt.image_path): #目录
file_list=[]
allFilePath(opt.image_path,file_list)
for img_path in file_list:
print(count,img_path)
time_b = time.time()
img =cv_imread(img_path)
if img is None:
continue
if img.shape[-1]==4:
img=cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
# detect_one(model,img_path,device)
dict_list=detect_plate(detect_model, img, device,opt.img_size)
ori_img=draw_result(img,dict_list)
img_name = os.path.basename(img_path)
save_img_path = os.path.join(save_path,img_name)
time_e=time.time()
time_gap = time_e-time_b
if count:
time_all+=time_gap
cv2.imwrite(save_img_path,ori_img)
count+=1
else: #单个图片
print(count,opt.image_path,end=" ")
img =cv_imread(opt.image_path)
if img.shape[-1]==4:
img=cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
# detect_one(model,img_path,device)
dict_list=detect_plate(detect_model, img, device,opt.img_size)
ori_img=draw_result(img,dict_list)
img_name = os.path.basename(opt.image_path)
save_img_path = os.path.join(save_path,img_name)
cv2.imwrite(save_img_path,ori_img)
print(f"sumTime time is {time.time()-time_begin} s, average pic time is {time_all/(len(file_list)-1)}")