BANNER = """
##################################################
# Proof-of-concept DNS hostname validation proxy #
# Filters rrset key, cname and ptr records       #
# FOR TESTING ONLY                               #
# do not use in production enviroments           #
##################################################
"""

print(BANNER)

import socket
import re
import sys
import traceback

import dns.message

RDTYPE_CNAME   = 5
RDTYPE_PTR     = 12
RCODE_SERVFAIL = 2

RE_LABEL = re.compile(b'^[a-zA-Z0-9\\-]*$')
def checkname(name):
    for label in name.labels:
        match = re.match(RE_LABEL, label)
        if not(match):
            return False
    return True

def checkrecord(record):
    if not(checkname(record.name)):
        return False
    
    if record.rdtype == RDTYPE_CNAME:
        for item in record.items:
            if not(checkname(item.target)):
                return False
    
    if record.rdtype == RDTYPE_PTR:
        for item in record.items:
            if not(checkname(item.target)):
                return False
    
    return True
    
def checkpacket(packet):
    
    for question in packet.question:
        if not(checkname(question.name)):
            return False
    
    for record in packet.answer:
        if not(checkrecord(record)):
            return False
    
    for record in packet.authority:
        if not(checkrecord(record)):
            return False
    
    for record in packet.additional:
        if not(checkrecord(record)):
            return False
    
    return True

if len(sys.argv) < 2:
    print("Usage: " + sys.argv[0] + " <upstream> [<listen_addr>]")
    sys.exit(1)

if len(sys.argv) >= 2:
    upstream = (sys.argv[1], 53)

listen = ("0.0.0.0", 53)
if len(sys.argv) >= 3:
    listen = (sys.argv[2], 53)

txid2client = {}

sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)

print("binding to " + listen[0] + ":" + str(listen[1]))
sock.bind(listen)

while True:
    
    pkt, remote = sock.recvfrom(0xFFFF)
    
    # try
    
    try:

        parsed = dns.message.from_wire(pkt)

        is_response = remote == upstream
        txid        = parsed.id
        packet_ok   = False
        question_s  = ""
        
        try:
            packet_ok   = checkpacket(parsed)
            question_s  = str(parsed.question[0])
        except:
            traceback.print_exc()
        
        print(remote[0] + ":" + str(remote[1]) + " " + question_s + " FLAGS=" + hex(parsed.flags) + " OK=" + str(packet_ok))
        
        if packet_ok:
            if is_response:
                if txid in txid2client:
                    client = txid2client[txid]
                    del txid2client[txid]
                    sock.sendto(pkt, client)
                else:
                    print("Error: TXID " + txid + " not found")
            else:
                txid2client[txid] = remote
                sock.sendto(pkt, upstream)
        else:
            if is_response:
                if txid in txid2client:
                    client = txid2client[txid]
                    del txid2client[txid]
                    
                    parsed.set_rcode(RCODE_SERVFAIL)
                    parsed.answer.clear()
                    parsed.authority.clear()
                    parsed.additional.clear()
                    
                    sock.sendto(parsed.to_wire(), client)
                else:
                    print("Error: TXID " + txid + " not found")
            else:
                
                parsed.flags = 0x8180
                parsed.set_rcode(RCODE_SERVFAIL)
                parsed.answer.clear()
                parsed.authority.clear()
                parsed.additional.clear()
                
                sock.sendto(parsed.to_wire(), remote)
            
                
    except:
        traceback.print_exc()
    
