Source code for inmanta.server.bootloader

    Copyright 2019 Inmanta

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.


import asyncio
import importlib
import logging
import pkgutil
from import Generator
from pkgutil import ModuleInfo
from types import ModuleType
from typing import Optional

import asyncpg

from inmanta import logging as inmanta_logging
from inmanta.logging import FullLoggingConfig, InmantaLoggerConfig
from inmanta.server import config
from inmanta.server.extensions import ApplicationContext, FeatureManager, InvalidSliceNameException
from inmanta.server.protocol import Server, ServerSlice
from inmanta.stable_api import stable_api

LOGGER = logging.getLogger(__name__)

def iter_namespace(ns_pkg: ModuleType) -> Generator[ModuleInfo, None, None]:
    """From python docs"""
    # Specifying the second argument (prefix) to iter_modules makes the
    # returned name an absolute name instead of a relative one. This allows
    # import_module to work without having to do additional modification to
    # the name.
    return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")

class PluginLoadFailed(Exception):

class ConstrainedApplicationContext(ApplicationContext):
    def __init__(self, parent: ApplicationContext, namespace: str) -> None:
        self.parent = parent
        self.namespace = namespace

    def register_slice(self, slice: ServerSlice) -> None:
        name =
        if not name.startswith(self.namespace + "."):
            raise InvalidSliceNameException(f"{name} should be in namespace {self.namespace}")

    def set_feature_manager(self, feature_manager: FeatureManager) -> None:

    def register_default_logging_config(self, log_config_extender: inmanta_logging.LoggingConfigBuilderExtension) -> None:

[docs] @stable_api class InmantaBootloader: """The inmanta bootloader is responsible for: - discovering extensions - loading extensions - loading core and extension slices - starting the server and its slices in the correct order """ # Cache field for available extensions AVAILABLE_EXTENSIONS: Optional[dict[str, str]] = None def __init__(self, configure_logging: bool = False) -> None: """ :param configure_logging: This config option is used by the tests to configure the logging framework. In normal execution, the logging framework is configured by the """ self.restserver = Server() self.started = False self.feature_manager: Optional[FeatureManager] = None # cache for ctx self.ctx: ApplicationContext | None = None if configure_logging: inmanta_logger_config = inmanta_logging.InmantaLoggerConfig.get_instance() inmanta_logger_config.apply_options(inmanta_logging.Options()) async def start(self) -> None: self.start_loggers_for_extensions() db_wait_time: int = config.db_wait_time.get() if db_wait_time != 0: # Wait for the database to be up before starting the server await self.wait_for_db(db_wait_time) ctx = self.load_slices() version = ctx.get_feature_manager().get_product_metadata().version"Starting inmanta-server version %s", version) for mypart in ctx.get_slices(): self.restserver.add_slice(mypart) ctx.get_feature_manager().add_slice(mypart) await self.restserver.start() self.started = True def start_loggers_for_extensions(self, on_config: InmantaLoggerConfig | None = None) -> FullLoggingConfig: ctx = self.load_slices() log_config_extenders = ctx.get_default_log_config_extenders() if on_config is None: on_config = InmantaLoggerConfig.get_current_instance() return on_config.extend_config(log_config_extenders) async def stop(self, timeout: Optional[int] = None) -> None: """ :param timeout: Raises TimeoutError when the server hasn't finished stopping after this amount of seconds. This argument should only be used by test cases. """ if not timeout: await self._stop() else: await asyncio.wait_for(self._stop(), timeout=timeout) async def _stop(self) -> None: await self.restserver.stop() if self.feature_manager is not None: await self.feature_manager.stop() @classmethod def get_available_extensions(cls) -> dict[str, str]: """ Returns a dictionary of all available inmanta extensions. The key contains the name of the extension and the value the fully qualified path to the python package. """ if cls.AVAILABLE_EXTENSIONS is None: try: inmanta_ext = importlib.import_module(EXTENSION_NAMESPACE) except ModuleNotFoundError: # This only happens when a test case creates and activates a new venv return {} else: cls.AVAILABLE_EXTENSIONS = { name[len(EXTENSION_NAMESPACE) + 1 :]: name for finder, name, ispkg in iter_namespace(inmanta_ext) } return dict(cls.AVAILABLE_EXTENSIONS) # Extension loading Phase I: from start to setup functions collected def _discover_plugin_packages(self, return_all_available_packages: bool = False) -> list[str]: """Discover all packages that are defined in the inmanta_ext namespace package. Filter available extensions based on enabled_extensions and disabled_extensions config in the server configuration. :param return_all_available_packages: Return all available plugin packages independent of whether the extension is enabled or not. :return: A list of all subpackages defined in inmanta_ext """ available = self.get_available_extensions()"Discovered extensions: %s", ", ".join(available.keys())) extensions: list[str] = [] enabled = [x for x in config.server_enabled_extensions.get() if len(x)] if return_all_available_packages: extensions.extend(available.values()) elif enabled: for ext in enabled: if ext not in available: raise PluginLoadFailed( f"Extension {ext} in config option {} in section " f"{config.server_enabled_extensions.section} is not available." ) extensions.append(available[ext]) elif len(available) > 1: # More than core is available f"Load extensions by setting configuration option {} in section " f"{config.server_enabled_extensions.section}. {len(available) - 1} extensions available but none are enabled." ) if "core" not in extensions: extensions.append(available["core"]) return extensions def _load_extension(self, name: str) -> ModuleType: """Import the extension defined in the package in name and return the setup function that needs to be called for the extension to register its slices in the application context. """ try: importlib.import_module(name) except Exception as e: raise PluginLoadFailed(f"Could not load module {name}") from e try: mod = importlib.import_module(f"{name}.{EXTENSION_MODULE}") except Exception as e: raise PluginLoadFailed(f"Could not load module {name}.{EXTENSION_MODULE}") from e else: self._validate_extension(mod) return mod def _validate_extension(self, ext_mod: ModuleType) -> None: """ Validate whether the given extension module satisfied the mandatory requirements for an Inmanta extension. If the requirements are not satisfied, this method raises an PluginLoadFailed exception. """ if not hasattr(ext_mod, "setup"): raise PluginLoadFailed(" doesn't have a setup method.") def _load_extensions(self, load_all_extensions: bool = False) -> dict[str, ModuleType]: """Discover all extensions, validate correct naming and load its setup function""" plugins: dict[str, ModuleType] = {} enabled_extensions: list[str] = self._discover_plugin_packages(load_all_extensions)"Enabled extensions: %s", ", ".join(enabled_extensions)) for name in enabled_extensions: try: module = self._load_extension(name) assert name.startswith(f"{EXTENSION_NAMESPACE}.") name = name[len(EXTENSION_NAMESPACE) + 1 :] plugins[name] = module except PluginLoadFailed: LOGGER.warning("Could not load extension %s", name, exc_info=True) return plugins def _collect_environment_settings(self, ext_module: ModuleType, app_ctx: ApplicationContext) -> None: """ Collect the settings of an Inmanta environment defined by the given extension. """ if not hasattr(ext_module, "register_environment_settings"): # Extension doesn't define any environment settings. return ext_module.register_environment_settings(app_ctx) # Extension loading Phase II: collect slices def _collect_slices( self, extensions: dict[str, ModuleType], only_register_environment_settings: bool = False ) -> ApplicationContext: """ Call the setup function on all extensions and let them register their slices in the ApplicationContext. """ ctx = ApplicationContext() for name, ext_module in extensions.items(): myctx = ConstrainedApplicationContext(ctx, name) self._collect_environment_settings(ext_module, myctx) if not only_register_environment_settings: ext_module.setup(myctx) return ctx def load_slices( self, *, load_all_extensions: bool = False, only_register_environment_settings: bool = False ) -> ApplicationContext: """ Load all slices in the server """ if self.ctx is not None and not load_all_extensions: return self.ctx exts: dict[str, ModuleType] = self._load_extensions(load_all_extensions) ctx: ApplicationContext = self._collect_slices(exts, only_register_environment_settings) self.feature_manager = ctx.get_feature_manager() if not only_register_environment_settings and not load_all_extensions: self.ctx = ctx return ctx async def wait_for_db(self, db_wait_time: int) -> None: """Wait for the database to be up by attempting to connect at intervals. :param db_wait_time: Maximum time to wait for the database to be up, in seconds. """ start_time = asyncio.get_event_loop().time() # Retrieve database connection settings from the configuration db_settings = { "host": config.db_host.get(), "port": config.db_port.get(), "user": config.db_username.get(), "password": config.db_password.get(), "database": config.db_name.get(), } while True: try: # Attempt to create a database connection conn = await asyncpg.connect(**db_settings, timeout=5) # raises TimeoutError after 5 seconds"Successfully connected to the database.") await conn.close(timeout=5) # close the connection return except asyncio.TimeoutError:"Waiting for database to be up: Connection attempt timed out.") except Exception:"Waiting for database to be up.", exc_info=True) # Check if the maximum wait time has been exceeded if 0 < db_wait_time < asyncio.get_event_loop().time() - start_time: LOGGER.error("Timed out waiting for the database to be up.") raise Exception("Database connection timeout after %d seconds." % db_wait_time) # Sleep for a second before retrying await asyncio.sleep(1)