mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-03-12 04:35:40 -07:00
Bump dnspython from 2.6.1 to 2.7.0 (#2440)
* Bump dnspython from 2.6.1 to 2.7.0 Bumps [dnspython](https://github.com/rthalley/dnspython) from 2.6.1 to 2.7.0. - [Release notes](https://github.com/rthalley/dnspython/releases) - [Changelog](https://github.com/rthalley/dnspython/blob/main/doc/whatsnew.rst) - [Commits](https://github.com/rthalley/dnspython/compare/v2.6.1...v2.7.0) --- updated-dependencies: - dependency-name: dnspython dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update dnspython==2.7.0 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
0836fb902c
commit
feca713b76
lib/dns
_asyncbackend.py_asyncio_backend.py_features.py_trio_backend.pyasyncquery.pydnssec.py
requirements.txtdnssecalgs
edns.pyexception.pygrange.pyipv6.pymessage.pyname.pynameserver.pyquery.pyquic
rdata.pyrdataset.pyrdatatype.pyrdtypes
resolver.pyset.pytokenizer.pytransaction.pyttl.pyversion.pywin32util.pyxfr.pyzonefile.py@ -26,6 +26,10 @@ class NullContext:
|
||||
|
||||
|
||||
class Socket: # pragma: no cover
|
||||
def __init__(self, family: int, type: int):
|
||||
self.family = family
|
||||
self.type = type
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
@ -46,9 +50,6 @@ class Socket: # pragma: no cover
|
||||
|
||||
|
||||
class DatagramSocket(Socket): # pragma: no cover
|
||||
def __init__(self, family: int):
|
||||
self.family = family
|
||||
|
||||
async def sendto(self, what, destination, timeout):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -42,7 +42,7 @@ class _DatagramProtocol:
|
||||
if exc is None:
|
||||
# EOF we triggered. Is there a better way to do this?
|
||||
try:
|
||||
raise EOFError
|
||||
raise EOFError("EOF")
|
||||
except EOFError as e:
|
||||
self.recvfrom.set_exception(e)
|
||||
else:
|
||||
@ -64,7 +64,7 @@ async def _maybe_wait_for(awaitable, timeout):
|
||||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, family, transport, protocol):
|
||||
super().__init__(family)
|
||||
super().__init__(family, socket.SOCK_DGRAM)
|
||||
self.transport = transport
|
||||
self.protocol = protocol
|
||||
|
||||
@ -99,7 +99,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
def __init__(self, af, reader, writer):
|
||||
self.family = af
|
||||
super().__init__(af, socket.SOCK_STREAM)
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
|
||||
@ -197,7 +197,7 @@ if dns._features.have("doh"):
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None:
|
||||
if resolver is None and bootstrap_address is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
|
@ -32,6 +32,9 @@ def _version_check(
|
||||
package, minimum = requirement.split(">=")
|
||||
try:
|
||||
version = importlib.metadata.version(package)
|
||||
# This shouldn't happen, but it apparently can.
|
||||
if version is None:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
t_version = _tuple_from_text(version)
|
||||
@ -82,10 +85,10 @@ def force(feature: str, enabled: bool) -> None:
|
||||
|
||||
_requirements: Dict[str, List[str]] = {
|
||||
### BEGIN generated requirements
|
||||
"dnssec": ["cryptography>=41"],
|
||||
"dnssec": ["cryptography>=43"],
|
||||
"doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
|
||||
"doq": ["aioquic>=0.9.25"],
|
||||
"idna": ["idna>=3.6"],
|
||||
"doq": ["aioquic>=1.0.0"],
|
||||
"idna": ["idna>=3.7"],
|
||||
"trio": ["trio>=0.23"],
|
||||
"wmi": ["wmi>=1.5.1"],
|
||||
### END generated requirements
|
||||
|
@ -30,13 +30,16 @@ _lltuple = dns.inet.low_level_address_tuple
|
||||
|
||||
|
||||
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
def __init__(self, socket):
|
||||
super().__init__(socket.family)
|
||||
self.socket = socket
|
||||
def __init__(self, sock):
|
||||
super().__init__(sock.family, socket.SOCK_DGRAM)
|
||||
self.socket = sock
|
||||
|
||||
async def sendto(self, what, destination, timeout):
|
||||
with _maybe_timeout(timeout):
|
||||
return await self.socket.sendto(what, destination)
|
||||
if destination is None:
|
||||
return await self.socket.send(what)
|
||||
else:
|
||||
return await self.socket.sendto(what, destination)
|
||||
raise dns.exception.Timeout(
|
||||
timeout=timeout
|
||||
) # pragma: no cover lgtm[py/unreachable-statement]
|
||||
@ -61,7 +64,7 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
||||
|
||||
class StreamSocket(dns._asyncbackend.StreamSocket):
|
||||
def __init__(self, family, stream, tls=False):
|
||||
self.family = family
|
||||
super().__init__(family, socket.SOCK_STREAM)
|
||||
self.stream = stream
|
||||
self.tls = tls
|
||||
|
||||
@ -171,7 +174,7 @@ if dns._features.have("doh"):
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None:
|
||||
if resolver is None and bootstrap_address is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
@ -205,7 +208,7 @@ class Backend(dns._asyncbackend.Backend):
|
||||
try:
|
||||
if source:
|
||||
await s.bind(_lltuple(source, af))
|
||||
if socktype == socket.SOCK_STREAM:
|
||||
if socktype == socket.SOCK_STREAM or destination is not None:
|
||||
connected = False
|
||||
with _maybe_timeout(timeout):
|
||||
await s.connect(_lltuple(destination, af))
|
||||
|
@ -19,10 +19,12 @@
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import random
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
import dns.asyncbackend
|
||||
import dns.exception
|
||||
@ -37,9 +39,11 @@ import dns.transaction
|
||||
from dns._asyncbackend import NullContext
|
||||
from dns.query import (
|
||||
BadResponse,
|
||||
HTTPVersion,
|
||||
NoDOH,
|
||||
NoDOQ,
|
||||
UDPMode,
|
||||
_check_status,
|
||||
_compute_times,
|
||||
_make_dot_ssl_context,
|
||||
_matches_destination,
|
||||
@ -338,7 +342,7 @@ async def _read_exactly(sock, count, expiration):
|
||||
while count > 0:
|
||||
n = await sock.recv(count, _timeout(expiration))
|
||||
if n == b"":
|
||||
raise EOFError
|
||||
raise EOFError("EOF")
|
||||
count = count - len(n)
|
||||
s = s + n
|
||||
return s
|
||||
@ -500,6 +504,20 @@ async def tls(
|
||||
return response
|
||||
|
||||
|
||||
def _maybe_get_resolver(
|
||||
resolver: Optional["dns.asyncresolver.Resolver"],
|
||||
) -> "dns.asyncresolver.Resolver":
|
||||
# We need a separate method for this to avoid overriding the global
|
||||
# variable "dns" with the as-yet undefined local variable "dns"
|
||||
# in https().
|
||||
if resolver is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.asyncresolver
|
||||
|
||||
resolver = dns.asyncresolver.Resolver()
|
||||
return resolver
|
||||
|
||||
|
||||
async def https(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
@ -515,7 +533,8 @@ async def https(
|
||||
verify: Union[bool, str] = True,
|
||||
bootstrap_address: Optional[str] = None,
|
||||
resolver: Optional["dns.asyncresolver.Resolver"] = None,
|
||||
family: Optional[int] = socket.AF_UNSPEC,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
http_version: HTTPVersion = HTTPVersion.DEFAULT,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||
|
||||
@ -529,26 +548,65 @@ async def https(
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
if not have_doh:
|
||||
raise NoDOH # pragma: no cover
|
||||
if client and not isinstance(client, httpx.AsyncClient):
|
||||
raise ValueError("session parameter must be an httpx.AsyncClient")
|
||||
|
||||
wire = q.to_wire()
|
||||
try:
|
||||
af = dns.inet.af_for_address(where)
|
||||
except ValueError:
|
||||
af = None
|
||||
transport = None
|
||||
headers = {"accept": "application/dns-message"}
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = "https://{}:{}{}".format(where, port, path)
|
||||
url = f"https://{where}:{port}{path}"
|
||||
elif af == socket.AF_INET6:
|
||||
url = "https://[{}]:{}{}".format(where, port, path)
|
||||
url = f"https://[{where}]:{port}{path}"
|
||||
else:
|
||||
url = where
|
||||
|
||||
extensions = {}
|
||||
if bootstrap_address is None:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.hostname is None:
|
||||
raise ValueError("no hostname in URL")
|
||||
if dns.inet.is_address(parsed.hostname):
|
||||
bootstrap_address = parsed.hostname
|
||||
extensions["sni_hostname"] = parsed.hostname
|
||||
if parsed.port is not None:
|
||||
port = parsed.port
|
||||
|
||||
if http_version == HTTPVersion.H3 or (
|
||||
http_version == HTTPVersion.DEFAULT and not have_doh
|
||||
):
|
||||
if bootstrap_address is None:
|
||||
resolver = _maybe_get_resolver(resolver)
|
||||
assert parsed.hostname is not None # for mypy
|
||||
answers = await resolver.resolve_name(parsed.hostname, family)
|
||||
bootstrap_address = random.choice(list(answers.addresses()))
|
||||
return await _http3(
|
||||
q,
|
||||
bootstrap_address,
|
||||
url,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
verify=verify,
|
||||
post=post,
|
||||
)
|
||||
|
||||
if not have_doh:
|
||||
raise NoDOH # pragma: no cover
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
if client and not isinstance(client, httpx.AsyncClient):
|
||||
raise ValueError("session parameter must be an httpx.AsyncClient")
|
||||
# pylint: enable=possibly-used-before-assignment
|
||||
|
||||
wire = q.to_wire()
|
||||
headers = {"accept": "application/dns-message"}
|
||||
|
||||
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
|
||||
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
|
||||
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
|
||||
if source is None:
|
||||
@ -557,24 +615,23 @@ async def https(
|
||||
else:
|
||||
local_address = source
|
||||
local_port = source_port
|
||||
transport = backend.get_transport_class()(
|
||||
local_address=local_address,
|
||||
http1=True,
|
||||
http2=True,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
if client:
|
||||
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
|
||||
else:
|
||||
cm = httpx.AsyncClient(
|
||||
http1=True, http2=True, verify=verify, transport=transport
|
||||
transport = backend.get_transport_class()(
|
||||
local_address=local_address,
|
||||
http1=h1,
|
||||
http2=h2,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
|
||||
|
||||
async with cm as the_client:
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||
# GET and POST examples
|
||||
@ -586,23 +643,33 @@ async def https(
|
||||
}
|
||||
)
|
||||
response = await backend.wait_for(
|
||||
the_client.post(url, headers=headers, content=wire), timeout
|
||||
the_client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
content=wire,
|
||||
extensions=extensions,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
else:
|
||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = await backend.wait_for(
|
||||
the_client.get(url, headers=headers, params={"dns": twire}), timeout
|
||||
the_client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params={"dns": twire},
|
||||
extensions=extensions,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||
# status codes
|
||||
if response.status_code < 200 or response.status_code > 299:
|
||||
raise ValueError(
|
||||
"{} responded with status code {}"
|
||||
"\nResponse body: {!r}".format(
|
||||
where, response.status_code, response.content
|
||||
)
|
||||
f"{where} responded with status code {response.status_code}"
|
||||
f"\nResponse body: {response.content!r}"
|
||||
)
|
||||
r = dns.message.from_wire(
|
||||
response.content,
|
||||
@ -617,6 +684,181 @@ async def https(
|
||||
return r
|
||||
|
||||
|
||||
async def _http3(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
url: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 853,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
verify: Union[bool, str] = True,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
hostname: Optional[str] = None,
|
||||
post: bool = True,
|
||||
) -> dns.message.Message:
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
|
||||
|
||||
url_parts = urllib.parse.urlparse(url)
|
||||
hostname = url_parts.hostname
|
||||
if url_parts.port is not None:
|
||||
port = url_parts.port
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||
|
||||
async with cfactory() as context:
|
||||
async with mfactory(
|
||||
context, verify_mode=verify, server_name=hostname, h3=True
|
||||
) as the_manager:
|
||||
the_connection = the_manager.connect(where, port, source, source_port)
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
stream = await the_connection.make_stream(timeout)
|
||||
async with stream:
|
||||
# note that send_h3() does not need await
|
||||
stream.send_h3(url, wire, post)
|
||||
wire = await stream.receive(_remaining(expiration))
|
||||
_check_status(stream.headers(), where, wire)
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
async def quic(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 853,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
connection: Optional[dns.quic.AsyncQuicConnection] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
hostname: Optional[str] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending an asynchronous query via
|
||||
DNS-over-QUIC.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.quic()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
|
||||
|
||||
if server_hostname is not None and hostname is None:
|
||||
hostname = server_hostname
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
the_connection: dns.quic.AsyncQuicConnection
|
||||
if connection:
|
||||
cfactory = dns.quic.null_factory
|
||||
mfactory = dns.quic.null_factory
|
||||
the_connection = connection
|
||||
else:
|
||||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||
|
||||
async with cfactory() as context:
|
||||
async with mfactory(
|
||||
context,
|
||||
verify_mode=verify,
|
||||
server_name=server_hostname,
|
||||
) as the_manager:
|
||||
if not connection:
|
||||
the_connection = the_manager.connect(where, port, source, source_port)
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
stream = await the_connection.make_stream(timeout)
|
||||
async with stream:
|
||||
await stream.send(wire, True)
|
||||
wire = await stream.receive(_remaining(expiration))
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
async def _inbound_xfr(
|
||||
txn_manager: dns.transaction.TransactionManager,
|
||||
s: dns.asyncbackend.Socket,
|
||||
query: dns.message.Message,
|
||||
serial: Optional[int],
|
||||
timeout: Optional[float],
|
||||
expiration: float,
|
||||
) -> Any:
|
||||
"""Given a socket, does the zone transfer."""
|
||||
rdtype = query.question[0].rdtype
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
origin = txn_manager.from_wire_origin()
|
||||
wire = query.to_wire()
|
||||
is_udp = s.type == socket.SOCK_DGRAM
|
||||
if is_udp:
|
||||
udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
|
||||
await udp_sock.sendto(wire, None, _timeout(expiration))
|
||||
else:
|
||||
tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
await tcp_sock.sendall(tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
timeout = _timeout(mexpiration)
|
||||
(rwire, _) = await udp_sock.recvfrom(65535, timeout)
|
||||
else:
|
||||
ldata = await _read_exactly(tcp_sock, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = await _read_exactly(tcp_sock, l, mexpiration)
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
done = inbound.process_message(r)
|
||||
yield r
|
||||
tsig_ctx = r.tsig_ctx
|
||||
if query.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
|
||||
|
||||
async def inbound_xfr(
|
||||
where: str,
|
||||
txn_manager: dns.transaction.TransactionManager,
|
||||
@ -642,139 +884,30 @@ async def inbound_xfr(
|
||||
(query, serial) = dns.xfr.make_query(txn_manager)
|
||||
else:
|
||||
serial = dns.xfr.extract_serial_from_query(query)
|
||||
rdtype = query.question[0].rdtype
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
origin = txn_manager.from_wire_origin()
|
||||
wire = query.to_wire()
|
||||
af = dns.inet.af_for_address(where)
|
||||
stuple = _source_tuple(af, source, source_port)
|
||||
dtuple = (where, port)
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
retry = True
|
||||
while retry:
|
||||
retry = False
|
||||
if is_ixfr and udp_mode != UDPMode.NEVER:
|
||||
sock_type = socket.SOCK_DGRAM
|
||||
is_udp = True
|
||||
else:
|
||||
sock_type = socket.SOCK_STREAM
|
||||
is_udp = False
|
||||
if not backend:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
|
||||
s = await backend.make_socket(
|
||||
af, sock_type, 0, stuple, dtuple, _timeout(expiration)
|
||||
af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
|
||||
)
|
||||
async with s:
|
||||
if is_udp:
|
||||
await s.sendto(wire, dtuple, _timeout(expiration))
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
await s.sendall(tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
destination = _lltuple((where, port), af)
|
||||
while True:
|
||||
timeout = _timeout(mexpiration)
|
||||
(rwire, from_address) = await s.recvfrom(65535, timeout)
|
||||
if _matches_destination(
|
||||
af, from_address, destination, True
|
||||
):
|
||||
break
|
||||
else:
|
||||
ldata = await _read_exactly(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = await _read_exactly(s, l, mexpiration)
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
try:
|
||||
done = inbound.process_message(r)
|
||||
except dns.xfr.UseTCP:
|
||||
assert is_udp # should not happen if we used TCP!
|
||||
if udp_mode == UDPMode.ONLY:
|
||||
raise
|
||||
done = True
|
||||
retry = True
|
||||
udp_mode = UDPMode.NEVER
|
||||
continue
|
||||
tsig_ctx = r.tsig_ctx
|
||||
if not retry and query.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
try:
|
||||
async for _ in _inbound_xfr(
|
||||
txn_manager, s, query, serial, timeout, expiration
|
||||
):
|
||||
pass
|
||||
return
|
||||
except dns.xfr.UseTCP:
|
||||
if udp_mode == UDPMode.ONLY:
|
||||
raise
|
||||
|
||||
|
||||
async def quic(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 853,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
connection: Optional[dns.quic.AsyncQuicConnection] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
backend: Optional[dns.asyncbackend.Backend] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending an asynchronous query via
|
||||
DNS-over-QUIC.
|
||||
|
||||
*backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
|
||||
the default, then dnspython will use the default backend.
|
||||
|
||||
See :py:func:`dns.query.quic()` for the documentation of the other
|
||||
parameters, exceptions, and return type of this method.
|
||||
"""
|
||||
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
the_connection: dns.quic.AsyncQuicConnection
|
||||
if connection:
|
||||
cfactory = dns.quic.null_factory
|
||||
mfactory = dns.quic.null_factory
|
||||
the_connection = connection
|
||||
else:
|
||||
(cfactory, mfactory) = dns.quic.factories_for_backend(backend)
|
||||
|
||||
async with cfactory() as context:
|
||||
async with mfactory(
|
||||
context, verify_mode=verify, server_name=server_hostname
|
||||
) as the_manager:
|
||||
if not connection:
|
||||
the_connection = the_manager.connect(where, port, source, source_port)
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
stream = await the_connection.make_stream(timeout)
|
||||
async with stream:
|
||||
await stream.send(wire, True)
|
||||
wire = await stream.receive(_remaining(expiration))
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
s = await backend.make_socket(
|
||||
af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
|
||||
)
|
||||
async with s:
|
||||
async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
|
||||
pass
|
||||
|
@ -118,6 +118,7 @@ def key_id(key: Union[DNSKEY, CDNSKEY]) -> int:
|
||||
"""
|
||||
|
||||
rdata = key.to_wire()
|
||||
assert rdata is not None # for mypy
|
||||
if key.algorithm == Algorithm.RSAMD5:
|
||||
return (rdata[-3] << 8) + rdata[-2]
|
||||
else:
|
||||
@ -224,7 +225,7 @@ def make_ds(
|
||||
if isinstance(algorithm, str):
|
||||
algorithm = DSDigest[algorithm.upper()]
|
||||
except Exception:
|
||||
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
|
||||
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
|
||||
if validating:
|
||||
check = policy.ok_to_validate_ds
|
||||
else:
|
||||
@ -240,14 +241,15 @@ def make_ds(
|
||||
elif algorithm == DSDigest.SHA384:
|
||||
dshash = hashlib.sha384()
|
||||
else:
|
||||
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
|
||||
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
|
||||
|
||||
if isinstance(name, str):
|
||||
name = dns.name.from_text(name, origin)
|
||||
wire = name.canonicalize().to_wire()
|
||||
assert wire is not None
|
||||
kwire = key.to_wire(origin=origin)
|
||||
assert wire is not None and kwire is not None # for mypy
|
||||
dshash.update(wire)
|
||||
dshash.update(key.to_wire(origin=origin))
|
||||
dshash.update(kwire)
|
||||
digest = dshash.digest()
|
||||
|
||||
dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + digest
|
||||
@ -323,6 +325,7 @@ def _get_rrname_rdataset(
|
||||
|
||||
|
||||
def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
public_cls = get_algorithm_cls_from_dnskey(key).public_cls
|
||||
try:
|
||||
public_key = public_cls.from_dnskey(key)
|
||||
@ -387,6 +390,7 @@ def _validate_rrsig(
|
||||
|
||||
data = _make_rrsig_signature_data(rrset, rrsig, origin)
|
||||
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
for candidate_key in candidate_keys:
|
||||
if not policy.ok_to_validate(candidate_key):
|
||||
continue
|
||||
@ -484,6 +488,7 @@ def _sign(
|
||||
verify: bool = False,
|
||||
policy: Optional[Policy] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
deterministic: bool = True,
|
||||
) -> RRSIG:
|
||||
"""Sign RRset using private key.
|
||||
|
||||
@ -523,6 +528,10 @@ def _sign(
|
||||
names in the rrset (including its owner name) must be absolute; otherwise the
|
||||
specified origin will be used to make names absolute when signing.
|
||||
|
||||
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
|
||||
(reproducible) signatures when supported by the algorithm used for signing.
|
||||
Currently, this only affects ECDSA.
|
||||
|
||||
Raises ``DeniedByPolicy`` if the signature is denied by policy.
|
||||
"""
|
||||
|
||||
@ -580,6 +589,7 @@ def _sign(
|
||||
|
||||
data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
|
||||
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
if isinstance(private_key, GenericPrivateKey):
|
||||
signing_key = private_key
|
||||
else:
|
||||
@ -589,7 +599,7 @@ def _sign(
|
||||
except UnsupportedAlgorithm:
|
||||
raise TypeError("Unsupported key algorithm")
|
||||
|
||||
signature = signing_key.sign(data, verify)
|
||||
signature = signing_key.sign(data, verify, deterministic)
|
||||
|
||||
return cast(RRSIG, rrsig_template.replace(signature=signature))
|
||||
|
||||
@ -629,7 +639,9 @@ def _make_rrsig_signature_data(
|
||||
rrname, rdataset = _get_rrname_rdataset(rrset)
|
||||
|
||||
data = b""
|
||||
data += rrsig.to_wire(origin=signer)[:18]
|
||||
wire = rrsig.to_wire(origin=signer)
|
||||
assert wire is not None # for mypy
|
||||
data += wire[:18]
|
||||
data += rrsig.signer.to_digestable(signer)
|
||||
|
||||
# Derelativize the name before considering labels.
|
||||
@ -686,6 +698,7 @@ def _make_dnskey(
|
||||
|
||||
algorithm = Algorithm.make(algorithm)
|
||||
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
if isinstance(public_key, GenericPublicKey):
|
||||
return public_key.to_dnskey(flags=flags, protocol=protocol)
|
||||
else:
|
||||
@ -832,7 +845,7 @@ def make_ds_rdataset(
|
||||
if isinstance(algorithm, str):
|
||||
algorithm = DSDigest[algorithm.upper()]
|
||||
except Exception:
|
||||
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
|
||||
raise UnsupportedAlgorithm(f'unsupported algorithm "{algorithm}"')
|
||||
_algorithms.add(algorithm)
|
||||
|
||||
if rdataset.rdtype == dns.rdatatype.CDS:
|
||||
@ -950,6 +963,7 @@ def default_rrset_signer(
|
||||
lifetime: Optional[int] = None,
|
||||
policy: Optional[Policy] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
"""Default RRset signer"""
|
||||
|
||||
@ -975,6 +989,7 @@ def default_rrset_signer(
|
||||
signer=signer,
|
||||
policy=policy,
|
||||
origin=origin,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
txn.add(rrset.name, rrset.ttl, rrsig)
|
||||
|
||||
@ -991,6 +1006,7 @@ def sign_zone(
|
||||
nsec3: Optional[NSEC3PARAM] = None,
|
||||
rrset_signer: Optional[RRsetSigner] = None,
|
||||
policy: Optional[Policy] = None,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
"""Sign zone.
|
||||
|
||||
@ -1030,6 +1046,10 @@ def sign_zone(
|
||||
function requires two arguments: transaction and RRset. If the not specified,
|
||||
``dns.dnssec.default_rrset_signer`` will be used.
|
||||
|
||||
*deterministic*, a ``bool``. If ``True``, the default, use deterministic
|
||||
(reproducible) signatures when supported by the algorithm used for signing.
|
||||
Currently, this only affects ECDSA.
|
||||
|
||||
Returns ``None``.
|
||||
"""
|
||||
|
||||
@ -1056,6 +1076,9 @@ def sign_zone(
|
||||
else:
|
||||
cm = zone.writer()
|
||||
|
||||
if zone.origin is None:
|
||||
raise ValueError("no zone origin")
|
||||
|
||||
with cm as _txn:
|
||||
if add_dnskey:
|
||||
if dnskey_ttl is None:
|
||||
@ -1081,6 +1104,7 @@ def sign_zone(
|
||||
lifetime=lifetime,
|
||||
policy=policy,
|
||||
origin=zone.origin,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
return _sign_zone_nsec(zone, _txn, _rrset_signer)
|
||||
|
||||
|
@ -26,6 +26,7 @@ AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
|
||||
|
||||
algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {}
|
||||
if _have_cryptography:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
algorithms.update(
|
||||
{
|
||||
(Algorithm.RSAMD5, None): PrivateRSAMD5,
|
||||
@ -59,7 +60,7 @@ def get_algorithm_cls(
|
||||
if cls:
|
||||
return cls
|
||||
raise UnsupportedAlgorithm(
|
||||
'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
|
||||
f'algorithm "{Algorithm.to_text(algorithm)}" not supported by dnspython'
|
||||
)
|
||||
|
||||
|
||||
|
@ -65,7 +65,12 @@ class GenericPrivateKey(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign DNSSEC data"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -68,7 +68,12 @@ class PrivateDSA(CryptographyPrivateKey):
|
||||
key_cls = dsa.DSAPrivateKey
|
||||
public_cls = PublicDSA
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 2536, section 3."""
|
||||
public_dsa_key = self.key.public_key()
|
||||
if public_dsa_key.key_size > 1024:
|
||||
|
@ -47,9 +47,17 @@ class PrivateECDSA(CryptographyPrivateKey):
|
||||
key_cls = ec.EllipticCurvePrivateKey
|
||||
public_cls = PublicECDSA
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 6605, section 4."""
|
||||
der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
|
||||
algorithm = ec.ECDSA(
|
||||
self.public_cls.chosen_hash, deterministic_signing=deterministic
|
||||
)
|
||||
der_signature = self.key.sign(data, algorithm)
|
||||
dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
|
||||
signature = int.to_bytes(
|
||||
dsa_r, length=self.public_cls.octets, byteorder="big"
|
||||
|
@ -29,7 +29,12 @@ class PublicEDDSA(CryptographyPublicKey):
|
||||
class PrivateEDDSA(CryptographyPrivateKey):
|
||||
public_cls: Type[PublicEDDSA]
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 8080, section 4."""
|
||||
signature = self.key.sign(data)
|
||||
if verify:
|
||||
|
@ -56,7 +56,12 @@ class PrivateRSA(CryptographyPrivateKey):
|
||||
public_cls = PublicRSA
|
||||
default_public_exponent = 65537
|
||||
|
||||
def sign(self, data: bytes, verify: bool = False) -> bytes:
|
||||
def sign(
|
||||
self,
|
||||
data: bytes,
|
||||
verify: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> bytes:
|
||||
"""Sign using a private key per RFC 3110, section 3."""
|
||||
signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
|
||||
if verify:
|
||||
|
@ -52,6 +52,8 @@ class OptionType(dns.enum.IntEnum):
|
||||
CHAIN = 13
|
||||
#: EDE (extended-dns-error)
|
||||
EDE = 15
|
||||
#: REPORTCHANNEL
|
||||
REPORTCHANNEL = 18
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
@ -222,7 +224,7 @@ class ECSOption(Option): # lgtm[py/missing-equals]
|
||||
self.addrdata = self.addrdata[:-1] + last
|
||||
|
||||
def to_text(self) -> str:
|
||||
return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen)
|
||||
return f"ECS {self.address}/{self.srclen} scope/{self.scopelen}"
|
||||
|
||||
@staticmethod
|
||||
def from_text(text: str) -> Option:
|
||||
@ -255,10 +257,10 @@ class ECSOption(Option): # lgtm[py/missing-equals]
|
||||
ecs_text = tokens[0]
|
||||
elif len(tokens) == 2:
|
||||
if tokens[0] != optional_prefix:
|
||||
raise ValueError('could not parse ECS from "{}"'.format(text))
|
||||
raise ValueError(f'could not parse ECS from "{text}"')
|
||||
ecs_text = tokens[1]
|
||||
else:
|
||||
raise ValueError('could not parse ECS from "{}"'.format(text))
|
||||
raise ValueError(f'could not parse ECS from "{text}"')
|
||||
n_slashes = ecs_text.count("/")
|
||||
if n_slashes == 1:
|
||||
address, tsrclen = ecs_text.split("/")
|
||||
@ -266,18 +268,16 @@ class ECSOption(Option): # lgtm[py/missing-equals]
|
||||
elif n_slashes == 2:
|
||||
address, tsrclen, tscope = ecs_text.split("/")
|
||||
else:
|
||||
raise ValueError('could not parse ECS from "{}"'.format(text))
|
||||
raise ValueError(f'could not parse ECS from "{text}"')
|
||||
try:
|
||||
scope = int(tscope)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"invalid scope " + '"{}": scope must be an integer'.format(tscope)
|
||||
)
|
||||
raise ValueError("invalid scope " + f'"{tscope}": scope must be an integer')
|
||||
try:
|
||||
srclen = int(tsrclen)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"invalid srclen " + '"{}": srclen must be an integer'.format(tsrclen)
|
||||
"invalid srclen " + f'"{tsrclen}": srclen must be an integer'
|
||||
)
|
||||
return ECSOption(address, srclen, scope)
|
||||
|
||||
@ -430,10 +430,65 @@ class NSIDOption(Option):
|
||||
return cls(parser.get_remaining())
|
||||
|
||||
|
||||
class CookieOption(Option):
|
||||
def __init__(self, client: bytes, server: bytes):
|
||||
super().__init__(dns.edns.OptionType.COOKIE)
|
||||
self.client = client
|
||||
self.server = server
|
||||
if len(client) != 8:
|
||||
raise ValueError("client cookie must be 8 bytes")
|
||||
if len(server) != 0 and (len(server) < 8 or len(server) > 32):
|
||||
raise ValueError("server cookie must be empty or between 8 and 32 bytes")
|
||||
|
||||
def to_wire(self, file: Any = None) -> Optional[bytes]:
|
||||
if file:
|
||||
file.write(self.client)
|
||||
if len(self.server) > 0:
|
||||
file.write(self.server)
|
||||
return None
|
||||
else:
|
||||
return self.client + self.server
|
||||
|
||||
def to_text(self) -> str:
|
||||
client = binascii.hexlify(self.client).decode()
|
||||
if len(self.server) > 0:
|
||||
server = binascii.hexlify(self.server).decode()
|
||||
else:
|
||||
server = ""
|
||||
return f"COOKIE {client}{server}"
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
|
||||
) -> Option:
|
||||
return cls(parser.get_bytes(8), parser.get_remaining())
|
||||
|
||||
|
||||
class ReportChannelOption(Option):
|
||||
# RFC 9567
|
||||
def __init__(self, agent_domain: dns.name.Name):
|
||||
super().__init__(OptionType.REPORTCHANNEL)
|
||||
self.agent_domain = agent_domain
|
||||
|
||||
def to_wire(self, file: Any = None) -> Optional[bytes]:
|
||||
return self.agent_domain.to_wire(file)
|
||||
|
||||
def to_text(self) -> str:
|
||||
return "REPORTCHANNEL " + self.agent_domain.to_text()
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(
|
||||
cls, otype: Union[OptionType, str], parser: dns.wire.Parser
|
||||
) -> Option:
|
||||
return cls(parser.get_name())
|
||||
|
||||
|
||||
_type_to_class: Dict[OptionType, Any] = {
|
||||
OptionType.ECS: ECSOption,
|
||||
OptionType.EDE: EDEOption,
|
||||
OptionType.NSID: NSIDOption,
|
||||
OptionType.COOKIE: CookieOption,
|
||||
OptionType.REPORTCHANNEL: ReportChannelOption,
|
||||
}
|
||||
|
||||
|
||||
@ -512,5 +567,6 @@ KEEPALIVE = OptionType.KEEPALIVE
|
||||
PADDING = OptionType.PADDING
|
||||
CHAIN = OptionType.CHAIN
|
||||
EDE = OptionType.EDE
|
||||
REPORTCHANNEL = OptionType.REPORTCHANNEL
|
||||
|
||||
### END generated OptionType constants
|
||||
|
@ -81,7 +81,7 @@ class DNSException(Exception):
|
||||
if kwargs:
|
||||
assert (
|
||||
set(kwargs.keys()) == self.supp_kwargs
|
||||
), "following set of keyword args is required: %s" % (self.supp_kwargs)
|
||||
), f"following set of keyword args is required: {self.supp_kwargs}"
|
||||
return kwargs
|
||||
|
||||
def _fmt_kwargs(self, **kwargs):
|
||||
|
@ -54,7 +54,7 @@ def from_text(text: str) -> Tuple[int, int, int]:
|
||||
elif c.isdigit():
|
||||
cur += c
|
||||
else:
|
||||
raise dns.exception.SyntaxError("Could not parse %s" % (c))
|
||||
raise dns.exception.SyntaxError(f"Could not parse {c}")
|
||||
|
||||
if state == 0:
|
||||
raise dns.exception.SyntaxError("no stop value specified")
|
||||
|
@ -143,9 +143,7 @@ def inet_aton(text: Union[str, bytes], ignore_scope: bool = False) -> bytes:
|
||||
if m is not None:
|
||||
b = dns.ipv4.inet_aton(m.group(2))
|
||||
btext = (
|
||||
"{}:{:02x}{:02x}:{:02x}{:02x}".format(
|
||||
m.group(1).decode(), b[0], b[1], b[2], b[3]
|
||||
)
|
||||
f"{m.group(1).decode()}:{b[0]:02x}{b[1]:02x}:{b[2]:02x}{b[3]:02x}"
|
||||
).encode()
|
||||
#
|
||||
# Try to turn '::<whatever>' into ':<whatever>'; if no match try to
|
||||
|
@ -18,9 +18,10 @@
|
||||
"""DNS Messages"""
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import io
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import dns.edns
|
||||
import dns.entropy
|
||||
@ -161,6 +162,7 @@ class Message:
|
||||
self.index: IndexType = {}
|
||||
self.errors: List[MessageError] = []
|
||||
self.time = 0.0
|
||||
self.wire: Optional[bytes] = None
|
||||
|
||||
@property
|
||||
def question(self) -> List[dns.rrset.RRset]:
|
||||
@ -220,16 +222,16 @@ class Message:
|
||||
|
||||
s = io.StringIO()
|
||||
s.write("id %d\n" % self.id)
|
||||
s.write("opcode %s\n" % dns.opcode.to_text(self.opcode()))
|
||||
s.write("rcode %s\n" % dns.rcode.to_text(self.rcode()))
|
||||
s.write("flags %s\n" % dns.flags.to_text(self.flags))
|
||||
s.write(f"opcode {dns.opcode.to_text(self.opcode())}\n")
|
||||
s.write(f"rcode {dns.rcode.to_text(self.rcode())}\n")
|
||||
s.write(f"flags {dns.flags.to_text(self.flags)}\n")
|
||||
if self.edns >= 0:
|
||||
s.write("edns %s\n" % self.edns)
|
||||
s.write(f"edns {self.edns}\n")
|
||||
if self.ednsflags != 0:
|
||||
s.write("eflags %s\n" % dns.flags.edns_to_text(self.ednsflags))
|
||||
s.write(f"eflags {dns.flags.edns_to_text(self.ednsflags)}\n")
|
||||
s.write("payload %d\n" % self.payload)
|
||||
for opt in self.options:
|
||||
s.write("option %s\n" % opt.to_text())
|
||||
s.write(f"option {opt.to_text()}\n")
|
||||
for name, which in self._section_enum.__members__.items():
|
||||
s.write(f";{name}\n")
|
||||
for rrset in self.section_from_number(which):
|
||||
@ -645,6 +647,7 @@ class Message:
|
||||
if multi:
|
||||
self.tsig_ctx = ctx
|
||||
wire = r.get_wire()
|
||||
self.wire = wire
|
||||
if prepend_length:
|
||||
wire = len(wire).to_bytes(2, "big") + wire
|
||||
return wire
|
||||
@ -912,6 +915,14 @@ class Message:
|
||||
self.flags &= 0x87FF
|
||||
self.flags |= dns.opcode.to_flags(opcode)
|
||||
|
||||
def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]:
|
||||
"""Return the list of options of the specified type."""
|
||||
return [option for option in self.options if option.otype == otype]
|
||||
|
||||
def extended_errors(self) -> List[dns.edns.EDEOption]:
|
||||
"""Return the list of Extended DNS Error (EDE) options in the message"""
|
||||
return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE))
|
||||
|
||||
def _get_one_rr_per_rrset(self, value):
|
||||
# What the caller picked is fine.
|
||||
return value
|
||||
@ -1192,9 +1203,9 @@ class _WireReader:
|
||||
if rdtype == dns.rdatatype.OPT:
|
||||
self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
|
||||
elif rdtype == dns.rdatatype.TSIG:
|
||||
if self.keyring is None:
|
||||
if self.keyring is None or self.keyring is True:
|
||||
raise UnknownTSIGKey("got signed message without keyring")
|
||||
if isinstance(self.keyring, dict):
|
||||
elif isinstance(self.keyring, dict):
|
||||
key = self.keyring.get(absolute_name)
|
||||
if isinstance(key, bytes):
|
||||
key = dns.tsig.Key(absolute_name, key, rd.algorithm)
|
||||
@ -1203,19 +1214,20 @@ class _WireReader:
|
||||
else:
|
||||
key = self.keyring
|
||||
if key is None:
|
||||
raise UnknownTSIGKey("key '%s' unknown" % name)
|
||||
self.message.keyring = key
|
||||
self.message.tsig_ctx = dns.tsig.validate(
|
||||
self.parser.wire,
|
||||
key,
|
||||
absolute_name,
|
||||
rd,
|
||||
int(time.time()),
|
||||
self.message.request_mac,
|
||||
rr_start,
|
||||
self.message.tsig_ctx,
|
||||
self.multi,
|
||||
)
|
||||
raise UnknownTSIGKey(f"key '{name}' unknown")
|
||||
if key:
|
||||
self.message.keyring = key
|
||||
self.message.tsig_ctx = dns.tsig.validate(
|
||||
self.parser.wire,
|
||||
key,
|
||||
absolute_name,
|
||||
rd,
|
||||
int(time.time()),
|
||||
self.message.request_mac,
|
||||
rr_start,
|
||||
self.message.tsig_ctx,
|
||||
self.multi,
|
||||
)
|
||||
self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
|
||||
else:
|
||||
rrset = self.message.find_rrset(
|
||||
@ -1251,6 +1263,7 @@ class _WireReader:
|
||||
factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
|
||||
self.message = factory(id=id)
|
||||
self.message.flags = dns.flags.Flag(flags)
|
||||
self.message.wire = self.parser.wire
|
||||
self.initialize_message(self.message)
|
||||
self.one_rr_per_rrset = self.message._get_one_rr_per_rrset(
|
||||
self.one_rr_per_rrset
|
||||
@ -1290,8 +1303,10 @@ def from_wire(
|
||||
) -> Message:
|
||||
"""Convert a DNS wire format message into a message object.
|
||||
|
||||
*keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message
|
||||
is signed.
|
||||
*keyring*, a ``dns.tsig.Key``, ``dict``, ``bool``, or ``None``, the key or keyring
|
||||
to use if the message is signed. If ``None`` or ``True``, then trying to decode
|
||||
a message with a TSIG will fail as it cannot be validated. If ``False``, then
|
||||
TSIG validation is disabled.
|
||||
|
||||
*request_mac*, a ``bytes`` or ``None``. If the message is a response to a
|
||||
TSIG-signed request, *request_mac* should be set to the MAC of that request.
|
||||
@ -1811,6 +1826,16 @@ def make_query(
|
||||
return m
|
||||
|
||||
|
||||
class CopyMode(enum.Enum):
|
||||
"""
|
||||
How should sections be copied when making an update response?
|
||||
"""
|
||||
|
||||
NOTHING = 0
|
||||
QUESTION = 1
|
||||
EVERYTHING = 2
|
||||
|
||||
|
||||
def make_response(
|
||||
query: Message,
|
||||
recursion_available: bool = False,
|
||||
@ -1818,13 +1843,14 @@ def make_response(
|
||||
fudge: int = 300,
|
||||
tsig_error: int = 0,
|
||||
pad: Optional[int] = None,
|
||||
copy_mode: Optional[CopyMode] = None,
|
||||
) -> Message:
|
||||
"""Make a message which is a response for the specified query.
|
||||
The message returned is really a response skeleton; it has all of the infrastructure
|
||||
required of a response, but none of the content.
|
||||
|
||||
The response's question section is a shallow copy of the query's question section,
|
||||
so the query's question RRsets should not be changed.
|
||||
Response section(s) which are copied are shallow copies of the matching section(s)
|
||||
in the query, so the query's RRsets should not be changed.
|
||||
|
||||
*query*, a ``dns.message.Message``, the query to respond to.
|
||||
|
||||
@ -1837,25 +1863,44 @@ def make_response(
|
||||
*tsig_error*, an ``int``, the TSIG error.
|
||||
|
||||
*pad*, a non-negative ``int`` or ``None``. If 0, the default, do not pad; otherwise
|
||||
if not ``None`` add padding bytes to make the message size a multiple of *pad*.
|
||||
Note that if padding is non-zero, an EDNS PADDING option will always be added to the
|
||||
if not ``None`` add padding bytes to make the message size a multiple of *pad*. Note
|
||||
that if padding is non-zero, an EDNS PADDING option will always be added to the
|
||||
message. If ``None``, add padding following RFC 8467, namely if the request is
|
||||
padded, pad the response to 468 otherwise do not pad.
|
||||
|
||||
*copy_mode*, a ``dns.message.CopyMode`` or ``None``, determines how sections are
|
||||
copied. The default, ``None`` copies sections according to the default for the
|
||||
message's opcode, which is currently ``dns.message.CopyMode.QUESTION`` for all
|
||||
opcodes. ``dns.message.CopyMode.QUESTION`` copies only the question section.
|
||||
``dns.message.CopyMode.EVERYTHING`` copies all sections other than OPT or TSIG
|
||||
records, which are created appropriately if needed. ``dns.message.CopyMode.NOTHING``
|
||||
copies no sections; note that this mode is for server testing purposes and is
|
||||
otherwise not recommended for use. In particular, ``dns.message.is_response()``
|
||||
will be ``False`` if you create a response this way and the rcode is not
|
||||
``FORMERR``, ``SERVFAIL``, ``NOTIMP``, or ``REFUSED``.
|
||||
|
||||
Returns a ``dns.message.Message`` object whose specific class is appropriate for the
|
||||
query. For example, if query is a ``dns.update.UpdateMessage``, response will be
|
||||
too.
|
||||
query. For example, if query is a ``dns.update.UpdateMessage``, the response will
|
||||
be one too.
|
||||
"""
|
||||
|
||||
if query.flags & dns.flags.QR:
|
||||
raise dns.exception.FormError("specified query message is not a query")
|
||||
factory = _message_factory_from_opcode(query.opcode())
|
||||
opcode = query.opcode()
|
||||
factory = _message_factory_from_opcode(opcode)
|
||||
response = factory(id=query.id)
|
||||
response.flags = dns.flags.QR | (query.flags & dns.flags.RD)
|
||||
if recursion_available:
|
||||
response.flags |= dns.flags.RA
|
||||
response.set_opcode(query.opcode())
|
||||
response.question = list(query.question)
|
||||
response.set_opcode(opcode)
|
||||
if copy_mode is None:
|
||||
copy_mode = CopyMode.QUESTION
|
||||
if copy_mode != CopyMode.NOTHING:
|
||||
response.question = list(query.question)
|
||||
if copy_mode == CopyMode.EVERYTHING:
|
||||
response.answer = list(query.answer)
|
||||
response.authority = list(query.authority)
|
||||
response.additional = list(query.additional)
|
||||
if query.edns >= 0:
|
||||
if pad is None:
|
||||
# Set response padding per RFC 8467
|
||||
|
@ -59,11 +59,11 @@ class NameRelation(dns.enum.IntEnum):
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
return cls.COMMONANCESTOR
|
||||
return cls.COMMONANCESTOR # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def _short_name(cls):
|
||||
return cls.__name__
|
||||
return cls.__name__ # pragma: no cover
|
||||
|
||||
|
||||
# Backwards compatibility
|
||||
@ -277,6 +277,7 @@ class IDNA2008Codec(IDNACodec):
|
||||
raise NoIDNA2008
|
||||
try:
|
||||
if self.uts_46:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
label = idna.uts46_remap(label, False, self.transitional)
|
||||
return idna.alabel(label)
|
||||
except idna.IDNAError as e:
|
||||
|
@ -168,12 +168,14 @@ class DoHNameserver(Nameserver):
|
||||
bootstrap_address: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
want_get: bool = False,
|
||||
http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
|
||||
):
|
||||
super().__init__()
|
||||
self.url = url
|
||||
self.bootstrap_address = bootstrap_address
|
||||
self.verify = verify
|
||||
self.want_get = want_get
|
||||
self.http_version = http_version
|
||||
|
||||
def kind(self):
|
||||
return "DoH"
|
||||
@ -214,6 +216,7 @@ class DoHNameserver(Nameserver):
|
||||
ignore_trailing=ignore_trailing,
|
||||
verify=self.verify,
|
||||
post=(not self.want_get),
|
||||
http_version=self.http_version,
|
||||
)
|
||||
|
||||
async def async_query(
|
||||
@ -238,6 +241,7 @@ class DoHNameserver(Nameserver):
|
||||
ignore_trailing=ignore_trailing,
|
||||
verify=self.verify,
|
||||
post=(not self.want_get),
|
||||
http_version=self.http_version,
|
||||
)
|
||||
|
||||
|
||||
|
545
lib/dns/query.py
545
lib/dns/query.py
@ -23,11 +23,13 @@ import enum
|
||||
import errno
|
||||
import os
|
||||
import os.path
|
||||
import random
|
||||
import selectors
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
import dns._features
|
||||
import dns.exception
|
||||
@ -129,7 +131,7 @@ if _have_httpx:
|
||||
family=socket.AF_UNSPEC,
|
||||
**kwargs,
|
||||
):
|
||||
if resolver is None:
|
||||
if resolver is None and bootstrap_address is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.resolver
|
||||
|
||||
@ -217,7 +219,7 @@ def _wait_for(fd, readable, writable, _, expiration):
|
||||
|
||||
if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
|
||||
return True
|
||||
sel = _selector_class()
|
||||
sel = selectors.DefaultSelector()
|
||||
events = 0
|
||||
if readable:
|
||||
events |= selectors.EVENT_READ
|
||||
@ -235,26 +237,6 @@ def _wait_for(fd, readable, writable, _, expiration):
|
||||
raise dns.exception.Timeout
|
||||
|
||||
|
||||
def _set_selector_class(selector_class):
|
||||
# Internal API. Do not use.
|
||||
|
||||
global _selector_class
|
||||
|
||||
_selector_class = selector_class
|
||||
|
||||
|
||||
if hasattr(selectors, "PollSelector"):
|
||||
# Prefer poll() on platforms that support it because it has no
|
||||
# limits on the maximum value of a file descriptor (plus it will
|
||||
# be more efficient for high values).
|
||||
#
|
||||
# We ignore typing here as we can't say _selector_class is Any
|
||||
# on python < 3.8 due to a bug.
|
||||
_selector_class = selectors.PollSelector # type: ignore
|
||||
else:
|
||||
_selector_class = selectors.SelectSelector # type: ignore
|
||||
|
||||
|
||||
def _wait_for_readable(s, expiration):
|
||||
_wait_for(s, True, False, True, expiration)
|
||||
|
||||
@ -355,6 +337,36 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
|
||||
raise
|
||||
|
||||
|
||||
def _maybe_get_resolver(
|
||||
resolver: Optional["dns.resolver.Resolver"],
|
||||
) -> "dns.resolver.Resolver":
|
||||
# We need a separate method for this to avoid overriding the global
|
||||
# variable "dns" with the as-yet undefined local variable "dns"
|
||||
# in https().
|
||||
if resolver is None:
|
||||
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
||||
import dns.resolver
|
||||
|
||||
resolver = dns.resolver.Resolver()
|
||||
return resolver
|
||||
|
||||
|
||||
class HTTPVersion(enum.IntEnum):
|
||||
"""Which version of HTTP should be used?
|
||||
|
||||
DEFAULT will select the first version from the list [2, 1.1, 3] that
|
||||
is available.
|
||||
"""
|
||||
|
||||
DEFAULT = 0
|
||||
HTTP_1 = 1
|
||||
H1 = 1
|
||||
HTTP_2 = 2
|
||||
H2 = 2
|
||||
HTTP_3 = 3
|
||||
H3 = 3
|
||||
|
||||
|
||||
def https(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
@ -370,7 +382,8 @@ def https(
|
||||
bootstrap_address: Optional[str] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
resolver: Optional["dns.resolver.Resolver"] = None,
|
||||
family: Optional[int] = socket.AF_UNSPEC,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
http_version: HTTPVersion = HTTPVersion.DEFAULT,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-HTTPS.
|
||||
|
||||
@ -420,27 +433,66 @@ def https(
|
||||
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
|
||||
and AAAA records will be retrieved.
|
||||
|
||||
*http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
"""
|
||||
|
||||
(af, _, the_source) = _destination_and_source(
|
||||
where, port, source, source_port, False
|
||||
)
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = f"https://{where}:{port}{path}"
|
||||
elif af == socket.AF_INET6:
|
||||
url = f"https://[{where}]:{port}{path}"
|
||||
else:
|
||||
url = where
|
||||
|
||||
extensions = {}
|
||||
if bootstrap_address is None:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.hostname is None:
|
||||
raise ValueError("no hostname in URL")
|
||||
if dns.inet.is_address(parsed.hostname):
|
||||
bootstrap_address = parsed.hostname
|
||||
extensions["sni_hostname"] = parsed.hostname
|
||||
if parsed.port is not None:
|
||||
port = parsed.port
|
||||
|
||||
if http_version == HTTPVersion.H3 or (
|
||||
http_version == HTTPVersion.DEFAULT and not have_doh
|
||||
):
|
||||
if bootstrap_address is None:
|
||||
resolver = _maybe_get_resolver(resolver)
|
||||
assert parsed.hostname is not None # for mypy
|
||||
answers = resolver.resolve_name(parsed.hostname, family)
|
||||
bootstrap_address = random.choice(list(answers.addresses()))
|
||||
return _http3(
|
||||
q,
|
||||
bootstrap_address,
|
||||
url,
|
||||
timeout,
|
||||
port,
|
||||
source,
|
||||
source_port,
|
||||
one_rr_per_rrset,
|
||||
ignore_trailing,
|
||||
verify=verify,
|
||||
post=post,
|
||||
)
|
||||
|
||||
if not have_doh:
|
||||
raise NoDOH # pragma: no cover
|
||||
if session and not isinstance(session, httpx.Client):
|
||||
raise ValueError("session parameter must be an httpx.Client")
|
||||
|
||||
wire = q.to_wire()
|
||||
(af, _, the_source) = _destination_and_source(
|
||||
where, port, source, source_port, False
|
||||
)
|
||||
transport = None
|
||||
headers = {"accept": "application/dns-message"}
|
||||
if af is not None and dns.inet.is_address(where):
|
||||
if af == socket.AF_INET:
|
||||
url = "https://{}:{}{}".format(where, port, path)
|
||||
elif af == socket.AF_INET6:
|
||||
url = "https://[{}]:{}{}".format(where, port, path)
|
||||
else:
|
||||
url = where
|
||||
|
||||
h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
|
||||
h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
|
||||
|
||||
# set source port and source address
|
||||
|
||||
@ -450,21 +502,22 @@ def https(
|
||||
else:
|
||||
local_address = the_source[0]
|
||||
local_port = the_source[1]
|
||||
transport = _HTTPTransport(
|
||||
local_address=local_address,
|
||||
http1=True,
|
||||
http2=True,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
if session:
|
||||
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
|
||||
else:
|
||||
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
|
||||
transport = _HTTPTransport(
|
||||
local_address=local_address,
|
||||
http1=h1,
|
||||
http2=h2,
|
||||
verify=verify,
|
||||
local_port=local_port,
|
||||
bootstrap_address=bootstrap_address,
|
||||
resolver=resolver,
|
||||
family=family,
|
||||
)
|
||||
|
||||
cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
|
||||
with cm as session:
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
|
||||
# GET and POST examples
|
||||
@ -475,20 +528,30 @@ def https(
|
||||
"content-length": str(len(wire)),
|
||||
}
|
||||
)
|
||||
response = session.post(url, headers=headers, content=wire, timeout=timeout)
|
||||
response = session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
content=wire,
|
||||
timeout=timeout,
|
||||
extensions=extensions,
|
||||
)
|
||||
else:
|
||||
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
|
||||
twire = wire.decode() # httpx does a repr() if we give it bytes
|
||||
response = session.get(
|
||||
url, headers=headers, timeout=timeout, params={"dns": twire}
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
params={"dns": twire},
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
|
||||
# status codes
|
||||
if response.status_code < 200 or response.status_code > 299:
|
||||
raise ValueError(
|
||||
"{} responded with status code {}"
|
||||
"\nResponse body: {}".format(where, response.status_code, response.content)
|
||||
f"{where} responded with status code {response.status_code}"
|
||||
f"\nResponse body: {response.content}"
|
||||
)
|
||||
r = dns.message.from_wire(
|
||||
response.content,
|
||||
@ -503,6 +566,81 @@ def https(
|
||||
return r
|
||||
|
||||
|
||||
def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
|
||||
if headers is None:
|
||||
raise KeyError
|
||||
for header, value in headers:
|
||||
if header == name:
|
||||
return value
|
||||
raise KeyError
|
||||
|
||||
|
||||
def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
|
||||
value = _find_header(headers, b":status")
|
||||
if value is None:
|
||||
raise SyntaxError("no :status header in response")
|
||||
status = int(value)
|
||||
if status < 0:
|
||||
raise SyntaxError("status is negative")
|
||||
if status < 200 or status > 299:
|
||||
error = ""
|
||||
if len(wire) > 0:
|
||||
try:
|
||||
error = ": " + wire.decode()
|
||||
except Exception:
|
||||
pass
|
||||
raise ValueError(f"{peer} responded with status code {status}{error}")
|
||||
|
||||
|
||||
def _http3(
|
||||
q: dns.message.Message,
|
||||
where: str,
|
||||
url: str,
|
||||
timeout: Optional[float] = None,
|
||||
port: int = 853,
|
||||
source: Optional[str] = None,
|
||||
source_port: int = 0,
|
||||
one_rr_per_rrset: bool = False,
|
||||
ignore_trailing: bool = False,
|
||||
verify: Union[bool, str] = True,
|
||||
hostname: Optional[str] = None,
|
||||
post: bool = True,
|
||||
) -> dns.message.Message:
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
|
||||
|
||||
url_parts = urllib.parse.urlparse(url)
|
||||
hostname = url_parts.hostname
|
||||
if url_parts.port is not None:
|
||||
port = url_parts.port
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
manager = dns.quic.SyncQuicManager(
|
||||
verify_mode=verify, server_name=hostname, h3=True
|
||||
)
|
||||
|
||||
with manager:
|
||||
connection = manager.connect(where, port, source, source_port)
|
||||
(start, expiration) = _compute_times(timeout)
|
||||
with connection.make_stream(timeout) as stream:
|
||||
stream.send_h3(url, wire, post)
|
||||
wire = stream.receive(_remaining(expiration))
|
||||
_check_status(stream.headers(), where, wire)
|
||||
finish = time.time()
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.request_mac,
|
||||
one_rr_per_rrset=one_rr_per_rrset,
|
||||
ignore_trailing=ignore_trailing,
|
||||
)
|
||||
r.time = max(finish - start, 0.0)
|
||||
if not q.is_response(r):
|
||||
raise BadResponse
|
||||
return r
|
||||
|
||||
|
||||
def _udp_recv(sock, max_size, expiration):
|
||||
"""Reads a datagram from the socket.
|
||||
A Timeout exception will be raised if the operation is not completed
|
||||
@ -855,7 +993,7 @@ def _net_read(sock, count, expiration):
|
||||
try:
|
||||
n = sock.recv(count)
|
||||
if n == b"":
|
||||
raise EOFError
|
||||
raise EOFError("EOF")
|
||||
count -= len(n)
|
||||
s += n
|
||||
except (BlockingIOError, ssl.SSLWantReadError):
|
||||
@ -1023,6 +1161,7 @@ def tcp(
|
||||
cm = _make_socket(af, socket.SOCK_STREAM, source)
|
||||
with cm as s:
|
||||
if not sock:
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
_connect(s, destination, expiration)
|
||||
send_tcp(s, wire, expiration)
|
||||
(r, received_time) = receive_tcp(
|
||||
@ -1188,6 +1327,7 @@ def quic(
|
||||
ignore_trailing: bool = False,
|
||||
connection: Optional[dns.quic.SyncQuicConnection] = None,
|
||||
verify: Union[bool, str] = True,
|
||||
hostname: Optional[str] = None,
|
||||
server_hostname: Optional[str] = None,
|
||||
) -> dns.message.Message:
|
||||
"""Return the response obtained after sending a query via DNS-over-QUIC.
|
||||
@ -1212,17 +1352,21 @@ def quic(
|
||||
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
|
||||
received message.
|
||||
|
||||
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the
|
||||
connection to use to send the query.
|
||||
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
|
||||
to send the query.
|
||||
|
||||
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
|
||||
of the server is done using the default CA bundle; if ``False``, then no
|
||||
verification is done; if a `str` then it specifies the path to a certificate file or
|
||||
directory which will be used for verification.
|
||||
|
||||
*server_hostname*, a ``str`` containing the server's hostname. The
|
||||
default is ``None``, which means that no hostname is known, and if an
|
||||
SSL context is created, hostname checking will be disabled.
|
||||
*hostname*, a ``str`` containing the server's hostname or ``None``. The default is
|
||||
``None``, which means that no hostname is known, and if an SSL context is created,
|
||||
hostname checking will be disabled. This value is ignored if *url* is not
|
||||
``None``.
|
||||
|
||||
*server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
|
||||
only, and has the same meaning as *hostname*.
|
||||
|
||||
Returns a ``dns.message.Message``.
|
||||
"""
|
||||
@ -1230,6 +1374,9 @@ def quic(
|
||||
if not dns.quic.have_quic:
|
||||
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
|
||||
|
||||
if server_hostname is not None and hostname is None:
|
||||
hostname = server_hostname
|
||||
|
||||
q.id = 0
|
||||
wire = q.to_wire()
|
||||
the_connection: dns.quic.SyncQuicConnection
|
||||
@ -1238,9 +1385,7 @@ def quic(
|
||||
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
|
||||
the_connection = connection
|
||||
else:
|
||||
manager = dns.quic.SyncQuicManager(
|
||||
verify_mode=verify, server_name=server_hostname
|
||||
)
|
||||
manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
|
||||
the_manager = manager # for type checking happiness
|
||||
|
||||
with manager:
|
||||
@ -1264,6 +1409,70 @@ def quic(
|
||||
return r
|
||||
|
||||
|
||||
class UDPMode(enum.IntEnum):
|
||||
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
|
||||
|
||||
NEVER means "never use UDP; always use TCP"
|
||||
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
|
||||
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
|
||||
"""
|
||||
|
||||
NEVER = 0
|
||||
TRY_FIRST = 1
|
||||
ONLY = 2
|
||||
|
||||
|
||||
def _inbound_xfr(
|
||||
txn_manager: dns.transaction.TransactionManager,
|
||||
s: socket.socket,
|
||||
query: dns.message.Message,
|
||||
serial: Optional[int],
|
||||
timeout: Optional[float],
|
||||
expiration: float,
|
||||
) -> Any:
|
||||
"""Given a socket, does the zone transfer."""
|
||||
rdtype = query.question[0].rdtype
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
origin = txn_manager.from_wire_origin()
|
||||
wire = query.to_wire()
|
||||
is_udp = s.type == socket.SOCK_DGRAM
|
||||
if is_udp:
|
||||
_udp_send(s, wire, None, expiration)
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
_net_write(s, tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
(rwire, _) = _udp_recv(s, 65535, mexpiration)
|
||||
else:
|
||||
ldata = _net_read(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = _net_read(s, l, mexpiration)
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
done = inbound.process_message(r)
|
||||
yield r
|
||||
tsig_ctx = r.tsig_ctx
|
||||
if query.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
|
||||
|
||||
def xfr(
|
||||
where: str,
|
||||
zone: Union[dns.name.Name, str],
|
||||
@ -1333,134 +1542,52 @@ def xfr(
|
||||
Returns a generator of ``dns.message.Message`` objects.
|
||||
"""
|
||||
|
||||
class DummyTransactionManager(dns.transaction.TransactionManager):
|
||||
def __init__(self, origin, relativize):
|
||||
self.info = (origin, relativize, dns.name.empty if relativize else origin)
|
||||
|
||||
def origin_information(self):
|
||||
return self.info
|
||||
|
||||
def get_class(self) -> dns.rdataclass.RdataClass:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def reader(self):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
|
||||
class DummyTransaction:
|
||||
def nop(self, *args, **kw):
|
||||
pass
|
||||
|
||||
def __getattr__(self, _):
|
||||
return self.nop
|
||||
|
||||
return cast(dns.transaction.Transaction, DummyTransaction())
|
||||
|
||||
if isinstance(zone, str):
|
||||
zone = dns.name.from_text(zone)
|
||||
rdtype = dns.rdatatype.RdataType.make(rdtype)
|
||||
q = dns.message.make_query(zone, rdtype, rdclass)
|
||||
if rdtype == dns.rdatatype.IXFR:
|
||||
rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial)
|
||||
q.authority.append(rrset)
|
||||
rrset = q.find_rrset(
|
||||
q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
|
||||
)
|
||||
soa = dns.rdata.from_text("IN", "SOA", ". . %u 0 0 0 0" % serial)
|
||||
rrset.add(soa, 0)
|
||||
if keyring is not None:
|
||||
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
|
||||
wire = q.to_wire()
|
||||
(af, destination, source) = _destination_and_source(
|
||||
where, port, source, source_port
|
||||
)
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
tm = DummyTransactionManager(zone, relativize)
|
||||
if use_udp and rdtype != dns.rdatatype.IXFR:
|
||||
raise ValueError("cannot do a UDP AXFR")
|
||||
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
|
||||
with _make_socket(af, sock_type, source) as s:
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
_connect(s, destination, expiration)
|
||||
l = len(wire)
|
||||
if use_udp:
|
||||
_udp_send(s, wire, None, expiration)
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", l) + wire
|
||||
_net_write(s, tcpmsg, expiration)
|
||||
done = False
|
||||
delete_mode = True
|
||||
expecting_SOA = False
|
||||
soa_rrset = None
|
||||
if relativize:
|
||||
origin = zone
|
||||
oname = dns.name.empty
|
||||
else:
|
||||
origin = None
|
||||
oname = zone
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if use_udp:
|
||||
(wire, _) = _udp_recv(s, 65535, mexpiration)
|
||||
else:
|
||||
ldata = _net_read(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
wire = _net_read(s, l, mexpiration)
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
r = dns.message.from_wire(
|
||||
wire,
|
||||
keyring=q.keyring,
|
||||
request_mac=q.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=True,
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
rcode = r.rcode()
|
||||
if rcode != dns.rcode.NOERROR:
|
||||
raise TransferError(rcode)
|
||||
tsig_ctx = r.tsig_ctx
|
||||
answer_index = 0
|
||||
if soa_rrset is None:
|
||||
if not r.answer or r.answer[0].name != oname:
|
||||
raise dns.exception.FormError("No answer or RRset not for qname")
|
||||
rrset = r.answer[0]
|
||||
if rrset.rdtype != dns.rdatatype.SOA:
|
||||
raise dns.exception.FormError("first RRset is not an SOA")
|
||||
answer_index = 1
|
||||
soa_rrset = rrset.copy()
|
||||
if rdtype == dns.rdatatype.IXFR:
|
||||
if dns.serial.Serial(soa_rrset[0].serial) <= serial:
|
||||
#
|
||||
# We're already up-to-date.
|
||||
#
|
||||
done = True
|
||||
else:
|
||||
expecting_SOA = True
|
||||
#
|
||||
# Process SOAs in the answer section (other than the initial
|
||||
# SOA in the first message).
|
||||
#
|
||||
for rrset in r.answer[answer_index:]:
|
||||
if done:
|
||||
raise dns.exception.FormError("answers after final SOA")
|
||||
if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
|
||||
if expecting_SOA:
|
||||
if rrset[0].serial != serial:
|
||||
raise dns.exception.FormError("IXFR base serial mismatch")
|
||||
expecting_SOA = False
|
||||
elif rdtype == dns.rdatatype.IXFR:
|
||||
delete_mode = not delete_mode
|
||||
#
|
||||
# If this SOA RRset is equal to the first we saw then we're
|
||||
# finished. If this is an IXFR we also check that we're
|
||||
# seeing the record in the expected part of the response.
|
||||
#
|
||||
if rrset == soa_rrset and (
|
||||
rdtype == dns.rdatatype.AXFR
|
||||
or (rdtype == dns.rdatatype.IXFR and delete_mode)
|
||||
):
|
||||
done = True
|
||||
elif expecting_SOA:
|
||||
#
|
||||
# We made an IXFR request and are expecting another
|
||||
# SOA RR, but saw something else, so this must be an
|
||||
# AXFR response.
|
||||
#
|
||||
rdtype = dns.rdatatype.AXFR
|
||||
expecting_SOA = False
|
||||
if done and q.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
yield r
|
||||
|
||||
|
||||
class UDPMode(enum.IntEnum):
|
||||
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
|
||||
|
||||
NEVER means "never use UDP; always use TCP"
|
||||
TRY_FIRST means "try to use UDP but fall back to TCP if needed"
|
||||
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
|
||||
"""
|
||||
|
||||
NEVER = 0
|
||||
TRY_FIRST = 1
|
||||
ONLY = 2
|
||||
yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
|
||||
|
||||
|
||||
def inbound_xfr(
|
||||
@ -1514,65 +1641,25 @@ def inbound_xfr(
|
||||
(query, serial) = dns.xfr.make_query(txn_manager)
|
||||
else:
|
||||
serial = dns.xfr.extract_serial_from_query(query)
|
||||
rdtype = query.question[0].rdtype
|
||||
is_ixfr = rdtype == dns.rdatatype.IXFR
|
||||
origin = txn_manager.from_wire_origin()
|
||||
wire = query.to_wire()
|
||||
|
||||
(af, destination, source) = _destination_and_source(
|
||||
where, port, source, source_port
|
||||
)
|
||||
(_, expiration) = _compute_times(lifetime)
|
||||
retry = True
|
||||
while retry:
|
||||
retry = False
|
||||
if is_ixfr and udp_mode != UDPMode.NEVER:
|
||||
sock_type = socket.SOCK_DGRAM
|
||||
is_udp = True
|
||||
else:
|
||||
sock_type = socket.SOCK_STREAM
|
||||
is_udp = False
|
||||
with _make_socket(af, sock_type, source) as s:
|
||||
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
|
||||
with _make_socket(af, socket.SOCK_DGRAM, source) as s:
|
||||
_connect(s, destination, expiration)
|
||||
if is_udp:
|
||||
_udp_send(s, wire, None, expiration)
|
||||
else:
|
||||
tcpmsg = struct.pack("!H", len(wire)) + wire
|
||||
_net_write(s, tcpmsg, expiration)
|
||||
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
|
||||
done = False
|
||||
tsig_ctx = None
|
||||
while not done:
|
||||
(_, mexpiration) = _compute_times(timeout)
|
||||
if mexpiration is None or (
|
||||
expiration is not None and mexpiration > expiration
|
||||
):
|
||||
mexpiration = expiration
|
||||
if is_udp:
|
||||
(rwire, _) = _udp_recv(s, 65535, mexpiration)
|
||||
else:
|
||||
ldata = _net_read(s, 2, mexpiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
rwire = _net_read(s, l, mexpiration)
|
||||
r = dns.message.from_wire(
|
||||
rwire,
|
||||
keyring=query.keyring,
|
||||
request_mac=query.mac,
|
||||
xfr=True,
|
||||
origin=origin,
|
||||
tsig_ctx=tsig_ctx,
|
||||
multi=(not is_udp),
|
||||
one_rr_per_rrset=is_ixfr,
|
||||
)
|
||||
try:
|
||||
done = inbound.process_message(r)
|
||||
except dns.xfr.UseTCP:
|
||||
assert is_udp # should not happen if we used TCP!
|
||||
if udp_mode == UDPMode.ONLY:
|
||||
raise
|
||||
done = True
|
||||
retry = True
|
||||
udp_mode = UDPMode.NEVER
|
||||
continue
|
||||
tsig_ctx = r.tsig_ctx
|
||||
if not retry and query.keyring and not r.had_tsig:
|
||||
raise dns.exception.FormError("missing TSIG")
|
||||
try:
|
||||
for _ in _inbound_xfr(
|
||||
txn_manager, s, query, serial, timeout, expiration
|
||||
):
|
||||
pass
|
||||
return
|
||||
except dns.xfr.UseTCP:
|
||||
if udp_mode == UDPMode.ONLY:
|
||||
raise
|
||||
|
||||
with _make_socket(af, socket.SOCK_STREAM, source) as s:
|
||||
_connect(s, destination, expiration)
|
||||
for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
|
||||
pass
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import dns._features
|
||||
import dns.asyncbackend
|
||||
|
||||
@ -73,3 +75,6 @@ else: # pragma: no cover
|
||||
class SyncQuicConnection: # type: ignore
|
||||
def make_stream(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Headers = List[Tuple[bytes, bytes]]
|
||||
|
@ -43,12 +43,26 @@ class AsyncioQuicStream(BaseQuicStream):
|
||||
raise dns.exception.Timeout
|
||||
self._expecting = 0
|
||||
|
||||
async def wait_for_end(self, expiration):
|
||||
while True:
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
if self._buffer.seen_end():
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
|
||||
except TimeoutError:
|
||||
raise dns.exception.Timeout
|
||||
|
||||
async def receive(self, timeout=None):
|
||||
expiration = self._expiration_from_timeout(timeout)
|
||||
await self.wait_for(2, expiration)
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
await self.wait_for(size, expiration)
|
||||
return self._buffer.get(size)
|
||||
if self._connection.is_h3():
|
||||
await self.wait_for_end(expiration)
|
||||
return self._buffer.get_all()
|
||||
else:
|
||||
await self.wait_for(2, expiration)
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
await self.wait_for(size, expiration)
|
||||
return self._buffer.get(size)
|
||||
|
||||
async def send(self, datagram, is_end=False):
|
||||
data = self._encapsulate(datagram)
|
||||
@ -83,6 +97,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||
self._wake_timer = asyncio.Condition()
|
||||
self._receiver_task = None
|
||||
self._sender_task = None
|
||||
self._wake_pending = False
|
||||
|
||||
async def _receiver(self):
|
||||
try:
|
||||
@ -104,19 +119,24 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||
self._connection.receive_datagram(datagram, address, time.time())
|
||||
# Wake up the timer in case the sender is sleeping, as there may be
|
||||
# stuff to send now.
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
await self._wakeup()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._done = True
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
await self._wakeup()
|
||||
self._handshake_complete.set()
|
||||
|
||||
async def _wakeup(self):
|
||||
self._wake_pending = True
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
|
||||
async def _wait_for_wake_timer(self):
|
||||
async with self._wake_timer:
|
||||
await self._wake_timer.wait()
|
||||
if not self._wake_pending:
|
||||
await self._wake_timer.wait()
|
||||
self._wake_pending = False
|
||||
|
||||
async def _sender(self):
|
||||
await self._socket_created.wait()
|
||||
@ -140,9 +160,28 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||
if event is None:
|
||||
return
|
||||
if isinstance(event, aioquic.quic.events.StreamDataReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(event.data, event.end_stream)
|
||||
if self.is_h3():
|
||||
h3_events = self._h3_conn.handle_event(event)
|
||||
for h3_event in h3_events:
|
||||
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
if stream._headers is None:
|
||||
stream._headers = h3_event.headers
|
||||
elif stream._trailers is None:
|
||||
stream._trailers = h3_event.headers
|
||||
if h3_event.stream_ended:
|
||||
await stream._add_input(b"", True)
|
||||
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(
|
||||
h3_event.data, h3_event.stream_ended
|
||||
)
|
||||
else:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
|
||||
@ -161,8 +200,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||
|
||||
async def write(self, stream, data, is_end=False):
|
||||
self._connection.send_stream_data(stream, data, is_end)
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
await self._wakeup()
|
||||
|
||||
def run(self):
|
||||
if self._closed:
|
||||
@ -189,8 +227,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||
self._connection.close()
|
||||
# sender might be blocked on this, so set it
|
||||
self._socket_created.set()
|
||||
async with self._wake_timer:
|
||||
self._wake_timer.notify_all()
|
||||
await self._wakeup()
|
||||
try:
|
||||
await self._receiver_task
|
||||
except asyncio.CancelledError:
|
||||
@ -203,8 +240,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
|
||||
|
||||
|
||||
class AsyncioQuicManager(AsyncQuicManager):
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
|
||||
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
|
||||
def __init__(
|
||||
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
|
||||
):
|
||||
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
|
||||
|
||||
def connect(
|
||||
self, address, port=853, source=None, source_port=0, want_session_ticket=True
|
||||
|
@ -1,12 +1,16 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import functools
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
import urllib
|
||||
from typing import Any, Optional
|
||||
|
||||
import aioquic.h3.connection # type: ignore
|
||||
import aioquic.h3.events # type: ignore
|
||||
import aioquic.quic.configuration # type: ignore
|
||||
import aioquic.quic.connection # type: ignore
|
||||
|
||||
@ -51,6 +55,12 @@ class Buffer:
|
||||
self._buffer = self._buffer[amount:]
|
||||
return data
|
||||
|
||||
def get_all(self):
|
||||
assert self.seen_end()
|
||||
data = self._buffer
|
||||
self._buffer = b""
|
||||
return data
|
||||
|
||||
|
||||
class BaseQuicStream:
|
||||
def __init__(self, connection, stream_id):
|
||||
@ -58,10 +68,18 @@ class BaseQuicStream:
|
||||
self._stream_id = stream_id
|
||||
self._buffer = Buffer()
|
||||
self._expecting = 0
|
||||
self._headers = None
|
||||
self._trailers = None
|
||||
|
||||
def id(self):
|
||||
return self._stream_id
|
||||
|
||||
def headers(self):
|
||||
return self._headers
|
||||
|
||||
def trailers(self):
|
||||
return self._trailers
|
||||
|
||||
def _expiration_from_timeout(self, timeout):
|
||||
if timeout is not None:
|
||||
expiration = time.time() + timeout
|
||||
@ -77,16 +95,51 @@ class BaseQuicStream:
|
||||
return timeout
|
||||
|
||||
# Subclass must implement receive() as sync / async and which returns a message
|
||||
# or raises UnexpectedEOF.
|
||||
# or raises.
|
||||
|
||||
# Subclass must implement send() as sync / async and which takes a message and
|
||||
# an EOF indicator.
|
||||
|
||||
def send_h3(self, url, datagram, post=True):
|
||||
if not self._connection.is_h3():
|
||||
raise SyntaxError("cannot send H3 to a non-H3 connection")
|
||||
url_parts = urllib.parse.urlparse(url)
|
||||
path = url_parts.path.encode()
|
||||
if post:
|
||||
method = b"POST"
|
||||
else:
|
||||
method = b"GET"
|
||||
path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
|
||||
headers = [
|
||||
(b":method", method),
|
||||
(b":scheme", url_parts.scheme.encode()),
|
||||
(b":authority", url_parts.netloc.encode()),
|
||||
(b":path", path),
|
||||
(b"accept", b"application/dns-message"),
|
||||
]
|
||||
if post:
|
||||
headers.extend(
|
||||
[
|
||||
(b"content-type", b"application/dns-message"),
|
||||
(b"content-length", str(len(datagram)).encode()),
|
||||
]
|
||||
)
|
||||
self._connection.send_headers(self._stream_id, headers, not post)
|
||||
if post:
|
||||
self._connection.send_data(self._stream_id, datagram, True)
|
||||
|
||||
def _encapsulate(self, datagram):
|
||||
if self._connection.is_h3():
|
||||
return datagram
|
||||
l = len(datagram)
|
||||
return struct.pack("!H", l) + datagram
|
||||
|
||||
def _common_add_input(self, data, is_end):
|
||||
self._buffer.put(data, is_end)
|
||||
try:
|
||||
return self._expecting > 0 and self._buffer.have(self._expecting)
|
||||
return (
|
||||
self._expecting > 0 and self._buffer.have(self._expecting)
|
||||
) or self._buffer.seen_end
|
||||
except UnexpectedEOF:
|
||||
return True
|
||||
|
||||
@ -97,7 +150,13 @@ class BaseQuicStream:
|
||||
|
||||
class BaseQuicConnection:
|
||||
def __init__(
|
||||
self, connection, address, port, source=None, source_port=0, manager=None
|
||||
self,
|
||||
connection,
|
||||
address,
|
||||
port,
|
||||
source=None,
|
||||
source_port=0,
|
||||
manager=None,
|
||||
):
|
||||
self._done = False
|
||||
self._connection = connection
|
||||
@ -106,6 +165,10 @@ class BaseQuicConnection:
|
||||
self._closed = False
|
||||
self._manager = manager
|
||||
self._streams = {}
|
||||
if manager.is_h3():
|
||||
self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
|
||||
else:
|
||||
self._h3_conn = None
|
||||
self._af = dns.inet.af_for_address(address)
|
||||
self._peer = dns.inet.low_level_address_tuple((address, port))
|
||||
if source is None and source_port != 0:
|
||||
@ -120,9 +183,18 @@ class BaseQuicConnection:
|
||||
else:
|
||||
self._source = None
|
||||
|
||||
def is_h3(self):
|
||||
return self._h3_conn is not None
|
||||
|
||||
def close_stream(self, stream_id):
|
||||
del self._streams[stream_id]
|
||||
|
||||
def send_headers(self, stream_id, headers, is_end=False):
|
||||
self._h3_conn.send_headers(stream_id, headers, is_end)
|
||||
|
||||
def send_data(self, stream_id, data, is_end=False):
|
||||
self._h3_conn.send_data(stream_id, data, is_end)
|
||||
|
||||
def _get_timer_values(self, closed_is_special=True):
|
||||
now = time.time()
|
||||
expiration = self._connection.get_timer()
|
||||
@ -148,17 +220,25 @@ class AsyncQuicConnection(BaseQuicConnection):
|
||||
|
||||
|
||||
class BaseQuicManager:
|
||||
def __init__(self, conf, verify_mode, connection_factory, server_name=None):
|
||||
def __init__(
|
||||
self, conf, verify_mode, connection_factory, server_name=None, h3=False
|
||||
):
|
||||
self._connections = {}
|
||||
self._connection_factory = connection_factory
|
||||
self._session_tickets = {}
|
||||
self._tokens = {}
|
||||
self._h3 = h3
|
||||
if conf is None:
|
||||
verify_path = None
|
||||
if isinstance(verify_mode, str):
|
||||
verify_path = verify_mode
|
||||
verify_mode = True
|
||||
if h3:
|
||||
alpn_protocols = ["h3"]
|
||||
else:
|
||||
alpn_protocols = ["doq", "doq-i03"]
|
||||
conf = aioquic.quic.configuration.QuicConfiguration(
|
||||
alpn_protocols=["doq", "doq-i03"],
|
||||
alpn_protocols=alpn_protocols,
|
||||
verify_mode=verify_mode,
|
||||
server_name=server_name,
|
||||
)
|
||||
@ -167,7 +247,13 @@ class BaseQuicManager:
|
||||
self._conf = conf
|
||||
|
||||
def _connect(
|
||||
self, address, port=853, source=None, source_port=0, want_session_ticket=True
|
||||
self,
|
||||
address,
|
||||
port=853,
|
||||
source=None,
|
||||
source_port=0,
|
||||
want_session_ticket=True,
|
||||
want_token=True,
|
||||
):
|
||||
connection = self._connections.get((address, port))
|
||||
if connection is not None:
|
||||
@ -189,9 +275,24 @@ class BaseQuicManager:
|
||||
)
|
||||
else:
|
||||
session_ticket_handler = None
|
||||
if want_token:
|
||||
try:
|
||||
token = self._tokens.pop((address, port))
|
||||
# We found a token, so make a configuration that uses it.
|
||||
conf = copy.copy(conf)
|
||||
conf.token = token
|
||||
except KeyError:
|
||||
# No token
|
||||
pass
|
||||
# Whether or not we found a token, we want a handler to save # one.
|
||||
token_handler = functools.partial(self.save_token, address, port)
|
||||
else:
|
||||
token_handler = None
|
||||
|
||||
qconn = aioquic.quic.connection.QuicConnection(
|
||||
configuration=conf,
|
||||
session_ticket_handler=session_ticket_handler,
|
||||
token_handler=token_handler,
|
||||
)
|
||||
lladdress = dns.inet.low_level_address_tuple((address, port))
|
||||
qconn.connect(lladdress, time.time())
|
||||
@ -207,6 +308,9 @@ class BaseQuicManager:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def is_h3(self):
|
||||
return self._h3
|
||||
|
||||
def save_session_ticket(self, address, port, ticket):
|
||||
# We rely on dictionaries keys() being in insertion order here. We
|
||||
# can't just popitem() as that would be LIFO which is the opposite of
|
||||
@ -218,6 +322,17 @@ class BaseQuicManager:
|
||||
del self._session_tickets[key]
|
||||
self._session_tickets[(address, port)] = ticket
|
||||
|
||||
def save_token(self, address, port, token):
|
||||
# We rely on dictionaries keys() being in insertion order here. We
|
||||
# can't just popitem() as that would be LIFO which is the opposite of
|
||||
# what we want.
|
||||
l = len(self._tokens)
|
||||
if l >= MAX_SESSION_TICKETS:
|
||||
keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
|
||||
for key in keys_to_delete:
|
||||
del self._tokens[key]
|
||||
self._tokens[(address, port)] = token
|
||||
|
||||
|
||||
class AsyncQuicManager(BaseQuicManager):
|
||||
def connect(self, address, port=853, source=None, source_port=0):
|
||||
|
@ -21,11 +21,9 @@ from dns.quic._common import (
|
||||
UnexpectedEOF,
|
||||
)
|
||||
|
||||
# Avoid circularity with dns.query
|
||||
if hasattr(selectors, "PollSelector"):
|
||||
_selector_class = selectors.PollSelector # type: ignore
|
||||
else:
|
||||
_selector_class = selectors.SelectSelector # type: ignore
|
||||
# Function used to create a socket. Can be overridden if needed in special
|
||||
# situations.
|
||||
socket_factory = socket.socket
|
||||
|
||||
|
||||
class SyncQuicStream(BaseQuicStream):
|
||||
@ -46,14 +44,29 @@ class SyncQuicStream(BaseQuicStream):
|
||||
raise dns.exception.Timeout
|
||||
self._expecting = 0
|
||||
|
||||
def wait_for_end(self, expiration):
|
||||
while True:
|
||||
timeout = self._timeout_from_expiration(expiration)
|
||||
with self._lock:
|
||||
if self._buffer.seen_end():
|
||||
return
|
||||
with self._wake_up:
|
||||
if not self._wake_up.wait(timeout):
|
||||
raise dns.exception.Timeout
|
||||
|
||||
def receive(self, timeout=None):
|
||||
expiration = self._expiration_from_timeout(timeout)
|
||||
self.wait_for(2, expiration)
|
||||
with self._lock:
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
self.wait_for(size, expiration)
|
||||
with self._lock:
|
||||
return self._buffer.get(size)
|
||||
if self._connection.is_h3():
|
||||
self.wait_for_end(expiration)
|
||||
with self._lock:
|
||||
return self._buffer.get_all()
|
||||
else:
|
||||
self.wait_for(2, expiration)
|
||||
with self._lock:
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
self.wait_for(size, expiration)
|
||||
with self._lock:
|
||||
return self._buffer.get(size)
|
||||
|
||||
def send(self, datagram, is_end=False):
|
||||
data = self._encapsulate(datagram)
|
||||
@ -81,7 +94,7 @@ class SyncQuicStream(BaseQuicStream):
|
||||
class SyncQuicConnection(BaseQuicConnection):
|
||||
def __init__(self, connection, address, port, source, source_port, manager):
|
||||
super().__init__(connection, address, port, source, source_port, manager)
|
||||
self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
|
||||
self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
|
||||
if self._source is not None:
|
||||
try:
|
||||
self._socket.bind(
|
||||
@ -118,7 +131,7 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||
|
||||
def _worker(self):
|
||||
try:
|
||||
sel = _selector_class()
|
||||
sel = selectors.DefaultSelector()
|
||||
sel.register(self._socket, selectors.EVENT_READ, self._read)
|
||||
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
|
||||
while not self._done:
|
||||
@ -140,6 +153,7 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||
finally:
|
||||
with self._lock:
|
||||
self._done = True
|
||||
self._socket.close()
|
||||
# Ensure anyone waiting for this gets woken up.
|
||||
self._handshake_complete.set()
|
||||
|
||||
@ -150,10 +164,29 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||
if event is None:
|
||||
return
|
||||
if isinstance(event, aioquic.quic.events.StreamDataReceived):
|
||||
with self._lock:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
stream._add_input(event.data, event.end_stream)
|
||||
if self.is_h3():
|
||||
h3_events = self._h3_conn.handle_event(event)
|
||||
for h3_event in h3_events:
|
||||
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
|
||||
with self._lock:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
if stream._headers is None:
|
||||
stream._headers = h3_event.headers
|
||||
elif stream._trailers is None:
|
||||
stream._trailers = h3_event.headers
|
||||
if h3_event.stream_ended:
|
||||
stream._add_input(b"", True)
|
||||
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
|
||||
with self._lock:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
stream._add_input(h3_event.data, h3_event.stream_ended)
|
||||
else:
|
||||
with self._lock:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
|
||||
@ -170,6 +203,18 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||
self._connection.send_stream_data(stream, data, is_end)
|
||||
self._send_wakeup.send(b"\x01")
|
||||
|
||||
def send_headers(self, stream_id, headers, is_end=False):
|
||||
with self._lock:
|
||||
super().send_headers(stream_id, headers, is_end)
|
||||
if is_end:
|
||||
self._send_wakeup.send(b"\x01")
|
||||
|
||||
def send_data(self, stream_id, data, is_end=False):
|
||||
with self._lock:
|
||||
super().send_data(stream_id, data, is_end)
|
||||
if is_end:
|
||||
self._send_wakeup.send(b"\x01")
|
||||
|
||||
def run(self):
|
||||
if self._closed:
|
||||
return
|
||||
@ -203,16 +248,24 @@ class SyncQuicConnection(BaseQuicConnection):
|
||||
|
||||
|
||||
class SyncQuicManager(BaseQuicManager):
|
||||
def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
|
||||
super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
|
||||
def __init__(
|
||||
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
|
||||
):
|
||||
super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def connect(
|
||||
self, address, port=853, source=None, source_port=0, want_session_ticket=True
|
||||
self,
|
||||
address,
|
||||
port=853,
|
||||
source=None,
|
||||
source_port=0,
|
||||
want_session_ticket=True,
|
||||
want_token=True,
|
||||
):
|
||||
with self._lock:
|
||||
(connection, start) = self._connect(
|
||||
address, port, source, source_port, want_session_ticket
|
||||
address, port, source, source_port, want_session_ticket, want_token
|
||||
)
|
||||
if start:
|
||||
connection.run()
|
||||
@ -226,6 +279,10 @@ class SyncQuicManager(BaseQuicManager):
|
||||
with self._lock:
|
||||
super().save_session_ticket(address, port, ticket)
|
||||
|
||||
def save_token(self, address, port, token):
|
||||
with self._lock:
|
||||
super().save_token(address, port, token)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
@ -36,16 +36,27 @@ class TrioQuicStream(BaseQuicStream):
|
||||
await self._wake_up.wait()
|
||||
self._expecting = 0
|
||||
|
||||
async def wait_for_end(self):
|
||||
while True:
|
||||
if self._buffer.seen_end():
|
||||
return
|
||||
async with self._wake_up:
|
||||
await self._wake_up.wait()
|
||||
|
||||
async def receive(self, timeout=None):
|
||||
if timeout is None:
|
||||
context = NullContext(None)
|
||||
else:
|
||||
context = trio.move_on_after(timeout)
|
||||
with context:
|
||||
await self.wait_for(2)
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
await self.wait_for(size)
|
||||
return self._buffer.get(size)
|
||||
if self._connection.is_h3():
|
||||
await self.wait_for_end()
|
||||
return self._buffer.get_all()
|
||||
else:
|
||||
await self.wait_for(2)
|
||||
(size,) = struct.unpack("!H", self._buffer.get(2))
|
||||
await self.wait_for(size)
|
||||
return self._buffer.get(size)
|
||||
raise dns.exception.Timeout
|
||||
|
||||
async def send(self, datagram, is_end=False):
|
||||
@ -115,6 +126,7 @@ class TrioQuicConnection(AsyncQuicConnection):
|
||||
await self._socket.send(datagram)
|
||||
finally:
|
||||
self._done = True
|
||||
self._socket.close()
|
||||
self._handshake_complete.set()
|
||||
|
||||
async def _handle_events(self):
|
||||
@ -124,9 +136,28 @@ class TrioQuicConnection(AsyncQuicConnection):
|
||||
if event is None:
|
||||
return
|
||||
if isinstance(event, aioquic.quic.events.StreamDataReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(event.data, event.end_stream)
|
||||
if self.is_h3():
|
||||
h3_events = self._h3_conn.handle_event(event)
|
||||
for h3_event in h3_events:
|
||||
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
if stream._headers is None:
|
||||
stream._headers = h3_event.headers
|
||||
elif stream._trailers is None:
|
||||
stream._trailers = h3_event.headers
|
||||
if h3_event.stream_ended:
|
||||
await stream._add_input(b"", True)
|
||||
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(
|
||||
h3_event.data, h3_event.stream_ended
|
||||
)
|
||||
else:
|
||||
stream = self._streams.get(event.stream_id)
|
||||
if stream:
|
||||
await stream._add_input(event.data, event.end_stream)
|
||||
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
|
||||
self._handshake_complete.set()
|
||||
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
|
||||
@ -183,9 +214,14 @@ class TrioQuicConnection(AsyncQuicConnection):
|
||||
|
||||
class TrioQuicManager(AsyncQuicManager):
|
||||
def __init__(
|
||||
self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
|
||||
self,
|
||||
nursery,
|
||||
conf=None,
|
||||
verify_mode=ssl.CERT_REQUIRED,
|
||||
server_name=None,
|
||||
h3=False,
|
||||
):
|
||||
super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
|
||||
super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
|
||||
self._nursery = nursery
|
||||
|
||||
def connect(
|
||||
|
@ -214,7 +214,7 @@ class Rdata:
|
||||
compress: Optional[dns.name.CompressType] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
canonicalize: bool = False,
|
||||
) -> bytes:
|
||||
) -> None:
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
def to_wire(
|
||||
@ -223,14 +223,19 @@ class Rdata:
|
||||
compress: Optional[dns.name.CompressType] = None,
|
||||
origin: Optional[dns.name.Name] = None,
|
||||
canonicalize: bool = False,
|
||||
) -> bytes:
|
||||
) -> Optional[bytes]:
|
||||
"""Convert an rdata to wire format.
|
||||
|
||||
Returns a ``bytes`` or ``None``.
|
||||
Returns a ``bytes`` if no output file was specified, or ``None`` otherwise.
|
||||
"""
|
||||
|
||||
if file:
|
||||
return self._to_wire(file, compress, origin, canonicalize)
|
||||
# We call _to_wire() and then return None explicitly instead of
|
||||
# of just returning the None from _to_wire() as mypy's func-returns-value
|
||||
# unhelpfully errors out with "error: "_to_wire" of "Rdata" does not return
|
||||
# a value (it only ever returns None)"
|
||||
self._to_wire(file, compress, origin, canonicalize)
|
||||
return None
|
||||
else:
|
||||
f = io.BytesIO()
|
||||
self._to_wire(f, compress, origin, canonicalize)
|
||||
@ -253,8 +258,9 @@ class Rdata:
|
||||
|
||||
Returns a ``bytes``.
|
||||
"""
|
||||
|
||||
return self.to_wire(origin=origin, canonicalize=True)
|
||||
wire = self.to_wire(origin=origin, canonicalize=True)
|
||||
assert wire is not None # for mypy
|
||||
return wire
|
||||
|
||||
def __repr__(self):
|
||||
covers = self.covers()
|
||||
@ -434,15 +440,11 @@ class Rdata:
|
||||
continue
|
||||
if key not in parameters:
|
||||
raise AttributeError(
|
||||
"'{}' object has no attribute '{}'".format(
|
||||
self.__class__.__name__, key
|
||||
)
|
||||
f"'{self.__class__.__name__}' object has no attribute '{key}'"
|
||||
)
|
||||
if key in ("rdclass", "rdtype"):
|
||||
raise AttributeError(
|
||||
"Cannot overwrite '{}' attribute '{}'".format(
|
||||
self.__class__.__name__, key
|
||||
)
|
||||
f"Cannot overwrite '{self.__class__.__name__}' attribute '{key}'"
|
||||
)
|
||||
|
||||
# Construct the parameter list. For each field, use the value in
|
||||
@ -646,13 +648,14 @@ _rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType],
|
||||
{}
|
||||
)
|
||||
_module_prefix = "dns.rdtypes"
|
||||
_dynamic_load_allowed = True
|
||||
|
||||
|
||||
def get_rdata_class(rdclass, rdtype):
|
||||
def get_rdata_class(rdclass, rdtype, use_generic=True):
|
||||
cls = _rdata_classes.get((rdclass, rdtype))
|
||||
if not cls:
|
||||
cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype))
|
||||
if not cls:
|
||||
if not cls and _dynamic_load_allowed:
|
||||
rdclass_text = dns.rdataclass.to_text(rdclass)
|
||||
rdtype_text = dns.rdatatype.to_text(rdtype)
|
||||
rdtype_text = rdtype_text.replace("-", "_")
|
||||
@ -670,12 +673,36 @@ def get_rdata_class(rdclass, rdtype):
|
||||
_rdata_classes[(rdclass, rdtype)] = cls
|
||||
except ImportError:
|
||||
pass
|
||||
if not cls:
|
||||
if not cls and use_generic:
|
||||
cls = GenericRdata
|
||||
_rdata_classes[(rdclass, rdtype)] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def load_all_types(disable_dynamic_load=True):
|
||||
"""Load all rdata types for which dnspython has a non-generic implementation.
|
||||
|
||||
Normally dnspython loads DNS rdatatype implementations on demand, but in some
|
||||
specialized cases loading all types at an application-controlled time is preferred.
|
||||
|
||||
If *disable_dynamic_load*, a ``bool``, is ``True`` then dnspython will not attempt
|
||||
to use its dynamic loading mechanism if an unknown type is subsequently encountered,
|
||||
and will simply use the ``GenericRdata`` class.
|
||||
"""
|
||||
# Load class IN and ANY types.
|
||||
for rdtype in dns.rdatatype.RdataType:
|
||||
get_rdata_class(dns.rdataclass.IN, rdtype, False)
|
||||
# Load the one non-ANY implementation we have in CH. Everything
|
||||
# else in CH is an ANY type, and we'll discover those on demand but won't
|
||||
# have to import anything.
|
||||
get_rdata_class(dns.rdataclass.CH, dns.rdatatype.A, False)
|
||||
if disable_dynamic_load:
|
||||
# Now disable dynamic loading so any subsequent unknown type immediately becomes
|
||||
# GenericRdata without a load attempt.
|
||||
global _dynamic_load_allowed
|
||||
_dynamic_load_allowed = False
|
||||
|
||||
|
||||
def from_text(
|
||||
rdclass: Union[dns.rdataclass.RdataClass, str],
|
||||
rdtype: Union[dns.rdatatype.RdataType, str],
|
||||
|
@ -160,7 +160,7 @@ class Rdataset(dns.set.Set):
|
||||
return s[:100] + "..."
|
||||
return s
|
||||
|
||||
return "[%s]" % ", ".join("<%s>" % maybe_truncate(str(rr)) for rr in self)
|
||||
return "[" + ", ".join(f"<{maybe_truncate(str(rr))}>" for rr in self) + "]"
|
||||
|
||||
def __repr__(self):
|
||||
if self.covers == 0:
|
||||
@ -248,12 +248,8 @@ class Rdataset(dns.set.Set):
|
||||
# (which is meaningless anyway).
|
||||
#
|
||||
s.write(
|
||||
"{}{}{} {}\n".format(
|
||||
ntext,
|
||||
pad,
|
||||
dns.rdataclass.to_text(rdclass),
|
||||
dns.rdatatype.to_text(self.rdtype),
|
||||
)
|
||||
f"{ntext}{pad}{dns.rdataclass.to_text(rdclass)} "
|
||||
f"{dns.rdatatype.to_text(self.rdtype)}\n"
|
||||
)
|
||||
else:
|
||||
for rd in self:
|
||||
|
@ -105,6 +105,8 @@ class RdataType(dns.enum.IntEnum):
|
||||
CAA = 257
|
||||
AVC = 258
|
||||
AMTRELAY = 260
|
||||
RESINFO = 261
|
||||
WALLET = 262
|
||||
TA = 32768
|
||||
DLV = 32769
|
||||
|
||||
@ -125,7 +127,7 @@ class RdataType(dns.enum.IntEnum):
|
||||
if text.find("-") >= 0:
|
||||
try:
|
||||
return cls[text.replace("-", "_")]
|
||||
except KeyError:
|
||||
except KeyError: # pragma: no cover
|
||||
pass
|
||||
return _registered_by_text.get(text)
|
||||
|
||||
@ -326,6 +328,8 @@ URI = RdataType.URI
|
||||
CAA = RdataType.CAA
|
||||
AVC = RdataType.AVC
|
||||
AMTRELAY = RdataType.AMTRELAY
|
||||
RESINFO = RdataType.RESINFO
|
||||
WALLET = RdataType.WALLET
|
||||
TA = RdataType.TA
|
||||
DLV = RdataType.DLV
|
||||
|
||||
|
@ -75,8 +75,9 @@ class GPOS(dns.rdata.Rdata):
|
||||
raise dns.exception.FormError("bad longitude")
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return "{} {} {}".format(
|
||||
self.latitude.decode(), self.longitude.decode(), self.altitude.decode()
|
||||
return (
|
||||
f"{self.latitude.decode()} {self.longitude.decode()} "
|
||||
f"{self.altitude.decode()}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -37,9 +37,7 @@ class HINFO(dns.rdata.Rdata):
|
||||
self.os = self._as_bytes(os, True, 255)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return '"{}" "{}"'.format(
|
||||
dns.rdata._escapify(self.cpu), dns.rdata._escapify(self.os)
|
||||
)
|
||||
return f'"{dns.rdata._escapify(self.cpu)}" "{dns.rdata._escapify(self.os)}"'
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
|
@ -48,7 +48,7 @@ class HIP(dns.rdata.Rdata):
|
||||
for server in self.servers:
|
||||
servers.append(server.choose_relativity(origin, relativize))
|
||||
if len(servers) > 0:
|
||||
text += " " + " ".join((x.to_unicode() for x in servers))
|
||||
text += " " + " ".join(x.to_unicode() for x in servers)
|
||||
return "%u %s %s%s" % (self.algorithm, hit, key, text)
|
||||
|
||||
@classmethod
|
||||
|
@ -38,11 +38,12 @@ class ISDN(dns.rdata.Rdata):
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
if self.subaddress:
|
||||
return '"{}" "{}"'.format(
|
||||
dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)
|
||||
return (
|
||||
f'"{dns.rdata._escapify(self.address)}" '
|
||||
f'"{dns.rdata._escapify(self.subaddress)}"'
|
||||
)
|
||||
else:
|
||||
return '"%s"' % dns.rdata._escapify(self.address)
|
||||
return f'"{dns.rdata._escapify(self.address)}"'
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
|
@ -44,7 +44,7 @@ def _exponent_of(what, desc):
|
||||
exp = i - 1
|
||||
break
|
||||
if exp is None or exp < 0:
|
||||
raise dns.exception.SyntaxError("%s value out of bounds" % desc)
|
||||
raise dns.exception.SyntaxError(f"{desc} value out of bounds")
|
||||
return exp
|
||||
|
||||
|
||||
@ -83,10 +83,10 @@ def _encode_size(what, desc):
|
||||
def _decode_size(what, desc):
|
||||
exponent = what & 0x0F
|
||||
if exponent > 9:
|
||||
raise dns.exception.FormError("bad %s exponent" % desc)
|
||||
raise dns.exception.FormError(f"bad {desc} exponent")
|
||||
base = (what & 0xF0) >> 4
|
||||
if base > 9:
|
||||
raise dns.exception.FormError("bad %s base" % desc)
|
||||
raise dns.exception.FormError(f"bad {desc} base")
|
||||
return base * pow(10, exponent)
|
||||
|
||||
|
||||
@ -184,10 +184,9 @@ class LOC(dns.rdata.Rdata):
|
||||
or self.horizontal_precision != _default_hprec
|
||||
or self.vertical_precision != _default_vprec
|
||||
):
|
||||
text += " {:0.2f}m {:0.2f}m {:0.2f}m".format(
|
||||
self.size / 100.0,
|
||||
self.horizontal_precision / 100.0,
|
||||
self.vertical_precision / 100.0,
|
||||
text += (
|
||||
f" {self.size / 100.0:0.2f}m {self.horizontal_precision / 100.0:0.2f}m"
|
||||
f" {self.vertical_precision / 100.0:0.2f}m"
|
||||
)
|
||||
return text
|
||||
|
||||
|
@ -44,7 +44,7 @@ class NSEC(dns.rdata.Rdata):
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
next = self.next.choose_relativity(origin, relativize)
|
||||
text = Bitmap(self.windows).to_text()
|
||||
return "{}{}".format(next, text)
|
||||
return f"{next}{text}"
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
|
24
lib/dns/rdtypes/ANY/RESINFO.py
Normal file
24
lib/dns/rdtypes/ANY/RESINFO.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
|
||||
#
|
||||
# Permission to use, copy, modify, and distribute this software and its
|
||||
# documentation for any purpose with or without fee is hereby granted,
|
||||
# provided that the above copyright notice and this permission notice
|
||||
# appear in all copies.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
|
||||
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
|
||||
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
||||
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
import dns.immutable
|
||||
import dns.rdtypes.txtbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class RESINFO(dns.rdtypes.txtbase.TXTBase):
|
||||
"""RESINFO record"""
|
@ -37,7 +37,7 @@ class RP(dns.rdata.Rdata):
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
mbox = self.mbox.choose_relativity(origin, relativize)
|
||||
txt = self.txt.choose_relativity(origin, relativize)
|
||||
return "{} {}".format(str(mbox), str(txt))
|
||||
return f"{str(mbox)} {str(txt)}"
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
|
@ -69,7 +69,7 @@ class TKEY(dns.rdata.Rdata):
|
||||
dns.rdata._base64ify(self.key, 0),
|
||||
)
|
||||
if len(self.other) > 0:
|
||||
text += " %s" % (dns.rdata._base64ify(self.other, 0))
|
||||
text += f" {dns.rdata._base64ify(self.other, 0)}"
|
||||
|
||||
return text
|
||||
|
||||
|
9
lib/dns/rdtypes/ANY/WALLET.py
Normal file
9
lib/dns/rdtypes/ANY/WALLET.py
Normal file
@ -0,0 +1,9 @@
|
||||
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
||||
|
||||
import dns.immutable
|
||||
import dns.rdtypes.txtbase
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class WALLET(dns.rdtypes.txtbase.TXTBase):
|
||||
"""WALLET record"""
|
@ -36,7 +36,7 @@ class X25(dns.rdata.Rdata):
|
||||
self.address = self._as_bytes(address, True, 255)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return '"%s"' % dns.rdata._escapify(self.address)
|
||||
return f'"{dns.rdata._escapify(self.address)}"'
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
|
@ -51,6 +51,7 @@ __all__ = [
|
||||
"OPENPGPKEY",
|
||||
"OPT",
|
||||
"PTR",
|
||||
"RESINFO",
|
||||
"RP",
|
||||
"RRSIG",
|
||||
"RT",
|
||||
@ -63,6 +64,7 @@ __all__ = [
|
||||
"TSIG",
|
||||
"TXT",
|
||||
"URI",
|
||||
"WALLET",
|
||||
"X25",
|
||||
"ZONEMD",
|
||||
]
|
||||
|
@ -37,7 +37,7 @@ class A(dns.rdata.Rdata):
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
domain = self.domain.choose_relativity(origin, relativize)
|
||||
return "%s %o" % (domain, self.address)
|
||||
return f"{domain} {self.address:o}"
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
|
@ -36,7 +36,7 @@ class NSAP(dns.rdata.Rdata):
|
||||
self.address = self._as_bytes(address)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
return "0x%s" % binascii.hexlify(self.address).decode()
|
||||
return f"0x{binascii.hexlify(self.address).decode()}"
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
|
@ -36,7 +36,7 @@ class EUIBase(dns.rdata.Rdata):
|
||||
self.eui = self._as_bytes(eui)
|
||||
if len(self.eui) != self.byte_len:
|
||||
raise dns.exception.FormError(
|
||||
"EUI%s rdata has to have %s bytes" % (self.byte_len * 8, self.byte_len)
|
||||
f"EUI{self.byte_len * 8} rdata has to have {self.byte_len} bytes"
|
||||
)
|
||||
|
||||
def to_text(self, origin=None, relativize=True, **kw):
|
||||
@ -49,16 +49,16 @@ class EUIBase(dns.rdata.Rdata):
|
||||
text = tok.get_string()
|
||||
if len(text) != cls.text_len:
|
||||
raise dns.exception.SyntaxError(
|
||||
"Input text must have %s characters" % cls.text_len
|
||||
f"Input text must have {cls.text_len} characters"
|
||||
)
|
||||
for i in range(2, cls.byte_len * 3 - 1, 3):
|
||||
if text[i] != "-":
|
||||
raise dns.exception.SyntaxError("Dash expected at position %s" % i)
|
||||
raise dns.exception.SyntaxError(f"Dash expected at position {i}")
|
||||
text = text.replace("-", "")
|
||||
try:
|
||||
data = binascii.unhexlify(text.encode())
|
||||
except (ValueError, TypeError) as ex:
|
||||
raise dns.exception.SyntaxError("Hex decoding error: %s" % str(ex))
|
||||
raise dns.exception.SyntaxError(f"Hex decoding error: {str(ex)}")
|
||||
return cls(rdclass, rdtype, data)
|
||||
|
||||
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
|
||||
|
@ -35,6 +35,7 @@ class ParamKey(dns.enum.IntEnum):
|
||||
ECH = 5
|
||||
IPV6HINT = 6
|
||||
DOHPATH = 7
|
||||
OHTTP = 8
|
||||
|
||||
@classmethod
|
||||
def _maximum(cls):
|
||||
@ -396,6 +397,36 @@ class ECHParam(Param):
|
||||
file.write(self.ech)
|
||||
|
||||
|
||||
@dns.immutable.immutable
|
||||
class OHTTPParam(Param):
|
||||
# We don't ever expect to instantiate this class, but we need
|
||||
# a from_value() and a from_wire_parser(), so we just return None
|
||||
# from the class methods when things are OK.
|
||||
|
||||
@classmethod
|
||||
def emptiness(cls):
|
||||
return Emptiness.ALWAYS
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value):
|
||||
if value is None or value == "":
|
||||
return None
|
||||
else:
|
||||
raise ValueError("ohttp with non-empty value")
|
||||
|
||||
def to_text(self):
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def from_wire_parser(cls, parser, origin=None): # pylint: disable=W0613
|
||||
if parser.remaining() != 0:
|
||||
raise dns.exception.FormError
|
||||
return None
|
||||
|
||||
def to_wire(self, file, origin=None): # pylint: disable=W0613
|
||||
raise NotImplementedError # pragma: no cover
|
||||
|
||||
|
||||
_class_for_key = {
|
||||
ParamKey.MANDATORY: MandatoryParam,
|
||||
ParamKey.ALPN: ALPNParam,
|
||||
@ -404,6 +435,7 @@ _class_for_key = {
|
||||
ParamKey.IPV4HINT: IPv4HintParam,
|
||||
ParamKey.ECH: ECHParam,
|
||||
ParamKey.IPV6HINT: IPv6HintParam,
|
||||
ParamKey.OHTTP: OHTTPParam,
|
||||
}
|
||||
|
||||
|
||||
|
@ -50,6 +50,8 @@ class TXTBase(dns.rdata.Rdata):
|
||||
self.strings: Tuple[bytes] = self._as_tuple(
|
||||
strings, lambda x: self._as_bytes(x, True, 255)
|
||||
)
|
||||
if len(self.strings) == 0:
|
||||
raise ValueError("the list of strings must not be empty")
|
||||
|
||||
def to_text(
|
||||
self,
|
||||
@ -60,7 +62,7 @@ class TXTBase(dns.rdata.Rdata):
|
||||
txt = ""
|
||||
prefix = ""
|
||||
for s in self.strings:
|
||||
txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s))
|
||||
txt += f'{prefix}"{dns.rdata._escapify(s)}"'
|
||||
prefix = " "
|
||||
return txt
|
||||
|
||||
|
@ -231,7 +231,7 @@ def weighted_processing_order(iterable):
|
||||
total = sum(rdata._processing_weight() or _no_weight for rdata in rdatas)
|
||||
while len(rdatas) > 1:
|
||||
r = random.uniform(0, total)
|
||||
for n, rdata in enumerate(rdatas):
|
||||
for n, rdata in enumerate(rdatas): # noqa: B007
|
||||
weight = rdata._processing_weight() or _no_weight
|
||||
if weight > r:
|
||||
break
|
||||
|
@ -36,6 +36,7 @@ import dns.ipv4
|
||||
import dns.ipv6
|
||||
import dns.message
|
||||
import dns.name
|
||||
import dns.rdata
|
||||
import dns.nameserver
|
||||
import dns.query
|
||||
import dns.rcode
|
||||
@ -45,7 +46,7 @@ import dns.rdtypes.svcbbase
|
||||
import dns.reversename
|
||||
import dns.tsig
|
||||
|
||||
if sys.platform == "win32":
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
import dns.win32util
|
||||
|
||||
|
||||
@ -83,7 +84,7 @@ class NXDOMAIN(dns.exception.DNSException):
|
||||
else:
|
||||
msg = "The DNS query name does not exist"
|
||||
qnames = ", ".join(map(str, qnames))
|
||||
return "{}: {}".format(msg, qnames)
|
||||
return f"{msg}: {qnames}"
|
||||
|
||||
@property
|
||||
def canonical_name(self):
|
||||
@ -96,7 +97,7 @@ class NXDOMAIN(dns.exception.DNSException):
|
||||
cname = response.canonical_name()
|
||||
if cname != qname:
|
||||
return cname
|
||||
except Exception:
|
||||
except Exception: # pragma: no cover
|
||||
# We can just eat this exception as it means there was
|
||||
# something wrong with the response.
|
||||
pass
|
||||
@ -154,7 +155,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
|
||||
"""Turn a resolution errors trace into a list of text."""
|
||||
texts = []
|
||||
for err in errors:
|
||||
texts.append("Server {} answered {}".format(err[0], err[3]))
|
||||
texts.append(f"Server {err[0]} answered {err[3]}")
|
||||
return texts
|
||||
|
||||
|
||||
@ -162,7 +163,7 @@ class LifetimeTimeout(dns.exception.Timeout):
|
||||
"""The resolution lifetime expired."""
|
||||
|
||||
msg = "The resolution lifetime expired."
|
||||
fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1]
|
||||
fmt = f"{msg[:-1]} after {{timeout:.3f}} seconds: {{errors}}"
|
||||
supp_kwargs = {"timeout", "errors"}
|
||||
|
||||
# We do this as otherwise mypy complains about unexpected keyword argument
|
||||
@ -211,7 +212,7 @@ class NoNameservers(dns.exception.DNSException):
|
||||
"""
|
||||
|
||||
msg = "All nameservers failed to answer the query."
|
||||
fmt = "%s {query}: {errors}" % msg[:-1]
|
||||
fmt = f"{msg[:-1]} {{query}}: {{errors}}"
|
||||
supp_kwargs = {"request", "errors"}
|
||||
|
||||
# We do this as otherwise mypy complains about unexpected keyword argument
|
||||
@ -297,7 +298,7 @@ class Answer:
|
||||
def __len__(self) -> int:
|
||||
return self.rrset and len(self.rrset) or 0
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[dns.rdata.Rdata]:
|
||||
return self.rrset and iter(self.rrset) or iter(tuple())
|
||||
|
||||
def __getitem__(self, i):
|
||||
@ -334,7 +335,7 @@ class HostAnswers(Answers):
|
||||
answers[dns.rdatatype.A] = v4
|
||||
return answers
|
||||
|
||||
# Returns pairs of (address, family) from this result, potentiallys
|
||||
# Returns pairs of (address, family) from this result, potentially
|
||||
# filtering by address family.
|
||||
def addresses_and_families(
|
||||
self, family: int = socket.AF_UNSPEC
|
||||
@ -347,7 +348,7 @@ class HostAnswers(Answers):
|
||||
answer = self.get(dns.rdatatype.AAAA)
|
||||
elif family == socket.AF_INET:
|
||||
answer = self.get(dns.rdatatype.A)
|
||||
else:
|
||||
else: # pragma: no cover
|
||||
raise NotImplementedError(f"unknown address family {family}")
|
||||
if answer:
|
||||
for rdata in answer:
|
||||
@ -938,7 +939,7 @@ class BaseResolver:
|
||||
|
||||
self.reset()
|
||||
if configure:
|
||||
if sys.platform == "win32":
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
self.read_registry()
|
||||
elif filename:
|
||||
self.read_resolv_conf(filename)
|
||||
@ -947,7 +948,7 @@ class BaseResolver:
|
||||
"""Reset all resolver configuration to the defaults."""
|
||||
|
||||
self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
|
||||
if len(self.domain) == 0:
|
||||
if len(self.domain) == 0: # pragma: no cover
|
||||
self.domain = dns.name.root
|
||||
self._nameservers = []
|
||||
self.nameserver_ports = {}
|
||||
@ -1040,7 +1041,7 @@ class BaseResolver:
|
||||
# setter logic, with additonal checking and enrichment.
|
||||
self.nameservers = nameservers
|
||||
|
||||
def read_registry(self) -> None:
|
||||
def read_registry(self) -> None: # pragma: no cover
|
||||
"""Extract resolver configuration from the Windows registry."""
|
||||
try:
|
||||
info = dns.win32util.get_dns_info() # type: ignore
|
||||
@ -1205,9 +1206,7 @@ class BaseResolver:
|
||||
enriched_nameservers.append(enriched_nameserver)
|
||||
else:
|
||||
raise ValueError(
|
||||
"nameservers must be a list or tuple (not a {})".format(
|
||||
type(nameservers)
|
||||
)
|
||||
f"nameservers must be a list or tuple (not a {type(nameservers)})"
|
||||
)
|
||||
return enriched_nameservers
|
||||
|
||||
@ -1431,7 +1430,7 @@ class Resolver(BaseResolver):
|
||||
elif family == socket.AF_INET6:
|
||||
v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
|
||||
return HostAnswers.make(v6=v6)
|
||||
elif family != socket.AF_UNSPEC:
|
||||
elif family != socket.AF_UNSPEC: # pragma: no cover
|
||||
raise NotImplementedError(f"unknown address family {family}")
|
||||
|
||||
raise_on_no_answer = modified_kwargs.pop("raise_on_no_answer", True)
|
||||
@ -1515,7 +1514,7 @@ class Resolver(BaseResolver):
|
||||
nameservers = dns._ddr._get_nameservers_sync(answer, timeout)
|
||||
if len(nameservers) > 0:
|
||||
self.nameservers = nameservers
|
||||
except Exception:
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
@ -1640,7 +1639,7 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
|
||||
return get_default_resolver().canonical_name(name)
|
||||
|
||||
|
||||
def try_ddr(lifetime: float = 5.0) -> None:
|
||||
def try_ddr(lifetime: float = 5.0) -> None: # pragma: no cover
|
||||
"""Try to update the default resolver's nameservers using Discovery of Designated
|
||||
Resolvers (DDR). If successful, the resolver will subsequently use
|
||||
DNS-over-HTTPS or DNS-over-TLS for future queries.
|
||||
@ -1926,7 +1925,7 @@ def _getnameinfo(sockaddr, flags=0):
|
||||
family = socket.AF_INET
|
||||
tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, socket.SOL_TCP, 0)
|
||||
if len(tuples) > 1:
|
||||
raise socket.error("sockaddr resolved to multiple addresses")
|
||||
raise OSError("sockaddr resolved to multiple addresses")
|
||||
addr = tuples[0][4][0]
|
||||
if flags & socket.NI_DGRAM:
|
||||
pname = "udp"
|
||||
@ -1961,7 +1960,7 @@ def _getfqdn(name=None):
|
||||
(name, _, _) = _gethostbyaddr(name)
|
||||
# Python's version checks aliases too, but our gethostbyname
|
||||
# ignores them, so we do so here as well.
|
||||
except Exception:
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
return name
|
||||
|
||||
|
@ -21,10 +21,11 @@ import itertools
|
||||
class Set:
|
||||
"""A simple set class.
|
||||
|
||||
This class was originally used to deal with sets being missing in
|
||||
ancient versions of python, but dnspython will continue to use it
|
||||
as these sets are based on lists and are thus indexable, and this
|
||||
ability is widely used in dnspython applications.
|
||||
This class was originally used to deal with python not having a set class, and
|
||||
originally the class used lists in its implementation. The ordered and indexable
|
||||
nature of RRsets and Rdatasets is unfortunately widely used in dnspython
|
||||
applications, so for backwards compatibility sets continue to be a custom class, now
|
||||
based on an ordered dictionary.
|
||||
"""
|
||||
|
||||
__slots__ = ["items"]
|
||||
@ -43,7 +44,7 @@ class Set:
|
||||
self.add(item) # lgtm[py/init-calls-subclass]
|
||||
|
||||
def __repr__(self):
|
||||
return "dns.set.Set(%s)" % repr(list(self.items.keys()))
|
||||
return f"dns.set.Set({repr(list(self.items.keys()))})" # pragma: no cover
|
||||
|
||||
def add(self, item):
|
||||
"""Add an item to the set."""
|
||||
|
@ -528,7 +528,7 @@ class Tokenizer:
|
||||
if value < 0 or value > 65535:
|
||||
if base == 8:
|
||||
raise dns.exception.SyntaxError(
|
||||
"%o is not an octal unsigned 16-bit integer" % value
|
||||
f"{value:o} is not an octal unsigned 16-bit integer"
|
||||
)
|
||||
else:
|
||||
raise dns.exception.SyntaxError(
|
||||
|
@ -486,7 +486,7 @@ class Transaction:
|
||||
if exact:
|
||||
raise DeleteNotExact(f"{method}: missing rdataset")
|
||||
else:
|
||||
self._delete_rdataset(name, rdtype, covers)
|
||||
self._checked_delete_rdataset(name, rdtype, covers)
|
||||
return
|
||||
else:
|
||||
rdataset = self._rdataset_from_args(method, True, args)
|
||||
@ -529,8 +529,6 @@ class Transaction:
|
||||
|
||||
def _end(self, commit):
|
||||
self._check_ended()
|
||||
if self._ended:
|
||||
raise AlreadyEnded
|
||||
try:
|
||||
self._end_transaction(commit)
|
||||
finally:
|
||||
|
@ -73,7 +73,7 @@ def from_text(text: str) -> int:
|
||||
elif c == "s":
|
||||
total += current
|
||||
else:
|
||||
raise BadTTL("unknown unit '%s'" % c)
|
||||
raise BadTTL(f"unknown unit '{c}'")
|
||||
current = 0
|
||||
need_digit = True
|
||||
if not current == 0:
|
||||
|
@ -20,9 +20,9 @@
|
||||
#: MAJOR
|
||||
MAJOR = 2
|
||||
#: MINOR
|
||||
MINOR = 6
|
||||
MINOR = 7
|
||||
#: MICRO
|
||||
MICRO = 1
|
||||
MICRO = 0
|
||||
#: RELEASELEVEL
|
||||
RELEASELEVEL = 0x0F
|
||||
#: SERIAL
|
||||
|
@ -13,8 +13,8 @@ if sys.platform == "win32":
|
||||
|
||||
# Keep pylint quiet on non-windows.
|
||||
try:
|
||||
WindowsError is None # pylint: disable=used-before-assignment
|
||||
except KeyError:
|
||||
_ = WindowsError # pylint: disable=used-before-assignment
|
||||
except NameError:
|
||||
WindowsError = Exception
|
||||
|
||||
if dns._features.have("wmi"):
|
||||
@ -44,6 +44,7 @@ if sys.platform == "win32":
|
||||
if _have_wmi:
|
||||
|
||||
class _WMIGetter(threading.Thread):
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.info = DnsInfo()
|
||||
@ -82,32 +83,21 @@ if sys.platform == "win32":
|
||||
def __init__(self):
|
||||
self.info = DnsInfo()
|
||||
|
||||
def _determine_split_char(self, entry):
|
||||
#
|
||||
# The windows registry irritatingly changes the list element
|
||||
# delimiter in between ' ' and ',' (and vice-versa) in various
|
||||
# versions of windows.
|
||||
#
|
||||
if entry.find(" ") >= 0:
|
||||
split_char = " "
|
||||
elif entry.find(",") >= 0:
|
||||
split_char = ","
|
||||
else:
|
||||
# probably a singleton; treat as a space-separated list.
|
||||
split_char = " "
|
||||
return split_char
|
||||
def _split(self, text):
|
||||
# The windows registry has used both " " and "," as a delimiter, and while
|
||||
# it is currently using "," in Windows 10 and later, updates can seemingly
|
||||
# leave a space in too, e.g. "a, b". So we just convert all commas to
|
||||
# spaces, and use split() in its default configuration, which splits on
|
||||
# all whitespace and ignores empty strings.
|
||||
return text.replace(",", " ").split()
|
||||
|
||||
def _config_nameservers(self, nameservers):
|
||||
split_char = self._determine_split_char(nameservers)
|
||||
ns_list = nameservers.split(split_char)
|
||||
for ns in ns_list:
|
||||
for ns in self._split(nameservers):
|
||||
if ns not in self.info.nameservers:
|
||||
self.info.nameservers.append(ns)
|
||||
|
||||
def _config_search(self, search):
|
||||
split_char = self._determine_split_char(search)
|
||||
search_list = search.split(split_char)
|
||||
for s in search_list:
|
||||
for s in self._split(search):
|
||||
s = _config_domain(s)
|
||||
if s not in self.info.search:
|
||||
self.info.search.append(s)
|
||||
@ -164,7 +154,7 @@ if sys.platform == "win32":
|
||||
lm,
|
||||
r"SYSTEM\CurrentControlSet\Control\Network"
|
||||
r"\{4D36E972-E325-11CE-BFC1-08002BE10318}"
|
||||
r"\%s\Connection" % guid,
|
||||
rf"\{guid}\Connection",
|
||||
)
|
||||
|
||||
try:
|
||||
@ -177,7 +167,7 @@ if sys.platform == "win32":
|
||||
raise ValueError # pragma: no cover
|
||||
|
||||
device_key = winreg.OpenKey(
|
||||
lm, r"SYSTEM\CurrentControlSet\Enum\%s" % pnp_id
|
||||
lm, rf"SYSTEM\CurrentControlSet\Enum\{pnp_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
@ -232,7 +222,7 @@ if sys.platform == "win32":
|
||||
self._config_fromkey(key, False)
|
||||
finally:
|
||||
key.Close()
|
||||
except EnvironmentError:
|
||||
except OSError:
|
||||
break
|
||||
finally:
|
||||
interfaces.Close()
|
||||
|
@ -33,7 +33,7 @@ class TransferError(dns.exception.DNSException):
|
||||
"""A zone transfer response got a non-zero rcode."""
|
||||
|
||||
def __init__(self, rcode):
|
||||
message = "Zone transfer error: %s" % dns.rcode.to_text(rcode)
|
||||
message = f"Zone transfer error: {dns.rcode.to_text(rcode)}"
|
||||
super().__init__(message)
|
||||
self.rcode = rcode
|
||||
|
||||
|
@ -230,7 +230,7 @@ class Reader:
|
||||
try:
|
||||
rdtype = dns.rdatatype.from_text(token.value)
|
||||
except Exception:
|
||||
raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value)
|
||||
raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'")
|
||||
|
||||
try:
|
||||
rd = dns.rdata.from_text(
|
||||
@ -251,9 +251,7 @@ class Reader:
|
||||
# We convert them to syntax errors so that we can emit
|
||||
# helpful filename:line info.
|
||||
(ty, va) = sys.exc_info()[:2]
|
||||
raise dns.exception.SyntaxError(
|
||||
"caught exception {}: {}".format(str(ty), str(va))
|
||||
)
|
||||
raise dns.exception.SyntaxError(f"caught exception {str(ty)}: {str(va)}")
|
||||
|
||||
if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
|
||||
# The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
|
||||
@ -281,41 +279,41 @@ class Reader:
|
||||
# Sometimes there are modifiers in the hostname. These come after
|
||||
# the dollar sign. They are in the form: ${offset[,width[,base]]}.
|
||||
# Make names
|
||||
mod = ""
|
||||
sign = "+"
|
||||
offset = "0"
|
||||
width = "0"
|
||||
base = "d"
|
||||
g1 = is_generate1.match(side)
|
||||
if g1:
|
||||
mod, sign, offset, width, base = g1.groups()
|
||||
if sign == "":
|
||||
sign = "+"
|
||||
g2 = is_generate2.match(side)
|
||||
if g2:
|
||||
mod, sign, offset = g2.groups()
|
||||
if sign == "":
|
||||
sign = "+"
|
||||
width = 0
|
||||
base = "d"
|
||||
g3 = is_generate3.match(side)
|
||||
if g3:
|
||||
mod, sign, offset, width = g3.groups()
|
||||
if sign == "":
|
||||
sign = "+"
|
||||
base = "d"
|
||||
else:
|
||||
g2 = is_generate2.match(side)
|
||||
if g2:
|
||||
mod, sign, offset = g2.groups()
|
||||
if sign == "":
|
||||
sign = "+"
|
||||
width = "0"
|
||||
base = "d"
|
||||
else:
|
||||
g3 = is_generate3.match(side)
|
||||
if g3:
|
||||
mod, sign, offset, width = g3.groups()
|
||||
if sign == "":
|
||||
sign = "+"
|
||||
base = "d"
|
||||
|
||||
if not (g1 or g2 or g3):
|
||||
mod = ""
|
||||
sign = "+"
|
||||
offset = 0
|
||||
width = 0
|
||||
base = "d"
|
||||
|
||||
offset = int(offset)
|
||||
width = int(width)
|
||||
ioffset = int(offset)
|
||||
iwidth = int(width)
|
||||
|
||||
if sign not in ["+", "-"]:
|
||||
raise dns.exception.SyntaxError("invalid offset sign %s" % sign)
|
||||
raise dns.exception.SyntaxError(f"invalid offset sign {sign}")
|
||||
if base not in ["d", "o", "x", "X", "n", "N"]:
|
||||
raise dns.exception.SyntaxError("invalid type %s" % base)
|
||||
raise dns.exception.SyntaxError(f"invalid type {base}")
|
||||
|
||||
return mod, sign, offset, width, base
|
||||
return mod, sign, ioffset, iwidth, base
|
||||
|
||||
def _generate_line(self):
|
||||
# range lhs [ttl] [class] type rhs [ comment ]
|
||||
@ -377,7 +375,7 @@ class Reader:
|
||||
if not token.is_identifier():
|
||||
raise dns.exception.SyntaxError
|
||||
except Exception:
|
||||
raise dns.exception.SyntaxError("unknown rdatatype '%s'" % token.value)
|
||||
raise dns.exception.SyntaxError(f"unknown rdatatype '{token.value}'")
|
||||
|
||||
# rhs (required)
|
||||
rhs = token.value
|
||||
@ -412,8 +410,8 @@ class Reader:
|
||||
lzfindex = _format_index(lindex, lbase, lwidth)
|
||||
rzfindex = _format_index(rindex, rbase, rwidth)
|
||||
|
||||
name = lhs.replace("$%s" % (lmod), lzfindex)
|
||||
rdata = rhs.replace("$%s" % (rmod), rzfindex)
|
||||
name = lhs.replace(f"${lmod}", lzfindex)
|
||||
rdata = rhs.replace(f"${rmod}", rzfindex)
|
||||
|
||||
self.last_name = dns.name.from_text(
|
||||
name, self.current_origin, self.tok.idna_codec
|
||||
@ -445,7 +443,7 @@ class Reader:
|
||||
# helpful filename:line info.
|
||||
(ty, va) = sys.exc_info()[:2]
|
||||
raise dns.exception.SyntaxError(
|
||||
"caught exception %s: %s" % (str(ty), str(va))
|
||||
f"caught exception {str(ty)}: {str(va)}"
|
||||
)
|
||||
|
||||
self.txn.add(name, ttl, rd)
|
||||
@ -528,7 +526,7 @@ class Reader:
|
||||
self.default_ttl_known,
|
||||
)
|
||||
)
|
||||
self.current_file = open(filename, "r")
|
||||
self.current_file = open(filename)
|
||||
self.tok = dns.tokenizer.Tokenizer(self.current_file, filename)
|
||||
self.current_origin = new_origin
|
||||
elif c == "$GENERATE":
|
||||
|
@ -7,7 +7,7 @@ cheroot==10.0.1
|
||||
cherrypy==18.10.0
|
||||
cloudinary==1.41.0
|
||||
distro==1.9.0
|
||||
dnspython==2.6.1
|
||||
dnspython==2.7.0
|
||||
facebook-sdk==3.1.0
|
||||
future==1.0.0
|
||||
ga4mp==2.0.4
|
||||
|
Loading…
x
Reference in New Issue
Block a user