mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-01-06 11:09:57 -08:00
dd9a35df51
* Bump pyjwt from 2.9.0 to 2.10.0 Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.9.0 to 2.10.0. - [Release notes](https://github.com/jpadilla/pyjwt/releases) - [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst) - [Commits](https://github.com/jpadilla/pyjwt/compare/2.9.0...2.10.0) --- updated-dependencies: - dependency-name: pyjwt dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update pyjwt==2.10.0 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci]
876 lines
30 KiB
Python
876 lines
30 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
|
|
|
|
from .exceptions import InvalidKeyError
|
|
from .types import HashlibHash, JWKDict
|
|
from .utils import (
|
|
base64url_decode,
|
|
base64url_encode,
|
|
der_to_raw_signature,
|
|
force_bytes,
|
|
from_base64url_uint,
|
|
is_pem_format,
|
|
is_ssh_key,
|
|
raw_to_der_signature,
|
|
to_base64url_uint,
|
|
)
|
|
|
|
try:
|
|
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
from cryptography.hazmat.primitives.asymmetric.ec import (
|
|
ECDSA,
|
|
SECP256K1,
|
|
SECP256R1,
|
|
SECP384R1,
|
|
SECP521R1,
|
|
EllipticCurve,
|
|
EllipticCurvePrivateKey,
|
|
EllipticCurvePrivateNumbers,
|
|
EllipticCurvePublicKey,
|
|
EllipticCurvePublicNumbers,
|
|
)
|
|
from cryptography.hazmat.primitives.asymmetric.ed448 import (
|
|
Ed448PrivateKey,
|
|
Ed448PublicKey,
|
|
)
|
|
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
|
|
Ed25519PrivateKey,
|
|
Ed25519PublicKey,
|
|
)
|
|
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
|
RSAPrivateKey,
|
|
RSAPrivateNumbers,
|
|
RSAPublicKey,
|
|
RSAPublicNumbers,
|
|
rsa_crt_dmp1,
|
|
rsa_crt_dmq1,
|
|
rsa_crt_iqmp,
|
|
rsa_recover_prime_factors,
|
|
)
|
|
from cryptography.hazmat.primitives.serialization import (
|
|
Encoding,
|
|
NoEncryption,
|
|
PrivateFormat,
|
|
PublicFormat,
|
|
load_pem_private_key,
|
|
load_pem_public_key,
|
|
load_ssh_public_key,
|
|
)
|
|
|
|
has_crypto = True
|
|
except ModuleNotFoundError:
|
|
has_crypto = False
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
# Type aliases for convenience in algorithms method signatures
|
|
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
|
|
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
|
|
AllowedOKPKeys = (
|
|
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
|
|
)
|
|
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
|
|
AllowedPrivateKeys = (
|
|
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
|
|
)
|
|
AllowedPublicKeys = (
|
|
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
|
|
)
|
|
|
|
|
|
requires_cryptography = {
|
|
"RS256",
|
|
"RS384",
|
|
"RS512",
|
|
"ES256",
|
|
"ES256K",
|
|
"ES384",
|
|
"ES521",
|
|
"ES512",
|
|
"PS256",
|
|
"PS384",
|
|
"PS512",
|
|
"EdDSA",
|
|
}
|
|
|
|
|
|
def get_default_algorithms() -> dict[str, Algorithm]:
|
|
"""
|
|
Returns the algorithms that are implemented by the library.
|
|
"""
|
|
default_algorithms = {
|
|
"none": NoneAlgorithm(),
|
|
"HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
|
|
"HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
|
|
"HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
|
|
}
|
|
|
|
if has_crypto:
|
|
default_algorithms.update(
|
|
{
|
|
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
|
|
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
|
|
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
|
|
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
|
|
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
|
|
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
|
|
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
|
|
"ES512": ECAlgorithm(
|
|
ECAlgorithm.SHA512
|
|
), # Backward compat for #219 fix
|
|
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
|
|
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
|
|
"PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
|
|
"EdDSA": OKPAlgorithm(),
|
|
}
|
|
)
|
|
|
|
return default_algorithms
|
|
|
|
|
|
class Algorithm(ABC):
|
|
"""
|
|
The interface for an algorithm used to sign and verify tokens.
|
|
"""
|
|
|
|
def compute_hash_digest(self, bytestr: bytes) -> bytes:
|
|
"""
|
|
Compute a hash digest using the specified algorithm's hash algorithm.
|
|
|
|
If there is no hash algorithm, raises a NotImplementedError.
|
|
"""
|
|
# lookup self.hash_alg if defined in a way that mypy can understand
|
|
hash_alg = getattr(self, "hash_alg", None)
|
|
if hash_alg is None:
|
|
raise NotImplementedError
|
|
|
|
if (
|
|
has_crypto
|
|
and isinstance(hash_alg, type)
|
|
and issubclass(hash_alg, hashes.HashAlgorithm)
|
|
):
|
|
digest = hashes.Hash(hash_alg(), backend=default_backend())
|
|
digest.update(bytestr)
|
|
return bytes(digest.finalize())
|
|
else:
|
|
return bytes(hash_alg(bytestr).digest())
|
|
|
|
@abstractmethod
|
|
def prepare_key(self, key: Any) -> Any:
|
|
"""
|
|
Performs necessary validation and conversions on the key and returns
|
|
the key value in the proper format for sign() and verify().
|
|
"""
|
|
|
|
@abstractmethod
|
|
def sign(self, msg: bytes, key: Any) -> bytes:
|
|
"""
|
|
Returns a digital signature for the specified message
|
|
using the specified key value.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
|
|
"""
|
|
Verifies that the specified digital signature is valid
|
|
for the specified message and key values.
|
|
"""
|
|
|
|
@overload
|
|
@staticmethod
|
|
@abstractmethod
|
|
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
|
|
|
|
@overload
|
|
@staticmethod
|
|
@abstractmethod
|
|
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
|
|
"""
|
|
Serializes a given key into a JWK
|
|
"""
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def from_jwk(jwk: str | JWKDict) -> Any:
|
|
"""
|
|
Deserializes a given key from JWK back into a key object
|
|
"""
|
|
|
|
|
|
class NoneAlgorithm(Algorithm):
|
|
"""
|
|
Placeholder for use when no signing or verification
|
|
operations are required.
|
|
"""
|
|
|
|
def prepare_key(self, key: str | None) -> None:
|
|
if key == "":
|
|
key = None
|
|
|
|
if key is not None:
|
|
raise InvalidKeyError('When alg = "none", key value must be None.')
|
|
|
|
return key
|
|
|
|
def sign(self, msg: bytes, key: None) -> bytes:
|
|
return b""
|
|
|
|
def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
|
|
return False
|
|
|
|
@staticmethod
|
|
def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def from_jwk(jwk: str | JWKDict) -> NoReturn:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class HMACAlgorithm(Algorithm):
|
|
"""
|
|
Performs signing and verification operations using HMAC
|
|
and the specified hash function.
|
|
"""
|
|
|
|
SHA256: ClassVar[HashlibHash] = hashlib.sha256
|
|
SHA384: ClassVar[HashlibHash] = hashlib.sha384
|
|
SHA512: ClassVar[HashlibHash] = hashlib.sha512
|
|
|
|
def __init__(self, hash_alg: HashlibHash) -> None:
|
|
self.hash_alg = hash_alg
|
|
|
|
def prepare_key(self, key: str | bytes) -> bytes:
|
|
key_bytes = force_bytes(key)
|
|
|
|
if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
|
|
raise InvalidKeyError(
|
|
"The specified key is an asymmetric key or x509 certificate and"
|
|
" should not be used as an HMAC secret."
|
|
)
|
|
|
|
return key_bytes
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key_obj: str | bytes, as_dict: Literal[True]
|
|
) -> JWKDict: ... # pragma: no cover
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key_obj: str | bytes, as_dict: Literal[False] = False
|
|
) -> str: ... # pragma: no cover
|
|
|
|
@staticmethod
|
|
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
|
|
jwk = {
|
|
"k": base64url_encode(force_bytes(key_obj)).decode(),
|
|
"kty": "oct",
|
|
}
|
|
|
|
if as_dict:
|
|
return jwk
|
|
else:
|
|
return json.dumps(jwk)
|
|
|
|
@staticmethod
|
|
def from_jwk(jwk: str | JWKDict) -> bytes:
|
|
try:
|
|
if isinstance(jwk, str):
|
|
obj: JWKDict = json.loads(jwk)
|
|
elif isinstance(jwk, dict):
|
|
obj = jwk
|
|
else:
|
|
raise ValueError
|
|
except ValueError:
|
|
raise InvalidKeyError("Key is not valid JSON") from None
|
|
|
|
if obj.get("kty") != "oct":
|
|
raise InvalidKeyError("Not an HMAC key")
|
|
|
|
return base64url_decode(obj["k"])
|
|
|
|
def sign(self, msg: bytes, key: bytes) -> bytes:
|
|
return hmac.new(key, msg, self.hash_alg).digest()
|
|
|
|
def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
|
|
return hmac.compare_digest(sig, self.sign(msg, key))
|
|
|
|
|
|
if has_crypto:
|
|
|
|
class RSAAlgorithm(Algorithm):
|
|
"""
|
|
Performs signing and verification operations using
|
|
RSASSA-PKCS-v1_5 and the specified hash function.
|
|
"""
|
|
|
|
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
|
|
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
|
|
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
|
|
|
|
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
|
|
self.hash_alg = hash_alg
|
|
|
|
def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
|
|
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
|
|
return key
|
|
|
|
if not isinstance(key, (bytes, str)):
|
|
raise TypeError("Expecting a PEM-formatted key.")
|
|
|
|
key_bytes = force_bytes(key)
|
|
|
|
try:
|
|
if key_bytes.startswith(b"ssh-rsa"):
|
|
return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
|
|
else:
|
|
return cast(
|
|
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
|
|
)
|
|
except ValueError:
|
|
try:
|
|
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
|
|
except (ValueError, UnsupportedAlgorithm):
|
|
raise InvalidKeyError(
|
|
"Could not parse the provided public key."
|
|
) from None
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key_obj: AllowedRSAKeys, as_dict: Literal[True]
|
|
) -> JWKDict: ... # pragma: no cover
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
|
|
) -> str: ... # pragma: no cover
|
|
|
|
@staticmethod
|
|
def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
|
|
obj: dict[str, Any] | None = None
|
|
|
|
if hasattr(key_obj, "private_numbers"):
|
|
# Private key
|
|
numbers = key_obj.private_numbers()
|
|
|
|
obj = {
|
|
"kty": "RSA",
|
|
"key_ops": ["sign"],
|
|
"n": to_base64url_uint(numbers.public_numbers.n).decode(),
|
|
"e": to_base64url_uint(numbers.public_numbers.e).decode(),
|
|
"d": to_base64url_uint(numbers.d).decode(),
|
|
"p": to_base64url_uint(numbers.p).decode(),
|
|
"q": to_base64url_uint(numbers.q).decode(),
|
|
"dp": to_base64url_uint(numbers.dmp1).decode(),
|
|
"dq": to_base64url_uint(numbers.dmq1).decode(),
|
|
"qi": to_base64url_uint(numbers.iqmp).decode(),
|
|
}
|
|
|
|
elif hasattr(key_obj, "verify"):
|
|
# Public key
|
|
numbers = key_obj.public_numbers()
|
|
|
|
obj = {
|
|
"kty": "RSA",
|
|
"key_ops": ["verify"],
|
|
"n": to_base64url_uint(numbers.n).decode(),
|
|
"e": to_base64url_uint(numbers.e).decode(),
|
|
}
|
|
else:
|
|
raise InvalidKeyError("Not a public or private key")
|
|
|
|
if as_dict:
|
|
return obj
|
|
else:
|
|
return json.dumps(obj)
|
|
|
|
@staticmethod
|
|
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
|
|
try:
|
|
if isinstance(jwk, str):
|
|
obj = json.loads(jwk)
|
|
elif isinstance(jwk, dict):
|
|
obj = jwk
|
|
else:
|
|
raise ValueError
|
|
except ValueError:
|
|
raise InvalidKeyError("Key is not valid JSON") from None
|
|
|
|
if obj.get("kty") != "RSA":
|
|
raise InvalidKeyError("Not an RSA key") from None
|
|
|
|
if "d" in obj and "e" in obj and "n" in obj:
|
|
# Private key
|
|
if "oth" in obj:
|
|
raise InvalidKeyError(
|
|
"Unsupported RSA private key: > 2 primes not supported"
|
|
)
|
|
|
|
other_props = ["p", "q", "dp", "dq", "qi"]
|
|
props_found = [prop in obj for prop in other_props]
|
|
any_props_found = any(props_found)
|
|
|
|
if any_props_found and not all(props_found):
|
|
raise InvalidKeyError(
|
|
"RSA key must include all parameters if any are present besides d"
|
|
) from None
|
|
|
|
public_numbers = RSAPublicNumbers(
|
|
from_base64url_uint(obj["e"]),
|
|
from_base64url_uint(obj["n"]),
|
|
)
|
|
|
|
if any_props_found:
|
|
numbers = RSAPrivateNumbers(
|
|
d=from_base64url_uint(obj["d"]),
|
|
p=from_base64url_uint(obj["p"]),
|
|
q=from_base64url_uint(obj["q"]),
|
|
dmp1=from_base64url_uint(obj["dp"]),
|
|
dmq1=from_base64url_uint(obj["dq"]),
|
|
iqmp=from_base64url_uint(obj["qi"]),
|
|
public_numbers=public_numbers,
|
|
)
|
|
else:
|
|
d = from_base64url_uint(obj["d"])
|
|
p, q = rsa_recover_prime_factors(
|
|
public_numbers.n, d, public_numbers.e
|
|
)
|
|
|
|
numbers = RSAPrivateNumbers(
|
|
d=d,
|
|
p=p,
|
|
q=q,
|
|
dmp1=rsa_crt_dmp1(d, p),
|
|
dmq1=rsa_crt_dmq1(d, q),
|
|
iqmp=rsa_crt_iqmp(p, q),
|
|
public_numbers=public_numbers,
|
|
)
|
|
|
|
return numbers.private_key()
|
|
elif "n" in obj and "e" in obj:
|
|
# Public key
|
|
return RSAPublicNumbers(
|
|
from_base64url_uint(obj["e"]),
|
|
from_base64url_uint(obj["n"]),
|
|
).public_key()
|
|
else:
|
|
raise InvalidKeyError("Not a public or private key")
|
|
|
|
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
|
|
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
|
|
|
|
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
|
|
try:
|
|
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
|
|
return True
|
|
except InvalidSignature:
|
|
return False
|
|
|
|
class ECAlgorithm(Algorithm):
|
|
"""
|
|
Performs signing and verification operations using
|
|
ECDSA and the specified hash function
|
|
"""
|
|
|
|
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
|
|
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
|
|
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
|
|
|
|
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
|
|
self.hash_alg = hash_alg
|
|
|
|
def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
|
|
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
|
|
return key
|
|
|
|
if not isinstance(key, (bytes, str)):
|
|
raise TypeError("Expecting a PEM-formatted key.")
|
|
|
|
key_bytes = force_bytes(key)
|
|
|
|
# Attempt to load key. We don't know if it's
|
|
# a Signing Key or a Verifying Key, so we try
|
|
# the Verifying Key first.
|
|
try:
|
|
if key_bytes.startswith(b"ecdsa-sha2-"):
|
|
crypto_key = load_ssh_public_key(key_bytes)
|
|
else:
|
|
crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
|
|
except ValueError:
|
|
crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
|
|
|
|
# Explicit check the key to prevent confusing errors from cryptography
|
|
if not isinstance(
|
|
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
|
|
):
|
|
raise InvalidKeyError(
|
|
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
|
|
) from None
|
|
|
|
return crypto_key
|
|
|
|
def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
|
|
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
|
|
|
|
return der_to_raw_signature(der_sig, key.curve)
|
|
|
|
def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
|
|
try:
|
|
der_sig = raw_to_der_signature(sig, key.curve)
|
|
except ValueError:
|
|
return False
|
|
|
|
try:
|
|
public_key = (
|
|
key.public_key()
|
|
if isinstance(key, EllipticCurvePrivateKey)
|
|
else key
|
|
)
|
|
public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
|
|
return True
|
|
except InvalidSignature:
|
|
return False
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key_obj: AllowedECKeys, as_dict: Literal[True]
|
|
) -> JWKDict: ... # pragma: no cover
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key_obj: AllowedECKeys, as_dict: Literal[False] = False
|
|
) -> str: ... # pragma: no cover
|
|
|
|
@staticmethod
|
|
def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
|
|
if isinstance(key_obj, EllipticCurvePrivateKey):
|
|
public_numbers = key_obj.public_key().public_numbers()
|
|
elif isinstance(key_obj, EllipticCurvePublicKey):
|
|
public_numbers = key_obj.public_numbers()
|
|
else:
|
|
raise InvalidKeyError("Not a public or private key")
|
|
|
|
if isinstance(key_obj.curve, SECP256R1):
|
|
crv = "P-256"
|
|
elif isinstance(key_obj.curve, SECP384R1):
|
|
crv = "P-384"
|
|
elif isinstance(key_obj.curve, SECP521R1):
|
|
crv = "P-521"
|
|
elif isinstance(key_obj.curve, SECP256K1):
|
|
crv = "secp256k1"
|
|
else:
|
|
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
|
|
|
|
obj: dict[str, Any] = {
|
|
"kty": "EC",
|
|
"crv": crv,
|
|
"x": to_base64url_uint(
|
|
public_numbers.x,
|
|
bit_length=key_obj.curve.key_size,
|
|
).decode(),
|
|
"y": to_base64url_uint(
|
|
public_numbers.y,
|
|
bit_length=key_obj.curve.key_size,
|
|
).decode(),
|
|
}
|
|
|
|
if isinstance(key_obj, EllipticCurvePrivateKey):
|
|
obj["d"] = to_base64url_uint(
|
|
key_obj.private_numbers().private_value,
|
|
bit_length=key_obj.curve.key_size,
|
|
).decode()
|
|
|
|
if as_dict:
|
|
return obj
|
|
else:
|
|
return json.dumps(obj)
|
|
|
|
@staticmethod
|
|
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
|
|
try:
|
|
if isinstance(jwk, str):
|
|
obj = json.loads(jwk)
|
|
elif isinstance(jwk, dict):
|
|
obj = jwk
|
|
else:
|
|
raise ValueError
|
|
except ValueError:
|
|
raise InvalidKeyError("Key is not valid JSON") from None
|
|
|
|
if obj.get("kty") != "EC":
|
|
raise InvalidKeyError("Not an Elliptic curve key") from None
|
|
|
|
if "x" not in obj or "y" not in obj:
|
|
raise InvalidKeyError("Not an Elliptic curve key") from None
|
|
|
|
x = base64url_decode(obj.get("x"))
|
|
y = base64url_decode(obj.get("y"))
|
|
|
|
curve = obj.get("crv")
|
|
curve_obj: EllipticCurve
|
|
|
|
if curve == "P-256":
|
|
if len(x) == len(y) == 32:
|
|
curve_obj = SECP256R1()
|
|
else:
|
|
raise InvalidKeyError(
|
|
"Coords should be 32 bytes for curve P-256"
|
|
) from None
|
|
elif curve == "P-384":
|
|
if len(x) == len(y) == 48:
|
|
curve_obj = SECP384R1()
|
|
else:
|
|
raise InvalidKeyError(
|
|
"Coords should be 48 bytes for curve P-384"
|
|
) from None
|
|
elif curve == "P-521":
|
|
if len(x) == len(y) == 66:
|
|
curve_obj = SECP521R1()
|
|
else:
|
|
raise InvalidKeyError(
|
|
"Coords should be 66 bytes for curve P-521"
|
|
) from None
|
|
elif curve == "secp256k1":
|
|
if len(x) == len(y) == 32:
|
|
curve_obj = SECP256K1()
|
|
else:
|
|
raise InvalidKeyError(
|
|
"Coords should be 32 bytes for curve secp256k1"
|
|
)
|
|
else:
|
|
raise InvalidKeyError(f"Invalid curve: {curve}")
|
|
|
|
public_numbers = EllipticCurvePublicNumbers(
|
|
x=int.from_bytes(x, byteorder="big"),
|
|
y=int.from_bytes(y, byteorder="big"),
|
|
curve=curve_obj,
|
|
)
|
|
|
|
if "d" not in obj:
|
|
return public_numbers.public_key()
|
|
|
|
d = base64url_decode(obj.get("d"))
|
|
if len(d) != len(x):
|
|
raise InvalidKeyError(
|
|
"D should be {} bytes for curve {}", len(x), curve
|
|
)
|
|
|
|
return EllipticCurvePrivateNumbers(
|
|
int.from_bytes(d, byteorder="big"), public_numbers
|
|
).private_key()
|
|
|
|
class RSAPSSAlgorithm(RSAAlgorithm):
|
|
"""
|
|
Performs a signature using RSASSA-PSS with MGF1
|
|
"""
|
|
|
|
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
|
|
return key.sign(
|
|
msg,
|
|
padding.PSS(
|
|
mgf=padding.MGF1(self.hash_alg()),
|
|
salt_length=self.hash_alg().digest_size,
|
|
),
|
|
self.hash_alg(),
|
|
)
|
|
|
|
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
|
|
try:
|
|
key.verify(
|
|
sig,
|
|
msg,
|
|
padding.PSS(
|
|
mgf=padding.MGF1(self.hash_alg()),
|
|
salt_length=self.hash_alg().digest_size,
|
|
),
|
|
self.hash_alg(),
|
|
)
|
|
return True
|
|
except InvalidSignature:
|
|
return False
|
|
|
|
class OKPAlgorithm(Algorithm):
|
|
"""
|
|
Performs signing and verification operations using EdDSA
|
|
|
|
This class requires ``cryptography>=2.6`` to be installed.
|
|
"""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
pass
|
|
|
|
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
|
|
if isinstance(key, (bytes, str)):
|
|
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
|
|
|
|
if "-----BEGIN PUBLIC" in key_str:
|
|
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
|
|
elif "-----BEGIN PRIVATE" in key_str:
|
|
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
|
|
elif key_str[0:4] == "ssh-":
|
|
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
|
|
|
|
# Explicit check the key to prevent confusing errors from cryptography
|
|
if not isinstance(
|
|
key,
|
|
(Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
|
|
):
|
|
raise InvalidKeyError(
|
|
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
|
|
)
|
|
|
|
return key
|
|
|
|
def sign(
|
|
self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
|
|
) -> bytes:
|
|
"""
|
|
Sign a message ``msg`` using the EdDSA private key ``key``
|
|
:param str|bytes msg: Message to sign
|
|
:param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
|
|
or :class:`.Ed448PrivateKey` isinstance
|
|
:return bytes signature: The signature, as bytes
|
|
"""
|
|
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
|
|
return key.sign(msg_bytes)
|
|
|
|
def verify(
|
|
self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
|
|
) -> bool:
|
|
"""
|
|
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
|
|
|
|
:param str|bytes sig: EdDSA signature to check ``msg`` against
|
|
:param str|bytes msg: Message to sign
|
|
:param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
|
|
A private or public EdDSA key instance
|
|
:return bool verified: True if signature is valid, False if not.
|
|
"""
|
|
try:
|
|
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
|
|
sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
|
|
|
|
public_key = (
|
|
key.public_key()
|
|
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
|
|
else key
|
|
)
|
|
public_key.verify(sig_bytes, msg_bytes)
|
|
return True # If no exception was raised, the signature is valid.
|
|
except InvalidSignature:
|
|
return False
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key: AllowedOKPKeys, as_dict: Literal[True]
|
|
) -> JWKDict: ... # pragma: no cover
|
|
|
|
@overload
|
|
@staticmethod
|
|
def to_jwk(
|
|
key: AllowedOKPKeys, as_dict: Literal[False] = False
|
|
) -> str: ... # pragma: no cover
|
|
|
|
@staticmethod
|
|
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
|
|
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
|
|
x = key.public_bytes(
|
|
encoding=Encoding.Raw,
|
|
format=PublicFormat.Raw,
|
|
)
|
|
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
|
|
|
|
obj = {
|
|
"x": base64url_encode(force_bytes(x)).decode(),
|
|
"kty": "OKP",
|
|
"crv": crv,
|
|
}
|
|
|
|
if as_dict:
|
|
return obj
|
|
else:
|
|
return json.dumps(obj)
|
|
|
|
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
|
|
d = key.private_bytes(
|
|
encoding=Encoding.Raw,
|
|
format=PrivateFormat.Raw,
|
|
encryption_algorithm=NoEncryption(),
|
|
)
|
|
|
|
x = key.public_key().public_bytes(
|
|
encoding=Encoding.Raw,
|
|
format=PublicFormat.Raw,
|
|
)
|
|
|
|
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
|
|
obj = {
|
|
"x": base64url_encode(force_bytes(x)).decode(),
|
|
"d": base64url_encode(force_bytes(d)).decode(),
|
|
"kty": "OKP",
|
|
"crv": crv,
|
|
}
|
|
|
|
if as_dict:
|
|
return obj
|
|
else:
|
|
return json.dumps(obj)
|
|
|
|
raise InvalidKeyError("Not a public or private key")
|
|
|
|
@staticmethod
|
|
def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
|
|
try:
|
|
if isinstance(jwk, str):
|
|
obj = json.loads(jwk)
|
|
elif isinstance(jwk, dict):
|
|
obj = jwk
|
|
else:
|
|
raise ValueError
|
|
except ValueError:
|
|
raise InvalidKeyError("Key is not valid JSON") from None
|
|
|
|
if obj.get("kty") != "OKP":
|
|
raise InvalidKeyError("Not an Octet Key Pair")
|
|
|
|
curve = obj.get("crv")
|
|
if curve != "Ed25519" and curve != "Ed448":
|
|
raise InvalidKeyError(f"Invalid curve: {curve}")
|
|
|
|
if "x" not in obj:
|
|
raise InvalidKeyError('OKP should have "x" parameter')
|
|
x = base64url_decode(obj.get("x"))
|
|
|
|
try:
|
|
if "d" not in obj:
|
|
if curve == "Ed25519":
|
|
return Ed25519PublicKey.from_public_bytes(x)
|
|
return Ed448PublicKey.from_public_bytes(x)
|
|
d = base64url_decode(obj.get("d"))
|
|
if curve == "Ed25519":
|
|
return Ed25519PrivateKey.from_private_bytes(d)
|
|
return Ed448PrivateKey.from_private_bytes(d)
|
|
except ValueError as err:
|
|
raise InvalidKeyError("Invalid key parameter") from err
|