用 Rust 实现一个 AVL 树

什么是 AVL 树以及 AVL 树的一些原理就不介绍了,我们直接讲如何用 Rust 来实现它。

结构体及 trait 定义

参考一般 AVL 树的实现,我定义了一个结构体 TreeNode 以及一个类型别名 AvlTreeNode

pub type AvlTreeNode<T> = Option<Box<TreeNode<T>>>;

#[derive(Clone, Debug)]
pub struct TreeNode<T: PartialOrd> {
    val: T,
    height: i32,
    left: AvlTreeNode<T>,
    right: AvlTreeNode<T>,
}

使用类型别名可以了少打几个字母以及好看一点。

还定义了一个 trait AvlTree

pub trait AvlTree<T: PartialOrd> {
    fn new(val: T) -> Self;
    fn height(&self) -> i32;
    fn insert(&mut self, val: T);
    fn delete(&mut self, val: T) -> Self;
}

因为懒,所以这里没有实现查找之类的操作。

我们知道要使 AVL 树才插入和删除后仍然保持平衡需要进行一些旋转操作,但是我不想把这些操作作为公共方法暴露出去,当然一般情况下我们可以为结构体实现私有方法,但是我们这里要为 AvlTreeNode 实现 AvlTree,而 AvlTreeNode 是一个 Option 的类型别名,不能直接为它实现方法。

所以我又定义了一个私有 trait __AvlTree 用来为 AvlTreeNode 实现这些私有方法:

trait __AvlTree<T: PartialOrd> {
    fn rotate_ll(&mut self);
    fn rotate_rr(&mut self);
    fn rotate_lr(&mut self);
    fn rotate_rl(&mut self);
    fn update_height(&mut self);
    fn balance_factor(&self) -> i32;
    fn do_insert(&mut self, val: T) -> InnerResult;
    fn do_delete(&mut self, val: &mut DeleteValue<T>) -> InnerResult;
}

其中 do_insertdo_delete 是用来在内部做递归操作的,它们返回操作的结果,但是这些结果仅仅作为内部结果:

enum InnerResult {
    Left,     //在左子树完成插入
    Right,    //在右子树完成插入
    Unknown,  //树的平衡性未知
    Balanced, //树已确定平衡
}

之所以这样做是为了在 do_insert 或者 do_delete 返回的时候直接拿到递归调用的一些有用信息,而不必去做多余的比较和检查。

AvlTreeNode 实现 AvlTree

四种旋转

首先是左左情况,假设我们现在已经有一颗 AVL 树如下,当我们再向树的左子树的左子树插入 0(或者 2)时,这颗树就会失去平衡:

insert_0.png

我们可以使用下面的步骤完成一次单旋转,使其恢复平衡:

left-left-rotate_0.png

left-left-rotate_1.png

对应的 Rust 实现如下:

// use core::mem::swap;

// impl<T: PartialOrd> __AvlTree<T> for AvlTreeNode<T>
fn rotate_ll(&mut self) {
    match self {
        Some(root) => {
            let left = &mut root.left.take();
            match left {
                Some(x) => {
                    swap(&mut root.left, &mut x.right);
                    self.update_height();
                    swap(self, &mut x.right);
                    swap(self, left);
                    self.update_height();
                }
                None => unreachable!(),
            }
        }
        None => unreachable!(),
    }
}

fn update_height(&mut self) {
    match self {
        None => return,
        Some(x) => x.height = max(x.left.height(), x.right.height()) + 1,
    }
}

// impl<T: PartialOrd> AvlTree<T> for AvlTreeNode<T> {
fn height(&self) -> i32 {
    match self {
        None => 0,
        Some(x) => x.height,
    }
}

对于右右的情况它与左左情况是左右对称的,只要把 rotate_ll 中的所有 leftright 互换就好了。

左右的情况,先对要平衡的节点的左子树做右右情况的单旋转,再对自身做左左情况的单旋转:

// impl<T: PartialOrd> __AvlTree<T> for AvlTreeNode<T>
fn rotate_lr(&mut self) {
    match self {
        Some(root) => {
            root.left.rotate_rr();
            self.rotate_ll();
        }
        None => unreachable!(),
    }
}

同理,右左的情况先对右子树做左左情况的单旋转,再对自身做右右情况的单旋转:

// impl<T: PartialOrd> __AvlTree<T> for AvlTreeNode<T>
fn rotate_rl(&mut self) {
    match self {
        Some(root) => {
            root.right.rotate_ll();
            self.rotate_rr();
        }
        None => unreachable!(),
    }
}

实现插入操作

// impl<T: PartialOrd> __AvlTree<T> for AvlTreeNode<T>
fn do_insert(&mut self, val: T) -> InnerResult {
    match self {
        //直接插入新节点
        None => {
            *self = Self::new(val);
            Unknown
        }
        //递归插入
        Some(root) => {
            //重复数据
            if val == root.val {
                Balanced
            //进入左子树递归插入
            } else if val < root.val {
                match root.left.do_insert(val) {
                    Balanced => Balanced,
                    x => {
                        if self.balance_factor() == 2 {
                            match x {
                                Left => self.rotate_ll(),
                                Right => self.rotate_lr(),
                                _ => unreachable!(),
                            }
                            Balanced
                        } else {
                            if self.height() == {
                                self.update_height();
                                self.height()
                            } {
                                Balanced
                            } else {
                                Left
                            }
                        }
                    }
                }
            //进入右子树递归插入
            } else {
                match root.right.do_insert(val) {
                    Balanced => Balanced,
                    x => {
                        if self.balance_factor() == -2 {
                            match x {
                                Left => self.rotate_rl(),
                                Right => self.rotate_rr(),
                                _ => unreachable!(),
                            }
                            Balanced
                        } else {
                            if self.height() == {
                                self.update_height();
                                self.height()
                            } {
                                Balanced
                            } else {
                                Right
                            }
                        }
                    }
                }
            }
        }
    }
}

fn balance_factor(&self) -> i32 {
    match self {
        None => 0,
        Some(x) => x.left.height() - x.right.height(),
    }
}

// impl<T: PartialOrd> AvlTree<T> for AvlTreeNode<T>
fn new(val: T) -> Self {
    Some(Box::new(TreeNode {
        val,
        height: 1,
        left: None,
        right: None,
    }))
}

fn insert(&mut self, val: T) {
    self.do_insert(val);
}

这里我将空节点的高度定为了 0,也就是说一个没有子节点的节点(叶子节点)的高度是 1,同时 balance_factor 返回的平衡因子是左子树的高度减右子树的高度。

我们来细看在左子树递归插入这一段:

//进入左子树递归插入
match root.left.do_insert(val) {
    //已经确定整颗树处于平衡
    Balanced => Balanced,
    x => {
        //失去平衡,由于本次是在左子树递归插入,所以失去平衡时平衡因子必为 2
        if self.balance_factor() == 2 {
            match x {
                //在左子树的左子树完成了插入
                Left => self.rotate_ll(),
                //在左子树的右子树完成了插入
                Right => self.rotate_lr(),
                //返回 `Unknown` 的时候当前节点必定是平衡的
                _ => unreachable!(),
            }
            //已经完成了旋转,可以确定整颗树已经平衡
            Balanced
        } else {
            //如果当前节点高度未发生变化,可以确定整颗树已平衡
            if self.height() == {
                self.update_height();
                self.height()
            } {
                Balanced
            //否则整颗树的平衡性未知,告诉上一层调用插入是在它的左子树完成的
            } else {
                Left
            }
        }
    }
}

root.left 的递归插入如果新建了节点会返回 Unknown,表示此时树的平衡性未知,上一层调用拿到这个返回值,会检查平衡因子,此时平衡因子必然是 0 或者 ±1,所以又会检查节点高度是否发生变化,如果没有发生变化,说明树已平衡,否则再上一层返回 Left, 告诉再上一层拿插入是在在它的左子树完成的。

进入右子树递归插入的逻辑同理。

实现删除操作

删除操作比插入稍微复杂一点,但原理也还算简单,首先查找需要删除的节点,如果找不到,那当然啥都不需要做,返回一个 None,如果找到这个节点并且这个节点是叶子节点,那我们可以直接删除它,然后检查树的平衡性并重新平衡就可以了,先管这种情形叫情形一。

如果要删除的这个节点只有一个子树,我们可以用它唯一的子节点替换它,再重新平衡树,这是情形二。显然它可以和情形一合并,因为没有子节点实际上可以看作有两个空节点,这个时候我可以用它的任意一个空的子节点替换它。

最后是要删除的节点同时拥有左右子树的情况,这个时候我们实际上有两个选择,即用它左子树的最大节点或者它右子树的最小节点替换它,所以这个问题可以转换为删除并取出左子树的最大节点或右子树的最小节点,然后将取出的节点与真正要删除的节点进行交换。实际操作的时候,为了尽可能小的破坏树的平衡性,我们可以选择两个子树中高度更高的那颗进行删除操作,并且做交换的时候可以只交换节点保存的值,这是情形三。

可以看出,情形三中删除子树的最小或最大节点实际上对应前面的情形一或二。

随之而来的问题就是如何删除并取出子树的最小或者最大节点,最容易想到的方法就是先查找其最小或最大值,然后再用 do_delete 去删除,但是很明显这种方法需要做两次查找。我希望这步操作可以只查找一次,所以又定义了一个枚举:

enum DeleteValue<T: PartialOrd> {
    Min,                 //匹配最小节点
    Max,                 //匹配最大节点
    Val(T),              //匹配给定值
    Del(AvlTreeNode<T>), //返回被删除节点
}

这个枚举可以直接表示树中的最小值和最大值,以及常规的值,顺便还可以用来返回被删除的节点。同时为了方便操作,我为它实现了 PartialOrd<Box<TreeNode<T>>>,这样它甚至可以和节点直接比较大小而不用取节点的 val 字段。

impl<T: PartialOrd> PartialEq<Box<TreeNode<T>>> for DeleteValue<T> {
    fn eq(&self, other: &Box<TreeNode<T>>) -> bool {
        match self {
            Min => other.left.is_none(),
            Max => other.right.is_none(),
            Val(v) => v == &other.val,
            _ => false,
        }
    }
}

impl<T: PartialOrd> PartialOrd<Box<TreeNode<T>>> for DeleteValue<T> {
    fn partial_cmp(&self, other: &Box<TreeNode<T>>) -> Option<Ordering> {
        match self {
            Min => Some(Ordering::Less),
            Max => Some(Ordering::Greater),
            Val(v) => v.partial_cmp(&other.val),
            _ => None,
        }
    }
}

基于上面的原理,最后 do_deletedelete 的实现如下:

// impl<T: PartialOrd> __AvlTree<T> for AvlTreeNode<T>
fn do_delete(&mut self, val: &mut DeleteValue<T>) -> InnerResult {
    match self {
        None => {
            *val = Del(None);
            Balanced
        }
        Some(root) => {
            //保存当前节点高度
            let height = root.height;

            //删除当前节点
            if val == root {
                if root.left.is_some() {
                    //左右子树均非空
                    if root.right.is_some() {
                        if root.left.height() > root.right.height() {
                            *val = Max;
                            root.left.do_delete(val);
                            match val {
                                Del(Some(x)) => {
                                    swap(&mut root.val, &mut x.val);
                                }
                                _ => unreachable!(),
                            }
                        } else {
                            *val = Min;
                            root.right.do_delete(val);
                            match val {
                                Del(Some(x)) => {
                                    swap(&mut root.val, &mut x.val);
                                }
                                _ => unreachable!(),
                            }
                        }
                    //左子树非空,右子树为空
                    } else {
                        let mut left = root.left.take();
                        swap(self, &mut left);
                        *val = Del(left);
                    }
                //左子树为空,右子树非空或为空
                } else {
                    let mut right = root.right.take();
                    swap(self, &mut right);
                    *val = Del(right);
                }
                self.update_height();
            //进入左子树递归删除
            } else if val < root {
                match root.left.do_delete(val) {
                    Balanced => return Balanced,
                    Unknown => {
                        if self.balance_factor() == -2 {
                            let right = self.as_ref().unwrap().right.as_ref().unwrap();
                            if right.left.height() > right.right.height() {
                                self.rotate_rl();
                            } else {
                                self.rotate_rr();
                            }
                        } else {
                            self.update_height();
                        }
                    }
                    _ => unreachable!(),
                }
            //进入右子树递归删除
            } else {
                match root.right.do_delete(val) {
                    Balanced => return Balanced,
                    Unknown => {
                        if self.balance_factor() == 2 {
                            let left = self.as_ref().unwrap().left.as_ref().unwrap();
                            if left.left.height() >= left.right.height() {
                                self.rotate_ll();
                            } else {
                                self.rotate_lr();
                            }
                        } else {
                            self.update_height();
                        }
                    }
                    _ => unreachable!(),
                }
            }

            //如果节点高度未发生变化,则可确定树已平衡
            if height == self.height() {
                Balanced
            } else {
                Unknown
            }
        }
    }
}

// impl<T: PartialOrd> AvlTree<T> for AvlTreeNode<T>
fn delete(&mut self, val: T) -> Self {
    let mut val = Val(val);
    self.do_delete(&mut val);
    match val {
        Del(x) => x,
        _ => unreachable!(),
    }
}

完整的实现在 https://github.com/nanpuyue/avl_tree

标签: Rust, 数据结构

添加新评论