Finite Fields


import numpy
from math import ceil, floor, log, sqrt
from itertools import *
from random import *
from decimal import *
from functools import reduce


def XEuclidean(u, v):
    "extended euclidean algorithm"
    U = numpy.array([1, 0, u])
    V = numpy.array([0, 1, v])
    while V[2] != 0:
        q = U[2] // V[2]
        T = U - V * q
        U = V
        V = T
    # (u0,u1,u2) -> u*u0+v*u1=u2
    return U


class BaseField(object):
    "Campo base Zp"

    def __init__(self, n):
        self.char = n

    def __eq__(self, X):
        return self.char == X.char

    def goodrep(self, x):
        "Buon rappresentante r per [x]: 0<=r<x"
        if self.char == 0:
            return x
        x %= self.char
        while x < 0:
            x += self.char
        x %= self.char
        return x

    def sum(self, x, y):
        return self.goodrep(x + y)

    def prod(self, x, y):
        return self.goodrep(x * y)

    def inv(self, x):
        "Inverso in Zp"
        if self.char > 0:
            SX = XEuclidean(x, self.char)
            return self.goodrep(SX[0])
        return 1.0 / x

    def random(self):
        if self.char == 0:
            return random_sample()
        else:
            return randrange(0, self.char)

    def Jacobi(self, a):
        "Simbolo di Jacobi"
        if self.char == 0:
            raise ValueError("char=0 !")
        else:
            m = self.char
        a = a % m
        t = 1
        while a != 0:
            while a % 2 == 0:
                a = a / 2
                if m % 8 in [3, 5]:
                    t = -t
            (a, m) = (m, a)
            if (a % 4 == m % 4) and (m % 4 == 3):
                t = -t
            a = a % m
        if m == 1:
            return t
        return 0

    def evenpart(self, x):
        # This is actually useful more than once
        e = 0
        while x % 2 == 0:
            e += 1
            x /= 2
        return [e, x]

    def sqrt(self, a):
        "Radice quadrata in Zp"
        p = self.char
        if self.char == 0:
            return sqrt(a)
        elif self.char == 2:
            return a
        elif self.Jacobi(a) == -1:
            raise ValueError("a is a non-square!")
        n = self.char - 1
        while self.Jacobi(n) != -1:
            n = randrange(1, self.char)
        [e, q] = self.evenpart(p - 1)
        z = pow(n, q, p)
        y = z
        r = e
        x = pow(a, (q - 1) / 2, p)
        b = self.prod(a, pow(x, 2, p))
        x = self.prod(a, x)
        while b % p != 1:
            m = 0
            while pow(b, pow(2, m), p) != 1:
                m += 1
            if m == r:
                raise ValueError("a is a non-square -")
            t = pow(y, pow(2, r - m - 1), p)
            y = pow(t, 2, p)
            r = m
            x = self.prod(x, t)
            b = self.prod(b, y)
        return x


class Polynomial(object):
    "Polinomio su di un campo Zp"

    def __init__(self, n=0):
        self.K = BaseField(n)

    def degree(self, x):
        if type(x) == int:
            return 0
        c = len(x) - 1
        while x[c] == 0:
            c -= 1
            if c < 0:
                return -1
        return c

    def cut(self, p):
        "Garantisce len(p)=deg(p)+1"
        if type(p) == int:
            return [p]
        if self.K.char > 0:
            p = [i % self.K.char for i in p]

        def last(x):
            return x[len(x) - 1]

        def butlast(x):
            return x[0 : len(x) - 1]

        while last(p) == 0 and len(p) > 1:
            p = butlast(p)
        return p

    def leading(self, x):
        "Coeff. direttore"
        t = self.degree(x)
        if t > -1:
            return x[self.degree(x)]
        return 0

    def prod(self, x, y):
        "Prodotto (da definizione)"
        if type(x) == int:
            x = [x]
        if type(y) == int:
            y = [y]

        def getCk(x, y, k):
            def getTerm(x, i):
                if i < len(x):
                    return x[i]
                return 0

            tsum = reduce(
                lambda x, y: self.K.sum(x, y),
                [
                    self.K.prod(
                        getTerm(x, i), getTerm(y, k - i)
                    )
                    for i in range(0, k + 1)
                ],
            )
            return tsum

        return self.cut(
            [
                getCk(x, y, i)
                for i in range(
                    0, (self.degree(x) + self.degree(y) + 1)
                )
            ]
        )

    def prodByScalar(self, c, x):
        "Prodotto per uno scalare"
        return self.cut([self.K.prod(c, i) for i in x])

    def sum(self, x, y):
        "Somma di polinomi"
        if type(x) == int:
            x = [x]
        if type(y) == int:
            y = [y]
        p1 = len(x) - len(y)
        if p1 > 0:
            rx = x
            ry = y + [0] * p1
        elif p1 < 0:
            rx = x + [0] * (-p1)
            ry = y
        else:
            rx = x
            ry = y
        return self.cut(
            [self.K.sum(v[0], v[1]) for v in zip(rx, ry)]
        )

    def quorem(self, A, B, debug=False):
        "Algoritmo della divisione: A=BQ+R"
        R = A
        Q = [0]
        if debug:
            print("A   =", A, "    deg=", self.degree(A))
            print("B   =", B, "    deg=", self.degree(B))
        while self.degree(R) >= self.degree(B):
            l = self.K.prod(
                self.leading(R), self.K.inv(self.leading(B))
            )
            S = [0] * (self.degree(R) - self.degree(B)) + [
                l
            ]
            Sm = [0] * (self.degree(R) - self.degree(B)) + [
                -l
            ]
            Q = self.sum(Q, S)
            R = self.sum(R, self.prod(Sm, B))
            if debug:
                print("Q   =", Q, "  deg=", self.degree(Q))
                print("A-BQ=", R, "  deg=", self.degree(R))
        return [self.cut(Q), self.cut(R)]

    def monic(self, A):
        A = self.cut(A)
        if A == [0]:
            return [0]
        A = self.prodByScalar(
            self.K.inv(self.leading(A)), A
        )
        return A

    def random(self, deg):
        return [self.K.random() for i in range(0, deg)]

    def XEuclidean(self, A, B):
        "Alg. Euclideo esteso: V*A+U*B=D"
        if self.degree(A) < self.degree(B):
            return self.XEuclidean(B, A)
        U = [1]
        D = A
        V1 = [0]
        V3 = B
        while V3 != [0]:
            [Q, R] = self.quorem(D, V3)
            T = self.sum(
                U, self.prodByScalar(-1, self.prod(V1, Q))
            )
            U = V1
            D = V3
            V1 = T
            V3 = R
        V = self.quorem(
            self.sum(
                D, self.prodByScalar(-1, self.prod(A, U))
            ),
            B,
        )[0]
        R1 = [V, U, D]
        j = self.K.inv(self.leading(D))
        return [self.prodByScalar(j, i) for i in R1]


class FField(object):
    "Campo finito"

    def __init__(self, p, irred=[-1, 1]):
        self.machinery = Polynomial(p)
        mpol = self.machinery.monic(irred)
        self.irred = [
            self.machinery.K.goodrep(i) for i in mpol
        ]
        self.dim = self.machinery.degree(irred)
        self.order = p ** self.dim
        if p > 2:
            self.nsq = [1]
            while self.squareP(self.nsq):
                self.nsq = self._random()

    def __eq__(self, X):
        return (
            self.irred == X.irred
            and self.machinery.K == X.machinery.K
        )

    def __len__(self):
        return self.order

    def elt(self, x):
        if type(x) == FFE:
            return x
        return FFE(self, x)

    def zerop(self, x):
        return self.machinery.cut(x) == [0]

    def sum(self, x, y):
        return self.machinery.sum(x, y)

    def prod(self, x, y):
        if self.machinery.cut(x) == [
            0
        ] or self.machinery.cut(y) == [0]:
            return [0]
        prod = self.machinery.prod(x, y)
        res = self.machinery.quorem(prod, self.irred)[1]
        return res

    def inv(self, x):
        sol = self.machinery.XEuclidean(x, self.irred)
        return sol[0]

    def _random(self):
        return self.machinery.random(self.dim)

    def random(self):
        return self.elt(self._random())

    def pow(self, x, n):
        n = int(n)

        def msb(x):
            i = 0
            t = int(x)
            while t > 0:
                i += 1
                t >>= 1
            return i

        if n == 0:
            return [1]
        if n < 0:
            return self.pow(self.inv(x), -n)
        l = msb(n)
        C = 1
        for i in range(0, l + 1):
            C = self.prod(C, C)
            if ((n >> (l - i)) & 0x1) == 1:
                C = self.prod(C, x)
        return self.machinery.cut(C)

    def int(self, x):
        "Rappresentazione come intero"
        r = Decimal(0)
        for i in range(0, len(x)):
            scale = pow(Decimal(self.machinery.K.char), i)
            r += scale * x[i]
        return r

    def __getitem__(self, x):
        if x >= self.order:
            raise IndexError("list index out of range")
        r = [0] * self.dim
        for i in range(0, self.dim):
            r[i] = int(x % self.machinery.K.char)
            x = floor(x / self.machinery.K.char)
        return self.elt(self.machinery.cut(r))

    def __getslice__(self, i, j):
        r = []
        for h in range(i, j):
            r.append(self[h])
        return r

    def squareP(self, x):
        if self.machinery.K.char == 2:
            return True
        if type(x) == int:
            x = [x]
        x = self.machinery.cut(x)
        if x == [1] or x == [0]:
            return True
        [e, q] = self.machinery.K.evenpart(self.order - 1)
        w = self.pow(x, q)
        for i in range(0, e):
            if w == [1]:
                return True
            w = self.pow(w, 2)
        return False

    def sqrt(self, x):
        if type(x) == int:
            x = [x]
        x = self.machinery.cut(x)
        # EASY CASES
        if x == [1] or x == [0]:
            return x
        elif self.dim == 1:
            return [self.machinery.K.sqrt(x[0])]
        elif self.machinery.K.char == 2:
            l = self.order / 2
            while self.pow(x, l) == [1]:
                l /= 2
            return self.pow(x, l)
        # ODD CHARACTERISTIC - TONELLI
        if not (self.squareP(x)):
            raise ValueError("sqrt - non-square element!")
        e = 0
        g = self.nsq
        [s, t] = self.machinery.K.evenpart(self.order - 1)
        for i in range(2, s + 1):
            h = self.prod(x, self.pow(g, -e))
            if self.pow(
                h, (self.order - 1) / pow(2, i)
            ) != [1]:
                e += pow(2, i - 1)
        h = self.prod(x, self.pow(g, -e))
        b = self.prod(
            self.pow(g, e / 2), self.pow(h, (t + 1) / 2)
        )
        return b

    def unrepr(self, x):
        if type(x) == FFE:
            return x
        elif type(x) == int:
            return self[x]
        else:
            return self.elt(x)

    def solve(self, b, c):
        # Solves x**2+b*x+c=0
        if b.zeroP():
            cp = -c
            if cp.squareP():
                return [cp.sqrt()]
            else:
                return False
        # Odd characteristic
        if self.machinery.K.char > 2:
            delta = b ** 2 - self.elt([4]) * c
            if not (delta.squareP()):
                return False
            dsq = delta.sqrt() / self.elt([2])
            bm2 = -b / self.elt([2])
            return [bm2 + dsq, bm2 - dsq]
        # Even characteristic

        def H(c):
            # Half Trace
            return reduce(
                lambda l, m: l + m,
                [
                    c ** pow(2, 2 * i)
                    for i in range(
                        0, int((self.dim - 1) / 2) + 1
                    )
                ],
            )

        def Tr(c):
            return reduce(
                lambda l, m: l + m,
                [
                    c ** pow(2, i)
                    for i in range(0, self.dim)
                ],
            )

        if b == self.elt([1]):
            if not (Tr(c) == self.elt([0])):
                return False
            if self.dim % 2 == 1:
                # x**2+x+c=0
                s = H(c)
                return [s, self.elt([1]) + s]
            elif self.dim % 2 == 0:
                delta = self.elt([1])
                while Tr(delta) == self.elt([0]):
                    delta = self.random()
                s = reduce(
                    lambda l, m: l + m,
                    [
                        (c ** pow(2, i))
                        * reduce(
                            lambda l, m: l + m,
                            [
                                delta ** pow(2, j)
                                for j in range(
                                    i + 1, self.dim
                                )
                            ],
                        )
                        for i in range(0, self.dim - 1)
                    ],
                )
                return [s, self.elt([1]) + s]

        s0 = self.solve(self.elt([1]), c * (b ** (-2)))
        if s0:
            return [i * b for i in s0]
        else:
            return False

        # shoudl
        raise ValueError("Should not happen")

    def prodTable(self):
        lenelt = int(
            (self.dim)
            * (
                ceil(
                    log(float(self.machinery.K.char))
                    / log(10.0)
                )
                + 1
            )
            + 2
        )
        fstring = "{0:>" + str(lenelt) + "}"
        print(fstring.format(" * ") + "|", end=" ")
        for x in self:
            print(
                fstring.format(
                    str(x)
                    .replace("#F", "")
                    .replace(" ", "")
                ),
                end=" ",
            )
        print()
        print("-" * ((lenelt + 1) * (len(self) + 1) + 1))
        for x in self:
            print(
                fstring.format(
                    str(x)
                    .replace("#F", "")
                    .replace(" ", "")
                )
                + "|",
                end=" ",
            )
            for y in self:
                print(
                    fstring.format(
                        str(x * y)
                        .replace("#F", "")
                        .replace(" ", "")
                    ),
                    end=" ",
                )
            print()


class FFE(object):
    "Elemento di campo finito"

    def __init__(self, FF, val=[0]):
        if not (type(FF) == FField):
            raise TypeError(
                "Param should be a finite field"
            )
        self.FF = FF
        self.set(val)

    def random(self):
        self.val = self.FF._random()

    def __len__(self):
        return len(self.val)

    def __getitem__(self, x):
        return self.val[x]

    def __eq__(self, x):
        if not (type(x) == type(self)):
            return False
        if not x.FF == self.FF:
            return False
        return self.val == x.val

    def __add__(self, x):
        R = FFE(self.FF)
        R.val = self.FF.sum(self.val, x.val)
        return R

    def __mul__(self, x):
        R = FFE(self.FF)
        R.val = self.FF.prod(self.val, x.val)
        return R

    def __neg__(self):
        R = FFE(self.FF)
        R.val = [
            (self.FF.machinery.K.char - i)
            % self.FF.machinery.K.char
            for i in self.val
        ]
        return R

    def __sub__(self, x):
        return self + (-x)

    def __truediv__(self, x):
        R = FFE(self.FF)
        R.val = self.FF.prod(self.val, self.FF.inv(x.val))
        return R

    def __div__(self, x):
        return __truediv__(self, x)

    def __pow__(self, n):
        R = FFE(self.FF)
        R.val = self.FF.pow(self.val, n)
        return R

    def __repr__(self):
        return "#F" + str(self.val)

    def __int__(self):
        return int(self.FF.int(self.val))

    def __hash__(self):
        return self.__int__()

    def __Decimal__(self):
        return Decimal(self.FF.int(self.val))

    def copy(self):
        R = FFE(self.FF)
        R.val = self.val
        return R

    def zero(self):
        self.val = [0]

    def zeroP(self):
        return self.val == [0]

    def squareP(self):
        return self.FF.squareP(self.val)

    def sqrt(self):
        return self.FF.elt(self.FF.sqrt(self.val))

    def set(self, x):
        if type(x) == int or type(x) == Decimal:
            self.val = self.FF[x].val
        elif type(x) == list:
            self.val = self.FF.machinery.cut(x)
        elif type(x) == FFE:
            self.val = x.val
            self.FF = x.FF
        elif type(x) == str:
            xx = x.replace("#F[", "[")
            self.val = self.FF.machinery.cut(eval(xx))


# EXAMPLES
p = 2
poly = [1, 1, 0, 1]
GF2 = FField(p)
GF8 = FField(p, poly)
poly = [1, 0, 1, 1]
GF8b = FField(p, poly)


p = 3
poly = [2, 2, 1]
GF3 = FField(3)
GF9 = FField(p, poly)
poly = [1, 2, 0, 1]
GF27 = FField(p, poly)

p = 5
GF5 = FField(p)
poly = [2, 4, 1]
GF25 = FField(p, poly)

p = 7
poly = [3, -1, 1]
GF7 = FField(p)
GF49 = FField(p, poly)

p = 11
GF11 = FField(p)
poly = [2 ** 6, 2, 0, 1]
GF1331 = FField(p, poly)


AESp = 2
AESpoly = [1, 1, 0, 1, 1, 0, 0, 0, 1]
GF256 = FField(AESp, AESpoly)