AES


from numpy import *
from functools import reduce


class GF256:

    def BinaryProd(self, x, y):
        # Fast Product
        r = 0
        t = x
        while(y > 0):
            r ^= (t * (y % 2))
            t = t << 1
            y = y // 2
        return r

    def msb(self, x):
        i = 0
        t = x
        while(t > 0):
            i += 1
            t >>= 1
        return i

    def prod(self, x, y):
        # Fast quotient
        # GF(256)
        poly = 0b100011011
        c = self.BinaryProd(x, y)
        if c == 0:
            return c
        nbits = self.msb(c)
        while(nbits > 8):
            c ^= (poly << (nbits - 9))
            nbits = self.msb(c)
        return c

    def sum(self, x, y):
        return x ^ y

    def pow(self, x, n):
        l = self.msb(n)
        C = 1
        for i in range(0, l + 1):
            C = self.prod(C, C)
            if ((n >> (l - i)) & 0x1) == 1:
                C = self.prod(C, x)
        return C

    def inv(self, x):
        # invert = x^(254)
        # 254=0b11111110
        # special case of square and prod
        C = 1
        for i in range(0, 7):
            C = self.prod(C, C)
            C = self.prod(C, x)
        C = self.prod(C, C)
        return C

    def PrintProdTable(self, start=0, end=256):
        def numToStr(x):
            return "{0:02x}".format(x)
        print(" * | ", end=' ')
        for y in range(start, end):
            print(numToStr(y), end=' ')
        print()
        print("---" * (end - start + 1) + "--")
        for x in range(start, end):
            print(numToStr(x) + " | ", end=' ')
            for y in range(start, end):
                print(numToStr(self.prod(int(x), int(y))), end=' ')
            print()


class AESimpl:

    def __init__(self, Nk=8, debug=False):
        self.Nb = 4
        self.K = GF256()
        if Nk == 4:
            self.Nk = 4
            self.Nr = 10
        elif Nk == 6:
            self.Nk = 6
            self.Nr = 12
        elif Nk == 8:
            self.Nk = 8
            self.Nr = 14
        else:
            raise ValueError("Nk=4,6,8")
        self.State = array(
            [[0 for i in range(0, self.Nb)]
             for j in range(0, 4)])
        self.debug = debug
        # Constants
        # N.B. Rcon starts from 1
        self.Rcon = [0] + \
            [self.VecToWord(array([self.K.pow(0x2, i - 1), 0x00, 0x00, 0x00]))
             for i in range(1, int(2 + self.Nb * (self.Nr + 1) / self.Nk))]
        self.affineConst = 0x63
        self.bVect = [0, 4, 5, 6, 7]
        self.bInv = [2, 5, 7]
        self.MixVect = [0x02, 0x03, 0x01, 0x01]
        self.InvMixVect = [0x0e, 0x0b, 0x0d, 0x09]
    # UTILITY

    def rotN(self, x, n, s):
        mask = pow(2, s) - 1
        if n >= 0:
            return ((x << n) & mask) ^ (x >> (s - n))
        else:
            return (x >> -n) ^ (x << (s + n) & mask)

    def RotByte(self, x, n=1):
        return self.rotN(x, n, 8)

    def VecToWord(self, v):
        return reduce(lambda i, j: i ^ j,
                      [v[i] << (8 * (3 - i)) for i in range(0, 4)])

    def WordToVec(self, v):
        mask = 0xff
        return array(
            [(v >> (8 * (3 - i))) & mask for i in range(0, 4)])

    def PrintArray(self, datum, mask="{0:02x} "):
        for i in range(0, 4):
            row = ""
            for j in range(0, len(datum[i])):
                v = datum[i, j]
                current = mask.format(int(v))
                row += current
            print(row)

    def PrintState(self):
        self.PrintArray(self.State)

    # SubBytes
    def SubByte(self, x):
        def affine(ix):
            # Start with the affine constant
            r = self.affineConst
            for i in self.bVect:
                r ^= self.RotByte(ix, -i)
            return r
        ix = self.K.inv(x)
        return affine(ix)

    def InvSubByte(self, x):
        def invAffine(v):
            # Strip affine constant
            r = v ^ self.affineConst
            r1 = 0
            # invert matrix
            for i in self.bInv:
                r1 ^= self.RotByte(r, -i)
            return r1
        return self.K.inv(invAffine(x))

    def ApplyBytes(self, f):
        self.State = array(
            [[f(self.State[j, i])
              for i in range(0, 4)]
             for j in range(0, self.Nb)])
        if self.debug:
            print("-----ApplyBytes-----")
            self.PrintState()

    def SubBytes(self):
        self.ApplyBytes(self.SubByte)

    def InvSubBytes(self):
        self.ApplyBytes(self.InvSubByte)

    # Shift Rows
    def ApplyShiftRows(self, f):
        def shift(i):
            return i
        for i in range(0, 4):
            OldRow = self.State[i].copy()
            for j in range(0, self.Nb):
                self.State[i][j] = OldRow[(j + f(i)) % 4]
        if self.debug:
            print("----ApplyShiftRows----")
            self.PrintState()

    def ShiftRows(self):
        def shift(i):
            return i
        self.ApplyShiftRows(shift)

    def InvShiftRows(self):
        def invshift(i):
            return -i
        self.ApplyShiftRows(invshift)

    # MixColumns
    def ApplyMixColumns(self, vect):
        def MixColumn(C):
            r = []
            for i in range(0, 4):
                r.append(
                    reduce(lambda x, y: self.K.sum(x, y),
                           [self.K.prod(C[(i + j) % 4], vect[j])
                            for j in range(0, 4)]))
            return array(r)
        for i in range(0, self.Nb):
            self.State[:, i] = MixColumn(self.State[:, i])
        if self.debug:
            print("----ApplyMixColumns----")
            self.PrintState()

    def MixColumns(self):
        self.ApplyMixColumns(self.MixVect)

    def InvMixColumns(self):
        self.ApplyMixColumns(self.InvMixVect)

    # AddRoundKey
    def AddRoundKey(self, Key):
        for i in range(0, 4):
            self.State[:, i] = \
                self.WordToVec(self.VecToWord(self.State[:, i]) ^ Key[i])
        if self.debug:
            print("----AddRoundKey----")
            print("KEY:")
            self.PrintArray(transpose(array([
                self.WordToVec(Key[i]) for i in range(0, 4)])))
            print("STATE:")
            self.PrintState()

    # KEY EXPANSION

    def KeyExpansion(self, Key):
        def RotWord(x):
            v = self.WordToVec(x)
            return self.VecToWord([v[1], v[2], v[3], v[0]])

        def SubWord(w):
            v = [self.SubByte(x) for x in self.WordToVec(w)]
            return self.VecToWord(v)

        w = [self.VecToWord(Key[4 * i:4 * (i + 1)]) for i in
             range(0, self.Nk)] + [0] * ((self.Nb) * (self.Nr + 1) - self.Nk)
        i = self.Nk
        while(i < self.Nb * (self.Nr + 1)):
            temp = w[i - 1]
            # print hex(temp)
            if (i % self.Nk == 0):
                temp = RotWord(temp)
                # print "R=",hex(temp)
                temp = SubWord(temp)
                # print "S=",hex(temp)
                # print "C=",hex(self.Rcon[i/self.Nk])
                temp ^= self.Rcon[int(i / self.Nk)]
                # print "X=",hex(temp)
            elif (self.Nk > 6 and i % self.Nk == 4):
                temp = SubWord(temp)
            w[i] = w[i - self.Nk] ^ temp
            # print hex(w[i])
            i += 1
        if self.debug:
            q = array([self.WordToVec(i) for i in w]).transpose()
            ltable = int(self.Nb * 2)
            st = "EXPANDED KEY"
            print(st + "-" * (ltable * (self.Nr + 1) - len(st)))
            self.PrintArray(q, "{0:02x}")
            print(("-" * ltable + " " * ltable) *
                  ((self.Nb) * int((self.Nr + 1) / 8)) + "-" * ltable)
            msg = ""
            for i in range(0, int(len(w) / 4)):
                msg += " " * int((ltable - 2) / 2) + \
                    "{0:02x}".format(i) + " " * int((ltable - 2) / 2)
            print(msg)
            print("-" * (ltable * (self.Nr + 1)))
        return w

    # Actual Encryption

    def Encrypt(self, msg, key):
        # Check
        if len(key) < self.Nk * 4:
            raise ValueError("Nk=", self.Nk)
        # and expand the key
        w = self.KeyExpansion(key)
        # Fill in initial state
        self.State = array([[msg[j + i * self.Nb] for i in range(0, self.Nb)]
                            for j in range(0, 4)])
        if self.debug:
            print("---INITIAL STATE---")
            self.PrintState()
        self.AddRoundKey(w[0:self.Nb])
        for i in range(1, self.Nr):
            self.SubBytes()
            self.ShiftRows()
            self.MixColumns()
            self.AddRoundKey(w[i * self.Nb:(i + 1) * self.Nb])
        if self.debug:
            print("---FINAL ROUND---")
        self.SubBytes()
        self.ShiftRows()
        self.AddRoundKey(w[self.Nr * self.Nb:(self.Nr + 1) * self.Nb])
        return transpose(self.State).flatten()

    # ... and decryption

    def Decrypt(self, msg, key):
        # Check
        if len(key) < self.Nk * 4:
            raise ValueError("Nk=", self.Nk)
        # and expand the key
        w = self.KeyExpansion(key)
        # Fill in initial state
        self.State = array([[msg[j + i * self.Nb] for i in range(0, self.Nb)]
                            for j in range(0, 4)])
        if self.debug:
            print("----Input----")
            self.PrintState()
        # Invert final round
        self.AddRoundKey(w[self.Nr * self.Nb:(self.Nr + 1) * self.Nb])
        for i in range(self.Nr - 1, 0, -1):
            self.InvShiftRows()
            self.InvSubBytes()
            self.AddRoundKey(w[i * self.Nb:(i + 1) * self.Nb])
            self.InvMixColumns()

        if self.debug:
            print("----FINAL ROUND----")
        self.InvShiftRows()
        self.InvSubBytes()
        self.AddRoundKey(w[0:self.Nb])
        return transpose(self.State).flatten()


def BinWeight(x):
    c = 0
    v = x
    while(v):
        v &= v - 1
        c += 1
    return c


def CommonBits(q, v):
    return sum([BinWeight(i[0] ^ i[1]) for i in zip(q, v)])


tstKey128 = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15,
             0x88, 0x09, 0xcf, 0x4f, 0x3c]
tstMsg128 = [0x32, 0x43, 0xf6, 0xa8, 0x88, 0x5a, 0x30, 0x8d, 0x31, 0x31, 0x98,
             0xa2, 0xe0, 0x37, 0x07, 0x34]
tstKey256 = list(range(0, 0x20))
tstMsg256 = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa,
             0xbb, 0xcc, 0xdd, 0xee, 0xff]

tstZ = [0 for i in range(0, 0x20)]

# Example
tst1 = [0 for i in range(0, 0x10)]
tst2 = tst1[:]
tst2[0] += 1
g = AESimpl(4, True)
enc1 = g.Encrypt(tst1, tstZ)
enc2 = g.Encrypt(tst2, tstZ)
d = CommonBits(enc1, enc2)