authentik/tests/e2e/utils.py

258 lines
9.8 KiB
Python
Raw Normal View History

2020-12-05 21:08:42 +00:00
"""authentik e2e testing utilities"""
2020-11-23 13:24:42 +00:00
import json
from functools import wraps
from glob import glob
from importlib.util import module_from_spec, spec_from_file_location
from inspect import getmembers, isfunction
2020-07-12 15:17:04 +01:00
from os import environ, makedirs
2020-09-11 22:21:11 +01:00
from time import sleep, time
from typing import Any, Callable, Optional
from django.apps import apps
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
from django.db import connection, transaction
2021-02-26 15:46:01 +00:00
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.operations.special import RunPython
from django.db.utils import IntegrityError
from django.test.testcases import TransactionTestCase
from django.urls import reverse
2020-09-11 22:21:11 +01:00
from docker import DockerClient, from_env
from docker.models.containers import Container
from selenium import webdriver
from selenium.common.exceptions import (
NoSuchElementException,
TimeoutException,
WebDriverException,
)
2020-11-23 13:24:42 +00:00
from selenium.webdriver.common.by import By
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
2021-02-26 15:46:01 +00:00
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.remote.webdriver import WebDriver
2021-02-21 22:30:31 +00:00
from selenium.webdriver.remote.webelement import WebElement
from selenium.webdriver.support.ui import WebDriverWait
from structlog.stdlib import get_logger
2020-12-05 21:08:42 +00:00
from authentik.core.api.users import UserSerializer
from authentik.core.models import User
from authentik.managed.manager import ObjectManager
# pylint: disable=invalid-name
2020-06-20 23:26:29 +01:00
def USER() -> User: # noqa
2020-12-05 21:08:42 +00:00
"""Cached function that always returns akadmin"""
return User.objects.get(username="akadmin")
class SeleniumTestCase(StaticLiveServerTestCase):
"""StaticLiveServerTestCase which automatically creates a Webdriver instance"""
2020-09-11 22:21:11 +01:00
container: Optional[Container] = None
2021-02-21 22:30:31 +00:00
wait_timeout: int
2020-09-11 22:21:11 +01:00
def setUp(self):
super().setUp()
2021-02-21 22:30:31 +00:00
self.wait_timeout = 60
makedirs("selenium_screenshots/", exist_ok=True)
self.driver = self._get_driver()
2020-06-20 23:26:29 +01:00
self.driver.maximize_window()
2020-09-29 14:01:01 +01:00
self.driver.implicitly_wait(30)
2021-02-21 22:30:31 +00:00
self.wait = WebDriverWait(self.driver, self.wait_timeout)
self.apply_default_data()
self.logger = get_logger()
2020-09-11 22:21:11 +01:00
if specs := self.get_container_specs():
self.container = self._start_container(specs)
def _start_container(self, specs: dict[str, Any]) -> Container:
2020-09-11 22:21:11 +01:00
client: DockerClient = from_env()
2020-09-29 13:04:23 +01:00
client.images.pull(specs["image"])
2020-09-11 22:21:11 +01:00
container = client.containers.run(**specs)
if "healthcheck" not in specs:
return container
2020-09-11 22:21:11 +01:00
while True:
container.reload()
status = container.attrs.get("State", {}).get("Health", {}).get("Status")
if status == "healthy":
return container
self.logger.info("Container failed healthcheck")
sleep(1)
def get_container_specs(self) -> Optional[dict[str, Any]]:
2020-09-11 22:21:11 +01:00
"""Optionally get container specs which will launched on setup, wait for the container to
be healthy, and deleted again on tearDown"""
return None
def _get_driver(self) -> WebDriver:
return webdriver.Remote(
command_executor="http://localhost:4444/wd/hub",
desired_capabilities=DesiredCapabilities.CHROME,
)
def tearDown(self):
2020-07-23 19:03:35 +01:00
if "TF_BUILD" in environ:
2020-07-12 15:17:04 +01:00
screenshot_file = (
f"selenium_screenshots/{self.__class__.__name__}_{time()}.png"
)
self.driver.save_screenshot(screenshot_file)
self.logger.warning("Saved screenshot", file=screenshot_file)
for line in self.driver.get_log("browser"):
self.logger.warning(
line["message"], source=line["source"], level=line["level"]
)
2020-09-11 22:21:11 +01:00
if self.container:
self.container.kill()
self.driver.quit()
super().tearDown()
2020-06-26 15:21:59 +01:00
def wait_for_url(self, desired_url):
"""Wait until URL is `desired_url`."""
self.wait.until(
lambda driver: driver.current_url == desired_url,
f"URL {self.driver.current_url} doesn't match expected URL {desired_url}",
)
2020-06-26 15:21:59 +01:00
def url(self, view, **kwargs) -> str:
"""reverse `view` with `**kwargs` into full URL using live_server_url"""
return self.live_server_url + reverse(view, kwargs=kwargs)
def shell_url(self, view) -> str:
2020-11-23 13:24:42 +00:00
"""same as self.url() but show URL in shell"""
return f"{self.live_server_url}/#{view}"
2020-11-23 13:24:42 +00:00
2021-02-21 22:30:31 +00:00
def get_shadow_root(
self, selector: str, container: Optional[WebElement] = None
) -> WebElement:
"""Get shadow root element's inner shadowRoot"""
if not container:
container = self.driver
shadow_root = container.find_element(By.CSS_SELECTOR, selector)
element = self.driver.execute_script(
"return arguments[0].shadowRoot", shadow_root
)
return element
2021-02-26 15:46:01 +00:00
def login(self):
"""Do entire login flow and check user afterwards"""
flow_executor = self.get_shadow_root("ak-flow-executor")
identification_stage = self.get_shadow_root(
"ak-stage-identification", flow_executor
)
identification_stage.find_element(
By.CSS_SELECTOR, "input[name=uid_field]"
).click()
identification_stage.find_element(
By.CSS_SELECTOR, "input[name=uid_field]"
).send_keys(USER().username)
identification_stage.find_element(
By.CSS_SELECTOR, "input[name=uid_field]"
).send_keys(Keys.ENTER)
flow_executor = self.get_shadow_root("ak-flow-executor")
password_stage = self.get_shadow_root("ak-stage-password", flow_executor)
password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys(
USER().username
)
password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys(
Keys.ENTER
)
2020-11-23 13:24:42 +00:00
def assert_user(self, expected_user: User):
"""Check users/me API and assert it matches expected_user"""
2020-12-05 21:08:42 +00:00
self.driver.get(self.url("authentik_api:user-me") + "?format=json")
2020-11-23 13:24:42 +00:00
user_json = self.driver.find_element(By.CSS_SELECTOR, "pre").text
user = UserSerializer(data=json.loads(user_json))
user.is_valid()
self.assertEqual(user["username"].value, expected_user.username)
self.assertEqual(user["name"].value, expected_user.name)
self.assertEqual(user["email"].value, expected_user.email)
def apply_default_data(self):
"""apply objects created by migrations after tables have been truncated"""
# Not all default objects are managed, like users for example
# Hence we still have to load all migrations and apply them, then run the ObjectManager
# Find all migration files
# load all functions
migration_files = glob("**/migrations/*.py", recursive=True)
matches = []
for migration in migration_files:
with open(migration, "r+") as migration_file:
# Check if they have a `RunPython`
if "RunPython" in migration_file.read():
matches.append(migration)
with connection.schema_editor() as schema_editor:
for match in matches:
# Load module from file path
spec = spec_from_file_location("", match)
migration_module = module_from_spec(spec)
# pyright: reportGeneralTypeIssues=false
spec.loader.exec_module(migration_module)
# Call all functions from module
for _, func in getmembers(migration_module, isfunction):
with transaction.atomic():
try:
func(apps, schema_editor)
except IntegrityError:
pass
ObjectManager().run()
2021-02-26 15:46:01 +00:00
def apply_migration(app_name: str, migration_name: str):
"""Re-apply migrations that create objects using RunPython before test cases"""
def wrapper_outter(func: Callable):
"""Retry test multiple times"""
loader = MigrationLoader(connection)
2021-02-26 15:46:01 +00:00
@wraps(func)
def wrapper(self: TransactionTestCase, *args, **kwargs):
migration = loader.get_migration(app_name, migration_name)
2021-02-26 15:46:01 +00:00
with connection.schema_editor() as schema_editor:
for operation in migration.operations:
if not isinstance(operation, RunPython):
continue
operation.code(apps, schema_editor)
return func(self, *args, **kwargs)
return wrapper
return wrapper_outter
def retry(max_retires=1, exceptions=None):
"""Retry test multiple times. Default to catching Selenium Timeout Exception"""
if not exceptions:
exceptions = [WebDriverException, TimeoutException, NoSuchElementException]
logger = get_logger()
def retry_actual(func: Callable):
"""Retry test multiple times"""
count = 1
@wraps(func)
def wrapper(self: TransactionTestCase, *args, **kwargs):
"""Run test again if we're below max_retries, including tearDown and
setUp. Otherwise raise the error"""
nonlocal count
try:
return func(self, *args, **kwargs)
# pylint: disable=catching-non-exception
except tuple(exceptions) as exc:
count += 1
if count > max_retires:
logger.debug("Exceeded retry count", exc=exc, test=self)
# pylint: disable=raising-non-exception
raise exc
logger.debug("Retrying on error", exc=exc, test=self)
self.tearDown()
self._post_teardown() # noqa
self.setUp()
return wrapper(self, *args, **kwargs)
return wrapper
return retry_actual