detect_plate/detect_demo.py
2024-08-07 09:32:38 +08:00

223 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: UTF-8 -*-
import argparse
import time
import os
import cv2
import torch
from numpy import random
import copy
import numpy as np
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import check_img_size, non_max_suppression_face, scale_coords
from utils.torch_utils import time_synchronized
from utils.cv_puttext import cv2ImgAddText
from plate_recognition.plate_rec import get_plate_result,allFilePath,cv_imread
from plate_recognition.double_plate_split_merge import get_split_merge
clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
def load_model(weights, device):
model = attempt_load(weights, map_location=device) # load FP32 model
return model
def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
coords[:, [0, 2, 4, 6]] -= pad[0] # x padding
coords[:, [1, 3, 5, 7]] -= pad[1] # y padding
coords[:, :10] /= gain
#clip_coords(coords, img0_shape)
coords[:, 0].clamp_(0, img0_shape[1]) # x1
coords[:, 1].clamp_(0, img0_shape[0]) # y1
coords[:, 2].clamp_(0, img0_shape[1]) # x2
coords[:, 3].clamp_(0, img0_shape[0]) # y2
coords[:, 4].clamp_(0, img0_shape[1]) # x3
coords[:, 5].clamp_(0, img0_shape[0]) # y3
coords[:, 6].clamp_(0, img0_shape[1]) # x4
coords[:, 7].clamp_(0, img0_shape[0]) # y4
# coords[:, 8].clamp_(0, img0_shape[1]) # x5
# coords[:, 9].clamp_(0, img0_shape[0]) # y5
return coords
def get_plate_rec_landmark(img, xyxy, conf, landmarks, class_num,device):
h,w,c = img.shape
result_dict={}
tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
x1 = int(xyxy[0])
y1 = int(xyxy[1])
x2 = int(xyxy[2])
y2 = int(xyxy[3])
landmarks_np=np.zeros((4,2))
rect=[x1,y1,x2,y2]
for i in range(4):
point_x = int(landmarks[2 * i])
point_y = int(landmarks[2 * i + 1])
landmarks_np[i]=np.array([point_x,point_y])
class_label= int(class_num) #车牌的的类型0代表单牌1代表双层车牌
result_dict['rect']=rect
result_dict['landmarks']=landmarks_np.tolist()
result_dict['class']=class_label
return result_dict
def detect_plate(model, orgimg, device,img_size):
# Load model
# img_size = opt_img_size
conf_thres = 0.3
iou_thres = 0.5
dict_list=[]
# orgimg = cv2.imread(image_path) # BGR
img0 = copy.deepcopy(orgimg)
assert orgimg is not None, 'Image Not Found '
h0, w0 = orgimg.shape[:2] # orig hw
r = img_size / max(h0, w0) # resize image to img_size
if r != 1: # always resize down, only resize up if training with augmentation
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
imgsz = check_img_size(img_size, s=model.stride.max()) # check img_size
img = letterbox(img0, new_shape=imgsz)[0]
# img =process_data(img0)
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
# Run inference
t0 = time.time()
img = torch.from_numpy(img).to(device)
img = img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
t1 = time_synchronized()
pred = model(img)[0]
t2=time_synchronized()
# print(f"infer time is {(t2-t1)*1000} ms")
# Apply NMS
pred = non_max_suppression_face(pred, conf_thres, iou_thres)
# print('img.shape: ', img.shape)
# print('orgimg.shape: ', orgimg.shape)
# Process detections
for i, det in enumerate(pred): # detections per image
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], orgimg.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
det[:, 5:13] = scale_coords_landmarks(img.shape[2:], det[:, 5:13], orgimg.shape).round()
for j in range(det.size()[0]):
xyxy = det[j, :4].view(-1).tolist()
conf = det[j, 4].cpu().numpy()
landmarks = det[j, 5:13].view(-1).tolist()
class_num = det[j, 13].cpu().numpy()
result_dict = get_plate_rec_landmark(orgimg, xyxy, conf, landmarks, class_num,device)
dict_list.append(result_dict)
return dict_list
# cv2.imwrite('result.jpg', orgimg)
def draw_result(orgimg,dict_list):
result_str =""
for result in dict_list:
rect_area = result['rect']
x,y,w,h = rect_area[0],rect_area[1],rect_area[2]-rect_area[0],rect_area[3]-rect_area[1]
padding_w = 0.05*w
padding_h = 0.11*h
rect_area[0]=max(0,int(x-padding_w))
rect_area[1]=max(0,int(y-padding_h))
rect_area[2]=min(orgimg.shape[1],int(rect_area[2]+padding_w))
rect_area[3]=min(orgimg.shape[0],int(rect_area[3]+padding_h))
landmarks=result['landmarks']
label=result['class']
# result_str+=result+" "
for i in range(4): #关键点
cv2.circle(orgimg, (int(landmarks[i][0]), int(landmarks[i][1])), 5, clors[i], -1)
cv2.rectangle(orgimg,(rect_area[0],rect_area[1]),(rect_area[2],rect_area[3]),clors[label],2) #画框
cv2.putText(img,str(label),(rect_area[0],rect_area[1]),cv2.FONT_HERSHEY_SIMPLEX,0.5,clors[label],2)
# orgimg=cv2ImgAddText(orgimg,label,rect_area[0]-height_area,rect_area[1]-height_area-10,(0,255,0),height_area)
# print(result_str)
return orgimg
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--detect_model', nargs='+', type=str, default='weights/plate_detect.pt', help='model.pt path(s)') #检测模型
parser.add_argument('--image_path', type=str, default=r'D:\Project\ChePai\test\images\val', help='source')
parser.add_argument('--img_size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--output', type=str, default='result1', help='source')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device =torch.device("cpu")
opt = parser.parse_args()
print(opt)
save_path = opt.output
count=0
if not os.path.exists(save_path):
os.mkdir(save_path)
detect_model = load_model(opt.detect_model, device) #初始化检测模型
time_all = 0
time_begin=time.time()
if not os.path.isfile(opt.image_path): #目录
file_list=[]
allFilePath(opt.image_path,file_list)
for img_path in file_list:
print(count,img_path)
time_b = time.time()
img =cv_imread(img_path)
if img is None:
continue
if img.shape[-1]==4:
img=cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
# detect_one(model,img_path,device)
dict_list=detect_plate(detect_model, img, device,opt.img_size)
ori_img=draw_result(img,dict_list)
img_name = os.path.basename(img_path)
save_img_path = os.path.join(save_path,img_name)
time_e=time.time()
time_gap = time_e-time_b
if count:
time_all+=time_gap
cv2.imwrite(save_img_path,ori_img)
count+=1
else: #单个图片
print(count,opt.image_path,end=" ")
img =cv_imread(opt.image_path)
if img.shape[-1]==4:
img=cv2.cvtColor(img,cv2.COLOR_BGRA2BGR)
# detect_one(model,img_path,device)
dict_list=detect_plate(detect_model, img, device,opt.img_size)
ori_img=draw_result(img,dict_list)
img_name = os.path.basename(opt.image_path)
save_img_path = os.path.join(save_path,img_name)
cv2.imwrite(save_img_path,ori_img)
print(f"sumTime time is {time.time()-time_begin} s, average pic time is {time_all/(len(file_list)-1)}")