数据结构之线段树

2022-05-10数据结构与算法

线段树(Segment Tree)是常用的用来维护区间信息的数据结构,其可以在 O(log n)O(log\ n) 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。本文介绍线段树的实现方法。

线段树的一个典型应用是 LeetCode 307. Range Sum Query - Mutable。如果在不可变的数组上区间求和,我们可以利用前缀和。但是在可变数组上区间求和,前缀和每次更新的复杂度将会达到 O(n)O(n)。相比之下,线段树的效率高得多,而且支持的查询也更多。

线段树举例

我们假设要在数组 nums 上建立一个线段树,那么线段树的每个结点都代表一个区间的计算值,该计算值可以是区间求和、区间最大值和最小值等等。举例来说,对于数组 [3,5,9,4],我们要做区间求和线段树,则该树如下图:

线段树举例

可以看到,每个结点包含两部分信息:区间和值。除叶子结点外,每个结点都一定有左右子结点,分别承载半个区间。这一特性决定了线段树可以用数组来表示。因为数组 [3,5,9,4] 的长度正好是 2 的次方,所以其线段树是一颗满二叉树

那么,一个线段树是否一定是完全二叉树不一定。比如对于一个长度为 5 的数组,其线段树就不是完全二叉树,大家可以自己尝试画一画。

线段树数组长度

下一步,我们需要确定表示线段树的数组 container 的长度。

正如上面的例子,最简单的思路就是把 nums 的长度凑够 2 的次方,作为线段树最后一层的结点个数,看看此时整棵满二叉树有多少个结点,即为 container 需要多少空间,这个空间肯定大于等于线段树的需求。

对于一个 nn 层(nN+n\in N^+满二叉树,其结点总个数为: i=0n12i=12n12=2n1 \sum_{i=0}^{n-1}2^i = \frac{1-2^{n}}{1-2}=2^{n}-1 且其第 nn 层结点个数为 2n12^{n-1}

那么,假设我们的数组有 kk 个数字,我们将其凑够满二叉树的最后一层,那么最后一层的结点个数就有: 2n1k 2^{n-1}\geq k 我们取满足该条件的最小 nn,即: n=log2k+1 n=\ulcorner log_{2}k \urcorner + 1 那么,能容纳该区间的线段树所需数组大小为: 2log2k+11 2^{\ulcorner log_{2}k \urcorner + 1}-1 所以我们可以编写线段树的构造函数。

class SegmentTree
{
    #container;
    #numsLength;

    /**
     * @param {number[]} nums 
     */
    constructor(nums)
    {
        const k = nums.length;
        const size = 2 ** (Math.ceil(Math.log2(k)) + 1) - 1;
        this.#container = new Array(size);
        this.#numsLength = k;

        this.#build(nums, 0, nums.length - 1, 0);
    }
}

merge 函数

我们首先要确定线段树每个结点的值的来源,即如何从左右子结点得到当前结点的值。比如,如果是用于区间求和的线段树,其 merge 函数如下:

class SegmentTree
{
    #container;
    #numsLength;

    /**
     * @param {number} a 
     * @param {number} b 
     * @returns {number}
     */
    #merge(a, b)
    {
        return a + b;
    }
}

这一部分功能可以用独立的函数实现,也可以作为参数从构造函数传入。

线段树的构造

很经典地,二叉树相关的算法很多都采用递归方法。对于线段树构造方法 build(),思路很简单:

传入一个结点 root

  • 如果 root 的区间仍然可二分,就根据区间和下标信息分别计算出左右子结点的区间和下标信息,然后递归构造左右子树,最后从左右子结点得到当前结点的值;
  • 如果 root 的区间不可二分,那么就从原数组中复制得到当前结点的值。

代码如下:

class SegmentTree
{
    #container;
    #numsLength;

    /**
     * `#container[rootIndex]` 当中的数字代表原数组 `[rootStartIndex, rootEndIndex]` 区间 merge 后的结果
     * @param {number[]} nums
     * @param {number} rootStartIndex 
     * @param {number} rootEndIndex 
     * @param {number} rootIndex 
     */
    #build(nums, rootStartIndex, rootEndIndex, rootIndex)
    {
        if (rootStartIndex === rootEndIndex)
        {
            this.#container[rootIndex] = nums[rootStartIndex];
        }
        else
        {
            const leftChildIndex = 2 * rootIndex + 1;
            const rightChildIndex = 2 * rootIndex + 2;

            const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);

            // 递归构造左右结点
            this.#build(nums, rootStartIndex, midIndex, leftChildIndex);
            this.#build(nums, midIndex + 1, rootEndIndex, rightChildIndex);

            // 构造当前结点
            this.#container[rootIndex] = this.#merge(
                this.#container[leftChildIndex],
                this.#container[rightChildIndex]
            );
        }
    }
}

线段树的单点修改

单点修改传入 (index,value)(index, value) 以修改 nums 数组中的值。类似地,单点修改也采用递归思路:

  • 如果 root 的区间不可再二分,那么修改 root 的值为 valuevalue 即可;
  • 否则,查看 indexindex 落在 root 哪个子结点的区间中,进行递归;
  • 子结点修改完成后,修改当前结点的值。

代码为:

class SegmentTree
{
    #container;
    #numsLength;

    /**
     * 
     * @param {number} index 
     * @param {number} val 
     */
    set(index, val)
    {
        this.#setHelper(index, val, 0, this.#numsLength - 1, 0);
    }

    /**
     * @param {number} index - 要修改的 nums 上的下标
     * @param {number} val
     * @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
     * @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
     * @param {number} rootIndex - root 结点在 container 上的下标
     */
    #setHelper(index, val, rootStartIndex, rootEndIndex, rootIndex)
    {
        if (rootStartIndex === rootEndIndex)
        {
            this.#container[rootIndex] = val;
        }
        else
        {
            const leftChildIndex = 2 * rootIndex + 1;
            const rightChildIndex = 2 * rootIndex + 2;

            const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);

            // 被修改的下标在左半边
            if (index <= midIndex)
            {
                this.#setHelper(index, val, rootStartIndex, midIndex, leftChildIndex);
            }
            // 被修改的下标在右半边
            else
            {
                this.#setHelper(index, val, midIndex + 1, rootEndIndex, rightChildIndex);
            }

            // 更新当前结点值
            this.#container[rootIndex] = this.#merge(
                this.#container[leftChildIndex],
                this.#container[rightChildIndex]
            );
        }
    }
}

线段树的查询

线段树的查询通常是查询一个区间 [startIndex,endIndex][startIndex, endIndex]。那么给定一个结点 root,就会有以下三种情况:

  1. 被查询区间就是 root 代表的区间,返回 root 的值即可;
  2. 被查询区间整个都在 root 左子结点代表的区间内,在左子结点上递归查询;
  3. 被查询区间整个都在 root 右子结点代表的区间内,在右子结点上递归查询;
  4. 被查询区间一部分在 root 左子结点代表的区间内,一部分在 root 右子结点代表的区间内,需要分割区间分别查询后合并。

写成代码就是:

class SegmentTree
{
    #container;
    #numsLength;

    /**
     * @param {number} startIndex - 要查询的在 nums 上的区间起点
     * @param {number} endIndex - 要查询的在 nums 上的区间终点
     */
    query(startIndex, endIndex)
    {
        return this.#queryHelper(0, this.#numsLength - 1, 0, startIndex, endIndex);
    }

    /**
     * @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
     * @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
     * @param {number} rootContainerIndex - root 结点在 container 上的下标
     * @param {number} queryStartIndex - 要查询的在 nums 上的区间起点
     * @param {number} queryEndIndex - 要查询的在 nums 上的区间终点
     * @returns {number}
     */
    #queryHelper(rootStartIndex, rootEndIndex, rootContainerIndex,
        queryStartIndex, queryEndIndex)
    {
        // 区间对应,当前 root 就是要找的结点
        if (rootStartIndex === queryStartIndex
        && rootEndIndex === queryEndIndex)
        {
            return this.#container[rootContainerIndex];
        }
        else
        {
            const leftChildIndex = 2 * rootContainerIndex + 1;
            const rightChildIndex = 2 * rootContainerIndex + 2;

            const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);

            // 在左半个区间
            if (queryEndIndex <= midIndex)
            {
                return this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, queryEndIndex);
            }
            // 在右半个区间
            else if (queryStartIndex > midIndex)
            {
                return this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, queryStartIndex, queryEndIndex);
            }
            // 横跨两个区间
            else
            {
                return this.#merge(
                    this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, midIndex),
                    this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, midIndex + 1, queryEndIndex),
                );
            }
        }
    }
}

完整代码

将以上代码和在一起,可以得到:

class SegmentTree
{
    #container;
    #numsLength;

    /**
     * @param {number[]} nums 
     */
    constructor(nums)
    {
        const k = nums.length;
        const size = 2 ** (Math.ceil(Math.log2(k)) + 1) - 1;
        this.#container = new Array(size);
        this.#numsLength = k;

        this.#build(nums, 0, nums.length - 1, 0);
    }

    /**
     * `#container[rootIndex]` 当中的数字代表原数组 `[rootStartIndex, rootEndIndex]` 区间 merge 后的结果
     * @param {number[]} nums
     * @param {number} rootStartIndex 
     * @param {number} rootEndIndex 
     * @param {number} rootIndex 
     */
    #build(nums, rootStartIndex, rootEndIndex, rootIndex)
    {
        if (rootStartIndex === rootEndIndex)
        {
            this.#container[rootIndex] = nums[rootStartIndex];
        }
        else
        {
            const leftChildIndex = 2 * rootIndex + 1;
            const rightChildIndex = 2 * rootIndex + 2;

            const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);

            // 递归构造左右结点
            this.#build(nums, rootStartIndex, midIndex, leftChildIndex);
            this.#build(nums, midIndex + 1, rootEndIndex, rightChildIndex);

            // 构造当前结点
            this.#container[rootIndex] = this.#merge(
                this.#container[leftChildIndex],
                this.#container[rightChildIndex]
            );
        }
    }

    /**
     * 
     * @param {number} index 
     * @param {number} val 
     */
    set(index, val)
    {
        this.#setHelper(index, val, 0, this.#numsLength - 1, 0);
    }

    /**
     * @param {number} index - 要修改的 nums 上的下标
     * @param {number} val
     * @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
     * @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
     * @param {number} rootIndex - root 结点在 container 上的下标
     */
    #setHelper(index, val, rootStartIndex, rootEndIndex, rootIndex)
    {
        if (rootStartIndex === rootEndIndex)
        {
            this.#container[rootIndex] = val;
        }
        else
        {
            const leftChildIndex = 2 * rootIndex + 1;
            const rightChildIndex = 2 * rootIndex + 2;

            const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);

            // 被修改的下标在左半边
            if (index <= midIndex)
            {
                this.#setHelper(index, val, rootStartIndex, midIndex, leftChildIndex);
            }
            // 被修改的下标在右半边
            else
            {
                this.#setHelper(index, val, midIndex + 1, rootEndIndex, rightChildIndex);
            }

            // 更新当前结点值
            this.#container[rootIndex] = this.#merge(
                this.#container[leftChildIndex],
                this.#container[rightChildIndex]
            );
        }
    }

    /**
     * @param {number} startIndex - 要查询的在 nums 上的区间起点
     * @param {number} endIndex - 要查询的在 nums 上的区间终点
     */
    query(startIndex, endIndex)
    {
        return this.#queryHelper(0, this.#numsLength - 1, 0, startIndex, endIndex);
    }

    /**
     * @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
     * @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
     * @param {number} rootContainerIndex - root 结点在 container 上的下标
     * @param {number} queryStartIndex - 要查询的在 nums 上的区间起点
     * @param {number} queryEndIndex - 要查询的在 nums 上的区间终点
     * @returns {number}
     */
    #queryHelper(rootStartIndex, rootEndIndex, rootContainerIndex,
        queryStartIndex, queryEndIndex)
    {
        // 区间对应,当前 root 就是要找的结点
        if (rootStartIndex === queryStartIndex
        && rootEndIndex === queryEndIndex)
        {
            return this.#container[rootContainerIndex];
        }
        else
        {
            const leftChildIndex = 2 * rootContainerIndex + 1;
            const rightChildIndex = 2 * rootContainerIndex + 2;

            const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);

            // 在左半个区间
            if (queryEndIndex <= midIndex)
            {
                return this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, queryEndIndex);
            }
            // 在右半个区间
            else if (queryStartIndex > midIndex)
            {
                return this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, queryStartIndex, queryEndIndex);
            }
            // 横跨两个区间
            else
            {
                return this.#merge(
                    this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, midIndex),
                    this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, midIndex + 1, queryEndIndex),
                );
            }
        }
    }

    /**
     * @param {number} a 
     * @param {number} b 
     * @returns {number}
     */
    #merge(a, b)
    {
        return a + b;
    }
}

显然,线段树的各部分都用到了二分的思想,因此时间复杂度都为 O(log n)O(log\ n)

线段树用于开篇提到的题目当中,可以得到不错的运行效率:

Accepted
    15/15 cases passed (605 ms)
    Your runtime beats 90.24 % of javascript submissions
    Your memory usage beats 91.87 % of javascript submissions (75 MB)

参考文献

  1. https://oi-wiki.org/ds/seg/