当前位置:首页>文章>数据结构>线段树详解:原理、构建、区间查询与更新(Python 实现)

线段树详解:原理、构建、区间查询与更新(Python 实现)

线段树详解教程

一、什么是线段树?

线段树也叫区间树;线段树是一种二叉搜索树,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点;

二、为什么要使用线段树?

在解释这个问题先让我们看一个经典的问题,区间染色;

假设有一面墙,长度为n,每次选择一段墙来进行染色,如下图所示:

线段树详解:原理、构建、区间查询与更新(Python 实现)

首先定义一端长度为n的墙;

线段树详解:原理、构建、区间查询与更新(Python 实现)

然后将4-9这个区间染成黄色;

线段树详解:原理、构建、区间查询与更新(Python 实现)

在将7-15染成绿色;

线段树详解:原理、构建、区间查询与更新(Python 实现)

再将1-5染成蓝色;

线段树详解:原理、构建、区间查询与更新(Python 实现)

再将6-12染成红色;

最后问题来了:

  1. 问m次操作后我们可以看见多少种颜色?
  2. m次操作后,我们可以在[i, j]区间看见多少种颜色?

不管怎样,对于这个问题来说,我们关注的是一个一个的区间,其实对于这个问题只需要两种操作就可以搞定,一个是染色操作(也就是所谓的区间更新),一个是查询操作(区间查询);但是由于我们这个是使用的列表来定义的一面墙,所以查询和更新的操作的时间复杂度都为O(n)级别的,如果我们直接使用列表进行操作系统的系统开销就太大了;因为对于这个问题我们关注点的是一个一个的区间,所以线段树在这里就有了用武之地了;

接下来我们在看一个计算机领域经典的区间查询问题

线段树详解:原理、构建、区间查询与更新(Python 实现)

查询一个区间[i, j]的最大值 最小值或者该区间的数字之和;

之前我们的数据结构都是对单个的数据进行操作,显然是不适用于这这样的场景的;这里如果我们将这个问题换成对区间内的数据进行操作会是怎么样的?其实这个问题的本质就是基于区间的统计查询。

放在现在的互联网环境下也有很多这样的问题需要使用到区间查询的操作来完成?

比如:一个电商网站去年一年注册的用户中消费最高的用户是谁?消费最少的用户是谁?

其实这样的问题我们需要注意的是:我们关注的仍然是动态的情况,在这种情况下我们使用线段树会是一个很好的选择;因为动态统计的时候会伴有两个操作一个更新一个查询,如果使用列表来进行操作会极大的增加性能开销;

接下来我们来对比一下线段树和列表在对这类问题方面的性能开销:

线段树详解:原理、构建、区间查询与更新(Python 实现)

三、线段树概念详解

以上的实例我们都可以将列表直接转换成线段树,

线段树详解:原理、构建、区间查询与更新(Python 实现)

这个一个列表,通过以上的实例我们不难发现在线段树中是没有添加和删除操作的,所以转换以后的线段树是这样的:

线段树详解:原理、构建、区间查询与更新(Python 实现)

前边也说过线段树是二分搜索树的一种,不同的地方在与线段树每个节点存储的是一个区间,我们以求区间数字之和的问题来进行解析,上图中每一个节点存储一个区间的数据,例如根节点存储的就是整个列表的数据,左右两个子节点存储的就是[0 - 3][4 - 7]这个区间的数据,以此类推;如果我们要操作这些区间数据我们只需要找到这些节点即可,以列表中[4 - 7]这个区间为例:

线段树详解:原理、构建、区间查询与更新(Python 实现)

对于线段树来说有时候我们需要查询的区间需要进行一次合成的操作,比如说在这个实例中我们要查询区间[2 - 5]的数据,就会变成下边的情况:

线段树详解:原理、构建、区间查询与更新(Python 实现)

我们需要先找到A[2 - 3]和A[4 - 5]这两个节点,然后对这两个节点在进行合并的操作;

所以在大数量的情况下,如果我们要操作区间数据的话,使用线段树可以很高效的解决我们的问题,而不会像使用列表那样需要先遍历一边所有的元素,这就是线段树的优势所在。

四、线段树的基础表示

  1. 在之前的例子中我们看到了线段树是一个二分搜索树,但是线段树不一定是一颗满的二叉树,比如说如下图所示:

线段树详解:原理、构建、区间查询与更新(Python 实现)

这是一位,我们的列表是10个元素的列表,根节点保存着10个元素的数据,根节点下边的左右子节点分别保存5个元素,但是左右子节点在往下分的话因为5不能被整除的元素就导致其下边的节点必然会出现元素个数不一致的情况,就像这样的情况:

线段树详解:原理、构建、区间查询与更新(Python 实现)

从图示中可以看出,这个线段树的叶子节点不一定是在最下边一层的,这也表示线段树不一定是满的二叉树同时也不一定是完全二叉树,但是线段树确实一颗平衡二叉树;

平衡二叉树:从根节点开始到叶子节点的深度的差值(最大深度和最小深度)不超过1,从这里我们也可以得出我们的堆也是一颗平衡二叉树,因为完全二叉树本身就是一种平衡二叉树;

  1. 平衡二叉树的好处就是不会像二分搜索树那样退化成一个链表,在平衡二叉树上进行搜索是非常高效的;

  2. 我们可以使用列表来表示平衡二叉树,因为我们可以将列表A看成是一个完全二叉树,虽然它最底层的叶子节点有一些是没有的,我们可以将这些节点看成None,这样这个平衡二叉树就会变成完全二叉树,而完全二叉树我们完全可以使用列表来表示;

  3. 那么问题就来了,如果这个区间有n个元素,我们使用列表表示需要多少个节点?

线段树详解:原理、构建、区间查询与更新(Python 实现)

如图所示,如果区间中有n个元素的话,我们的空间需要2n,因为上边的空间是下边空间之和,但是这有一个最坏的情况,就是我们的列表是奇数的,比如说大小为5这样的情况,那么此时的线段树存储空间就应该是下边这样的:

线段树详解:原理、构建、区间查询与更新(Python 实现)

因为区间元素个数为n的话,我们需要2n的空间来进行存储,如果多出一个元素的话我们就需要4n的空间来进行存储,最后我们的线段树模型就应该像下边图示的那样:

线段树详解:原理、构建、区间查询与更新(Python 实现)

我们将空的节点全部复制为None,这样就可以让它满足完全二叉树的定义;

注意:在这里我们是浪费了一些存储空间的,由于我们定义的线段树是没有插入操作的也就是说是静态的,那么我们为了性能是完全可以牺牲掉这些空间的;

代码实现线段树基础:

class SegmentTree:

    def __init__(self, arr, merger=None):
        """
        初始化线段树
        :param arr: 输入数组
        :param merger: 合并函数,用于合并区间元素
        """
        self.merger = merger
        self.data = arr[:]  # 复制数组
        self.tree = [None] * (4 * len(arr))  # 线段树数组

        if len(arr) > 0:
            self._build_segment_tree(0, 0, len(self.data) - 1)

    def get_size(self):
        """
        获取数组长度
        :return: 数组长度
        """
        return len(self.data)

    def get(self, index):
        """
        获取当前索引上的元素
        :param index: 索引
        :return: 元素值
        """
        if index < 0 or index >= len(self.data):
            raise IndexError("Index is illegal.")
        return self.data[index]

    def _left_child(self, index):
        """
        返回完全二叉树的列表表示中,一个索引上元素的左节点索引
        :param index: 当前索引
        :return: 左子节点索引
        """
        return 2 * index + 1

    def _right_child(self, index):
        """
        返回完全二叉树的列表表示中,一个索引上元素的右节点索引
        :param index: 当前索引
        :return: 右子节点索引
        """
        return 2 * index + 2

五、创建线段树

  1. 首先以求和为例我们先来看一下线段树的模型

线段树详解:原理、构建、区间查询与更新(Python 实现)

在这个图示中,我们的列表长度为10,所以线段树的根节点存储的就是10个元素的和,下边的左右子节点以及各个节点存储的都是相应区间元素的和,如果要创建这样的一个线段树我们就需要使用递归的方法进行创建;

代码实现创建过程

def _build_segment_tree(self, tree_index, l, r):
    """
    在tree_index的位置创建表示区间[l....r]的线段树
    :param tree_index: 树中的索引
    :param l: 区间左边界
    :param r: 区间右边界
    """
    if l == r:
        self.tree[tree_index] = self.data[l]
        return

    left_tree_index = self._left_child(tree_index)
    right_tree_index = self._right_child(tree_index)

    mid = l + (r - l) // 2

    self._build_segment_tree(left_tree_index, l, mid)
    self._build_segment_tree(right_tree_index, mid + 1, r)

    if self.merger:
        self.tree[tree_index] = self.merger(self.tree[left_tree_index], 
                                           self.tree[right_tree_index])

def __str__(self):
    """
    重写__str__方法
    :return: 字符串表示
    """
    res = []
    res.append('[')
    for i in range(len(self.tree)):
        if self.tree[i] is not None:
            res.append(str(self.tree[i]))
        else:
            res.append("None")

        if i != len(self.tree) - 1:
            res.append(", ")

    res.append(']')
    return ''.join(res)

测试

def main():
    nums = [-2, 0, 3, -5, 2, -1]

    # 创建求和的线段树
    seg_tree = SegmentTree(nums, lambda a, b: a + b)

    print(seg_tree)

if __name__ == "__main__":
    main()

六、线段树中的区间查询

基于递归我们可以很轻松的实现线段树的区间查询操作,

线段树详解:原理、构建、区间查询与更新(Python 实现)

比如说基于这个线段树我们要查询区间为2-5的统计信息;

线段树详解:原理、构建、区间查询与更新(Python 实现)

我们可以先从根节点进行查询;

线段树详解:原理、构建、区间查询与更新(Python 实现)

根据图示我们可以得出要查询[2,5]这个区间的数据就需要查询根节点下左右两个节点的元素,左节点查询[2,3],右节点查询[4,5];

线段树详解:原理、构建、区间查询与更新(Python 实现)

由于左右[2,3]和[4,5]这两个节点都有父节点,在这里我们就可以使用递归进行查询;

代码实现线段树区间的查询操作:

def query(self, query_l, query_r):
    """
    返回区间[query_l, query_r]的值
    :param query_l: 查询区间左边界
    :param query_r: 查询区间右边界
    :return: 查询结果
    """
    if (query_l < 0 or query_l >= len(self.data) or 
        query_r < 0 or query_r >= len(self.data) or 
        query_l > query_r):
        raise IndexError("Index is illegal.")

    return self._query(0, 0, len(self.data) - 1, query_l, query_r)

def _query(self, tree_index, l, r, query_l, query_r):
    """
    线段树区间查询的核心方法
    在以tree_index为根的线段树中[l...r]的范围里,查找[query_l...query_r]的值
    :param tree_index: 树节点索引
    :param l: 当前节点表示区间的左边界
    :param r: 当前节点表示区间的右边界
    :param query_l: 查询区间左边界
    :param query_r: 查询区间右边界
    :return: 查询结果
    """
    if l == query_l and r == query_r:
        return self.tree[tree_index]

    mid = l + (r - l) // 2
    left_tree_index = self._left_child(tree_index)
    right_tree_index = self._right_child(tree_index)

    if query_l >= mid + 1:
        return self._query(right_tree_index, mid + 1, r, query_l, query_r)
    elif query_r <= mid:
        return self._query(left_tree_index, l, mid, query_l, query_r)

    left_result = self._query(left_tree_index, l, mid, query_l, mid)
    right_result = self._query(right_tree_index, mid + 1, r, mid + 1, query_r)

    return self.merger(left_result, right_result)

测试:

def main():
    nums = [-2, 0, 3, -5, 2, -1]

    seg_tree = SegmentTree(nums, lambda a, b: a + b)
    print(seg_tree)

    print(seg_tree.query(0, 2))
    print(seg_tree.query(2, 5))
    print(seg_tree.query(0, 5))

if __name__ == "__main__":
    main()

七、线段树中的更新的操作

1️⃣ 代码实现更新操作

def set(self, index, e):
    """
    将index位置上的值更新为e
    :param index: 要更新的索引
    :param e: 新的值
    """
    if index < 0 or index >= len(self.data):
        raise IndexError("Index is illegal.")

    self.data[index] = e
    self._set(0, 0, len(self.data) - 1, index, e)

def _set(self, tree_index, l, r, index, e):
    """
    在以tree_index为根的线段树中,更新index的值为e
    :param tree_index: 树节点索引
    :param l: 当前节点表示区间的左边界
    :param r: 当前节点表示区间的右边界
    :param index: 要更新的索引
    :param e: 新的值
    """
    if l == r:
        self.tree[tree_index] = e
        return

    mid = l + (r - l) // 2
    left_tree_index = self._left_child(tree_index)
    right_tree_index = self._right_child(tree_index)

    if index >= mid + 1:
        self._set(right_tree_index, mid + 1, r, index, e)
    else:
        self._set(left_tree_index, l, mid, index, e)

    if self.merger:
        self.tree[tree_index] = self.merger(self.tree[left_tree_index], 
                                           self.tree[right_tree_index])

完整的线段树类实现

class SegmentTree:

    def __init__(self, arr, merger=None):
        """
        初始化线段树
        :param arr: 输入数组
        :param merger: 合并函数,用于合并区间元素
        """
        self.merger = merger
        self.data = arr[:]  # 复制数组
        self.tree = [None] * (4 * len(arr))  # 线段树数组

        if len(arr) > 0:
            self._build_segment_tree(0, 0, len(self.data) - 1)

    def get_size(self):
        """获取数组长度"""
        return len(self.data)

    def get(self, index):
        """获取当前索引上的元素"""
        if index < 0 or index >= len(self.data):
            raise IndexError("Index is illegal.")
        return self.data[index]

    def _left_child(self, index):
        """返回左子节点索引"""
        return 2 * index + 1

    def _right_child(self, index):
        """返回右子节点索引"""
        return 2 * index + 2

    def _build_segment_tree(self, tree_index, l, r):
        """构建线段树"""
        if l == r:
            self.tree[tree_index] = self.data[l]
            return

        left_tree_index = self._left_child(tree_index)
        right_tree_index = self._right_child(tree_index)

        mid = l + (r - l) // 2

        self._build_segment_tree(left_tree_index, l, mid)
        self._build_segment_tree(right_tree_index, mid + 1, r)

        if self.merger:
            self.tree[tree_index] = self.merger(self.tree[left_tree_index], 
                                               self.tree[right_tree_index])

    def query(self, query_l, query_r):
        """区间查询"""
        if (query_l < 0 or query_l >= len(self.data) or 
            query_r < 0 or query_r >= len(self.data) or 
            query_l > query_r):
            raise IndexError("Index is illegal.")

        return self._query(0, 0, len(self.data) - 1, query_l, query_r)

    def _query(self, tree_index, l, r, query_l, query_r):
        """递归查询实现"""
        if l == query_l and r == query_r:
            return self.tree[tree_index]

        mid = l + (r - l) // 2
        left_tree_index = self._left_child(tree_index)
        right_tree_index = self._right_child(tree_index)

        if query_l >= mid + 1:
            return self._query(right_tree_index, mid + 1, r, query_l, query_r)
        elif query_r <= mid:
            return self._query(left_tree_index, l, mid, query_l, query_r)

        left_result = self._query(left_tree_index, l, mid, query_l, mid)
        right_result = self._query(right_tree_index, mid + 1, r, mid + 1, query_r)

        return self.merger(left_result, right_result)

    def set(self, index, e):
        """更新操作"""
        if index < 0 or index >= len(self.data):
            raise IndexError("Index is illegal.")

        self.data[index] = e
        self._set(0, 0, len(self.data) - 1, index, e)

    def _set(self, tree_index, l, r, index, e):
        """递归更新实现"""
        if l == r:
            self.tree[tree_index] = e
            return

        mid = l + (r - l) // 2
        left_tree_index = self._left_child(tree_index)
        right_tree_index = self._right_child(tree_index)

        if index >= mid + 1:
            self._set(right_tree_index, mid + 1, r, index, e)
        else:
            self._set(left_tree_index, l, mid, index, e)

        if self.merger:
            self.tree[tree_index] = self.merger(self.tree[left_tree_index], 
                                               self.tree[right_tree_index])

    def __str__(self):
        """字符串表示"""
        res = []
        res.append('[')
        for i in range(len(self.tree)):
            if self.tree[i] is not None:
                res.append(str(self.tree[i]))
            else:
                res.append("None")

            if i != len(self.tree) - 1:
                res.append(", ")

        res.append(']')
        return ''.join(res)

# 测试代码
def main():
    nums = [-2, 0, 3, -5, 2, -1]

    # 创建求和的线段树
    seg_tree = SegmentTree(nums, lambda a, b: a + b)
    print(seg_tree)

    # 测试查询
    print(seg_tree.query(0, 2))  # 查询[0,2]区间和
    print(seg_tree.query(2, 5))  # 查询[2,5]区间和  
    print(seg_tree.query(0, 5))  # 查询[0,5]区间和

    # 测试更新
    seg_tree.set(1, 10)  # 将索引1的值更新为10
    print(seg_tree.query(0, 2))  # 再次查询[0,2]区间和

if __name__ == "__main__":
    main()
数据结构

优先队列与堆(Heap)详解:概念、实现、Heapify 与应用

2025-8-18 10:09:09

数据结构

Trie字典树详解 - 基础原理与代码实现

2025-8-20 9:40:43

搜索