Python实现最大堆(大顶堆)

发布时间:2019-08-26 07:17:13编辑:auto阅读(1985)

    最大堆是指最大的元素在堆顶的堆。

    Python自带的heapq模块实现的是最小堆,没有提供最大堆的实现。虽然有些文章通过把元素取反再放入堆,出堆时再取反,把问题转换为最小堆问题也能间接实现最大堆,但是这样的实现只适合数值型的元素,不适合自定义类型。

    下面给出实现代码:

    # -*- coding: UTF-8 -*- 
                    
    import random
    
    
    class MaxHeap(object):
    
        def __init__(self):
            self._data = []
            self._count = len(self._data)
    
        def size(self):
            return self._count
    
        def isEmpty(self):
            return self._count == 0
    
        def add(self, item):
            # 插入元素入堆
            self._data.append(item)
            self._count += 1
            self._shiftup(self._count-1)
    
        def pop(self):
            # 出堆
            if self._count > 0:
                ret = self._data[0]
                self._data[0] = self._data[self._count-1]
                self._count -= 1
                self._shiftDown(0)
                return ret
            
        def _shiftup(self, index):
            # 上移self._data[index],以使它不大于父节点
            parent = (index-1)>>1
            while index > 0 and self._data[parent] < self._data[index]:
                # swap
                self._data[parent], self._data[index] = self._data[index], self._data[parent]
                index = parent
                parent = (index-1)>>1
    
        def _shiftDown(self, index):
            # 上移self._data[index],以使它不小于子节点
            j = (index << 1) + 1
            while j < self._count :
                # 有子节点
                if j+1 < self._count and self._data[j+1] > self._data[j]:
                    # 有右子节点,并且右子节点较大
                    j += 1
                if self._data[index] >= self._data[j]:
                    # 堆的索引位置已经大于两个子节点,不需要交换了
                    break
                self._data[index], self._data[j] = self._data[j], self._data[index]
                index = j
                j = (index << 1) + 1
    
    # 元素是数值类型                
    def testIntValue():
        for iTimes in range(10):
            iLen = random.randint(1,300)
            allData= random.sample(range(iLen*100), iLen)
    #         allData = [1, 4, 3, 2, 5, 7, 6]
    #         iLen = len(allData)
            print('\nlen =',iLen)
            
            oMaxHeap = MaxHeap()
            print('_data:\t   ', allData)
            arrDataSorted = sorted(allData, reverse=True)
            print('dataSorted:', arrDataSorted)
            for i in allData:
                oMaxHeap.add(i)
            heapData = []    
            for i in range(iLen):
                iExpected = arrDataSorted[i]
                iActual = oMaxHeap.pop()
                heapData.append(iActual)
                print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual))
                assert iExpected==iActual, ""
            print('dataSorted:', arrDataSorted)
            print('heapData:  ',heapData)
            
    # 元素是元祖类型
    def testTupleValue():
        for iTimes in range(10):
            iLen = random.randint(1,300)
            listData= random.sample(range(iLen*100), iLen)
    #         listData = [1, 4, 3, 2, 5, 7, 6]
    #         iLen = len(listData)
            # 注意:key作为比较大小的关键
            allData = dict(zip(listData, [str(e) for e in listData]))
            print('\nlen =',iLen)
            print('allData: ', allData)
            
            oMaxHeap = MaxHeap()
            arrDataSorted = sorted(allData.items(), key=lambda d:d[0], reverse=True)
    #         arrDataSorted = sorted(allData, reverse=True)
            print('dataSorted:', arrDataSorted)
            for (k,v) in allData.items():
                oMaxHeap.add((k,v)) # 元祖的第一个元素作为比较点
            heapData = []    
            for i in range(iLen):
                iExpected = arrDataSorted[i]
                iActual = oMaxHeap.pop()
                heapData.append(iActual)
                print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual))
                assert iExpected==iActual, ""
            print('dataSorted:', arrDataSorted)
            print('heapData:  ',heapData)
            
    # 元素是自定义类    
    def testClassValue():
        
        class Model4Test(object):
            '''
            用于放入到堆的自定义类。注意要重写__lt__、__ge__、__le__和__cmp__函数。
            '''
            def __init__(self, sUid, value):
                self._sUid = sUid
                self._value = value
            
            def getUid(self):
                return self._sUid
            
            def getValue(self):
                return self._value
            
            # 类类型,使用的是小于号_lt_
            def __lt__(self, other):#operator < 
    #             print('in __lt__(self, other)')
                return self.getValue() < other.getValue()
           
            def __ge__(self,other):#oprator >=
                return self.getValue() >= other.getValue()
         
            #下面两个方法重写一个就可以了
            def __le__(self,other):#oprator <=
                return self.getValue() <= other.getValue()
             
            def __cmp__(self,other):
                #call global(builtin) function cmp for int
                return super.cmp(self.getValue(),other.getValue())
            
            def __str__(self):
                return '({0}, {1})'.format(self._value, self._sUid)
                
        for iTimes in range(10):
            iLen = random.randint(1,300)
            listData = random.sample(range(iLen*100), iLen)
    #         listData = [1, 4, 3, 2, 5, 7, 6]
            allData = [Model4Test(str(value), value) for value in listData]
            print('allData:   ', [str(e) for e in allData])
            iLen = len(allData)
            print('\nlen =',iLen)
    
            oMaxHeap = MaxHeap()
            arrDataSorted = sorted(allData, reverse=True)
            print('dataSorted:', [str(e) for e in arrDataSorted])
            for i in allData:
                oMaxHeap.add(i)
            heapData = []    
            for i in range(iLen):
                iExpected = arrDataSorted[i]
                iActual = oMaxHeap.pop()
                heapData.append(iActual)
                print('{0}, expected: {1}, actual: {2}'.format(iExpected==iActual, iExpected, iActual))
                assert iExpected==iActual, ""
            print('dataSorted:', [str(e) for e in arrDataSorted])
            print('heapData:  ', [str(e) for e in heapData])
                            
    if __name__ == '__main__':
        testIntValue()
        testTupleValue()
        testClassValue()

关键字