%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
from unionfind import *
class GraphNode:
def __init__(self):
self.data = {}
def draw_2d_graph(nodes, edges, draw_nodes=True, draw_labels=False, linewidth=2):
ax = plt.gca()
ax.set_facecolor((0.9, 0.9, 0.9))
for (i, j, d) in edges:
x1, y1 = nodes[i].data['x'], nodes[i].data['y']
x2, y2 = nodes[j].data['x'], nodes[j].data['y']
plt.plot([x1, x2], [y1, y2], linewidth=linewidth)
for i, n in enumerate(nodes):
if draw_nodes:
plt.scatter(n.data['x'], n.data['y'], 100, c='k')
if draw_labels:
plt.text(n.data['x']+0.002, n.data['y']+0.002, "{}".format(i), zorder=10, c='r', fontsize='xx-large')
def dist_of_edge(e):
return e[2]
def get_mst_kruskal(nodes, edges):
edges = sorted(edges, key = dist_of_edge)
djset = UFFast(len(nodes))
new_edges = []
for e in edges:
(i, j, d) = e
if not djset.find(i, j):
djset.union(i, j)
new_edges.append(e)
return new_edges
For a graph that has $E$ edges and $N$ nodes, the worst-case time complexity of Kruskal's algorithm is $O(E \log E)$. This is because of the dominant step of sorting the edges in increasing order of distance with a comparison-based sort. The step of using union find is actually quite fast; over all $E$ edges, it takes $E \alpha(E)$ time where $\alpha(E)$ is the inverse Ackermann function (due to rank-based balancing and path compression), which, for all practical purposes, is a constant.
For a Euclidean MST in which the nodes correspond to points chosen in the plane and we have the complete graph on all $N$ nodes, then there are $N(N+1)/2$ or $O(N^2)$ edges in the whole graph. This means that Kruskal's algorithm takes $O(N^2 \log (N^2)) = O(N^2 2 \log(N)) = O(N^2 \log N)$ time.
However, we can do better than this for Euclidean MSTs if we use a graph arising from a Delaunay Triangulation of the points. The edges Delaunay Triangulation form a planar graph which has $O(N)$ edges, and it can be shown that a subset of the Delaunay edges forms a minimum spanning tree. Therefore, we can narrow down to $O(N)$ edges before applying Kruskals' algorithm, which means the total complexity goes down to $O(N \log(N))$. The code below shows how to do this using scipy
from scipy.spatial import Delaunay
def make_delaunay_graph(N):
x = np.random.rand(N)
y = np.random.rand(N)
nodes = []
for i in range(N):
n = GraphNode()
n.data = {'x':x[i], 'y':y[i]}
nodes.append(n)
tri = Delaunay(np.array([x, y]).T).simplices
edges = set()
for i in range(tri.shape[0]):
for k in range(3):
i1, i2 = tri[i, k], tri[i, (k+1)%3]
d = np.sqrt(np.sum((x[i1]-x[i2])**2 + (y[i1]-y[i2])**2))
edges.add((i1, i2, d))
return nodes, list(edges)
np.random.seed(0)
nodes, edges = make_delaunay_graph(20)
new_edges = get_mst_kruskal(nodes, edges)
plt.figure(figsize=(10, 15))
plt.subplot(211)
draw_2d_graph(nodes, edges)
plt.subplot(212)
draw_2d_graph(nodes, new_edges)
One cool application of spanning trees is the automatic creation of mazes. If we create a graph in which the nodes are on a grid and the edges connect left/right/up/down neighbors, and we assign random distances to all edge lengths, then running Kruskal's algorithm yields edges which form a maze. Since it yields a spanning tree, we can find a path between any two pairs of vertices. If the start vertex and the end vertex are on the boundary of the grid, then we can create a traditional maze
def make_grid_graph(N, seed = 0):
"""
Parameters
----------
N: int
Resolution of grid
"""
np.random.seed(seed)
nodes = []
for i in range(N):
for j in range(N):
n = GraphNode()
n.data = {'x':j, 'y':i}
nodes.append(n)
edges = []
neighbs = [[-1, 0], [1, 0], [0, -1], [0, 1]]
for i in range(N):
for j in range(N):
idx1 = i*N + j
for [di, dj] in neighbs:
ii = i + di
jj = j + dj
if ii >= 0 and jj >= 0 and ii < N and jj < N:
idx2 = ii*N + jj
edges.append((idx1, idx2, np.random.rand()))
return nodes, edges
def draw_grid_edges_image(edges, N):
I = np.zeros((2*N-1, 2*N-1))
for (i, j, d) in edges:
x1, y1 = i%N, i//N
x2, y2 = j%N, j//N
x1, x2 = min(x1, x2), max(x1, x2)
y1, y2 = min(y1, y2), max(y1, y2)
if np.abs(x1-x2) > 0:
# Horizontal line
I[2*y1, 2*x1:2*x2+1] = 1
else:
I[2*y1:2*y2+1, 2*x1] = 1
plt.imshow(I)
plt.gca().invert_yaxis()
plt.axis('off')
N = 10
nodes, edges = make_grid_graph(N)
new_edges = get_mst_kruskal(nodes, edges)
plt.figure(figsize=(10, 5))
plt.subplot(121)
draw_2d_graph(nodes, new_edges, linewidth=3)
plt.subplot(122)
draw_grid_edges_image(new_edges, N)
plt.figure(figsize=(10, 10))
N = 100
nodes, edges = make_grid_graph(N)
new_edges = get_mst_kruskal(nodes, edges)
draw_grid_edges_image(new_edges, N)
plt.show()