"""
Support for partial compiles
:copyright: 2023 Inmanta
:contact: code@inmanta.com
:license: Inmanta EULA
"""
import abc
import logging
import typing
from collections import defaultdict
from typing import Optional, Sequence
import inmanta.ast
import inmanta.const
import inmanta.execute.proxy as proxy
import inmanta_plugins.lsm
import inmanta_plugins.lsm.allocation_helpers
from inmanta.plugins import PluginException
from inmanta_lsm import const as lsm_const
from inmanta_lsm.model import AttributeSetName
LOGGER = logging.getLogger(__name__)
VersionedServiceEntity: typing.TypeAlias = tuple[str, int]
class InstancePartitioningException(PluginException):
pass
class TraversalStep:
"""
Represents a step in the hierarchy between owner and owned entities
"""
def __init__(
self,
child_entity: str,
child_version: int,
to_parent_relation: str,
parent_entity: str,
) -> None:
self.child_entity = child_entity
self.parent_entity = parent_entity
self.child_version = child_version
self.to_parent_relation = to_parent_relation
if not self.to_parent_relation:
raise InstancePartitioningException(
f"The attribute `relation_to_owner` on ServiceEntityBinding {child_entity} is not set, "
f"but the `owner` relation is set: these two must always be used in combination."
)
[docs]
class SelectorAPI(abc.ABC):
"""The Selector is responsible for determining which instances are returned by lsm::all
A specific selector class can be registered using `inmanta_plugins.lsm.global_cache.set_selector_factory`
A selector is used in 4 phases:
1. the Selector is constructed (empty)
2. it is fed all relevant bindings to analyze and cache.
3. it is fed all instances requested via the environment variable inmanta_instance_id as used for partial compile
4. it returns the instances selected
All methods can be called multiple times, but once a method from the next phase is called,
methods from the previous phase should not get called any more (for the same binding)
"""
def __init__(self, env: str) -> None:
self.env = env
[docs]
@abc.abstractmethod
def reload_bindings(self) -> None:
"""Register a new binding (phase 2):
This method is only required for very advanced selectors that need to inspect the binding structure.
This method checks the binding cache for new instances and registers them.
e.g.
.. code-block:: python
def reload_bindings(self) -> None:
for (
name,
version,
), binding in dict(inmanta_plugins.lsm.global_cache.get_all_versioned_bindings()).items():
if (name, version) not in self.root_for:
self.register_binding(binding)
Implementors, keep in mind that:
1. method can be re-executed because of unset exceptions
2. any binding additionally required MUST be registered in the global cache (inmanta_plugins.lsm.global_cache)
"""
pass
[docs]
@abc.abstractmethod
def register_instances(self, instance_ids: Sequence[str]) -> None:
"""
register explicitly requested instances (phase 3), can be called multiple times.
"""
[docs]
@abc.abstractmethod
def select(self, requested_service_type: str) -> list[dict]:
"""
Return all instances for a specific type (phase 4)
can be called multiple times
All instances must also be cached in the global cache
"""
class AllSelector(SelectorAPI):
"""Selector to be used when not doing partial compile"""
def reload_bindings(self) -> None:
pass
def register_instances(self, instance_ids: list[str]) -> None:
pass
def select(self, requested_service_type: str) -> list[dict]:
return inmanta_plugins.lsm.global_cache.get_all_instances(
self.env, requested_service_type
)
class AbstractSelector(SelectorAPI):
"""
Base class to implement custom selectors
it keeps track of:
:param validate: is this a validation compile
:param instance_to_validate: if it is a validation compile, which instance are we validating.
This is built up by calls to :meth:`register_instances`.
:param requested_instances: what are the instance we need to include in the selection?
This is built up by calls to :meth:`register_instances`.
The select_all method is the only one that still needs implementing
"""
def __init__(self, env: str) -> None:
super().__init__(env)
self.validate = inmanta_plugins.lsm.allocation_helpers.is_validation_compile()
# Instance level
self.requested_instances: set[str] = set()
self.instance_to_validate: Optional[str] = None
# Cache
self.select_all_cache: Optional[dict[str, list[dict]]] = None
def reload_bindings(self) -> None:
pass
def register_instances(self, instance_ids: list[str]) -> None:
"""
register instances, can be called multiple times
"""
if self.select_all_cache is not None:
for instance_id in instance_ids:
assert (
instance_id in self.requested_instances
), f"Register instance {instance_id} called after select, this is not allowed"
for instance_id in instance_ids:
self.requested_instances.add(instance_id)
if self.validate:
if len(self.requested_instances) == 0:
raise Exception(
"Validation compile without instance set! This is not allowed"
)
if len(self.requested_instances) > 1:
raise Exception(
"Validation compile with multiple instance set! This is not allowed. instances: %s",
self.requested_instances,
)
self.instance_to_validate = list(self.requested_instances)[0]
@abc.abstractmethod
def select_all(self) -> dict[str, list[dict]]:
"""Main entry point to overload:
Select all instances, as required to compile the `requested_instances`
return as a dictionary mapping service_entity_name to a list of entities
See documentation of examples
"""
pass
def select(self, requested_service_type: str) -> list[dict]:
# this cache is very subtle
# lsm::all can be called on multiple disjoint trees
# we may have completed all stages for one tree, but have no bindings for the other tree yet.
if not self.select_all_cache or (
requested_service_type not in self.select_all_cache
):
self.select_all_cache = self.select_all()
return self.select_all_cache.get(requested_service_type, [])
def _get_attribute(self, current_instance: dict, relation: str) -> dict:
"""
Robust but sloppy way of getting the attributes out
some validation states have no attribute set set,
which is why we may have to take another attribute set to build the group
This means that no service can move to another group or this code becomes unreliable
"""
instance = inmanta_plugins.lsm.global_cache.convert_instance(
current_instance, self.validate, self.instance_to_validate
)
instance_attributes = None
if instance is not None:
instance_attributes = instance["attributes"]
if not instance_attributes:
# No attributes, wing it
# We don't check if this is consistent over the different attribute sets
# As it is allowed to move within a group
instance_attributes = next(
iter(
the_set
for the_set in (
current_instance.get(attr_set.value)
for attr_set in AttributeSetName
)
if the_set is not None
)
)
return instance_attributes.get(relation)
class TreeSelector(AbstractSelector):
"""
Default selector, support any type of tree,
where the full subtree is selected if any instance is requested
"""
def __init__(self, env: str) -> None:
super().__init__(env)
# Type level fields
# List of children for each type (used to traverse down)
self.children: dict[str, set[VersionedServiceEntity]] = defaultdict(set)
# Cache of relation to traverse one level up
self.upward_lookup: dict[VersionedServiceEntity, Optional[TraversalStep]] = {}
self.root_for: dict[VersionedServiceEntity, proxy.DynamicProxy] = {}
# Instance level
self.root_instance_for: dict[str, dict] = {}
def reload_bindings(self) -> None:
for (
name,
version,
), binding in dict(
inmanta_plugins.lsm.global_cache.get_all_versioned_bindings()
).items():
if (name, version) not in self.root_for:
self.register_binding(binding)
def register_binding(self, binding: "lsm::ServiceBindingVersion") -> None:
LOGGER.log(
inmanta.const.LOG_LEVEL_TRACE,
"Resolution of service binding %s version %s: start",
binding.service.service_entity_name,
binding.version,
)
# This may be re-executed a few times
# but complex to cache, easy to execute, so we don't bother
root = binding
try:
while True:
parent = root.owner
if not parent:
break
root = parent
except inmanta.execute.proxy.UnsetException:
LOGGER.warning(
"Please set the owner relation on the service entity binding for the service entity %s (instantiated at %s) "
"in the constructor, even if it is null",
binding.service_entity,
str(binding._get_instance()._location),
)
raise
except inmanta.ast.OptionalValueException:
# found root
pass
LOGGER.log(
inmanta.const.LOG_LEVEL_TRACE,
"Resolution of service binding %s version %s: found root %s version %s",
binding.service.service_entity_name,
binding.version,
root.service.service_entity_name,
root.version,
)
siblings = set()
todo = [root]
while todo:
item = todo.pop()
if item not in siblings:
siblings.add(item)
todo.extend(item.owned)
LOGGER.log(
inmanta.const.LOG_LEVEL_TRACE,
"Resolution of service binding %s version %s: found siblings [%s]",
binding.service.service_entity_name,
", ".join(
[
f"{sibling.service.service_entity_name} version {sibling.version}"
for sibling in siblings
]
),
)
for sibling in siblings:
inmanta_plugins.lsm.global_cache.register_binding(sibling)
self.root_for[(sibling.service.service_entity_name, sibling.version)] = root
LOGGER.log(
inmanta.const.LOG_LEVEL_TRACE,
"Resolution of service binding %s version %s: done",
binding.service.service_entity_name,
binding.version,
)
# Build traversal spec for this root
self.upward_lookup[(root.service.service_entity_name, root.version)] = None
self.root_for[(root.service.service_entity_name, root.version)] = root
self.build_traversal_spec(root)
def build_traversal_spec(self, parent: "lsm::ServiceBindingVersion") -> None:
logging.debug(
"Building traversal spec for %s",
(parent.service.service_entity_name, parent.version),
)
for child in parent.owned:
child_type = child.service.service_entity_name
traversal = TraversalStep(
child_type,
child.version,
child.relation_to_owner,
parent,
)
self.upward_lookup[(child_type, child.version)] = traversal
self.children[parent.service.service_entity_name].add(
(child_type, child.version)
)
self.build_traversal_spec(
child
) # Will run out of stack if there is a cycle
def find_root(
self,
current_instance: dict,
) -> dict:
"""Find the root instance, in its api form"""
# cache
id = current_instance["id"]
root = self.root_instance_for.get(id)
if root:
return root
service_type = current_instance["service_entity"]
next_step = self.upward_lookup.get(
(service_type, current_instance["service_entity_version"]), None
)
if next_step is None:
# root
self.root_instance_for[id] = current_instance
return current_instance
parent_id = self._get_attribute(current_instance, next_step.to_parent_relation)
if not parent_id:
raise InstancePartitioningException(
f"The service instance {service_type}-{current_instance['id']} does not have the relation "
f"`{next_step.to_parent_relation}` towards its parent set."
)
parent = inmanta_plugins.lsm.global_cache.get_instance(
env=self.env,
service_entity_name=None,
instance_id=parent_id,
include_terminated=True,
)
if not parent:
raise InstancePartitioningException(
f"Can not find the owning entity of {service_type}-{current_instance['id']}: "
f"no instance found for {next_step.parent_entity}-{parent_id}"
)
root = self.find_root(parent)
self.root_instance_for[id] = root
return root
def find_all_children(
self,
type_and_version: VersionedServiceEntity,
owners: list[dict],
found_instances: dict[VersionedServiceEntity, list],
) -> None:
"""
Recursively finds all children of a given entity type and version.
:param type_and_version: The service entity type and version of the parent entity
:param owners: a list of instances of the parent entity that we want to find the children for
:param found_instances: the instances that we have found thus far
"""
found_instances[type_and_version].extend(owners)
# Get child types for service entity name
child_types = self.children.get(type_and_version[0])
if not child_types:
return
parent_ids = {owner["id"] for owner in owners}
for child_type, child_version in child_types:
traversal = self.upward_lookup[(child_type, child_version)]
assert (
traversal is not None
) # make mypy happy, every child should have a parent
instances_raw = inmanta_plugins.lsm.global_cache.get_all_instances(
self.env, service_entity_name=child_type
)
# Check that the instance is of the correct version
# And if the relation_to_owner matches one of the parent_ids
selected_instances = [
instance
for instance in instances_raw
if instance is not None
and instance.get("service_entity_version", 0) == child_version
and (
self._get_attribute(instance, traversal.to_parent_relation)
in parent_ids
)
]
self.find_all_children(
(child_type, child_version), selected_instances, found_instances
)
return
def select_all(self) -> dict[str, list[dict]]:
if not self.requested_instances:
raise Exception(f"Environment variable {lsm_const.ENV_INSTANCE_ID} not set")
LOGGER.log(
logging.INFO,
"Selecting instances for current instance: %s",
self.requested_instances,
)
global_cache = inmanta_plugins.lsm.global_cache
for entity in {entity for entity, version in self.upward_lookup.keys()}:
# Pre-load everything!
global_cache.get_all_instances(self.env, entity)
# roots, per type
roots: dict[VersionedServiceEntity, dict[str, dict]] = defaultdict(dict)
for instance in self.requested_instances:
current_instance_raw = inmanta_plugins.lsm.global_cache.get_instance(
env=self.env,
service_entity_name=None,
instance_id=instance,
include_terminated=True,
)
if not current_instance_raw:
LOGGER.info("No instance found for %s", instance)
continue
if (
current_instance_raw["service_entity"],
current_instance_raw["service_entity_version"],
) not in self.upward_lookup:
LOGGER.info(
"Current instance %s is not under any known root",
instance,
)
continue
root_instance = self.find_root(current_instance_raw)
LOGGER.log(
logging.DEBUG,
"Root instance found for %s version %s:%s",
root_instance["service_entity"],
root_instance["service_entity_version"],
root_instance["id"],
)
roots[
(
root_instance["service_entity"],
root_instance["service_entity_version"],
)
][root_instance["id"]] = root_instance
# We use versions to fetch all the corresponding children
versioned_instances: dict[VersionedServiceEntity, list[dict]] = defaultdict(
list
)
for type_and_version, instances in roots.items():
self.find_all_children(
type_and_version,
[root_instance for root_instance in instances.values()],
versioned_instances,
)
# We don't care about versions on the output of select_all
out: dict[str, list[dict]] = defaultdict(list)
for entity, version in versioned_instances.keys():
out[entity].extend(versioned_instances[(entity, version)])
return out
class ParentSelector(TreeSelector):
"""
Class to help select instances that have to be considered when doing a partial compile
This class supports select all instances below a common root entity.
"""
pass