Mini Shell
# test-case-name: twisted.names.test.test_dns
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for twisted.names.dns.
"""
import struct
from io import BytesIO
from typing import cast
from zope.interface.verify import verifyClass
from twisted.internet import address, task
from twisted.internet.error import CannotListenError, ConnectionDone
from twisted.names import dns
from twisted.python.failure import Failure
from twisted.python.util import FancyEqMixin, FancyStrMixin
from twisted.test import proto_helpers
from twisted.test.testutils import ComparisonTestsMixin
from twisted.trial import unittest
RECORD_TYPES = [
dns.Record_NS,
dns.Record_MD,
dns.Record_MF,
dns.Record_CNAME,
dns.Record_MB,
dns.Record_MG,
dns.Record_MR,
dns.Record_PTR,
dns.Record_DNAME,
dns.Record_A,
dns.Record_SOA,
dns.Record_NULL,
dns.Record_WKS,
dns.Record_SRV,
dns.Record_AFSDB,
dns.Record_RP,
dns.Record_HINFO,
dns.Record_MINFO,
dns.Record_MX,
dns.Record_TXT,
dns.Record_AAAA,
dns.Record_A6,
dns.Record_NAPTR,
dns.Record_SSHFP,
dns.Record_TSIG,
dns.UnknownRecord,
]
class DomainStringTests(unittest.SynchronousTestCase):
def test_bytes(self):
"""
L{dns.domainString} returns L{bytes} unchanged.
"""
self.assertEqual(
b"twistedmatrix.com",
dns.domainString(b"twistedmatrix.com"),
)
def test_native(self):
"""
L{dns.domainString} converts a native string to L{bytes}
if necessary.
"""
self.assertEqual(b"example.com", dns.domainString("example.com"))
def test_text(self):
"""
L{dns.domainString} always converts a unicode string to L{bytes}.
"""
self.assertEqual(b"foo.example", dns.domainString("foo.example"))
def test_idna(self):
"""
L{dns.domainString} encodes Unicode using IDNA.
"""
self.assertEqual(b"xn--fwg.test", dns.domainString("\u203D.test"))
def test_nonsense(self):
"""
L{dns.domainString} encodes Unicode using IDNA.
"""
self.assertRaises(TypeError, dns.domainString, 9000)
self.assertRaises(TypeError, dns.domainString, dns.Name("bar.example"))
class Ord2ByteTests(unittest.TestCase):
"""
Tests for L{dns._ord2bytes}.
"""
def test_ord2byte(self):
"""
L{dns._ord2byte} accepts an integer and returns a byte string of length
one with an ordinal value equal to the given integer.
"""
self.assertEqual(b"\x10", dns._ord2bytes(0x10))
class Str2TimeTests(unittest.TestCase):
"""
Tests for L{dns.str2name}.
"""
def test_nonString(self):
"""
When passed a non-string object, L{dns.str2name} returns it unmodified.
"""
time = object()
self.assertIs(time, dns.str2time(time))
def test_seconds(self):
"""
Passed a string giving a number of seconds, L{dns.str2time} returns the
number of seconds represented. For example, C{"10S"} represents C{10}
seconds.
"""
self.assertEqual(10, dns.str2time("10S"))
def test_minutes(self):
"""
Like C{test_seconds}, but for the C{"M"} suffix which multiplies the
time value by C{60} (the number of seconds in a minute!).
"""
self.assertEqual(2 * 60, dns.str2time("2M"))
def test_hours(self):
"""
Like C{test_seconds}, but for the C{"H"} suffix which multiplies the
time value by C{3600}, the number of seconds in an hour.
"""
self.assertEqual(3 * 3600, dns.str2time("3H"))
def test_days(self):
"""
Like L{test_seconds}, but for the C{"D"} suffix which multiplies the
time value by C{86400}, the number of seconds in a day.
"""
self.assertEqual(4 * 86400, dns.str2time("4D"))
def test_weeks(self):
"""
Like L{test_seconds}, but for the C{"W"} suffix which multiplies the
time value by C{604800}, the number of seconds in a week.
"""
self.assertEqual(5 * 604800, dns.str2time("5W"))
def test_years(self):
"""
Like L{test_seconds}, but for the C{"Y"} suffix which multiplies the
time value by C{31536000}, the number of seconds in a year.
"""
self.assertEqual(6 * 31536000, dns.str2time("6Y"))
def test_invalidPrefix(self):
"""
If a non-integer prefix is given, L{dns.str2time} raises L{ValueError}.
"""
self.assertRaises(ValueError, dns.str2time, "fooS")
class NameTests(unittest.TestCase):
"""
Tests for L{Name}, the representation of a single domain name with support
for encoding into and decoding from DNS message format.
"""
def test_nonStringName(self):
"""
When constructed with a name which is neither C{bytes} nor C{str},
L{Name} raises L{TypeError}.
"""
self.assertRaises(TypeError, dns.Name, 123)
self.assertRaises(TypeError, dns.Name, object())
self.assertRaises(TypeError, dns.Name, [])
def test_unicodeName(self):
"""
L{dns.Name} automatically encodes unicode domain name using C{idna}
encoding.
"""
name = dns.Name("\u00e9chec.example.org")
self.assertIsInstance(name.name, bytes)
self.assertEqual(b"xn--chec-9oa.example.org", name.name)
def test_decode(self):
"""
L{Name.decode} populates the L{Name} instance with name information read
from the file-like object passed to it.
"""
n = dns.Name()
n.decode(BytesIO(b"\x07example\x03com\x00"))
self.assertEqual(n.name, b"example.com")
def test_encode(self):
"""
L{Name.encode} encodes its name information and writes it to the
file-like object passed to it.
"""
name = dns.Name(b"foo.example.com")
stream = BytesIO()
name.encode(stream)
self.assertEqual(stream.getvalue(), b"\x03foo\x07example\x03com\x00")
def test_encodeWithCompression(self):
"""
If a compression dictionary is passed to it, L{Name.encode} uses offset
information from it to encode its name with references to existing
labels in the stream instead of including another copy of them in the
output. It also updates the compression dictionary with the location of
the name it writes to the stream.
"""
name = dns.Name(b"foo.example.com")
compression = {b"example.com": 0x17}
# Some bytes already encoded into the stream for this message
previous = b"some prefix to change .tell()"
stream = BytesIO()
stream.write(previous)
# The position at which the encoded form of this new name will appear in
# the stream.
expected = len(previous) + dns.Message.headerSize
name.encode(stream, compression)
self.assertEqual(b"\x03foo\xc0\x17", stream.getvalue()[len(previous) :])
self.assertEqual(
{b"example.com": 0x17, b"foo.example.com": expected}, compression
)
def test_unknown(self):
"""
A resource record of unknown type and class is parsed into an
L{UnknownRecord} instance with its data preserved, and an
L{UnknownRecord} instance is serialized to a string equal to the one it
was parsed from.
"""
wire = (
b"\x01\x00" # Message ID
b"\x00" # answer bit, opCode nibble, auth bit, trunc bit, recursive
# bit
b"\x00" # recursion bit, empty bit, authenticData bit,
# checkingDisabled bit, response code nibble
b"\x00\x01" # number of queries
b"\x00\x01" # number of answers
b"\x00\x00" # number of authorities
b"\x00\x01" # number of additionals
# query
b"\x03foo\x03bar\x00" # foo.bar
b"\xde\xad" # type=0xdead
b"\xbe\xef" # cls=0xbeef
# 1st answer
b"\xc0\x0c" # foo.bar - compressed
b"\xde\xad" # type=0xdead
b"\xbe\xef" # cls=0xbeef
b"\x00\x00\x01\x01" # ttl=257
b"\x00\x08somedata" # some payload data
# 1st additional
b"\x03baz\x03ban\x00" # baz.ban
b"\x00\x01" # type=A
b"\x00\x01" # cls=IN
b"\x00\x00\x01\x01" # ttl=257
b"\x00\x04" # len=4
b"\x01\x02\x03\x04" # 1.2.3.4
)
msg = dns.Message()
msg.fromStr(wire)
self.assertEqual(
msg.queries,
[
dns.Query(b"foo.bar", type=0xDEAD, cls=0xBEEF),
],
)
self.assertEqual(
msg.answers,
[
dns.RRHeader(
b"foo.bar",
type=0xDEAD,
cls=0xBEEF,
ttl=257,
payload=dns.UnknownRecord(b"somedata", ttl=257),
),
],
)
self.assertEqual(
msg.additional,
[
dns.RRHeader(
b"baz.ban",
type=dns.A,
cls=dns.IN,
ttl=257,
payload=dns.Record_A("1.2.3.4", ttl=257),
),
],
)
enc = msg.toStr()
self.assertEqual(enc, wire)
def test_decodeWithCompression(self):
"""
If the leading byte of an encoded label (in bytes read from a stream
passed to L{Name.decode}) has its two high bits set, the next byte is
treated as a pointer to another label in the stream and that label is
included in the name being decoded.
"""
# Slightly modified version of the example from RFC 1035, section 4.1.4.
stream = BytesIO(
b"x" * 20 + b"\x01f\x03isi\x04arpa\x00"
b"\x03foo\xc0\x14"
b"\x03bar\xc0\x20"
)
stream.seek(20)
name = dns.Name()
name.decode(stream)
# Verify we found the first name in the stream and that the stream
# position is left at the first byte after the decoded name.
self.assertEqual(b"f.isi.arpa", name.name)
self.assertEqual(32, stream.tell())
# Get the second name from the stream and make the same assertions.
name.decode(stream)
self.assertEqual(name.name, b"foo.f.isi.arpa")
self.assertEqual(38, stream.tell())
# Get the third and final name
name.decode(stream)
self.assertEqual(name.name, b"bar.foo.f.isi.arpa")
self.assertEqual(44, stream.tell())
def test_rejectCompressionLoop(self):
"""
L{Name.decode} raises L{ValueError} if the stream passed to it includes
a compression pointer which forms a loop, causing the name to be
undecodable.
"""
name = dns.Name()
stream = BytesIO(b"\xc0\x00")
self.assertRaises(ValueError, name.decode, stream)
def test_equality(self):
"""
L{Name} instances are equal as long as they have the same value for
L{Name.name}, regardless of the case.
"""
name1 = dns.Name(b"foo.bar")
name2 = dns.Name(b"foo.bar")
self.assertEqual(name1, name2)
name3 = dns.Name(b"fOO.bar")
self.assertEqual(name1, name3)
def test_inequality(self):
"""
L{Name} instances are not equal as long as they have different
L{Name.name} attributes.
"""
name1 = dns.Name(b"foo.bar")
name2 = dns.Name(b"bar.foo")
self.assertNotEqual(name1, name2)
class RoundtripDNSTests(unittest.TestCase):
"""
Encoding and then decoding various objects.
"""
names = [b"example.org", b"go-away.fish.tv", b"23strikesback.net"]
def test_name(self):
for n in self.names:
# encode the name
f = BytesIO()
dns.Name(n).encode(f)
# decode the name
f.seek(0, 0)
result = dns.Name()
result.decode(f)
self.assertEqual(result.name, n)
def test_query(self):
"""
L{dns.Query.encode} returns a byte string representing the fields of the
query which can be decoded into a new L{dns.Query} instance using
L{dns.Query.decode}.
"""
for n in self.names:
for dnstype in range(1, 17):
for dnscls in range(1, 5):
# encode the query
f = BytesIO()
dns.Query(n, dnstype, dnscls).encode(f)
# decode the result
f.seek(0, 0)
result = dns.Query()
result.decode(f)
self.assertEqual(result.name.name, n)
self.assertEqual(result.type, dnstype)
self.assertEqual(result.cls, dnscls)
def test_resourceRecordHeader(self):
"""
L{dns.RRHeader.encode} encodes the record header's information and
writes it to the file-like object passed to it and
L{dns.RRHeader.decode} reads from a file-like object to re-construct a
L{dns.RRHeader} instance.
"""
# encode the RR
f = BytesIO()
dns.RRHeader(b"test.org", 3, 4, 17).encode(f)
# decode the result
f.seek(0, 0)
result = dns.RRHeader()
result.decode(f)
self.assertEqual(result.name, dns.Name(b"test.org"))
self.assertEqual(result.type, 3)
self.assertEqual(result.cls, 4)
self.assertEqual(result.ttl, 17)
def test_resourceRecordHeaderTypeMismatch(self):
"""
L{RRHeader()} raises L{ValueError} when the given type and the type
of the payload don't match.
"""
with self.assertRaisesRegex(ValueError, r"Payload type \(AAAA\) .* type \(A\)"):
dns.RRHeader(type=dns.A, payload=dns.Record_AAAA())
def test_resources(self):
"""
L{dns.SimpleRecord.encode} encodes the record's name information and
writes it to the file-like object passed to it and
L{dns.SimpleRecord.decode} reads from a file-like object to re-construct
a L{dns.SimpleRecord} instance.
"""
names = (
b"this.are.test.name",
b"will.compress.will.this.will.name.will.hopefully",
b"test.CASE.preSErVatIOn.YeAH",
b"a.s.h.o.r.t.c.a.s.e.t.o.t.e.s.t",
b"singleton",
)
for s in names:
f = BytesIO()
dns.SimpleRecord(s).encode(f)
f.seek(0, 0)
result = dns.SimpleRecord()
result.decode(f)
self.assertEqual(result.name, dns.Name(s))
def test_hashable(self):
"""
Instances of all record types are hashable.
"""
for k in RECORD_TYPES:
k1, k2 = k(), k()
hk1 = hash(k1)
hk2 = hash(k2)
self.assertEqual(hk1, hk2, f"{hk1} != {hk2} (for {k})")
def test_Charstr(self):
"""
Test L{dns.Charstr} encode and decode.
"""
for n in self.names:
# encode the name
f = BytesIO()
dns.Charstr(n).encode(f)
# decode the name
f.seek(0, 0)
result = dns.Charstr()
result.decode(f)
self.assertEqual(result.string, n)
def _recordRoundtripTest(self, record):
"""
Assert that encoding C{record} and then decoding the resulting bytes
creates a record which compares equal to C{record}.
@type record: L{dns.IEncodable}
@param record: A record instance to encode
"""
stream = BytesIO()
record.encode(stream)
length = stream.tell()
stream.seek(0, 0)
replica = record.__class__()
replica.decode(stream, length)
self.assertEqual(record, replica)
def assertEncodedFormat(self, expectedEncoding, record):
"""
Assert that encoding C{record} produces the expected bytes.
@type record: L{dns.IEncodable}
@param record: A record instance to encode
@type expectedEncoding: L{bytes}
@param expectedEncoding: The value which C{record.encode()}
should produce.
"""
stream = BytesIO()
record.encode(stream)
self.assertEqual(stream.getvalue(), expectedEncoding)
def test_SOA(self):
"""
The byte stream written by L{dns.Record_SOA.encode} can be used by
L{dns.Record_SOA.decode} to reconstruct the state of the original
L{dns.Record_SOA} instance.
"""
self._recordRoundtripTest(
dns.Record_SOA(
mname=b"foo",
rname=b"bar",
serial=12,
refresh=34,
retry=56,
expire=78,
minimum=90,
)
)
def test_A(self):
"""
The byte stream written by L{dns.Record_A.encode} can be used by
L{dns.Record_A.decode} to reconstruct the state of the original
L{dns.Record_A} instance.
"""
self._recordRoundtripTest(dns.Record_A("1.2.3.4"))
def test_NULL(self):
"""
The byte stream written by L{dns.Record_NULL.encode} can be used by
L{dns.Record_NULL.decode} to reconstruct the state of the original
L{dns.Record_NULL} instance.
"""
self._recordRoundtripTest(dns.Record_NULL(b"foo bar"))
def test_WKS(self):
"""
The byte stream written by L{dns.Record_WKS.encode} can be used by
L{dns.Record_WKS.decode} to reconstruct the state of the original
L{dns.Record_WKS} instance.
"""
self._recordRoundtripTest(dns.Record_WKS("1.2.3.4", 3, b"xyz"))
def test_AAAA(self):
"""
The byte stream written by L{dns.Record_AAAA.encode} can be used by
L{dns.Record_AAAA.decode} to reconstruct the state of the original
L{dns.Record_AAAA} instance.
"""
self._recordRoundtripTest(dns.Record_AAAA("::1"))
def test_A6(self):
"""
The byte stream written by L{dns.Record_A6.encode} can be used by
L{dns.Record_A6.decode} to reconstruct the state of the original
L{dns.Record_A6} instance.
"""
self._recordRoundtripTest(dns.Record_A6(8, "::1:2", b"foo"))
def test_SRV(self):
"""
The byte stream written by L{dns.Record_SRV.encode} can be used by
L{dns.Record_SRV.decode} to reconstruct the state of the original
L{dns.Record_SRV} instance.
"""
self._recordRoundtripTest(
dns.Record_SRV(priority=1, weight=2, port=3, target=b"example.com")
)
def test_SSHFP(self):
"""
The byte stream written by L{dns.Record_SSHFP.encode} can be used by
L{dns.Record_SSHFP.decode} to reconstruct the state of the original
L{dns.Record_SSHFP} instance.
"""
fp = (
b"\xda\x39\xa3\xee\x5e\x6b\x4b\x0d"
+ b"\x32\x55\xbf\xef\x95\x60\x18\x90\xaf\xd8\x07\x09"
)
rr = dns.Record_SSHFP(
algorithm=dns.Record_SSHFP.ALGORITHM_DSS,
fingerprintType=dns.Record_SSHFP.FINGERPRINT_TYPE_SHA1,
fingerprint=fp,
)
self._recordRoundtripTest(rr)
self.assertEncodedFormat(b"\x02\x01" + fp, rr)
def test_NAPTR(self):
"""
Test L{dns.Record_NAPTR} encode and decode.
"""
naptrs = [
(100, 10, b"u", b"sip+E2U", b"!^.*$!sip:information@domain.tld!", b""),
(100, 50, b"s", b"http+I2L+I2C+I2R", b"", b"_http._tcp.gatech.edu"),
]
for (order, preference, flags, service, regexp, replacement) in naptrs:
rin = dns.Record_NAPTR(
order, preference, flags, service, regexp, replacement
)
e = BytesIO()
rin.encode(e)
e.seek(0, 0)
rout = dns.Record_NAPTR()
rout.decode(e)
self.assertEqual(rin.order, rout.order)
self.assertEqual(rin.preference, rout.preference)
self.assertEqual(rin.flags, rout.flags)
self.assertEqual(rin.service, rout.service)
self.assertEqual(rin.regexp, rout.regexp)
self.assertEqual(rin.replacement.name, rout.replacement.name)
self.assertEqual(rin.ttl, rout.ttl)
def test_AFSDB(self):
"""
The byte stream written by L{dns.Record_AFSDB.encode} can be used by
L{dns.Record_AFSDB.decode} to reconstruct the state of the original
L{dns.Record_AFSDB} instance.
"""
self._recordRoundtripTest(dns.Record_AFSDB(subtype=3, hostname=b"example.com"))
def test_RP(self):
"""
The byte stream written by L{dns.Record_RP.encode} can be used by
L{dns.Record_RP.decode} to reconstruct the state of the original
L{dns.Record_RP} instance.
"""
self._recordRoundtripTest(
dns.Record_RP(mbox=b"alice.example.com", txt=b"example.com")
)
def test_HINFO(self):
"""
The byte stream written by L{dns.Record_HINFO.encode} can be used by
L{dns.Record_HINFO.decode} to reconstruct the state of the original
L{dns.Record_HINFO} instance.
"""
self._recordRoundtripTest(dns.Record_HINFO(cpu=b"fast", os=b"great"))
def test_MINFO(self):
"""
The byte stream written by L{dns.Record_MINFO.encode} can be used by
L{dns.Record_MINFO.decode} to reconstruct the state of the original
L{dns.Record_MINFO} instance.
"""
self._recordRoundtripTest(dns.Record_MINFO(rmailbx=b"foo", emailbx=b"bar"))
def test_MX(self):
"""
The byte stream written by L{dns.Record_MX.encode} can be used by
L{dns.Record_MX.decode} to reconstruct the state of the original
L{dns.Record_MX} instance.
"""
self._recordRoundtripTest(dns.Record_MX(preference=1, name=b"example.com"))
def test_TSIG(self):
"""
The byte stream written by L{dns.Record_TSIG.encode} can be used by
L{dns.Record_TSIG.decode} to reconstruct the state of the original
L{dns.Record_TSIG} instance.
"""
mac = b"\x00\x01\x02\x03\x10\x11\x12\x13" b"\x20\x21\x22\x23\x30\x31\x32\x33"
rr = dns.Record_TSIG(
algorithm="hmac-md5.sig-alg.reg.int",
timeSigned=1515548975,
originalID=42,
fudge=5,
MAC=mac,
)
self._recordRoundtripTest(rr)
rdata = (
b"\x08hmac-md5\x07sig-alg\x03reg\x03int\x00"
b"\x00\x00\x5a\x55\x71\x2f\x00\x05\x00\x10"
+ mac
+ b"\x00\x2A\x00\x00\x00\x00"
)
self.assertEncodedFormat(rdata, rr)
rr = dns.Record_TSIG(
algorithm="hmac-sha256",
timeSigned=4511798055, # More than 32 bits
originalID=65535,
error=dns.EBADTIME,
otherData=b"\x80\x00\x00\x00\x00\x08",
MAC=mac,
)
self._recordRoundtripTest(rr)
rdata = (
b"\x0Bhmac-sha256\x00"
b"\x00\x01\x0c\xec\x93\x27\x00\x05\x00\x10"
+ mac
+ b"\xff\xff\x00\x12\x00\x06"
b"\x80\x00\x00\x00\x00\x08"
)
self.assertEncodedFormat(rdata, rr)
def test_TXT(self):
"""
The byte stream written by L{dns.Record_TXT.encode} can be used by
L{dns.Record_TXT.decode} to reconstruct the state of the original
L{dns.Record_TXT} instance.
"""
self._recordRoundtripTest(dns.Record_TXT(b"foo", b"bar"))
MESSAGE_AUTHENTIC_DATA_BYTES = (
b"\x00\x00" # ID
b"\x00" #
b"\x20" # RA, Z, AD=1, CD, RCODE
b"\x00\x00" # Query count
b"\x00\x00" # Answer count
b"\x00\x00" # Authority count
b"\x00\x00" # Additional count
)
MESSAGE_CHECKING_DISABLED_BYTES = (
b"\x00\x00" # ID
b"\x00" #
b"\x10" # RA, Z, AD, CD=1, RCODE
b"\x00\x00" # Query count
b"\x00\x00" # Answer count
b"\x00\x00" # Authority count
b"\x00\x00" # Additional count
)
class MessageTests(unittest.SynchronousTestCase):
"""
Tests for L{twisted.names.dns.Message}.
"""
def test_authenticDataDefault(self):
"""
L{dns.Message.authenticData} has default value 0.
"""
self.assertEqual(dns.Message().authenticData, 0)
def test_authenticDataOverride(self):
"""
L{dns.Message.__init__} accepts a C{authenticData} argument which
is assigned to L{dns.Message.authenticData}.
"""
self.assertEqual(dns.Message(authenticData=1).authenticData, 1)
def test_authenticDataEncode(self):
"""
L{dns.Message.toStr} encodes L{dns.Message.authenticData} into
byte4 of the byte string.
"""
self.assertEqual(
dns.Message(authenticData=1).toStr(), MESSAGE_AUTHENTIC_DATA_BYTES
)
def test_authenticDataDecode(self):
"""
L{dns.Message.fromStr} decodes byte4 and assigns bit3 to
L{dns.Message.authenticData}.
"""
m = dns.Message()
m.fromStr(MESSAGE_AUTHENTIC_DATA_BYTES)
self.assertEqual(m.authenticData, 1)
def test_checkingDisabledDefault(self):
"""
L{dns.Message.checkingDisabled} has default value 0.
"""
self.assertEqual(dns.Message().checkingDisabled, 0)
def test_checkingDisabledOverride(self):
"""
L{dns.Message.__init__} accepts a C{checkingDisabled} argument which
is assigned to L{dns.Message.checkingDisabled}.
"""
self.assertEqual(dns.Message(checkingDisabled=1).checkingDisabled, 1)
def test_checkingDisabledEncode(self):
"""
L{dns.Message.toStr} encodes L{dns.Message.checkingDisabled} into
byte4 of the byte string.
"""
self.assertEqual(
dns.Message(checkingDisabled=1).toStr(), MESSAGE_CHECKING_DISABLED_BYTES
)
def test_checkingDisabledDecode(self):
"""
L{dns.Message.fromStr} decodes byte4 and assigns bit4 to
L{dns.Message.checkingDisabled}.
"""
m = dns.Message()
m.fromStr(MESSAGE_CHECKING_DISABLED_BYTES)
self.assertEqual(m.checkingDisabled, 1)
def test_reprDefaults(self):
"""
L{dns.Message.__repr__} omits field values and sections which are
identical to their defaults. The id field value is always shown.
"""
self.assertEqual("<Message id=0>", repr(dns.Message()))
def test_reprFlagsIfSet(self):
"""
L{dns.Message.__repr__} displays flags if they are L{True}.
"""
m = dns.Message(
answer=True,
auth=True,
trunc=True,
recDes=True,
recAv=True,
authenticData=True,
checkingDisabled=True,
)
self.assertEqual(
"<Message "
"id=0 "
"flags=answer,auth,trunc,recDes,recAv,authenticData,"
"checkingDisabled"
">",
repr(m),
)
def test_reprNonDefautFields(self):
"""
L{dns.Message.__repr__} displays field values if they differ from their
defaults.
"""
m = dns.Message(id=10, opCode=20, rCode=30, maxSize=40)
self.assertEqual(
"<Message " "id=10 " "opCode=20 " "rCode=30 " "maxSize=40" ">",
repr(m),
)
def test_reprNonDefaultSections(self):
"""
L{dns.Message.__repr__} displays sections which differ from their
defaults.
"""
m = dns.Message()
m.queries = [1, 2, 3]
m.answers = [4, 5, 6]
m.authority = [7, 8, 9]
m.additional = [10, 11, 12]
self.assertEqual(
"<Message "
"id=0 "
"queries=[1, 2, 3] "
"answers=[4, 5, 6] "
"authority=[7, 8, 9] "
"additional=[10, 11, 12]"
">",
repr(m),
)
def test_emptyMessage(self):
"""
Test that a message which has been truncated causes an EOFError to
be raised when it is parsed.
"""
msg = dns.Message()
self.assertRaises(EOFError, msg.fromStr, b"")
def test_emptyQuery(self):
"""
Test that bytes representing an empty query message can be decoded
as such.
"""
msg = dns.Message()
msg.fromStr(
b"\x01\x00" # Message ID
b"\x00" # answer bit, opCode nibble, auth bit, trunc bit, recursive bit
b"\x00" # recursion bit, empty bit, authenticData bit,
# checkingDisabled bit, response code nibble
b"\x00\x00" # number of queries
b"\x00\x00" # number of answers
b"\x00\x00" # number of authorities
b"\x00\x00" # number of additionals
)
self.assertEqual(msg.id, 256)
self.assertFalse(msg.answer, "Message was not supposed to be an answer.")
self.assertEqual(msg.opCode, dns.OP_QUERY)
self.assertFalse(msg.auth, "Message was not supposed to be authoritative.")
self.assertFalse(msg.trunc, "Message was not supposed to be truncated.")
self.assertEqual(msg.queries, [])
self.assertEqual(msg.answers, [])
self.assertEqual(msg.authority, [])
self.assertEqual(msg.additional, [])
def test_NULL(self):
"""
A I{NULL} record with an arbitrary payload can be encoded and decoded as
part of a L{dns.Message}.
"""
bytes = b"".join([dns._ord2bytes(i) for i in range(256)])
rec = dns.Record_NULL(bytes)
rr = dns.RRHeader(b"testname", dns.NULL, payload=rec)
msg1 = dns.Message()
msg1.answers.append(rr)
s = BytesIO()
msg1.encode(s)
s.seek(0, 0)
msg2 = dns.Message()
msg2.decode(s)
self.assertIsInstance(msg2.answers[0].payload, dns.Record_NULL)
self.assertEqual(msg2.answers[0].payload.payload, bytes)
def test_lookupRecordTypeDefault(self):
"""
L{Message.lookupRecordType} returns C{dns.UnknownRecord} if it is
called with an integer which doesn't correspond to any known record
type.
"""
# 65280 is the first value in the range reserved for private
# use, so it shouldn't ever conflict with an officially
# allocated value.
self.assertIs(dns.Message().lookupRecordType(65280), dns.UnknownRecord)
def test_nonAuthoritativeMessage(self):
"""
The L{RRHeader} instances created by L{Message} from a non-authoritative
message are marked as not authoritative.
"""
buf = BytesIO()
answer = dns.RRHeader(payload=dns.Record_A("1.2.3.4", ttl=0))
answer.encode(buf)
message = dns.Message()
message.fromStr(
b"\x01\x00" # Message ID
# answer bit, opCode nibble, auth bit, trunc bit, recursive bit
b"\x00"
# recursion bit, empty bit, authenticData bit,
# checkingDisabled bit, response code nibble
b"\x00"
b"\x00\x00" # number of queries
b"\x00\x01" # number of answers
b"\x00\x00" # number of authorities
b"\x00\x00" + buf.getvalue() # number of additionals
)
self.assertEqual(message.answers, [answer])
self.assertFalse(message.answers[0].auth)
def test_authoritativeMessage(self):
"""
The L{RRHeader} instances created by L{Message} from an authoritative
message are marked as authoritative.
"""
buf = BytesIO()
answer = dns.RRHeader(payload=dns.Record_A("1.2.3.4", ttl=0))
answer.encode(buf)
message = dns.Message()
message.fromStr(
b"\x01\x00" # Message ID
# answer bit, opCode nibble, auth bit, trunc bit, recursive bit
b"\x04"
# recursion bit, empty bit, authenticData bit,
# checkingDisabled bit, response code nibble
b"\x00"
b"\x00\x00" # number of queries
b"\x00\x01" # number of answers
b"\x00\x00" # number of authorities
b"\x00\x00" + buf.getvalue() # number of additionals
)
answer.auth = True
self.assertEqual(message.answers, [answer])
self.assertTrue(message.answers[0].auth)
class MessageComparisonTests(ComparisonTestsMixin, unittest.SynchronousTestCase):
"""
Tests for the rich comparison of L{dns.Message} instances.
"""
def messageFactory(self, *args, **kwargs):
"""
Create a L{dns.Message}.
The L{dns.Message} constructor doesn't accept C{queries}, C{answers},
C{authority}, C{additional} arguments, so we extract them from the
kwargs supplied to this factory function and assign them to the message.
@param args: Positional arguments.
@param kwargs: Keyword arguments.
@return: A L{dns.Message} instance.
"""
queries = kwargs.pop("queries", [])
answers = kwargs.pop("answers", [])
authority = kwargs.pop("authority", [])
additional = kwargs.pop("additional", [])
m = dns.Message(**kwargs)
if queries:
m.queries = queries
if answers:
m.answers = answers
if authority:
m.authority = authority
if additional:
m.additional = additional
return m
def test_id(self):
"""
Two L{dns.Message} instances compare equal if they have the same id
value.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(id=10),
self.messageFactory(id=10),
self.messageFactory(id=20),
)
def test_answer(self):
"""
Two L{dns.Message} instances compare equal if they have the same answer
flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(answer=1),
self.messageFactory(answer=1),
self.messageFactory(answer=0),
)
def test_opCode(self):
"""
Two L{dns.Message} instances compare equal if they have the same opCode
value.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(opCode=10),
self.messageFactory(opCode=10),
self.messageFactory(opCode=20),
)
def test_recDes(self):
"""
Two L{dns.Message} instances compare equal if they have the same recDes
flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(recDes=1),
self.messageFactory(recDes=1),
self.messageFactory(recDes=0),
)
def test_recAv(self):
"""
Two L{dns.Message} instances compare equal if they have the same recAv
flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(recAv=1),
self.messageFactory(recAv=1),
self.messageFactory(recAv=0),
)
def test_auth(self):
"""
Two L{dns.Message} instances compare equal if they have the same auth
flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(auth=1),
self.messageFactory(auth=1),
self.messageFactory(auth=0),
)
def test_rCode(self):
"""
Two L{dns.Message} instances compare equal if they have the same rCode
value.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(rCode=10),
self.messageFactory(rCode=10),
self.messageFactory(rCode=20),
)
def test_trunc(self):
"""
Two L{dns.Message} instances compare equal if they have the same trunc
flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(trunc=1),
self.messageFactory(trunc=1),
self.messageFactory(trunc=0),
)
def test_maxSize(self):
"""
Two L{dns.Message} instances compare equal if they have the same
maxSize value.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(maxSize=10),
self.messageFactory(maxSize=10),
self.messageFactory(maxSize=20),
)
def test_authenticData(self):
"""
Two L{dns.Message} instances compare equal if they have the same
authenticData flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(authenticData=1),
self.messageFactory(authenticData=1),
self.messageFactory(authenticData=0),
)
def test_checkingDisabled(self):
"""
Two L{dns.Message} instances compare equal if they have the same
checkingDisabled flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(checkingDisabled=1),
self.messageFactory(checkingDisabled=1),
self.messageFactory(checkingDisabled=0),
)
def test_queries(self):
"""
Two L{dns.Message} instances compare equal if they have the same
queries.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(queries=[dns.Query(b"example.com")]),
self.messageFactory(queries=[dns.Query(b"example.com")]),
self.messageFactory(queries=[dns.Query(b"example.org")]),
)
def test_answers(self):
"""
Two L{dns.Message} instances compare equal if they have the same
answers.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(
answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))]
),
self.messageFactory(
answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))]
),
self.messageFactory(
answers=[dns.RRHeader(b"example.org", payload=dns.Record_A("4.3.2.1"))]
),
)
def test_authority(self):
"""
Two L{dns.Message} instances compare equal if they have the same
authority records.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(
authority=[
dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA())
]
),
self.messageFactory(
authority=[
dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA())
]
),
self.messageFactory(
authority=[
dns.RRHeader(b"example.org", type=dns.SOA, payload=dns.Record_SOA())
]
),
)
def test_additional(self):
"""
Two L{dns.Message} instances compare equal if they have the same
additional records.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(
additional=[
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))
]
),
self.messageFactory(
additional=[
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))
]
),
self.messageFactory(
additional=[
dns.RRHeader(b"example.org", payload=dns.Record_A("1.2.3.4"))
]
),
)
class TestController:
"""
Pretend to be a DNS query processor for a DNSDatagramProtocol.
@ivar messages: the list of received messages.
@type messages: C{list} of (msg, protocol, address)
"""
def __init__(self):
"""
Initialize the controller: create a list of messages.
"""
self.messages = []
def messageReceived(self, msg, proto, addr=None):
"""
Save the message so that it can be checked during the tests.
"""
self.messages.append((msg, proto, addr))
class DatagramProtocolTests(unittest.TestCase):
"""
Test various aspects of L{dns.DNSDatagramProtocol}.
"""
def setUp(self):
"""
Create a L{dns.DNSDatagramProtocol} with a deterministic clock.
"""
self.clock = task.Clock()
self.controller = TestController()
self.proto = dns.DNSDatagramProtocol(self.controller)
transport = proto_helpers.FakeDatagramTransport()
self.proto.makeConnection(transport)
self.proto.callLater = self.clock.callLater
def test_truncatedPacket(self):
"""
Test that when a short datagram is received, datagramReceived does
not raise an exception while processing it.
"""
self.proto.datagramReceived(b"", address.IPv4Address("UDP", "127.0.0.1", 12345))
self.assertEqual(self.controller.messages, [])
def test_simpleQuery(self):
"""
Test content received after a query.
"""
d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")])
self.assertEqual(len(self.proto.liveMessages.keys()), 1)
m = dns.Message()
m.id = next(iter(self.proto.liveMessages.keys()))
m.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))]
def cb(result):
self.assertEqual(result.answers[0].payload.dottedQuad(), "1.2.3.4")
d.addCallback(cb)
self.proto.datagramReceived(m.toStr(), ("127.0.0.1", 21345))
return d
def test_queryTimeout(self):
"""
Test that query timeouts after some seconds.
"""
d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")])
self.assertEqual(len(self.proto.liveMessages), 1)
self.clock.advance(10)
self.assertFailure(d, dns.DNSQueryTimeoutError)
self.assertEqual(len(self.proto.liveMessages), 0)
return d
def test_writeError(self):
"""
Exceptions raised by the transport's write method should be turned into
C{Failure}s passed to errbacks of the C{Deferred} returned by
L{DNSDatagramProtocol.query}.
"""
def writeError(message, addr):
raise RuntimeError("bar")
self.proto.transport.write = writeError
d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")])
return self.assertFailure(d, RuntimeError)
def test_listenError(self):
"""
Exception L{CannotListenError} raised by C{listenUDP} should be turned
into a C{Failure} passed to errback of the C{Deferred} returned by
L{DNSDatagramProtocol.query}.
"""
def startListeningError():
raise CannotListenError(None, None, None)
self.proto.startListening = startListeningError
# Clean up transport so that the protocol calls startListening again
self.proto.transport = None
d = self.proto.query(("127.0.0.1", 21345), [dns.Query(b"foo")])
return self.assertFailure(d, CannotListenError)
def test_receiveMessageNotInLiveMessages(self):
"""
When receiving a message whose id is not in
L{DNSDatagramProtocol.liveMessages} or L{DNSDatagramProtocol.resends},
the message will be received by L{DNSDatagramProtocol.controller}.
"""
message = dns.Message()
message.id = 1
message.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))]
self.proto.datagramReceived(message.toStr(), ("127.0.0.1", 21345))
self.assertEqual(self.controller.messages[-1][0].toStr(), message.toStr())
class TestTCPController(TestController):
"""
Pretend to be a DNS query processor for a DNSProtocol.
@ivar connections: A list of L{DNSProtocol} instances which have
notified this controller that they are connected and have not
yet notified it that their connection has been lost.
"""
def __init__(self):
TestController.__init__(self)
self.connections = []
def connectionMade(self, proto):
self.connections.append(proto)
def connectionLost(self, proto):
self.connections.remove(proto)
class DNSProtocolTests(unittest.TestCase):
"""
Test various aspects of L{dns.DNSProtocol}.
"""
def setUp(self):
"""
Create a L{dns.DNSProtocol} with a deterministic clock.
"""
self.clock = task.Clock()
self.controller = TestTCPController()
self.proto = dns.DNSProtocol(self.controller)
self.proto.makeConnection(proto_helpers.StringTransport())
self.proto.callLater = self.clock.callLater
def test_connectionTracking(self):
"""
L{dns.DNSProtocol} calls its controller's C{connectionMade}
method with itself when it is connected to a transport and its
controller's C{connectionLost} method when it is disconnected.
"""
self.assertEqual(self.controller.connections, [self.proto])
self.proto.connectionLost(Failure(ConnectionDone("Fake Connection Done")))
self.assertEqual(self.controller.connections, [])
def test_queryTimeout(self):
"""
Test that query timeouts after some seconds.
"""
d = self.proto.query([dns.Query(b"foo")])
self.assertEqual(len(self.proto.liveMessages), 1)
self.clock.advance(60)
self.assertFailure(d, dns.DNSQueryTimeoutError)
self.assertEqual(len(self.proto.liveMessages), 0)
return d
def test_simpleQuery(self):
"""
Test content received after a query.
"""
d = self.proto.query([dns.Query(b"foo")])
self.assertEqual(len(self.proto.liveMessages.keys()), 1)
m = dns.Message()
m.id = next(iter(self.proto.liveMessages.keys()))
m.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))]
def cb(result):
self.assertEqual(result.answers[0].payload.dottedQuad(), "1.2.3.4")
d.addCallback(cb)
s = m.toStr()
s = struct.pack("!H", len(s)) + s
self.proto.dataReceived(s)
return d
def test_writeError(self):
"""
Exceptions raised by the transport's write method should be turned into
C{Failure}s passed to errbacks of the C{Deferred} returned by
L{DNSProtocol.query}.
"""
def writeError(message):
raise RuntimeError("bar")
self.proto.transport.write = writeError
d = self.proto.query([dns.Query(b"foo")])
return self.assertFailure(d, RuntimeError)
def test_receiveMessageNotInLiveMessages(self):
"""
When receiving a message whose id is not in L{DNSProtocol.liveMessages}
the message will be received by L{DNSProtocol.controller}.
"""
message = dns.Message()
message.id = 1
message.answers = [dns.RRHeader(payload=dns.Record_A(address="1.2.3.4"))]
string = message.toStr()
string = struct.pack("!H", len(string)) + string
self.proto.dataReceived(string)
self.assertEqual(self.controller.messages[-1][0].toStr(), message.toStr())
class ReprTests(unittest.TestCase):
"""
Tests for the C{__repr__} implementation of record classes.
"""
def test_ns(self):
"""
The repr of a L{dns.Record_NS} instance includes the name of the
nameserver and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_NS(b"example.com", 4321)), "<NS name=example.com ttl=4321>"
)
def test_md(self):
"""
The repr of a L{dns.Record_MD} instance includes the name of the
mail destination and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_MD(b"example.com", 4321)), "<MD name=example.com ttl=4321>"
)
def test_mf(self):
"""
The repr of a L{dns.Record_MF} instance includes the name of the
mail forwarder and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_MF(b"example.com", 4321)), "<MF name=example.com ttl=4321>"
)
def test_cname(self):
"""
The repr of a L{dns.Record_CNAME} instance includes the name of the
mail forwarder and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_CNAME(b"example.com", 4321)),
"<CNAME name=example.com ttl=4321>",
)
def test_mb(self):
"""
The repr of a L{dns.Record_MB} instance includes the name of the
mailbox and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_MB(b"example.com", 4321)), "<MB name=example.com ttl=4321>"
)
def test_mg(self):
"""
The repr of a L{dns.Record_MG} instance includes the name of the
mail group member and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_MG(b"example.com", 4321)), "<MG name=example.com ttl=4321>"
)
def test_mr(self):
"""
The repr of a L{dns.Record_MR} instance includes the name of the
mail rename domain and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_MR(b"example.com", 4321)), "<MR name=example.com ttl=4321>"
)
def test_ptr(self):
"""
The repr of a L{dns.Record_PTR} instance includes the name of the
pointer and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_PTR(b"example.com", 4321)),
"<PTR name=example.com ttl=4321>",
)
def test_dname(self):
"""
The repr of a L{dns.Record_DNAME} instance includes the name of the
non-terminal DNS name redirection and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_DNAME(b"example.com", 4321)),
"<DNAME name=example.com ttl=4321>",
)
def test_a(self):
"""
The repr of a L{dns.Record_A} instance includes the dotted-quad
string representation of the address it is for and the TTL of the
record.
"""
self.assertEqual(
repr(dns.Record_A("1.2.3.4", 567)), "<A address=1.2.3.4 ttl=567>"
)
def test_soa(self):
"""
The repr of a L{dns.Record_SOA} instance includes all of the
authority fields.
"""
self.assertEqual(
repr(
dns.Record_SOA(
mname=b"mName",
rname=b"rName",
serial=123,
refresh=456,
retry=789,
expire=10,
minimum=11,
ttl=12,
)
),
"<SOA mname=mName rname=rName serial=123 refresh=456 "
"retry=789 expire=10 minimum=11 ttl=12>",
)
def test_null(self):
"""
The repr of a L{dns.Record_NULL} instance includes the repr of its
payload and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_NULL(b"abcd", 123)), "<NULL payload='abcd' ttl=123>"
)
def test_wks(self):
"""
The repr of a L{dns.Record_WKS} instance includes the dotted-quad
string representation of the address it is for, the IP protocol
number it is for, and the TTL of the record.
"""
self.assertEqual(
repr(dns.Record_WKS("2.3.4.5", 7, ttl=8)),
"<WKS address=2.3.4.5 protocol=7 ttl=8>",
)
def test_aaaa(self):
"""
The repr of a L{dns.Record_AAAA} instance includes the colon-separated
hex string representation of the address it is for and the TTL of the
record.
"""
self.assertEqual(
repr(dns.Record_AAAA("8765::1234", ttl=10)),
"<AAAA address=8765::1234 ttl=10>",
)
def test_a6(self):
"""
The repr of a L{dns.Record_A6} instance includes the colon-separated
hex string representation of the address it is for and the TTL of the
record.
"""
self.assertEqual(
repr(dns.Record_A6(0, "1234::5678", b"foo.bar", ttl=10)),
"<A6 suffix=1234::5678 prefix=foo.bar ttl=10>",
)
def test_srv(self):
"""
The repr of a L{dns.Record_SRV} instance includes the name and port of
the target and the priority, weight, and TTL of the record.
"""
self.assertEqual(
repr(dns.Record_SRV(1, 2, 3, b"example.org", 4)),
"<SRV priority=1 weight=2 target=example.org port=3 ttl=4>",
)
def test_naptr(self):
"""
The repr of a L{dns.Record_NAPTR} instance includes the order,
preference, flags, service, regular expression, replacement, and TTL of
the record.
"""
record = dns.Record_NAPTR(5, 9, b"S", b"http", b"/foo/bar/i", b"baz", 3)
self.assertEqual(
repr(record),
"<NAPTR order=5 preference=9 flags=S service=http "
"regexp=/foo/bar/i replacement=baz ttl=3>",
)
def test_afsdb(self):
"""
The repr of a L{dns.Record_AFSDB} instance includes the subtype,
hostname, and TTL of the record.
"""
self.assertEqual(
repr(dns.Record_AFSDB(3, b"example.org", 5)),
"<AFSDB subtype=3 hostname=example.org ttl=5>",
)
def test_rp(self):
"""
The repr of a L{dns.Record_RP} instance includes the mbox, txt, and TTL
fields of the record.
"""
self.assertEqual(
repr(dns.Record_RP(b"alice.example.com", b"admin.example.com", 3)),
"<RP mbox=alice.example.com txt=admin.example.com ttl=3>",
)
def test_hinfo(self):
"""
The repr of a L{dns.Record_HINFO} instance includes the cpu, os, and
TTL fields of the record.
"""
self.assertEqual(
repr(dns.Record_HINFO(b"sparc", b"minix", 12)),
"<HINFO cpu='sparc' os='minix' ttl=12>",
)
def test_minfo(self):
"""
The repr of a L{dns.Record_MINFO} instance includes the rmailbx,
emailbx, and TTL fields of the record.
"""
record = dns.Record_MINFO(b"alice.example.com", b"bob.example.com", 15)
self.assertEqual(
repr(record),
"<MINFO responsibility=alice.example.com " "errors=bob.example.com ttl=15>",
)
def test_mx(self):
"""
The repr of a L{dns.Record_MX} instance includes the preference, name,
and TTL fields of the record.
"""
self.assertEqual(
repr(dns.Record_MX(13, b"mx.example.com", 2)),
"<MX preference=13 name=mx.example.com ttl=2>",
)
def test_txt(self):
"""
The repr of a L{dns.Record_TXT} instance includes the data and ttl
fields of the record.
"""
self.assertEqual(
repr(dns.Record_TXT(b"foo", b"bar", ttl=15)),
"<TXT data=['foo', 'bar'] ttl=15>",
)
def test_spf(self):
"""
The repr of a L{dns.Record_SPF} instance includes the data and ttl
fields of the record.
"""
self.assertEqual(
repr(dns.Record_SPF(b"foo", b"bar", ttl=15)),
"<SPF data=['foo', 'bar'] ttl=15>",
)
def test_unknown(self):
"""
The repr of a L{dns.UnknownRecord} instance includes the data and ttl
fields of the record.
"""
self.assertEqual(
repr(dns.UnknownRecord(b"foo\x1fbar", 12)),
"<UNKNOWN data='foo\\x1fbar' ttl=12>",
)
class EqualityTests(ComparisonTestsMixin, unittest.TestCase):
"""
Tests for the equality and non-equality behavior of record classes.
"""
def _equalityTest(self, firstValueOne, secondValueOne, valueTwo):
return self.assertNormalEqualityImplementation(
firstValueOne, secondValueOne, valueTwo
)
def test_charstr(self):
"""
Two L{dns.Charstr} instances compare equal if and only if they have the
same string value.
"""
self._equalityTest(
dns.Charstr(b"abc"), dns.Charstr(b"abc"), dns.Charstr(b"def")
)
def test_name(self):
"""
Two L{dns.Name} instances compare equal if and only if they have the
same name value.
"""
self._equalityTest(dns.Name(b"abc"), dns.Name(b"abc"), dns.Name(b"def"))
def _simpleEqualityTest(self, cls):
"""
Assert that instances of C{cls} with the same attributes compare equal
to each other and instances with different attributes compare as not
equal.
@param cls: A L{dns.SimpleRecord} subclass.
"""
# Vary the TTL
self._equalityTest(
cls(b"example.com", 123), cls(b"example.com", 123), cls(b"example.com", 321)
)
# Vary the name
self._equalityTest(
cls(b"example.com", 123), cls(b"example.com", 123), cls(b"example.org", 123)
)
def test_rrheader(self):
"""
Two L{dns.RRHeader} instances compare equal if and only if they have
the same name, type, class, time to live, payload, and authoritative
bit.
"""
# Vary the name
self._equalityTest(
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.org", payload=dns.Record_A("1.2.3.4")),
)
# Vary the payload
self._equalityTest(
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.5")),
)
# Vary the type. Leave the payload as None so that we don't have to
# provide non-equal values.
self._equalityTest(
dns.RRHeader(b"example.com", dns.A),
dns.RRHeader(b"example.com", dns.A),
dns.RRHeader(b"example.com", dns.MX),
)
# Probably not likely to come up. Most people use the internet.
self._equalityTest(
dns.RRHeader(b"example.com", cls=dns.IN, payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", cls=dns.IN, payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", cls=dns.CS, payload=dns.Record_A("1.2.3.4")),
)
# Vary the ttl
self._equalityTest(
dns.RRHeader(b"example.com", ttl=60, payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", ttl=60, payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", ttl=120, payload=dns.Record_A("1.2.3.4")),
)
# Vary the auth bit
self._equalityTest(
dns.RRHeader(b"example.com", auth=1, payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", auth=1, payload=dns.Record_A("1.2.3.4")),
dns.RRHeader(b"example.com", auth=0, payload=dns.Record_A("1.2.3.4")),
)
def test_ns(self):
"""
Two L{dns.Record_NS} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_NS)
def test_md(self):
"""
Two L{dns.Record_MD} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_MD)
def test_mf(self):
"""
Two L{dns.Record_MF} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_MF)
def test_cname(self):
"""
Two L{dns.Record_CNAME} instances compare equal if and only if they
have the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_CNAME)
def test_mb(self):
"""
Two L{dns.Record_MB} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_MB)
def test_mg(self):
"""
Two L{dns.Record_MG} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_MG)
def test_mr(self):
"""
Two L{dns.Record_MR} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_MR)
def test_ptr(self):
"""
Two L{dns.Record_PTR} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_PTR)
def test_dname(self):
"""
Two L{dns.Record_MD} instances compare equal if and only if they have
the same name and TTL.
"""
self._simpleEqualityTest(dns.Record_DNAME)
def test_a(self):
"""
Two L{dns.Record_A} instances compare equal if and only if they have
the same address and TTL.
"""
# Vary the TTL
self._equalityTest(
dns.Record_A("1.2.3.4", 5),
dns.Record_A("1.2.3.4", 5),
dns.Record_A("1.2.3.4", 6),
)
# Vary the address
self._equalityTest(
dns.Record_A("1.2.3.4", 5),
dns.Record_A("1.2.3.4", 5),
dns.Record_A("1.2.3.5", 5),
)
def test_soa(self):
"""
Two L{dns.Record_SOA} instances compare equal if and only if they have
the same mname, rname, serial, refresh, minimum, expire, retry, and
ttl.
"""
# Vary the mname
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"xname", b"rname", 123, 456, 789, 10, 20, 30),
)
# Vary the rname
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"xname", 123, 456, 789, 10, 20, 30),
)
# Vary the serial
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 1, 456, 789, 10, 20, 30),
)
# Vary the refresh
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 1, 789, 10, 20, 30),
)
# Vary the minimum
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 1, 10, 20, 30),
)
# Vary the expire
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 1, 20, 30),
)
# Vary the retry
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 1, 30),
)
# Vary the ttl
self._equalityTest(
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"rname", 123, 456, 789, 10, 20, 30),
dns.Record_SOA(b"mname", b"xname", 123, 456, 789, 10, 20, 1),
)
def test_null(self):
"""
Two L{dns.Record_NULL} instances compare equal if and only if they have
the same payload and ttl.
"""
# Vary the payload
self._equalityTest(
dns.Record_NULL("foo bar", 10),
dns.Record_NULL("foo bar", 10),
dns.Record_NULL("bar foo", 10),
)
# Vary the ttl
self._equalityTest(
dns.Record_NULL("foo bar", 10),
dns.Record_NULL("foo bar", 10),
dns.Record_NULL("foo bar", 100),
)
def test_wks(self):
"""
Two L{dns.Record_WKS} instances compare equal if and only if they have
the same address, protocol, map, and ttl.
"""
# Vary the address
self._equalityTest(
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("4.3.2.1", 1, "foo", 2),
)
# Vary the protocol
self._equalityTest(
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("1.2.3.4", 100, "foo", 2),
)
# Vary the map
self._equalityTest(
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("1.2.3.4", 1, "bar", 2),
)
# Vary the ttl
self._equalityTest(
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("1.2.3.4", 1, "foo", 2),
dns.Record_WKS("1.2.3.4", 1, "foo", 200),
)
def test_aaaa(self):
"""
Two L{dns.Record_AAAA} instances compare equal if and only if they have
the same address and ttl.
"""
# Vary the address
self._equalityTest(
dns.Record_AAAA("1::2", 1),
dns.Record_AAAA("1::2", 1),
dns.Record_AAAA("2::1", 1),
)
# Vary the ttl
self._equalityTest(
dns.Record_AAAA("1::2", 1),
dns.Record_AAAA("1::2", 1),
dns.Record_AAAA("1::2", 10),
)
def test_a6(self):
"""
Two L{dns.Record_A6} instances compare equal if and only if they have
the same prefix, prefix length, suffix, and ttl.
"""
# Note, A6 is crazy, I'm not sure these values are actually legal.
# Hopefully that doesn't matter for this test. -exarkun
# Vary the prefix length
self._equalityTest(
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(32, "::abcd", b"example.com", 10),
)
# Vary the suffix
self._equalityTest(
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(16, "::abcd:0", b"example.com", 10),
)
# Vary the prefix
self._equalityTest(
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(16, "::abcd", b"example.org", 10),
)
# Vary the ttl
self._equalityTest(
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(16, "::abcd", b"example.com", 10),
dns.Record_A6(16, "::abcd", b"example.com", 100),
)
def test_srv(self):
"""
Two L{dns.Record_SRV} instances compare equal if and only if they have
the same priority, weight, port, target, and ttl.
"""
# Vary the priority
self._equalityTest(
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(100, 20, 30, b"example.com", 40),
)
# Vary the weight
self._equalityTest(
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 200, 30, b"example.com", 40),
)
# Vary the port
self._equalityTest(
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 300, b"example.com", 40),
)
# Vary the target
self._equalityTest(
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 30, b"example.org", 40),
)
# Vary the ttl
self._equalityTest(
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 30, b"example.com", 40),
dns.Record_SRV(10, 20, 30, b"example.com", 400),
)
def test_sshfp(self):
"""
Two L{dns.Record_SSHFP} instances compare equal if and only if
they have the same key type, fingerprint type, fingerprint, and ttl.
"""
# Vary the key type.
self._equalityTest(
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(2, 2, b"happyday", 40),
)
# Vary the fingerprint type.
self._equalityTest(
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(1, 1, b"happyday", 40),
)
# Vary the fingerprint itself.
self._equalityTest(
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(1, 2, b"happxday", 40),
)
# Vary the ttl.
self._equalityTest(
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(1, 2, b"happyday", 40),
dns.Record_SSHFP(1, 2, b"happyday", 45),
)
def test_naptr(self):
"""
Two L{dns.Record_NAPTR} instances compare equal if and only if they
have the same order, preference, flags, service, regexp, replacement,
and ttl.
"""
# Vary the order
self._equalityTest(
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(2, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
)
# Vary the preference
self._equalityTest(
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 3, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
)
# Vary the flags
self._equalityTest(
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"p", b"sip+E2U", b"/foo/bar/", b"baz", 12),
)
# Vary the service
self._equalityTest(
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"http", b"/foo/bar/", b"baz", 12),
)
# Vary the regexp
self._equalityTest(
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/bar/foo/", b"baz", 12),
)
# Vary the replacement
self._equalityTest(
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/bar/foo/", b"quux", 12),
)
# Vary the ttl
self._equalityTest(
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/foo/bar/", b"baz", 12),
dns.Record_NAPTR(1, 2, b"u", b"sip+E2U", b"/bar/foo/", b"baz", 5),
)
def test_afsdb(self):
"""
Two L{dns.Record_AFSDB} instances compare equal if and only if they
have the same subtype, hostname, and ttl.
"""
# Vary the subtype
self._equalityTest(
dns.Record_AFSDB(1, b"example.com", 2),
dns.Record_AFSDB(1, b"example.com", 2),
dns.Record_AFSDB(2, b"example.com", 2),
)
# Vary the hostname
self._equalityTest(
dns.Record_AFSDB(1, b"example.com", 2),
dns.Record_AFSDB(1, b"example.com", 2),
dns.Record_AFSDB(1, b"example.org", 2),
)
# Vary the ttl
self._equalityTest(
dns.Record_AFSDB(1, b"example.com", 2),
dns.Record_AFSDB(1, b"example.com", 2),
dns.Record_AFSDB(1, b"example.com", 3),
)
def test_rp(self):
"""
Two L{Record_RP} instances compare equal if and only if they have the
same mbox, txt, and ttl.
"""
# Vary the mbox
self._equalityTest(
dns.Record_RP(b"alice.example.com", b"alice is nice", 10),
dns.Record_RP(b"alice.example.com", b"alice is nice", 10),
dns.Record_RP(b"bob.example.com", b"alice is nice", 10),
)
# Vary the txt
self._equalityTest(
dns.Record_RP(b"alice.example.com", b"alice is nice", 10),
dns.Record_RP(b"alice.example.com", b"alice is nice", 10),
dns.Record_RP(b"alice.example.com", b"alice is not nice", 10),
)
# Vary the ttl
self._equalityTest(
dns.Record_RP(b"alice.example.com", b"alice is nice", 10),
dns.Record_RP(b"alice.example.com", b"alice is nice", 10),
dns.Record_RP(b"alice.example.com", b"alice is nice", 100),
)
def test_hinfo(self):
"""
Two L{dns.Record_HINFO} instances compare equal if and only if they
have the same cpu, os, and ttl.
"""
# Vary the cpu
self._equalityTest(
dns.Record_HINFO("x86-64", "plan9", 10),
dns.Record_HINFO("x86-64", "plan9", 10),
dns.Record_HINFO("i386", "plan9", 10),
)
# Vary the os
self._equalityTest(
dns.Record_HINFO("x86-64", "plan9", 10),
dns.Record_HINFO("x86-64", "plan9", 10),
dns.Record_HINFO("x86-64", "plan11", 10),
)
# Vary the ttl
self._equalityTest(
dns.Record_HINFO("x86-64", "plan9", 10),
dns.Record_HINFO("x86-64", "plan9", 10),
dns.Record_HINFO("x86-64", "plan9", 100),
)
def test_minfo(self):
"""
Two L{dns.Record_MINFO} instances compare equal if and only if they
have the same rmailbx, emailbx, and ttl.
"""
# Vary the rmailbx
self._equalityTest(
dns.Record_MINFO(b"rmailbox", b"emailbox", 10),
dns.Record_MINFO(b"rmailbox", b"emailbox", 10),
dns.Record_MINFO(b"someplace", b"emailbox", 10),
)
# Vary the emailbx
self._equalityTest(
dns.Record_MINFO(b"rmailbox", b"emailbox", 10),
dns.Record_MINFO(b"rmailbox", b"emailbox", 10),
dns.Record_MINFO(b"rmailbox", b"something", 10),
)
# Vary the ttl
self._equalityTest(
dns.Record_MINFO(b"rmailbox", b"emailbox", 10),
dns.Record_MINFO(b"rmailbox", b"emailbox", 10),
dns.Record_MINFO(b"rmailbox", b"emailbox", 100),
)
def test_mx(self):
"""
Two L{dns.Record_MX} instances compare equal if and only if they have
the same preference, name, and ttl.
"""
# Vary the preference
self._equalityTest(
dns.Record_MX(10, b"example.org", 20),
dns.Record_MX(10, b"example.org", 20),
dns.Record_MX(100, b"example.org", 20),
)
# Vary the name
self._equalityTest(
dns.Record_MX(10, b"example.org", 20),
dns.Record_MX(10, b"example.org", 20),
dns.Record_MX(10, b"example.net", 20),
)
# Vary the ttl
self._equalityTest(
dns.Record_MX(10, b"example.org", 20),
dns.Record_MX(10, b"example.org", 20),
dns.Record_MX(10, b"example.org", 200),
)
def test_txt(self):
"""
Two L{dns.Record_TXT} instances compare equal if and only if they have
the same data and ttl.
"""
# Vary the length of the data
self._equalityTest(
dns.Record_TXT("foo", "bar", ttl=10),
dns.Record_TXT("foo", "bar", ttl=10),
dns.Record_TXT("foo", "bar", "baz", ttl=10),
)
# Vary the value of the data
self._equalityTest(
dns.Record_TXT("foo", "bar", ttl=10),
dns.Record_TXT("foo", "bar", ttl=10),
dns.Record_TXT("bar", "foo", ttl=10),
)
# Vary the ttl
self._equalityTest(
dns.Record_TXT("foo", "bar", ttl=10),
dns.Record_TXT("foo", "bar", ttl=10),
dns.Record_TXT("foo", "bar", ttl=100),
)
def test_spf(self):
"""
L{dns.Record_SPF} instances compare equal if and only if they have the
same data and ttl.
"""
# Vary the length of the data
self._equalityTest(
dns.Record_SPF("foo", "bar", ttl=10),
dns.Record_SPF("foo", "bar", ttl=10),
dns.Record_SPF("foo", "bar", "baz", ttl=10),
)
# Vary the value of the data
self._equalityTest(
dns.Record_SPF("foo", "bar", ttl=10),
dns.Record_SPF("foo", "bar", ttl=10),
dns.Record_SPF("bar", "foo", ttl=10),
)
# Vary the ttl
self._equalityTest(
dns.Record_SPF("foo", "bar", ttl=10),
dns.Record_SPF("foo", "bar", ttl=10),
dns.Record_SPF("foo", "bar", ttl=100),
)
def test_tsig(self):
"""
L{dns.Record_TSIG} instances compare equal if and only if they have the
same RDATA (algorithm, timestamp, MAC, etc.) and ttl.
"""
baseargs = {
"algorithm": "hmac-sha224",
"timeSigned": 1515548975,
"fudge": 5,
"MAC": b"\x01\x02\x03\x04\x05",
"originalID": 99,
"error": dns.OK,
"otherData": b"",
"ttl": 40,
}
altargs = {
"algorithm": "hmac-sha512",
"timeSigned": 1515548875,
"fudge": 0,
"MAC": b"\x05\x04\x03\x02\x01",
"originalID": 65437,
"error": dns.EBADTIME,
"otherData": b"\x00\x00",
"ttl": 400,
}
for kw in baseargs.keys():
altered = baseargs.copy()
altered[kw] = altargs[kw]
self._equalityTest(
dns.Record_TSIG(**altered),
dns.Record_TSIG(**altered),
dns.Record_TSIG(**baseargs),
)
def test_unknown(self):
"""
L{dns.UnknownRecord} instances compare equal if and only if they have
the same data and ttl.
"""
# Vary the length of the data
self._equalityTest(
dns.UnknownRecord("foo", ttl=10),
dns.UnknownRecord("foo", ttl=10),
dns.UnknownRecord("foobar", ttl=10),
)
# Vary the value of the data
self._equalityTest(
dns.UnknownRecord("foo", ttl=10),
dns.UnknownRecord("foo", ttl=10),
dns.UnknownRecord("bar", ttl=10),
)
# Vary the ttl
self._equalityTest(
dns.UnknownRecord("foo", ttl=10),
dns.UnknownRecord("foo", ttl=10),
dns.UnknownRecord("foo", ttl=100),
)
class RRHeaderTests(unittest.TestCase):
"""
Tests for L{twisted.names.dns.RRHeader}.
"""
def test_negativeTTL(self):
"""
Attempting to create a L{dns.RRHeader} instance with a negative TTL
causes L{ValueError} to be raised.
"""
self.assertRaises(
ValueError,
dns.RRHeader,
"example.com",
dns.A,
dns.IN,
-1,
dns.Record_A("127.0.0.1"),
)
def test_nonIntegralTTL(self):
"""
L{dns.RRHeader} converts TTLs to integers.
"""
ttlAsFloat = 123.45
header = dns.RRHeader(
"example.com", dns.A, dns.IN, ttlAsFloat, dns.Record_A("127.0.0.1")
)
self.assertEqual(header.ttl, int(ttlAsFloat))
def test_nonNumericTTLRaisesTypeError(self):
"""
Attempting to create a L{dns.RRHeader} instance with a TTL
that L{int} cannot convert to an integer raises a L{TypeError}.
"""
self.assertRaises(
ValueError,
dns.RRHeader,
"example.com",
dns.A,
dns.IN,
"this is not a number",
dns.Record_A("127.0.0.1"),
)
class NameToLabelsTests(unittest.SynchronousTestCase):
"""
Tests for L{twisted.names.dns._nameToLabels}.
"""
def test_empty(self):
"""
L{dns._nameToLabels} returns a list containing a single
empty label for an empty name.
"""
self.assertEqual(dns._nameToLabels(b""), [b""])
def test_onlyDot(self):
"""
L{dns._nameToLabels} returns a list containing a single
empty label for a name containing only a dot.
"""
self.assertEqual(dns._nameToLabels(b"."), [b""])
def test_withoutTrailingDot(self):
"""
L{dns._nameToLabels} returns a list ending with an empty
label for a name without a trailing dot.
"""
self.assertEqual(dns._nameToLabels(b"com"), [b"com", b""])
def test_withTrailingDot(self):
"""
L{dns._nameToLabels} returns a list ending with an empty
label for a name with a trailing dot.
"""
self.assertEqual(dns._nameToLabels(b"com."), [b"com", b""])
def test_subdomain(self):
"""
L{dns._nameToLabels} returns a list containing entries
for all labels in a subdomain name.
"""
self.assertEqual(
dns._nameToLabels(b"foo.bar.baz.example.com."),
[b"foo", b"bar", b"baz", b"example", b"com", b""],
)
def test_casePreservation(self):
"""
L{dns._nameToLabels} preserves the case of ascii
characters in labels.
"""
self.assertEqual(dns._nameToLabels(b"EXAMPLE.COM"), [b"EXAMPLE", b"COM", b""])
def assertIsSubdomainOf(testCase, descendant, ancestor):
"""
Assert that C{descendant} *is* a subdomain of C{ancestor}.
@type testCase: L{unittest.SynchronousTestCase}
@param testCase: The test case on which to run the assertions.
@type descendant: C{str}
@param descendant: The subdomain name to test.
@type ancestor: C{str}
@param ancestor: The superdomain name to test.
"""
testCase.assertTrue(
dns._isSubdomainOf(descendant, ancestor),
f"{descendant!r} is not a subdomain of {ancestor!r}",
)
def assertIsNotSubdomainOf(testCase, descendant, ancestor):
"""
Assert that C{descendant} *is not* a subdomain of C{ancestor}.
@type testCase: L{unittest.SynchronousTestCase}
@param testCase: The test case on which to run the assertions.
@type descendant: C{str}
@param descendant: The subdomain name to test.
@type ancestor: C{str}
@param ancestor: The superdomain name to test.
"""
testCase.assertFalse(
dns._isSubdomainOf(descendant, ancestor),
f"{descendant!r} is a subdomain of {ancestor!r}",
)
class IsSubdomainOfTests(unittest.SynchronousTestCase):
"""
Tests for L{twisted.names.dns._isSubdomainOf}.
"""
def test_identical(self):
"""
L{dns._isSubdomainOf} returns C{True} for identical
domain names.
"""
assertIsSubdomainOf(self, b"example.com", b"example.com")
def test_parent(self):
"""
L{dns._isSubdomainOf} returns C{True} when the first
name is an immediate descendant of the second name.
"""
assertIsSubdomainOf(self, b"foo.example.com", b"example.com")
def test_distantAncestor(self):
"""
L{dns._isSubdomainOf} returns C{True} when the first
name is a distant descendant of the second name.
"""
assertIsSubdomainOf(self, b"foo.bar.baz.example.com", b"com")
def test_superdomain(self):
"""
L{dns._isSubdomainOf} returns C{False} when the first
name is an ancestor of the second name.
"""
assertIsNotSubdomainOf(self, b"example.com", b"foo.example.com")
def test_sibling(self):
"""
L{dns._isSubdomainOf} returns C{False} if the first name
is a sibling of the second name.
"""
assertIsNotSubdomainOf(self, b"foo.example.com", b"bar.example.com")
def test_unrelatedCommonSuffix(self):
"""
L{dns._isSubdomainOf} returns C{False} even when domain
names happen to share a common suffix.
"""
assertIsNotSubdomainOf(self, b"foo.myexample.com", b"example.com")
def test_subdomainWithTrailingDot(self):
"""
L{dns._isSubdomainOf} returns C{True} if the first name
is a subdomain of the second name but the first name has a
trailing ".".
"""
assertIsSubdomainOf(self, b"foo.example.com.", b"example.com")
def test_superdomainWithTrailingDot(self):
"""
L{dns._isSubdomainOf} returns C{True} if the first name
is a subdomain of the second name but the second name has a
trailing ".".
"""
assertIsSubdomainOf(self, b"foo.example.com", b"example.com.")
def test_bothWithTrailingDot(self):
"""
L{dns._isSubdomainOf} returns C{True} if the first name
is a subdomain of the second name and both names have a
trailing ".".
"""
assertIsSubdomainOf(self, b"foo.example.com.", b"example.com.")
def test_emptySubdomain(self):
"""
L{dns._isSubdomainOf} returns C{False} if the first name
is empty and the second name is not.
"""
assertIsNotSubdomainOf(self, b"", b"example.com")
def test_emptySuperdomain(self):
"""
L{dns._isSubdomainOf} returns C{True} if the second name
is empty and the first name is not.
"""
assertIsSubdomainOf(self, b"foo.example.com", b"")
def test_caseInsensitiveComparison(self):
"""
L{dns._isSubdomainOf} does case-insensitive comparison
of name labels.
"""
assertIsSubdomainOf(self, b"foo.example.com", b"EXAMPLE.COM")
assertIsSubdomainOf(self, b"FOO.EXAMPLE.COM", b"example.com")
class OPTNonStandardAttributes:
"""
Generate byte and instance representations of an L{dns._OPTHeader}
where all attributes are set to non-default values.
For testing whether attributes have really been read from the byte
string during decoding.
"""
@classmethod
def bytes(cls, excludeName=False, excludeOptions=False):
"""
Return L{bytes} representing an encoded OPT record.
@param excludeName: A flag that controls whether to exclude
the name field. This allows a non-standard name to be
prepended during the test.
@type excludeName: L{bool}
@param excludeOptions: A flag that controls whether to exclude
the RDLEN field. This allows encoded variable options to be
appended during the test.
@type excludeOptions: L{bool}
@return: L{bytes} representing the encoded OPT record returned
by L{object}.
"""
rdlen = b"\x00\x00" # RDLEN 0
if excludeOptions:
rdlen = b""
return (
b"\x00" # 0 root zone
b"\x00\x29" # type 41
b"\x02\x00" # udpPayloadsize 512
b"\x03" # extendedRCODE 3
b"\x04" # version 4
b"\x80\x00" # DNSSEC OK 1 + Z
) + rdlen
@classmethod
def object(cls):
"""
Return a new L{dns._OPTHeader} instance.
@return: A L{dns._OPTHeader} instance with attributes that
match the encoded record returned by L{bytes}.
"""
return dns._OPTHeader(
udpPayloadSize=512, extendedRCODE=3, version=4, dnssecOK=True
)
class OPTHeaderTests(ComparisonTestsMixin, unittest.TestCase):
"""
Tests for L{twisted.names.dns._OPTHeader}.
"""
def test_interface(self):
"""
L{dns._OPTHeader} implements L{dns.IEncodable}.
"""
verifyClass(dns.IEncodable, dns._OPTHeader)
def test_name(self):
"""
L{dns._OPTHeader.name} is an instance attribute whose value is
fixed as the root domain
"""
self.assertEqual(dns._OPTHeader().name, dns.Name(b""))
def test_nameReadonly(self):
"""
L{dns._OPTHeader.name} is readonly.
"""
h = dns._OPTHeader()
self.assertRaises(AttributeError, setattr, h, "name", dns.Name(b"example.com"))
def test_type(self):
"""
L{dns._OPTHeader.type} is an instance attribute with fixed value
41.
"""
self.assertEqual(dns._OPTHeader().type, 41)
def test_typeReadonly(self):
"""
L{dns._OPTHeader.type} is readonly.
"""
h = dns._OPTHeader()
self.assertRaises(AttributeError, setattr, h, "type", dns.A)
def test_udpPayloadSize(self):
"""
L{dns._OPTHeader.udpPayloadSize} defaults to 4096 as
recommended in rfc6891 section-6.2.5.
"""
self.assertEqual(dns._OPTHeader().udpPayloadSize, 4096)
def test_udpPayloadSizeOverride(self):
"""
L{dns._OPTHeader.udpPayloadSize} can be overridden in the
constructor.
"""
self.assertEqual(dns._OPTHeader(udpPayloadSize=512).udpPayloadSize, 512)
def test_extendedRCODE(self):
"""
L{dns._OPTHeader.extendedRCODE} defaults to 0.
"""
self.assertEqual(dns._OPTHeader().extendedRCODE, 0)
def test_extendedRCODEOverride(self):
"""
L{dns._OPTHeader.extendedRCODE} can be overridden in the
constructor.
"""
self.assertEqual(dns._OPTHeader(extendedRCODE=1).extendedRCODE, 1)
def test_version(self):
"""
L{dns._OPTHeader.version} defaults to 0.
"""
self.assertEqual(dns._OPTHeader().version, 0)
def test_versionOverride(self):
"""
L{dns._OPTHeader.version} can be overridden in the
constructor.
"""
self.assertEqual(dns._OPTHeader(version=1).version, 1)
def test_dnssecOK(self):
"""
L{dns._OPTHeader.dnssecOK} defaults to False.
"""
self.assertFalse(dns._OPTHeader().dnssecOK)
def test_dnssecOKOverride(self):
"""
L{dns._OPTHeader.dnssecOK} can be overridden in the
constructor.
"""
self.assertTrue(dns._OPTHeader(dnssecOK=True).dnssecOK)
def test_options(self):
"""
L{dns._OPTHeader.options} defaults to empty list.
"""
self.assertEqual(dns._OPTHeader().options, [])
def test_optionsOverride(self):
"""
L{dns._OPTHeader.options} can be overridden in the
constructor.
"""
h = dns._OPTHeader(options=[(1, 1, b"\x00")])
self.assertEqual(h.options, [(1, 1, b"\x00")])
def test_encode(self):
"""
L{dns._OPTHeader.encode} packs the header fields and writes
them to a file like object passed in as an argument.
"""
b = BytesIO()
OPTNonStandardAttributes.object().encode(b)
self.assertEqual(b.getvalue(), OPTNonStandardAttributes.bytes())
def test_encodeWithOptions(self):
"""
L{dns._OPTHeader.options} is a list of L{dns._OPTVariableOption}
instances which are packed into the rdata area of the header.
"""
h = OPTNonStandardAttributes.object()
h.options = [
dns._OPTVariableOption(1, b"foobarbaz"),
dns._OPTVariableOption(2, b"qux"),
]
b = BytesIO()
h.encode(b)
self.assertEqual(
b.getvalue(),
OPTNonStandardAttributes.bytes(excludeOptions=True)
+ (
b"\x00\x14" # RDLEN 20
b"\x00\x01" # OPTION-CODE
b"\x00\x09" # OPTION-LENGTH
b"foobarbaz" # OPTION-DATA
b"\x00\x02" # OPTION-CODE
b"\x00\x03" # OPTION-LENGTH
b"qux" # OPTION-DATA
),
)
def test_decode(self):
"""
L{dns._OPTHeader.decode} unpacks the header fields from a file
like object and populates the attributes of an existing
L{dns._OPTHeader} instance.
"""
decodedHeader = dns._OPTHeader()
decodedHeader.decode(BytesIO(OPTNonStandardAttributes.bytes()))
self.assertEqual(decodedHeader, OPTNonStandardAttributes.object())
def test_decodeAllExpectedBytes(self):
"""
L{dns._OPTHeader.decode} reads all the bytes of the record
that is being decoded.
"""
# Check that all the input data has been consumed.
b = BytesIO(OPTNonStandardAttributes.bytes())
decodedHeader = dns._OPTHeader()
decodedHeader.decode(b)
self.assertEqual(b.tell(), len(b.getvalue()))
def test_decodeOnlyExpectedBytes(self):
"""
L{dns._OPTHeader.decode} reads only the bytes from the current
file position to the end of the record that is being
decoded. Trailing bytes are not consumed.
"""
b = BytesIO(OPTNonStandardAttributes.bytes() + b"xxxx") # Trailing bytes
decodedHeader = dns._OPTHeader()
decodedHeader.decode(b)
self.assertEqual(b.tell(), len(b.getvalue()) - len(b"xxxx"))
def test_decodeDiscardsName(self):
"""
L{dns._OPTHeader.decode} discards the name which is encoded in
the supplied bytes. The name attribute of the resulting
L{dns._OPTHeader} instance will always be L{dns.Name(b'')}.
"""
b = BytesIO(
OPTNonStandardAttributes.bytes(excludeName=True) + b"\x07example\x03com\x00"
)
h = dns._OPTHeader()
h.decode(b)
self.assertEqual(h.name, dns.Name(b""))
def test_decodeRdlengthTooShort(self):
"""
L{dns._OPTHeader.decode} raises an exception if the supplied
RDLEN is too short.
"""
b = BytesIO(
OPTNonStandardAttributes.bytes(excludeOptions=True)
+ (
b"\x00\x05" # RDLEN 5 Too short - should be 6
b"\x00\x01" # OPTION-CODE
b"\x00\x02" # OPTION-LENGTH
b"\x00\x00" # OPTION-DATA
)
)
h = dns._OPTHeader()
self.assertRaises(EOFError, h.decode, b)
def test_decodeRdlengthTooLong(self):
"""
L{dns._OPTHeader.decode} raises an exception if the supplied
RDLEN is too long.
"""
b = BytesIO(
OPTNonStandardAttributes.bytes(excludeOptions=True)
+ (
b"\x00\x07" # RDLEN 7 Too long - should be 6
b"\x00\x01" # OPTION-CODE
b"\x00\x02" # OPTION-LENGTH
b"\x00\x00" # OPTION-DATA
)
)
h = dns._OPTHeader()
self.assertRaises(EOFError, h.decode, b)
def test_decodeWithOptions(self):
"""
If the OPT bytes contain variable options,
L{dns._OPTHeader.decode} will populate a list
L{dns._OPTHeader.options} with L{dns._OPTVariableOption}
instances.
"""
b = BytesIO(
OPTNonStandardAttributes.bytes(excludeOptions=True)
+ (
b"\x00\x14" # RDLEN 20
b"\x00\x01" # OPTION-CODE
b"\x00\x09" # OPTION-LENGTH
b"foobarbaz" # OPTION-DATA
b"\x00\x02" # OPTION-CODE
b"\x00\x03" # OPTION-LENGTH
b"qux" # OPTION-DATA
)
)
h = dns._OPTHeader()
h.decode(b)
self.assertEqual(
h.options,
[
dns._OPTVariableOption(1, b"foobarbaz"),
dns._OPTVariableOption(2, b"qux"),
],
)
def test_fromRRHeader(self):
"""
L{_OPTHeader.fromRRHeader} accepts an L{RRHeader} instance and
returns an L{_OPTHeader} instance whose attribute values have
been derived from the C{cls}, C{ttl} and C{payload} attributes
of the original header.
"""
genericHeader = dns.RRHeader(
b"example.com",
type=dns.OPT,
cls=0xFFFF,
ttl=(0xFE << 24 | 0xFD << 16 | True << 15),
payload=dns.UnknownRecord(b"\xff\xff\x00\x03abc"),
)
decodedOptHeader = dns._OPTHeader.fromRRHeader(genericHeader)
expectedOptHeader = dns._OPTHeader(
udpPayloadSize=0xFFFF,
extendedRCODE=0xFE,
version=0xFD,
dnssecOK=True,
options=[dns._OPTVariableOption(code=0xFFFF, data=b"abc")],
)
self.assertEqual(decodedOptHeader, expectedOptHeader)
def test_repr(self):
"""
L{dns._OPTHeader.__repr__} displays the name and type and all
the fixed and extended header values of the OPT record.
"""
self.assertEqual(
repr(dns._OPTHeader()),
"<_OPTHeader "
"name= "
"type=41 "
"udpPayloadSize=4096 "
"extendedRCODE=0 "
"version=0 "
"dnssecOK=False "
"options=[]>",
)
def test_equalityUdpPayloadSize(self):
"""
Two L{OPTHeader} instances compare equal if they have the same
udpPayloadSize.
"""
self.assertNormalEqualityImplementation(
dns._OPTHeader(udpPayloadSize=512),
dns._OPTHeader(udpPayloadSize=512),
dns._OPTHeader(udpPayloadSize=4096),
)
def test_equalityExtendedRCODE(self):
"""
Two L{OPTHeader} instances compare equal if they have the same
extendedRCODE.
"""
self.assertNormalEqualityImplementation(
dns._OPTHeader(extendedRCODE=1),
dns._OPTHeader(extendedRCODE=1),
dns._OPTHeader(extendedRCODE=2),
)
def test_equalityVersion(self):
"""
Two L{OPTHeader} instances compare equal if they have the same
version.
"""
self.assertNormalEqualityImplementation(
dns._OPTHeader(version=1),
dns._OPTHeader(version=1),
dns._OPTHeader(version=2),
)
def test_equalityDnssecOK(self):
"""
Two L{OPTHeader} instances compare equal if they have the same
dnssecOK flags.
"""
self.assertNormalEqualityImplementation(
dns._OPTHeader(dnssecOK=True),
dns._OPTHeader(dnssecOK=True),
dns._OPTHeader(dnssecOK=False),
)
def test_equalityOptions(self):
"""
Two L{OPTHeader} instances compare equal if they have the same
options.
"""
self.assertNormalEqualityImplementation(
dns._OPTHeader(options=[dns._OPTVariableOption(1, b"x")]),
dns._OPTHeader(options=[dns._OPTVariableOption(1, b"x")]),
dns._OPTHeader(options=[dns._OPTVariableOption(2, b"y")]),
)
class OPTVariableOptionTests(ComparisonTestsMixin, unittest.TestCase):
"""
Tests for L{dns._OPTVariableOption}.
"""
def test_interface(self):
"""
L{dns._OPTVariableOption} implements L{dns.IEncodable}.
"""
verifyClass(dns.IEncodable, dns._OPTVariableOption)
def test_constructorArguments(self):
"""
L{dns._OPTVariableOption.__init__} requires code and data
arguments which are saved as public instance attributes.
"""
h = dns._OPTVariableOption(1, b"x")
self.assertEqual(h.code, 1)
self.assertEqual(h.data, b"x")
def test_repr(self):
"""
L{dns._OPTVariableOption.__repr__} displays the code and data
of the option.
"""
self.assertEqual(
repr(dns._OPTVariableOption(1, b"x")),
"<_OPTVariableOption " "code=1 " "data=x" ">",
)
def test_equality(self):
"""
Two OPTVariableOption instances compare equal if they have the same
code and data values.
"""
self.assertNormalEqualityImplementation(
dns._OPTVariableOption(1, b"x"),
dns._OPTVariableOption(1, b"x"),
dns._OPTVariableOption(2, b"x"),
)
self.assertNormalEqualityImplementation(
dns._OPTVariableOption(1, b"x"),
dns._OPTVariableOption(1, b"x"),
dns._OPTVariableOption(1, b"y"),
)
def test_encode(self):
"""
L{dns._OPTVariableOption.encode} encodes the code and data
instance attributes to a byte string which also includes the
data length.
"""
o = dns._OPTVariableOption(1, b"foobar")
b = BytesIO()
o.encode(b)
self.assertEqual(
b.getvalue(),
b"\x00\x01" # OPTION-CODE 1
b"\x00\x06" # OPTION-LENGTH 6
b"foobar", # OPTION-DATA
)
def test_decode(self):
"""
L{dns._OPTVariableOption.decode} is a classmethod that decodes
a byte string and returns a L{dns._OPTVariableOption} instance.
"""
b = BytesIO(
b"\x00\x01" # OPTION-CODE 1
b"\x00\x06" # OPTION-LENGTH 6
b"foobar" # OPTION-DATA
)
o = dns._OPTVariableOption()
o.decode(b)
self.assertEqual(o.code, 1)
self.assertEqual(o.data, b"foobar")
class RaisedArgs(Exception):
"""
An exception which can be raised by fakes to test that the fake is called
with expected arguments.
"""
def __init__(self, args, kwargs):
"""
Store the positional and keyword arguments as attributes.
@param args: The positional args.
@param kwargs: The keyword args.
"""
self.args = args
self.kwargs = kwargs
class MessageEmpty:
"""
Generate byte string and constructor arguments for an empty
L{dns._EDNSMessage}.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x01\x00" # id: 256
b"\x97" # QR: 1, OPCODE: 2, AA: 0, TC: 0, RD: 1
b"\x8f" # RA: 1, Z, RCODE: 15
b"\x00\x00" # number of queries
b"\x00\x00" # number of answers
b"\x00\x00" # number of authorities
b"\x00\x00" # number of additionals
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=256,
answer=True,
opCode=dns.OP_STATUS,
auth=True,
trunc=True,
recDes=True,
recAv=True,
rCode=15,
ednsVersion=None,
)
class MessageTruncated:
"""
An empty response message whose TR bit is set to 1.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x01\x00" # ID: 256
b"\x82" # QR: 1, OPCODE: 0, AA: 0, TC: 1, RD: 0
b"\x00" # RA: 0, Z, RCODE: 0
b"\x00\x00" # Number of queries
b"\x00\x00" # Number of answers
b"\x00\x00" # Number of authorities
b"\x00\x00" # Number of additionals
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=256,
answer=1,
opCode=0,
auth=0,
trunc=1,
recDes=0,
recAv=0,
rCode=0,
ednsVersion=None,
)
class MessageNonAuthoritative:
"""
A minimal non-authoritative message.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x01\x00" # ID 256
b"\x00" # QR: 0, OPCODE: 0, AA: 0, TC: 0, RD: 0
b"\x00" # RA: 0, Z, RCODE: 0
b"\x00\x00" # Query count
b"\x00\x01" # Answer count
b"\x00\x00" # Authorities count
b"\x00\x00" # Additionals count
# Answer
b"\x00" # RR NAME (root)
b"\x00\x01" # RR TYPE 1 (A)
b"\x00\x01" # RR CLASS 1 (IN)
b"\x00\x00\x00\x00" # RR TTL
b"\x00\x04" # RDLENGTH 4
b"\x01\x02\x03\x04" # IPv4 1.2.3.4
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=256,
auth=0,
ednsVersion=None,
answers=[
dns.RRHeader(b"", payload=dns.Record_A("1.2.3.4", ttl=0), auth=False)
],
)
class MessageAuthoritative:
"""
A minimal authoritative message.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x01\x00" # ID: 256
b"\x04" # QR: 0, OPCODE: 0, AA: 1, TC: 0, RD: 0
b"\x00" # RA: 0, Z, RCODE: 0
b"\x00\x00" # Query count
b"\x00\x01" # Answer count
b"\x00\x00" # Authorities count
b"\x00\x00" # Additionals count
# Answer
b"\x00" # RR NAME (root)
b"\x00\x01" # RR TYPE 1 (A)
b"\x00\x01" # RR CLASS 1 (IN)
b"\x00\x00\x00\x00" # RR TTL
b"\x00\x04" # RDLENGTH 4
b"\x01\x02\x03\x04" # IPv4 1.2.3.4
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=256,
auth=1,
ednsVersion=None,
answers=[
dns.RRHeader(b"", payload=dns.Record_A("1.2.3.4", ttl=0), auth=True)
],
)
class MessageComplete:
"""
An example of a fully populated non-edns response message.
Contains name compression, answers, authority, and additional records.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x01\x00" # ID: 256
b"\x95" # QR: 1, OPCODE: 2, AA: 1, TC: 0, RD: 1
b"\x8f" # RA: 1, Z, RCODE: 15
b"\x00\x01" # Query count
b"\x00\x01" # Answer count
b"\x00\x01" # Authorities count
b"\x00\x01" # Additionals count
# Query begins at Byte 12
b"\x07example\x03com\x00" # QNAME
b"\x00\x06" # QTYPE 6 (SOA)
b"\x00\x01" # QCLASS 1 (IN)
# Answers
b"\xc0\x0c" # RR NAME (compression ref b12)
b"\x00\x06" # RR TYPE 6 (SOA)
b"\x00\x01" # RR CLASS 1 (IN)
b"\xff\xff\xff\xff" # RR TTL
b"\x00\x27" # RDLENGTH 39
b"\x03ns1\xc0\x0c" # Mname (ns1.example.com (compression ref b15)
b"\x0ahostmaster\xc0\x0c" # rname (hostmaster.example.com)
b"\xff\xff\xff\xfe" # Serial
b"\x7f\xff\xff\xfd" # Refresh
b"\x7f\xff\xff\xfc" # Retry
b"\x7f\xff\xff\xfb" # Expire
b"\xff\xff\xff\xfa" # Minimum
# Authority
b"\xc0\x0c" # RR NAME (example.com compression ref b12)
b"\x00\x02" # RR TYPE 2 (NS)
b"\x00\x01" # RR CLASS 1 (IN)
b"\xff\xff\xff\xff" # RR TTL
b"\x00\x02" # RDLENGTH
b"\xc0\x29" # RDATA (ns1.example.com (compression ref b41)
# Additional
b"\xc0\x29" # RR NAME (ns1.example.com compression ref b41)
b"\x00\x01" # RR TYPE 1 (A)
b"\x00\x01" # RR CLASS 1 (IN)
b"\xff\xff\xff\xff" # RR TTL
b"\x00\x04" # RDLENGTH
b"\x05\x06\x07\x08" # RDATA 5.6.7.8
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=256,
answer=1,
opCode=dns.OP_STATUS,
auth=1,
recDes=1,
recAv=1,
rCode=15,
ednsVersion=None,
queries=[dns.Query(b"example.com", dns.SOA)],
answers=[
dns.RRHeader(
b"example.com",
type=dns.SOA,
ttl=0xFFFFFFFF,
auth=True,
payload=dns.Record_SOA(
ttl=0xFFFFFFFF,
mname=b"ns1.example.com",
rname=b"hostmaster.example.com",
serial=0xFFFFFFFE,
refresh=0x7FFFFFFD,
retry=0x7FFFFFFC,
expire=0x7FFFFFFB,
minimum=0xFFFFFFFA,
),
)
],
authority=[
dns.RRHeader(
b"example.com",
type=dns.NS,
ttl=0xFFFFFFFF,
auth=True,
payload=dns.Record_NS("ns1.example.com", ttl=0xFFFFFFFF),
)
],
additional=[
dns.RRHeader(
b"ns1.example.com",
type=dns.A,
ttl=0xFFFFFFFF,
auth=True,
payload=dns.Record_A("5.6.7.8", ttl=0xFFFFFFFF),
)
],
)
class MessageEDNSQuery:
"""
A minimal EDNS query message.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x00\x00" # ID: 0
b"\x00" # QR: 0, OPCODE: 0, AA: 0, TC: 0, RD: 0
b"\x00" # RA: 0, Z, RCODE: 0
b"\x00\x01" # Queries count
b"\x00\x00" # Anwers count
b"\x00\x00" # Authority count
b"\x00\x01" # Additionals count
# Queries
b"\x03www\x07example\x03com\x00" # QNAME
b"\x00\x01" # QTYPE (A)
b"\x00\x01" # QCLASS (IN)
# Additional OPT record
b"\x00" # NAME (.)
b"\x00\x29" # TYPE (OPT 41)
b"\x10\x00" # UDP Payload Size (4096)
b"\x00" # Extended RCODE
b"\x03" # EDNS version
b"\x00\x00" # DO: False + Z
b"\x00\x00" # RDLENGTH
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=0,
answer=0,
opCode=dns.OP_QUERY,
auth=0,
recDes=0,
recAv=0,
rCode=0,
ednsVersion=3,
dnssecOK=False,
queries=[dns.Query(b"www.example.com", dns.A)],
additional=[],
)
class MessageEDNSComplete:
"""
An example of a fully populated edns response message.
Contains name compression, answers, authority, and additional records.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x01\x00" # ID: 256
b"\x95" # QR: 1, OPCODE: 2, AA: 1, TC: 0, RD: 1
b"\xbf" # RA: 1, AD: 1, RCODE: 15
b"\x00\x01" # Query count
b"\x00\x01" # Answer count
b"\x00\x01" # Authorities count
b"\x00\x02" # Additionals count
# Query begins at Byte 12
b"\x07example\x03com\x00" # QNAME
b"\x00\x06" # QTYPE 6 (SOA)
b"\x00\x01" # QCLASS 1 (IN)
# Answers
b"\xc0\x0c" # RR NAME (compression ref b12)
b"\x00\x06" # RR TYPE 6 (SOA)
b"\x00\x01" # RR CLASS 1 (IN)
b"\xff\xff\xff\xff" # RR TTL
b"\x00\x27" # RDLENGTH 39
b"\x03ns1\xc0\x0c" # mname (ns1.example.com (compression ref b15)
b"\x0ahostmaster\xc0\x0c" # rname (hostmaster.example.com)
b"\xff\xff\xff\xfe" # Serial
b"\x7f\xff\xff\xfd" # Refresh
b"\x7f\xff\xff\xfc" # Retry
b"\x7f\xff\xff\xfb" # Expire
b"\xff\xff\xff\xfa" # Minimum
# Authority
b"\xc0\x0c" # RR NAME (example.com compression ref b12)
b"\x00\x02" # RR TYPE 2 (NS)
b"\x00\x01" # RR CLASS 1 (IN)
b"\xff\xff\xff\xff" # RR TTL
b"\x00\x02" # RDLENGTH
b"\xc0\x29" # RDATA (ns1.example.com (compression ref b41)
# Additional
b"\xc0\x29" # RR NAME (ns1.example.com compression ref b41)
b"\x00\x01" # RR TYPE 1 (A)
b"\x00\x01" # RR CLASS 1 (IN)
b"\xff\xff\xff\xff" # RR TTL
b"\x00\x04" # RDLENGTH
b"\x05\x06\x07\x08" # RDATA 5.6.7.8
# Additional OPT record
b"\x00" # NAME (.)
b"\x00\x29" # TYPE (OPT 41)
b"\x04\x00" # UDP Payload Size (1024)
b"\x00" # Extended RCODE
b"\x03" # EDNS version
b"\x80\x00" # DO: True + Z
b"\x00\x00" # RDLENGTH
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=256,
answer=1,
opCode=dns.OP_STATUS,
auth=1,
trunc=0,
recDes=1,
recAv=1,
rCode=15,
ednsVersion=3,
dnssecOK=True,
authenticData=True,
checkingDisabled=True,
maxSize=1024,
queries=[dns.Query(b"example.com", dns.SOA)],
answers=[
dns.RRHeader(
b"example.com",
type=dns.SOA,
ttl=0xFFFFFFFF,
auth=True,
payload=dns.Record_SOA(
ttl=0xFFFFFFFF,
mname=b"ns1.example.com",
rname=b"hostmaster.example.com",
serial=0xFFFFFFFE,
refresh=0x7FFFFFFD,
retry=0x7FFFFFFC,
expire=0x7FFFFFFB,
minimum=0xFFFFFFFA,
),
)
],
authority=[
dns.RRHeader(
b"example.com",
type=dns.NS,
ttl=0xFFFFFFFF,
auth=True,
payload=dns.Record_NS("ns1.example.com", ttl=0xFFFFFFFF),
)
],
additional=[
dns.RRHeader(
b"ns1.example.com",
type=dns.A,
ttl=0xFFFFFFFF,
auth=True,
payload=dns.Record_A("5.6.7.8", ttl=0xFFFFFFFF),
)
],
)
class MessageEDNSExtendedRCODE:
"""
An example of an EDNS message with an extended RCODE.
"""
@classmethod
def bytes(cls):
"""
Bytes which are expected when encoding an instance constructed using
C{kwargs} and which are expected to result in an identical instance when
decoded.
@return: The L{bytes} of a wire encoded message.
"""
return (
b"\x00\x00"
b"\x00"
b"\x0c" # RA: 0, Z, RCODE: 12
b"\x00\x00"
b"\x00\x00"
b"\x00\x00"
b"\x00\x01" # 1 additionals
# Additional OPT record
b"\x00"
b"\x00\x29"
b"\x10\x00"
b"\xab" # Extended RCODE: 171
b"\x00"
b"\x00\x00"
b"\x00\x00"
)
@classmethod
def kwargs(cls):
"""
Keyword constructor arguments which are expected to result in an
instance which returns C{bytes} when encoded.
@return: A L{dict} of keyword arguments.
"""
return dict(
id=0,
answer=False,
opCode=dns.OP_QUERY,
auth=False,
trunc=False,
recDes=False,
recAv=False,
rCode=0xABC, # Combined OPT extended RCODE + Message RCODE
ednsVersion=0,
dnssecOK=False,
maxSize=4096,
queries=[],
answers=[],
authority=[],
additional=[],
)
class MessageComparable(FancyEqMixin, FancyStrMixin):
"""
A wrapper around L{dns.Message} which is comparable so that it can be tested
using some of the L{dns._EDNSMessage} tests.
"""
showAttributes = compareAttributes = (
"id",
"answer",
"opCode",
"auth",
"trunc",
"recDes",
"recAv",
"rCode",
"queries",
"answers",
"authority",
"additional",
)
def __init__(self, original):
self.original = original
def __getattr__(self, key):
return getattr(self.original, key)
def verifyConstructorArgument(
testCase, cls, argName, defaultVal, altVal, attrName=None
):
"""
Verify that an attribute has the expected default value and that a
corresponding argument passed to a constructor is assigned to that
attribute.
@param testCase: The L{TestCase} whose assert methods will be
called.
@type testCase: L{unittest.TestCase}
@param cls: The constructor under test.
@type cls: L{type}
@param argName: The name of the constructor argument under test.
@type argName: L{str}
@param defaultVal: The expected default value of C{attrName} /
C{argName}
@type defaultVal: L{object}
@param altVal: A value which is different from the default. Used to
test that supplied constructor arguments are actually assigned to the
correct attribute.
@type altVal: L{object}
@param attrName: The name of the attribute under test if different
from C{argName}. Defaults to C{argName}
@type attrName: L{str}
"""
if attrName is None:
attrName = argName
actual = {}
expected = {"defaultVal": defaultVal, "altVal": altVal}
o = cls()
actual["defaultVal"] = getattr(o, attrName)
o = cls(**{argName: altVal})
actual["altVal"] = getattr(o, attrName)
testCase.assertEqual(expected, actual)
class ConstructorTestsMixin:
"""
Helper methods for verifying default attribute values and corresponding
constructor arguments.
"""
def _verifyConstructorArgument(self, argName, defaultVal, altVal):
"""
Wrap L{verifyConstructorArgument} to provide simpler interface for
testing Message and _EDNSMessage constructor arguments.
@param argName: The name of the constructor argument.
@param defaultVal: The expected default value.
@param altVal: An alternative value which is expected to be assigned to
a correspondingly named attribute.
"""
verifyConstructorArgument(
testCase=self,
cls=self.messageFactory,
argName=argName,
defaultVal=defaultVal,
altVal=altVal,
)
def _verifyConstructorFlag(self, argName, defaultVal):
"""
Wrap L{verifyConstructorArgument} to provide simpler interface for
testing _EDNSMessage constructor flags.
@param argName: The name of the constructor flag argument
@param defaultVal: The expected default value of the flag
"""
assert defaultVal in (True, False)
verifyConstructorArgument(
testCase=self,
cls=self.messageFactory,
argName=argName,
defaultVal=defaultVal,
altVal=not defaultVal,
)
class CommonConstructorTestsMixin:
"""
Tests for constructor arguments and their associated attributes that are
common to both L{twisted.names.dns._EDNSMessage} and L{dns.Message}.
TestCase classes that use this mixin must provide a C{messageFactory} method
which accepts any argment supported by L{dns.Message.__init__}.
TestCases must also mixin ConstructorTestsMixin which provides some custom
assertions for testing constructor arguments.
"""
def test_id(self):
"""
L{dns._EDNSMessage.id} defaults to C{0} and can be overridden in
the constructor.
"""
self._verifyConstructorArgument("id", defaultVal=0, altVal=1)
def test_answer(self):
"""
L{dns._EDNSMessage.answer} defaults to C{False} and can be overridden in
the constructor.
"""
self._verifyConstructorFlag("answer", defaultVal=False)
def test_opCode(self):
"""
L{dns._EDNSMessage.opCode} defaults to L{dns.OP_QUERY} and can be
overridden in the constructor.
"""
self._verifyConstructorArgument(
"opCode", defaultVal=dns.OP_QUERY, altVal=dns.OP_STATUS
)
def test_auth(self):
"""
L{dns._EDNSMessage.auth} defaults to C{False} and can be overridden in
the constructor.
"""
self._verifyConstructorFlag("auth", defaultVal=False)
def test_trunc(self):
"""
L{dns._EDNSMessage.trunc} defaults to C{False} and can be overridden in
the constructor.
"""
self._verifyConstructorFlag("trunc", defaultVal=False)
def test_recDes(self):
"""
L{dns._EDNSMessage.recDes} defaults to C{False} and can be overridden in
the constructor.
"""
self._verifyConstructorFlag("recDes", defaultVal=False)
def test_recAv(self):
"""
L{dns._EDNSMessage.recAv} defaults to C{False} and can be overridden in
the constructor.
"""
self._verifyConstructorFlag("recAv", defaultVal=False)
def test_rCode(self):
"""
L{dns._EDNSMessage.rCode} defaults to C{0} and can be overridden in the
constructor.
"""
self._verifyConstructorArgument("rCode", defaultVal=0, altVal=123)
def test_maxSize(self):
"""
L{dns._EDNSMessage.maxSize} defaults to C{512} and can be overridden in
the constructor.
"""
self._verifyConstructorArgument("maxSize", defaultVal=512, altVal=1024)
def test_queries(self):
"""
L{dns._EDNSMessage.queries} defaults to C{[]}.
"""
self.assertEqual(self.messageFactory().queries, [])
def test_answers(self):
"""
L{dns._EDNSMessage.answers} defaults to C{[]}.
"""
self.assertEqual(self.messageFactory().answers, [])
def test_authority(self):
"""
L{dns._EDNSMessage.authority} defaults to C{[]}.
"""
self.assertEqual(self.messageFactory().authority, [])
def test_additional(self):
"""
L{dns._EDNSMessage.additional} defaults to C{[]}.
"""
self.assertEqual(self.messageFactory().additional, [])
class EDNSMessageConstructorTests(
ConstructorTestsMixin, CommonConstructorTestsMixin, unittest.SynchronousTestCase
):
"""
Tests for L{twisted.names.dns._EDNSMessage} constructor arguments that are
shared with L{dns.Message}.
"""
messageFactory = dns._EDNSMessage
class MessageConstructorTests(
ConstructorTestsMixin, CommonConstructorTestsMixin, unittest.SynchronousTestCase
):
"""
Tests for L{twisted.names.dns.Message} constructor arguments that are shared
with L{dns._EDNSMessage}.
"""
messageFactory = dns.Message
class EDNSMessageSpecificsTests(ConstructorTestsMixin, unittest.SynchronousTestCase):
"""
Tests for L{dns._EDNSMessage}.
These tests are for L{dns._EDNSMessage} APIs which are not shared with
L{dns.Message}.
"""
messageFactory = dns._EDNSMessage
def test_ednsVersion(self):
"""
L{dns._EDNSMessage.ednsVersion} defaults to C{0} and can be overridden
in the constructor.
"""
self._verifyConstructorArgument("ednsVersion", defaultVal=0, altVal=None)
def test_dnssecOK(self):
"""
L{dns._EDNSMessage.dnssecOK} defaults to C{False} and can be overridden
in the constructor.
"""
self._verifyConstructorFlag("dnssecOK", defaultVal=False)
def test_authenticData(self):
"""
L{dns._EDNSMessage.authenticData} defaults to C{False} and can be
overridden in the constructor.
"""
self._verifyConstructorFlag("authenticData", defaultVal=False)
def test_checkingDisabled(self):
"""
L{dns._EDNSMessage.checkingDisabled} defaults to C{False} and can be
overridden in the constructor.
"""
self._verifyConstructorFlag("checkingDisabled", defaultVal=False)
def test_queriesOverride(self):
"""
L{dns._EDNSMessage.queries} can be overridden in the constructor.
"""
msg = self.messageFactory(queries=[dns.Query(b"example.com")])
self.assertEqual(msg.queries, [dns.Query(b"example.com")])
def test_answersOverride(self):
"""
L{dns._EDNSMessage.answers} can be overridden in the constructor.
"""
msg = self.messageFactory(
answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))]
)
self.assertEqual(
msg.answers, [dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))]
)
def test_authorityOverride(self):
"""
L{dns._EDNSMessage.authority} can be overridden in the constructor.
"""
msg = self.messageFactory(
authority=[
dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA())
]
)
self.assertEqual(
msg.authority,
[dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA())],
)
def test_additionalOverride(self):
"""
L{dns._EDNSMessage.authority} can be overridden in the constructor.
"""
msg = self.messageFactory(
additional=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))]
)
self.assertEqual(
msg.additional,
[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))],
)
def test_reprDefaults(self):
"""
L{dns._EDNSMessage.__repr__} omits field values and sections which are
identical to their defaults. The id field value is always shown.
"""
self.assertEqual("<_EDNSMessage id=0>", repr(self.messageFactory()))
def test_reprFlagsIfSet(self):
"""
L{dns._EDNSMessage.__repr__} displays flags if they are L{True}.
"""
m = self.messageFactory(
answer=True,
auth=True,
trunc=True,
recDes=True,
recAv=True,
authenticData=True,
checkingDisabled=True,
dnssecOK=True,
)
self.assertEqual(
"<_EDNSMessage "
"id=0 "
"flags=answer,auth,trunc,recDes,recAv,authenticData,"
"checkingDisabled,dnssecOK"
">",
repr(m),
)
def test_reprNonDefautFields(self):
"""
L{dns._EDNSMessage.__repr__} displays field values if they differ from
their defaults.
"""
m = self.messageFactory(id=10, opCode=20, rCode=30, maxSize=40, ednsVersion=50)
self.assertEqual(
"<_EDNSMessage "
"id=10 "
"opCode=20 "
"rCode=30 "
"maxSize=40 "
"ednsVersion=50"
">",
repr(m),
)
def test_reprNonDefaultSections(self):
"""
L{dns.Message.__repr__} displays sections which differ from their
defaults.
"""
m = self.messageFactory()
m.queries = [1, 2, 3]
m.answers = [4, 5, 6]
m.authority = [7, 8, 9]
m.additional = [10, 11, 12]
self.assertEqual(
"<_EDNSMessage "
"id=0 "
"queries=[1, 2, 3] "
"answers=[4, 5, 6] "
"authority=[7, 8, 9] "
"additional=[10, 11, 12]"
">",
repr(m),
)
def test_fromStrCallsMessageFactory(self):
"""
L{dns._EDNSMessage.fromString} calls L{dns._EDNSMessage._messageFactory}
to create a new L{dns.Message} instance which is used to decode the
supplied bytes.
"""
class FakeMessageFactory:
"""
Fake message factory.
"""
def fromStr(self, *args, **kwargs):
"""
Fake fromStr method which raises the arguments it was passed.
@param args: positional arguments
@param kwargs: keyword arguments
"""
raise RaisedArgs(args, kwargs)
m = dns._EDNSMessage()
m._messageFactory = FakeMessageFactory
dummyBytes = object()
e = self.assertRaises(RaisedArgs, m.fromStr, dummyBytes)
self.assertEqual(((dummyBytes,), {}), (e.args, e.kwargs))
def test_fromStrCallsFromMessage(self):
"""
L{dns._EDNSMessage.fromString} calls L{dns._EDNSMessage._fromMessage}
with a L{dns.Message} instance
"""
m = dns._EDNSMessage()
class FakeMessageFactory:
"""
Fake message factory.
"""
def fromStr(self, bytes):
"""
A noop fake version of fromStr
@param bytes: the bytes to be decoded
"""
fakeMessage = FakeMessageFactory()
m._messageFactory = lambda: fakeMessage
def fakeFromMessage(*args, **kwargs):
raise RaisedArgs(args, kwargs)
m._fromMessage = fakeFromMessage
e = self.assertRaises(RaisedArgs, m.fromStr, b"")
self.assertEqual(((fakeMessage,), {}), (e.args, e.kwargs))
def test_toStrCallsToMessage(self):
"""
L{dns._EDNSMessage.toStr} calls L{dns._EDNSMessage._toMessage}
"""
m = dns._EDNSMessage()
def fakeToMessage(*args, **kwargs):
raise RaisedArgs(args, kwargs)
m._toMessage = fakeToMessage
e = self.assertRaises(RaisedArgs, m.toStr)
self.assertEqual(((), {}), (e.args, e.kwargs))
def test_toStrCallsToMessageToStr(self):
"""
L{dns._EDNSMessage.toStr} calls C{toStr} on the message returned by
L{dns._EDNSMessage._toMessage}.
"""
m = dns._EDNSMessage()
dummyBytes = object()
class FakeMessage:
"""
Fake Message
"""
def toStr(self):
"""
Fake toStr which returns dummyBytes.
@return: dummyBytes
"""
return dummyBytes
def fakeToMessage(*args, **kwargs):
return FakeMessage()
m._toMessage = fakeToMessage
self.assertEqual(dummyBytes, m.toStr())
class EDNSMessageEqualityTests(ComparisonTestsMixin, unittest.SynchronousTestCase):
"""
Tests for equality between L{dns._EDNSMessage} instances.
These tests will not work with L{dns.Message} because it does not use
L{twisted.python.util.FancyEqMixin}.
"""
messageFactory = dns._EDNSMessage
def test_id(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
id.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(id=1),
self.messageFactory(id=1),
self.messageFactory(id=2),
)
def test_answer(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
answer flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(answer=True),
self.messageFactory(answer=True),
self.messageFactory(answer=False),
)
def test_opCode(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
opCode.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(opCode=dns.OP_STATUS),
self.messageFactory(opCode=dns.OP_STATUS),
self.messageFactory(opCode=dns.OP_INVERSE),
)
def test_auth(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
auth flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(auth=True),
self.messageFactory(auth=True),
self.messageFactory(auth=False),
)
def test_trunc(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
trunc flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(trunc=True),
self.messageFactory(trunc=True),
self.messageFactory(trunc=False),
)
def test_recDes(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
recDes flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(recDes=True),
self.messageFactory(recDes=True),
self.messageFactory(recDes=False),
)
def test_recAv(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
recAv flag.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(recAv=True),
self.messageFactory(recAv=True),
self.messageFactory(recAv=False),
)
def test_rCode(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
rCode.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(rCode=16),
self.messageFactory(rCode=16),
self.messageFactory(rCode=15),
)
def test_ednsVersion(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
ednsVersion.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(ednsVersion=1),
self.messageFactory(ednsVersion=1),
self.messageFactory(ednsVersion=None),
)
def test_dnssecOK(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
dnssecOK.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(dnssecOK=True),
self.messageFactory(dnssecOK=True),
self.messageFactory(dnssecOK=False),
)
def test_authenticData(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
authenticData flags.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(authenticData=True),
self.messageFactory(authenticData=True),
self.messageFactory(authenticData=False),
)
def test_checkingDisabled(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
checkingDisabled flags.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(checkingDisabled=True),
self.messageFactory(checkingDisabled=True),
self.messageFactory(checkingDisabled=False),
)
def test_maxSize(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
maxSize.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(maxSize=2048),
self.messageFactory(maxSize=2048),
self.messageFactory(maxSize=1024),
)
def test_queries(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
queries.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(queries=[dns.Query(b"example.com")]),
self.messageFactory(queries=[dns.Query(b"example.com")]),
self.messageFactory(queries=[dns.Query(b"example.org")]),
)
def test_answers(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
answers.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(
answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))]
),
self.messageFactory(
answers=[dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))]
),
self.messageFactory(
answers=[dns.RRHeader(b"example.org", payload=dns.Record_A("4.3.2.1"))]
),
)
def test_authority(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
authority records.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(
authority=[
dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA())
]
),
self.messageFactory(
authority=[
dns.RRHeader(b"example.com", type=dns.SOA, payload=dns.Record_SOA())
]
),
self.messageFactory(
authority=[
dns.RRHeader(b"example.org", type=dns.SOA, payload=dns.Record_SOA())
]
),
)
def test_additional(self):
"""
Two L{dns._EDNSMessage} instances compare equal if they have the same
additional records.
"""
self.assertNormalEqualityImplementation(
self.messageFactory(
additional=[
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))
]
),
self.messageFactory(
additional=[
dns.RRHeader(b"example.com", payload=dns.Record_A("1.2.3.4"))
]
),
self.messageFactory(
additional=[
dns.RRHeader(b"example.org", payload=dns.Record_A("1.2.3.4"))
]
),
)
class StandardEncodingTestsMixin:
"""
Tests for the encoding and decoding of various standard (not EDNS) messages.
These tests should work with both L{dns._EDNSMessage} and L{dns.Message}.
TestCase classes that use this mixin must provide a C{messageFactory} method
which accepts any argment supported by L{dns._EDNSMessage.__init__}.
EDNS specific arguments may be discarded if not supported by the message
class under construction.
"""
def test_emptyMessageEncode(self):
"""
An empty message can be encoded.
"""
self.assertEqual(
self.messageFactory(**MessageEmpty.kwargs()).toStr(), MessageEmpty.bytes()
)
def test_emptyMessageDecode(self):
"""
An empty message byte sequence can be decoded.
"""
m = self.messageFactory()
m.fromStr(MessageEmpty.bytes())
self.assertEqual(m, self.messageFactory(**MessageEmpty.kwargs()))
def test_completeQueryEncode(self):
"""
A fully populated query message can be encoded.
"""
self.assertEqual(
self.messageFactory(**MessageComplete.kwargs()).toStr(),
MessageComplete.bytes(),
)
def test_completeQueryDecode(self):
"""
A fully populated message byte string can be decoded.
"""
m = self.messageFactory()
m.fromStr(MessageComplete.bytes()),
self.assertEqual(m, self.messageFactory(**MessageComplete.kwargs()))
def test_NULL(self):
"""
A I{NULL} record with an arbitrary payload can be encoded and decoded as
part of a message.
"""
bytes = b"".join([dns._ord2bytes(i) for i in range(256)])
rec = dns.Record_NULL(bytes)
rr = dns.RRHeader(b"testname", dns.NULL, payload=rec)
msg1 = self.messageFactory()
msg1.answers.append(rr)
s = msg1.toStr()
msg2 = self.messageFactory()
msg2.fromStr(s)
self.assertIsInstance(msg2.answers[0].payload, dns.Record_NULL)
self.assertEqual(msg2.answers[0].payload.payload, bytes)
def test_nonAuthoritativeMessageEncode(self):
"""
If the message C{authoritative} attribute is set to 0, the encoded bytes
will have AA bit 0.
"""
self.assertEqual(
self.messageFactory(**MessageNonAuthoritative.kwargs()).toStr(),
MessageNonAuthoritative.bytes(),
)
def test_nonAuthoritativeMessageDecode(self):
"""
The L{dns.RRHeader} instances created by a message from a
non-authoritative message byte string are marked as not authoritative.
"""
m = self.messageFactory()
m.fromStr(MessageNonAuthoritative.bytes())
self.assertEqual(m, self.messageFactory(**MessageNonAuthoritative.kwargs()))
def test_authoritativeMessageEncode(self):
"""
If the message C{authoritative} attribute is set to 1, the encoded bytes
will have AA bit 1.
"""
self.assertEqual(
self.messageFactory(**MessageAuthoritative.kwargs()).toStr(),
MessageAuthoritative.bytes(),
)
def test_authoritativeMessageDecode(self):
"""
The message and its L{dns.RRHeader} instances created by C{decode} from
an authoritative message byte string, are marked as authoritative.
"""
m = self.messageFactory()
m.fromStr(MessageAuthoritative.bytes())
self.assertEqual(m, self.messageFactory(**MessageAuthoritative.kwargs()))
def test_truncatedMessageEncode(self):
"""
If the message C{trunc} attribute is set to 1 the encoded bytes will
have TR bit 1.
"""
self.assertEqual(
self.messageFactory(**MessageTruncated.kwargs()).toStr(),
MessageTruncated.bytes(),
)
def test_truncatedMessageDecode(self):
"""
The message instance created by decoding a truncated message is marked
as truncated.
"""
m = self.messageFactory()
m.fromStr(MessageTruncated.bytes())
self.assertEqual(m, self.messageFactory(**MessageTruncated.kwargs()))
class EDNSMessageStandardEncodingTests(
StandardEncodingTestsMixin, unittest.SynchronousTestCase
):
"""
Tests for the encoding and decoding of various standard (non-EDNS) messages
by L{dns._EDNSMessage}.
"""
messageFactory = dns._EDNSMessage
class MessageStandardEncodingTests(
StandardEncodingTestsMixin, unittest.SynchronousTestCase
):
"""
Tests for the encoding and decoding of various standard (non-EDNS) messages
by L{dns.Message}.
"""
@staticmethod
def messageFactory(**kwargs):
"""
This function adapts constructor arguments expected by
_EDNSMessage.__init__ to arguments suitable for use with the
Message.__init__.
Also handles the fact that unlike L{dns._EDNSMessage},
L{dns.Message.__init__} does not accept queries, answers etc as
arguments.
Also removes any L{dns._EDNSMessage} specific arguments.
@param args: The positional arguments which will be passed to
L{dns.Message.__init__}.
@param kwargs: The keyword arguments which will be stripped of EDNS
specific arguments before being passed to L{dns.Message.__init__}.
@return: An L{dns.Message} instance.
"""
queries = kwargs.pop("queries", [])
answers = kwargs.pop("answers", [])
authority = kwargs.pop("authority", [])
additional = kwargs.pop("additional", [])
kwargs.pop("ednsVersion", None)
m = dns.Message(**kwargs)
m.queries = queries
m.answers = answers
m.authority = authority
m.additional = additional
return MessageComparable(m)
class EDNSMessageEDNSEncodingTests(unittest.SynchronousTestCase):
"""
Tests for the encoding and decoding of various EDNS messages.
These test will not work with L{dns.Message}.
"""
messageFactory = dns._EDNSMessage
def test_ednsMessageDecodeStripsOptRecords(self):
"""
The L(_EDNSMessage} instance created by L{dns._EDNSMessage.decode} from
an EDNS query never includes OPT records in the additional section.
"""
m = self.messageFactory()
m.fromStr(MessageEDNSQuery.bytes())
self.assertEqual(m.additional, [])
def test_ednsMessageDecodeMultipleOptRecords(self):
"""
An L(_EDNSMessage} instance created from a byte string containing
multiple I{OPT} records will discard all the C{OPT} records.
C{ednsVersion} will be set to L{None}.
@see: U{https://tools.ietf.org/html/rfc6891#section-6.1.1}
"""
m = dns.Message()
m.additional = [dns._OPTHeader(version=2), dns._OPTHeader(version=3)]
ednsMessage = dns._EDNSMessage()
ednsMessage.fromStr(m.toStr())
self.assertIsNone(ednsMessage.ednsVersion)
def test_fromMessageCopiesSections(self):
"""
L{dns._EDNSMessage._fromMessage} returns an L{_EDNSMessage} instance
whose queries, answers, authority and additional lists are copies (not
references to) the original message lists.
"""
standardMessage = dns.Message()
standardMessage.fromStr(MessageEDNSQuery.bytes())
ednsMessage = dns._EDNSMessage._fromMessage(standardMessage)
duplicates = []
for attrName in ("queries", "answers", "authority", "additional"):
if getattr(standardMessage, attrName) is getattr(ednsMessage, attrName):
duplicates.append(attrName)
if duplicates:
self.fail(
"Message and _EDNSMessage shared references to the following "
"section lists after decoding: %s" % (duplicates,)
)
def test_toMessageCopiesSections(self):
"""
L{dns._EDNSMessage.toStr} makes no in place changes to the message
instance.
"""
ednsMessage = dns._EDNSMessage(ednsVersion=1)
ednsMessage.toStr()
self.assertEqual(ednsMessage.additional, [])
def test_optHeaderPosition(self):
"""
L{dns._EDNSMessage} can decode OPT records, regardless of their position
in the additional records section.
"The OPT RR MAY be placed anywhere within the additional data section."
@see: U{https://tools.ietf.org/html/rfc6891#section-6.1.1}
"""
# XXX: We need an _OPTHeader.toRRHeader method. See #6779.
b = BytesIO()
optRecord = dns._OPTHeader(version=1)
optRecord.encode(b)
optRRHeader = dns.RRHeader()
b.seek(0)
optRRHeader.decode(b)
m = dns.Message()
m.additional = [optRRHeader]
actualMessages = []
actualMessages.append(dns._EDNSMessage._fromMessage(m).ednsVersion)
m.additional.append(dns.RRHeader(type=dns.A))
actualMessages.append(dns._EDNSMessage._fromMessage(m).ednsVersion)
m.additional.insert(0, dns.RRHeader(type=dns.A))
actualMessages.append(dns._EDNSMessage._fromMessage(m).ednsVersion)
self.assertEqual([1] * 3, actualMessages)
def test_ednsDecode(self):
"""
The L(_EDNSMessage} instance created by L{dns._EDNSMessage.fromStr}
derives its edns specific values (C{ednsVersion}, etc) from the supplied
OPT record.
"""
m = self.messageFactory()
m.fromStr(MessageEDNSComplete.bytes())
self.assertEqual(m, self.messageFactory(**MessageEDNSComplete.kwargs()))
def test_ednsEncode(self):
"""
The L(_EDNSMessage} instance created by L{dns._EDNSMessage.toStr}
encodes its edns specific values (C{ednsVersion}, etc) into an OPT
record added to the additional section.
"""
self.assertEqual(
self.messageFactory(**MessageEDNSComplete.kwargs()).toStr(),
MessageEDNSComplete.bytes(),
)
def test_extendedRcodeEncode(self):
"""
The L(_EDNSMessage.toStr} encodes the extended I{RCODE} (>=16) by
assigning the lower 4bits to the message RCODE field and the upper 4bits
to the OPT pseudo record.
"""
self.assertEqual(
self.messageFactory(**MessageEDNSExtendedRCODE.kwargs()).toStr(),
MessageEDNSExtendedRCODE.bytes(),
)
def test_extendedRcodeDecode(self):
"""
The L(_EDNSMessage} instance created by L{dns._EDNSMessage.fromStr}
derives RCODE from the supplied OPT record.
"""
m = self.messageFactory()
m.fromStr(MessageEDNSExtendedRCODE.bytes())
self.assertEqual(m, self.messageFactory(**MessageEDNSExtendedRCODE.kwargs()))
def test_extendedRcodeZero(self):
"""
Note that EXTENDED-RCODE value 0 indicates that an unextended RCODE is
in use (values 0 through 15).
https://tools.ietf.org/html/rfc6891#section-6.1.3
"""
ednsMessage = self.messageFactory(rCode=15, ednsVersion=0)
standardMessage = ednsMessage._toMessage()
self.assertEqual(
(15, 0),
(standardMessage.rCode, standardMessage.additional[0].extendedRCODE),
)
class ResponseFromMessageTests(unittest.SynchronousTestCase):
"""
Tests for L{dns._responseFromMessage}.
"""
def test_responseFromMessageResponseType(self):
"""
L{dns.Message._responseFromMessage} is a constructor function which
generates a new I{answer} message from an existing L{dns.Message} like
instance.
"""
request = dns.Message()
response = dns._responseFromMessage(
responseConstructor=dns.Message, message=request
)
self.assertIsNot(request, response)
def test_responseType(self):
"""
L{dns._responseFromMessage} returns a new instance of C{cls}
"""
class SuppliedClass:
id = 1
queries = []
expectedClass = dns.Message
self.assertIsInstance(
dns._responseFromMessage(
responseConstructor=expectedClass, message=SuppliedClass()
),
expectedClass,
)
def test_responseId(self):
"""
L{dns._responseFromMessage} copies the C{id} attribute of the original
message.
"""
self.assertEqual(
1234,
dns._responseFromMessage(
responseConstructor=dns.Message, message=dns.Message(id=1234)
).id,
)
def test_responseAnswer(self):
"""
L{dns._responseFromMessage} sets the C{answer} flag to L{True}
"""
request = dns.Message()
response = dns._responseFromMessage(
responseConstructor=dns.Message, message=request
)
self.assertEqual((False, True), (request.answer, response.answer))
def test_responseQueries(self):
"""
L{dns._responseFromMessage} copies the C{queries} attribute of the
original message.
"""
request = dns.Message()
expectedQueries = [object(), object(), object()]
request.queries = expectedQueries[:]
self.assertEqual(
expectedQueries,
dns._responseFromMessage(
responseConstructor=dns.Message, message=request
).queries,
)
def test_responseKwargs(self):
"""
L{dns._responseFromMessage} accepts other C{kwargs} which are assigned
to the new message before it is returned.
"""
self.assertEqual(
123,
dns._responseFromMessage(
responseConstructor=dns.Message, message=dns.Message(), rCode=123
).rCode,
)
class Foo:
"""
An example class for use in L{dns._compactRepr} tests.
It follows the pattern of initialiser settable flags, fields and sections
found in L{dns.Message} and L{dns._EDNSMessage}.
"""
def __init__(
self,
field1=1,
field2=2,
alwaysShowField="AS",
flagTrue=True,
flagFalse=False,
section1=None,
):
"""
Set some flags, fields and sections as public attributes.
"""
self.field1 = field1
self.field2 = field2
self.alwaysShowField = alwaysShowField
self.flagTrue = flagTrue
self.flagFalse = flagFalse
if section1 is None:
section1 = []
self.section1 = section1
def __repr__(self) -> str:
"""
Call L{dns._compactRepr} to generate a string representation.
"""
return cast(
str,
dns._compactRepr(
self,
alwaysShow="alwaysShowField".split(),
fieldNames="field1 field2 alwaysShowField".split(),
flagNames="flagTrue flagFalse".split(),
sectionNames="section1 section2".split(),
),
)
class CompactReprTests(unittest.SynchronousTestCase):
"""
Tests for L{dns._compactRepr}.
"""
messageFactory = Foo
def test_defaults(self):
"""
L{dns._compactRepr} omits field values and sections which have the
default value. Flags which are True are always shown.
"""
self.assertEqual(
"<Foo alwaysShowField='AS' flags=flagTrue>", repr(self.messageFactory())
)
def test_flagsIfSet(self):
"""
L{dns._compactRepr} displays flags if they have a non-default value.
"""
m = self.messageFactory(flagTrue=True, flagFalse=True)
self.assertEqual(
"<Foo " "alwaysShowField='AS' " "flags=flagTrue,flagFalse" ">",
repr(m),
)
def test_nonDefautFields(self):
"""
L{dns._compactRepr} displays field values if they differ from their
defaults.
"""
m = self.messageFactory(field1=10, field2=20)
self.assertEqual(
"<Foo "
"field1=10 "
"field2=20 "
"alwaysShowField='AS' "
"flags=flagTrue"
">",
repr(m),
)
def test_nonDefaultSections(self):
"""
L{dns._compactRepr} displays sections which differ from their defaults.
"""
m = self.messageFactory()
m.section1 = [1, 1, 1]
m.section2 = [2, 2, 2]
self.assertEqual(
"<Foo "
"alwaysShowField='AS' "
"flags=flagTrue "
"section1=[1, 1, 1] "
"section2=[2, 2, 2]"
">",
repr(m),
)
Zerion Mini Shell 1.0