[PYTHON] Understanding and implementing the Tonelli-Shanks algorithm (2)

Introduction

The previous article is here.

Implementation

First, let's repeat the conditions assumed this time.

If $ n $ is a multiple of $ p $, $ n \ equiv 0 \ {\ rm mod} \ p $ is hard to say a remainder, so I'll exclude it this time (exception is included in the code).

Legendre sign

\left(\begin{array} nn\\\ p \end{array}\right) := n^{\frac{p-1}{2}} \equiv \begin{cases} 1 \ {\ rm mod} \ p \ Leftrightarrow n is quadratic residue \\\ -1 \ {\ rm mod} \ p \ Leftrightarrow n is a square non-remainder \end{cases}

was. Here, of course, if $ n $ is a multiple of $ p $ $\left(\begin{array} nn\\\ p \end{array}\right) = 0$ Therefore, exception handling is done here.

Note that pow (a, b, c) represents $ a ^ b \ {\ rm mod} \ c $ [^ 2].

python


def legendre_symbol(n, p):
    ls = pow(n, (p - 1) // 2, p)
    if ls == 1:
        return 1
    #pow function is 0~ p-Returns a value in the range of 1
    elif ls == p - 1:
        return -1
    else:
        # ls ==0, that is, when n is a multiple of p
        raise Exception('n:{} = 0 mod p:{}'.format(n, p))

When p is a prime number divided by 4 and has a remainder of 3

x = \pm n^{\frac{p+1}{4}}

Was the answer. We also define a check_sqrt function that confirms that the response is correct.

python


#Basically there should be no assertion error here
def check_sqrt(x, n, p):
    assert(pow(x, 2, p) == n % p)

def modular_sqrt(n, p):
    if p % 4 == 3:
        x = pow(n, (p + 1) // 4, p)
        check_sqrt(x, n, p)
        return [x, p - x]
    else:
        #I will explain
        pass

When p is a prime number that is left over by dividing by 4

Here is the implementation of the Tonelli-Shanks algorithm.

Step 1.

p-1 = Q \cdot 2^S

($ Q $ is an odd number and $ S $ is a positive integer).

In Python, uppercase letters often mean constants, so use lowercase q, s.

python


def modular_sqrt(n, p):
    ...
    else:
        # Step 1.
        q, s = p - 1, 0
        while q % 2 == 0:
            q //= 2
            s += 1
    ...

Step 2.

Randomly select $ z $ which is ** square non-remainder **.

As I mentioned last time, half are square non-modulo, so we brute force from $ 2 $.

The reason why we don't start with $ 1 $ is that $ x $, which is $ x ^ 2 \ equiv 1 \ {\ rm mod} p , naturally exists ( x = 1 $) for any $ p $. Because $ 1 $ is a quadratic residue.

python


def modular_sqrt(n, p):
    ...
    else:
        # Step 1.
        q, s = p - 1, 0
        while q % 2 == 0:
            q //= 2
            s += 1
        
        # Step 2.
        z = 2
        while legendre_symbol(z, p) != -1:
            z += 1
    ...

Step 3.

\begin{cases} M_0 = S\\\ c_0 = z^Q\\\ t_0 = n^Q\\\ R_0 = n^{\frac{Q+1}{2}} \end{cases}

This remains the same. As before, define it in all lowercase.

python


def modular_sqrt(n, p):
    ...
    else:
        ...
        # Step 2.
        z = 2
        while legendre_symbol(z, p) != -1:
            z += 1

        # Step 3.
        m, c, t, r = s, pow(z, q, p), pow(n, q, p), pow(n, (q + 1) // 2, p)
    ...

Step 4.

  1. If $ t_i \ equiv 1 $, then $ \ pm R_i $ is the square root of $ n $ and exits the loop statement.

  2. If not, update the value as follows:

\begin{cases} M_ {i + 1} = \ left (\ left (t_i \ right) ^ {2 ^ {j}} \ equiv The minimum j that satisfies 1, but 0

Now, there are two supplementary explanations for implementing this.

The first one.

When finding $ M_ {i + 1} $, it is divided into $ j = 1, 2, \ cdots $ in order, but "$ t_i $ is multiplied by $ 2 $ times and $ (t_i) ^ 2 $ And check if it becomes 1. If not, multiply $ t_i $ by $ 4 $ times to calculate $ (t_i) ^ 4 $ and check if it becomes 1 ... "The following code is a little wasteful. Don't you think there are many? (M_update corresponds to $ M_ {i + 1} $)

python


for j in range(1, m):
    if pow(t, pow(2, j), p) == 1:
        m_update = j
        break

If you have calculated $ (t_i) ^ 2 $, you can square it to get $ (t_i) ^ 4 $, so let's reuse it [^ 3].

python


pow_t = pow(t, 2, p)
for j in range(1, m):
    if pow_t == 1:
        m_update = j
        break
    pow_t = pow(pow_t, 2, p)

The second.

b_i = \left(c_i\right)^{2^{M_i - M_{i+1}-1}}

If you define, the value update can be neatly written as follows. This symbol can also be found on Wikipedia.

\begin{cases} M_ {i + 1} = \ left (\ left (t_i \ right) ^ {2 ^ {j}} \ equiv The minimum j that satisfies 1, but 0

Last time, I thought it would be confusing to introduce more variables, so I omitted it.

Based on the above two points, you can write the code as follows.

python


def modular_sqrt(n, p):
    ...
    else:
        ...
        # Step 3.
        m, c, t, r = s, pow(z, q, p), pow(n, q, p), pow(n, (q + 1) // 2, p)
        
        # Step 4.
        while t != 1:
            pow_t = pow(t, 2, p)
            for j in range(1, m):
                if pow_t == 1:
                    m_update = j
                    break
                pow_t = pow(pow_t, 2, p)
            b = pow(c, int(pow(2, m - m_update - 1)), p)
            m, c, t, r = m_update, pow(b, 2, p), t * pow(b, 2, p) % p, r * b % p

        #Confirmation of answer
        check_sqrt(r, n, p)
        return [r, p - r]

As an implementation note, use both $ M_i $ and $ M_ {i + 1} $ when updating to $ c_ {i + 1}, t_ {i + 1}, R_ {i + 1} $ So, once you set m_update = j, don't update m immediately.

Other

You can actually check if $ p $ is a prime number in polynomial time [^ 4].

python


from gmpy2 import is_prime

is_prime(p)

Or

python


from Crypto.Util.number import isPrime

isPrime(p)

High-speed primality test is possible.

I think that both are modules that are not included by default, so you need to install them with pip3.

Whole source code

python


#!/usr/bin/env python3

from Crypto.Util.number import isPrime
# from gmpy2 import is_prime

def legendre_symbol(n, p):
    ls = pow(n, (p - 1) // 2, p)
    if ls == 1:
        return 1
    elif ls == p - 1:
        return -1
    else:
        # in case ls == 0
        raise Exception('n:{} = 0 mod p:{}'.format(n, p))

def check_sqrt(x, n, p):
    assert(pow(x, 2, p) == n % p)

def modular_sqrt(n:int, p:int) -> list:
    if type(n) != int or type(p) != int:
        raise TypeError('n and p must be integers')

    if p < 3:
        raise Exception('p must be equal to or more than 3')

    if not isPrime(p):
        raise Exception('p must be a prime number. {} is a composite number'.format(p))

    if legendre_symbol(n, p) == -1:
        raise Exception('n={} is Quadratic Nonresidue modulo p={}'.format(n, p))

    if p % 4 == 3:
        x = pow(n, (p + 1) // 4, p)
        check_sqrt(x, n, p)
        return [x, p - x]
    
    # Tonelli-Shanks
    q, s = p - 1, 0
    while q % 2 == 0:
        q //= 2
        s += 1
    z = 2
    while legendre_symbol(z, p) != -1:
        z += 1
    m, c, t, r = s, pow(z, q, p), pow(n, q, p), pow(n, (q + 1) // 2, p)
    while t != 1:
        pow_t = pow(t, 2, p)
        for j in range(1, m):
            if pow_t == 1:
                m_update = j
                break
            pow_t = pow(pow_t, 2, p)
        b = pow(c, int(pow(2, m - m_update - 1)), p)
        m, c, t, r = m_update, pow(b, 2, p), t * pow(b, 2, p) % p, r * b % p
    check_sqrt(r, n, p)
    return [r, p - r]

print(modular_sqrt(5, 41))
# => [28, 13]

[^ 1]: If a solution is found with $ 0 <x <p $, then adding (or subtracting) $ p $ to that $ x $ is also a natural solution, so this time Limited to the range of $ 0 <x <p $. [^ 2]: In Python, you can write ʻa ** b% c, but using the powfunction works faster. [^ 3]: For example, if $ 0 <j <10 $, the former can call thepow` function up to 18 times, while the latter can only call it 9 times. It is a very about computational complexity evaluation, but from this point of view, the latter was adopted this time. However, in reality, there is almost no difference because the original algorithm is fast regardless of which one is used. [^ 4]: "$ p $ is not a prime number (= composite number)" and "you can see the result of factorization of $ p $" are different. This time I'm touching on the former. If the latter is done in polynomial time, the theory of security such as so-called RSA cryptography will be destroyed. (However, it is still an unsolved problem, just because it is said that polynomial time will not be possible.)

Recommended Posts

Understanding and implementing the Tonelli-Shanks algorithm (2)
Understanding and implementing the Tonelli-Shanks algorithm (1)
Understanding and implementing Style GAN
Behind the graph drawing algorithm and Networkx
Full understanding of the concepts of Bellman-Ford and Dijkstra
Understanding the Tensor (2): Shape
Rabbit and turtle algorithm
Understanding the meaning of complex and bizarre normal distribution formulas
Solve the spiral book (algorithm and data structure) with python!
About understanding the 3-point reader [...]
Touch the mock and stub