线段树(Segment Tree)是常用的用来维护区间信息的数据结构,其可以在 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。本文介绍线段树的实现方法。
线段树的一个典型应用是 LeetCode 307. Range Sum Query - Mutable。如果在不可变的数组上区间求和,我们可以利用前缀和。但是在可变数组上区间求和,前缀和每次更新的复杂度将会达到 。相比之下,线段树的效率高得多,而且支持的查询也更多。
线段树举例
我们假设要在数组 nums
上建立一个线段树,那么线段树的每个结点都代表一个区间的计算值,该计算值可以是区间求和、区间最大值和最小值等等。举例来说,对于数组 [3,5,9,4]
,我们要做区间求和线段树,则该树如下图:
可以看到,每个结点包含两部分信息:区间和值。除叶子结点外,每个结点都一定有左右子结点,分别承载半个区间。这一特性决定了线段树可以用数组来表示。因为数组 [3,5,9,4]
的长度正好是 2 的次方,所以其线段树是一颗满二叉树。
那么,一个线段树是否一定是完全二叉树?不一定。比如对于一个长度为 5 的数组,其线段树就不是完全二叉树,大家可以自己尝试画一画。
线段树数组长度
下一步,我们需要确定表示线段树的数组 container
的长度。
正如上面的例子,最简单的思路就是把 nums
的长度凑够 2 的次方,作为线段树最后一层的结点个数,看看此时整棵满二叉树有多少个结点,即为 container
需要多少空间,这个空间肯定大于等于线段树的需求。
对于一个 层()满二叉树,其结点总个数为: 且其第 层结点个数为 。
那么,假设我们的数组有 个数字,我们将其凑够满二叉树的最后一层,那么最后一层的结点个数就有: 我们取满足该条件的最小 ,即: 那么,能容纳该区间的线段树所需数组大小为: 所以我们可以编写线段树的构造函数。
class SegmentTree
{
#container;
#numsLength;
/**
* @param {number[]} nums
*/
constructor(nums)
{
const k = nums.length;
const size = 2 ** (Math.ceil(Math.log2(k)) + 1) - 1;
this.#container = new Array(size);
this.#numsLength = k;
this.#build(nums, 0, nums.length - 1, 0);
}
}
merge
函数
我们首先要确定线段树每个结点的值的来源,即如何从左右子结点得到当前结点的值。比如,如果是用于区间求和的线段树,其 merge
函数如下:
class SegmentTree
{
#container;
#numsLength;
/**
* @param {number} a
* @param {number} b
* @returns {number}
*/
#merge(a, b)
{
return a + b;
}
}
这一部分功能可以用独立的函数实现,也可以作为参数从构造函数传入。
线段树的构造
很经典地,二叉树相关的算法很多都采用递归方法。对于线段树构造方法 build()
,思路很简单:
传入一个结点 root
,
- 如果
root
的区间仍然可二分,就根据区间和下标信息分别计算出左右子结点的区间和下标信息,然后递归构造左右子树,最后从左右子结点得到当前结点的值; - 如果
root
的区间不可二分,那么就从原数组中复制得到当前结点的值。
代码如下:
class SegmentTree
{
#container;
#numsLength;
/**
* `#container[rootIndex]` 当中的数字代表原数组 `[rootStartIndex, rootEndIndex]` 区间 merge 后的结果
* @param {number[]} nums
* @param {number} rootStartIndex
* @param {number} rootEndIndex
* @param {number} rootIndex
*/
#build(nums, rootStartIndex, rootEndIndex, rootIndex)
{
if (rootStartIndex === rootEndIndex)
{
this.#container[rootIndex] = nums[rootStartIndex];
}
else
{
const leftChildIndex = 2 * rootIndex + 1;
const rightChildIndex = 2 * rootIndex + 2;
const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);
// 递归构造左右结点
this.#build(nums, rootStartIndex, midIndex, leftChildIndex);
this.#build(nums, midIndex + 1, rootEndIndex, rightChildIndex);
// 构造当前结点
this.#container[rootIndex] = this.#merge(
this.#container[leftChildIndex],
this.#container[rightChildIndex]
);
}
}
}
线段树的单点修改
单点修改传入 以修改 nums
数组中的值。类似地,单点修改也采用递归思路:
- 如果
root
的区间不可再二分,那么修改root
的值为 即可; - 否则,查看 落在 root 哪个子结点的区间中,进行递归;
- 子结点修改完成后,修改当前结点的值。
代码为:
class SegmentTree
{
#container;
#numsLength;
/**
*
* @param {number} index
* @param {number} val
*/
set(index, val)
{
this.#setHelper(index, val, 0, this.#numsLength - 1, 0);
}
/**
* @param {number} index - 要修改的 nums 上的下标
* @param {number} val
* @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
* @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
* @param {number} rootIndex - root 结点在 container 上的下标
*/
#setHelper(index, val, rootStartIndex, rootEndIndex, rootIndex)
{
if (rootStartIndex === rootEndIndex)
{
this.#container[rootIndex] = val;
}
else
{
const leftChildIndex = 2 * rootIndex + 1;
const rightChildIndex = 2 * rootIndex + 2;
const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);
// 被修改的下标在左半边
if (index <= midIndex)
{
this.#setHelper(index, val, rootStartIndex, midIndex, leftChildIndex);
}
// 被修改的下标在右半边
else
{
this.#setHelper(index, val, midIndex + 1, rootEndIndex, rightChildIndex);
}
// 更新当前结点值
this.#container[rootIndex] = this.#merge(
this.#container[leftChildIndex],
this.#container[rightChildIndex]
);
}
}
}
线段树的查询
线段树的查询通常是查询一个区间 。那么给定一个结点 root
,就会有以下三种情况:
- 被查询区间就是
root
代表的区间,返回root
的值即可; - 被查询区间整个都在
root
左子结点代表的区间内,在左子结点上递归查询; - 被查询区间整个都在
root
右子结点代表的区间内,在右子结点上递归查询; - 被查询区间一部分在
root
左子结点代表的区间内,一部分在root
右子结点代表的区间内,需要分割区间分别查询后合并。
写成代码就是:
class SegmentTree
{
#container;
#numsLength;
/**
* @param {number} startIndex - 要查询的在 nums 上的区间起点
* @param {number} endIndex - 要查询的在 nums 上的区间终点
*/
query(startIndex, endIndex)
{
return this.#queryHelper(0, this.#numsLength - 1, 0, startIndex, endIndex);
}
/**
* @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
* @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
* @param {number} rootContainerIndex - root 结点在 container 上的下标
* @param {number} queryStartIndex - 要查询的在 nums 上的区间起点
* @param {number} queryEndIndex - 要查询的在 nums 上的区间终点
* @returns {number}
*/
#queryHelper(rootStartIndex, rootEndIndex, rootContainerIndex,
queryStartIndex, queryEndIndex)
{
// 区间对应,当前 root 就是要找的结点
if (rootStartIndex === queryStartIndex
&& rootEndIndex === queryEndIndex)
{
return this.#container[rootContainerIndex];
}
else
{
const leftChildIndex = 2 * rootContainerIndex + 1;
const rightChildIndex = 2 * rootContainerIndex + 2;
const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);
// 在左半个区间
if (queryEndIndex <= midIndex)
{
return this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, queryEndIndex);
}
// 在右半个区间
else if (queryStartIndex > midIndex)
{
return this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, queryStartIndex, queryEndIndex);
}
// 横跨两个区间
else
{
return this.#merge(
this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, midIndex),
this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, midIndex + 1, queryEndIndex),
);
}
}
}
}
完整代码
将以上代码和在一起,可以得到:
class SegmentTree
{
#container;
#numsLength;
/**
* @param {number[]} nums
*/
constructor(nums)
{
const k = nums.length;
const size = 2 ** (Math.ceil(Math.log2(k)) + 1) - 1;
this.#container = new Array(size);
this.#numsLength = k;
this.#build(nums, 0, nums.length - 1, 0);
}
/**
* `#container[rootIndex]` 当中的数字代表原数组 `[rootStartIndex, rootEndIndex]` 区间 merge 后的结果
* @param {number[]} nums
* @param {number} rootStartIndex
* @param {number} rootEndIndex
* @param {number} rootIndex
*/
#build(nums, rootStartIndex, rootEndIndex, rootIndex)
{
if (rootStartIndex === rootEndIndex)
{
this.#container[rootIndex] = nums[rootStartIndex];
}
else
{
const leftChildIndex = 2 * rootIndex + 1;
const rightChildIndex = 2 * rootIndex + 2;
const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);
// 递归构造左右结点
this.#build(nums, rootStartIndex, midIndex, leftChildIndex);
this.#build(nums, midIndex + 1, rootEndIndex, rightChildIndex);
// 构造当前结点
this.#container[rootIndex] = this.#merge(
this.#container[leftChildIndex],
this.#container[rightChildIndex]
);
}
}
/**
*
* @param {number} index
* @param {number} val
*/
set(index, val)
{
this.#setHelper(index, val, 0, this.#numsLength - 1, 0);
}
/**
* @param {number} index - 要修改的 nums 上的下标
* @param {number} val
* @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
* @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
* @param {number} rootIndex - root 结点在 container 上的下标
*/
#setHelper(index, val, rootStartIndex, rootEndIndex, rootIndex)
{
if (rootStartIndex === rootEndIndex)
{
this.#container[rootIndex] = val;
}
else
{
const leftChildIndex = 2 * rootIndex + 1;
const rightChildIndex = 2 * rootIndex + 2;
const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);
// 被修改的下标在左半边
if (index <= midIndex)
{
this.#setHelper(index, val, rootStartIndex, midIndex, leftChildIndex);
}
// 被修改的下标在右半边
else
{
this.#setHelper(index, val, midIndex + 1, rootEndIndex, rightChildIndex);
}
// 更新当前结点值
this.#container[rootIndex] = this.#merge(
this.#container[leftChildIndex],
this.#container[rightChildIndex]
);
}
}
/**
* @param {number} startIndex - 要查询的在 nums 上的区间起点
* @param {number} endIndex - 要查询的在 nums 上的区间终点
*/
query(startIndex, endIndex)
{
return this.#queryHelper(0, this.#numsLength - 1, 0, startIndex, endIndex);
}
/**
* @param {number} rootStartIndex - root 结点代表的 nums 上的区间起点
* @param {number} rootEndIndex - root 结点代表的 nums 上的区间终点
* @param {number} rootContainerIndex - root 结点在 container 上的下标
* @param {number} queryStartIndex - 要查询的在 nums 上的区间起点
* @param {number} queryEndIndex - 要查询的在 nums 上的区间终点
* @returns {number}
*/
#queryHelper(rootStartIndex, rootEndIndex, rootContainerIndex,
queryStartIndex, queryEndIndex)
{
// 区间对应,当前 root 就是要找的结点
if (rootStartIndex === queryStartIndex
&& rootEndIndex === queryEndIndex)
{
return this.#container[rootContainerIndex];
}
else
{
const leftChildIndex = 2 * rootContainerIndex + 1;
const rightChildIndex = 2 * rootContainerIndex + 2;
const midIndex = rootStartIndex + Math.floor((rootEndIndex - rootStartIndex) / 2);
// 在左半个区间
if (queryEndIndex <= midIndex)
{
return this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, queryEndIndex);
}
// 在右半个区间
else if (queryStartIndex > midIndex)
{
return this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, queryStartIndex, queryEndIndex);
}
// 横跨两个区间
else
{
return this.#merge(
this.#queryHelper(rootStartIndex, midIndex, leftChildIndex, queryStartIndex, midIndex),
this.#queryHelper(midIndex + 1, rootEndIndex, rightChildIndex, midIndex + 1, queryEndIndex),
);
}
}
}
/**
* @param {number} a
* @param {number} b
* @returns {number}
*/
#merge(a, b)
{
return a + b;
}
}
显然,线段树的各部分都用到了二分的思想,因此时间复杂度都为 。
线段树用于开篇提到的题目当中,可以得到不错的运行效率:
Accepted
15/15 cases passed (605 ms)
Your runtime beats 90.24 % of javascript submissions
Your memory usage beats 91.87 % of javascript submissions (75 MB)
参考文献
- https://oi-wiki.org/ds/seg/