python+opencv实现目标跟踪

发布时间:2019-09-06 08:51:59编辑:auto阅读(3614)

    python-opencv3.0新增了一些比较有用的追踪器算法,这里根据官网示例写了一个追踪器类

    程序只能运行在安装有opencv3.0以上版本和对应的contrib模块的python解释器

     

    #encoding=utf-8
    
    import cv2
    from items import MessageItem
    import time
    import numpy as np
    '''
    监视者模块,负责入侵检测,目标跟踪
    '''
    class WatchDog(object):
      #入侵检测者模块,用于入侵检测
        def __init__(self,frame=None):
            #运动检测器构造函数
            self._background = None
            if frame is not None:
                self._background = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)
            self.es = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
        def isWorking(self):
            #运动检测器是否工作
            return self._background is not None
        def startWorking(self,frame):
            #运动检测器开始工作
            if frame is not None:
                self._background = cv2.GaussianBlur(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY), (21, 21), 0)
        def stopWorking(self):
            #运动检测器结束工作
            self._background = None
        def analyze(self,frame):
            #运动检测
            if frame is None or self._background is None:
                return
            sample_frame = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)
            diff = cv2.absdiff(self._background,sample_frame)
            diff = cv2.threshold(diff, 25, 255, cv2.THRESH_BINARY)[1]
            diff = cv2.dilate(diff, self.es, iterations=2)
            image, cnts, hierarchy = cv2.findContours(diff.copy(),cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            coordinate = []
            bigC = None
            bigMulti = 0
            for c in cnts:
                if cv2.contourArea(c) < 1500:
                    continue
                (x,y,w,h) = cv2.boundingRect(c)
                if w * h > bigMulti:
                    bigMulti = w * h
                    bigC = ((x,y),(x+w,y+h))
            if bigC:
                cv2.rectangle(frame, bigC[0],bigC[1], (255,0,0), 2, 1)
            coordinate.append(bigC)
            message = {"coord":coordinate}
            message['msg'] = None
            return MessageItem(frame,message)
    
    class Tracker(object):
        '''
        追踪者模块,用于追踪指定目标
        '''
        def __init__(self,tracker_type = "BOOSTING",draw_coord = True):
            '''
            初始化追踪器种类
            '''
            #获得opencv版本
            (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
            self.tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
            self.tracker_type = tracker_type
            self.isWorking = False
            self.draw_coord = draw_coord
            #构造追踪器
            if int(minor_ver) < 3:
                self.tracker = cv2.Tracker_create(tracker_type)
            else:
                if tracker_type == 'BOOSTING':
                    self.tracker = cv2.TrackerBoosting_create()
                if tracker_type == 'MIL':
                    self.tracker = cv2.TrackerMIL_create()
                if tracker_type == 'KCF':
                    self.tracker = cv2.TrackerKCF_create()
                if tracker_type == 'TLD':
                    self.tracker = cv2.TrackerTLD_create()
                if tracker_type == 'MEDIANFLOW':
                    self.tracker = cv2.TrackerMedianFlow_create()
                if tracker_type == 'GOTURN':
                    self.tracker = cv2.TrackerGOTURN_create()
        def initWorking(self,frame,box):
            '''
            追踪器工作初始化
            frame:初始化追踪画面
            box:追踪的区域
            '''
            if not self.tracker:
                raise Exception("追踪器未初始化")
            status = self.tracker.init(frame,box)
            if not status:
                raise Exception("追踪器工作初始化失败")
            self.coord = box
            self.isWorking = True
    
        def track(self,frame):
            '''
            开启追踪
            '''
            message = None
            if self.isWorking:
                status,self.coord = self.tracker.update(frame)
                if status:
                    message = {"coord":[((int(self.coord[0]), int(self.coord[1])),(int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3])))]}
                    if self.draw_coord:
                        p1 = (int(self.coord[0]), int(self.coord[1]))
                        p2 = (int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3]))
                        cv2.rectangle(frame, p1, p2, (255,0,0), 2, 1)
                        message['msg'] = "is tracking"
            return MessageItem(frame,message)
    
    class ObjectTracker(object):
        def __init__(self,dataSet):
            self.cascade = cv2.CascadeClassifier(dataSet)
        def track(self,frame):
            gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
            faces = self.cascade.detectMultiScale(gray,1.03,5)
            for (x,y,w,h) in faces:
                cv2.rectangle(frame,(x,y),(x+w,y+h),(255,0,0),2)
            return frame
    
    if __name__ == '__main__' :
        a = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
        tracker = Tracker(tracker_type="KCF")
        video = cv2.VideoCapture(0)
        ok, frame = video.read()
        bbox = cv2.selectROI(frame, False)
        tracker.initWorking(frame,bbox)
        while True:
            _,frame = video.read();
            if(_):
                item = tracker.track(frame);
                cv2.imshow("track",item.getFrame())
                k = cv2.waitKey(1) & 0xff
                if k == 27:
                    break

     

    #encoding=utf-8
    import json
    from utils import IOUtil
    '''
    信息封装类
    '''
    class MessageItem(object):
        #用于封装信息的类,包含图片和其他信息
        def __init__(self,frame,message):
            self._frame = frame
            self._message = message
        def getFrame(self):
            #图片信息
            return self._frame
        def getMessage(self):
            #文字信息,json格式
            return self._message
        def getBase64Frame(self):
            #返回base64格式的图片,将BGR图像转化为RGB图像
            jepg = IOUtil.array_to_bytes(self._frame[...,::-1])
            return IOUtil.bytes_to_base64(jepg)
        def getBase64FrameByte(self):
            #返回base64格式图片的bytes
            return bytes(self.getBase64Frame())
        def getJson(self):
            #获得json数据格式
            dicdata = {"frame":self.getBase64Frame().decode(),"message":self.getMessage()}
            return json.dumps(dicdata)
        def getBinaryFrame(self):
            return IOUtil.array_to_bytes(self._frame[...,::-1])
    

     

    运行之后在第一帧图像上选择要追踪的部分,这里测试了一下使用KCF算法的追踪器

    更新:忘记放utils,给大家造成的困扰深表歉意

    #encoding=utf-8
    import time
    import numpy
    import base64
    import os
    import logging
    import sys
    from settings import *
    from PIL import Image
    from io import BytesIO
    
    #工具类
    class IOUtil(object):
        #流操作工具类
        @staticmethod
        def array_to_bytes(pic,formatter="jpeg",quality=70):
            '''
            静态方法,将numpy数组转化二进制流
            :param pic: numpy数组
            :param format: 图片格式
            :param quality:压缩比,压缩比越高,产生的二进制数据越短
            :return: 
            '''
            stream = BytesIO()
            picture = Image.fromarray(pic)
            picture.save(stream,format=formatter,quality=quality)
            jepg = stream.getvalue()
            stream.close()
            return jepg
        @staticmethod
        def bytes_to_base64(byte):
            '''
            静态方法,bytes转base64编码
            :param byte: 
            :return: 
            '''
            return base64.b64encode(byte)
        @staticmethod
        def transport_rgb(frame):
            '''
            将bgr图像转化为rgb图像,或者将rgb图像转化为bgr图像
            '''
            return frame[...,::-1]
        @staticmethod
        def byte_to_package(bytes,cmd,var=1):
            '''
            将每一帧的图片流的二进制数据进行分包
            :param byte: 二进制文件
            :param cmd:命令
            :return: 
            '''
            head = [ver,len(byte),cmd]
            headPack = struct.pack("!3I", *head)
            senddata = headPack+byte
            return senddata
        @staticmethod
        def mkdir(filePath):
            '''
            创建文件夹
            '''
            if not os.path.exists(filePath):
                os.mkdir(filePath)
        @staticmethod
        def countCenter(box):
            '''
            计算一个矩形的中心
            '''
            return (int(abs(box[0][0] - box[1][0])*0.5) + box[0][0],int(abs(box[0][1] - box[1][1])*0.5) +box[0][1])
        @staticmethod
        def countBox(center):
            '''
            根据两个点计算出,x,y,c,r
            '''
            return (center[0][0],center[0][1],center[1][0]-center[0][0],center[1][1]-center[0][1])
        @staticmethod
        def getImageFileName():
            return time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())+'.png'
    
    #构造日志
    logger = logging.getLogger(LOG_NAME)
    formatter = logging.Formatter(LOG_FORMATTER)
    IOUtil.mkdir(LOG_DIR);
    file_handler = logging.FileHandler(LOG_DIR + LOG_FILE,encoding='utf-8')
    file_handler.setFormatter(formatter)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    logger.setLevel(logging.INFO)

     

关键字