Example filtering DNS proxy

We provide a simple example of a DNS proxy which filters out domain names which do not represent valid hostnames below.

Download source code: filter-proxy.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
BANNER = """
##################################################
# Proof-of-concept DNS hostname validation proxy #
# Filters rrset key, cname and ptr records       #
# FOR TESTING ONLY                               #
# do not use in production enviroments           #
##################################################
"""

print(BANNER)

import socket
import re
import sys
import traceback

import dns.message

RDTYPE_CNAME   = 5
RDTYPE_PTR     = 12
RCODE_SERVFAIL = 2

RE_LABEL = re.compile(b'^[a-zA-Z0-9\\-]*$')
def checkname(name):
    for label in name.labels:
        match = re.match(RE_LABEL, label)
        if not(match):
            return False
    return True

def checkrecord(record):
    if not(checkname(record.name)):
        return False
    
    if record.rdtype == RDTYPE_CNAME:
        for item in record.items:
            if not(checkname(item.target)):
                return False
    
    if record.rdtype == RDTYPE_PTR:
        for item in record.items:
            if not(checkname(item.target)):
                return False
    
    return True
    
def checkpacket(packet):
    
    for question in packet.question:
        if not(checkname(question.name)):
            return False
    
    for record in packet.answer:
        if not(checkrecord(record)):
            return False
    
    for record in packet.authority:
        if not(checkrecord(record)):
            return False
    
    for record in packet.additional:
        if not(checkrecord(record)):
            return False
    
    return True

if len(sys.argv) < 2:
    print("Usage: " + sys.argv[0] + " <upstream> [<listen_addr>]")
    sys.exit(1)

if len(sys.argv) >= 2:
    upstream = (sys.argv[1], 53)

listen = ("0.0.0.0", 53)
if len(sys.argv) >= 3:
    listen = (sys.argv[2], 53)

txid2client = {}

sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)

print("binding to " + listen[0] + ":" + str(listen[1]))
sock.bind(listen)

while True:
    
    pkt, remote = sock.recvfrom(0xFFFF)
    
    # try
    
    try:

        parsed = dns.message.from_wire(pkt)

        is_response = remote == upstream
        txid        = parsed.id
        packet_ok   = False
        question_s  = ""
        
        try:
            packet_ok   = checkpacket(parsed)
            question_s  = str(parsed.question[0])
        except:
            traceback.print_exc()
        
        print(remote[0] + ":" + str(remote[1]) + " " + question_s + " FLAGS=" + hex(parsed.flags) + " OK=" + str(packet_ok))
        
        if packet_ok:
            if is_response:
                if txid in txid2client:
                    client = txid2client[txid]
                    del txid2client[txid]
                    sock.sendto(pkt, client)
                else:
                    print("Error: TXID " + txid + " not found")
            else:
                txid2client[txid] = remote
                sock.sendto(pkt, upstream)
        else:
            if is_response:
                if txid in txid2client:
                    client = txid2client[txid]
                    del txid2client[txid]
                    
                    parsed.set_rcode(RCODE_SERVFAIL)
                    parsed.answer.clear()
                    parsed.authority.clear()
                    parsed.additional.clear()
                    
                    sock.sendto(parsed.to_wire(), client)
                else:
                    print("Error: TXID " + txid + " not found")
            else:
                
                parsed.flags = 0x8180
                parsed.set_rcode(RCODE_SERVFAIL)
                parsed.answer.clear()
                parsed.authority.clear()
                parsed.additional.clear()
                
                sock.sendto(parsed.to_wire(), remote)
            
                
    except:
        traceback.print_exc()