import asyncio
import os
import sys
import struct
import time


SERVER = os.environ['DC_SERVER_IP']
NBT_NAME = b'EOGFGLGPCACACACACACACACACACACACA'  # "Neko"

DNS_SERVER = (SERVER, 53)
NBT_SERVER = (SERVER, 137)

LABEL = b'.' * 63
LABEL2 = b'x.' * 31 + b'x'
LABEL3 = b'x'

def construct_packet(proto, msg_id, label=LABEL, label_n=127):
    msg_id &= 0xffff
    if proto == 'nbt':
        name = [NBT_NAME] + [label] * (label_n - 1)
        qtype = 0x20  # NBT_QTYPE_NETBIOS
    else:
        name = [label] * label_n
        qtype = 1     # dns type A

    header = struct.pack('!6H',
                         msg_id,
                         0x0100,       # query, with recursion
                         0x0001,       # number of queries
                         0x0000,       # no answers
                         0x0000,       # no records
                         0x0000,       # no extra records
    )
    tail = struct.pack('!BHH',
                       0x00,         # root node
                       qtype,
                       0x0001,       # class IN-ternet
    )
    encoded_bits = [header]

    for n in name:
        encoded_bits.append(b'%c%s' % (len(n), n))

    encoded_bits.append(tail)
    return b''.join(encoded_bits)


running_connections = set()


class DnsBotherer(asyncio.DatagramProtocol):
    def __init__(self, proto, i, loop, label, label_n=127):
        self.id = f'{proto} {i}'
        print(f"{self.id} {label}")
        self.packet = construct_packet(proto, i, label=label, label_n=label_n)
        self.loop = loop
        self.start = time.time()

    def connection_made(self, transport):
        self.transport = transport
        self.connected = time.time()
        elapsed = self.connected - self.start
        print(f"{self.id} connected after {elapsed:.2} seconds")
        transport.sendto(self.packet)
        running_connections.add(self.id)

    def datagram_received(self, data, addr):
        elapsed = time.time() - self.connected
        print(f"{self.id} received {len(data)} bytes {elapsed:.2} seconds later")
        running_connections.remove(self.id)
        self.transport.close()
        if not running_connections:
            self.loop.stop()

    def __call__(self):
        return self


def main():
    if len(sys.argv) > 1:
        n_sockets = int(sys.argv[1])
    else:
        n_sockets = 1

    loop = asyncio.get_event_loop()

    # a little something to keep the loop alive until all the requests
    # are setup
    loop.call_later(1.0, int)

    for j, label in enumerate([LABEL,
                               LABEL2,
                               LABEL3,
    ]):
        conns = []
        for i in range(n_sockets):
            id = j * n_sockets + i
            #n = DnsBotherer('nbt', id, loop, label)
            d = DnsBotherer('dns', id, loop, label)
            #nbt_conn = loop.create_datagram_endpoint(n, remote_addr=NBT_SERVER)
            dns_conn = loop.create_datagram_endpoint(d, remote_addr=DNS_SERVER)
            #conns.append(nbt_conn)
            conns.append(dns_conn)

        loop.run_until_complete(asyncio.gather(*conns))
        loop.run_forever()
    loop.close()

main()
