Segment Tree 线段树

线段树也是一种平衡的二叉树,不过它的主要用途并不是对单一的数据做操作,而是对针对一个区间做操作。

什么是线段树?

假设有一个数组 data, 你需要对其中的某一个区间 data[x...y] 进行操作,比如求和。你可以选择遍历一遍整个数组,这样的时间复杂度是 O(N)。但如果使用线段树,则可以将时间复杂度降低到 O(logN),代价是需要额外的O(4*N)的空间复杂度。

线段树长下图这样。这个线段树用于存储 data[0...5]整个区间的信息,方便对其中任意一个区间进行操做。每一个节点旁边的斜体数字表示当前节点在线段树中的索引,根节点为0。比如这个线段树是用来进行区间求和操作,现在需要求解 data[2...5] 的和。我们只需要从根节点向下依次找到 4号节点,5号节点,再将4号节点向上级 return, 将5号节点向上级return,最后对return上来的两个结果进行求和即可。可以看出,对整个区间的操作转换成了访问二叉树的节点。

线段树的性质

  • 线段树是一棵平衡二叉树,但不是满二叉树(从上到下,从左到右依次排列),除去倒数第一层以外是一颗满二叉树。

  • 由于是一棵平衡二叉树,所以访问每一个节点的时间复杂度都是 O(N)。

  • 假设原数组长度为N,则线段树的最大长度需要4N。简单证明如下: 假设N=2^k, 则线段树最后一层刚好排满,节点数就为N。而满二叉树的性质就是倒数第一层的节点数略大于上面所有层的节点数之和。所以整棵满二叉树的节点和为2N。 假设原数组长度很巧合,恰好N=2^(k+1)-1, 则相比于N=2^k 的层数,恰好多出一层,最后一层的节点数为2N,整棵二叉树节点数为4N。

构建一棵线段树

一些定义

假设根节点的索引为0, 则线段树的任意节点左孩子节点索引为 2*x+1, x为任意节点索引,同理右孩子节点索引为 2*x+2。整个线段树的构造过程如下:

public interface IMerger<E>{

    E merge(E a, E b);

}

IMerger 是一个接口, 里面有一个 merge 方法,给定两个索引 a, b 返回区间 [a..b]的信息。

构造方法

有两个数组 data 和 tree, data 储存原始数组,tree 为线段树。还有一个IMerger接口里面的merger方法,在这个类里面不着急实现,直接调用就好。

构造方法非常简单,将传入的数组赋值给data, 然后给 tree 数组开一个 4*N的空间,初始化tree。

public class SegmentTree<E> {

    private E[] data;
    private E[] tree;
    private IMerger<E> merger;

    public SegmentTree(E[] arr, IMerger<E> merger){
        this.merger=merger;
        data =(E[]) new Object[arr.length];
        for (int i = 0; i < arr.length; i++) {
            data[i]=arr[i];
        }
        tree=(E[]) new Object[4*arr.length];
        buildSegmentTree(0,0, data.length-1);
    }
    
    
}
    

重点看一下buildSegmentTree(int treeIdx,int l, int r) 方法,这是整个构造过程的核心。 其中 treeIdx 代表线段树的节点,l,r 分别表示当前节点所表示的区间左右索引。

当 l>=r 时说明当前区间只有一个元素,到达递归终止条件,直接给 tree[l] 赋值,再return。

否则说明还没有递归到底,需要计算左右子树的索引区间[l..mid] 和 [mid+1...r] 然后分别向左右两边递归。

当作有两颗子树构建完成了,则利用左右两颗子树构建当前节点。

private int leftChild(int idx){
    return 2*idx+1;
}

private int rightChild(int idx){
    return 2*idx+2;
}
    
private void buildSegmentTree(int treeIdx,int l, int r){
    if (l>=r){
        tree[treeIdx]=data[l];
        return;
    }
    int leftIdx=leftChild(treeIdx);
    int rightIdx=rightChild(treeIdx);

    int mid=l+(r-l)/2;
    buildSegmentTree(leftIdx,l,mid);
    buildSegmentTree(rightIdx,mid+1,r);

    tree[treeIdx]= merger.merge(tree[leftIdx],tree[rightIdx]);
}

区间查询操作

接下来的 query(int queryL, int queryR) 是整个线段树的核心方法,用于查询 [queryL...queryR]区间内的所有信息。整个查询过程从根节点向下寻找节点,找到节点后逐级向上return。

query(int treeIdx,int l, int r,int queryL, int queryR)表示 从 treeIdx 节点( [l..r] 区间内)开始查询 [queryL...queryR]区间内的信息。

若l==queryL 且 r==queryR 则说明刚好这个节点就代表了[queryL...queryR]区间内的信息,直接return。否则还需要向下面的左右子树递归查询。为此需要计算左右两个子节点的区间 [l...mid],[mid+1...r]。

若查询区间[queryL...queryR]刚好全部位于左孩子节点范围内或右孩子范围内,则只需要单独向左节点或右节点递归。

若查询区间[queryL...queryR] 部分位于左孩子节点范围内, 部分位于右孩子范围内,则只需要分别向左节点,右节点递归。再将左右节点return上来的结果进行 merge, 再return回上一层的节点。

//queryL, queryR 分别表示查询的左右区间
public E query(int queryL, int queryR){
    if (queryL<0 || queryL>= data.length || queryR<0 || queryR>= data.length ){
        throw new IllegalArgumentException(" Illegal index !");
    }

    return query(0,0, data.length-1,queryL,queryR);
}

//treeIdx 表示当前节点的索引
//l 表示当前节点所表示原数组的左索引,  r 表示当前节点所表示的原数组右索引
private E query(int treeIdx,int l, int r,int queryL, int queryR){
    if (l==queryL && r==queryR){
        return tree[treeIdx];
    }

    int mid=l+(r-l)/2;
    int leftTreeIdx=leftChild(treeIdx);
    int rightTreeIdx=rightChild(treeIdx);

    if (queryL>=mid+1 ){
        return query(rightTreeIdx,mid+1 ,r,queryL,queryR);
    }else if(queryR<=mid){
        return query(leftTreeIdx,l,mid ,queryL,queryR);
    }else {
        E leftResult=query(leftTreeIdx,l,mid,queryL,mid);
        E rightResult=query(rightTreeIdx,mid+1 ,r,mid+1,queryR);
        return merger.merge(leftResult,rightResult);
    }
}

更改原数组元素

假设需要更改原数组的内的某个元素,那么对应的线段树也需要逐层向上修改节点的值。

set(int idx,E e)表示将原数组内索引为idx的元素修改为 e。在其内部需要调用set(int treeIdx, int l, int r, int idx, E e)对线段树进行修改。其中 treeIdx表示当前节点,l, r 表示当前节点所代表的区间 [l, r]的左右索引,idx表示原数组内要修改元素的索引, e 表示修改后的元素。

若 l>=r, 说明当前区间只有一个元素,已经递归到底了,已经到达要修改的线段树最小区间节点了,直接修改 tree[reeIdx] 然后 return。

否则说明还没有找到要修改的线段树最小区间节点,则需要向当前节点的左子树或右子树继续寻找。为此需要计算当前区间的中点 mid。若 mid>=mid+1, 说明需要修改的线段树最小区间位于当前节点的右子树,则需要向右孩子继续递归,否则则向左孩子继续递归。

当左右孩子都已经完成修改时,需要重新merge左右孩子,更新当前节点 tree[treeIdx]的值,然后return。

public void set(int idx,E e){
    if (idx<0 || idx>= data.length  ){
        throw new IllegalArgumentException(" Illegal index !");
    }
    data[idx]=e;
    set(0,0,data.length-1,idx,e);
}

private void set(int treeIdx, int l, int r, int idx, E e){
    if (l>=r){
        tree[treeIdx]=e;
        return;
    }
    int mid=l+(r-l)/2;
    int leftTreeIdx=leftChild(treeIdx);
    int rightTreeIdx=rightChild(treeIdx);

    if (idx>=mid+1){
        set(rightTreeIdx,mid+1,r,idx,e);
    }else {
        set(leftTreeIdx,l,mid,idx,e);
    }
    tree[treeIdx]= merger.merge(tree[leftTreeIdx],tree[rightTreeIdx]);
}

至此,我们已经实现了一棵线段树的所有核心操作。

测试一下

重写 toString()方法

为了方便测试,我们重写了toString()方法。简单来说就是从上到下,从左到右依次输出线段树的节点。

@Override
public String toString() {
    StringBuilder res=new StringBuilder();
    res.append('[');
    for (int i = 0; i < tree.length; i++) {
        if (tree[i]!=null){
            res.append(tree[i]);
        }
        else{
            res.append("null");
        }
        if (i!= tree.length-1){
            res.append(", ");
        }
    }
    res.append(']');
    return res.toString();
}

Run起来

我们定义线段树的功能为区间求和。

public class SegmentTreeTest {

    public static void main(String[] args) {
        Integer[] nums={-2,0,3,-5,2,-1};
        SegmentTree<Integer> segTree=new SegmentTree<>(nums, new IMerger<Integer>() {
            @Override
            public Integer merge(Integer a, Integer b) {
                return a+b;
            }
        });
        System.out.println(segTree);

        segTree.set(2,6);
        System.out.println(segTree);
    }
}

运行结果:

[-3, 1, -4, -2, 3, -3, -1, -2, 0, null, null, -5, 2, null, null, null, null, null, null, null, null, null, null, null]
[0, 4, -4, -2, 6, -3, -1, -2, 0, null, null, -5, 2, null, null, null, null, null, null, null, null, null, null, null]

Last updated

Was this helpful?