Elliptic Curves


from FField import *
from numpy import *


class ECurve(object):
    "Curva Ellittica"

    def __init__(self, p, irred=[-1, 1]):
        self.F = FField(p, irred)

        self.check = False
        self._inv = False
        self._sum = False
        self._findP = False
        # These values are actually computed
        # when needed
        self.computedsize = False
        self.computedpoints = False
        self.countpoints = False

    def check(self, P):
        return self._check(P)

    def sum(self, x, y):
        return self.ecp(self._sum(x, y))

    def inv(self, x):
        return self.ecp(self._inv(x))

    def findP(self, x):
        l = self._findP(x)
        if l:
            return [self.ecp(i) for i in self._findP(x)]
        else:
            return False

    def EquWeierstrass(self, a, b):
        # Equazione in forma di Weierstrass
        self.a = self.F.elt(a)
        self.b = self.F.elt(b)
        if self.F.machinery.K.char in [2, 3]:
            raise ValueError("char != 2,3")
        self.equation = (
            "y**2=x**3+" + str(a) + "x+" + str(b)
        )
        self.Delta = (
            self.F.elt([4]) * (self.a ** 3)
            + self.F.elt([27]) * self.b ** 2
        )
        if self.Delta == self.F.elt([0]):
            raise ValueError("Delta==0")

        def lhs(y):
            return y ** 2

        def rhs(x):
            return x ** 3 + self.a * x + self.b

        def check(P):
            if P.Pt == "#PO":
                return True
            [x, y] = P.Pt
            return lhs(y) == rhs(x)

        def findP(x):
            r = rhs(x)
            # print str(x)+"->"+str(rhs),
            if not (r.squareP()):
                return False
            if r.zeroP():
                y = [[x, r]]
            else:
                y0 = r.sqrt()
                y = [[x, y0], [x, self.F.elt([-1]) * y0]]
            # print "="+str(y)
            return y

        def lam(P, Q):
            if P[0].val != Q[0].val:
                return (Q[1] - P[1]) / (Q[0] - P[0])
            if P == Q:
                if P[1].zeroP():
                    return "#PO"
                return (
                    self.a + self.F.elt([3]) * (P[0] ** 2)
                ) / (self.F.elt([2]) * P[1])
            return "#PO"

        def sum(P, Q):
            if P == "#PO":
                return Q
            if Q == "#PO":
                return P
            l = lam(P, Q)
            if l == "#PO":
                return self.ecp("#PO")
            x3 = l ** 2 - P[0] - Q[0]
            y3 = l * (P[0] - x3) - P[1]
            return self.ecp([x3, y3])

        def inv(P):
            if P.Pt == "#PO":
                return self.ecp("#PO")
            return self.ecp([P[0], -P[1]])

        # Bind functions
        self.check = check
        self._sum = sum
        self._inv = inv
        self._findP = findP

    def EquNonSSing(self, a, b):
        if self.F.machinery.K.char != 2:
            raise ValueError("char == 2!")
        self.a = self.F.elt(a)
        self.b = self.F.elt(b)
        if self.b.zeroP():
            raise ValueError("b!=0")
        self.equation = (
            "y**2+xy=x**3+" + str(a) + "x**2+" + str(b)
        )

        def lhs(P):
            return P[1] ** 2 + P[0] * P[1]

        def rhs(x):
            return x ** 3 + self.a * x ** 2 + self.b

        def check(P):
            if P.Pt == "#PO":
                return True
            return lhs(P) == rhs(P[0])

        def findP(x):
            r = rhs(x)
            if r.zeroP():
                return [
                    self.ecp(i)
                    for i in [[x, self.F.elt([0])], [x, x]]
                ]
            yv = self.F.solve(x, r)
            if yv:
                return [self.ecp([x, y]) for y in yv]
            else:
                return False

        def sum(P, Q):
            def kappa(P, Q):
                if P[0] == Q[0]:
                    return self.ecp("#PO")
                return (P[1] + Q[1]) / (P[0] + Q[0])

            def mu(P):
                if P[0].zeroP():
                    return self.ecp("#PO")
                return P[0] + P[1] / P[0]

            def double(P):
                m = mu(P)
                if type(m) == ECP:
                    return self.ecp("#PO")
                x3 = m ** 2 + m + self.a
                y3 = P[0] ** 2 + (m + self.F.elt([1])) * x3
                return self.ecp([x3, y3])

            def add(P, Q):
                k = kappa(P, Q)
                if type(k) == ECP:
                    return self.ecp("#PO")
                x3 = (k ** 2 + k) + P[0] + Q[0] + self.a
                y3 = k * (P[0] + x3) + x3 + P[1]
                return self.ecp([x3, y3])

            if P.Pt == "#PO":
                return Q
            elif Q.Pt == "#PO":
                return P
            elif P == Q:
                return double(P)
            else:
                return add(P, Q)

        def inv(P):
            if P.Pt == "#PO":
                return P
            if P.Pt[0].zeroP():
                return P
            Q = [P[0], P[0] + P[1]]
            return self.ecp(Q)

        self.check = check
        self._sum = sum
        self._inv = inv

        self._findP = findP

    def EquSSing(self, a, b, c):
        # The supersingular case
        if self.F.machinery.K.char != 2:
            raise ValueError("char == 2!")
        if self.F.zerop(c):
            raise ValueError("c!=0")
        self.a = self.F.elt(a)
        self.b = self.F.elt(b)
        self.c = self.F.elt(c)
        self.equation = (
            "y**2+"
            + str(c)
            + "y=x**3+"
            + str(a)
            + "x+"
            + str(b)
        )

        def lhs(y):
            return y ** 2 + self.c * y

        def rhs(x):
            return x ** 3 + self.a * x + self.b

        def check(P):
            if P.Pt == "#PO":
                return True
            [x, y] = P.Pt
            return lhs(y) == rhs(x)

        def findP(x):
            r = rhs(x)
            yv = self.F.solve(self.c, r)
            if yv:
                return [self.ecp([x, y]) for y in yv]
            else:
                return False

        def sum(P, Q):
            def kappa(P, Q):
                if P[0] == Q[0]:
                    return "#PO"
                return (P[1] + Q[1]) / (P[0] + Q[0])

            def eta(P):
                return (P[0] ** 2 + self.a) / self.c

            def double(P):
                e = eta(P.Pt)
                x3 = e ** 2
                y3 = e * (P.Pt[0] + x3) + P.Pt[1] + self.c
                return self.ecp([x3, y3])

            def add(P, Q):
                k = kappa(P.Pt, Q.Pt)
                if k == "#PO":
                    return self.ecp("#PO")
                x3 = k ** 2 + P[0] + Q[0]
                y3 = k * (P[0] + x3) + P[1] + self.c
                return self.ecp([x3, y3])

            if P == "#PO":
                return Q
            elif Q == "#PO":
                return P
            elif P == Q:
                return double(P)
            else:
                return add(P, Q)

        def inv(P):
            if P.Pt == "#PO":
                return self.ecp("#PO")
            x3 = P[0]
            y3 = P[1] + self.c
            return self.ecp([x3, y3])

        self.check = check
        self._sum = sum
        self._inv = inv

        self._findP = findP

    def allPoints(self):
        if self.computedpoints:
            return self.computedpoints
        l1 = []

        vx = [self.findP(x) for x in self.F]
        vx = [self.ecp(x) for x in vx if x]
        for t in vx:
            [l1.append(i) for i in t]
        l1.append(self.ecp("#PO"))
        l1.sort()
        self.computedpoints = l1
        self.computedsize = len(l1)
        return l1

    def mul(self, k, P):
        def msb(x):
            i = 0
            t = int(x)
            while t > 0:
                i += 1
                t >>= 1
            return i

        if k == 0:
            return "#PO"
        elif k < 0:
            return self.mul(-k, self.inv(P))
        k = int(k)
        l = msb(k)
        C = self.ecp("#PO")
        for i in range(0, l + 1):
            C = C + C
            if ((k >> (l - i)) & 0x1) == 1:
                C = C + P
        return C

    def random(self):
        while 1 == 1:
            xp = self.F.random()
            if self.findP:
                y = self.findP(xp)
                if y:
                    r = randrange(0, len(y))
                    return self.ecp(y[r])
            else:
                for i in self.F[0 : len(self.F)]:
                    P = self.ecp([xp, i])
                    if self.check(P):
                        return P

    def Orbit(self, P):
        res = []
        X = P
        while not (X == "#PO"):
            res.append(X)
            X = self.sum(X, P)
        res.append("#PO")
        return res

    def Order(self, P):
        return len(self.Orbit(P))

    def __len__(self):
        if self.computedsize:
            return self.computedsize
        else:
            if self.countpoints:
                return countpoints()
            else:
                x = self.allPoints()
                return self.computedsize

    def __getitem__(self, i):
        if not (self.computedpoints):
            self.allPoints()
        return self.ecp(self.computedpoints[i])

    def __str__(self):
        return self.equation

    def __eq__(self, oth):
        if type(self) == type(oth):
            return self.equation == str(oth)
        else:
            return False

    def ecp(self, P):
        if type(P) == ECP:
            return P
        else:
            return ECP(self, P)

    def Test(self):
        print("Equation: " + self.equation)
        pts = [str(i) for i in self]
        pts.sort()

        def TestHasseWeil():
            print(str(len(self)) + " ", end=" ")
            if abs(len(self) - self.F.order - 1) > 2 * (
                sqrt(self.F.order)
            ):
                return False
            else:
                return True

        def TestInverse():
            t0 = set([self.check(-x) for x in self])
            if False in t0:
                return False
            for x in self:
                y = -x
                z = x + y
                print(str(x) + "+" + str(y) + "=" + str(z))
                if not (z == self.ecp("#PO")):
                    return False
            return True

        def TestSum():
            for x in self:
                print(str(x) + "+*", end=" ")
                t2 = [str(x + y) for y in self]
                t2.sort()
                if not (t2 == pts):
                    print()
                    print(pts)
                    print()
                    print(t2)
                    return False
                print(" induces a permutation!")
            return True

        def TestAssoc():
            for x in self:
                print("x=" + str(x) + "  ", end=" ")
                for y in self:
                    print(str(x) + "+" + str(y) + "+...")
                    for z in self:
                        R0 = x + y
                        R0 = R0 + z
                        R1 = y + z
                        R1 = R1 + x
                        if not (R0 == R1):
                            print("FAILS")
                            return False
            return True

        print("Hasse-Weil ...", end=" ")
        if TestHasseWeil():
            print("OK")
        else:
            print("FAIL")
            return False
        print("Inverse ...", end=" ")
        if TestInverse():
            print("OK")
        else:
            print("FAIL")
            return False
        print("Composition ...", end=" ")
        if TestSum():
            print("OK")
        else:
            print("FAIL")
            return False
        print("Associativity ...", end=" ")
        if TestAssoc():
            print("OK")
        else:
            print("FAIL")
            return False
        return True


class ECP(object):
    "Punto su di una curva ellittica"

    def __init__(self, EC, Pt="#PO"):
        if not (type(EC) == ECurve):
            raise TypeError(
                "Param should be an elliptic curve"
            )
        self.EC = EC
        self.Pt = Pt

    def set(self, W):
        if W == "#PO":
            self.Pt = W
        elif type(W) == str:
            W = eval(W.replace("#P", ""))
        if self.EC.check(W):
            self.Pt = W
        else:
            raise ValueError("P not on EC")

    def emb(self, x, i=0):
        tmp = self.EC.findP(x)
        if not (tmp):
            raise ValueError("P not on EC")
        if len(tmp) == 1:
            self.Pt = tmp[0]
        self.Pt = tmp[i]

    def random(self):
        self.Pt = self.EC.random()

    def order(self):
        return self.EC.Order(self)

    def __repr__(self):
        if self.Pt == "#PO":
            return str(self.Pt)
        return "#P" + str(self.Pt)

    def __add__(self, Q):
        return self.EC.sum(self, Q)

    def __gt__(self, Q):
        return str(self.Pt) > str(Q.Pt)

    def __sub__(self, Q):
        return self.EC.sum(self, self.EC.inv(Q))

    def __pow__(self, n):
        if n < 0:
            P0 = -self
            n = -n
        elif n == 0:
            P0 = self.EC.ecp("#PO")
        else:
            P0 = self
        return self.EC.mul(n, P0)

    def __eq__(self, X):
        if X == "#PO" and self.Pt == "#PO":
            return True
        if not (type(X) == ECP):
            return False
        return (
            X[0].val == self[0].val
            and X[1].val == self[1].val
        )

    def __getitem__(self, i):
        if i > 1:
            raise IndexError("P=P(x,y)")
        if self.Pt == "#PO":
            return self.EC.ecp("#PO")
        return self.Pt[i]

    def __neg__(self):
        return self.EC.inv(self)

    def copy(self):
        R = ECP(self.EC)
        R.set(self.Pt)
        return R

    def zero(self):
        self.Pt = "#PO"


# SAMPLES

E = ECurve(5, [8, 8, 0, 1])
E.EquWeierstrass([1], [1])
# E.Test()

F = ECurve(2, [1, 1, 0, 0, 1])
F.EquNonSSing([0], [1])
# F.Test()

G = ECurve(2, [1, 0, 1, 0, 0, 1])
G.EquSSing([1, 1], [0, 1], [1, 1])
# G.Test()

L = ECurve(2, [1, 0, 1, 0, 0, 1])
L.EquNonSSing(1, 1)

H = ECurve(29)
H.EquWeierstrass([1], [2])
# H.Test()

M = ECurve(2, [1, 0, 0, 1, 0, 0, 1, 1] + [0] * 156 + [1])
M.EquNonSSing([1], [1])

N = ECurve(2, [1, 1, 0, 1])
N.EquSSing([1], [1], [0, 1])

Np = ECurve(2, [1, 1, 0, 1])
Np.EquNonSSing([1], [1])