import asyncio
import os
import sys
import struct
import time


SERVER = os.environ['SERVER_IP']
NBT_NAME = b'EOGFGLGPCACACACACACACACACACACACA'  # "neko"

DNS_SERVER = (SERVER, 53)
NBT_SERVER = (SERVER, 137)


def construct_packet(proto, msg_id):
    if proto == 'nbt':
        name = [NBT_NAME] + [b'.' * 63] * 126
        qtype = 0x20  # NBT_QTYPE_NETBIOS
    else:
        name = [b'.' * 63] * 127
        qtype = 1     # dns type A

    header = struct.pack('!6H',
                         msg_id,
                         0x0100,       # query, with recursion
                         0x0001,       # number of queries
                         0x0000,       # no answers
                         0x0000,       # no records
                         0x0000,       # no extra records
    )
    tail = struct.pack('!BHH',
                       0x00,         # root node
                       qtype,
                       0x0001,       # class IN-ternet
    )
    encoded_bits = [header]

    for n in name:
        encoded_bits.append(b'%c%s' % (len(n), n))

    encoded_bits.append(tail)
    return b''.join(encoded_bits)


running_connections = set()


class DnsBotherer:
    def __init__(self, proto, i, loop):
        self.id = f'{proto} {i}'
        self.packet = construct_packet(proto, i)
        self.loop = loop
        self.start = time.time()

    def connection_made(self, transport):
        self.connected = time.time()
        elapsed = self.connected - self.start
        print(f"{self.id} connected after {elapsed} seconds")
        transport.sendto(self.packet)
        running_connections.add(self.id)

    def datagram_received(self, data, addr):
        elapsed = time.time() - self.start
        print(f"{self.id} received {len(data)} bytes after {elapsed} seconds")
        running_connections.remove(self.id)
        if not running_connections:
            self.loop.stop()


def botherer_generator(proto, loop):
    i = 0
    while True:
        yield lambda: DnsBotherer(proto, i, loop)
        i += 1


def main():
    if len(sys.argv) > 1:
        n_sockets = int(sys.argv[1])
    else:
        n_sockets = 3

    loop = asyncio.get_event_loop()

    for i, n, d in zip(range(n_sockets),
                       botherer_generator('nbt', loop),
                       botherer_generator('dns', loop)):
        nbt_conn = loop.create_datagram_endpoint(n, remote_addr=NBT_SERVER)
        dns_conn = loop.create_datagram_endpoint(d, remote_addr=DNS_SERVER)
        loop.run_until_complete(nbt_conn)
        loop.run_until_complete(dns_conn)

    loop.run_forever()

    loop.close()

main()
