7.6. Red-Black Trees (code only)¶
For information about red-black trees, see the lecture handout “2-3 trees and red-black trees” under Theme 4 (Search trees).
Here is an implementation of red-black trees:
// A dictionary implemented using an red-black tree.
public class RedBlackMap<K extends Comparable<K>, V> implements Map<K, V> {
Node root = null; // The root of the red black tree.
int treeSize = 0; // The size of the tree.
// A node in an red-black tree.
class Node {
K key;
V value;
Node left;
Node right;
boolean isRed;
Node(boolean isRed, K key, V value, Node left, Node right) {
this.key = key;
this.value = value;
this.left = left;
this.right = right;
this.isRed = isRed;
}
}
// Check if a node is red. 'null' is always black.
boolean isRed(Node node) {
if (node == null) return false;
return node.isRed;
}
// Check that the invariant holds.
void checkInvariant() {
if (isRed(root))
throw new AssertionError("red root");
ArrayList<K> keys = new ArrayList<>();
iteratorHelper(root, keys);
if (keys.size() != treeSize)
throw new AssertionError("wrong tree size");
checkInvariantHelper(root, null, null);
}
// Recursive helper method for 'check_invariant'.
// Checks that the node is the root of a valid red-black tree, and that
// all keys k satisfy lo < k < hi. The test lo < k is skipped
// if lo is None, and k < hi is skipped if hi is None.
// Returns the "black height" of the tree.
int checkInvariantHelper(Node node, K lo, K hi) {
if (node == null) return 0;
if (lo != null && node.key.compareTo(lo) <= 0)
throw new AssertionError("key too small");
if (hi != null && node.key.compareTo(hi) >= 0)
throw new AssertionError("key too big");
if (isRed(node.right))
throw new AssertionError("red right child");
if (isRed(node) && isRed(node.left))
throw new AssertionError("red node with red left child");
// Keys in the left subtree should be < node.key
// Keys in the right subtree should be > node.key
int h1 = checkInvariantHelper(node.left, lo, node.key);
int h2 = checkInvariantHelper(node.right, node.key, hi);
if (h1 != h2)
throw new AssertionError("unbalanced tree");
return h1 + (isRed(node) ? 0 : 1);
}
// Return true if there are no keys.
public boolean isEmpty() {
return root == null;
}
// Return the number of keys.
public int size() {
return treeSize;
}
// Return true if the key has an associated value.
public boolean containsKey(K key) {
return get(key) != null;
}
// Look up a key.
public V get(K key) {
return getHelper(root, key);
}
// Recursive helper method for 'get'.
V getHelper(Node node, K key) {
if (node == null)
return null;
if (node.key.compareTo(key) > 0)
return getHelper(node.left, key);
else if (node.key.compareTo(key) < 0)
return getHelper(node.right, key);
else // node.key == key
return node.value;
}
// Add a key-value pair, or update the value associated with an existing key.
// Returns the previous value associated with the key,
// or null if the key wasn't previously present.
public V put(K key, V value) {
root = putHelper(root, key, value);
if (isRed(root))
root.isRed = false;
if (oldValue == null)
treeSize++;
return oldValue;
}
// Recursive helper method for 'put'.
Node putHelper(Node node, K key, V value) {
if (node == null) {
oldValue = null;
return new Node(true, key, value, null, null);
} else if (node.key.compareTo(key) > 0) {
node.left = putHelper(node.left, key, value);
} else if (node.key.compareTo(key) < 0) {
node.right = putHelper(node.right, key, value);
} else { // node.key == key
oldValue = node.value;
node.value = value;
}
return rebalance(node);
}
// Used by put, remove, putHelper and removeHelper,
// in order to return the value previously stored in the node.
private V oldValue;
// Delete a key.
public V remove(K key) {
throw new UnsupportedOperationException("remove is not implemented yet");
}
// Repair the red-black invariant by rebalancing the node.
Node rebalance(Node node) {
if (node == null) return node;
// Skew
if (isRed(node.right))
node = rotateLeft(node);
// Split part 1
if (isRed(node.left) && isRed(node.left.left))
node = rotateRight(node);
// Split part 2
if (isRed(node.left) && isRed(node.right)) {
node.left.isRed = false;
node.right.isRed = false;
node.isRed = true;
}
return node;
}
Node rotateLeft(Node node) {
// Left rotation.
//
// x y
// / \ / \
// A y ===> x C
// / \ / \
// B C A B
//
// Variables are named according to the picture above.
Node x = node;
Node A = x.left;
Node y = x.right;
Node B = y.left;
Node C = y.right;
// We also swap x's and y's colours (e.g. if x was black before, then y will be black afterwards).
return new Node(x.isRed, y.key, y.value, new Node(y.isRed, x.key, x.value, A, B), C);
}
Node rotateRight(Node node) {
// Right rotation.
//
// x y
// / \ / \
// y C ===> A x
// / \ / \
// A B B C
//
// Variables are named according to the picture above.
Node x = node;
Node y = x.left;
Node A = y.left;
Node B = y.right;
Node C = x.right;
// We also swap x's and y's colours (e.g. if x was black before, then y will be black afterwards).
return new Node(x.isRed, y.key, y.value, A, new Node(y.isRed, x.key, x.value, B, C));
}
// Iterate through all keys.
// This is called when the user writes 'for (K key: bst) { ... }.'
public Iterator<K> iterator() {
// The easiest way to solve this is to add all keys to an
// ArrayList, then iterate through that.
ArrayList<K> keys = new ArrayList<>();
iteratorHelper(root, keys);
return keys.iterator();
}
// Recursive helper method for 'iterator'
void iteratorHelper(Node node, ArrayList<K> keys) {
if (node == null) return;
iteratorHelper(node.left, keys);
keys.add(node.key);
iteratorHelper(node.right, keys);
}
}
# Python does not have internal classes, so we have to make the tree node class standalone.
class Node:
"""A node in a red-black tree."""
def __init__(self, is_red, key, value, left = None, right = None):
self._is_red = is_red
self.key = key
self.value = value
self.left = left
self.right = right
def is_red(self):
if self is None:
return False
else:
return self._is_red
class RedBlackMap(Map):
"""A dictionary implemented using a binary search tree."""
def __init__(self):
self.root = None
self.treeSize = 0
def check_invariant(self):
"""Check that the invariant holds."""
assert not Node.is_red(self.root), "red root"
keys = list(self)
assert len(keys) == self.treeSize, "wrong tree size"
self.check_invariant_helper(self.root, None, None)
@staticmethod
def check_invariant_helper(node, lo, hi):
"""Recurive helper method for 'check_invariant'.
Checks that the node is the root of a valid red-black tree, and that
all keys k satisfy lo < k < hi. The test lo < k is skipped
if lo is None, and k < hi is skipped if hi is None.
Returns the "black height" of the tree."""
if node is None: return 0
assert lo is None or node.key > lo, "key too small"
assert hi is None or node.key < hi, "key too big"
assert not Node.is_red(node.right), "red right child"
assert not (Node.is_red(node) and Node.is_red(node.left)), "red node with red left child"
# Keys in the left subtree should be < node.key
# Keys in the right subtree should be > node.key
h1 = RedBlackMap.check_invariant_helper(node.left, lo, node.key)
h2 = RedBlackMap.check_invariant_helper(node.right, node.key, hi)
assert h1 == h2, "unbalanced tree"
return h1 + (0 if Node.is_red(node) else 1)
def isEmpty(self):
"""Return true if there are no keys."""
return self.root is None
def size(self):
"""Return the number of keys."""
return self.treeSize
def containsKey(self, key):
"""Return true if the key has an associated value."""
return self.get(key) is not None
def get(self, key):
"""Look up a key."""
return self.get_helper(self.root, key)
@staticmethod
def get_helper(node, key):
"""Helper method for 'get'."""
if node is None:
return None
elif node.key > key:
return RedBlackMap.get_helper(node.left, key)
elif node.key < key:
return RedBlackMap.get_helper(node.right, key)
else:
return node.value
def put(self, key, value):
"""Add a key-value pair, or update the value associated with an existing key.
Returns the value previously associated with the key,
or None if the key was not present."""
self.root, old_value = self.put_helper(self.root, key, value)
if Node.is_red(self.root):
self.root._is_red = False
if old_value is None:
self.treeSize += 1
return old_value
@staticmethod
def put_helper(node, key, value):
"""Recursive helper method for 'put'.
Returns the updated node, and the value previously associated with the key."""
if node is None:
return Node(True, key, value, None, None), None
elif node.key > key:
node.left, old_value = RedBlackMap.put_helper(node.left, key, value)
elif node.key < key:
node.right, old_value = RedBlackMap.put_helper(node.right, key, value)
else: # node.key == key
old_value = node.value
node.value = value
return RedBlackMap.rebalance(node), old_value
def remove(self, key):
"""Delete a key.
Not implemented yet!"""
raise NotImplementedError("remove is not implemented yet")
@staticmethod
def rebalance(node):
if node is None: return None
# Skew
if Node.is_red(node.right):
node = RedBlackMap.rotate_left(node)
# Split part 1
if Node.is_red(node.left) and Node.is_red(node.left.left):
node = RedBlackMap.rotate_right(node)
# Split part 2
if Node.is_red(node.left) and Node.is_red(node.right):
node.left._is_red = False
node.right._is_red = False
node._is_red = True
return node
@staticmethod
def rotate_left(node):
"""
Left rotation.
x y
/ \ / \
A y ===> x C
/ \ / \
B C A B
"""
# Variables are named according to the picture above.
x = node
A = x.left
y = x.right
B = y.left
C = y.right
# We also swap x's and y's colours
# (e.g. if x was black before, then y will be black afterwards).
return Node(is_red = x.is_red(), key = y.key, value = y.value,
left =
Node(is_red = y.is_red(), key = x.key, value = x.value,
left = A, right = B),
right = C)
@staticmethod
def rotate_right(node):
"""
Right rotation.
x y
/ \ / \
y C ===> A x
/ \ / \
A B B C
"""
# Variables are named according to the picture above.
x = node
y = x.left
A = y.left
B = y.right
C = x.right
# We also swap x's and y's colours
# (e.g. if x was black before, then y will be black afterwards).
return Node(is_red = x.is_red(), key = y.key, value = y.value,
left = A,
right =
Node(is_red = y.is_red(), key = x.key, value = x.value,
left = B, right = C))
def __iter__(self):
"""Iterate through all keys.
This is called when the user writes 'for key in bst: ...'."""
return self.iter_helper(self.root)
@staticmethod
def iter_helper(node):
"""Helper method for '__iter__'."""
# This method is a generator:
# https://docs.python.org/3/howto/functional.html#generators
# Generators are an easy way to make iterators
if node is None:
return
else:
for key in RedBlackMap.iter_helper(node.left):
yield key
yield node.key
for key in RedBlackMap.iter_helper(node.right):
yield key
def __getitem__(self, key):
"""This is called when the user writes 'x = bst[key]'."""
return self.get(key)
def __setitem__(self, key, value):
"""This is called when the user writes 'bst[key] = value'."""
self.put(key, value)
def __contains__(self, key):
"""This is called when the user writes 'key in bst'."""
return self.containsKey(key)
def __delitem__(self, key):
"""This is called when the user writes 'del bst[key]'."""
self.remove(key)