MixNet


from Crypto.PublicKey import RSA


class Mixer(object):
    "Mixer Generico/USA RSA COME CHIAVE PUBBLICA"

    def __init__(self, lines, RSAobject="", dbg=False):
        "lines=linee in i/o; key=chiave del mixer"
        if RSAobject == "":
            print("Mixer: generating RSA key")
            RSAobject = RSA.generate(1024)
        self.lines = lines
        self.CryptoMachinery = RSAobject
        self.store = []
        self.dbg = dbg

    def GetPk(self):
        return self.CryptoMachinery.publickey()

    def GenPerm(self):
        "Permutazione casuale"
        from random import randrange

        i = []
        while len(i) < self.lines:
            k = randrange(0, self.lines)
            if not (k in i):
                i.append(k)
        return i

    def encEntry(self, x):
        "Codifica entrate:IBRIDO RSA+AES(ECB)"
        from random import randrange
        from Crypto.Cipher import AES
        from hashlib import sha256
        import base64

        sk = int(
            randrange(0, 2 ** self.CryptoMachinery.size())
        )
        skb = sk.to_bytes(
            self.CryptoMachinery.size(), byteorder="big"
        )
        AESK = sha256(bytearray(skb)).digest()
        ECB = AES.new(AESK, AES.MODE_ECB)
        rx = repr(x).encode("utf-8")
        encr = repr([rx]).encode("utf-8")
        t = 16 - len(encr) % 16
        encr += b"^" * t
        return [
            repr(self.GetPk().encrypt(sk, 0)),
            base64.b64encode(ECB.encrypt(encr)),
        ]

    def decEntry(self, x):
        import re
        from Crypto.Cipher import AES
        from hashlib import sha256
        import base64

        sk = self.CryptoMachinery.decrypt(eval(x[0])[0])
        # print sk
        skb = sk.to_bytes(
            self.CryptoMachinery.size(), byteorder="big"
        )
        AESK = sha256(bytearray(skb)).digest()
        ECB = AES.new(AESK, AES.MODE_ECB)
        recv = ECB.decrypt(base64.b64decode(x[1]))
        recv = eval(
            re.sub("](\^+)$", "]", recv.decode("utf-8"))
        )[0]
        return eval(recv)

    def encrypt(self, datum):
        "Codifica dati"
        res = [self.encEntry(x) for x in datum]
        return res

    def decrypt(self, datum):
        "Decodifica dati"
        res = [self.decEntry(x) for x in datum]
        return res

    def load(self, msg):
        "Caricamento messaggi nel mixer"
        if self.dbg:
            print("Received msg=", msg)
        actual = self.decrypt(msg)
        if self.dbg:
            print("Loaded msg=", actual)
        self.store.append(actual)

    def mix(self):
        "Mixing"
        perm = self.GenPerm()
        if self.dbg:
            print("Permutation: ", perm)
            self.perm = perm
        self.store[0 : self.lines] = [
            self.store[x] for x in perm
        ]

    def PrintStatus(self):
        print(self.store)

    def operate(self):
        "Ogni mixer trasmette solo quando ha abbastanza messaggi"
        if len(self.store) < self.lines:
            return False
        self.mix()
        outdata = self.store[0 : self.lines]
        self.store = self.store[
            self.lines : len(self.store)
        ]
        return outdata


class MixNet(object):
    def __init__(
        self, lines, mixer_nos=4, pathlen=4, dbg=False
    ):
        "Setup: definizione e init mixers"
        self.Mixers = [
            Mixer(lines) for x in range(0, mixer_nos)
        ]
        self.pathlen = pathlen
        self.loaded = 0
        self.dbg = dbg

    def __getitem__(self, i):
        return self.Mixers[i]

    def PrepareMsg(self, text, mixer_list=[]):
        "Preparazione Messaggio-source routing in mixer_list"
        from random import randrange

        if mixer_list == -1:
            dummy = True
            mixer_list = []
        else:
            dummy = False
        t = len(mixer_list)
        n = len(self.Mixers)
        if t == 0:
            mixer_list = [
                randrange(0, n)
                for x in range(0, self.pathlen)
            ]
            t = self.pathlen
        if dummy:
            message = ["DUMMY", ""]
        else:
            message = ["EOM", text]
        if self.dbg:
            print("Path assegnato=", mixer_list)
        for DST in mixer_list[::-1]:
            if self.dbg and not (dummy):
                print("Codifica per Mixer numero=", DST)
                print("Message=", message)
            message = [
                DST,
                self.Mixers[DST].encrypt(message),
            ]
        return message

    def RouteMsg(self, Msg):
        "Routing dei messaggi"
        if Msg[0] == "EOM":
            return Msg[1]
        if Msg[0] == "DUMMY":
            return False
        dest = Msg[0]
        payload = Msg[1]
        self.Mixers[dest].load(payload)
        return False

    def AcceptMsg(self, Msg, mixer_list=[]):
        "Accetta un messaggio nella mixnet"
        M = self.PrepareMsg(Msg, mixer_list)
        self.loaded += 1
        print("Messaggio: ", Msg)
        print("EXT -> ", M[0])
        self.RouteMsg(M)

    def FireDummy(self):
        "Inietta un messaggio a vuoto"
        D = self.PrepareMsg("", -1)
        self.RouteMsg(D)

    def PrintStatus(self):
        for i in range(0, len(self.Mixers)):
            print("Mixer ", i, ":")
            self.Mixers[i].PrintStatus()
            print("---")

    def OperateMixers(self):
        "Esegue i mixers"
        output = []
        allres = False
        for i in range(0, len(self.Mixers)):
            rr = self.Mixers[i].operate()
            allres |= not (rr == False)
            if rr:
                for msg in rr:
                    print("Route: ", i, "->", msg[0])
                    x = self.RouteMsg(msg)
                    if x:
                        output.append(x)
        if not (allres):
            print("Inject Dummy message")
            self.FireDummy()
        if self.dbg:
            self.PrintStatus()
        self.loaded -= len(output)
        print("---")
        return output

    def WaitForOutput(self):
        if self.loaded == 0:
            return []
        r = []
        while r == []:
            r = self.OperateMixers()
        return r

    def OutAll(self):
        out = []
        r = self.WaitForOutput()
        while not (r == []):
            out.append(r)
            r = self.WaitForOutput()
        return out

    def GetTstMessages(self, t):
        for i in range(0, t):
            self.AcceptMsg("Test message #" + repr(i))