Source code for inmanta_plugins.lsm.partial

"""
    Support for partial compiles

    :copyright: 2023 Inmanta
    :contact: code@inmanta.com
    :license: Inmanta EULA
"""

import abc
import logging
import typing
from collections import defaultdict
from typing import Optional, Sequence

import inmanta.ast
import inmanta.const
import inmanta.execute.proxy as proxy
import inmanta_plugins.lsm
import inmanta_plugins.lsm.allocation_helpers
from inmanta.plugins import PluginException
from inmanta_lsm import const as lsm_const
from inmanta_lsm.model import AttributeSetName

LOGGER = logging.getLogger(__name__)

VersionedServiceEntity: typing.TypeAlias = tuple[str, int]


class InstancePartitioningException(PluginException):
    pass


class TraversalStep:
    """
    Represents a step in the hierarchy between owner and owned entities
    """

    def __init__(
        self,
        child_entity: str,
        child_version: int,
        to_parent_relation: str,
        parent_entity: str,
    ) -> None:
        self.child_entity = child_entity
        self.parent_entity = parent_entity
        self.child_version = child_version
        self.to_parent_relation = to_parent_relation
        if not self.to_parent_relation:
            raise InstancePartitioningException(
                f"The attribute `relation_to_owner` on ServiceEntityBinding {child_entity} is not set, "
                f"but the `owner` relation is set: these two must always be used in combination."
            )


[docs] class SelectorAPI(abc.ABC): """The Selector is responsible for determining which instances are returned by lsm::all A specific selector class can be registered using `inmanta_plugins.lsm.global_cache.set_selector_factory` A selector is used in 4 phases: 1. the Selector is constructed (empty) 2. it is fed all relevant bindings to analyze and cache. 3. it is fed all instances requested via the environment variable inmanta_instance_id as used for partial compile 4. it returns the instances selected All methods can be called multiple times, but once a method from the next phase is called, methods from the previous phase should not get called any more (for the same binding) """ def __init__(self, env: str) -> None: self.env = env
[docs] @abc.abstractmethod def reload_bindings(self) -> None: """Register a new binding (phase 2): This method is only required for very advanced selectors that need to inspect the binding structure. This method checks the binding cache for new instances and registers them. e.g. .. code-block:: python def reload_bindings(self) -> None: for ( name, version, ), binding in dict(inmanta_plugins.lsm.global_cache.get_all_versioned_bindings()).items(): if (name, version) not in self.root_for: self.register_binding(binding) Implementors, keep in mind that: 1. method can be re-executed because of unset exceptions 2. any binding additionally required MUST be registered in the global cache (inmanta_plugins.lsm.global_cache) """ pass
[docs] @abc.abstractmethod def register_instances(self, instance_ids: Sequence[str]) -> None: """ register explicitly requested instances (phase 3), can be called multiple times. """
[docs] @abc.abstractmethod def select(self, requested_service_type: str) -> list[dict]: """ Return all instances for a specific type (phase 4) can be called multiple times All instances must also be cached in the global cache """
class AllSelector(SelectorAPI): """Selector to be used when not doing partial compile""" def reload_bindings(self) -> None: pass def register_instances(self, instance_ids: list[str]) -> None: pass def select(self, requested_service_type: str) -> list[dict]: return inmanta_plugins.lsm.global_cache.get_all_instances( self.env, requested_service_type ) class AbstractSelector(SelectorAPI): """ Base class to implement custom selectors it keeps track of: :param validate: is this a validation compile :param instance_to_validate: if it is a validation compile, which instance are we validating. This is built up by calls to :meth:`register_instances`. :param requested_instances: what are the instance we need to include in the selection? This is built up by calls to :meth:`register_instances`. The select_all method is the only one that still needs implementing """ def __init__(self, env: str) -> None: super().__init__(env) self.validate = inmanta_plugins.lsm.allocation_helpers.is_validation_compile() # Instance level self.requested_instances: set[str] = set() self.instance_to_validate: Optional[str] = None # Cache self.select_all_cache: Optional[dict[str, list[dict]]] = None def reload_bindings(self) -> None: pass def register_instances(self, instance_ids: list[str]) -> None: """ register instances, can be called multiple times """ if self.select_all_cache is not None: for instance_id in instance_ids: assert ( instance_id in self.requested_instances ), f"Register instance {instance_id} called after select, this is not allowed" for instance_id in instance_ids: self.requested_instances.add(instance_id) if self.validate: if len(self.requested_instances) == 0: raise Exception( "Validation compile without instance set! This is not allowed" ) if len(self.requested_instances) > 1: raise Exception( "Validation compile with multiple instance set! This is not allowed. instances: %s", self.requested_instances, ) self.instance_to_validate = list(self.requested_instances)[0] @abc.abstractmethod def select_all(self) -> dict[str, list[dict]]: """Main entry point to overload: Select all instances, as required to compile the `requested_instances` return as a dictionary mapping service_entity_name to a list of entities See documentation of examples """ pass def select(self, requested_service_type: str) -> list[dict]: # this cache is very subtle # lsm::all can be called on multiple disjoint trees # we may have completed all stages for one tree, but have no bindings for the other tree yet. if not self.select_all_cache or ( requested_service_type not in self.select_all_cache ): self.select_all_cache = self.select_all() return self.select_all_cache.get(requested_service_type, []) def _get_attribute(self, current_instance: dict, relation: str) -> dict: """ Robust but sloppy way of getting the attributes out some validation states have no attribute set set, which is why we may have to take another attribute set to build the group This means that no service can move to another group or this code becomes unreliable """ instance = inmanta_plugins.lsm.global_cache.convert_instance( current_instance, self.validate, self.instance_to_validate ) instance_attributes = None if instance is not None: instance_attributes = instance["attributes"] if not instance_attributes: # No attributes, wing it # We don't check if this is consistent over the different attribute sets # As it is allowed to move within a group instance_attributes = next( iter( the_set for the_set in ( current_instance.get(attr_set.value) for attr_set in AttributeSetName ) if the_set is not None ) ) return instance_attributes.get(relation) class TreeSelector(AbstractSelector): """ Default selector, support any type of tree, where the full subtree is selected if any instance is requested """ def __init__(self, env: str) -> None: super().__init__(env) # Type level fields # List of children for each type (used to traverse down) self.children: dict[str, set[VersionedServiceEntity]] = defaultdict(set) # Cache of relation to traverse one level up self.upward_lookup: dict[VersionedServiceEntity, Optional[TraversalStep]] = {} self.root_for: dict[VersionedServiceEntity, proxy.DynamicProxy] = {} # Instance level self.root_instance_for: dict[str, dict] = {} def reload_bindings(self) -> None: for ( name, version, ), binding in dict( inmanta_plugins.lsm.global_cache.get_all_versioned_bindings() ).items(): if (name, version) not in self.root_for: self.register_binding(binding) def register_binding(self, binding: "lsm::ServiceBindingVersion") -> None: LOGGER.log( inmanta.const.LOG_LEVEL_TRACE, "Resolution of service binding %s version %s: start", binding.service.service_entity_name, binding.version, ) # This may be re-executed a few times # but complex to cache, easy to execute, so we don't bother root = binding try: while True: parent = root.owner if not parent: break root = parent except inmanta.execute.proxy.UnsetException: LOGGER.warning( "Please set the owner relation on the service entity binding for the service entity %s (instantiated at %s) " "in the constructor, even if it is null", binding.service_entity, str(binding._get_instance()._location), ) raise except inmanta.ast.OptionalValueException: # found root pass LOGGER.log( inmanta.const.LOG_LEVEL_TRACE, "Resolution of service binding %s version %s: found root %s version %s", binding.service.service_entity_name, binding.version, root.service.service_entity_name, root.version, ) siblings = set() todo = [root] while todo: item = todo.pop() if item not in siblings: siblings.add(item) todo.extend(item.owned) LOGGER.log( inmanta.const.LOG_LEVEL_TRACE, "Resolution of service binding %s version %s: found siblings [%s]", binding.service.service_entity_name, ", ".join( [ f"{sibling.service.service_entity_name} version {sibling.version}" for sibling in siblings ] ), ) for sibling in siblings: inmanta_plugins.lsm.global_cache.register_binding(sibling) self.root_for[(sibling.service.service_entity_name, sibling.version)] = root LOGGER.log( inmanta.const.LOG_LEVEL_TRACE, "Resolution of service binding %s version %s: done", binding.service.service_entity_name, binding.version, ) # Build traversal spec for this root self.upward_lookup[(root.service.service_entity_name, root.version)] = None self.root_for[(root.service.service_entity_name, root.version)] = root self.build_traversal_spec(root) def build_traversal_spec(self, parent: "lsm::ServiceBindingVersion") -> None: logging.debug( "Building traversal spec for %s", (parent.service.service_entity_name, parent.version), ) for child in parent.owned: child_type = child.service.service_entity_name traversal = TraversalStep( child_type, child.version, child.relation_to_owner, parent, ) self.upward_lookup[(child_type, child.version)] = traversal self.children[parent.service.service_entity_name].add( (child_type, child.version) ) self.build_traversal_spec( child ) # Will run out of stack if there is a cycle def find_root( self, current_instance: dict, ) -> dict: """Find the root instance, in its api form""" # cache id = current_instance["id"] root = self.root_instance_for.get(id) if root: return root service_type = current_instance["service_entity"] next_step = self.upward_lookup.get( (service_type, current_instance["service_entity_version"]), None ) if next_step is None: # root self.root_instance_for[id] = current_instance return current_instance parent_id = self._get_attribute(current_instance, next_step.to_parent_relation) if not parent_id: raise InstancePartitioningException( f"The service instance {service_type}-{current_instance['id']} does not have the relation " f"`{next_step.to_parent_relation}` towards its parent set." ) parent = inmanta_plugins.lsm.global_cache.get_instance( env=self.env, service_entity_name=None, instance_id=parent_id, include_terminated=True, ) if not parent: raise InstancePartitioningException( f"Can not find the owning entity of {service_type}-{current_instance['id']}: " f"no instance found for {next_step.parent_entity}-{parent_id}" ) root = self.find_root(parent) self.root_instance_for[id] = root return root def find_all_children( self, type_and_version: VersionedServiceEntity, owners: list[dict], found_instances: dict[VersionedServiceEntity, list], ) -> None: """ Recursively finds all children of a given entity type and version. :param type_and_version: The service entity type and version of the parent entity :param owners: a list of instances of the parent entity that we want to find the children for :param found_instances: the instances that we have found thus far """ found_instances[type_and_version].extend(owners) # Get child types for service entity name child_types = self.children.get(type_and_version[0]) if not child_types: return parent_ids = {owner["id"] for owner in owners} for child_type, child_version in child_types: traversal = self.upward_lookup[(child_type, child_version)] assert ( traversal is not None ) # make mypy happy, every child should have a parent instances_raw = inmanta_plugins.lsm.global_cache.get_all_instances( self.env, service_entity_name=child_type ) # Check that the instance is of the correct version # And if the relation_to_owner matches one of the parent_ids selected_instances = [ instance for instance in instances_raw if instance is not None and instance.get("service_entity_version", 0) == child_version and ( self._get_attribute(instance, traversal.to_parent_relation) in parent_ids ) ] self.find_all_children( (child_type, child_version), selected_instances, found_instances ) return def select_all(self) -> dict[str, list[dict]]: if not self.requested_instances: raise Exception(f"Environment variable {lsm_const.ENV_INSTANCE_ID} not set") LOGGER.log( logging.INFO, "Selecting instances for current instance: %s", self.requested_instances, ) global_cache = inmanta_plugins.lsm.global_cache for entity in {entity for entity, version in self.upward_lookup.keys()}: # Pre-load everything! global_cache.get_all_instances(self.env, entity) # roots, per type roots: dict[VersionedServiceEntity, dict[str, dict]] = defaultdict(dict) for instance in self.requested_instances: current_instance_raw = inmanta_plugins.lsm.global_cache.get_instance( env=self.env, service_entity_name=None, instance_id=instance, include_terminated=True, ) if not current_instance_raw: LOGGER.info("No instance found for %s", instance) continue if ( current_instance_raw["service_entity"], current_instance_raw["service_entity_version"], ) not in self.upward_lookup: LOGGER.info( "Current instance %s is not under any known root", instance, ) continue root_instance = self.find_root(current_instance_raw) LOGGER.log( logging.DEBUG, "Root instance found for %s version %s:%s", root_instance["service_entity"], root_instance["service_entity_version"], root_instance["id"], ) roots[ ( root_instance["service_entity"], root_instance["service_entity_version"], ) ][root_instance["id"]] = root_instance # We use versions to fetch all the corresponding children versioned_instances: dict[VersionedServiceEntity, list[dict]] = defaultdict( list ) for type_and_version, instances in roots.items(): self.find_all_children( type_and_version, [root_instance for root_instance in instances.values()], versioned_instances, ) # We don't care about versions on the output of select_all out: dict[str, list[dict]] = defaultdict(list) for entity, version in versioned_instances.keys(): out[entity].extend(versioned_instances[(entity, version)]) return out class ParentSelector(TreeSelector): """ Class to help select instances that have to be considered when doing a partial compile This class supports select all instances below a common root entity. """ pass