"""
Support for partial compiles
:copyright: 2023 Inmanta
:contact: code@inmanta.com
:license: Inmanta EULA
"""
import abc
import logging
import typing
from collections import defaultdict
from typing import Optional, Sequence
import inmanta.ast
import inmanta.const
import inmanta.execute.proxy as proxy
import inmanta_plugins.lsm
import inmanta_plugins.lsm.allocation_helpers
from inmanta.plugins import PluginException
from inmanta_lsm import const as lsm_const
from inmanta_lsm.model import AttributeSetName
LOGGER = logging.getLogger(__name__)
VersionedServiceEntity: typing.TypeAlias = tuple[str, int]
class InstancePartitioningException(PluginException):
    pass
class TraversalStep:
    """
    Represents a step in the hierarchy between owner and owned entities
    """
    def __init__(
        self,
        child_entity: str,
        child_version: int,
        to_parent_relation: str,
        parent_entity: str,
    ) -> None:
        self.child_entity = child_entity
        self.parent_entity = parent_entity
        self.child_version = child_version
        self.to_parent_relation = to_parent_relation
        if not self.to_parent_relation:
            raise InstancePartitioningException(
                f"The attribute `relation_to_owner` on ServiceEntityBinding {child_entity} is not set, "
                f"but the `owner` relation is set: these two must always be used in combination."
            )
[docs]
class SelectorAPI(abc.ABC):
    """The Selector is responsible for determining which instances are returned by lsm::all
    A specific selector class can be registered using `inmanta_plugins.lsm.global_cache.set_selector_factory`
    A selector is used in 4 phases:
    1. the Selector is constructed (empty)
    2. it is fed all relevant bindings to analyze and cache.
    3. it is fed all instances requested via the environment variable inmanta_instance_id as used for partial compile
    4. it returns the instances selected
    All methods can be called multiple times, but once a method from the next phase is called,
    methods from the previous phase should not get called any more (for the same binding)
    """
    def __init__(self, env: str) -> None:
        self.env = env
[docs]
    @abc.abstractmethod
    def reload_bindings(self) -> None:
        """Register a new binding (phase 2):
        This method is only required for very advanced selectors that need to inspect the binding structure.
        This method checks the binding cache for new instances and registers them.
        e.g.
        .. code-block:: python
            def reload_bindings(self) -> None:
                for (
                    name,
                    version,
                ), binding in dict(inmanta_plugins.lsm.global_cache.get_all_versioned_bindings()).items():
                    if (name, version) not in self.root_for:
                        self.register_binding(binding)
        Implementors, keep in mind that:
            1. method can be re-executed because of unset exceptions
            2. any binding additionally required MUST be registered in the global cache (inmanta_plugins.lsm.global_cache)
        """
        pass 
[docs]
    @abc.abstractmethod
    def register_instances(self, instance_ids: Sequence[str]) -> None:
        """
        register explicitly requested instances (phase 3), can be called multiple times.
        """ 
[docs]
    @abc.abstractmethod
    def select(self, requested_service_type: str) -> list[dict]:
        """
        Return all instances for a specific type (phase 4)
        can be called multiple times
        All instances must also be cached in the global cache
        """ 
 
class AllSelector(SelectorAPI):
    """Selector to be used when not doing partial compile"""
    def reload_bindings(self) -> None:
        pass
    def register_instances(self, instance_ids: list[str]) -> None:
        pass
    def select(self, requested_service_type: str) -> list[dict]:
        return inmanta_plugins.lsm.global_cache.get_all_instances(
            self.env, requested_service_type
        )
class AbstractSelector(SelectorAPI):
    """
    Base class to implement custom selectors
    it keeps track of:
    :param validate: is this a validation compile
    :param instance_to_validate: if it is a validation compile, which instance are we validating.
        This is built up by calls to :meth:`register_instances`.
    :param requested_instances: what are the instance we need to include in the selection?
        This is built up by calls to :meth:`register_instances`.
    The select_all method is the only one that still needs implementing
    """
    def __init__(self, env: str) -> None:
        super().__init__(env)
        self.validate = inmanta_plugins.lsm.allocation_helpers.is_validation_compile()
        # Instance level
        self.requested_instances: set[str] = set()
        self.instance_to_validate: Optional[str] = None
        # Cache
        self.select_all_cache: Optional[dict[str, list[dict]]] = None
    def reload_bindings(self) -> None:
        pass
    def register_instances(self, instance_ids: list[str]) -> None:
        """
        register instances, can be called multiple times
        """
        if self.select_all_cache is not None:
            for instance_id in instance_ids:
                assert (
                    instance_id in self.requested_instances
                ), f"Register instance {instance_id} called after select, this is not allowed"
        for instance_id in instance_ids:
            self.requested_instances.add(instance_id)
        if self.validate:
            if len(self.requested_instances) == 0:
                raise Exception(
                    "Validation compile without instance set! This is not allowed"
                )
            if len(self.requested_instances) > 1:
                raise Exception(
                    "Validation compile with multiple instance set! This is not allowed. instances: %s",
                    self.requested_instances,
                )
            self.instance_to_validate = list(self.requested_instances)[0]
    @abc.abstractmethod
    def select_all(self) -> dict[str, list[dict]]:
        """Main entry point to overload:
        Select all instances, as required to compile the `requested_instances`
        return as a dictionary mapping service_entity_name to a list of entities
        See documentation of examples
        """
        pass
    def select(self, requested_service_type: str) -> list[dict]:
        # this cache is very subtle
        # lsm::all can be called on multiple disjoint trees
        # we may have completed all stages for one tree, but have no bindings for the other tree yet.
        if not self.select_all_cache or (
            requested_service_type not in self.select_all_cache
        ):
            self.select_all_cache = self.select_all()
        return self.select_all_cache.get(requested_service_type, [])
    def _get_attribute(self, current_instance: dict, relation: str) -> dict:
        """
        Robust but sloppy way of getting the attributes out
        some validation states have no attribute set set,
        which is why we may have to take another attribute set to build the group
        This means that no service can move to another group or this code becomes unreliable
        """
        instance = inmanta_plugins.lsm.global_cache.convert_instance(
            current_instance, self.validate, self.instance_to_validate
        )
        instance_attributes = None
        if instance is not None:
            instance_attributes = instance["attributes"]
        if not instance_attributes:
            # No attributes, wing it
            # We don't check if this is consistent over the different attribute sets
            # As it is allowed to move within a group
            instance_attributes = next(
                iter(
                    the_set
                    for the_set in (
                        current_instance.get(attr_set.value)
                        for attr_set in AttributeSetName
                    )
                    if the_set is not None
                )
            )
        return instance_attributes.get(relation)
class TreeSelector(AbstractSelector):
    """
    Default selector, support any type of tree,
    where the full subtree is selected if any instance is requested
    """
    def __init__(self, env: str) -> None:
        super().__init__(env)
        # Type level fields
        # List of children for each type (used to traverse down)
        self.children: dict[str, set[VersionedServiceEntity]] = defaultdict(set)
        # Cache of relation to traverse one level up
        self.upward_lookup: dict[VersionedServiceEntity, Optional[TraversalStep]] = {}
        self.root_for: dict[VersionedServiceEntity, proxy.DynamicProxy] = {}
        # Instance level
        self.root_instance_for: dict[str, dict] = {}
    def reload_bindings(self) -> None:
        for (
            name,
            version,
        ), binding in dict(
            inmanta_plugins.lsm.global_cache.get_all_versioned_bindings()
        ).items():
            if (name, version) not in self.root_for:
                self.register_binding(binding)
    def register_binding(self, binding: "lsm::ServiceBindingVersion") -> None:
        LOGGER.log(
            inmanta.const.LOG_LEVEL_TRACE,
            "Resolution of service binding %s version %s: start",
            binding.service.service_entity_name,
            binding.version,
        )
        # This may be re-executed a few times
        # but complex to cache, easy to execute, so we don't bother
        root = binding
        try:
            while True:
                parent = root.owner
                if not parent:
                    break
                root = parent
        except inmanta.execute.proxy.UnsetException:
            LOGGER.warning(
                "Please set the owner relation on the service entity binding for the service entity %s (instantiated at %s) "
                "in the constructor, even if it is null",
                binding.service_entity,
                str(binding._get_instance()._location),
            )
            raise
        except inmanta.ast.OptionalValueException:
            # found root
            pass
        LOGGER.log(
            inmanta.const.LOG_LEVEL_TRACE,
            "Resolution of service binding %s version %s: found root %s version %s",
            binding.service.service_entity_name,
            binding.version,
            root.service.service_entity_name,
            root.version,
        )
        siblings = set()
        todo = [root]
        while todo:
            item = todo.pop()
            if item not in siblings:
                siblings.add(item)
                todo.extend(item.owned)
        LOGGER.log(
            inmanta.const.LOG_LEVEL_TRACE,
            "Resolution of service binding %s version %s: found siblings [%s]",
            binding.service.service_entity_name,
            ", ".join(
                [
                    f"{sibling.service.service_entity_name} version {sibling.version}"
                    for sibling in siblings
                ]
            ),
        )
        for sibling in siblings:
            inmanta_plugins.lsm.global_cache.register_binding(sibling)
            self.root_for[(sibling.service.service_entity_name, sibling.version)] = root
        LOGGER.log(
            inmanta.const.LOG_LEVEL_TRACE,
            "Resolution of service binding %s version %s: done",
            binding.service.service_entity_name,
            binding.version,
        )
        # Build traversal spec for this root
        self.upward_lookup[(root.service.service_entity_name, root.version)] = None
        self.root_for[(root.service.service_entity_name, root.version)] = root
        self.build_traversal_spec(root)
    def build_traversal_spec(self, parent: "lsm::ServiceBindingVersion") -> None:
        logging.debug(
            "Building traversal spec for %s",
            (parent.service.service_entity_name, parent.version),
        )
        for child in parent.owned:
            child_type = child.service.service_entity_name
            traversal = TraversalStep(
                child_type,
                child.version,
                child.relation_to_owner,
                parent,
            )
            self.upward_lookup[(child_type, child.version)] = traversal
            self.children[parent.service.service_entity_name].add(
                (child_type, child.version)
            )
            self.build_traversal_spec(
                child
            )  # Will run out of stack if there is a cycle
    def find_root(
        self,
        current_instance: dict,
    ) -> dict:
        """Find the root instance, in its api form"""
        # cache
        id = current_instance["id"]
        root = self.root_instance_for.get(id)
        if root:
            return root
        service_type = current_instance["service_entity"]
        next_step = self.upward_lookup.get(
            (service_type, current_instance["service_entity_version"]), None
        )
        if next_step is None:
            # root
            self.root_instance_for[id] = current_instance
            return current_instance
        parent_id = self._get_attribute(current_instance, next_step.to_parent_relation)
        if not parent_id:
            raise InstancePartitioningException(
                f"The service instance {service_type}-{current_instance['id']} does not have the relation "
                f"`{next_step.to_parent_relation}` towards its parent set."
            )
        parent = inmanta_plugins.lsm.global_cache.get_instance(
            env=self.env,
            service_entity_name=None,
            instance_id=parent_id,
            include_terminated=True,
        )
        if not parent:
            raise InstancePartitioningException(
                f"Can not find the owning entity of {service_type}-{current_instance['id']}: "
                f"no instance found for {next_step.parent_entity}-{parent_id}"
            )
        root = self.find_root(parent)
        self.root_instance_for[id] = root
        return root
    def find_all_children(
        self,
        type_and_version: VersionedServiceEntity,
        owners: list[dict],
        found_instances: dict[VersionedServiceEntity, list],
    ) -> None:
        """
        Recursively finds all children of a given entity type and version.
        :param type_and_version: The service entity type and version of the parent entity
        :param owners: a list of instances of the parent entity that we want to find the children for
        :param found_instances: the instances that we have found thus far
        """
        found_instances[type_and_version].extend(owners)
        # Get child types for service entity name
        child_types = self.children.get(type_and_version[0])
        if not child_types:
            return
        parent_ids = {owner["id"] for owner in owners}
        for child_type, child_version in child_types:
            traversal = self.upward_lookup[(child_type, child_version)]
            assert (
                traversal is not None
            )  # make mypy happy, every child should have a parent
            instances_raw = inmanta_plugins.lsm.global_cache.get_all_instances(
                self.env, service_entity_name=child_type
            )
            # Check that the instance is of the correct version
            # And if the relation_to_owner matches one of the parent_ids
            selected_instances = [
                instance
                for instance in instances_raw
                if instance is not None
                and instance.get("service_entity_version", 0) == child_version
                and (
                    self._get_attribute(instance, traversal.to_parent_relation)
                    in parent_ids
                )
            ]
            self.find_all_children(
                (child_type, child_version), selected_instances, found_instances
            )
        return
    def select_all(self) -> dict[str, list[dict]]:
        if not self.requested_instances:
            raise Exception(f"Environment variable {lsm_const.ENV_INSTANCE_ID} not set")
        LOGGER.log(
            logging.INFO,
            "Selecting instances for current instance: %s",
            self.requested_instances,
        )
        global_cache = inmanta_plugins.lsm.global_cache
        for entity in {entity for entity, version in self.upward_lookup.keys()}:
            # Pre-load everything!
            global_cache.get_all_instances(self.env, entity)
        # roots, per type
        roots: dict[VersionedServiceEntity, dict[str, dict]] = defaultdict(dict)
        for instance in self.requested_instances:
            current_instance_raw = inmanta_plugins.lsm.global_cache.get_instance(
                env=self.env,
                service_entity_name=None,
                instance_id=instance,
                include_terminated=True,
            )
            if not current_instance_raw:
                LOGGER.info("No instance found for %s", instance)
                continue
            if (
                current_instance_raw["service_entity"],
                current_instance_raw["service_entity_version"],
            ) not in self.upward_lookup:
                LOGGER.info(
                    "Current instance %s is not under any known root",
                    instance,
                )
                continue
            root_instance = self.find_root(current_instance_raw)
            LOGGER.log(
                logging.DEBUG,
                "Root instance found for %s version %s:%s",
                root_instance["service_entity"],
                root_instance["service_entity_version"],
                root_instance["id"],
            )
            roots[
                (
                    root_instance["service_entity"],
                    root_instance["service_entity_version"],
                )
            ][root_instance["id"]] = root_instance
        # We use versions to fetch all the corresponding children
        versioned_instances: dict[VersionedServiceEntity, list[dict]] = defaultdict(
            list
        )
        for type_and_version, instances in roots.items():
            self.find_all_children(
                type_and_version,
                [root_instance for root_instance in instances.values()],
                versioned_instances,
            )
        # We don't care about versions on the output of select_all
        out: dict[str, list[dict]] = defaultdict(list)
        for entity, version in versioned_instances.keys():
            out[entity].extend(versioned_instances[(entity, version)])
        return out
class ParentSelector(TreeSelector):
    """
    Class to help select instances that have to be considered when doing a partial compile
    This class supports select all instances below a common root entity.
    """
    pass