from numpy import *
from functools import reduce
class GF256:
def BinaryProd(self, x, y):
# Fast Product
r = 0
t = x
while y > 0:
r ^= t * (y % 2)
t = t << 1
y = y // 2
return r
def msb(self, x):
i = 0
t = x
while t > 0:
i += 1
t >>= 1
return i
def prod(self, x, y):
# Fast quotient
# GF(256)
poly = 0b100011011
c = self.BinaryProd(x, y)
if c == 0:
return c
nbits = self.msb(c)
while nbits > 8:
c ^= poly << (nbits - 9)
nbits = self.msb(c)
return c
def sum(self, x, y):
return x ^ y
def pow(self, x, n):
l = self.msb(n)
C = 1
for i in range(0, l + 1):
C = self.prod(C, C)
if ((n >> (l - i)) & 0x1) == 1:
C = self.prod(C, x)
return C
def inv(self, x):
# invert = x^(254)
# 254=0b11111110
# special case of square and prod
C = 1
for i in range(0, 7):
C = self.prod(C, C)
C = self.prod(C, x)
C = self.prod(C, C)
return C
def PrintProdTable(self, start=0, end=256):
def numToStr(x):
return "{0:02x}".format(x)
print(" * | ", end=" ")
for y in range(start, end):
print(numToStr(y), end=" ")
print()
print("---" * (end - start + 1) + "--")
for x in range(start, end):
print(numToStr(x) + " | ", end=" ")
for y in range(start, end):
print(
numToStr(self.prod(int(x), int(y))),
end=" ",
)
print()
class AESimpl:
def __init__(self, Nk=8, debug=False):
self.Nb = 4
self.K = GF256()
if Nk == 4:
self.Nk = 4
self.Nr = 10
elif Nk == 6:
self.Nk = 6
self.Nr = 12
elif Nk == 8:
self.Nk = 8
self.Nr = 14
else:
raise ValueError("Nk=4,6,8")
self.State = array(
[
[0 for i in range(0, self.Nb)]
for j in range(0, 4)
]
)
self.debug = debug
# Constants
# N.B. Rcon starts from 1
self.Rcon = [0] + [
self.VecToWord(
array(
[
self.K.pow(0x2, i - 1),
0x00,
0x00,
0x00,
]
)
)
for i in range(
1,
int(2 + self.Nb * (self.Nr + 1) / self.Nk),
)
]
self.affineConst = 0x63
self.bVect = [0, 4, 5, 6, 7]
self.bInv = [2, 5, 7]
self.MixVect = [0x02, 0x03, 0x01, 0x01]
self.InvMixVect = [0x0E, 0x0B, 0x0D, 0x09]
# UTILITY
def rotN(self, x, n, s):
mask = pow(2, s) - 1
if n >= 0:
return ((x << n) & mask) ^ (x >> (s - n))
else:
return (x >> -n) ^ (x << (s + n) & mask)
def RotByte(self, x, n=1):
return self.rotN(x, n, 8)
def VecToWord(self, v):
return reduce(
lambda i, j: i ^ j,
[v[i] << (8 * (3 - i)) for i in range(0, 4)],
)
def WordToVec(self, v):
mask = 0xFF
return array(
[
(v >> (8 * (3 - i))) & mask
for i in range(0, 4)
]
)
def PrintArray(self, datum, mask="{0:02x} "):
for i in range(0, 4):
row = ""
for j in range(0, len(datum[i])):
v = datum[i, j]
current = mask.format(int(v))
row += current
print(row)
def PrintState(self):
self.PrintArray(self.State)
# SubBytes
def SubByte(self, x):
def affine(ix):
# Start with the affine constant
r = self.affineConst
for i in self.bVect:
r ^= self.RotByte(ix, -i)
return r
ix = self.K.inv(x)
return affine(ix)
def InvSubByte(self, x):
def invAffine(v):
# Strip affine constant
r = v ^ self.affineConst
r1 = 0
# invert matrix
for i in self.bInv:
r1 ^= self.RotByte(r, -i)
return r1
return self.K.inv(invAffine(x))
def ApplyBytes(self, f):
self.State = array(
[
[f(self.State[j, i]) for i in range(0, 4)]
for j in range(0, self.Nb)
]
)
if self.debug:
print("-----ApplyBytes-----")
self.PrintState()
def SubBytes(self):
self.ApplyBytes(self.SubByte)
def InvSubBytes(self):
self.ApplyBytes(self.InvSubByte)
# Shift Rows
def ApplyShiftRows(self, f):
def shift(i):
return i
for i in range(0, 4):
OldRow = self.State[i].copy()
for j in range(0, self.Nb):
self.State[i][j] = OldRow[(j + f(i)) % 4]
if self.debug:
print("----ApplyShiftRows----")
self.PrintState()
def ShiftRows(self):
def shift(i):
return i
self.ApplyShiftRows(shift)
def InvShiftRows(self):
def invshift(i):
return -i
self.ApplyShiftRows(invshift)
# MixColumns
def ApplyMixColumns(self, vect):
def MixColumn(C):
r = []
for i in range(0, 4):
r.append(
reduce(
lambda x, y: self.K.sum(x, y),
[
self.K.prod(
C[(i + j) % 4], vect[j]
)
for j in range(0, 4)
],
)
)
return array(r)
for i in range(0, self.Nb):
self.State[:, i] = MixColumn(self.State[:, i])
if self.debug:
print("----ApplyMixColumns----")
self.PrintState()
def MixColumns(self):
self.ApplyMixColumns(self.MixVect)
def InvMixColumns(self):
self.ApplyMixColumns(self.InvMixVect)
# AddRoundKey
def AddRoundKey(self, Key):
for i in range(0, 4):
self.State[:, i] = self.WordToVec(
self.VecToWord(self.State[:, i]) ^ Key[i]
)
if self.debug:
print("----AddRoundKey----")
print("KEY:")
self.PrintArray(
transpose(
array(
[
self.WordToVec(Key[i])
for i in range(0, 4)
]
)
)
)
print("STATE:")
self.PrintState()
# KEY EXPANSION
def KeyExpansion(self, Key):
def RotWord(x):
v = self.WordToVec(x)
return self.VecToWord([v[1], v[2], v[3], v[0]])
def SubWord(w):
v = [self.SubByte(x) for x in self.WordToVec(w)]
return self.VecToWord(v)
w = [
self.VecToWord(Key[4 * i : 4 * (i + 1)])
for i in range(0, self.Nk)
] + [0] * ((self.Nb) * (self.Nr + 1) - self.Nk)
i = self.Nk
while i < self.Nb * (self.Nr + 1):
temp = w[i - 1]
# print hex(temp)
if i % self.Nk == 0:
temp = RotWord(temp)
# print "R=",hex(temp)
temp = SubWord(temp)
# print "S=",hex(temp)
# print "C=",hex(self.Rcon[i/self.Nk])
temp ^= self.Rcon[int(i / self.Nk)]
# print "X=",hex(temp)
elif self.Nk > 6 and i % self.Nk == 4:
temp = SubWord(temp)
w[i] = w[i - self.Nk] ^ temp
# print hex(w[i])
i += 1
if self.debug:
q = array(
[self.WordToVec(i) for i in w]
).transpose()
ltable = int(self.Nb * 2)
st = "EXPANDED KEY"
print(
st
+ "-" * (ltable * (self.Nr + 1) - len(st))
)
self.PrintArray(q, "{0:02x}")
print(
("-" * ltable + " " * ltable)
* ((self.Nb) * int((self.Nr + 1) / 8))
+ "-" * ltable
)
msg = ""
for i in range(0, int(len(w) / 4)):
msg += (
" " * int((ltable - 2) / 2)
+ "{0:02x}".format(i)
+ " " * int((ltable - 2) / 2)
)
print(msg)
print("-" * (ltable * (self.Nr + 1)))
return w
# Actual Encryption
def Encrypt(self, msg, key):
# Check
if len(key) < self.Nk * 4:
raise ValueError("Nk=", self.Nk)
# and expand the key
w = self.KeyExpansion(key)
# Fill in initial state
self.State = array(
[
[
msg[j + i * self.Nb]
for i in range(0, self.Nb)
]
for j in range(0, 4)
]
)
if self.debug:
print("---INITIAL STATE---")
self.PrintState()
self.AddRoundKey(w[0 : self.Nb])
for i in range(1, self.Nr):
self.SubBytes()
self.ShiftRows()
self.MixColumns()
self.AddRoundKey(
w[i * self.Nb : (i + 1) * self.Nb]
)
if self.debug:
print("---FINAL ROUND---")
self.SubBytes()
self.ShiftRows()
self.AddRoundKey(
w[self.Nr * self.Nb : (self.Nr + 1) * self.Nb]
)
return transpose(self.State).flatten()
# ... and decryption
def Decrypt(self, msg, key):
# Check
if len(key) < self.Nk * 4:
raise ValueError("Nk=", self.Nk)
# and expand the key
w = self.KeyExpansion(key)
# Fill in initial state
self.State = array(
[
[
msg[j + i * self.Nb]
for i in range(0, self.Nb)
]
for j in range(0, 4)
]
)
if self.debug:
print("----Input----")
self.PrintState()
# Invert final round
self.AddRoundKey(
w[self.Nr * self.Nb : (self.Nr + 1) * self.Nb]
)
for i in range(self.Nr - 1, 0, -1):
self.InvShiftRows()
self.InvSubBytes()
self.AddRoundKey(
w[i * self.Nb : (i + 1) * self.Nb]
)
self.InvMixColumns()
if self.debug:
print("----FINAL ROUND----")
self.InvShiftRows()
self.InvSubBytes()
self.AddRoundKey(w[0 : self.Nb])
return transpose(self.State).flatten()
def BinWeight(x):
c = 0
v = x
while v:
v &= v - 1
c += 1
return c
def CommonBits(q, v):
return sum([BinWeight(i[0] ^ i[1]) for i in zip(q, v)])
tstKey128 = [
0x2B,
0x7E,
0x15,
0x16,
0x28,
0xAE,
0xD2,
0xA6,
0xAB,
0xF7,
0x15,
0x88,
0x09,
0xCF,
0x4F,
0x3C,
]
tstMsg128 = [
0x32,
0x43,
0xF6,
0xA8,
0x88,
0x5A,
0x30,
0x8D,
0x31,
0x31,
0x98,
0xA2,
0xE0,
0x37,
0x07,
0x34,
]
tstKey256 = list(range(0, 0x20))
tstMsg256 = [
0x00,
0x11,
0x22,
0x33,
0x44,
0x55,
0x66,
0x77,
0x88,
0x99,
0xAA,
0xBB,
0xCC,
0xDD,
0xEE,
0xFF,
]
tstZ = [0 for i in range(0, 0x20)]
# Example
tst1 = [0 for i in range(0, 0x10)]
tst2 = tst1[:]
tst2[0] += 1
g = AESimpl(4, True)
enc1 = g.Encrypt(tst1, tstZ)
enc2 = g.Encrypt(tst2, tstZ)
d = CommonBits(enc1, enc2)