# AES

``````
from numpy import *
from functools import reduce

class GF256:

def BinaryProd(self, x, y):
# Fast Product
r = 0
t = x
while(y > 0):
r ^= (t * (y % 2))
t = t << 1
y = y // 2
return r

def msb(self, x):
i = 0
t = x
while(t > 0):
i += 1
t >>= 1
return i

def prod(self, x, y):
# Fast quotient
# GF(256)
poly = 0b100011011
c = self.BinaryProd(x, y)
if c == 0:
return c
nbits = self.msb(c)
while(nbits > 8):
c ^= (poly << (nbits - 9))
nbits = self.msb(c)
return c

def sum(self, x, y):
return x ^ y

def pow(self, x, n):
l = self.msb(n)
C = 1
for i in range(0, l + 1):
C = self.prod(C, C)
if ((n >> (l - i)) & 0x1) == 1:
C = self.prod(C, x)
return C

def inv(self, x):
# invert = x^(254)
# 254=0b11111110
# special case of square and prod
C = 1
for i in range(0, 7):
C = self.prod(C, C)
C = self.prod(C, x)
C = self.prod(C, C)
return C

def PrintProdTable(self, start=0, end=256):
def numToStr(x):
return "{0:02x}".format(x)
print(" * | ", end=' ')
for y in range(start, end):
print(numToStr(y), end=' ')
print()
print("---" * (end - start + 1) + "--")
for x in range(start, end):
print(numToStr(x) + " | ", end=' ')
for y in range(start, end):
print(numToStr(self.prod(int(x), int(y))), end=' ')
print()

class AESimpl:

def __init__(self, Nk=8, debug=False):
self.Nb = 4
self.K = GF256()
if Nk == 4:
self.Nk = 4
self.Nr = 10
elif Nk == 6:
self.Nk = 6
self.Nr = 12
elif Nk == 8:
self.Nk = 8
self.Nr = 14
else:
raise ValueError("Nk=4,6,8")
self.State = array(
[[0 for i in range(0, self.Nb)]
for j in range(0, 4)])
self.debug = debug
# Constants
# N.B. Rcon starts from 1
self.Rcon = [0] + \
[self.VecToWord(array([self.K.pow(0x2, i - 1), 0x00, 0x00, 0x00]))
for i in range(1, int(2 + self.Nb * (self.Nr + 1) / self.Nk))]
self.affineConst = 0x63
self.bVect = [0, 4, 5, 6, 7]
self.bInv = [2, 5, 7]
self.MixVect = [0x02, 0x03, 0x01, 0x01]
self.InvMixVect = [0x0e, 0x0b, 0x0d, 0x09]
# UTILITY

def rotN(self, x, n, s):
mask = pow(2, s) - 1
if n >= 0:
return ((x << n) & mask) ^ (x >> (s - n))
else:
return (x >> -n) ^ (x << (s + n) & mask)

def RotByte(self, x, n=1):
return self.rotN(x, n, 8)

def VecToWord(self, v):
return reduce(lambda i, j: i ^ j,
[v[i] << (8 * (3 - i)) for i in range(0, 4)])

def WordToVec(self, v):
return array(
[(v >> (8 * (3 - i))) & mask for i in range(0, 4)])

for i in range(0, 4):
row = ""
for j in range(0, len(datum[i])):
v = datum[i, j]
row += current
print(row)

def PrintState(self):
self.PrintArray(self.State)

# SubBytes
def SubByte(self, x):
def affine(ix):
r = self.affineConst
for i in self.bVect:
r ^= self.RotByte(ix, -i)
return r
ix = self.K.inv(x)
return affine(ix)

def InvSubByte(self, x):
def invAffine(v):
# Strip affine constant
r = v ^ self.affineConst
r1 = 0
# invert matrix
for i in self.bInv:
r1 ^= self.RotByte(r, -i)
return r1
return self.K.inv(invAffine(x))

def ApplyBytes(self, f):
self.State = array(
[[f(self.State[j, i])
for i in range(0, 4)]
for j in range(0, self.Nb)])
if self.debug:
print("-----ApplyBytes-----")
self.PrintState()

def SubBytes(self):
self.ApplyBytes(self.SubByte)

def InvSubBytes(self):
self.ApplyBytes(self.InvSubByte)

# Shift Rows
def ApplyShiftRows(self, f):
def shift(i):
return i
for i in range(0, 4):
OldRow = self.State[i].copy()
for j in range(0, self.Nb):
self.State[i][j] = OldRow[(j + f(i)) % 4]
if self.debug:
print("----ApplyShiftRows----")
self.PrintState()

def ShiftRows(self):
def shift(i):
return i
self.ApplyShiftRows(shift)

def InvShiftRows(self):
def invshift(i):
return -i
self.ApplyShiftRows(invshift)

# MixColumns
def ApplyMixColumns(self, vect):
def MixColumn(C):
r = []
for i in range(0, 4):
r.append(
reduce(lambda x, y: self.K.sum(x, y),
[self.K.prod(C[(i + j) % 4], vect[j])
for j in range(0, 4)]))
return array(r)
for i in range(0, self.Nb):
self.State[:, i] = MixColumn(self.State[:, i])
if self.debug:
print("----ApplyMixColumns----")
self.PrintState()

def MixColumns(self):
self.ApplyMixColumns(self.MixVect)

def InvMixColumns(self):
self.ApplyMixColumns(self.InvMixVect)

for i in range(0, 4):
self.State[:, i] = \
self.WordToVec(self.VecToWord(self.State[:, i]) ^ Key[i])
if self.debug:
print("KEY:")
self.PrintArray(transpose(array([
self.WordToVec(Key[i]) for i in range(0, 4)])))
print("STATE:")
self.PrintState()

# KEY EXPANSION

def KeyExpansion(self, Key):
def RotWord(x):
v = self.WordToVec(x)
return self.VecToWord([v[1], v[2], v[3], v[0]])

def SubWord(w):
v = [self.SubByte(x) for x in self.WordToVec(w)]
return self.VecToWord(v)

w = [self.VecToWord(Key[4 * i:4 * (i + 1)]) for i in
range(0, self.Nk)] + [0] * ((self.Nb) * (self.Nr + 1) - self.Nk)
i = self.Nk
while(i < self.Nb * (self.Nr + 1)):
temp = w[i - 1]
# print hex(temp)
if (i % self.Nk == 0):
temp = RotWord(temp)
# print "R=",hex(temp)
temp = SubWord(temp)
# print "S=",hex(temp)
# print "C=",hex(self.Rcon[i/self.Nk])
temp ^= self.Rcon[int(i / self.Nk)]
# print "X=",hex(temp)
elif (self.Nk > 6 and i % self.Nk == 4):
temp = SubWord(temp)
w[i] = w[i - self.Nk] ^ temp
# print hex(w[i])
i += 1
if self.debug:
q = array([self.WordToVec(i) for i in w]).transpose()
ltable = int(self.Nb * 2)
st = "EXPANDED KEY"
print(st + "-" * (ltable * (self.Nr + 1) - len(st)))
self.PrintArray(q, "{0:02x}")
print(("-" * ltable + " " * ltable) *
((self.Nb) * int((self.Nr + 1) / 8)) + "-" * ltable)
msg = ""
for i in range(0, int(len(w) / 4)):
msg += " " * int((ltable - 2) / 2) + \
"{0:02x}".format(i) + " " * int((ltable - 2) / 2)
print(msg)
print("-" * (ltable * (self.Nr + 1)))
return w

# Actual Encryption

def Encrypt(self, msg, key):
# Check
if len(key) < self.Nk * 4:
raise ValueError("Nk=", self.Nk)
# and expand the key
w = self.KeyExpansion(key)
# Fill in initial state
self.State = array([[msg[j + i * self.Nb] for i in range(0, self.Nb)]
for j in range(0, 4)])
if self.debug:
print("---INITIAL STATE---")
self.PrintState()
for i in range(1, self.Nr):
self.SubBytes()
self.ShiftRows()
self.MixColumns()
self.AddRoundKey(w[i * self.Nb:(i + 1) * self.Nb])
if self.debug:
print("---FINAL ROUND---")
self.SubBytes()
self.ShiftRows()
self.AddRoundKey(w[self.Nr * self.Nb:(self.Nr + 1) * self.Nb])
return transpose(self.State).flatten()

# ... and decryption

def Decrypt(self, msg, key):
# Check
if len(key) < self.Nk * 4:
raise ValueError("Nk=", self.Nk)
# and expand the key
w = self.KeyExpansion(key)
# Fill in initial state
self.State = array([[msg[j + i * self.Nb] for i in range(0, self.Nb)]
for j in range(0, 4)])
if self.debug:
print("----Input----")
self.PrintState()
# Invert final round
self.AddRoundKey(w[self.Nr * self.Nb:(self.Nr + 1) * self.Nb])
for i in range(self.Nr - 1, 0, -1):
self.InvShiftRows()
self.InvSubBytes()
self.AddRoundKey(w[i * self.Nb:(i + 1) * self.Nb])
self.InvMixColumns()

if self.debug:
print("----FINAL ROUND----")
self.InvShiftRows()
self.InvSubBytes()
return transpose(self.State).flatten()

def BinWeight(x):
c = 0
v = x
while(v):
v &= v - 1
c += 1
return c

def CommonBits(q, v):
return sum([BinWeight(i[0] ^ i[1]) for i in zip(q, v)])

tstKey128 = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15,
0x88, 0x09, 0xcf, 0x4f, 0x3c]
tstMsg128 = [0x32, 0x43, 0xf6, 0xa8, 0x88, 0x5a, 0x30, 0x8d, 0x31, 0x31, 0x98,
0xa2, 0xe0, 0x37, 0x07, 0x34]
tstKey256 = list(range(0, 0x20))
tstMsg256 = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa,
0xbb, 0xcc, 0xdd, 0xee, 0xff]

tstZ = [0 for i in range(0, 0x20)]

# Example
tst1 = [0 for i in range(0, 0x10)]
tst2 = tst1[:]
tst2[0] += 1
g = AESimpl(4, True)
enc1 = g.Encrypt(tst1, tstZ)
enc2 = g.Encrypt(tst2, tstZ)
d = CommonBits(enc1, enc2)

```
```