线段树

本文会探讨什么是线段树,线段树代码模块,以及相应的练习题。

什么是线段树

线段树是用利用完全二叉树的形式存储区间信息,并提供区间更改、查询的一种数据结构,例如下图中用来存储列表 [10,11,12,13,14][10, 11, 12, 13, 14] 区间和的线段树结构:


从上图中,叶子节点对应列表中的每个元素,而第 ii 个节点对应着列表中 [lefti,righti][\text{left}_i, \text{right}_i]这样的一个封闭区间的和。

线段树一般用来存储区间的和、最大值、最小值等信息。

本质上,线段树是利用二分法,构建了线段树索引和原始数组区间的一对一映射关系。

当我们对任意一个区间进行查询时,我们可以从根节点出发,访问我们要存储要查询区间的节点。例如,在下图中,我们访问 [5,5][5,5] 区间的区间和:


对于要查询的区间 [5,5][5, 5],从根节点出发,由于根节点中位数是 77,我们需要向左节点走…

线段树的组件

现在我们讨论线段树实现所需要的组件。在进行讨论之前,我们先给出一份修改数组并查询区间和的代码,接着再来解释每个组件的功能。

class SegmentationTree:

    def __init__(self, nums):
        self.values = defaultdict(int)
        self.__bulid(1, 0, len(nums)-1, nums)

    # 初始化    
    def __bulid(self, idx, l, r, nums):
        if l == r:
            self.values[idx] = nums[l]
            return 
        mid =  (r - l) // 2 + l
        self.__bulid(idx<<1, l, mid, nums)
        self.__bulid(idx<<1|1, mid+1, r, nums)
        self.pushup(idx)
    
    # 上推
    def pushup(self, idx):
        self.values[idx] = self.values[idx<<1] +  self.values[idx<<1|1]
    
    # 修改
    def update(self, idx, l, r, pos, val):
        if l > pos or r < pos:
            return
        if l == r:
            self.values[idx] = val
            # self.lazy = val
            return 
        mid = (r - l) // 2 + l
        # self.pushdown(idx)
        if pos <= mid:
            self.update(idx<<1, l, mid, pos, val)
        if pos > mid:
            self.update(idx<<1|1, mid+1, r, pos, val)
            
        self.pushup(idx)

    # 查询
    def query(self, idx, l, r, start, end):
        if l > end or r < start:
            return 0
        if start <= l and r <= end:
            return self.values[idx]
        mid = (r - l) // 2 + l
        ans = 0
        if start <= mid:
            ans += self.query(idx<<1, l, mid, start, end)
        if mid < end:
            ans += self.query(idx<<1|1 , mid+1, r, start, end)
        return ans


class NumArray:

    def __init__(self, nums: List[int]):
        self.tree = SegmentationTree(nums)
        self.n = len(nums)

    def update(self, index: int, val: int) -> None:
        self.tree.update(1, 0, self.n-1, index, val)

    def sumRange(self, left: int, right: int) -> int:
        return self.tree.query(1, 0, self.n-1, left, right)

初始化

在进行修改之前,我们要先在每个叶子节点上存储原有的数组值。

当数组的索引对应的区间长度为 11(即不可再分)时,我们就找到了叶子节点。由于区间和原始数组的索引含义相同,因此我们可以为每个叶子节点赋值。

这里我们采用的是递归的方法,类似树的遍历,再找到叶子节点后,为其前辈节点赋值。

def __bulid(self, idx, l, r, nums):
    if l == r:
        self.values[idx] = nums[l]
        return 
    mid =  (r - l) // 2 + l
    self.__bulid(idx<<1, l, mid, nums)
    self.__bulid(idx<<1|1, mid+1, r, nums)
    self.pushup(idx)

上推

我们在子节点修改的值,需要将修改信息逆向返回根节点,来维持涉及该后代的节点保持同步更新。

def pushup(self, idx):
    self.values[idx] = self.values[idx<<1] +  self.values[idx<<1|1]

这里的父节点和子节点的关系相对简单,父节点为子节点的加和。

更新

在该题中,因为我们只对叶子节点进行更新,因此只有当我们找到叶子节点时,进行修改。然后通过二分法判断我们该进入那个分治,最后的pushup语句则保证路径上节点同步修改信息。

def update(self, idx, l, r, pos, val):
    if l > pos or r < pos:
        return
    if l == r:
        self.values[idx] = val
        # self.lazy = val
        return 
    mid = (r - l) // 2 + l
    # self.pushdown(idx)
    if pos <= mid:
        self.update(idx<<1, l, mid, pos, val)
    if pos > mid:
        self.update(idx<<1|1, mid+1, r, pos, val)
        
    self.pushup(idx)

当然,实际中我们不太可能仅仅对单点进行更新,对于区间更新,参考以下模板。

def update(self, idx, l, r, start, end, val):
    if l > end or r < start:
        return 
    if start <= l and r <= end:
        if self.heights[idx] < val:
            self.heights[idx] = val
        if self.lazy[idx] < val:
            self.lazy[idx] = val
        return 
    mid = (r - l) // 2 + l
    self.pushdown(idx)
    if start <= mid:
        self.update(idx<<1, l, mid, start ,end, val)
    if end > mid:
        self.update(idx<<1|1, mid+1, r, start, end, val)
    self.pushup(idx)

这里有一个pushdown函数,我们会在下文中进行解释。

查询

查询和修改类似,只不过我们要把符合要求的结构返回,这里也是用到了一个递归的结构。

def query(self, idx, l, r, start, end):
    if l > end or r < start:
        return 0
    if start <= l and r <= end:
        return self.values[idx]
    mid = (r - l) // 2 + l
    ans = 0
    if start <= mid:
        ans += self.query(idx<<1, l, mid, start, end)
    if mid < end:
        ans += self.query(idx<<1|1 , mid+1, r, start, end)
    return ans

下推

这一节中我们会讨论一下,为什么会有lazy标签和pushdown函数。

例如,我们在下图中,对 [3,4][3, 4] 区间的每个元素加上 55。根据前面的区间更新代码,我们可以得到在线段树第 33 个节点,也就是 d[3]d[3] 我们就可以找到了区间,并给这个区间的值加上 1010(个数 ×\times 增量)。

注意此时,我们的子节点信息 d[6]d[6]d[7]d[7] 是没有更改的,如果不做处理,那么我们对 d[6]d[6] 进行查询,就会返回错误答案。

我们当然可以将区间的信息 下传 到每一个子节点,那这样在糟糕的情况下得到的时间复杂度是 O(n)O(n)

另一种方法是,我们使用一个 tag,叫做懒惰标签,来存储这个区间的信息。只有当我们要对这个区间的子节点进行查询时,我们再把这个累计的区间修改信息下传给子节点。


def pushdown(self, idx):
    if self.lazy[idx]:
        self.lazy[idx << 1] = self.lazy[idx]
        self.lazy[idx << 1|1] = self.lazy[idx]
        self.tree[idx << 1] = self.lazy[idx]
        self.tree[idx << 1|1] = self.lazy[idx]
        self.lazy[idx] = 0
    return

小结

这一节中介绍的除了查询、上推和修改之外,不是每个组件都是必须的。

对于主要的 修改/查询 功能,我们通过递归和二分的方法,找到修改区间/节点,来修改信息,结尾的 pushup 方法是保证我们的修改能够更新到路径中的父节点。具体的 pushup 采用什么样的方式,以来的是题目中父节点和子节点的关系。

笼统的讲,

  • 按照修改的范围,我们可以分为区间修改和单点修改两种;
  • 按照修改的方式,我们可以分为覆盖式修改和增量修改。(通常,覆盖修改不需要懒标签,因为父区间的值就已经包含了修改信息)

以上均是幻想时间👀,多做题多练习。

例题

思路和参考 首先轮廓只可能是在纵坐标的变化点出现,这里有两个含义:
  1. 纵坐标的变化对应着轮廓在 x 这一点 max y 的改变;
  2. x 一定是建筑的左右断点。

这样我们就用扫描线的方法,从左向右查询这一点上的最大值和前一个最大值进行比较,如果有变化就存储。

于是现在的问题变成如何查询某一点的最大值,对应着线段树中的单点查询。整体的流程是:

  1. 对建筑构建线段树,存储区间最大值,注意的pushdown函数,以及右端点范围;
  2. 从左向右扫描线。
from collections import defaultdict


class SegmentationTree:
    def __init__(self) -> None:
        self.heights = defaultdict(int)
        self.lazy = defaultdict(int)
    
    def update(self, idx, l, r, start, end, val):
        if l > end or r < start:
            return 
        if start <= l and r <= end:
            # print(self.heights[idx], val)
            if self.heights[idx] < val:
                self.heights[idx] = val
            if self.lazy[idx] < val:
                self.lazy[idx] = val
            return 
        mid = (r - l) // 2 + l
        self.pushdown(idx)
        if start <= mid:
            self.update(idx<<1, l, mid, start ,end, val)
        if end > mid:
            self.update(idx<<1|1, mid+1, r, start, end, val)
        self.pushup(idx)

    def pushdown(self, idx):
        if self.lazy[idx]:
            if self.heights[idx<<1] < self.lazy[idx]:
                self.heights[idx<<1] = self.lazy[idx]
            if self.lazy[idx<<1] < self.lazy[idx]:
                self.lazy[idx<<1] = self.lazy[idx]
            if self.heights[idx<<1|1] < self.lazy[idx]:
                self.heights[idx<<1|1] = self.lazy[idx]
            if self.lazy[idx<<1|1] < self.lazy[idx]:
                self.lazy[idx<<1|1] = self.lazy[idx]
            self.lazy[idx] = 0
    
    def pushup(self, idx):
        self.heights[idx] = max(self.heights[idx<<1], self.heights[idx<<1|1])

    def query(self, idx, l, r, start, end):
        if l > end or r < start:
            return -1
        if start <= l and r <= end:
            return self.heights[idx]
        mid = (r-l) //2 + l
        self.pushdown(idx)
        ans = -1
        if start <= mid:
            ans = max(ans, self.query(idx<<1, l, mid, start, end))
        if mid < end:
            ans = max(ans, self.query(idx<<1|1, mid+1, r, start ,end))
        return ans
    
class Solution:
    def getSkyline(self, buildings: List[List[int]]) -> List[List[int]]:
        # 本质是记录x区间的最大值
        # 从左向右sweep,如果最大值发生了变化则记录下来

        x_list = sum([[x1, x2] for x1, x2, _ in buildings], [])
        x_list = sorted(list(set(x_list)))
        x_rank_map = {val:i for i, val in enumerate(x_list)}

        # 记录区间
        n = len(x_list)
        tree = SegmentationTree()

        for x1, x2, h in buildings:
            rank1, rank2 = x_rank_map[x1], x_rank_map[x2]
            # print(rank1, rank2, h)
            tree.update(1, 0, n-1, rank1, rank2-1, h)

        pre_max_h = -1
        ans = []
        for x in x_list:
            rank = x_rank_map[x]
            cur_max_h = tree.query(1, 0, n-1, rank, rank)
            if cur_max_h != pre_max_h:
                ans.append([x, cur_max_h])
            pre_max_h = cur_max_h
        return ans

pushdown函数对应的场景是,例如,我们左边的建筑很大,现在在右边空缺的范围添加一个低一点的建筑,由于两个区间的有重合,直接覆盖式子节点,会得到错误答案。

思路和参考

和天际线类似,同样的扫描线,不过有所区别的是,这里我们要求解的是区域,对应的是 x 覆盖的 y 范围。

这里矩形的左边对应增加,右边对应退出覆盖范围。所以我们要有一个 cnt 数组来记录当前 y 是否被覆盖。

整体的思路是,我们将 x从小到大去重排序,左边对应 cnt++\text{cnt} ++,右边对应 cnt\text{cnt} --。利用扫描线,对于下标 ii 找到 和其对应横坐标不同的下标 jj(即为横坐标长度),然后修改对应矩形在 y 轴区间的 cnt 值。 最后访问第一个节点得到整个 y 轴的覆盖区间的总长度,即得到面积。

from collections import defaultdict


class Node:
    def __init__(self, cnt=0, cover_length=0, length=0):
        self.cnt = cnt
        self.cover_length = cover_length
        self.length = length
    
    def __repr__(self) -> str:
        return "the interval is counted {}, times, and covered length is {}, max length is {}".format(self.cnt, self.cover_length, self.length)

class SegmentationTree:

    def __init__(self, nums) -> None:
        """
        叶子节点是每个y轴值(去除掉开始值)
        """
        self.tree = defaultdict(lambda : Node())
        self.__build(1, 1, len(nums)-1, nums)

    def __build(self, idx, l, r, nums):
        node = self.tree[idx]
        if l == r:
            node.length = nums[l] - nums[l-1]
            return
        mid = (l + r) >> 1
        self.__build(idx << 1, l, mid, nums)
        self.__build(idx << 1 | 1, mid+1, r, nums)
        node.length = self.tree[idx<<1].length + self.tree[idx<<1|1].length
    
    def update(self, idx, l, r, start, end, val):
        if r < start or end < l:
            return 
        node = self.tree[idx]
        if start <= l and r <= end:
            node.cnt += val
        else:
            mid = (l + r) >> 1
            if start <= mid:
                self.update(idx << 1, l, mid, start, end, val)
            if mid < end:
                self.update(idx << 1 | 1, mid+1, r, start, end, val)
        self.pushup(idx, l, r)
    
    def pushup(self, idx, l, r):
        node = self.tree[idx]
        if node.cnt > 0:
            node.cover_length = node.length
        elif l == r:
            node.cover_length = 0
        else:
            node.cover_length = self.tree[idx<<1].cover_length + self.tree[idx<<1|1].cover_length

        
class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        y_set, sweep = set(), []

        for i, rec in enumerate(rectangles):
            x1, y1, x2, y2 = rec
            y_set.update((y1, y2))
            sweep.extend([[x1, i, 1], [x2, i, -1]])
        
        y_list = sorted(list(y_set))
        y_map = {val:i for i, val in enumerate(y_list)}
        sweep.sort()

        idx = ans = 0
        s_num = len(sweep)
        tree = SegmentationTree(y_list)
        n = len(y_list)


        while idx < s_num:
            j = idx
            while j + 1 < s_num and sweep[idx][0] == sweep[j+1][0]:
                j += 1
            
            if j+1 == s_num:
                break
            
            for i in range(idx, j + 1):
                _, rec_idx, diff = sweep[i]
                _, y1, _, y2 = rectangles[rec_idx]

                y1, y2 = y_map[y1], y_map[y2]
                tree.update(1, 1, n-1, y1+1, y2, diff)

            ans += tree.tree[1].cover_length * (sweep[j+1][0] - sweep[idx][0])

            idx = j + 1
        
        return ans % (10**9 + 7)

这里有很多tricks。首先就是 pushup 函数,如果 cnt 大于0我们返回的是区间总长度,如果叶子节点则返回 00,否则返回两个子节点的加和。对应的所有区间,即使叶子节点也要pushup

另外这里没有lazy标签,这里是用到了扫描线的性质,copy一段解析

考虑删除某一条线段。在删除线段后,需要动态地维护 sum 的值,但是删除当前线段后,剩下的线段覆盖了多大的区间是无法通过当前被覆盖的区间长度计算出来的。一种处理方法是查询子结点被覆盖的区间长度,但是子节点当前记录的 sum 值是线段还未被删除时的 sum 值,想让子节点记录的 sum 值正确,就需要把当前节点的删除操作传递到子节点,这意味着这次操作将无法通过打懒惰标记的方式偷懒,因此复杂度会上升到 O(n2)O(n^2),用了线段树和没用一样。

一个简单的理解就是,在最开始的单点查询中,我们修改了一个区间的加和值,那么后续的单点查询很麻烦。

思路和参考 在掉落方块之前,我们要查询方块对应的区间的最大值,在比较增加边长后的区间最大值和全局最大值进行比较后,我们要把这个变化后的区间最大值覆盖式更新。

因为我们先查询的最大值,同时又是增加,所以不存在天际线中 pushdown 出现的问题。

from collections import defaultdict


class SegmentationTree:

    def __init__(self) -> None:
        self.tree = defaultdict(int)  # 统计最大高度信息
        self.lazy = defaultdict(int)
    
    def update(self, idx, l, r, start, end, val):
        if l > end or r < start:
            return 
        
        if start <= l and r <= end:
            self.tree[idx] = val
            self.lazy[idx] = val
            return
        mid = (r + l) >> 1
        self.pushdown(idx)
        if start <= mid:
            self.update(idx<<1, l, mid, start, end, val)
        if mid < end:
            self.update(idx<<1|1, mid+1, r, start, end, val)
        self.pushup(idx, l, r)
    
    def pushup(self, idx, l, r):
        # if l == r:
        #     return 
        self.tree[idx] = max(self.tree[idx<<1], self.tree[idx<<1|1])
    
    def pushdown(self, idx):
        if self.lazy[idx]:
            self.lazy[idx << 1] = self.lazy[idx]
            self.lazy[idx << 1|1] = self.lazy[idx]
            self.tree[idx << 1] = self.lazy[idx]
            self.tree[idx << 1|1] = self.lazy[idx]
            self.lazy[idx] = 0
        return

    def query(self, idx, l, r, start, end):
        if l > end or r < start:
            return 0
        if start <= l and r <= end:
            return self.tree[idx]
        mid = (r - l) // 2 + l
        self.pushdown(idx)
        ans = 0
        if start <= mid:
            ans = max(ans, self.query(idx<<1, l, mid, start, end))
        if mid < end:
            ans = max(ans, self.query(idx<<1|1, mid+1, r, start, end))
        return ans

class Solution:
    def fallingSquares(self, positions: List[List[int]]) -> List[int]:
        # 对区间进行count
        x_list = sum([[x, x+l] for x, l in positions], [])
        x_list = sorted(list(set(x_list)))
        n = len(x_list) - 1

        x_rank_map = {val:i for i, val in enumerate(x_list)}
        seg_tree = SegmentationTree()

        ans = [-1]
        for x1, l in positions:
            x2 = x1 + l
            rank1, rank2 = x_rank_map[x1], x_rank_map[x2]

            # 先query这个区间的最大值,然后modify
            hx = seg_tree.query(1, 1, n, rank1+1, rank2) + l
            ans.append(max(ans[-1], hx))
            seg_tree.update(1, 1, n, rank1+1, rank2, hx)
        
        return ans[1:]
思路和参考 用上面的模板是不需要考虑动态开点问题的(因为他本是就是动态开点),最需要注意的是我们不需要懒惰标签,以及在添加和删除时赋予的值。
class SegmentationTree:

    def __init__(self):
        self.tree = defaultdict(int)
    
    def update(self, idx, l, r, start, end, val):
        if l > end or r < start:
            return 
        
        if start <= l and r <=  end:
            self.tree[idx] = val 
            return 
        mid = (r - l) // 2 + l
        self.pushdown(idx)
        if mid >= start:
            self.update(idx<<1, l, mid, start, end, val)
        if mid < end:
            self.update(idx<<1|1, mid+1, r, start, end, val)
        self.pushup(idx)
    
    def pushdown(self, idx):
        if self.tree[idx]:
            self.tree[idx<<1] = self.tree[idx] 
            self.tree[idx<<1|1] = self.tree[idx]
    
    def pushup(self, idx):
        self.tree[idx] = self.tree[idx<<1] & self.tree[idx<<1|1]
    
    def query(self, idx, l, r, start, end):
        if l> end or r < start:
            return True
        
        if start <= l and r <= end:
            return self.tree[idx] == 1

        mid = (r -l) //2 + l
        self.pushdown(idx)
        ans = True
        if start <= mid:
            ans &= self.query(idx <<1, l, mid, start, end)
        if mid < end:
            ans &= self.query(idx <<1|1, mid+1, r, start, end)
        return ans

class RangeModule:

    def __init__(self):
        self.l_idx, self.r_idx = 1, pow(10, 9)
        self.tree = SegmentationTree()

    def addRange(self, left: int, right: int) -> None:
        self.tree.update(1, self.l_idx, self.r_idx, left, right-1, 1)
        
        
    def queryRange(self, left: int, right: int) -> bool:
        return self.tree.query(1, self.l_idx, self.r_idx, left, right-1)


    def removeRange(self, left: int, right: int) -> None:
        self.tree.update(1, self.l_idx, self.r_idx, left, right-1, 2)
思路和参考 这一道题我们用到了最大公约数的性质,二分法和线段树。

首先,给定ABA\sub B的两个集合,我们可以得到

gcd(A)gcd(B).\rm{gcd}(A) \geq \rm{gcd}(B).

即最大公约数的单调性。

那么对于每一个坐标ii,我们寻找其左端点,分别找到最大公约数小于 kk 的区间和最大公约数小于等于 kk 的区间,两者之差就是长度。

线段树用来存储最大公约数的值,这里需要注意的query 函数。跟上面的模板不一样的是,这样写不用考虑空区间的返回值(因为不会查询空区间)。

二分法可以查看我写的二分法blog

from collections import defaultdict
import math

class SegmentationTree:

    def __init__(self, nums) -> None:
        self.tree = defaultdict(int)
        self.__build(1, 0, len(nums)-1, nums)
    
    def __build(self, idx, l, r, nums):
        if l == r:
            self.tree[idx] = nums[l]
            return 
        mid = (l+r) >> 1
        self.__build(idx << 1, l, mid, nums)
        self.__build(idx << 1|1, mid + 1, r, nums)
        self.pushup(idx, l, r)
    
    def pushup(self, idx, l, r):
        self.tree[idx] = math.gcd(self.tree[idx<<1], self.tree[idx<<1|1])
    
    def query(self, idx, l, r, start, end):
        # if l > end or r < start: return self.k
        if start <= l and r <= end:
            return self.tree[idx]
        mid = (r + l) >> 1
        if end <= mid:
            return self.query(idx<<1, l, mid, start, end)
        if mid < start:
            return self.query(idx<<1|1, mid+1, r, start, end)
        return math.gcd(self.query(idx<<1, l, mid, start, end), self.query(idx<<1|1, mid+1, r, start, end))
    
class Solution:
    def subarrayGCD(self, nums: List[int], k: int) -> int:
        n = len(nums)

        # 区间从0开始count,最大可取到n-1
        segment_tree = SegmentationTree(nums)
        ans = 0
        # for i, val in segment_tree.tree.items():
        #     print(i, val)
        # print(segment_tree.query(1, 0, n-1, 0, 1))

        for i, x in enumerate(nums):

            left_1, right_1 = 0, i

            # left_1左侧小于k
            while left_1 <= right_1:
                mid = (right_1 - left_1) // 2 + left_1
                if segment_tree.query(1, 0, n-1, mid, i) < k:
                    left_1 = mid + 1  # 保证左侧小于k
                else:
                    right_1 = mid - 1  # 保证右侧大于等于k
            # left_2左侧小于等于k(大于k的最小值)
            left_2, right_2 = 0, i
            while left_2 <= right_2:
                mid = (right_2 - left_2) // 2 + left_2 
                if segment_tree.query(1, 0, n-1, mid, i) <= k:
                    left_2 = mid + 1  # 保证左侧小与等于k
                else:
                    right_2 = mid - 1  # 保证右侧大于k
            # print(left_1, left_2, right_1, right_2)
            ans += right_2 - left_1 + 1
        return ans

参考资料