** [Disjoint Set Data Structure (Union-Find)](https://ja.wikipedia.org/wiki/%E7%B4%A0%E9%9B%86%E5%90%88%E3%83%87 % E3% 83% BC% E3% 82% BF% E6% A7% 8B% E9% 80% A0) ** is a data structure that classifies elements into groups that do not intersect each other. The standard implementation is the Union-Find tree. I think everyone already knows that.
There are many implementation examples of Union-Find tree in Qiita, but I have been wondering for a long time.
find (). Can you write it more clearly and shortly?So I wrote a short implementation in Python using dict and frozenset. There are 10 lines including spaces.
class EasyUnionFind:
    def __init__(self, n):
        self._groups = {x: frozenset([x]) for x in range(n)}
    def union(self, x, y):
        group = self._groups[x] | self._groups[y]
        self._groups.update((c, group) for c in group)
    def groups(self):
        return frozenset(self._groups.values())
Let's compare it with the Union-Find tree implementation. The comparison target is [Implementation introduced] in Article with the most likes found by searching for "Union Find Python" on Qiita. (https://www.kumilog.net/entry/union-find).
** The result is that this 10-line implementation is slower **. Also, it seems that the difference increases as the number of elements increases. Sorry.
| Element count | Union-Find tree implementation | This 10-line implementation | Ratio of travel time | 
|---|---|---|---|
| 1000 | 0.72 seconds | 1.17 seconds | 1.63 | 
| 2000 | 1.46 seconds | 2.45 seconds | 1.68 | 
| 4000 | 2.93 seconds | 5.14 seconds | 1.75 | 
| 8000 | 6.01 seconds | 11.0 seconds | 1.83 | 
However, even if it is slow, it is about twice as slow as the Union-Find tree, so it may be useful in some cases.
code:
import random
import timeit
import sys
import platform
class EasyUnionFind:
    """
Implementation using dict and frozenset.
    """
    def __init__(self, n):
        self._groups = {x: frozenset([x]) for x in range(n)}
    def union(self, x, y):
        group = self._groups[x] | self._groups[y]
        self._groups.update((c, group) for c in group)
    def groups(self):
        return frozenset(self._groups.values())
class UnionFind(object):
    """
Typical Union-Implementation by Find tree.
    https://www.kumilog.net/entry/union-I copied the example implementation of find,
Delete unnecessary member functions this time.groups()Was added.
    """
    def __init__(self, n=1):
        self.par = [i for i in range(n)]
        self.rank = [0 for _ in range(n)]
        self.size = [1 for _ in range(n)]
        self.n = n
    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            self.par[x] = self.find(self.par[x])
            return self.par[x]
    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                x, y = y, x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1
            self.par[y] = x
            self.size[x] += self.size[y]
    def groups(self):
        groups = {}
        for i in range(self.n):
            groups.setdefault(self.find(i), []).append(i)
        return frozenset(frozenset(group) for group in groups.values())
def test1():
    """Test if the results of the two implementations are the same. If there is a difference, an AssertionError is sent."""
    print("===== TEST1 =====")
    random.seed(20200228)
    n = 2000
    for _ in range(1000):
        elements = range(n)
        pairs = [
            (random.choice(elements), random.choice(elements))
            for _ in range(n // 2)
        ]
        uf1 = UnionFind(n)
        uf2 = EasyUnionFind(n)
        for x, y in pairs:
            uf1.union(x, y)
            uf2.union(x, y)
        assert uf1.groups() == uf2.groups()
    print('ok')
    print()
def test2():
    """
Output the time required for two implementations while increasing the number of elements.
    """
    print("===== TEST2 =====")
    random.seed(20200228)
    def execute_union_find(klass, n, test_datum):
        for pairs in test_datum:
            uf = klass(n)
            for x, y in pairs:
                uf.union(x, y)
    timeit_number = 1
    for n in [1000, 2000, 4000, 8000]:
        print(f"n={n}")
        test_datum = []
        for _ in range(1000):
            elements = range(n)
            pairs = [
                (random.choice(elements), random.choice(elements))
                for _ in range(n // 2)
            ]
            test_datum.append(pairs)
        t = timeit.timeit(lambda: execute_union_find(UnionFind, n, test_datum), number=timeit_number)
        print("UnionFind", t)
        t = timeit.timeit(lambda: execute_union_find(EasyUnionFind, n, test_datum), number=timeit_number)
        print("EasyUnionFind", t)
        print()
def main():
    print(sys.version)
    print(platform.platform())
    print()
    test1()
    test2()
if __name__ == "__main__":
    main()
Execution result:
3.7.6 (default, Dec 30 2019, 19:38:28)
[Clang 11.0.0 (clang-1100.0.33.16)]
Darwin-18.7.0-x86_64-i386-64bit
===== TEST1 =====
ok
===== TEST2 =====
n=1000
UnionFind 0.7220867589999997
EasyUnionFind 1.1789850389999987
n=2000
UnionFind 1.460918638999999
EasyUnionFind 2.4546459260000013
n=4000
UnionFind 2.925022847000001
EasyUnionFind 5.142797402000003
n=8000
UnionFind 6.01257184
EasyUnionFind 10.963117657000005
Recommended Posts