AES


from numpy import *
from functools import reduce


class GF256:
    def BinaryProd(self, x, y):
        # Fast Product
        r = 0
        t = x
        while y > 0:
            r ^= t * (y % 2)
            t = t << 1
            y = y // 2
        return r

    def msb(self, x):
        i = 0
        t = x
        while t > 0:
            i += 1
            t >>= 1
        return i

    def prod(self, x, y):
        # Fast quotient
        # GF(256)
        poly = 0b100011011
        c = self.BinaryProd(x, y)
        if c == 0:
            return c
        nbits = self.msb(c)
        while nbits > 8:
            c ^= poly << (nbits - 9)
            nbits = self.msb(c)
        return c

    def sum(self, x, y):
        return x ^ y

    def pow(self, x, n):
        l = self.msb(n)
        C = 1
        for i in range(0, l + 1):
            C = self.prod(C, C)
            if ((n >> (l - i)) & 0x1) == 1:
                C = self.prod(C, x)
        return C

    def inv(self, x):
        # invert = x^(254)
        # 254=0b11111110
        # special case of square and prod
        C = 1
        for i in range(0, 7):
            C = self.prod(C, C)
            C = self.prod(C, x)
        C = self.prod(C, C)
        return C

    def PrintProdTable(self, start=0, end=256):
        def numToStr(x):
            return "{0:02x}".format(x)

        print(" * | ", end=" ")
        for y in range(start, end):
            print(numToStr(y), end=" ")
        print()
        print("---" * (end - start + 1) + "--")
        for x in range(start, end):
            print(numToStr(x) + " | ", end=" ")
            for y in range(start, end):
                print(
                    numToStr(self.prod(int(x), int(y))),
                    end=" ",
                )
            print()


class AESimpl:
    def __init__(self, Nk=8, debug=False):
        self.Nb = 4
        self.K = GF256()
        if Nk == 4:
            self.Nk = 4
            self.Nr = 10
        elif Nk == 6:
            self.Nk = 6
            self.Nr = 12
        elif Nk == 8:
            self.Nk = 8
            self.Nr = 14
        else:
            raise ValueError("Nk=4,6,8")
        self.State = array(
            [
                [0 for i in range(0, self.Nb)]
                for j in range(0, 4)
            ]
        )
        self.debug = debug
        # Constants
        # N.B. Rcon starts from 1
        self.Rcon = [0] + [
            self.VecToWord(
                array(
                    [
                        self.K.pow(0x2, i - 1),
                        0x00,
                        0x00,
                        0x00,
                    ]
                )
            )
            for i in range(
                1,
                int(2 + self.Nb * (self.Nr + 1) / self.Nk),
            )
        ]
        self.affineConst = 0x63
        self.bVect = [0, 4, 5, 6, 7]
        self.bInv = [2, 5, 7]
        self.MixVect = [0x02, 0x03, 0x01, 0x01]
        self.InvMixVect = [0x0E, 0x0B, 0x0D, 0x09]

    # UTILITY

    def rotN(self, x, n, s):
        mask = pow(2, s) - 1
        if n >= 0:
            return ((x << n) & mask) ^ (x >> (s - n))
        else:
            return (x >> -n) ^ (x << (s + n) & mask)

    def RotByte(self, x, n=1):
        return self.rotN(x, n, 8)

    def VecToWord(self, v):
        return reduce(
            lambda i, j: i ^ j,
            [v[i] << (8 * (3 - i)) for i in range(0, 4)],
        )

    def WordToVec(self, v):
        mask = 0xFF
        return array(
            [
                (v >> (8 * (3 - i))) & mask
                for i in range(0, 4)
            ]
        )

    def PrintArray(self, datum, mask="{0:02x} "):
        for i in range(0, 4):
            row = ""
            for j in range(0, len(datum[i])):
                v = datum[i, j]
                current = mask.format(int(v))
                row += current
            print(row)

    def PrintState(self):
        self.PrintArray(self.State)

    # SubBytes
    def SubByte(self, x):
        def affine(ix):
            # Start with the affine constant
            r = self.affineConst
            for i in self.bVect:
                r ^= self.RotByte(ix, -i)
            return r

        ix = self.K.inv(x)
        return affine(ix)

    def InvSubByte(self, x):
        def invAffine(v):
            # Strip affine constant
            r = v ^ self.affineConst
            r1 = 0
            # invert matrix
            for i in self.bInv:
                r1 ^= self.RotByte(r, -i)
            return r1

        return self.K.inv(invAffine(x))

    def ApplyBytes(self, f):
        self.State = array(
            [
                [f(self.State[j, i]) for i in range(0, 4)]
                for j in range(0, self.Nb)
            ]
        )
        if self.debug:
            print("-----ApplyBytes-----")
            self.PrintState()

    def SubBytes(self):
        self.ApplyBytes(self.SubByte)

    def InvSubBytes(self):
        self.ApplyBytes(self.InvSubByte)

    # Shift Rows
    def ApplyShiftRows(self, f):
        def shift(i):
            return i

        for i in range(0, 4):
            OldRow = self.State[i].copy()
            for j in range(0, self.Nb):
                self.State[i][j] = OldRow[(j + f(i)) % 4]
        if self.debug:
            print("----ApplyShiftRows----")
            self.PrintState()

    def ShiftRows(self):
        def shift(i):
            return i

        self.ApplyShiftRows(shift)

    def InvShiftRows(self):
        def invshift(i):
            return -i

        self.ApplyShiftRows(invshift)

    # MixColumns
    def ApplyMixColumns(self, vect):
        def MixColumn(C):
            r = []
            for i in range(0, 4):
                r.append(
                    reduce(
                        lambda x, y: self.K.sum(x, y),
                        [
                            self.K.prod(
                                C[(i + j) % 4], vect[j]
                            )
                            for j in range(0, 4)
                        ],
                    )
                )
            return array(r)

        for i in range(0, self.Nb):
            self.State[:, i] = MixColumn(self.State[:, i])
        if self.debug:
            print("----ApplyMixColumns----")
            self.PrintState()

    def MixColumns(self):
        self.ApplyMixColumns(self.MixVect)

    def InvMixColumns(self):
        self.ApplyMixColumns(self.InvMixVect)

    # AddRoundKey
    def AddRoundKey(self, Key):
        for i in range(0, 4):
            self.State[:, i] = self.WordToVec(
                self.VecToWord(self.State[:, i]) ^ Key[i]
            )
        if self.debug:
            print("----AddRoundKey----")
            print("KEY:")
            self.PrintArray(
                transpose(
                    array(
                        [
                            self.WordToVec(Key[i])
                            for i in range(0, 4)
                        ]
                    )
                )
            )
            print("STATE:")
            self.PrintState()

    # KEY EXPANSION

    def KeyExpansion(self, Key):
        def RotWord(x):
            v = self.WordToVec(x)
            return self.VecToWord([v[1], v[2], v[3], v[0]])

        def SubWord(w):
            v = [self.SubByte(x) for x in self.WordToVec(w)]
            return self.VecToWord(v)

        w = [
            self.VecToWord(Key[4 * i : 4 * (i + 1)])
            for i in range(0, self.Nk)
        ] + [0] * ((self.Nb) * (self.Nr + 1) - self.Nk)
        i = self.Nk
        while i < self.Nb * (self.Nr + 1):
            temp = w[i - 1]
            # print hex(temp)
            if i % self.Nk == 0:
                temp = RotWord(temp)
                # print "R=",hex(temp)
                temp = SubWord(temp)
                # print "S=",hex(temp)
                # print "C=",hex(self.Rcon[i/self.Nk])
                temp ^= self.Rcon[int(i / self.Nk)]
                # print "X=",hex(temp)
            elif self.Nk > 6 and i % self.Nk == 4:
                temp = SubWord(temp)
            w[i] = w[i - self.Nk] ^ temp
            # print hex(w[i])
            i += 1
        if self.debug:
            q = array(
                [self.WordToVec(i) for i in w]
            ).transpose()
            ltable = int(self.Nb * 2)
            st = "EXPANDED KEY"
            print(
                st
                + "-" * (ltable * (self.Nr + 1) - len(st))
            )
            self.PrintArray(q, "{0:02x}")
            print(
                ("-" * ltable + " " * ltable)
                * ((self.Nb) * int((self.Nr + 1) / 8))
                + "-" * ltable
            )
            msg = ""
            for i in range(0, int(len(w) / 4)):
                msg += (
                    " " * int((ltable - 2) / 2)
                    + "{0:02x}".format(i)
                    + " " * int((ltable - 2) / 2)
                )
            print(msg)
            print("-" * (ltable * (self.Nr + 1)))
        return w

    # Actual Encryption

    def Encrypt(self, msg, key):
        # Check
        if len(key) < self.Nk * 4:
            raise ValueError("Nk=", self.Nk)
        # and expand the key
        w = self.KeyExpansion(key)
        # Fill in initial state
        self.State = array(
            [
                [
                    msg[j + i * self.Nb]
                    for i in range(0, self.Nb)
                ]
                for j in range(0, 4)
            ]
        )
        if self.debug:
            print("---INITIAL STATE---")
            self.PrintState()
        self.AddRoundKey(w[0 : self.Nb])
        for i in range(1, self.Nr):
            self.SubBytes()
            self.ShiftRows()
            self.MixColumns()
            self.AddRoundKey(
                w[i * self.Nb : (i + 1) * self.Nb]
            )
        if self.debug:
            print("---FINAL ROUND---")
        self.SubBytes()
        self.ShiftRows()
        self.AddRoundKey(
            w[self.Nr * self.Nb : (self.Nr + 1) * self.Nb]
        )
        return transpose(self.State).flatten()

    # ... and decryption

    def Decrypt(self, msg, key):
        # Check
        if len(key) < self.Nk * 4:
            raise ValueError("Nk=", self.Nk)
        # and expand the key
        w = self.KeyExpansion(key)
        # Fill in initial state
        self.State = array(
            [
                [
                    msg[j + i * self.Nb]
                    for i in range(0, self.Nb)
                ]
                for j in range(0, 4)
            ]
        )
        if self.debug:
            print("----Input----")
            self.PrintState()
        # Invert final round
        self.AddRoundKey(
            w[self.Nr * self.Nb : (self.Nr + 1) * self.Nb]
        )
        for i in range(self.Nr - 1, 0, -1):
            self.InvShiftRows()
            self.InvSubBytes()
            self.AddRoundKey(
                w[i * self.Nb : (i + 1) * self.Nb]
            )
            self.InvMixColumns()

        if self.debug:
            print("----FINAL ROUND----")
        self.InvShiftRows()
        self.InvSubBytes()
        self.AddRoundKey(w[0 : self.Nb])
        return transpose(self.State).flatten()


def BinWeight(x):
    c = 0
    v = x
    while v:
        v &= v - 1
        c += 1
    return c


def CommonBits(q, v):
    return sum([BinWeight(i[0] ^ i[1]) for i in zip(q, v)])


tstKey128 = [
    0x2B,
    0x7E,
    0x15,
    0x16,
    0x28,
    0xAE,
    0xD2,
    0xA6,
    0xAB,
    0xF7,
    0x15,
    0x88,
    0x09,
    0xCF,
    0x4F,
    0x3C,
]
tstMsg128 = [
    0x32,
    0x43,
    0xF6,
    0xA8,
    0x88,
    0x5A,
    0x30,
    0x8D,
    0x31,
    0x31,
    0x98,
    0xA2,
    0xE0,
    0x37,
    0x07,
    0x34,
]
tstKey256 = list(range(0, 0x20))
tstMsg256 = [
    0x00,
    0x11,
    0x22,
    0x33,
    0x44,
    0x55,
    0x66,
    0x77,
    0x88,
    0x99,
    0xAA,
    0xBB,
    0xCC,
    0xDD,
    0xEE,
    0xFF,
]

tstZ = [0 for i in range(0, 0x20)]

# Example
tst1 = [0 for i in range(0, 0x10)]
tst2 = tst1[:]
tst2[0] += 1
g = AESimpl(4, True)
enc1 = g.Encrypt(tst1, tstZ)
enc2 = g.Encrypt(tst2, tstZ)
d = CommonBits(enc1, enc2)