package com.mj.tree; import java.util.Comparator; @SuppressWarnings("unchecked") public class BST extends BinaryTree { private Comparator comparator; public BST() { this(null); } public BST(Comparator comparator) { this.comparator = comparator; } public void add(E element) { elementNotNullCheck(element); // 添加第一个节点 if (root == null) { root = createNode(element, null); size++; // 新添加节点之后的处理 afterAdd(root); return; } // 添加的不是第一个节点 // 找到父节点 Node parent = root; Node node = root; int cmp = 0; do { cmp = compare(element, node.element); parent = node; if (cmp > 0) { node = node.right; } else if (cmp < 0) { node = node.left; } else { // 相等 node.element = element; return; } } while (node != null); // 看看插入到父节点的哪个位置 Node newNode = createNode(element, parent); if (cmp > 0) { parent.right = newNode; } else { parent.left = newNode; } size++; // 新添加节点之后的处理 afterAdd(newNode); } /** * 添加node之后的调整 * @param node 新添加的节点 */ protected void afterAdd(Node node) { } /** * 删除node之后的调整 * @param node 被删除的节点 或者 用以取代被删除节点的子节点(当被删除节点的度为1) */ protected void afterRemove(Node node) { } public void remove(E element) { remove(node(element)); } public boolean contains(E element) { return node(element) != null; } private void remove(Node node) { if (node == null) return; size--; if (node.hasTwoChildren()) { // 度为2的节点 // 找到后继节点 Node s = successor(node); // 用后继节点的值覆盖度为2的节点的值 node.element = s.element; // 删除后继节点 node = s; } // 删除node节点(node的度必然是1或者0) Node replacement = node.left != null ? node.left : node.right; if (replacement != null) { // node是度为1的节点 // 更改parent replacement.parent = node.parent; // 更改parent的left、right的指向 if (node.parent == null) { // node是度为1的节点并且是根节点 root = replacement; } else if (node == node.parent.left) { node.parent.left = replacement; } else { // node == node.parent.right node.parent.right = replacement; } // 删除节点之后的处理 afterRemove(replacement); } else if (node.parent == null) { // node是叶子节点并且是根节点 root = null; // 删除节点之后的处理 afterRemove(node); } else { // node是叶子节点,但不是根节点 if (node == node.parent.left) { node.parent.left = null; } else { // node == node.parent.right node.parent.right = null; } // 删除节点之后的处理 afterRemove(node); } } private Node node(E element) { Node node = root; while (node != null) { int cmp = compare(element, node.element); if (cmp == 0) return node; if (cmp > 0) { node = node.right; } else { // cmp < 0 node = node.left; } } return null; } /** * @return 返回值等于0,代表e1和e2相等;返回值大于0,代表e1大于e2;返回值小于于0,代表e1小于e2 */ private int compare(E e1, E e2) { if (comparator != null) { return comparator.compare(e1, e2); } return ((Comparable)e1).compareTo(e2); } private void elementNotNullCheck(E element) { if (element == null) { throw new IllegalArgumentException("element must not be null"); } } }