1
0
mirror of https://github.com/Tautulli/Tautulli.git synced 2025-03-12 04:35:40 -07:00

Bump dnspython from 2.6.1 to 2.7.0 ()

* Bump dnspython from 2.6.1 to 2.7.0

Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.6.1 to 2.7.0.
- [Release notes](https://github.com/rthalley/dnspython/releases)
- [Changelog](https://github.com/rthalley/dnspython/blob/main/doc/whatsnew.rst)
- [Commits](https://github.com/rthalley/dnspython/compare/v2.6.1...v2.7.0)

---
updated-dependencies:
- dependency-name: dnspython
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update dnspython==2.7.0

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2024-11-19 10:00:50 -08:00 committed by GitHub
parent 0836fb902c
commit feca713b76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 1382 additions and 665 deletions

@ -26,6 +26,10 @@ class NullContext:
class Socket: # pragma: no cover
def __init__(self, family: int, type: int):
self.family = family
self.type = type
async def close(self):
pass
@ -46,9 +50,6 @@ class Socket: # pragma: no cover
class DatagramSocket(Socket): # pragma: no cover
def __init__(self, family: int):
self.family = family
async def sendto(self, what, destination, timeout):
raise NotImplementedError

@ -42,7 +42,7 @@ class _DatagramProtocol:
if exc is None:
# EOF we triggered. Is there a better way to do this?
try:
raise EOFError
raise EOFError("EOF")
except EOFError as e:
self.recvfrom.set_exception(e)
else:
@ -64,7 +64,7 @@ async def _maybe_wait_for(awaitable, timeout):
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
super().__init__(family)
super().__init__(family, socket.SOCK_DGRAM)
self.transport = transport
self.protocol = protocol
@ -99,7 +99,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
self.family = af
super().__init__(af, socket.SOCK_STREAM)
self.reader = reader
self.writer = writer
@ -197,7 +197,7 @@ if dns._features.have("doh"):
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver

@ -32,6 +32,9 @@ def _version_check(
package, minimum = requirement.split(">=")
try:
version = importlib.metadata.version(package)
# This shouldn't happen, but it apparently can.
if version is None:
return False
except Exception:
return False
t_version = _tuple_from_text(version)
@ -82,10 +85,10 @@ def force(feature: str, enabled: bool) -> None:
_requirements: Dict[str, List[str]] = {
### BEGIN generated requirements
"dnssec": ["cryptography>=41"],
"dnssec": ["cryptography>=43"],
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
"doq": ["aioquic>=0.9.25"],
"idna": ["idna>=3.6"],
"doq": ["aioquic>=1.0.0"],
"idna": ["idna>=3.7"],
"trio": ["trio>=0.23"],
"wmi": ["wmi>=1.5.1"],
### END generated requirements

@ -30,13 +30,16 @@ _lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
def __init__(self, sock):
super().__init__(sock.family, socket.SOCK_DGRAM)
self.socket = sock
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
if destination is None:
return await self.socket.send(what)
else:
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
@ -61,7 +64,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False):
self.family = family
super().__init__(family, socket.SOCK_STREAM)
self.stream = stream
self.tls = tls
@ -171,7 +174,7 @@ if dns._features.have("doh"):
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
@ -205,7 +208,7 @@ class Backend(dns._asyncbackend.Backend):
try:
if source:
await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM:
if socktype == socket.SOCK_STREAM or destination is not None:
connected = False
with _maybe_timeout(timeout):
await s.connect(_lltuple(destination, af))

@ -19,10 +19,12 @@
import base64
import contextlib
import random
import socket
import struct
import time
from typing import Any, Dict, Optional, Tuple, Union
import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union, cast
import dns.asyncbackend
import dns.exception
@ -37,9 +39,11 @@ import dns.transaction
from dns._asyncbackend import NullContext
from dns.query import (
BadResponse,
HTTPVersion,
NoDOH,
NoDOQ,
UDPMode,
_check_status,
_compute_times,
_make_dot_ssl_context,
_matches_destination,
@ -338,7 +342,7 @@ async def _read_exactly(sock, count, expiration):
while count > 0:
n = await sock.recv(count, _timeout(expiration))
if n == b"":
raise EOFError
raise EOFError("EOF")
count = count - len(n)
s = s + n
return s
@ -500,6 +504,20 @@ async def tls(
return response
def _maybe_get_resolver(
resolver: Optional["dns.asyncresolver.Resolver"],
) -> "dns.asyncresolver.Resolver":
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
return resolver
async def https(
q: dns.message.Message,
where: str,
@ -515,7 +533,8 @@ async def https(
verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
@ -529,26 +548,65 @@ async def https(
parameters, exceptions, and return type of this method.
"""
if not have_doh:
raise NoDOH # pragma: no cover
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
wire = q.to_wire()
try:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
transport = None
headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path)
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path)
url = f"https://[{where}]:{port}{path}"
else:
url = where
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # for mypy
answers = await resolver.resolve_name(parsed.hostname, family)
bootstrap_address = random.choice(list(answers.addresses()))
return await _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
)
if not have_doh:
raise NoDOH # pragma: no cover
# pylint: disable=possibly-used-before-assignment
if client and not isinstance(client, httpx.AsyncClient):
raise ValueError("session parameter must be an httpx.AsyncClient")
# pylint: enable=possibly-used-before-assignment
wire = q.to_wire()
headers = {"accept": "application/dns-message"}
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
backend = dns.asyncbackend.get_default_backend()
if source is None:
@ -557,24 +615,23 @@ async def https(
else:
local_address = source
local_port = source_port
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if client:
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
http1=True, http2=True, verify=verify, transport=transport
transport = backend.get_transport_class()(
local_address=local_address,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
async with cm as the_client:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
@ -586,23 +643,33 @@ async def https(
}
)
response = await backend.wait_for(
the_client.post(url, headers=headers, content=wire), timeout
the_client.post(
url,
headers=headers,
content=wire,
extensions=extensions,
),
timeout,
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = await backend.wait_for(
the_client.get(url, headers=headers, params={"dns": twire}), timeout
the_client.get(
url,
headers=headers,
params={"dns": twire},
extensions=extensions,
),
timeout,
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError(
"{} responded with status code {}"
"\nResponse body: {!r}".format(
where, response.status_code, response.content
)
f"{where} responded with status code {response.status_code}"
f"\nResponse body: {response.content!r}"
)
r = dns.message.from_wire(
response.content,
@ -617,6 +684,181 @@ async def https(
return r
async def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
post: bool = True,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=hostname, h3=True
) as the_manager:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
# note that send_h3() does not need await
stream.send_h3(url, wire, post)
wire = await stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context,
verify_mode=verify,
server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
async def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: dns.asyncbackend.Socket,
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
await udp_sock.sendto(wire, None, _timeout(expiration))
else:
tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
tcpmsg = struct.pack("!H", len(wire)) + wire
await tcp_sock.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
timeout = _timeout(mexpiration)
(rwire, _) = await udp_sock.recvfrom(65535, timeout)
else:
ldata = await _read_exactly(tcp_sock, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(tcp_sock, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
@ -642,139 +884,30 @@ async def inbound_xfr(
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
if not backend:
backend = dns.asyncbackend.get_default_backend()
(_, expiration) = _compute_times(lifetime)
retry = True
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
if not backend:
backend = dns.asyncbackend.get_default_backend()
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
s = await backend.make_socket(
af, sock_type, 0, stuple, dtuple, _timeout(expiration)
af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
if is_udp:
await s.sendto(wire, dtuple, _timeout(expiration))
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
await s.sendall(tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
destination = _lltuple((where, port), af)
while True:
timeout = _timeout(mexpiration)
(rwire, from_address) = await s.recvfrom(65535, timeout)
if _matches_destination(
af, from_address, destination, True
):
break
else:
ldata = await _read_exactly(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = await _read_exactly(s, l, mexpiration)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
try:
async for _ in _inbound_xfr(
txn_manager, s, query, serial, timeout, expiration
):
pass
return
except dns.xfr.UseTCP:
if udp_mode == UDPMode.ONLY:
raise
async def quic(
q: dns.message.Message,
where: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
DNS-over-QUIC.
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
the default, then dnspython will use the default backend.
See :py:func:`dns.query.quic()` for the documentation of the other
parameters, exceptions, and return type of this method.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
if connection:
cfactory = dns.quic.null_factory
mfactory = dns.quic.null_factory
the_connection = connection
else:
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
async with cfactory() as context:
async with mfactory(
context, verify_mode=verify, server_name=server_hostname
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
await stream.send(wire, True)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
s = await backend.make_socket(
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
)
async with s:
async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
pass

@ -118,6 +118,7 @@ def key_id(key: Union[DNSKEY, CDNSKEY]) -> int:
"""
rdata = key.to_wire()
assert rdata is not None # for mypy
if key.algorithm == Algorithm.RSAMD5:
return (rdata[-3] << 8) + rdata[-2]
else:
@ -224,7 +225,7 @@ def make_ds(
if isinstance(algorithm, str):
algorithm = DSDigest[algorithm.upper()]
except Exception:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
if validating:
check = policy.ok_to_validate_ds
else:
@ -240,14 +241,15 @@ def make_ds(
elif algorithm == DSDigest.SHA384:
dshash = hashlib.sha384()
else:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
if isinstance(name, str):
name = dns.name.from_text(name, origin)
wire = name.canonicalize().to_wire()
assert wire is not None
kwire = key.to_wire(origin=origin)
assert wire is not None and kwire is not None # for mypy
dshash.update(wire)
dshash.update(key.to_wire(origin=origin))
dshash.update(kwire)
digest = dshash.digest()
dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest
@ -323,6 +325,7 @@ def _get_rrname_rdataset(
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
# pylint: disable=possibly-used-before-assignment
public_cls = get_algorithm_cls_from_dnskey(key).public_cls
try:
public_key = public_cls.from_dnskey(key)
@ -387,6 +390,7 @@ def _validate_rrsig(
data = _make_rrsig_signature_data(rrset, rrsig, origin)
# pylint: disable=possibly-used-before-assignment
for candidate_key in candidate_keys:
if not policy.ok_to_validate(candidate_key):
continue
@ -484,6 +488,7 @@ def _sign(
verify: bool = False,
policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
deterministic: bool = True,
) -> RRSIG:
"""Sign RRset using private key.
@ -523,6 +528,10 @@ def _sign(
names in the rrset (including its owner name) must be absolute; otherwise the
specified origin will be used to make names absolute when signing.
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
(reproducible) signatures when supported by the algorithm used for signing.
Currently, this only affects ECDSA.
Raises ``DeniedByPolicy`` if the signature is denied by policy.
"""
@ -580,6 +589,7 @@ def _sign(
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
# pylint: disable=possibly-used-before-assignment
if isinstance(private_key, GenericPrivateKey):
signing_key = private_key
else:
@ -589,7 +599,7 @@ def _sign(
except UnsupportedAlgorithm:
raise TypeError("Unsupported key algorithm")
signature = signing_key.sign(data, verify)
signature = signing_key.sign(data, verify, deterministic)
return cast(RRSIG, rrsig_template.replace(signature=signature))
@ -629,7 +639,9 @@ def _make_rrsig_signature_data(
rrname, rdataset = _get_rrname_rdataset(rrset)
data = b""
data += rrsig.to_wire(origin=signer)[:18]
wire = rrsig.to_wire(origin=signer)
assert wire is not None # for mypy
data += wire[:18]
data += rrsig.signer.to_digestable(signer)
# Derelativize the name before considering labels.
@ -686,6 +698,7 @@ def _make_dnskey(
algorithm = Algorithm.make(algorithm)
# pylint: disable=possibly-used-before-assignment
if isinstance(public_key, GenericPublicKey):
return public_key.to_dnskey(flags=flags, protocol=protocol)
else:
@ -832,7 +845,7 @@ def make_ds_rdataset(
if isinstance(algorithm, str):
algorithm = DSDigest[algorithm.upper()]
except Exception:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
_algorithms.add(algorithm)
if rdataset.rdtype == dns.rdatatype.CDS:
@ -950,6 +963,7 @@ def default_rrset_signer(
lifetime: Optional[int] = None,
policy: Optional[Policy] = None,
origin: Optional[dns.name.Name] = None,
deterministic: bool = True,
) -> None:
"""Default RRset signer"""
@ -975,6 +989,7 @@ def default_rrset_signer(
signer=signer,
policy=policy,
origin=origin,
deterministic=deterministic,
)
txn.add(rrset.name, rrset.ttl, rrsig)
@ -991,6 +1006,7 @@ def sign_zone(
nsec3: Optional[NSEC3PARAM] = None,
rrset_signer: Optional[RRsetSigner] = None,
policy: Optional[Policy] = None,
deterministic: bool = True,
) -> None:
"""Sign zone.
@ -1030,6 +1046,10 @@ def sign_zone(
function requires two arguments: transaction and RRset. If the not specified,
``dns.dnssec.default_rrset_signer`` will be used.
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
(reproducible) signatures when supported by the algorithm used for signing.
Currently, this only affects ECDSA.
Returns ``None``.
"""
@ -1056,6 +1076,9 @@ def sign_zone(
else:
cm = zone.writer()
if zone.origin is None:
raise ValueError("no zone origin")
with cm as _txn:
if add_dnskey:
if dnskey_ttl is None:
@ -1081,6 +1104,7 @@ def sign_zone(
lifetime=lifetime,
policy=policy,
origin=zone.origin,
deterministic=deterministic,
)
return _sign_zone_nsec(zone, _txn, _rrset_signer)

@ -26,6 +26,7 @@ AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
if _have_cryptography:
# pylint: disable=possibly-used-before-assignment
algorithms.update(
{
(Algorithm.RSAMD5, None): PrivateRSAMD5,
@ -59,7 +60,7 @@ def get_algorithm_cls(
if cls:
return cls
raise UnsupportedAlgorithm(
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython'
)

@ -65,7 +65,12 @@ class GenericPrivateKey(ABC):
pass
@abstractmethod
def sign(self, data: bytes, verify: bool = False) -> bytes:
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign DNSSEC data"""
@abstractmethod

@ -68,7 +68,12 @@ class PrivateDSA(CryptographyPrivateKey):
key_cls = dsa.DSAPrivateKey
public_cls = PublicDSA
def sign(self, data: bytes, verify: bool = False) -> bytes:
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 2536, section 3."""
public_dsa_key = self.key.public_key()
if public_dsa_key.key_size > 1024:

@ -47,9 +47,17 @@ class PrivateECDSA(CryptographyPrivateKey):
key_cls = ec.EllipticCurvePrivateKey
public_cls = PublicECDSA
def sign(self, data: bytes, verify: bool = False) -> bytes:
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 6605, section 4."""
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
algorithm = ec.ECDSA(
self.public_cls.chosen_hash, deterministic_signing=deterministic
)
der_signature = self.key.sign(data, algorithm)
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
signature = int.to_bytes(
dsa_r, length=self.public_cls.octets, byteorder="big"

@ -29,7 +29,12 @@ class PublicEDDSA(CryptographyPublicKey):
class PrivateEDDSA(CryptographyPrivateKey):
public_cls: Type[PublicEDDSA]
def sign(self, data: bytes, verify: bool = False) -> bytes:
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 8080, section 4."""
signature = self.key.sign(data)
if verify:

@ -56,7 +56,12 @@ class PrivateRSA(CryptographyPrivateKey):
public_cls = PublicRSA
default_public_exponent = 65537
def sign(self, data: bytes, verify: bool = False) -> bytes:
def sign(
self,
data: bytes,
verify: bool = False,
deterministic: bool = True,
) -> bytes:
"""Sign using a private key per RFC 3110, section 3."""
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
if verify:

@ -52,6 +52,8 @@ class OptionType(dns.enum.IntEnum):
CHAIN = 13
#: EDE (extended-dns-error)
EDE = 15
#: REPORTCHANNEL
REPORTCHANNEL = 18
@classmethod
def _maximum(cls):
@ -222,7 +224,7 @@ class ECSOption(Option): # lgtm[py/missing-equals]
self.addrdata = self.addrdata[:-1] + last
def to_text(self) -> str:
return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}"
@staticmethod
def from_text(text: str) -> Option:
@ -255,10 +257,10 @@ class ECSOption(Option): # lgtm[py/missing-equals]
ecs_text = tokens[0]
elif len(tokens) == 2:
if tokens[0] != optional_prefix:
raise ValueError('could not parse ECS from "{}"'.format(text))
raise ValueError(f'could not parse ECS from "{text}"')
ecs_text = tokens[1]
else:
raise ValueError('could not parse ECS from "{}"'.format(text))
raise ValueError(f'could not parse ECS from "{text}"')
n_slashes = ecs_text.count("/")
if n_slashes == 1:
address, tsrclen = ecs_text.split("/")
@ -266,18 +268,16 @@ class ECSOption(Option): # lgtm[py/missing-equals]
elif n_slashes == 2:
address, tsrclen, tscope = ecs_text.split("/")
else:
raise ValueError('could not parse ECS from "{}"'.format(text))
raise ValueError(f'could not parse ECS from "{text}"')
try:
scope = int(tscope)
except ValueError:
raise ValueError(
"invalid scope " + '"{}": scope must be an integer'.format(tscope)
)
raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer')
try:
srclen = int(tsrclen)
except ValueError:
raise ValueError(
"invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
"invalid srclen " + f'"{tsrclen}": srclen must be an integer'
)
return ECSOption(address, srclen, scope)
@ -430,10 +430,65 @@ class NSIDOption(Option):
return cls(parser.get_remaining())
class CookieOption(Option):
def __init__(self, client: bytes, server: bytes):
super().__init__(dns.edns.OptionType.COOKIE)
self.client = client
self.server = server
if len(client) != 8:
raise ValueError("client cookie must be 8 bytes")
if len(server) != 0 and (len(server) < 8 or len(server) > 32):
raise ValueError("server cookie must be empty or between 8 and 32 bytes")
def to_wire(self, file: Any = None) -> Optional[bytes]:
if file:
file.write(self.client)
if len(self.server) > 0:
file.write(self.server)
return None
else:
return self.client + self.server
def to_text(self) -> str:
client = binascii.hexlify(self.client).decode()
if len(self.server) > 0:
server = binascii.hexlify(self.server).decode()
else:
server = ""
return f"COOKIE {client}{server}"
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_bytes(8), parser.get_remaining())
class ReportChannelOption(Option):
# RFC 9567
def __init__(self, agent_domain: dns.name.Name):
super().__init__(OptionType.REPORTCHANNEL)
self.agent_domain = agent_domain
def to_wire(self, file: Any = None) -> Optional[bytes]:
return self.agent_domain.to_wire(file)
def to_text(self) -> str:
return "REPORTCHANNEL " + self.agent_domain.to_text()
@classmethod
def from_wire_parser(
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
) -> Option:
return cls(parser.get_name())
_type_to_class: Dict[OptionType, Any] = {
OptionType.ECS: ECSOption,
OptionType.EDE: EDEOption,
OptionType.NSID: NSIDOption,
OptionType.COOKIE: CookieOption,
OptionType.REPORTCHANNEL: ReportChannelOption,
}
@ -512,5 +567,6 @@ KEEPALIVE = OptionType.KEEPALIVE
PADDING = OptionType.PADDING
CHAIN = OptionType.CHAIN
EDE = OptionType.EDE
REPORTCHANNEL = OptionType.REPORTCHANNEL
### END generated OptionType constants

@ -81,7 +81,7 @@ class DNSException(Exception):
if kwargs:
assert (
set(kwargs.keys()) == self.supp_kwargs
), "following set of keyword args is required: %s" % (self.supp_kwargs)
), f"following set of keyword args is required: {self.supp_kwargs}"
return kwargs
def _fmt_kwargs(self, **kwargs):

@ -54,7 +54,7 @@ def from_text(text: str) -> Tuple[int, int, int]:
elif c.isdigit():
cur += c
else:
raise dns.exception.SyntaxError("Could not parse %s" % (c))
raise dns.exception.SyntaxError(f"Could not parse {c}")
if state == 0:
raise dns.exception.SyntaxError("no stop value specified")

@ -143,9 +143,7 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
if m is not None:
b = dns.ipv4.inet_aton(m.group(2))
btext = (
"{}:{:02x}{:02x}:{:02x}{:02x}".format(
m.group(1).decode(), b[0], b[1], b[2], b[3]
)
f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}"
).encode()
#
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to

@ -18,9 +18,10 @@
"""DNS Messages"""
import contextlib
import enum
import io
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import dns.edns
import dns.entropy
@ -161,6 +162,7 @@ class Message:
self.index: IndexType = {}
self.errors: List[MessageError] = []
self.time = 0.0
self.wire: Optional[bytes] = None
@property
def question(self) -> List[dns.rrset.RRset]:
@ -220,16 +222,16 @@ class Message:
s = io.StringIO()
s.write("id %d\n" % self.id)
s.write("opcode %s\n" % dns.opcode.to_text(self.opcode()))
s.write("rcode %s\n" % dns.rcode.to_text(self.rcode()))
s.write("flags %s\n" % dns.flags.to_text(self.flags))
s.write(f"opcode {dns.opcode.to_text(self.opcode())}\n")
s.write(f"rcode {dns.rcode.to_text(self.rcode())}\n")
s.write(f"flags {dns.flags.to_text(self.flags)}\n")
if self.edns >= 0:
s.write("edns %s\n" % self.edns)
s.write(f"edns {self.edns}\n")
if self.ednsflags != 0:
s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags))
s.write(f"eflags {dns.flags.edns_to_text(self.ednsflags)}\n")
s.write("payload %d\n" % self.payload)
for opt in self.options:
s.write("option %s\n" % opt.to_text())
s.write(f"option {opt.to_text()}\n")
for name, which in self._section_enum.__members__.items():
s.write(f";{name}\n")
for rrset in self.section_from_number(which):
@ -645,6 +647,7 @@ class Message:
if multi:
self.tsig_ctx = ctx
wire = r.get_wire()
self.wire = wire
if prepend_length:
wire = len(wire).to_bytes(2, "big") + wire
return wire
@ -912,6 +915,14 @@ class Message:
self.flags &= 0x87FF
self.flags |= dns.opcode.to_flags(opcode)
def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]:
"""Return the list of options of the specified type."""
return [option for option in self.options if option.otype == otype]
def extended_errors(self) -> List[dns.edns.EDEOption]:
"""Return the list of Extended DNS Error (EDE) options in the message"""
return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE))
def _get_one_rr_per_rrset(self, value):
# What the caller picked is fine.
return value
@ -1192,9 +1203,9 @@ class _WireReader:
if rdtype == dns.rdatatype.OPT:
self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
elif rdtype == dns.rdatatype.TSIG:
if self.keyring is None:
if self.keyring is None or self.keyring is True:
raise UnknownTSIGKey("got signed message without keyring")
if isinstance(self.keyring, dict):
elif isinstance(self.keyring, dict):
key = self.keyring.get(absolute_name)
if isinstance(key, bytes):
key = dns.tsig.Key(absolute_name, key, rd.algorithm)
@ -1203,19 +1214,20 @@ class _WireReader:
else:
key = self.keyring
if key is None:
raise UnknownTSIGKey("key '%s' unknown" % name)
self.message.keyring = key
self.message.tsig_ctx = dns.tsig.validate(
self.parser.wire,
key,
absolute_name,
rd,
int(time.time()),
self.message.request_mac,
rr_start,
self.message.tsig_ctx,
self.multi,
)
raise UnknownTSIGKey(f"key '{name}' unknown")
if key:
self.message.keyring = key
self.message.tsig_ctx = dns.tsig.validate(
self.parser.wire,
key,
absolute_name,
rd,
int(time.time()),
self.message.request_mac,
rr_start,
self.message.tsig_ctx,
self.multi,
)
self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
else:
rrset = self.message.find_rrset(
@ -1251,6 +1263,7 @@ class _WireReader:
factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
self.message = factory(id=id)
self.message.flags = dns.flags.Flag(flags)
self.message.wire = self.parser.wire
self.initialize_message(self.message)
self.one_rr_per_rrset = self.message._get_one_rr_per_rrset(
self.one_rr_per_rrset
@ -1290,8 +1303,10 @@ def from_wire(
) -> Message:
"""Convert a DNS wire format message into a message object.
*keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message
is signed.
*keyring*, a ``dns.tsig.Key``, ``dict``, ``bool``, or ``None``, the key or keyring
to use if the message is signed. If ``None`` or ``True``, then trying to decode
a message with a TSIG will fail as it cannot be validated. If ``False``, then
TSIG validation is disabled.
*request_mac*, a ``bytes`` or ``None``. If the message is a response to a
TSIG-signed request, *request_mac* should be set to the MAC of that request.
@ -1811,6 +1826,16 @@ def make_query(
return m
class CopyMode(enum.Enum):
"""
How should sections be copied when making an update response?
"""
NOTHING = 0
QUESTION = 1
EVERYTHING = 2
def make_response(
query: Message,
recursion_available: bool = False,
@ -1818,13 +1843,14 @@ def make_response(
fudge: int = 300,
tsig_error: int = 0,
pad: Optional[int] = None,
copy_mode: Optional[CopyMode] = None,
) -> Message:
"""Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all of the infrastructure
required of a response, but none of the content.
The response's question section is a shallow copy of the query's question section,
so the query's question RRsets should not be changed.
Response section(s) which are copied are shallow copies of the matching section(s)
in the query, so the query's RRsets should not be changed.
*query*, a ``dns.message.Message``, the query to respond to.
@ -1837,25 +1863,44 @@ def make_response(
*tsig_error*, an ``int``, the TSIG error.
*pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise
if not ``None`` add padding bytes to make the message size a multiple of *pad*.
Note that if padding is non-zero, an EDNS PADDING option will always be added to the
if not ``None`` add padding bytes to make the message size a multiple of *pad*. Note
that if padding is non-zero, an EDNS PADDING option will always be added to the
message. If ``None``, add padding following RFC 8467, namely if the request is
padded, pad the response to 468 otherwise do not pad.
*copy_mode*, a ``dns.message.CopyMode`` or ``None``, determines how sections are
copied. The default, ``None`` copies sections according to the default for the
message's opcode, which is currently ``dns.message.CopyMode.QUESTION`` for all
opcodes. ``dns.message.CopyMode.QUESTION`` copies only the question section.
``dns.message.CopyMode.EVERYTHING`` copies all sections other than OPT or TSIG
records, which are created appropriately if needed. ``dns.message.CopyMode.NOTHING``
copies no sections; note that this mode is for server testing purposes and is
otherwise not recommended for use. In particular, ``dns.message.is_response()``
will be ``False`` if you create a response this way and the rcode is not
``FORMERR``, ``SERVFAIL``, ``NOTIMP``, or ``REFUSED``.
Returns a ``dns.message.Message`` object whose specific class is appropriate for the
query. For example, if query is a ``dns.update.UpdateMessage``, response will be
too.
query. For example, if query is a ``dns.update.UpdateMessage``, the response will
be one too.
"""
if query.flags & dns.flags.QR:
raise dns.exception.FormError("specified query message is not a query")
factory = _message_factory_from_opcode(query.opcode())
opcode = query.opcode()
factory = _message_factory_from_opcode(opcode)
response = factory(id=query.id)
response.flags = dns.flags.QR | (query.flags & dns.flags.RD)
if recursion_available:
response.flags |= dns.flags.RA
response.set_opcode(query.opcode())
response.question = list(query.question)
response.set_opcode(opcode)
if copy_mode is None:
copy_mode = CopyMode.QUESTION
if copy_mode != CopyMode.NOTHING:
response.question = list(query.question)
if copy_mode == CopyMode.EVERYTHING:
response.answer = list(query.answer)
response.authority = list(query.authority)
response.additional = list(query.additional)
if query.edns >= 0:
if pad is None:
# Set response padding per RFC 8467

@ -59,11 +59,11 @@ class NameRelation(dns.enum.IntEnum):
@classmethod
def _maximum(cls):
return cls.COMMONANCESTOR
return cls.COMMONANCESTOR # pragma: no cover
@classmethod
def _short_name(cls):
return cls.__name__
return cls.__name__ # pragma: no cover
# Backwards compatibility
@ -277,6 +277,7 @@ class IDNA2008Codec(IDNACodec):
raise NoIDNA2008
try:
if self.uts_46:
# pylint: disable=possibly-used-before-assignment
label = idna.uts46_remap(label, False, self.transitional)
return idna.alabel(label)
except idna.IDNAError as e:

@ -168,12 +168,14 @@ class DoHNameserver(Nameserver):
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
want_get: bool = False,
http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
self.http_version = http_version
def kind(self):
return "DoH"
@ -214,6 +216,7 @@ class DoHNameserver(Nameserver):
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)
async def async_query(
@ -238,6 +241,7 @@ class DoHNameserver(Nameserver):
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
http_version=self.http_version,
)

@ -23,11 +23,13 @@ import enum
import errno
import os
import os.path
import random
import selectors
import socket
import struct
import time
from typing import Any, Dict, Optional, Tuple, Union
import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union, cast
import dns._features
import dns.exception
@ -129,7 +131,7 @@ if _have_httpx:
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
if resolver is None and bootstrap_address is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
@ -217,7 +219,7 @@ def _wait_for(fd, readable, writable, _, expiration):
if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
return True
sel = _selector_class()
sel = selectors.DefaultSelector()
events = 0
if readable:
events |= selectors.EVENT_READ
@ -235,26 +237,6 @@ def _wait_for(fd, readable, writable, _, expiration):
raise dns.exception.Timeout
def _set_selector_class(selector_class):
# Internal API. Do not use.
global _selector_class
_selector_class = selector_class
if hasattr(selectors, "PollSelector"):
# Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values).
#
# We ignore typing here as we can't say _selector_class is Any
# on python < 3.8 due to a bug.
_selector_class = selectors.PollSelector # type: ignore
else:
_selector_class = selectors.SelectSelector # type: ignore
def _wait_for_readable(s, expiration):
_wait_for(s, True, False, True, expiration)
@ -355,6 +337,36 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
raise
def _maybe_get_resolver(
resolver: Optional["dns.resolver.Resolver"],
) -> "dns.resolver.Resolver":
# We need a separate method for this to avoid overriding the global
# variable "dns" with the as-yet undefined local variable "dns"
# in https().
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.resolver
resolver = dns.resolver.Resolver()
return resolver
class HTTPVersion(enum.IntEnum):
"""Which version of HTTP should be used?
DEFAULT will select the first version from the list [2, 1.1, 3] that
is available.
"""
DEFAULT = 0
HTTP_1 = 1
H1 = 1
HTTP_2 = 2
H2 = 2
HTTP_3 = 3
H3 = 3
def https(
q: dns.message.Message,
where: str,
@ -370,7 +382,8 @@ def https(
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None,
family: Optional[int] = socket.AF_UNSPEC,
family: int = socket.AF_UNSPEC,
http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
@ -420,27 +433,66 @@ def https(
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved.
*http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
Returns a ``dns.message.Message``.
"""
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = f"https://{where}:{port}{path}"
elif af == socket.AF_INET6:
url = f"https://[{where}]:{port}{path}"
else:
url = where
extensions = {}
if bootstrap_address is None:
# pylint: disable=possibly-used-before-assignment
parsed = urllib.parse.urlparse(url)
if parsed.hostname is None:
raise ValueError("no hostname in URL")
if dns.inet.is_address(parsed.hostname):
bootstrap_address = parsed.hostname
extensions["sni_hostname"] = parsed.hostname
if parsed.port is not None:
port = parsed.port
if http_version == HTTPVersion.H3 or (
http_version == HTTPVersion.DEFAULT and not have_doh
):
if bootstrap_address is None:
resolver = _maybe_get_resolver(resolver)
assert parsed.hostname is not None # for mypy
answers = resolver.resolve_name(parsed.hostname, family)
bootstrap_address = random.choice(list(answers.addresses()))
return _http3(
q,
bootstrap_address,
url,
timeout,
port,
source,
source_port,
one_rr_per_rrset,
ignore_trailing,
verify=verify,
post=post,
)
if not have_doh:
raise NoDOH # pragma: no cover
if session and not isinstance(session, httpx.Client):
raise ValueError("session parameter must be an httpx.Client")
wire = q.to_wire()
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
transport = None
headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path)
elif af == socket.AF_INET6:
url = "https://[{}]:{}{}".format(where, port, path)
else:
url = where
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
# set source port and source address
@ -450,21 +502,22 @@ def https(
else:
local_address = the_source[0]
local_port = the_source[1]
transport = _HTTPTransport(
local_address=local_address,
http1=True,
http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
transport = _HTTPTransport(
local_address=local_address,
http1=h1,
http2=h2,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
resolver=resolver,
family=family,
)
cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
@ -475,20 +528,30 @@ def https(
"content-length": str(len(wire)),
}
)
response = session.post(url, headers=headers, content=wire, timeout=timeout)
response = session.post(
url,
headers=headers,
content=wire,
timeout=timeout,
extensions=extensions,
)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
twire = wire.decode() # httpx does a repr() if we give it bytes
response = session.get(
url, headers=headers, timeout=timeout, params={"dns": twire}
url,
headers=headers,
timeout=timeout,
params={"dns": twire},
extensions=extensions,
)
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
if response.status_code < 200 or response.status_code > 299:
raise ValueError(
"{} responded with status code {}"
"\nResponse body: {}".format(where, response.status_code, response.content)
f"{where} responded with status code {response.status_code}"
f"\nResponse body: {response.content}"
)
r = dns.message.from_wire(
response.content,
@ -503,6 +566,81 @@ def https(
return r
def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
if headers is None:
raise KeyError
for header, value in headers:
if header == name:
return value
raise KeyError
def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
value = _find_header(headers, b":status")
if value is None:
raise SyntaxError("no :status header in response")
status = int(value)
if status < 0:
raise SyntaxError("status is negative")
if status < 200 or status > 299:
error = ""
if len(wire) > 0:
try:
error = ": " + wire.decode()
except Exception:
pass
raise ValueError(f"{peer} responded with status code {status}{error}")
def _http3(
q: dns.message.Message,
where: str,
url: str,
timeout: Optional[float] = None,
port: int = 853,
source: Optional[str] = None,
source_port: int = 0,
one_rr_per_rrset: bool = False,
ignore_trailing: bool = False,
verify: Union[bool, str] = True,
hostname: Optional[str] = None,
post: bool = True,
) -> dns.message.Message:
if not dns.quic.have_quic:
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
url_parts = urllib.parse.urlparse(url)
hostname = url_parts.hostname
if url_parts.port is not None:
port = url_parts.port
q.id = 0
wire = q.to_wire()
manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=hostname, h3=True
)
with manager:
connection = manager.connect(where, port, source, source_port)
(start, expiration) = _compute_times(timeout)
with connection.make_stream(timeout) as stream:
stream.send_h3(url, wire, post)
wire = stream.receive(_remaining(expiration))
_check_status(stream.headers(), where, wire)
finish = time.time()
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
)
r.time = max(finish - start, 0.0)
if not q.is_response(r):
raise BadResponse
return r
def _udp_recv(sock, max_size, expiration):
"""Reads a datagram from the socket.
A Timeout exception will be raised if the operation is not completed
@ -855,7 +993,7 @@ def _net_read(sock, count, expiration):
try:
n = sock.recv(count)
if n == b"":
raise EOFError
raise EOFError("EOF")
count -= len(n)
s += n
except (BlockingIOError, ssl.SSLWantReadError):
@ -1023,6 +1161,7 @@ def tcp(
cm = _make_socket(af, socket.SOCK_STREAM, source)
with cm as s:
if not sock:
# pylint: disable=possibly-used-before-assignment
_connect(s, destination, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(
@ -1188,6 +1327,7 @@ def quic(
ignore_trailing: bool = False,
connection: Optional[dns.quic.SyncQuicConnection] = None,
verify: Union[bool, str] = True,
hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC.
@ -1212,17 +1352,21 @@ def quic(
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message.
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the
connection to use to send the query.
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
to send the query.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
*server_hostname*, a ``str`` containing the server's hostname. The
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
*hostname*, a ``str`` containing the server's hostname or ``None``. The default is
``None``, which means that no hostname is known, and if an SSL context is created,
hostname checking will be disabled. This value is ignored if *url* is not
``None``.
*server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
only, and has the same meaning as *hostname*.
Returns a ``dns.message.Message``.
"""
@ -1230,6 +1374,9 @@ def quic(
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
if server_hostname is not None and hostname is None:
hostname = server_hostname
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.SyncQuicConnection
@ -1238,9 +1385,7 @@ def quic(
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection
else:
manager = dns.quic.SyncQuicManager(
verify_mode=verify, server_name=server_hostname
)
manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
the_manager = manager # for type checking happiness
with manager:
@ -1264,6 +1409,70 @@ def quic(
return r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
s: socket.socket,
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
expiration: float,
) -> Any:
"""Given a socket, does the zone transfer."""
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
is_udp = s.type == socket.SOCK_DGRAM
if is_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
_net_write(s, tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
done = inbound.process_message(r)
yield r
tsig_ctx = r.tsig_ctx
if query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
def xfr(
where: str,
zone: Union[dns.name.Name, str],
@ -1333,134 +1542,52 @@ def xfr(
Returns a generator of ``dns.message.Message`` objects.
"""
class DummyTransactionManager(dns.transaction.TransactionManager):
def __init__(self, origin, relativize):
self.info = (origin, relativize, dns.name.empty if relativize else origin)
def origin_information(self):
return self.info
def get_class(self) -> dns.rdataclass.RdataClass:
raise NotImplementedError # pragma: no cover
def reader(self):
raise NotImplementedError # pragma: no cover
def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
class DummyTransaction:
def nop(self, *args, **kw):
pass
def __getattr__(self, _):
return self.nop
return cast(dns.transaction.Transaction, DummyTransaction())
if isinstance(zone, str):
zone = dns.name.from_text(zone)
rdtype = dns.rdatatype.RdataType.make(rdtype)
q = dns.message.make_query(zone, rdtype, rdclass)
if rdtype == dns.rdatatype.IXFR:
rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial)
q.authority.append(rrset)
rrset = q.find_rrset(
q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
)
soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial)
rrset.add(soa, 0)
if keyring is not None:
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
wire = q.to_wire()
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
(_, expiration) = _compute_times(lifetime)
tm = DummyTransactionManager(zone, relativize)
if use_udp and rdtype != dns.rdatatype.IXFR:
raise ValueError("cannot do a UDP AXFR")
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
with _make_socket(af, sock_type, source) as s:
(_, expiration) = _compute_times(lifetime)
_connect(s, destination, expiration)
l = len(wire)
if use_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration)
done = False
delete_mode = True
expecting_SOA = False
soa_rrset = None
if relativize:
origin = zone
oname = dns.name.empty
else:
origin = None
oname = zone
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if use_udp:
(wire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
wire = _net_read(s, l, mexpiration)
is_ixfr = rdtype == dns.rdatatype.IXFR
r = dns.message.from_wire(
wire,
keyring=q.keyring,
request_mac=q.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=True,
one_rr_per_rrset=is_ixfr,
)
rcode = r.rcode()
if rcode != dns.rcode.NOERROR:
raise TransferError(rcode)
tsig_ctx = r.tsig_ctx
answer_index = 0
if soa_rrset is None:
if not r.answer or r.answer[0].name != oname:
raise dns.exception.FormError("No answer or RRset not for qname")
rrset = r.answer[0]
if rrset.rdtype != dns.rdatatype.SOA:
raise dns.exception.FormError("first RRset is not an SOA")
answer_index = 1
soa_rrset = rrset.copy()
if rdtype == dns.rdatatype.IXFR:
if dns.serial.Serial(soa_rrset[0].serial) <= serial:
#
# We're already up-to-date.
#
done = True
else:
expecting_SOA = True
#
# Process SOAs in the answer section (other than the initial
# SOA in the first message).
#
for rrset in r.answer[answer_index:]:
if done:
raise dns.exception.FormError("answers after final SOA")
if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
if expecting_SOA:
if rrset[0].serial != serial:
raise dns.exception.FormError("IXFR base serial mismatch")
expecting_SOA = False
elif rdtype == dns.rdatatype.IXFR:
delete_mode = not delete_mode
#
# If this SOA RRset is equal to the first we saw then we're
# finished. If this is an IXFR we also check that we're
# seeing the record in the expected part of the response.
#
if rrset == soa_rrset and (
rdtype == dns.rdatatype.AXFR
or (rdtype == dns.rdatatype.IXFR and delete_mode)
):
done = True
elif expecting_SOA:
#
# We made an IXFR request and are expecting another
# SOA RR, but saw something else, so this must be an
# AXFR response.
#
rdtype = dns.rdatatype.AXFR
expecting_SOA = False
if done and q.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
yield r
class UDPMode(enum.IntEnum):
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
NEVER means "never use UDP; always use TCP"
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
"""
NEVER = 0
TRY_FIRST = 1
ONLY = 2
yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
def inbound_xfr(
@ -1514,65 +1641,25 @@ def inbound_xfr(
(query, serial) = dns.xfr.make_query(txn_manager)
else:
serial = dns.xfr.extract_serial_from_query(query)
rdtype = query.question[0].rdtype
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
(af, destination, source) = _destination_and_source(
where, port, source, source_port
)
(_, expiration) = _compute_times(lifetime)
retry = True
while retry:
retry = False
if is_ixfr and udp_mode != UDPMode.NEVER:
sock_type = socket.SOCK_DGRAM
is_udp = True
else:
sock_type = socket.SOCK_STREAM
is_udp = False
with _make_socket(af, sock_type, source) as s:
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
with _make_socket(af, socket.SOCK_DGRAM, source) as s:
_connect(s, destination, expiration)
if is_udp:
_udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", len(wire)) + wire
_net_write(s, tcpmsg, expiration)
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
done = False
tsig_ctx = None
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or (
expiration is not None and mexpiration > expiration
):
mexpiration = expiration
if is_udp:
(rwire, _) = _udp_recv(s, 65535, mexpiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)
rwire = _net_read(s, l, mexpiration)
r = dns.message.from_wire(
rwire,
keyring=query.keyring,
request_mac=query.mac,
xfr=True,
origin=origin,
tsig_ctx=tsig_ctx,
multi=(not is_udp),
one_rr_per_rrset=is_ixfr,
)
try:
done = inbound.process_message(r)
except dns.xfr.UseTCP:
assert is_udp # should not happen if we used TCP!
if udp_mode == UDPMode.ONLY:
raise
done = True
retry = True
udp_mode = UDPMode.NEVER
continue
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
try:
for _ in _inbound_xfr(
txn_manager, s, query, serial, timeout, expiration
):
pass
return
except dns.xfr.UseTCP:
if udp_mode == UDPMode.ONLY:
raise
with _make_socket(af, socket.SOCK_STREAM, source) as s:
_connect(s, destination, expiration)
for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
pass

@ -1,5 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import List, Tuple
import dns._features
import dns.asyncbackend
@ -73,3 +75,6 @@ else: # pragma: no cover
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError
Headers = List[Tuple[bytes, bytes]]

@ -43,12 +43,26 @@ class AsyncioQuicStream(BaseQuicStream):
raise dns.exception.Timeout
self._expecting = 0
async def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.seen_end():
return
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
if self._connection.is_h3():
await self.wait_for_end(expiration)
return self._buffer.get_all()
else:
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
@ -83,6 +97,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
self._wake_pending = False
async def _receiver(self):
try:
@ -104,19 +119,24 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
async with self._wake_timer:
self._wake_timer.notify_all()
await self._wakeup()
except Exception:
pass
finally:
self._done = True
async with self._wake_timer:
self._wake_timer.notify_all()
await self._wakeup()
self._handshake_complete.set()
async def _wakeup(self):
self._wake_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()
async def _wait_for_wake_timer(self):
async with self._wake_timer:
await self._wake_timer.wait()
if not self._wake_pending:
await self._wake_timer.wait()
self._wake_pending = False
async def _sender(self):
await self._socket_created.wait()
@ -140,9 +160,28 @@ class AsyncioQuicConnection(AsyncQuicConnection):
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@ -161,8 +200,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
async with self._wake_timer:
self._wake_timer.notify_all()
await self._wakeup()
def run(self):
if self._closed:
@ -189,8 +227,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
async with self._wake_timer:
self._wake_timer.notify_all()
await self._wakeup()
try:
await self._receiver_task
except asyncio.CancelledError:
@ -203,8 +240,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
class AsyncioQuicManager(AsyncQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True

@ -1,12 +1,16 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import base64
import copy
import functools
import socket
import struct
import time
import urllib
from typing import Any, Optional
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
@ -51,6 +55,12 @@ class Buffer:
self._buffer = self._buffer[amount:]
return data
def get_all(self):
assert self.seen_end()
data = self._buffer
self._buffer = b""
return data
class BaseQuicStream:
def __init__(self, connection, stream_id):
@ -58,10 +68,18 @@ class BaseQuicStream:
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
self._headers = None
self._trailers = None
def id(self):
return self._stream_id
def headers(self):
return self._headers
def trailers(self):
return self._trailers
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
@ -77,16 +95,51 @@ class BaseQuicStream:
return timeout
# Subclass must implement receive() as sync / async and which returns a message
# or raises UnexpectedEOF.
# or raises.
# Subclass must implement send() as sync / async and which takes a message and
# an EOF indicator.
def send_h3(self, url, datagram, post=True):
if not self._connection.is_h3():
raise SyntaxError("cannot send H3 to a non-H3 connection")
url_parts = urllib.parse.urlparse(url)
path = url_parts.path.encode()
if post:
method = b"POST"
else:
method = b"GET"
path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
headers = [
(b":method", method),
(b":scheme", url_parts.scheme.encode()),
(b":authority", url_parts.netloc.encode()),
(b":path", path),
(b"accept", b"application/dns-message"),
]
if post:
headers.extend(
[
(b"content-type", b"application/dns-message"),
(b"content-length", str(len(datagram)).encode()),
]
)
self._connection.send_headers(self._stream_id, headers, not post)
if post:
self._connection.send_data(self._stream_id, datagram, True)
def _encapsulate(self, datagram):
if self._connection.is_h3():
return datagram
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
try:
return self._expecting > 0 and self._buffer.have(self._expecting)
return (
self._expecting > 0 and self._buffer.have(self._expecting)
) or self._buffer.seen_end
except UnexpectedEOF:
return True
@ -97,7 +150,13 @@ class BaseQuicStream:
class BaseQuicConnection:
def __init__(
self, connection, address, port, source=None, source_port=0, manager=None
self,
connection,
address,
port,
source=None,
source_port=0,
manager=None,
):
self._done = False
self._connection = connection
@ -106,6 +165,10 @@ class BaseQuicConnection:
self._closed = False
self._manager = manager
self._streams = {}
if manager.is_h3():
self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
else:
self._h3_conn = None
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
@ -120,9 +183,18 @@ class BaseQuicConnection:
else:
self._source = None
def is_h3(self):
return self._h3_conn is not None
def close_stream(self, stream_id):
del self._streams[stream_id]
def send_headers(self, stream_id, headers, is_end=False):
self._h3_conn.send_headers(stream_id, headers, is_end)
def send_data(self, stream_id, data, is_end=False):
self._h3_conn.send_data(stream_id, data, is_end)
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
@ -148,17 +220,25 @@ class AsyncQuicConnection(BaseQuicConnection):
class BaseQuicManager:
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
def __init__(
self, conf, verify_mode, connection_factory, server_name=None, h3=False
):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
self._tokens = {}
self._h3 = h3
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
if h3:
alpn_protocols = ["h3"]
else:
alpn_protocols = ["doq", "doq-i03"]
conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=["doq", "doq-i03"],
alpn_protocols=alpn_protocols,
verify_mode=verify_mode,
server_name=server_name,
)
@ -167,7 +247,13 @@ class BaseQuicManager:
self._conf = conf
def _connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
connection = self._connections.get((address, port))
if connection is not None:
@ -189,9 +275,24 @@ class BaseQuicManager:
)
else:
session_ticket_handler = None
if want_token:
try:
token = self._tokens.pop((address, port))
# We found a token, so make a configuration that uses it.
conf = copy.copy(conf)
conf.token = token
except KeyError:
# No token
pass
# Whether or not we found a token, we want a handler to save # one.
token_handler = functools.partial(self.save_token, address, port)
else:
token_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
@ -207,6 +308,9 @@ class BaseQuicManager:
except KeyError:
pass
def is_h3(self):
return self._h3
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
@ -218,6 +322,17 @@ class BaseQuicManager:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket
def save_token(self, address, port, token):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._tokens)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._tokens[key]
self._tokens[(address, port)] = token
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):

@ -21,11 +21,9 @@ from dns.quic._common import (
UnexpectedEOF,
)
# Avoid circularity with dns.query
if hasattr(selectors, "PollSelector"):
_selector_class = selectors.PollSelector # type: ignore
else:
_selector_class = selectors.SelectSelector # type: ignore
# Function used to create a socket. Can be overridden if needed in special
# situations.
socket_factory = socket.socket
class SyncQuicStream(BaseQuicStream):
@ -46,14 +44,29 @@ class SyncQuicStream(BaseQuicStream):
raise dns.exception.Timeout
self._expecting = 0
def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.seen_end():
return
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
if self._connection.is_h3():
self.wait_for_end(expiration)
with self._lock:
return self._buffer.get_all()
else:
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
@ -81,7 +94,7 @@ class SyncQuicStream(BaseQuicStream):
class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
if self._source is not None:
try:
self._socket.bind(
@ -118,7 +131,7 @@ class SyncQuicConnection(BaseQuicConnection):
def _worker(self):
try:
sel = _selector_class()
sel = selectors.DefaultSelector()
sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
while not self._done:
@ -140,6 +153,7 @@ class SyncQuicConnection(BaseQuicConnection):
finally:
with self._lock:
self._done = True
self._socket.close()
# Ensure anyone waiting for this gets woken up.
self._handshake_complete.set()
@ -150,10 +164,29 @@ class SyncQuicConnection(BaseQuicConnection):
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(h3_event.data, h3_event.stream_ended)
else:
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@ -170,6 +203,18 @@ class SyncQuicConnection(BaseQuicConnection):
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
def send_headers(self, stream_id, headers, is_end=False):
with self._lock:
super().send_headers(stream_id, headers, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def send_data(self, stream_id, data, is_end=False):
with self._lock:
super().send_data(stream_id, data, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def run(self):
if self._closed:
return
@ -203,16 +248,24 @@ class SyncQuicConnection(BaseQuicConnection):
class SyncQuicManager(BaseQuicManager):
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
self._lock = threading.Lock()
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
with self._lock:
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
address, port, source, source_port, want_session_ticket, want_token
)
if start:
connection.run()
@ -226,6 +279,10 @@ class SyncQuicManager(BaseQuicManager):
with self._lock:
super().save_session_ticket(address, port, ticket)
def save_token(self, address, port, token):
with self._lock:
super().save_token(address, port, token)
def __enter__(self):
return self

@ -36,16 +36,27 @@ class TrioQuicStream(BaseQuicStream):
await self._wake_up.wait()
self._expecting = 0
async def wait_for_end(self):
while True:
if self._buffer.seen_end():
return
async with self._wake_up:
await self._wake_up.wait()
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
if self._connection.is_h3():
await self.wait_for_end()
return self._buffer.get_all()
else:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
raise dns.exception.Timeout
async def send(self, datagram, is_end=False):
@ -115,6 +126,7 @@ class TrioQuicConnection(AsyncQuicConnection):
await self._socket.send(datagram)
finally:
self._done = True
self._socket.close()
self._handshake_complete.set()
async def _handle_events(self):
@ -124,9 +136,28 @@ class TrioQuicConnection(AsyncQuicConnection):
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@ -183,9 +214,14 @@ class TrioQuicConnection(AsyncQuicConnection):
class TrioQuicManager(AsyncQuicManager):
def __init__(
self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
self,
nursery,
conf=None,
verify_mode=ssl.CERT_REQUIRED,
server_name=None,
h3=False,
):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
self._nursery = nursery
def connect(

@ -214,7 +214,7 @@ class Rdata:
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> bytes:
) -> None:
raise NotImplementedError # pragma: no cover
def to_wire(
@ -223,14 +223,19 @@ class Rdata:
compress: Optional[dns.name.CompressType] = None,
origin: Optional[dns.name.Name] = None,
canonicalize: bool = False,
) -> bytes:
) -> Optional[bytes]:
"""Convert an rdata to wire format.
Returns a ``bytes`` or ``None``.
Returns a ``bytes`` if no output file was specified, or ``None`` otherwise.
"""
if file:
return self._to_wire(file, compress, origin, canonicalize)
# We call _to_wire() and then return None explicitly instead of
# of just returning the None from _to_wire() as mypy's func-returns-value
# unhelpfully errors out with "error: "_to_wire" of "Rdata" does not return
# a value (it only ever returns None)"
self._to_wire(file, compress, origin, canonicalize)
return None
else:
f = io.BytesIO()
self._to_wire(f, compress, origin, canonicalize)
@ -253,8 +258,9 @@ class Rdata:
Returns a ``bytes``.
"""
return self.to_wire(origin=origin, canonicalize=True)
wire = self.to_wire(origin=origin, canonicalize=True)
assert wire is not None # for mypy
return wire
def __repr__(self):
covers = self.covers()
@ -434,15 +440,11 @@ class Rdata:
continue
if key not in parameters:
raise AttributeError(
"'{}' object has no attribute '{}'".format(
self.__class__.__name__, key
)
f"'{self.__class__.__name__}' object has no attribute '{key}'"
)
if key in ("rdclass", "rdtype"):
raise AttributeError(
"Cannot overwrite '{}' attribute '{}'".format(
self.__class__.__name__, key
)
f"Cannot overwrite '{self.__class__.__name__}' attribute '{key}'"
)
# Construct the parameter list. For each field, use the value in
@ -646,13 +648,14 @@ _rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType],
{}
)
_module_prefix = "dns.rdtypes"
_dynamic_load_allowed = True
def get_rdata_class(rdclass, rdtype):
def get_rdata_class(rdclass, rdtype, use_generic=True):
cls = _rdata_classes.get((rdclass, rdtype))
if not cls:
cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype))
if not cls:
if not cls and _dynamic_load_allowed:
rdclass_text = dns.rdataclass.to_text(rdclass)
rdtype_text = dns.rdatatype.to_text(rdtype)
rdtype_text = rdtype_text.replace("-", "_")
@ -670,12 +673,36 @@ def get_rdata_class(rdclass, rdtype):
_rdata_classes[(rdclass, rdtype)] = cls
except ImportError:
pass
if not cls:
if not cls and use_generic:
cls = GenericRdata
_rdata_classes[(rdclass, rdtype)] = cls
return cls
def load_all_types(disable_dynamic_load=True):
"""Load all rdata types for which dnspython has a non-generic implementation.
Normally dnspython loads DNS rdatatype implementations on demand, but in some
specialized cases loading all types at an application-controlled time is preferred.
If *disable_dynamic_load*, a ``bool``, is ``True`` then dnspython will not attempt
to use its dynamic loading mechanism if an unknown type is subsequently encountered,
and will simply use the ``GenericRdata`` class.
"""
# Load class IN and ANY types.
for rdtype in dns.rdatatype.RdataType:
get_rdata_class(dns.rdataclass.IN, rdtype, False)
# Load the one non-ANY implementation we have in CH. Everything
# else in CH is an ANY type, and we'll discover those on demand but won't
# have to import anything.
get_rdata_class(dns.rdataclass.CH, dns.rdatatype.A, False)
if disable_dynamic_load:
# Now disable dynamic loading so any subsequent unknown type immediately becomes
# GenericRdata without a load attempt.
global _dynamic_load_allowed
_dynamic_load_allowed = False
def from_text(
rdclass: Union[dns.rdataclass.RdataClass, str],
rdtype: Union[dns.rdatatype.RdataType, str],

@ -160,7 +160,7 @@ class Rdataset(dns.set.Set):
return s[:100] + "..."
return s
return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self)
return "[" + ", ".join(f"<{maybe_truncate(str(rr))}>" for rr in self) + "]"
def __repr__(self):
if self.covers == 0:
@ -248,12 +248,8 @@ class Rdataset(dns.set.Set):
# (which is meaningless anyway).
#
s.write(
"{}{}{} {}\n".format(
ntext,
pad,
dns.rdataclass.to_text(rdclass),
dns.rdatatype.to_text(self.rdtype),
)
f"{ntext}{pad}{dns.rdataclass.to_text(rdclass)} "
f"{dns.rdatatype.to_text(self.rdtype)}\n"
)
else:
for rd in self:

@ -105,6 +105,8 @@ class RdataType(dns.enum.IntEnum):
CAA = 257
AVC = 258
AMTRELAY = 260
RESINFO = 261
WALLET = 262
TA = 32768
DLV = 32769
@ -125,7 +127,7 @@ class RdataType(dns.enum.IntEnum):
if text.find("-") >= 0:
try:
return cls[text.replace("-", "_")]
except KeyError:
except KeyError: # pragma: no cover
pass
return _registered_by_text.get(text)
@ -326,6 +328,8 @@ URI = RdataType.URI
CAA = RdataType.CAA
AVC = RdataType.AVC
AMTRELAY = RdataType.AMTRELAY
RESINFO = RdataType.RESINFO
WALLET = RdataType.WALLET
TA = RdataType.TA
DLV = RdataType.DLV

@ -75,8 +75,9 @@ class GPOS(dns.rdata.Rdata):
raise dns.exception.FormError("bad longitude")
def to_text(self, origin=None, relativize=True, **kw):
return "{} {} {}".format(
self.latitude.decode(), self.longitude.decode(), self.altitude.decode()
return (
f"{self.latitude.decode()} {self.longitude.decode()} "
f"{self.altitude.decode()}"
)
@classmethod

@ -37,9 +37,7 @@ class HINFO(dns.rdata.Rdata):
self.os = self._as_bytes(os, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
return '"{}" "{}"'.format(
dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
)
return f'"{dns.rdata._escapify(self.cpu)}" "{dns.rdata._escapify(self.os)}"'
@classmethod
def from_text(

@ -48,7 +48,7 @@ class HIP(dns.rdata.Rdata):
for server in self.servers:
servers.append(server.choose_relativity(origin, relativize))
if len(servers) > 0:
text += " " + " ".join((x.to_unicode() for x in servers))
text += " " + " ".join(x.to_unicode() for x in servers)
return "%u %s %s%s" % (self.algorithm, hit, key, text)
@classmethod

@ -38,11 +38,12 @@ class ISDN(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
if self.subaddress:
return '"{}" "{}"'.format(
dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)
return (
f'"{dns.rdata._escapify(self.address)}" '
f'"{dns.rdata._escapify(self.subaddress)}"'
)
else:
return '"%s"' % dns.rdata._escapify(self.address)
return f'"{dns.rdata._escapify(self.address)}"'
@classmethod
def from_text(

@ -44,7 +44,7 @@ def _exponent_of(what, desc):
exp = i - 1
break
if exp is None or exp < 0:
raise dns.exception.SyntaxError("%s value out of bounds" % desc)
raise dns.exception.SyntaxError(f"{desc} value out of bounds")
return exp
@ -83,10 +83,10 @@ def _encode_size(what, desc):
def _decode_size(what, desc):
exponent = what & 0x0F
if exponent > 9:
raise dns.exception.FormError("bad %s exponent" % desc)
raise dns.exception.FormError(f"bad {desc} exponent")
base = (what & 0xF0) >> 4
if base > 9:
raise dns.exception.FormError("bad %s base" % desc)
raise dns.exception.FormError(f"bad {desc} base")
return base * pow(10, exponent)
@ -184,10 +184,9 @@ class LOC(dns.rdata.Rdata):
or self.horizontal_precision != _default_hprec
or self.vertical_precision != _default_vprec
):
text += " {:0.2f}m {:0.2f}m {:0.2f}m".format(
self.size / 100.0,
self.horizontal_precision / 100.0,
self.vertical_precision / 100.0,
text += (
f" {self.size / 100.0:0.2f}m {self.horizontal_precision / 100.0:0.2f}m"
f" {self.vertical_precision / 100.0:0.2f}m"
)
return text

@ -44,7 +44,7 @@ class NSEC(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
next = self.next.choose_relativity(origin, relativize)
text = Bitmap(self.windows).to_text()
return "{}{}".format(next, text)
return f"{next}{text}"
@classmethod
def from_text(

@ -0,0 +1,24 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class RESINFO(dns.rdtypes.txtbase.TXTBase):
"""RESINFO record"""

@ -37,7 +37,7 @@ class RP(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
mbox = self.mbox.choose_relativity(origin, relativize)
txt = self.txt.choose_relativity(origin, relativize)
return "{} {}".format(str(mbox), str(txt))
return f"{str(mbox)} {str(txt)}"
@classmethod
def from_text(

@ -69,7 +69,7 @@ class TKEY(dns.rdata.Rdata):
dns.rdata._base64ify(self.key, 0),
)
if len(self.other) > 0:
text += " %s" % (dns.rdata._base64ify(self.other, 0))
text += f" {dns.rdata._base64ify(self.other, 0)}"
return text

@ -0,0 +1,9 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import dns.immutable
import dns.rdtypes.txtbase
@dns.immutable.immutable
class WALLET(dns.rdtypes.txtbase.TXTBase):
"""WALLET record"""

@ -36,7 +36,7 @@ class X25(dns.rdata.Rdata):
self.address = self._as_bytes(address, True, 255)
def to_text(self, origin=None, relativize=True, **kw):
return '"%s"' % dns.rdata._escapify(self.address)
return f'"{dns.rdata._escapify(self.address)}"'
@classmethod
def from_text(

@ -51,6 +51,7 @@ __all__ = [
"OPENPGPKEY",
"OPT",
"PTR",
"RESINFO",
"RP",
"RRSIG",
"RT",
@ -63,6 +64,7 @@ __all__ = [
"TSIG",
"TXT",
"URI",
"WALLET",
"X25",
"ZONEMD",
]

@ -37,7 +37,7 @@ class A(dns.rdata.Rdata):
def to_text(self, origin=None, relativize=True, **kw):
domain = self.domain.choose_relativity(origin, relativize)
return "%s %o" % (domain, self.address)
return f"{domain} {self.address:o}"
@classmethod
def from_text(

@ -36,7 +36,7 @@ class NSAP(dns.rdata.Rdata):
self.address = self._as_bytes(address)
def to_text(self, origin=None, relativize=True, **kw):
return "0x%s" % binascii.hexlify(self.address).decode()
return f"0x{binascii.hexlify(self.address).decode()}"
@classmethod
def from_text(

@ -36,7 +36,7 @@ class EUIBase(dns.rdata.Rdata):
self.eui = self._as_bytes(eui)
if len(self.eui) != self.byte_len:
raise dns.exception.FormError(
"EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len)
f"EUI{self.byte_len * 8} rdata has to have {self.byte_len} bytes"
)
def to_text(self, origin=None, relativize=True, **kw):
@ -49,16 +49,16 @@ class EUIBase(dns.rdata.Rdata):
text = tok.get_string()
if len(text) != cls.text_len:
raise dns.exception.SyntaxError(
"Input text must have %s characters" % cls.text_len
f"Input text must have {cls.text_len} characters"
)
for i in range(2, cls.byte_len * 3 - 1, 3):
if text[i] != "-":
raise dns.exception.SyntaxError("Dash expected at position %s" % i)
raise dns.exception.SyntaxError(f"Dash expected at position {i}")
text = text.replace("-", "")
try:
data = binascii.unhexlify(text.encode())
except (ValueError, TypeError) as ex:
raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex))
raise dns.exception.SyntaxError(f"Hex decoding error: {str(ex)}")
return cls(rdclass, rdtype, data)
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):

@ -35,6 +35,7 @@ class ParamKey(dns.enum.IntEnum):
ECH = 5
IPV6HINT = 6
DOHPATH = 7
OHTTP = 8
@classmethod
def _maximum(cls):
@ -396,6 +397,36 @@ class ECHParam(Param):
file.write(self.ech)
@dns.immutable.immutable
class OHTTPParam(Param):
# We don't ever expect to instantiate this class, but we need
# a from_value() and a from_wire_parser(), so we just return None
# from the class methods when things are OK.
@classmethod
def emptiness(cls):
return Emptiness.ALWAYS
@classmethod
def from_value(cls, value):
if value is None or value == "":
return None
else:
raise ValueError("ohttp with non-empty value")
def to_text(self):
raise NotImplementedError # pragma: no cover
@classmethod
def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
if parser.remaining() != 0:
raise dns.exception.FormError
return None
def to_wire(self, file, origin=None): # pylint: disable=W0613
raise NotImplementedError # pragma: no cover
_class_for_key = {
ParamKey.MANDATORY: MandatoryParam,
ParamKey.ALPN: ALPNParam,
@ -404,6 +435,7 @@ _class_for_key = {
ParamKey.IPV4HINT: IPv4HintParam,
ParamKey.ECH: ECHParam,
ParamKey.IPV6HINT: IPv6HintParam,
ParamKey.OHTTP: OHTTPParam,
}

@ -50,6 +50,8 @@ class TXTBase(dns.rdata.Rdata):
self.strings: Tuple[bytes] = self._as_tuple(
strings, lambda x: self._as_bytes(x, True, 255)
)
if len(self.strings) == 0:
raise ValueError("the list of strings must not be empty")
def to_text(
self,
@ -60,7 +62,7 @@ class TXTBase(dns.rdata.Rdata):
txt = ""
prefix = ""
for s in self.strings:
txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s))
txt += f'{prefix}"{dns.rdata._escapify(s)}"'
prefix = " "
return txt

@ -231,7 +231,7 @@ def weighted_processing_order(iterable):
total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
while len(rdatas) > 1:
r = random.uniform(0, total)
for n, rdata in enumerate(rdatas):
for n, rdata in enumerate(rdatas): # noqa: B007
weight = rdata._processing_weight() or _no_weight
if weight > r:
break

@ -36,6 +36,7 @@ import dns.ipv4
import dns.ipv6
import dns.message
import dns.name
import dns.rdata
import dns.nameserver
import dns.query
import dns.rcode
@ -45,7 +46,7 @@ import dns.rdtypes.svcbbase
import dns.reversename
import dns.tsig
if sys.platform == "win32":
if sys.platform == "win32": # pragma: no cover
import dns.win32util
@ -83,7 +84,7 @@ class NXDOMAIN(dns.exception.DNSException):
else:
msg = "The DNS query name does not exist"
qnames = ", ".join(map(str, qnames))
return "{}: {}".format(msg, qnames)
return f"{msg}: {qnames}"
@property
def canonical_name(self):
@ -96,7 +97,7 @@ class NXDOMAIN(dns.exception.DNSException):
cname = response.canonical_name()
if cname != qname:
return cname
except Exception:
except Exception: # pragma: no cover
# We can just eat this exception as it means there was
# something wrong with the response.
pass
@ -154,7 +155,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
"""Turn a resolution errors trace into a list of text."""
texts = []
for err in errors:
texts.append("Server {} answered {}".format(err[0], err[3]))
texts.append(f"Server {err[0]} answered {err[3]}")
return texts
@ -162,7 +163,7 @@ class LifetimeTimeout(dns.exception.Timeout):
"""The resolution lifetime expired."""
msg = "The resolution lifetime expired."
fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1]
fmt = f"{msg[:-1]} after {{timeout:.3f}} seconds: {{errors}}"
supp_kwargs = {"timeout", "errors"}
# We do this as otherwise mypy complains about unexpected keyword argument
@ -211,7 +212,7 @@ class NoNameservers(dns.exception.DNSException):
"""
msg = "All nameservers failed to answer the query."
fmt = "%s {query}: {errors}" % msg[:-1]
fmt = f"{msg[:-1]} {{query}}: {{errors}}"
supp_kwargs = {"request", "errors"}
# We do this as otherwise mypy complains about unexpected keyword argument
@ -297,7 +298,7 @@ class Answer:
def __len__(self) -> int:
return self.rrset and len(self.rrset) or 0
def __iter__(self):
def __iter__(self) -> Iterator[dns.rdata.Rdata]:
return self.rrset and iter(self.rrset) or iter(tuple())
def __getitem__(self, i):
@ -334,7 +335,7 @@ class HostAnswers(Answers):
answers[dns.rdatatype.A] = v4
return answers
# Returns pairs of (address, family) from this result, potentiallys
# Returns pairs of (address, family) from this result, potentially
# filtering by address family.
def addresses_and_families(
self, family: int = socket.AF_UNSPEC
@ -347,7 +348,7 @@ class HostAnswers(Answers):
answer = self.get(dns.rdatatype.AAAA)
elif family == socket.AF_INET:
answer = self.get(dns.rdatatype.A)
else:
else: # pragma: no cover
raise NotImplementedError(f"unknown address family {family}")
if answer:
for rdata in answer:
@ -938,7 +939,7 @@ class BaseResolver:
self.reset()
if configure:
if sys.platform == "win32":
if sys.platform == "win32": # pragma: no cover
self.read_registry()
elif filename:
self.read_resolv_conf(filename)
@ -947,7 +948,7 @@ class BaseResolver:
"""Reset all resolver configuration to the defaults."""
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
if len(self.domain) == 0:
if len(self.domain) == 0: # pragma: no cover
self.domain = dns.name.root
self._nameservers = []
self.nameserver_ports = {}
@ -1040,7 +1041,7 @@ class BaseResolver:
# setter logic, with additonal checking and enrichment.
self.nameservers = nameservers
def read_registry(self) -> None:
def read_registry(self) -> None: # pragma: no cover
"""Extract resolver configuration from the Windows registry."""
try:
info = dns.win32util.get_dns_info() # type: ignore
@ -1205,9 +1206,7 @@ class BaseResolver:
enriched_nameservers.append(enriched_nameserver)
else:
raise ValueError(
"nameservers must be a list or tuple (not a {})".format(
type(nameservers)
)
f"nameservers must be a list or tuple (not a {type(nameservers)})"
)
return enriched_nameservers
@ -1431,7 +1430,7 @@ class Resolver(BaseResolver):
elif family == socket.AF_INET6:
v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
return HostAnswers.make(v6=v6)
elif family != socket.AF_UNSPEC:
elif family != socket.AF_UNSPEC: # pragma: no cover
raise NotImplementedError(f"unknown address family {family}")
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
@ -1515,7 +1514,7 @@ class Resolver(BaseResolver):
nameservers = dns._ddr._get_nameservers_sync(answer, timeout)
if len(nameservers) > 0:
self.nameservers = nameservers
except Exception:
except Exception: # pragma: no cover
pass
@ -1640,7 +1639,7 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
return get_default_resolver().canonical_name(name)
def try_ddr(lifetime: float = 5.0) -> None:
def try_ddr(lifetime: float = 5.0) -> None: # pragma: no cover
"""Try to update the default resolver's nameservers using Discovery of Designated
Resolvers (DDR). If successful, the resolver will subsequently use
DNS-over-HTTPS or DNS-over-TLS for future queries.
@ -1926,7 +1925,7 @@ def _getnameinfo(sockaddr, flags=0):
family = socket.AF_INET
tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0)
if len(tuples) > 1:
raise socket.error("sockaddr resolved to multiple addresses")
raise OSError("sockaddr resolved to multiple addresses")
addr = tuples[0][4][0]
if flags & socket.NI_DGRAM:
pname = "udp"
@ -1961,7 +1960,7 @@ def _getfqdn(name=None):
(name, _, _) = _gethostbyaddr(name)
# Python's version checks aliases too, but our gethostbyname
# ignores them, so we do so here as well.
except Exception:
except Exception: # pragma: no cover
pass
return name

@ -21,10 +21,11 @@ import itertools
class Set:
"""A simple set class.
This class was originally used to deal with sets being missing in
ancient versions of python, but dnspython will continue to use it
as these sets are based on lists and are thus indexable, and this
ability is widely used in dnspython applications.
This class was originally used to deal with python not having a set class, and
originally the class used lists in its implementation. The ordered and indexable
nature of RRsets and Rdatasets is unfortunately widely used in dnspython
applications, so for backwards compatibility sets continue to be a custom class, now
based on an ordered dictionary.
"""
__slots__ = ["items"]
@ -43,7 +44,7 @@ class Set:
self.add(item) # lgtm[py/init-calls-subclass]
def __repr__(self):
return "dns.set.Set(%s)" % repr(list(self.items.keys()))
return f"dns.set.Set({repr(list(self.items.keys()))})" # pragma: no cover
def add(self, item):
"""Add an item to the set."""

@ -528,7 +528,7 @@ class Tokenizer:
if value < 0 or value > 65535:
if base == 8:
raise dns.exception.SyntaxError(
"%o is not an octal unsigned 16-bit integer" % value
f"{value:o} is not an octal unsigned 16-bit integer"
)
else:
raise dns.exception.SyntaxError(

@ -486,7 +486,7 @@ class Transaction:
if exact:
raise DeleteNotExact(f"{method}: missing rdataset")
else:
self._delete_rdataset(name, rdtype, covers)
self._checked_delete_rdataset(name, rdtype, covers)
return
else:
rdataset = self._rdataset_from_args(method, True, args)
@ -529,8 +529,6 @@ class Transaction:
def _end(self, commit):
self._check_ended()
if self._ended:
raise AlreadyEnded
try:
self._end_transaction(commit)
finally:

@ -73,7 +73,7 @@ def from_text(text: str) -> int:
elif c == "s":
total += current
else:
raise BadTTL("unknown unit '%s'" % c)
raise BadTTL(f"unknown unit '{c}'")
current = 0
need_digit = True
if not current == 0:

@ -20,9 +20,9 @@
#: MAJOR
MAJOR = 2
#: MINOR
MINOR = 6
MINOR = 7
#: MICRO
MICRO = 1
MICRO = 0
#: RELEASELEVEL
RELEASELEVEL = 0x0F
#: SERIAL

@ -13,8 +13,8 @@ if sys.platform == "win32":
# Keep pylint quiet on non-windows.
try:
WindowsError is None # pylint: disable=used-before-assignment
except KeyError:
_ = WindowsError # pylint: disable=used-before-assignment
except NameError:
WindowsError = Exception
if dns._features.have("wmi"):
@ -44,6 +44,7 @@ if sys.platform == "win32":
if _have_wmi:
class _WMIGetter(threading.Thread):
# pylint: disable=possibly-used-before-assignment
def __init__(self):
super().__init__()
self.info = DnsInfo()
@ -82,32 +83,21 @@ if sys.platform == "win32":
def __init__(self):
self.info = DnsInfo()
def _determine_split_char(self, entry):
#
# The windows registry irritatingly changes the list element
# delimiter in between ' ' and ',' (and vice-versa) in various
# versions of windows.
#
if entry.find(" ") >= 0:
split_char = " "
elif entry.find(",") >= 0:
split_char = ","
else:
# probably a singleton; treat as a space-separated list.
split_char = " "
return split_char
def _split(self, text):
# The windows registry has used both " " and "," as a delimiter, and while
# it is currently using "," in Windows 10 and later, updates can seemingly
# leave a space in too, e.g. "a, b". So we just convert all commas to
# spaces, and use split() in its default configuration, which splits on
# all whitespace and ignores empty strings.
return text.replace(",", " ").split()
def _config_nameservers(self, nameservers):
split_char = self._determine_split_char(nameservers)
ns_list = nameservers.split(split_char)
for ns in ns_list:
for ns in self._split(nameservers):
if ns not in self.info.nameservers:
self.info.nameservers.append(ns)
def _config_search(self, search):
split_char = self._determine_split_char(search)
search_list = search.split(split_char)
for s in search_list:
for s in self._split(search):
s = _config_domain(s)
if s not in self.info.search:
self.info.search.append(s)
@ -164,7 +154,7 @@ if sys.platform == "win32":
lm,
r"SYSTEM\CurrentControlSet\Control\Network"
r"\{4D36E972-E325-11CE-BFC1-08002BE10318}"
r"\%s\Connection" % guid,
rf"\{guid}\Connection",
)
try:
@ -177,7 +167,7 @@ if sys.platform == "win32":
raise ValueError # pragma: no cover
device_key = winreg.OpenKey(
lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id
lm, rf"SYSTEM\CurrentControlSet\Enum\{pnp_id}"
)
try:
@ -232,7 +222,7 @@ if sys.platform == "win32":
self._config_fromkey(key, False)
finally:
key.Close()
except EnvironmentError:
except OSError:
break
finally:
interfaces.Close()

@ -33,7 +33,7 @@ class TransferError(dns.exception.DNSException):
"""A zone transfer response got a non-zero rcode."""
def __init__(self, rcode):
message = "Zone transfer error: %s" % dns.rcode.to_text(rcode)
message = f"Zone transfer error: {dns.rcode.to_text(rcode)}"
super().__init__(message)
self.rcode = rcode

@ -230,7 +230,7 @@ class Reader:
try:
rdtype = dns.rdatatype.from_text(token.value)
except Exception:
raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value)
raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'")
try:
rd = dns.rdata.from_text(
@ -251,9 +251,7 @@ class Reader:
# We convert them to syntax errors so that we can emit
# helpful filename:line info.
(ty, va) = sys.exc_info()[:2]
raise dns.exception.SyntaxError(
"caught exception {}: {}".format(str(ty), str(va))
)
raise dns.exception.SyntaxError(f"caught exception {str(ty)}: {str(va)}")
if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
# The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
@ -281,41 +279,41 @@ class Reader:
# Sometimes there are modifiers in the hostname. These come after
# the dollar sign. They are in the form: ${offset[,width[,base]]}.
# Make names
mod = ""
sign = "+"
offset = "0"
width = "0"
base = "d"
g1 = is_generate1.match(side)
if g1:
mod, sign, offset, width, base = g1.groups()
if sign == "":
sign = "+"
g2 = is_generate2.match(side)
if g2:
mod, sign, offset = g2.groups()
if sign == "":
sign = "+"
width = 0
base = "d"
g3 = is_generate3.match(side)
if g3:
mod, sign, offset, width = g3.groups()
if sign == "":
sign = "+"
base = "d"
else:
g2 = is_generate2.match(side)
if g2:
mod, sign, offset = g2.groups()
if sign == "":
sign = "+"
width = "0"
base = "d"
else:
g3 = is_generate3.match(side)
if g3:
mod, sign, offset, width = g3.groups()
if sign == "":
sign = "+"
base = "d"
if not (g1 or g2 or g3):
mod = ""
sign = "+"
offset = 0
width = 0
base = "d"
offset = int(offset)
width = int(width)
ioffset = int(offset)
iwidth = int(width)
if sign not in ["+", "-"]:
raise dns.exception.SyntaxError("invalid offset sign %s" % sign)
raise dns.exception.SyntaxError(f"invalid offset sign {sign}")
if base not in ["d", "o", "x", "X", "n", "N"]:
raise dns.exception.SyntaxError("invalid type %s" % base)
raise dns.exception.SyntaxError(f"invalid type {base}")
return mod, sign, offset, width, base
return mod, sign, ioffset, iwidth, base
def _generate_line(self):
# range lhs [ttl] [class] type rhs [ comment ]
@ -377,7 +375,7 @@ class Reader:
if not token.is_identifier():
raise dns.exception.SyntaxError
except Exception:
raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value)
raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'")
# rhs (required)
rhs = token.value
@ -412,8 +410,8 @@ class Reader:
lzfindex = _format_index(lindex, lbase, lwidth)
rzfindex = _format_index(rindex, rbase, rwidth)
name = lhs.replace("$%s" % (lmod), lzfindex)
rdata = rhs.replace("$%s" % (rmod), rzfindex)
name = lhs.replace(f"${lmod}", lzfindex)
rdata = rhs.replace(f"${rmod}", rzfindex)
self.last_name = dns.name.from_text(
name, self.current_origin, self.tok.idna_codec
@ -445,7 +443,7 @@ class Reader:
# helpful filename:line info.
(ty, va) = sys.exc_info()[:2]
raise dns.exception.SyntaxError(
"caught exception %s: %s" % (str(ty), str(va))
f"caught exception {str(ty)}: {str(va)}"
)
self.txn.add(name, ttl, rd)
@ -528,7 +526,7 @@ class Reader:
self.default_ttl_known,
)
)
self.current_file = open(filename, "r")
self.current_file = open(filename)
self.tok = dns.tokenizer.Tokenizer(self.current_file, filename)
self.current_origin = new_origin
elif c == "$GENERATE":

@ -7,7 +7,7 @@ cheroot==10.0.1
cherrypy==18.10.0
cloudinary==1.41.0
distro==1.9.0
dnspython==2.6.1
dnspython==2.7.0
facebook-sdk==3.1.0
future==1.0.0
ga4mp==2.0.4