Close Window

Data Structures and Algorithms

Chapter 7 Search Trees

Show Source |    | About   «  7.5. The AVL Tree   ::   Contents   ::   7.7. The Splay Tree (optional)  »

7.6. Red-Black Trees (code only)

For information about red-black trees, see the lecture handout “2-3 trees and red-black trees” under Theme 4 (Search trees).

Here is an implementation of red-black trees:

// A dictionary implemented using an red-black tree.
public class RedBlackMap<K extends Comparable<K>, V> implements Map<K, V> {
    Node root = null;   // The root of the red black tree.
    int treeSize = 0;   // The size of the tree.

    // A node in an red-black tree.
    class Node {
        K key;
        V value;
        Node left;
        Node right;
        boolean isRed;

        Node(boolean isRed, K key, V value, Node left, Node right) {
            this.key = key;
            this.value = value;
            this.left = left;
            this.right = right;
            this.isRed = isRed;

    // Check if a node is red. 'null' is always black.
    boolean isRed(Node node) {
        if (node == null) return false;
        return node.isRed;

    // Check that the invariant holds.
    void checkInvariant() {
        if (isRed(root))
            throw new AssertionError("red root");
        ArrayList<K> keys = new ArrayList<>();
        iteratorHelper(root, keys);
        if (keys.size() != treeSize)
            throw new AssertionError("wrong tree size");
        checkInvariantHelper(root, null, null);

    // Recursive helper method for 'check_invariant'.
    // Checks that the node is the root of a valid red-black tree, and that
    // all keys k satisfy lo < k < hi. The test lo < k is skipped
    // if lo is None, and k < hi is skipped if hi is None.
    // Returns the "black height" of the tree.
    int checkInvariantHelper(Node node, K lo, K hi) {
        if (node == null) return 0;

        if (lo != null && node.key.compareTo(lo) <= 0)
            throw new AssertionError("key too small");
        if (hi != null && node.key.compareTo(hi) >= 0)
            throw new AssertionError("key too big");

        if (isRed(node.right))
            throw new AssertionError("red right child");
        if (isRed(node) && isRed(node.left))
            throw new AssertionError("red node with red left child");

        // Keys in the left subtree should be < node.key
        // Keys in the right subtree should be > node.key
        int h1 = checkInvariantHelper(node.left, lo, node.key);
        int h2 = checkInvariantHelper(node.right, node.key, hi);
        if (h1 != h2)
            throw new AssertionError("unbalanced tree");

        return h1 + (isRed(node) ? 0 : 1);

    // Return true if there are no keys.
    public boolean isEmpty() {
        return root == null;

    // Return the number of keys.
    public int size() {
        return treeSize;

    // Return true if the key has an associated value.
    public boolean containsKey(K key) {
        return get(key) != null;

    // Look up a key.
    public V get(K key) {
        return getHelper(root, key);

    // Recursive helper method for 'get'.
    V getHelper(Node node, K key) {
        if (node == null)
            return null;
        if (node.key.compareTo(key) > 0)
            return getHelper(node.left, key);
        else if (node.key.compareTo(key) < 0)
            return getHelper(node.right, key);
        else // node.key == key
            return node.value;

    // Add a key-value pair, or update the value associated with an existing key.
    // Returns the previous value associated with the key,
    // or null if the key wasn't previously present.
    public V put(K key, V value) {
        root = putHelper(root, key, value);
        if (isRed(root))
            root.isRed = false;
        if (oldValue == null)
        return oldValue;

    // Recursive helper method for 'put'.
    Node putHelper(Node node, K key, V value) {
        if (node == null) {
            oldValue = null;
            return new Node(true, key, value, null, null);
        } else if (node.key.compareTo(key) > 0) {
            node.left = putHelper(node.left, key, value);
        } else if (node.key.compareTo(key) < 0) {
            node.right = putHelper(node.right, key, value);
        } else { // node.key == key
            oldValue = node.value;
            node.value = value;
        return rebalance(node);

    // Used by put, remove, putHelper and removeHelper,
    // in order to return the value previously stored in the node.
    private V oldValue;

    // Delete a key.
    public V remove(K key) {
        throw new UnsupportedOperationException("remove is not implemented yet");

    // Repair the red-black invariant by rebalancing the node.
    Node rebalance(Node node) {
        if (node == null) return node;

        // Skew
        if (isRed(node.right))
            node = rotateLeft(node);

        // Split part 1
        if (isRed(node.left) && isRed(node.left.left))
            node = rotateRight(node);

        // Split part 2
        if (isRed(node.left) && isRed(node.right)) {
            node.left.isRed = false;
            node.right.isRed = false;
            node.isRed = true;
        return node;

    Node rotateLeft(Node node) {
        // Left rotation.
        //    x                 y
        //   / \               / \
        //  A   y     ===>    x   C
        //     / \           / \
        //    B   C         A   B
        // Variables are named according to the picture above.
        Node x = node;
        Node A = x.left;
        Node y = x.right;
        Node B = y.left;
        Node C = y.right;
        // We also swap x's and y's colours (e.g. if x was black before, then y will be black afterwards).
        return new Node(x.isRed, y.key, y.value, new Node(y.isRed, x.key, x.value, A, B), C);

    Node rotateRight(Node node) {
        // Right rotation.
        //      x              y
        //     / \            / \
        //    y   C   ===>   A   x
        //   / \                / \
        //  A   B              B   C
        // Variables are named according to the picture above.
        Node x = node;
        Node y = x.left;
        Node A = y.left;
        Node B = y.right;
        Node C = x.right;
        // We also swap x's and y's colours (e.g. if x was black before, then y will be black afterwards).
        return new Node(x.isRed, y.key, y.value, A, new Node(y.isRed, x.key, x.value, B, C));

    // Iterate through all keys.
    // This is called when the user writes 'for (K key: bst) { ... }.'
    public Iterator<K> iterator() {
        // The easiest way to solve this is to add all keys to an
        // ArrayList, then iterate through that.
        ArrayList<K> keys = new ArrayList<>();
        iteratorHelper(root, keys);
        return keys.iterator();

    // Recursive helper method for 'iterator'
    void iteratorHelper(Node node, ArrayList<K> keys) {
        if (node == null) return;
        iteratorHelper(node.left, keys);
        iteratorHelper(node.right, keys);
# Python does not have internal classes, so we have to make the tree node class standalone.
class Node:
    """A node in a red-black tree."""

    def __init__(self, is_red, key, value, left = None, right = None):
        self._is_red = is_red
        self.key = key
        self.value = value
        self.left = left
        self.right = right

    def is_red(self):
        if self is None:
            return False
            return self._is_red

class RedBlackMap(Map):
    """A dictionary implemented using a binary search tree."""

    def __init__(self):
        self.root = None
        self.treeSize = 0

    def check_invariant(self):
        """Check that the invariant holds."""
        assert not Node.is_red(self.root), "red root"
        keys = list(self)
        assert len(keys) == self.treeSize, "wrong tree size"
        self.check_invariant_helper(self.root, None, None)

    def check_invariant_helper(node, lo, hi):
        """Recurive helper method for 'check_invariant'.
        Checks that the node is the root of a valid red-black tree, and that
        all keys k satisfy lo < k < hi. The test lo < k is skipped
        if lo is None, and k < hi is skipped if hi is None.
        Returns the "black height" of the tree."""

        if node is None: return 0

        assert lo is None or node.key > lo, "key too small"
        assert hi is None or node.key < hi, "key too big"

        assert not Node.is_red(node.right), "red right child"

        assert not (Node.is_red(node) and Node.is_red(node.left)), "red node with red left child"

        # Keys in the left subtree should be < node.key
        # Keys in the right subtree should be > node.key
        h1 = RedBlackMap.check_invariant_helper(node.left, lo, node.key)
        h2 = RedBlackMap.check_invariant_helper(node.right, node.key, hi)
        assert h1 == h2, "unbalanced tree"

        return h1 + (0 if Node.is_red(node) else 1)

    def isEmpty(self):
        """Return true if there are no keys."""
        return self.root is None
    def size(self):
        """Return the number of keys."""
        return self.treeSize

    def containsKey(self, key):
        """Return true if the key has an associated value."""
        return self.get(key) is not None

    def get(self, key):
        """Look up a key."""
        return self.get_helper(self.root, key)

    def get_helper(node, key):
        """Helper method for 'get'."""
        if node is None:
            return None
        elif node.key > key:
            return RedBlackMap.get_helper(node.left, key)
        elif node.key < key:
            return RedBlackMap.get_helper(node.right, key)
            return node.value

    def put(self, key, value):
        """Add a key-value pair, or update the value associated with an existing key. 
        Returns the value previously associated with the key, 
        or None if the key was not present."""
        self.root, old_value = self.put_helper(self.root, key, value)
        if Node.is_red(self.root):
            self.root._is_red = False
        if old_value is None:
            self.treeSize += 1
        return old_value

    def put_helper(node, key, value):
        """Recursive helper method for 'put'.
        Returns the updated node, and the value previously associated with the key."""
        if node is None:
            return Node(True, key, value, None, None), None
        elif node.key > key:
            node.left, old_value = RedBlackMap.put_helper(node.left, key, value)
        elif node.key < key:
            node.right, old_value = RedBlackMap.put_helper(node.right, key, value)
        else: # node.key == key
            old_value = node.value
            node.value = value
        return RedBlackMap.rebalance(node), old_value

    def remove(self, key):
        """Delete a key. 
        Not implemented yet!"""
        raise NotImplementedError("remove is not implemented yet")

    def rebalance(node):
        if node is None: return None
        # Skew
        if Node.is_red(node.right):
            node = RedBlackMap.rotate_left(node)

        # Split part 1
        if Node.is_red(node.left) and Node.is_red(node.left.left):
            node = RedBlackMap.rotate_right(node)

        # Split part 2
        if Node.is_red(node.left) and Node.is_red(node.right):
            node.left._is_red = False
            node.right._is_red = False
            node._is_red = True

        return node

    def rotate_left(node):
        Left rotation.

           x                 y
          / \               / \
         A   y     ===>    x   C
            / \           / \
           B   C         A   B
        # Variables are named according to the picture above.
        x = node
        A = x.left
        y = x.right
        B = y.left
        C = y.right

        # We also swap x's and y's colours
        # (e.g. if x was black before, then y will be black afterwards).
        return Node(is_red = x.is_red(), key = y.key, value = y.value,
                    left =
                        Node(is_red = y.is_red(), key = x.key, value = x.value,
                             left = A, right = B),
                    right = C)

    def rotate_right(node):
        Right rotation.

             x              y
            / \            / \
           y   C   ===>   A   x
          / \                / \
         A   B              B   C
        # Variables are named according to the picture above.
        x = node
        y = x.left
        A = y.left
        B = y.right
        C = x.right

        # We also swap x's and y's colours
        # (e.g. if x was black before, then y will be black afterwards).
        return Node(is_red = x.is_red(), key = y.key, value = y.value,
                    left = A,
                    right =
                        Node(is_red = y.is_red(), key = x.key, value = x.value,
                             left = B, right = C))

    def __iter__(self):
        """Iterate through all keys.
        This is called when the user writes 'for key in bst: ...'."""
        return self.iter_helper(self.root)

    def iter_helper(node):
        """Helper method for '__iter__'."""

        # This method is a generator:
        # Generators are an easy way to make iterators
        if node is None:
            for key in RedBlackMap.iter_helper(node.left):
                yield key
            yield node.key
            for key in RedBlackMap.iter_helper(node.right):
                yield key

    def __getitem__(self, key):
        """This is called when the user writes 'x = bst[key]'."""
        return self.get(key)
    def __setitem__(self, key, value):
        """This is called when the user writes 'bst[key] = value'."""
        self.put(key, value)

    def __contains__(self, key):
        """This is called when the user writes 'key in bst'."""
        return self.containsKey(key)

    def __delitem__(self, key):
        """This is called when the user writes 'del bst[key]'."""

   «  7.5. The AVL Tree   ::   Contents   ::   7.7. The Splay Tree (optional)  »

Close Window