pyqt5 实现使用瓦片显示超大影像

pyqt5 实现使用瓦片显示超大影像,第1张

这里只列出代码

基本思路就是将影像分割成较小的瓦片,然后循环显示到QGraphicsItem中,还涉及到瓦片缩放、平移、范围选取和点选取。

mport pyqtgraph as pg
from PyQt5.QtCore import Qt, QRectF, pyqtSignal
from PyQt5.QtGui import QColor, QPalette, QPixmap
from PyQt5.QtWidgets import QGraphicsView, QGraphicsScene

from view.image.ImageItem import ImageItem


class ImageView(QGraphicsView):
    # 鼠标悬停像素的信号
    signal_pixel_selected = pyqtSignal(str, int)
    signal_progress = pyqtSignal(object, object)

    item: ImageItem = None
    pixmap: QPixmap = None
    # 影像名称和item的对应关系
    item_dict = {}

    def __init__(self):
        super(ImageView, self).__init__()
        pg.ImageView()
        pg.ImageItem()
        self.QImg = None
        self.image = None
        self.image_array = None
        self.background_color = QColor()
        self.background_color.setNamedColor('#ffffff')
        self.palette = QPalette()
        self.palette.setColor(QPalette.Window, self.background_color)
        self.__init_view()

    def __init_view(self):
        self.setAutoFillBackground(True)
        self.setPalette(self.palette)
        self.setAlignment(Qt.AlignCenter)
        # 创建场景
        self.scene = QGraphicsScene()

    # 显示整幅图像
    def showImage(self):
        self.setSceneRect(QRectF(-(self.width() / 2), -(self.height() / 2), self.width() - 10, self.height() - 10))
        self.setScene(self.scene)
        self.item = ImageItem(self.image, self.signal_pixel_selected, self.signal_progress)
        self.item.setQGraphicsViewWH(self.width(), self.height())
        self.scene.clear()
        self.scene.addItem(self.item)

    def showLayer(self, image):
        self.image = image
        if image is None:
            return
        self.image = image
        self.showImage()

    # 清理显示的图层
    def cleanLayer(self, image):
        self.scene.clear()

    # 设置图像的显示状态
    def setItemStatus(self, status, tool):
        if self.item is not None:
            self.item.setStatus(status)
            self.item.setCurrentTool(tool)

 影像显示控件继承了QGraphicsItem,在这里实现缩放和平移

from PyQt5.QtCore import Qt, QPoint, QRectF
from PyQt5.QtGui import QPainter, QPixmap, QColor, QPen, QImage
from PyQt5.QtWidgets import QStyleOptionGraphicsItem, QWidget, QGraphicsItem, QApplication

from image import OpticalImage
from system.System import System


class ImageItem(QGraphicsItem):

    def __init__(self, image: OpticalImage, signal_pixel_selected, signal_progress):
        super(ImageItem, self).__init__()
        # 影像在横轴偏移量
        self.w_offset = 0
        # 影像在纵轴的偏移量
        self.h_offset = 0
        # 鼠标当前所在的像素位置
        self.w_cur_pos = 0
        self.h_cur_pos = 0
        # 当前正在显示的影像
        self.image = None
        # 信号
        self.signal_pixel_selected = None
        # 选择状态
        self.image_view_status = System.v_image_show
        # 当前的矩形几何
        self.show_rect = []
        # 当前选择的点
        self.show_point = []
        # 当前正在画的点
        self.current_paint_point = None
        # 当前正在画的矩形
        self.current_paint_rect = None
        # 当前鼠标按下的键
        self.mouse_click_button = None
        # pixmap瓦片
        self.pixmap_tiles = {}

        self.image = image
        self.setAcceptDrops(True)
        self.m_scaleValue = 1
        self.m_scaleDafault = 1
        self.m_isMove = False
        self.m_isRect = False
        self.m_startPos = None
        self.setAcceptHoverEvents(True)
        self.signal_pixel_selected = signal_pixel_selected
        self.signal_progress = signal_progress

        self.currentTool = System.v_image_show
        self.setStatus(System.v_image_show)
        self.create_pixmap_tiles()

    def boundingRect(self):
        """
        重写原有方法
        :return:
        """
        self.w_offset = int(self.image.width / 2)
        self.h_offset = int(self.image.height / 2)
        return QRectF(-self.w_offset, -self.h_offset, self.image.width, self.image.height)

    def paint(self, painter, const, widget=None):
        """
        重写方法
        :param painter:
        :param const:
        :param widget:
        :return:
        """
        self.painter_tiles(painter)
        self.paintShowChecked(painter, const, widget)

    def painter_tiles(self, painter):
        """
        将影像瓦片展示到控件
        :param painter:
        :return:
        """
        tiles_size = self.image.tiles.size
        length = len(self.pixmap_tiles)
        cur = 0
        for key, value in self.pixmap_tiles.items():
            h_index, w_index = key
            h_offset = h_index * tiles_size - self.h_offset
            w_offset = w_index * tiles_size - self.w_offset
            if value is None:
                continue
            painter.drawPixmap(w_offset, h_offset, value)
            cur = cur + 1

    def create_pixmap_tiles(self):
        """
        根据显示的图层创建pixmap的瓦片
        :return:
        """
        length = len(self.image.get_show_layer())
        cur = 0
        for key, value in self.image.get_show_layer().items():
            pixmap = self.create_pixmap(value)
            self.pixmap_tiles[key] = pixmap
            cur = cur + 1
            System.signal.signal_progress.emit("正在创建瓦片:", cur * 100 / length)
            QApplication.processEvents()
        print("pixmap瓦片创建完成")
        # self.image.start_image_static()

    def create_pixmap(self, image_array):
        """
        根据给定的数组生成pixmap
        :param image_array:
        :return:
        """
        shape = image_array.shape
        height = 0
        width = 0
        bytesPerComponent = 1
        if len(shape) == 2:
            height, width = shape
        elif len(shape) == 3:
            height, width, bytesPerComponent = shape
        bytesPerLine = bytesPerComponent * width
        if bytesPerComponent == 1:
            # 显示灰度图像
            QImg = QImage(image_array.tobytes(), width, height, bytesPerLine, QImage.Format_Grayscale8)
        elif bytesPerComponent == 3:
            # 显示GGB图像
            QImg = QImage(image_array.tobytes(), width, height, bytesPerLine, QImage.Format_RGB888)
        else:
            # 如果不符合就不显示
            return
        return QPixmap.fromImage(QImg)

    def paintEvent(self):
        pass

    # 将主界面的控件QGraphicsView的width和height传进本类中,并根据图像的长宽和控件的长宽的比例来使图片缩放到适合控件的大小
    def setQGraphicsViewWH(self, nwidth: int, nheight: int):
        nImgWidth = self.image.width
        nImgHeight = self.image.height
        t_width = nwidth / nImgWidth
        t_height = nheight / nImgHeight
        if t_width > t_height:
            self.m_scaleDafault = t_height
        else:
            self.m_scaleDafault = t_width
        self.setScale(self.m_scaleDafault)
        self.m_scaleValue = self.m_scaleDafault

    # 重置图片大小
    def resetItemPos(self):
        #  缩放比例回到一开始的自适应比例
        self.m_scaleValue = self.m_scaleDafault
        # 缩放到一开始的自适应大小
        self.setScale(self.m_scaleDafault)
        self.setPos(0, 0)

    def getScaleValue(self):
        return self.m_scaleValue

    # 画矩形
    def paintShowChecked(self, painter: QPainter, const: QStyleOptionGraphicsItem, widget: QWidget):
        # 把当前正在画的矩形实时显示出来
        if self.current_paint_rect is not None:
            self.paintRect(self.current_paint_rect, painter)
        for rect in self.show_rect:
            self.paintRect(rect, painter)
        if self.current_paint_point is not None:
            self.paintPoint(self.current_paint_point, painter, const, widget)
        for point in self.show_point:
            self.paintPoint(point, painter, const, widget)

    def paintPoint(self, point, painter: QPainter, const: QStyleOptionGraphicsItem, widget: QWidget):
        """
        画点
        :param widget:
        :param const:
        :param point:
        :param painter:
        :return:
        """
        # pointImg = QImage(16, 16, QImage.Format_RGBA8888)
        # pointImg.load("../icon/cross.png", "PNG")
        # Pix = QPixmap(16, 16)
        # Pix.fromImage(pointImg)
        # painter.drawPixmap(point.x() - 8, point.y() - 8, Pix)
        painter.setPen(QPen(QColor(255, 0, 0), 10 / self.m_scaleValue))
        painter.drawPoint(point)

    def paintRect(self, rect, painter: QPainter):
        """
        显示矩形
        :param rect:
        :param painter:
        :return:
        """
        if rect is None:
            return
        painter.setPen(QPen(QColor(255, 0, 0), 1 / self.m_scaleValue))
        painter.drawRect(rect)
        painter.fillRect(rect, QColor(200, 200, 0, 50))

    # 设置图像控件的 *** 作状态
    def setStatus(self, status):
        old_status = self.image_view_status
        self.image_view_status = status
        if self.image_view_status == System.v_image_show:
            self.setCursor(Qt.OpenHandCursor)
            self.show_rect = []
            self.current_paint_rect = None
            self.update(self.boundingRect())
        elif self.image_view_status == System.v_image_choose_rect:
            self.setCursor(Qt.CrossCursor)
        elif self.image_view_status == System.v_image_choose_multi_rect:
            self.setCursor(Qt.CrossCursor)
            self.multi_rect = True
        elif self.image_view_status == System.v_image_choose_point:
            self.setCursor(Qt.CrossCursor)
        elif self.image_view_status == System.v_image_choose_cancel_rect:
            self.image_view_status = old_status
            # 取消选择时,保持原有的选择状态
            self.setCursor(Qt.CrossCursor)
            self.show_rect = []
            self.current_paint_rect = None
            self.show_point = []
            self.current_paint_point = None
            self.update(self.boundingRect())

    # 设置当前使用图像控件的工具类
    def setCurrentTool(self, tool):
        self.currentTool = tool

    # 鼠标点击事件
    def mousePressEvent(self, event):
        self.mouse_click_button = event.buttons()
        if event.buttons() == Qt.LeftButton:  # 左键按下
            if self.image_view_status == System.v_image_show:
                # 图像展示状态-移动图像
                self.m_startPos = event.pos()
                self.setCursor(Qt.ClosedHandCursor)
            elif self.image_view_status == System.v_image_choose_rect:
                # 矩形框选状态
                self.m_startPos = event.pos()
                self.show_rect = []
            elif self.image_view_status == System.v_image_choose_multi_rect:
                self.m_startPos = event.pos()
            elif self.image_view_status == System.v_image_choose_point:
                self.current_paint_point = event.pos()
        elif event.buttons() == Qt.RightButton:  # 右键按下
            print("单击鼠标右键")
        elif event.buttons() == Qt.MidButton:  # 中键按下
            self.m_startPos = event.pos()
            self.setCursor(Qt.ClosedHandCursor)

    '''滚轮滚动事件'''

    def wheelEvent(self, event):
        if event.delta() > 0 and self.m_scaleValue >= 1000:
            return
        elif event.delta() < 0 and self.m_scaleValue <= 0.01:
            return
        else:
            angle = event.delta()
            # 获取当前鼠标相对于view的位置
            pos = event.pos()
            qrealOriginScale = self.m_scaleValue
            if angle > 0:
                self.m_scaleValue *= 1.1
            else:  # 滚轮下滚
                self.m_scaleValue *= 0.9
            self.setScale(self.m_scaleValue)
            # 使图片缩放的效果看起来像是以鼠标所在点为中心进行缩放的
            if angle > 0:
                self.moveBy(-pos.x() * qrealOriginScale * 0.1, -pos.y() * qrealOriginScale * 0.1)
            else:  # 滚轮下滚
                self.moveBy(pos.x() * qrealOriginScale * 0.1, pos.y() * qrealOriginScale * 0.1)

    '''鼠标双击事件(单击)'''

    def mouseDoubleClickEvent(self, event):
        self.resetItemPos()

    '''鼠标键释放事件'''

    def mouseReleaseEvent(self, event):
        if self.mouse_click_button == Qt.LeftButton:
            if self.image_view_status == System.v_image_show:
                self.setCursor(Qt.OpenHandCursor)
            elif self.image_view_status == System.v_image_choose_rect:
                self.show_rect.append(self.current_paint_rect)
            elif self.image_view_status == System.v_image_choose_multi_rect:
                self.show_rect.append(self.current_paint_rect)
            elif self.image_view_status == System.v_image_choose_point:
                self.show_point.append(self.createPoint(self.current_paint_point))
                self.update(self.boundingRect())
        elif self.mouse_click_button == Qt.MidButton:  # 中键释放
            self.setCursor(Qt.OpenHandCursor)
            self.setStatus(self.image_view_status)
        self.mouse_click_button = None

    '''鼠标移动事件'''

    def mouseMoveEvent(self, event):
        if self.mouse_click_button == Qt.MidButton:
            point: QPoint = (event.pos() - self.m_startPos) * self.m_scaleValue
            self.moveBy(point.x(), point.y())
        elif self.image_view_status == System.v_image_choose_rect or self.image_view_status == System.v_image_choose_multi_rect:
            # 矩形框选状态
            pos = event.pos()
            rbx = int(pos.x())
            rby = int(pos.y())
            ltx = int(self.m_startPos.x())
            lty = int(self.m_startPos.y())
            self.current_paint_rect = self.createRect(ltx, lty, rbx, rby)
            self.update(self.boundingRect())

    def createPoint(self, point):
        """
        根据坐标获取显示的点
        :param x:
        :param y:
        :return:
        """
        return point

    # 根据坐标画出矩形
    def createRect(self, startX, startY, endX, endY):
        """
        根据坐标画出矩形
        :param startX:
        :param startY:
        :param endX:
        :param endY:
        :return:
        """
        point_x = min(startX - 1, endX)
        point_y = min(startY - 1, endY)
        side_x = abs(startX - endX)
        side_y = abs(startY - endY)
        return QRectF(point_x, point_y, side_x, side_y)

    # 鼠标滑过事件
    def hoverMoveEvent(self, event):
        pos = event.pos()
        self.w_cur_pos, self.h_cur_pos = self.getViewPoint2ImagePos(pos)
        info = self.image.getPointInfo(self.h_cur_pos, self.w_cur_pos)
        self.signal_pixel_selected.emit(info, 5000)

    def getViewPoint2ImagePos(self, pos):
        """
        根据控件坐标获取图片坐标
        :param pos:
        :return:
        """
        return int(pos.x() + self.w_offset), int(pos.y() + self.h_offset)

    def getCheckImagePos(self):
        points = []
        for point in self.show_point:
            h, w = self.getViewPoint2ImagePos(point)
            points.append(self.image.getPointGeo(h, w))
        return points

    # 获取当前显示图像在矩形框内的切片
    def getSelectedRectImage(self):
        if self.show_rect is not None:
            selected_image = []
            for rect in self.show_rect:
                ltx = rect.x() + self.w_offset
                lty = rect.y() + self.h_offset
                rbx = ltx + rect.width()
                rby = lty + rect.height()
                selected_image.append(self.image.getImageByRange(int(ltx), int(lty), int(rbx), int(rby)))
            return selected_image
        else:
            return None

    def getSelectPoint(self):
        """
        获取选中的点
        :return:
        """
        return self.show_point

    # 矩形框选择完成的调用
    def rectSelected(self):
        if self.show_rect is None or self.image is None:
            return
import numpy as np

from image.RsImage import RSImage


class OpticalImage(RSImage):
    """
    光学遥感影像实体
    """

    def __init__(self, path, fileType):
        super(OpticalImage, self).__init__(path, fileType)
        # 初始化显示数组的索引
        show_bands = 3
        if self.bands < 3:
            show_bands = self.bands
        self.show_layers_index = np.arange(show_bands)
        self.tiles.init_show_tiles(True)

    def get_show_layer(self):
        return self.tiles.get_show_tiles(self.show_layers_index)

    # 获取指定范围内的图像内容
    def getImageByRange(self, x, y, tx, ty):
        """
        获取指定范围内的图像内容
        :param x:
        :param y:
        :param tx:
        :param ty:
        :return:
        """
        return self.tiles.get_range(x, y, tx, ty)

 影像信息类

# -*- coding: utf-8 -*-
"""
遥感影像处理类
包括光学遥感影像和雷达遥感影像
"""
import os

import gdal
import numpy as np
import osr

from image.ImageParams import ImageParams
from image.Tiles import Tiles
from view.dock.layer.ImageLayerTreeItem import ImageLayerTreeItemPos


class RSImage:
    """
    遥感影像处理基类
    """
    # 文件名称
    imageName = None
    # 文件路径
    imagePath = None
    # 图片内容
    dataset = None
    # 影像图层数量
    bands = 0
    # 影像宽度(像素)
    width = 0
    # 影像高度(像素)
    height = 0
    # 影像所有图层
    allLayer = None
    # 影像类型
    fileType = None
    # 影像坐标参数
    adfGeoTransform = None
    # 影像控件参考信息
    projection = None
    # 影像状态
    image_status = None
    # 影像信息
    image_info = {}
    # 当前显示的图层索引
    show_layers_index = [0, 1, 2]
    # 影像瓦片
    tiles: Tiles = None
    # 影像信息
    image_params = None

    def __init__(self, path, fileType):
        self.fileType = fileType
        self.__readImage(path)
        self.init_image_params()

    def __readImage(self, path):
        """
        打开文件
        :param self:
        :param path:
        :return:
        """
        print("打开文件" + path)
        self.imagePath = path
        self.imageName = os.path.basename(path)
        self.dataset = gdal.Open(path)
        if self.dataset is None:
            print("文件打开失败")
            self.image_status = False
            return
        self.image_status = True
        self.bands = self.dataset.RasterCount
        self.width = self.dataset.RasterXSize
        self.height = self.dataset.RasterYSize
        # 获取影像坐标和分辨率
        self.adfGeoTransform = self.dataset.GetGeoTransform()
        # 获取影像投影信息
        self.projection = osr.SpatialReference(self.dataset.GetProjection())
        self.tiles = Tiles(self.dataset, 1000, self.fileType)
        pass

    def get_show_layer(self):
        pass

    # 获取指定的图层数组
    def getBandByIndex(self, index):
        """
        获取指定的图层数组
        :param index:
        :return:
        """
        return self.tiles.get_band(index)

    # 根据像素位置获取当前坐标点
    def getPointGeo(self, h, w):
        """
        根据像素位置获取当前坐标点
        :param h:
        :param w:
        :return:
        """
        px = self.adfGeoTransform[0] + w * self.adfGeoTransform[1] + h * self.adfGeoTransform[2]
        py = self.adfGeoTransform[3] + w * self.adfGeoTransform[4] + h * self.adfGeoTransform[5]
        col = [px, py]
        return col

    # 根据像素位置获取当前的像素值
    def getPointValue(self, h, w):
        """
        根据像素位置获取当前的像素值
        :param h:
        :param w:
        :return:
        """
        return self.tiles.get_pix(w, h)

    # 获取指定范围内的图像内容
    def getImageByRange(self, x, y, tx, ty):
        """
        获取指定范围内的图像内容
        :param x:
        :param y:
        :param tx:
        :param ty:
        :return:
        """
        pass

    def getPointInfo(self, h, w):
        """
        根据当前位置获取当前的像素信息并格式化输出
        :param self:
        :param h:
        :param w:
        :return:
        """
        if h >= self.height or w >= self.width:
            return "超出影像范围"
        if self.bands == 1:
            point_value = self.getPointValue(h, w)
        else:
            point_value = self.getPointValue(h, w)[self.show_layers_index]
        geo_value = self.getPointGeo(h, w)
        point_value_str = np.str(point_value)
        geo_value_str = np.str(geo_value)
        unit_name = self.projection.GetAttrValue('UNIT', 0)
        unit = self.projection.GetAttrValue('UNIT', 1)
        if unit_name is None:
            unit_name = "无单位"
        if unit is None:
            unit = "无"
        return '像素值:' + point_value_str + ' 坐标:' + geo_value_str + " " + str(unit_name) + "(" + unit + ")"

    def get_image_info(self):
        """
        获取影像信息
        :param self:
        :return:
        """
        base_info = {"路径": self.imagePath, "名称": self.imageName, "宽度": self.width, "高度": self.height,
                     "图层": self.bands}
        self.image_info["基础信息"] = base_info
        # 解析图像的地理信息
        if self.adfGeoTransform is not None:
            geo_info = {"起始坐标": {"X": self.adfGeoTransform[0], "Y": abs(self.adfGeoTransform[3])},
                        "分辨率": {"X": self.adfGeoTransform[1], "Y": abs(self.adfGeoTransform[5])},
                        "幅宽": {"X": self.width * abs(self.adfGeoTransform[1]),
                               "Y": self.height * abs(self.adfGeoTransform[5])}}
            # 左上角坐标
            self.image_info["坐标信息"] = geo_info
        # 解析影像的投影信息
        return self.image_info

    def init_image_params(self):
        self.image_params = ImageParams("ImageParams", "影像信息", ImageParams.TYPE_NODE)

        base_info = self.image_params.add_child("BaseInfo", "基础信息", ImageParams.TYPE_NODE)
        base_info.add_child("ImageName", "影像名称", ImageParams.TYPE_STR, self.imageName)
        base_info.add_child("ImagePath", "影像路径", ImageParams.TYPE_STR, self.imagePath)
        base_info.add_child("ImageWidth", "影像宽度", ImageParams.TYPE_INT, self.width, "像素")
        base_info.add_child("ImageHeight", "影像高度", ImageParams.TYPE_INT, self.height, "像素")

        geo_info = self.image_params.add_child("GeoInfo", "地理信息", ImageParams.TYPE_NODE)

        start_pos = geo_info.add_child("StartPos", "起始坐标", ImageParams.TYPE_NODE)
        unit_name = self.projection.GetAttrValue('UNIT', 0)
        pos_unit = self.projection.GetAttrValue('UNIT', 1)
        if pos_unit is None or pos_unit == "":
            pos_unit = "未获取"

        start_pos.add_child("StartPosX", "X", ImageParams.TYPE_FLOAT, self.adfGeoTransform[0], pos_unit, unit_name)
        start_pos.add_child("StartPosY", "Y", ImageParams.TYPE_FLOAT, self.adfGeoTransform[3], pos_unit, unit_name)

        resolution = geo_info.add_child("Resolution", "分辨率", ImageParams.TYPE_NODE)
        resolution.add_child("ResolutionX", "X轴分辨率", ImageParams.TYPE_FLOAT, self.adfGeoTransform[1], pos_unit,
                             unit_name)
        resolution.add_child("ResolutionY", "Y轴分辨率", ImageParams.TYPE_FLOAT, self.adfGeoTransform[5], pos_unit,
                             unit_name)

        size = geo_info.add_child("ImageSize", "影像幅宽", ImageParams.TYPE_NODE)
        size.add_child("SizeX", "X轴幅宽", ImageParams.TYPE_FLOAT, self.width * abs(self.adfGeoTransform[1]), pos_unit,
                       unit_name)
        size.add_child("SizeY", "Y轴幅宽", ImageParams.TYPE_FLOAT, self.height * abs(self.adfGeoTransform[5]), pos_unit,
                       unit_name)

    def start_image_static(self):
        """
        统计影像的直方图信息
        :return:
        """
        self.tiles.start_image_static()

 影像属性,不重要

# -*- coding: utf-8 -*-

class ImageParams:
    TYPE_STR = 1
    TYPE_INT = 2
    TYPE_FLOAT = 3
    TYPE_LIST = 10
    TYPE_DICT = 11
    TYPE_TABLE = 12
    TYPE_NODE = 30

    def __init__(self, name, label, p_type, value=None, unit=None, unit_name=None):
        # 参数名称
        self.name = name
        # 显示名称
        self.label = label
        # 参数值
        self.value = value
        # 参数类型
        self.p_type = p_type
        # 单位
        if unit is None:
            self.unit = ""
        else:
            self.unit = unit
        # 单位名称
        if unit_name is None:
            self.unit_name = ""
        else:
            self.unit_name = unit_name
        # 参数子项
        self.children = {}

    def add_child(self, name, label, p_type, value=None, unit=None, unit_name=None):
        """
        增加子节点
        :param unit_name:
        :param unit:
        :param name:
        :param label:
        :param value:
        :param p_type:
        :return:
        """
        child = ImageParams(name, label, p_type, value, unit, unit_name)
        self.children.update({name: child})
        return child


if __name__ == '__main__':
    par = {}
    par.update({"key": "value"})
    pass

 瓦片类:

# -*- coding: utf-8 -*-
import _thread
import math
from collections import Counter

import numpy as np
from PyQt5.QtWidgets import QApplication

from system.System import System
from util import MathUtil


class Tiles:
    """
    影像瓦片模型
    """

    def __init__(self, dataset, size=1000, image_type=None):
        """
        初始化影像瓦片
        :param dataset: 影像源数据集
        :param size:
        """
        # 瓦片大小,默认1000
        self.size = 1000
        self.image_type = image_type
        # 实际的瓦片字典
        self.tiles_dict_source = {}
        # 显示的瓦片
        self.tiles_dict_show = {}
        self.dataset = dataset
        self.bands = self.dataset.RasterCount
        self.width = self.dataset.RasterXSize
        self.height = self.dataset.RasterYSize
        self.size = size
        self.max_value = -99999
        self.min_value = 99999
        # 横向瓦片个数
        self.w_t = math.ceil(self.width / self.size)
        # 纵向瓦片个数
        self.h_t = math.ceil(self.height / self.size)
        # 初始化图层极值
        self.native_extremum_array = np.zeros([self.bands, 2])
        self.extremum_array = np.zeros([self.bands, 2])
        self.static = {}
        # 获取各个图层的最大和最小值
        for band in range(self.bands):
            self.native_extremum_array[band] = [-999999999, 999999999]
            self.extremum_array[band] = [-999999999, 999999999]
        # 初始化瓦片
        self.__init_tiles()

    def __init_tiles(self):
        """
        瓦片切片
        :return:
        """
        for h in range(self.h_t):
            h_range_start = h * self.size
            h_range = self.size
            if h_range_start + h_range > self.height:
                h_range = self.height - h_range_start
            for w in range(self.w_t):
                w_range_start = w * self.size
                w_range = self.size
                if w_range_start + w_range > self.width:
                    w_range = self.width - w_range_start
                tiles = self.dataset.ReadAsArray(w_range_start, h_range_start, w_range, h_range)
                self.tiles_dict_source[(h, w)] = tiles
            System.signal.signal_progress.emit("正在切分瓦片:", (h + 1) * 100 / self.h_t)
            QApplication.processEvents()
        if System.m_image_type_optical == self.image_type:
            self.image_extremum()

    def image_extremum(self):
        length = len(self.tiles_dict_source)
        cur = 0
        for key, value in self.tiles_dict_source.items():
            # 统计极值
            for band in range(self.bands):
                layer = value[band]
                max_new = np.max(layer)
                min_new = np.min(layer)
                max_old = self.native_extremum_array[band][0]
                min_old = self.native_extremum_array[band][1]
                if max_new > max_old:
                    max_old = max_new
                if min_new < min_old:
                    min_old = min_new
                self.native_extremum_array[band] = [max_old, min_old]
            cur = cur + 1
            System.signal.signal_progress.emit("影像分析:", (cur / length) * 100)
            QApplication.processEvents()

    def start_image_static(self):
        """
        统计影像信息
        :return:
        """
        if len(self.static) == 0:
            _thread.start_new_thread(self.image_static, ())

    def image_static(self):
        length = len(self.tiles_dict_source)
        cur = 0
        for key, value in self.tiles_dict_source.items():
            # 统计像素值
            for band in range(self.bands):
                layer = value[band]
                band_static = self.static.get(band)
                if band_static is None:
                    band_static = {}
                tiles_static = Counter(layer.flatten())
                self.static[band] = MathUtil.dict_add_or_plus(band_static, tiles_static)
            # 更新进度条信息
            cur = cur + 1
            System.signal.signal_progress.emit("初始化影像直方图信息:", (cur / length) * 100)
            QApplication.processEvents()
        for band_index in range(self.bands):
            # 对统计值排序
            band = self.static.get(band_index)
            band_sort = dict(sorted(band.items(), key=lambda x: x[0], reverse=False))
            self.static[band_index] = band_sort
        System.signal.signal_image_static.emit(self.static)

    def __get_tiles_index(self, p_w, p_h):
        """
        根据影像的实际坐标获取所在的瓦片索引
        :param p_w:
        :param p_h:
        :return:
        """
        w_t_i = math.floor(p_w / self.size)
        h_t_i = math.floor(p_h / self.size)
        return h_t_i, w_t_i

    def __set_extremum(self, native=True, max_pix=255, min_pix=0):
        if native:
            self.extremum_array = self.native_extremum_array
        else:
            for band in range(self.bands):
                self.extremum_array[band] = [max_pix, min_pix]

    def set_extremum(self, max_value, min_value):
        """
        设置影像的显示范围
        :param max_value:
        :param min_value:
        :return:
        """
        self.max_value = max_value
        self.min_value = min_value

    def get_band(self, index):
        band = self.dataset.GetRasterBand(index + 1)
        return band.ReadAsArray(0, 0, self.width, self.height)

    def get_pix(self, p_w, p_h):
        """
        获取像素值
        :param p_w: 宽度坐标
        :param p_h: 高度坐标
        :return:
        """
        tiles = self.get_tiles(p_w, p_h)
        if tiles is None:
            print("未找到瓦片信息")
            return np.zeros([self.bands])
        t_w = p_w % self.size
        t_h = p_h % self.size
        if self.bands == 1:
            return tiles[t_h, t_w]
        else:
            return tiles[:, t_h, t_w]

    def get_range(self, s_w, s_h, t_w, t_h):
        """
        获取选择的范围
        :param s_w: 起始宽度坐标
        :param s_h: 起始高度坐标
        :param t_w: 结束宽度坐标
        :param t_h: 结束高度坐标
        :return:
        """
        select_range = self.dataset.ReadAsArray(s_w, s_h, t_w - s_w, t_h - s_h)
        shape = select_range.shape
        if len(shape) == 2:
            return select_range
        else:
            return np.rollaxis(select_range, 0, 3)

    def get_radar_range(self, s_w, s_h, t_w, t_h):
        """
        雷达影像获取范围
        :param s_w:
        :param s_h:
        :param t_w:
        :param t_h:
        :return:
        """
        select_range = self.dataset.ReadAsArray(s_w, s_h, t_w - s_w, t_h - s_h)
        return self.radar_modulo(select_range)

    def get_tiles(self, p_w, p_h):
        """
        根据影像的实际坐标获取所在的瓦片
        :param p_w:
        :param p_h:
        :return:
        """
        index = self.__get_tiles_index(p_w, p_h)
        return self.tiles_dict_source.get(index)

    def init_show_tiles(self, native=True, max_pix=255, min_pix=0):
        """
        初始化用于显示的影像瓦片
        :param min_pix:
        :param max_pix:
        :param native:
        :return:
        """
        self.__set_extremum(native, max_pix, min_pix)
        length = len(self.tiles_dict_source)
        cur = 0
        for key, value in self.tiles_dict_source.items():
            if System.m_image_type_optical == self.image_type:
                self.tiles_dict_show[key] = self.stretch_extremum(value)
            elif System.m_image_type_radar == self.image_type:
                self.tiles_dict_show[key] = self.radar_show(value, 255, 0)
            # 更新进度条信息
            cur = cur + 1
            System.signal.signal_progress.emit("初始化影像显示:", (cur / length) * 100)
            QApplication.processEvents()

    def get_show_tiles(self, show_layer_index):
        """
        根据图层索引获取显示的图层
        :param show_layer_index: 图层索引数组,python原生的格式
        :return:
        """
        show_dict = {}
        for key, value in self.tiles_dict_show.items():
            if self.bands == 1 or self.bands < len(show_layer_index):
                show_tiles = value
            else:
                show_tiles = value[:, :, show_layer_index]
            show_dict[key] = show_tiles
        return show_dict

    def stretch_extremum(self, array):
        """
        影像拉伸-将影像值拉伸到0-255之间
        :param array:
        :return:
        """
        show_tiles = np.empty(array.shape, dtype='uint8')
        for band in range(self.bands):
            max_value = self.extremum_array[band][0]
            min_value = self.extremum_array[band][1]
            coefficient = (max_value - min_value) / 255
            if self.bands == 1:
                coefficient_array = array[:, :] / coefficient
            else:
                coefficient_array = array[band, :, :] / coefficient
            coefficient_array = coefficient_array - min_value / coefficient
            show_layer = coefficient_array.astype('uint8')
            if self.bands == 1:
                show_tiles = show_layer
            else:
                show_tiles[band] = show_layer
        if self.bands == 1:
            return show_tiles
        else:
            return np.rollaxis(show_tiles, 0, 3)

    def radar_show(self, array, max_pix, min_pix):
        """
        对雷达影像取模
        :param min_pix:
        :param max_pix:
        :param array:
        :return:
        """
        modulo_layer = self.radar_modulo(array)
        show_layer = np.where(modulo_layer > max_pix, max_pix, modulo_layer)
        show_layer = np.where(show_layer < min_pix, min_pix, show_layer)
        return show_layer.astype('uint8')

    def radar_modulo(self, array):
        """
        雷达影像取模
        :param array:
        :return:
        """
        band1 = array[0, :, :].astype('float')
        band2 = array[1, :, :].astype('float')
        modulo_layer = np.sqrt(np.square(band1) / 2 + np.square(band2) / 2)
        return modulo_layer
# -*- coding: utf-8 -*-
import math


def get_average(records):
    """
    平均值
    """
    return sum(records) / len(records)


def get_variance(records):
    """
    方差 反映一个数据集的离散程度
    """
    average = get_average(records)
    return sum([(x - average) ** 2 for x in records]) / len(records)


def get_standard_deviation(records):
    """
    标准差 == 均方差 反映一个数据集的离散程度
    """
    variance = get_variance(records)
    return math.sqrt(variance)


def get_rms(records):
    """
    均方根值 反映的是有效值而不是平均值
    """
    return math.sqrt(sum([x ** 2 for x in records]) / len(records))


def get_mse(records_real, records_predict):
    """
    均方误差 估计值与真值 偏差
    """
    if len(records_real) == len(records_predict):
        return sum([(x - y) ** 2 for x, y in zip(records_real, records_predict)]) / len(records_real)
    else:
        return None


def get_rmse(records_real, records_predict):
    """
    均方根误差:是均方误差的算术平方根
    """
    mse = get_mse(records_real, records_predict)
    if mse:
        return math.sqrt(mse)
    else:
        return None


def get_mae(records_real, records_predict):
    """
    平均绝对误差
    """
    if len(records_real) == len(records_predict):
        return sum([abs(x - y) for x, y in zip(records_real, records_predict)]) / len(records_real)
    else:
        return None


def dict_add_or_plus(dict1, dict2):
    dict_new = {}
    for key in list(set(dict1) | set(dict2)):
        if dict1.get(key) and dict2.get(key):
            dict_new.update({key: dict1.get(key) + dict2.get(key)})
        else:
            dict_new.update({key: dict1.get(key) or dict2.get(key)})
    return dict_new

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/918929.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-16
下一篇 2022-05-16

发表评论

登录后才能评论

评论列表(0条)

保存