import numpy
from math import ceil, floor, exp, log, sqrt
from decimal import *
from random import *
from EuclideanGCD import *
from Pollard import *
from hashlib import sha512
from Crypto.Cipher import AES
from functools import reduce
getcontext().prec = 600
# UTILITY
def EuclideanGCD(a, b):
    if a < b:
        return EuclideanGCD(b, a)
    b1, r = divmod(a, b)
    if r == 0:
        return b
    return EuclideanGCD(b, r)
def LCM2(a, b):
    return (a * b) // EuclideanGCD(a, b)
def LCM(L):
    return reduce(lambda x, y: LCM2(x, y), L)
def EulerPhi(L):
    # L contains the factorisation of N
    N = numpy.multiply.reduce([y[0] ** y[1] for y in L])
    Factors = [(x[0] ** (x[1] - 1)) * (x[0] - 1) for x in L]
    print(N)
    return numpy.multiply.reduce(Factors)
def EulerPhi2(p, q):
    return (p - 1) * (q - 1)
def Carmichael(L):
    Factors = [(x[0] ** (x[1] - 1)) * (x[0] - 1) for x in L]
    return LCM(Factors)
def Carmichael2(p, q):
    return Carmichael(((p, 1), (q, 1)))
    # return (p-1)*(q-1)/EuclideanGCD(p-1,q-1)
def InvZn(x, N):
    SX = XEuclidean(x, N)
    if SX[2] > 1:
        return False
    return (SX[0] + N) % N
def powSign(a, b, N):
    if b > 0:
        return pow(int(a), int(b), int(N))
    else:
        return InvZn(pow(int(a), -int(b), int(N)), N)
# CRT = Teorema Cinese del Resto
def ChineseRemThm(L):
    print("L=", L, "\n")
    # L: list (a,m) where (x=a)%m
    M = numpy.multiply.reduce([x[1] for x in L])
    P = [
        (x[0] * (M // x[1]) * (InvZn((M // x[1]), x[1])))
        % M
        for x in L
    ]
    print(P)
    return sum(P) % M
# Knuth, TAOCP vol 2, pag. 462 (modificato)
def SimplePowModN(m, e, N):
    def numBinDigits(x):
        i = 0
        t = x
        while t > 0:
            i += 1
            t >>= 1
        return i
    k = numBinDigits(e)
    C = 1
    for i in range(0, k + 1):
        C = (C * C) % N
        if ((e >> (k - i)) & 0x1) == 1:
            C = (C * m) % N
    return C
# RSA con phi
def RSAKeyGen(p, q, e):
    # Genera la chiave privata data la chiave pubblica
    # Modulo
    N = p * q
    # Phi
    phi = EulerPhi2(p, q)
    print("Modulus =", N, "\n")
    # Calcolo d con algoritmo euclideo
    Rv = XEuclidean(e, phi)
    if Rv[2] > 1:
        return False
    d = (Rv[0] + phi) % phi
    k = (Rv[1] + phi) % phi
    print("Phi,d,k =", phi, ",", d, ",", k)
    return [e, d, N]
# RSA con lambda
def CaKeyGen(p, q, e):
    # Genera la chiave privata data la chiave pubblica
    # usando la funzione di Carmichael
    # Modulo
    N = p * q
    # Lambda
    lam = Carmichael2(p, q)
    print("Modulus =", N, "\n")
    # Calcolo d con algoritmo euclideo
    Rv = XEuclidean(e, lam)
    if Rv[2] > 1:
        return False
    d = (Rv[0] + lam) % lam
    k = (Rv[1] + lam) % lam
    print("Lambda,d,k =", lam, ",", d, ",", k)
    return [e, d, N]
def RSAEnc(m, e, N):
    # Codifica RSA
    return pow(int(m), int(e), int(N))
def RSADec(c, d, N):
    # Decodifica RSA
    return pow(int(c), int(d), int(N))
# Decodifica con CRT
def RSACRTDec(c, d, p, q):
    # Decodifica RSA (tramite t.cinese d.resto)
    dp = d % (p - 1)
    dq = d % (q - 1)
    pm = pow(p, q - 2, q)
    mp = pow(c, dp, p)
    mq = pow(c, dq, q)
    # m=mp+p*(((mq-mp)*pm)%q)
    m = ChineseRemThm(((mp, p), (mp, q)))
    return m
# Radice quadrata mod p
def sqrt(a, p):
    # Sqrt mod p
    if Jacobi(a, p) != 1:
        return False
    a = a % p
    if p % 8 in [3, 7]:
        x = pow(a, (p + 1) // 4, p)
        return x
    if p % 8 == 5:
        x = pow(a, (p + 3) // 8, p)
        c = pow(x, 2, p)
        if c % p != a % p:
            x = (x * pow(2, (p - 1) // 4, p)) % p
        return x
    d = 1
    while Jacobi(d, p) != -1:
        d = randrange(2, p - 1)
    t = p - 1
    s = 0
    while t % 2 == 0:
        s += 1
        t = t // 2
    A = pow(a, t, p)
    D = pow(d, t, p)
    m = 0
    for i in range(0, s):
        if (
            pow(A * pow(D, m, p), pow(2, s - 1 - i), p) % p
            == -1 % p
        ):
            m = m + pow(2, i)
    print(m)
    x = (pow(a, (t + 1) // 2, p) * pow(D, m // 2, p)) % p
    return x
# Funzioni per definire l'hash
def GETstream(r, bits):
    def GENstreamAES(r, bits):
        TXT = "0" * 16
        rd = sha512(str(r).encode("utf-8")).digest()
        OFB = AES.new(rd[0:32], AES.MODE_OFB, rd[33:49])
        res = bytearray()
        for i in range(0, int(bits // (16 * 8)) + 1):
            res += OFB.encrypt(TXT)
        return res
    def StreamToN(r):
        rt = 0
        print("r=", r, "\n")
        for i in range(len(r) - 1, 0, -1):
            rt += 256 * rt + ord(str(r)[i])
        return rt
    return StreamToN(GENstreamAES(r, bits))
# CODIFICA OAEP
def OAEPenc(m, e, N):
    mbit = int(log(N) / log(2) + 1)
    r = randrange(1, N)
    g = GETstream(r, mbit) % N
    m0 = m ^ g
    m1 = r ^ (GETstream(m0, mbit) % N)
    # print "m0=",m0
    # print "m1=",m1
    return (RSAEnc(m0, e, N), RSAEnc(m1, e, N))
# DECODIFICA OAEP
def OAEPdec(c, d, N):
    mbit = int(log(N) / log(2) + 1)
    m0 = RSADec(c[0], d, N)
    m1 = RSADec(c[1], d, N)
    print("m0=", m0)
    print("m1=", m1)
    r = m1 ^ (GETstream(m0, mbit) % N)
    g = GETstream(r, mbit) % N
    m = m0 ^ g
    return m
# Esempi
# RSA "Classico"
p = 633825300114114700748351602943
q = 1267650600228229401496703205653
e = 13
SK = RSAKeyGen(p, q, e)
N = SK[2]
d = SK[1]
m = 239828359
print("\n", m == RSADec(RSAEnc(m, e, N), d, N), "\n")
# Idem con OAEP
for i in range(0, 4):
    cc = OAEPenc(m, e, N)
    print("CODIFICA=", cc)
    dd = OAEPdec(cc, d, N)
    print(dd == m)
# RSA con lambda
p = (2 ** 50) * (3 ** 8) * (5 ** 5) * (7 ** 2) * 11 + 1
q = (2 ** 52) * (3 ** 10) * (5 ** 6) * 7 * 11 * 17 + 1
N = p * q
e = 13
d = RSAKeyGen(p, q, e)[1]
dp = CaKeyGen(p, q, e)[1]
print("d  ~ 10**", int(log(d * 1.0) / log(10)))
print("dp ~ 10**", int(log(dp * 1.0) / log(10)))
m = 12489237523578
enc = RSAEnc(m, e, N)
dec1 = RSADec(enc, d, N)
dec2 = RSADec(enc, dp, N)
print(m == dec1 == dec2)
# Modulo piccolo
def DecodeByFactoring(enc, e, N):
    p1 = PollardRho(N)
    q1 = N // p1
    d1 = RSAKeyGen(p1, q1, e)[1]
    print("------")
    print("p=", p1)
    print("q=", q1)
    print("d=", d1)
    print("------")
    return RSADec(enc, d1, N)
p = 1247893
q = 5247899
e = 13
d = RSAKeyGen(p, q, e)[1]
N = p * q
m = 12341
enc = RSAEnc(m, e, N)
dc = DecodeByFactoring(enc, e, N)
print(m == dc)
# Messaggio piccolo
def SmallMSG(c, e):
    print("c=", c)
    # getcontext().Emax=c*2
    # print(getcontext())
    return pow(Decimal(c), Decimal(1) / Decimal(e))
p = 633825300114114700748351602943
q = 1267650600228229401496703205653
e = 5
SK = RSAKeyGen(p, q, e)
N = SK[2]
m = 392
enc = RSAEnc(m, e, N)
dec = SmallMSG(enc, e)
print(dec == m)
# Broadcast attack -> esponente condiviso
def SharedExp(ML, e):
    # ML: (enc_i,N_i)
    N = numpy.multiply.reduce([x[1] for x in ML])
    me = ChineseRemThm(ML)
    print("M**e=", me)
    return SmallMSG(me, e)
# Setup
e = 5
p1 = 633825300114114700748351602943
q1 = 1267650600228229401496703205653
N1 = p1 * q1
p2 = 381520424476945831628649898931
q2 = 2220446049250313080847263336181640719
N2 = p2 * q2
p3 = 1224809639974238708818962962512535510581441
q3 = 452592555681759518058893560348969204658413
N3 = p3 * q3
p4 = 3572361449924862900721975307328228572528663
q4 = 2713293577442931584119786007232630749134861
N4 = p4 * q4
p5 = 79483005834479742509794422283529355413
q5 = 75800901255111448774141714366483
N5 = p5 * q5
K1 = RSAKeyGen(p1, q1, e)
K2 = RSAKeyGen(p2, q2, e)
K3 = RSAKeyGen(p3, q3, e)
K4 = RSAKeyGen(p4, q4, e)
K5 = RSAKeyGen(p5, q5, e)
# Message
m = 1211202812833592356234
# Encrypt
e1 = RSAEnc(m, e, N1)
e2 = RSAEnc(m, e, N2)
e3 = RSAEnc(m, e, N3)
e4 = RSAEnc(m, e, N4)
e5 = RSAEnc(m, e, N5)
# Break system
dd = SharedExp(
    ((e1, N1), (e2, N2), (e3, N3), (e4, N4), (e5, N5)), e
)
print(dd == m)
# Broadcast attack -> modulo condiviso
def SharedModulus(enc1, enc2, e1, e2, N):
    [a1, a2, x] = XEuclidean(e1, e2)
    Rmsg = (powSign(enc1, a1, N) * powSign(enc2, a2, N)) % N
    return Rmsg
p = 633825300114114700748351602943
q = 1267650600228229401496703205653
N = p * q
# Encryption keys
e1 = 13
e2 = 29
# Decryption keys (will not be used)
d1 = RSAKeyGen(p, q, e1)[1]
d2 = RSAKeyGen(p, q, e2)[1]
# Message
m = 219412
# Encrypt
enc1 = RSAEnc(m, e1, N)
enc2 = RSAEnc(m, e2, N)
print(SharedModulus(enc1, enc2, e1, e2, N) == m)
# Related messages con e=3
p = 633825300114114700748351602943
q = 1267650600228229401496703205653
N = p * q
e = 3
a = 5
b = 124
m1 = 21495234
m2 = a * m1 + b
enc1 = RSAEnc(m1, e, N)
enc2 = RSAEnc(m2, e, N)
def RelatedMSG3(c1, c2, a, b):
    # Related Message attack (e=3)
    m1 = (
        b
        * (c2 + 2 * a ** 3 * c1 - b ** 3)
        * InvZn(a * (c2 - a ** 3 * c1 + 2 * b ** 3), N)
    ) % N
    return m1, a * m1 + b
d = RelatedMSG3(enc1, enc2, a, b)
print(d == (m1, m2))