RODY/app/file_tool.py
552068321@qq.com 6f7de660aa first commit
2022-11-04 17:37:08 +08:00

254 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import shutil
from math import ceil
from typing import List, Optional, Union
#from ai_platform.common.config import settings
# from ai_platform.model.crud import image_label_crud as ilc, project_list_crud as plc, \
# image_dataset_curd as idc
#from ai_platform.common.logger import logger
# from ai_platform.model.database import session
from app.core.common_utils import logger
from app.json_util import write_info
# root_path = settings.root_path
# root_path = '/home/wd/server/ai_platform/data_set/'
#db = session
def delete_file(files: List[str]):
"""
删除文件
:param files:
:return:
"""
for file in files:
if os.path.exists(file):
os.remove(file)
def get_file_then_delete_file(path: str):
"""
删除指定路径下的所有文件
:param path:
:return:
"""
(filedir, filename) = os.path.split(path)
if os.path.exists(filedir):
del_files = []
for (dirpath, dirnames, filenames) in os.walk(filedir):
for filename in filenames:
del_files.append(os.path.join(dirpath, filename))
# del_files = os.listdir(filedir)
delete_file(files=del_files)
return filedir
def delete_dir_file(files: List[str], json_files: List[str]):
"""
若训练集、测试机、验证集的存放文件夹不为空, 删除文件夹下所有文件
:param json_files:
:param files:
:return:
"""
logger.info('删除图片数据')
train_target_path = files[0].replace('ori/images', 'trained/images/train')
train_filedir = get_file_then_delete_file(train_target_path)
val_target_path = files[0].replace('ori/images', 'trained/images/val')
val_filedir = get_file_then_delete_file(val_target_path)
test_target_path = files[0].replace('ori/images', 'trained/images/test')
test_filedir = get_file_then_delete_file(test_target_path)
if len(json_files) == 0:
logger.info('无json数据')
else:
logger.info('删除json数据')
train_target_path = json_files[0].replace('ori/labels', 'trained/labels/train')
get_file_then_delete_file(train_target_path)
val_target_path = json_files[0].replace('ori/labels', 'trained/labels/val')
get_file_then_delete_file(val_target_path)
val_target_path = json_files[0].replace('ori/labels', 'trained/labels/test')
get_file_then_delete_file(val_target_path)
return [train_filedir + '/', val_filedir + '/', test_filedir + '/']
def mv_file(train_files: List[str], test_files: List[str], r_v_rate: Optional[float] = 0.9,
t_t_rate: Optional[float] = 0.9):
"""
移动图片标签到指定位置
:param train_files:测试集
:param test_files:验证集
:param r_v_rate:训练集内部比例
:param t_t_rate:训练-验证比例
:return:
"""
train_img_files = [i for i in train_files if not i.endswith('.json')]
train_json_files = [i for i in train_files if i.endswith('.json')]
test_img_files = [i for i in test_files if not i.endswith('.json')]
test_json_files = [i for i in test_files if i.endswith('.json')]
# 训练集、验证集、测试集
#logger.info('训练集、验证集、测试集开始划分')
train_len_all = len(train_img_files)
if t_t_rate is not None:
test_len_all = len(test_img_files)
len_all = train_len_all + test_len_all
t_t_rate_c = test_len_all / len_all
if t_t_rate_c > t_t_rate:
train_len_all = ceil(len_all * t_t_rate)
test_files.extend(train_img_files[train_len_all:])
train_len = ceil(train_len_all * r_v_rate)
# t_files: 训练集, val_files:验证集
t_files = train_img_files[0:train_len]
val_files = train_img_files[train_len:train_len_all]
# 判断目标文件夹是否存在, 存在则删除目录下文件
#logger.info('判断目标文件夹是否存在, 存在则删除目录下文件')
target_path = delete_dir_file(files=train_img_files, json_files=train_json_files)
# 放到指定文件夹
#logger.info('放到指定文件夹')
# t_files:训练集开始移动
for file in t_files:
if os.path.exists(file):
file_path = file.replace('ori/images', 'trained/images/train')
# /3148803620347904/ori/images/4.jpg
(filedir, filename) = os.path.split(file_path)
if not os.path.exists(filedir):
os.makedirs(filedir)
shutil.copyfile(file, file_path)
# json 放到指定文件夹下
json_file = os.path.splitext(file)[0].replace('images', 'labels') + '.json'
if json_file in train_json_files:
file_path = json_file.replace('ori/labels', 'trained/labels/train')
# /3148803620347904/ori/labels/4.jpg.json
(filedir, filename) = os.path.split(file_path)
if not os.path.exists(filedir):
os.makedirs(filedir)
shutil.copyfile(json_file, file_path)
# 测试集开始
for file in val_files:
if os.path.exists(file):
file_path = file.replace('ori/images', 'trained/images/val')
(filedir, filename) = os.path.split(file_path)
if not os.path.exists(filedir):
os.makedirs(filedir)
shutil.copyfile(file, file_path)
# json 放到指定文件夹下
json_file = os.path.splitext(file)[0].replace('images', 'labels') + '.json'
if json_file in train_json_files:
file_path = json_file.replace('ori/labels', 'trained/labels/val')
(filedir, filename) = os.path.split(file_path)
if not os.path.exists(filedir):
os.makedirs(filedir)
shutil.copyfile(json_file, file_path)
for file in test_img_files:
if os.path.exists(file):
file_path = file.replace('ori/images', 'trained/images/test')
# /3148803620347904/ori/images/4.jpg
(filedir, filename) = os.path.split(file_path)
if not os.path.exists(filedir):
os.makedirs(filedir)
shutil.copyfile(file, file_path)
# json 放到指定文件夹下
json_file = os.path.splitext(file)[0].replace('images', 'labels') + '.json'
if json_file in test_json_files:
file_path = json_file.replace('ori/labels', 'trained/labels/test')
# /3148803620347904/ori/labels/4.jpg.json
(filedir, filename) = os.path.split(file_path)
if not os.path.exists(filedir):
os.makedirs(filedir)
shutil.copyfile(json_file, file_path)
return target_path
def get_file(ori_path: str, type_list: Union[object,str]):
# imgs = idc.get_image_all_proj_no(proj_no=proj_no, db=db)
imgs = os.listdir(ori_path + '/images')
train_files = []
test_files = []
# 训练、测试比例强制91
for img in imgs[0:1]:
path = ori_path + 'images/' +img
# print(os.path.exists(path))
if os.path.exists(path):
test_files.append(path)
#label = ori_path + 'labels/' + os.path.split(path)[1]
(filename1, extension) = os.path.splitext(img) # 文件名与后缀名分开
label = ori_path + 'labels/' + filename1 + '.json'
if label is not None:
#train_files.append(label)
test_files.append(label)
for img in imgs[1:]:
path = ori_path + 'images/' +img
if os.path.exists(path):
train_files.append(path)
(filename2, extension) = os.path.splitext(img) # 文件名与后缀名分开
label = ori_path + 'labels/' + filename2 + '.json'
if label is not None:
train_files.append(label)
if len(train_files) == 0 or len(test_files) == 0:
return False
# proj = plc.get_proj_by_proj_no(proj_no=proj_no, db=db)
target_path = mv_file(train_files=train_files, test_files=test_files)
# 生成标签
# img_types = ilc.get_label_by_proj_no(proj_no=proj_no, db=db)
# type_list = []
# for img_type in img_types:
# type_list.append(img_type.lebel_type)
type_dict = {'classes': type_list}
str_json = json.dumps(type_dict)
path = os.path.dirname(ori_path) + '/img_label_type'
# path = root_path + proj_no + '/img_label_type'
write_info(file_name=path, file_info=json.loads(str_json))
target_path.append(path + '.json')
return target_path
# def get_file_path(proj_no: str):
# """
# 识别算法,给算法传递图片路径
# :param proj_no:
# :return:
# """
# path = root_path + '/' + proj_no
# img_path = path
# # 创建他们所需的文件夹
# vgg_path = path + '/vgg'
# if not os.path.exists(vgg_path):
# # vgg不存在创建
# train_path = vgg_path + '/train'
# test_path = vgg_path + '/test'
# os.makedirs(train_path)
# os.makedirs(test_path)
# # 生成标签
# img_types = ilc.get_label_by_proj_no(proj_no=proj_no, db=db)
# type_list = []
# for img_type in img_types:
# type_list.append(img_type.lebel_type)
# type_dict = {'classes': type_list}
# str_json = json.dumps(type_dict)
# path = root_path + proj_no + '/img_label_type'
# write_info(file_name=path, file_info=json.loads(str_json))
# return img_path, path + '.json'
if __name__ == '__main__':
# s = os.path.exists('D:/pythonProject/DeepLearnAiPlatform/data_set/868503011860480/ori/images/1.png')
# print(s)
# file = 'D:/pythonProject/DeepLearnAiPlatform/data_set/3148803620347904/ori/labels/36.json'
# file_path = 'D:/pythonProject/DeepLearnAiPlatform/data_set/3148803620347904/trained/labels/36.json'
s = get_file(proj_no='3148803620347904')
# shutil.copyfile(file, file_path)
print(s)