authentik/passbook/providers/saml/processors/request_parser.py

118 lines
4.0 KiB
Python

"""SAML AuthNRequest Parser and dataclass"""
from base64 import b64decode
from dataclasses import dataclass
from typing import Optional
from urllib.parse import quote_plus
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from defusedxml import ElementTree
from signxml import XMLVerifier
from structlog import get_logger
from passbook.providers.saml.exceptions import CannotHandleAssertion
from passbook.providers.saml.models import SAMLProvider
from passbook.providers.saml.utils.encoding import decode_base64_and_inflate
from passbook.sources.saml.processors.constants import (
NS_SAML_PROTOCOL,
SAML_NAME_ID_FORMAT_EMAIL,
)
LOGGER = get_logger()
@dataclass
class AuthNRequest:
"""AuthNRequest Dataclass"""
# pylint: disable=invalid-name
id: Optional[str] = None
relay_state: Optional[str] = None
name_id_policy: str = SAML_NAME_ID_FORMAT_EMAIL
class AuthNRequestParser:
"""AuthNRequest Parser"""
provider: SAMLProvider
def __init__(self, provider: SAMLProvider):
self.provider = provider
def _parse_xml(self, decoded_xml: str, relay_state: Optional[str]) -> AuthNRequest:
root = ElementTree.fromstring(decoded_xml)
request_acs_url = root.attrib["AssertionConsumerServiceURL"]
if self.provider.acs_url != request_acs_url:
msg = (
f"ACS URL of {request_acs_url} doesn't match Provider "
f"ACS URL of {self.provider.acs_url}."
)
LOGGER.info(msg)
raise CannotHandleAssertion(msg)
auth_n_request = AuthNRequest(id=root.attrib["ID"], relay_state=relay_state)
# Check if AuthnRequest has a NameID Policy object
name_id_policies = root.findall(f"{{{NS_SAML_PROTOCOL}}}:NameIDPolicy")
if len(name_id_policies) > 0:
name_id_policy = name_id_policies[0]
auth_n_request.name_id_policy = name_id_policy.attrib["Format"]
return auth_n_request
def parse(self, saml_request: str, relay_state: Optional[str]) -> AuthNRequest:
"""Validate and parse raw request with enveloped signautre."""
decoded_xml = decode_base64_and_inflate(saml_request)
if self.provider.signing_kp:
try:
XMLVerifier().verify(
decoded_xml, x509_cert=self.provider.signing_kp.certificate_data
)
except InvalidSignature as exc:
raise CannotHandleAssertion("Failed to verify signature") from exc
return self._parse_xml(decoded_xml, relay_state)
def parse_detached(
self,
saml_request: str,
relay_state: Optional[str],
signature: Optional[str] = None,
sig_alg: Optional[str] = None,
) -> AuthNRequest:
"""Validate and parse raw request with detached signature"""
decoded_xml = decode_base64_and_inflate(saml_request)
if signature and sig_alg:
# if sig_alg == "http://www.w3.org/2000/09/xmldsig#rsa-sha1":
sig_hash = hashes.SHA1() # nosec
querystring = f"SAMLRequest={quote_plus(saml_request)}&"
if relay_state is not None:
querystring += f"RelayState={quote_plus(relay_state)}&"
querystring += f"SigAlg={sig_alg}"
public_key = self.provider.signing_kp.private_key.public_key()
try:
public_key.verify(
b64decode(signature),
querystring.encode(),
padding.PSS(
mgf=padding.MGF1(sig_hash), salt_length=padding.PSS.MAX_LENGTH
),
sig_hash,
)
except InvalidSignature as exc:
raise CannotHandleAssertion("Failed to verify signature") from exc
return self._parse_xml(decoded_xml, relay_state)
def idp_initiated(self) -> AuthNRequest:
"""Create IdP Initiated AuthNRequest"""
return AuthNRequest()