MixNet


from Crypto.PublicKey import RSA


class Mixer(object):
    "Mixer Generico/USA RSA COME CHIAVE PUBBLICA"

    def __init__(self, lines, RSAobject="", dbg=False):
        "lines=linee in i/o; key=chiave del mixer"
        if RSAobject == "":
            print("Mixer: generating RSA key")
            RSAobject = RSA.generate(1024)
        self.lines = lines
        self.CryptoMachinery = RSAobject
        self.store = []
        self.dbg = dbg

    def GetPk(self):
        return self.CryptoMachinery.publickey()

    def GenPerm(self):
        "Permutazione casuale"
        from random import randrange
        i = []
        while(len(i) < self.lines):
            k = randrange(0, self.lines)
            if not(k in i):
                i.append(k)
        return i

    def encEntry(self, x):
        "Codifica entrate:IBRIDO RSA+AES(ECB)"
        from random import randrange
        from Crypto.Cipher import AES
        from hashlib import sha256
        import base64
        sk = int(randrange(0, 2**self.CryptoMachinery.size()))
        skb = sk.to_bytes(self.CryptoMachinery.size(), byteorder='big')
        AESK = sha256(bytearray(skb)).digest()
        ECB = AES.new(AESK, AES.MODE_ECB)
        rx = repr(x).encode('utf-8')
        encr = repr([rx]).encode('utf-8')
        t = 16 - len(encr) % 16
        encr += b'^' * t
        return [repr(self.GetPk().encrypt(sk, 0)),
                base64.b64encode(ECB.encrypt(encr))]

    def decEntry(self, x):
        import re
        from Crypto.Cipher import AES
        from hashlib import sha256
        import base64
        sk = self.CryptoMachinery.decrypt(eval(x[0])[0])
        # print sk
        skb = sk.to_bytes(self.CryptoMachinery.size(), byteorder='big')
        AESK = sha256(bytearray(skb)).digest()
        ECB = AES.new(AESK, AES.MODE_ECB)
        recv = ECB.decrypt(base64.b64decode(x[1]))
        recv = eval(re.sub("](\^+)$", "]", recv.decode('utf-8')))[0]
        return eval(recv)

    def encrypt(self, datum):
        "Codifica dati"
        res = [self.encEntry(x) for x in datum]
        return res

    def decrypt(self, datum):
        "Decodifica dati"
        res = [self.decEntry(x) for x in datum]
        return res

    def load(self, msg):
        "Caricamento messaggi nel mixer"
        if self.dbg:
            print("Received msg=", msg)
        actual = self.decrypt(msg)
        if self.dbg:
            print("Loaded msg=", actual)
        self.store.append(actual)

    def mix(self):
        "Mixing"
        perm = self.GenPerm()
        if self.dbg:
            print("Permutation: ", perm)
            self.perm = perm
        self.store[0:self.lines] = [self.store[x] for x in perm]

    def PrintStatus(self):
        print(self.store)

    def operate(self):
        "Ogni mixer trasmette solo quando ha abbastanza messaggi"
        if len(self.store) < self.lines:
            return False
        self.mix()
        outdata = self.store[0:self.lines]
        self.store = self.store[self.lines:len(self.store)]
        return outdata


class MixNet(object):

    def __init__(self, lines, mixer_nos=4, pathlen=4, dbg=False):
        "Setup: definizione e init mixers"
        self.Mixers = [Mixer(lines) for x in range(0, mixer_nos)]
        self.pathlen = pathlen
        self.loaded = 0
        self.dbg = dbg

    def __getitem__(self, i):
        return self.Mixers[i]

    def PrepareMsg(self, text, mixer_list=[]):
        "Preparazione Messaggio-source routing in mixer_list"
        from random import randrange
        if mixer_list == -1:
            dummy = True
            mixer_list = []
        else:
            dummy = False
        t = len(mixer_list)
        n = len(self.Mixers)
        if (t == 0):
            mixer_list = [randrange(0, n) for x in range(0, self.pathlen)]
            t = self.pathlen
        if dummy:
            message = ['DUMMY', ""]
        else:
            message = ['EOM', text]
        if self.dbg:
            print("Path assegnato=", mixer_list)
        for DST in mixer_list[::-1]:
            if self.dbg and not(dummy):
                print("Codifica per Mixer numero=", DST)
                print("Message=", message)
            message = [DST, self.Mixers[DST].encrypt(message)]
        return message

    def RouteMsg(self, Msg):
        "Routing dei messaggi"
        if Msg[0] == 'EOM':
            return(Msg[1])
        if Msg[0] == 'DUMMY':
            return False
        dest = Msg[0]
        payload = Msg[1]
        self.Mixers[dest].load(payload)
        return False

    def AcceptMsg(self, Msg, mixer_list=[]):
        "Accetta un messaggio nella mixnet"
        M = self.PrepareMsg(Msg, mixer_list)
        self.loaded += 1
        print("Messaggio: ", Msg)
        print("EXT -> ", M[0])
        self.RouteMsg(M)

    def FireDummy(self):
        "Inietta un messaggio a vuoto"
        D = self.PrepareMsg("", -1)
        self.RouteMsg(D)

    def PrintStatus(self):
        for i in range(0, len(self.Mixers)):
            print("Mixer ", i, ":")
            self.Mixers[i].PrintStatus()
            print("---")

    def OperateMixers(self):
        "Esegue i mixers"
        output = []
        allres = False
        for i in range(0, len(self.Mixers)):
            rr = self.Mixers[i].operate()
            allres |= not(rr == False)
            if rr:
                for msg in rr:
                    print("Route: ", i, "->", msg[0])
                    x = self.RouteMsg(msg)
                    if x:
                        output.append(x)
        if not(allres):
            print("Inject Dummy message")
            self.FireDummy()
        if self.dbg:
            self.PrintStatus()
        self.loaded -= len(output)
        print("---")
        return output

    def WaitForOutput(self):
        if (self.loaded == 0):
            return []
        r = []
        while(r == []):
            r = self.OperateMixers()
        return r

    def OutAll(self):
        out = []
        r = self.WaitForOutput()
        while(not(r == [])):
            out.append(r)
            r = self.WaitForOutput()
        return out

    def GetTstMessages(self, t):
        for i in range(0, t):
            self.AcceptMsg("Test message #" + repr(i))