[PYTHON] Implement UnionFind (equivalent) in 10 lines

** [Disjoint Set Data Structure (Union-Find)](https://ja.wikipedia.org/wiki/%E7%B4%A0%E9%9B%86%E5%90%88%E3%83%87 % E3% 83% BC% E3% 82% BF% E6% A7% 8B% E9% 80% A0) ** is a data structure that classifies elements into groups that do not intersect each other. The standard implementation is the Union-Find tree. I think everyone already knows that.

There are many implementation examples of Union-Find tree in Qiita, but I have been wondering for a long time.

  1. It is difficult to understand because there is a recursive call to find (). Can you write it more clearly and shortly?
  2. Isn't it actually faster to use the built-in collection type?

10-line implementation

So I wrote a short implementation in Python using dict and frozenset. There are 10 lines including spaces.

class EasyUnionFind:
    def __init__(self, n):
        self._groups = {x: frozenset([x]) for x in range(n)}

    def union(self, x, y):
        group = self._groups[x] | self._groups[y]
        self._groups.update((c, group) for c in group)

    def groups(self):
        return frozenset(self._groups.values())

I tried to compare

Let's compare it with the Union-Find tree implementation. The comparison target is [Implementation introduced] in Article with the most likes found by searching for "Union Find Python" on Qiita. (https://www.kumilog.net/entry/union-find).

** The result is that this 10-line implementation is slower **. Also, it seems that the difference increases as the number of elements increases. Sorry.

Element count Union-Find tree implementation This 10-line implementation Ratio of travel time
1000 0.72 seconds 1.17 seconds 1.63
2000 1.46 seconds 2.45 seconds 1.68
4000 2.93 seconds 5.14 seconds 1.75
8000 6.01 seconds 11.0 seconds 1.83

However, even if it is slow, it is about twice as slow as the Union-Find tree, so it may be useful in some cases.

Comparison code and execution result

code:

import random
import timeit
import sys
import platform


class EasyUnionFind:
    """
Implementation using dict and frozenset.
    """
    def __init__(self, n):
        self._groups = {x: frozenset([x]) for x in range(n)}

    def union(self, x, y):
        group = self._groups[x] | self._groups[y]
        self._groups.update((c, group) for c in group)

    def groups(self):
        return frozenset(self._groups.values())


class UnionFind(object):
    """
Typical Union-Implementation by Find tree.
    https://www.kumilog.net/entry/union-I copied the example implementation of find,
Delete unnecessary member functions this time.groups()Was added.
    """
    def __init__(self, n=1):
        self.par = [i for i in range(n)]
        self.rank = [0 for _ in range(n)]
        self.size = [1 for _ in range(n)]
        self.n = n

    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            self.par[x] = self.find(self.par[x])
            return self.par[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                x, y = y, x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1
            self.par[y] = x
            self.size[x] += self.size[y]

    def groups(self):
        groups = {}
        for i in range(self.n):
            groups.setdefault(self.find(i), []).append(i)
        return frozenset(frozenset(group) for group in groups.values())


def test1():
    """Test if the results of the two implementations are the same. If there is a difference, an AssertionError is sent."""
    print("===== TEST1 =====")
    random.seed(20200228)
    n = 2000
    for _ in range(1000):
        elements = range(n)
        pairs = [
            (random.choice(elements), random.choice(elements))
            for _ in range(n // 2)
        ]
        uf1 = UnionFind(n)
        uf2 = EasyUnionFind(n)
        for x, y in pairs:
            uf1.union(x, y)
            uf2.union(x, y)
        assert uf1.groups() == uf2.groups()
    print('ok')
    print()


def test2():
    """
Output the time required for two implementations while increasing the number of elements.
    """
    print("===== TEST2 =====")
    random.seed(20200228)

    def execute_union_find(klass, n, test_datum):
        for pairs in test_datum:
            uf = klass(n)
            for x, y in pairs:
                uf.union(x, y)

    timeit_number = 1
    for n in [1000, 2000, 4000, 8000]:
        print(f"n={n}")
        test_datum = []
        for _ in range(1000):
            elements = range(n)
            pairs = [
                (random.choice(elements), random.choice(elements))
                for _ in range(n // 2)
            ]
            test_datum.append(pairs)

        t = timeit.timeit(lambda: execute_union_find(UnionFind, n, test_datum), number=timeit_number)
        print("UnionFind", t)

        t = timeit.timeit(lambda: execute_union_find(EasyUnionFind, n, test_datum), number=timeit_number)
        print("EasyUnionFind", t)
        print()

def main():
    print(sys.version)
    print(platform.platform())
    print()
    test1()
    test2()

if __name__ == "__main__":
    main()

Execution result:

3.7.6 (default, Dec 30 2019, 19:38:28)
[Clang 11.0.0 (clang-1100.0.33.16)]
Darwin-18.7.0-x86_64-i386-64bit

===== TEST1 =====
ok

===== TEST2 =====
n=1000
UnionFind 0.7220867589999997
EasyUnionFind 1.1789850389999987

n=2000
UnionFind 1.460918638999999
EasyUnionFind 2.4546459260000013

n=4000
UnionFind 2.925022847000001
EasyUnionFind 5.142797402000003

n=8000
UnionFind 6.01257184
EasyUnionFind 10.963117657000005

Recommended Posts

Implement UnionFind (equivalent) in 10 lines
Implement and understand union-find trees in Go
Implement Enigma in python
Implement recommendations in Python
Implement XENO in python
Implement sum in Python
Implement Traceroute in Python 3
Implement LSTM AutoEncoder in Keras
[Python 3] Prime factorization in 14 lines
Implement follow functionality in Django
Implement timer function in pygame
Implement Style Transfer in Pytorch
Implement recursive closures in Go
Implement naive bayes in Python 3.3
Implement ancient ciphers in python
Implement Redis Mutex in Python
Implement extension field in Python
Make python segfault in 2 lines
Implement fast RPC in Python
Implement method chain in Python
Implement Dijkstra's Algorithm in python
Implement Slack chatbot in Python
Python install in 2 lines @Windows
Implement blockchain with about 60 lines
Implement Gaussian process in Pyro
Implement stacking learning in Python [Kaggle]
Implement Table Driven Test in Java
Implement R's power.prop.test function in python
Implemented image segmentation in python (Union-Find)
Implement a date setter in Tkinter
Implement the Singleton pattern in Python
Make python segfault in three lines
Quickly implement REST API in Python
Show dividing lines in matplotlib histogram