DianLi/x64/Release/models/mask_fill/tifIO.py

68 lines
2.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from osgeo import gdal
# from tqdm import tqdm
def ReadTif(tif_path):
dataset = gdal.Open(tif_path)
width = dataset.RasterXSize
height = dataset.RasterYSize
geotrans = list(dataset.GetGeoTransform()) # 仿射矩阵
proj = dataset.GetProjection() # 地图投影信息
data = dataset.ReadAsArray(0, 0, width, height) # 将数据写成数组,对应栅格矩阵
del dataset # 关闭对象文件dataset
return proj, geotrans, data, width, height
def writeTif(fileroute, im_proj, im_geotrans, im_data):
# 判断栅格数据的数据类型
"""
GDAL中的GDALDataType是一个枚举型其中的值为
GDT_Unknown : 未知数据类型
GDT_Byte : 8bit正整型 (C++中对应unsigned char)
GDT_UInt16 : 16bit正整型 (C++中对应 unsigned short)
GDT_Int16 : 16bit整型 (C++中对应 short 或 short int)
GDT_UInt32 : 32bit 正整型 (C++中对应unsigned long)
GDT_Int32 : 32bit整型 (C++中对应int 或 long 或 long int)
GDT_Float32 : 32bit 浮点型 (C++中对应float)
GDT_Float64 : 64bit 浮点型 (C++中对应double)
GDT_CInt16 : 16bit复整型 (?)
GDT_CInt32 : 32bit复整型 (?)
GDT_CFloat32 : 32bit复浮点型 (?)
GDT_CFloat64 : 64bit复浮点型 (?)
"""
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
im_data = im_data.reshape(im_bands,im_height, im_width)
# 创建文件
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(fileroute, im_width,im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data[0]) # 写入数组数据
else:
#for i in range(im_bands):
for i in tqdm(range(im_bands)):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
def np2gdal(hwc): # hwc2chw
chw = hwc.swapaxes(2, 0).swapaxes(1, 2) # h,w,c to c,h,w
return chw
def gdal2np(chw): #chw2hwc
hwc = chw.swapaxes(1, 0).swapaxes(1, 2) # h,w,c
return hwc