Source code for inmanta.data

"""
    Copyright 2017 Inmanta

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.

    Contact: code@inmanta.com
"""

import asyncio
import copy
import datetime
import enum
import hashlib
import json
import logging
import re
import typing
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import abc, defaultdict
from collections.abc import Awaitable, Callable, Iterable, Sequence, Set
from configparser import RawConfigParser
from contextlib import AbstractAsyncContextManager
from itertools import chain
from re import Pattern
from typing import Generic, NewType, Optional, TypeVar, Union, cast, overload
from uuid import UUID

import asyncpg
import dateutil
import pydantic
import pydantic.tools
import typing_inspect
from asyncpg import Connection
from asyncpg.exceptions import SerializationError
from asyncpg.protocol import Record

import inmanta.const as const
import inmanta.db.versions
import inmanta.resources as resources
import inmanta.util as util
from crontab import CronTab
from inmanta.const import DATETIME_MIN_UTC, DONE_STATES, UNDEPLOYABLE_NAMES, AgentStatus, LogLevel, ResourceState
from inmanta.data import model as m
from inmanta.data import schema
from inmanta.data.model import (
    AuthMethod,
    BaseModel,
    PagingBoundaries,
    PipConfig,
    ResourceIdStr,
    api_boundary_datetime_normalizer,
)
from inmanta.protocol.common import custom_json_encoder
from inmanta.protocol.exceptions import BadRequest, NotFound
from inmanta.server import config
from inmanta.stable_api import stable_api
from inmanta.types import JsonType, PrimitiveTypes

LOGGER = logging.getLogger(__name__)

DBLIMIT = 100000
APILIMIT = 1000

# TODO: disconnect
# TODO: difference between None and not set

# Used as the 'default' parameter value for the Field class, when no default value has been set
default_unset = object()

PRIMITIVE_SQL_TYPES = Union[str, int, bool, datetime.datetime, UUID]

"""
Locking order rules:
In general, locks should be acquired consistently with delete cascade lock order, which is top down. Additional lock orderings
are as follows. This list should be extended when new locks (explicit or implicit) are introduced. The rules below are written
as `A -> B`, meaning A should be locked before B in any transaction that acquires a lock on both.
- Code -> ConfigurationModel
- Agentprocess -> Agentinstance -> Agent
"""


@enum.unique
class QueryType(str, enum.Enum):
    def _generate_next_value_(name, start: int, count: int, last_values: abc.Sequence[object]) -> str:  # noqa: N805
        """
        Make enum.auto() return the name of the enum member in lower case.
        """
        return name.lower()

    EQUALS = enum.auto()  # The filter value equals the value in the database
    CONTAINS = enum.auto()  # Any of the filter values are equal to the value in the database (exact match)
    IS_NOT_NULL = enum.auto()  # The value is NULL in the database
    CONTAINS_PARTIAL = enum.auto()  # Any of the filter values are equal to the value in the database (partial match)
    RANGE = enum.auto()  # The values in the database are in the range described by the filter values and operators
    NOT_CONTAINS = enum.auto()  # None of the filter values are equal to the value in the database (exact match)
    COMBINED = enum.auto()  # The value describes a combination of other query types


class InvalidQueryType(Exception):
    def __init__(self, message: str) -> None:
        super().__init__(message)
        self.message = message


class TableLockMode(enum.Enum):
    """
    Table level locks as defined in the PostgreSQL docs:

    https://www.postgresql.org/docs/13/explicit-locking.html#LOCKING-TABLES. When acquiring a lock, make sure to use the same
    locking order accross transactions (as described at the top of this module) to prevent deadlocks and to otherwise respect
    the consistency docs: https://www.postgresql.org/docs/13/applevel-consistency.html#NON-SERIALIZABLE-CONSISTENCY.

    Not all lock modes are currently supported to keep the interface minimal (only include what we actually use). This class
    may be extended when a new lock mode is required.
    """

    ROW_EXCLUSIVE = "ROW EXCLUSIVE"
    SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
    SHARE = "SHARE"
    SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"


class RowLockMode(enum.Enum):
    """
    Row level locks as defined in the PostgreSQL docs: https://www.postgresql.org/docs/13/explicit-locking.html#LOCKING-ROWS.
    When acquiring a lock, make sure to use the same locking order accross transactions (as described at the top of this
    module) to prevent deadlocks and to otherwise respect the consistency docs:
    https://www.postgresql.org/docs/13/applevel-consistency.html#NON-SERIALIZABLE-CONSISTENCY.
    """

    FOR_UPDATE = "FOR UPDATE"
    FOR_NO_KEY_UPDATE = "FOR NO KEY UPDATE"
    FOR_SHARE = "FOR SHARE"
    FOR_KEY_SHARE = "FOR KEY SHARE"


class RangeOperator(enum.Enum):
    LT = "<"
    LE = "<="
    GT = ">"
    GE = ">="

    @property
    def pg_value(self) -> str:
        return self.value

    @classmethod
    def parse(cls, text: str) -> "RangeOperator":
        try:
            return cls[text.upper()]
        except KeyError:
            raise ValueError(f"Failed to parse {text} as a RangeOperator")


RangeConstraint = list[tuple[RangeOperator, int]]
DateRangeConstraint = list[tuple[RangeOperator, datetime.datetime]]
QueryFilter = tuple[QueryType, object]


class PagingCounts:
    def __init__(self, total: int, before: int, after: int) -> None:
        self.total = total
        self.before = before
        self.after = after


class InvalidQueryParameter(Exception):
    def __init__(self, message: str) -> None:
        super().__init__(message)
        self.message = message


class InvalidFieldNameException(Exception):
    def __init__(self, message: str, *args: object) -> None:
        super().__init__(message, *args)
        self.message = message


ColumnNameStr = NewType("ColumnNameStr", str)
"""
    A valid database column name
"""

OrderStr = NewType("OrderStr", str)
"""
    A valid database ordering
"""


class ArgumentCollector:
    """
    Small helper to make placeholders for query arguments

    args = ArgumentCollector()
    query = f"SELECT * FROM table WHERE a = {args(a_value)} AND b = {args(b_value)}"
    con.fetch(query, *args.get_values())
    """

    def __init__(self, offset: int = 0, de_duplicate: bool = False) -> None:
        """

        :param offset: the smallest number already in use, the next one given out will be offset+1
        :param de_duplicate: if the value is the same, return the same number
        """
        self.args: list[object] = []
        self.offset = offset
        self.de_duplicate = de_duplicate

    def __call__(self, entry: object) -> str:
        if self.de_duplicate and entry in self.args:
            return "$" + str(self.args.index(entry) + 1 + self.offset)
        self.args.append(entry)
        return "$" + str(len(self.args) + self.offset)

    def get_values(self) -> list[object]:
        return self.args


class PagingOrder(str, enum.Enum):
    ASC = "ASC"
    DESC = "DESC"

    def invert(self) -> "PagingOrder":
        if self == PagingOrder.ASC:
            return PagingOrder.DESC
        return PagingOrder.ASC

    def db_form(self, *, nullable: bool = True) -> OrderStr:
        # The current filtering and sorting framework has the built-in assumption that nulls are considered the lowest values,
        # hence we must deviate from postgres' default order. As a result, we may lose the opportunity to use indexes, which
        # use the same order.
        # The framework can not easily be refactored because
        #   1. Not all column types have a sane MAX value to coalesce to
        #   2. The alternative approach to use a window function `row_number() OVER (ORDER BY ...)`, selecting on the ids of
        #       the first and last elements in the page, is more accurate, and does hit the indexes, but it also builds the
        #       row number for each row, which ends up costing even more.
        if nullable:
            if self == PagingOrder.ASC:
                return OrderStr("ASC NULLS FIRST")
            return OrderStr("DESC NULLS LAST")
        # Luckily, for NOT NULL columns we will never encounter the COALESCE issue, so we can safely use the default order.
        else:
            return OrderStr(self.value)


class InvalidSort(Exception):
    def __init__(self, message: str, *args: object) -> None:
        super().__init__(message, *args)
        self.message = message


class ColumnType:
    """
    Class encapsulating all handling of specific column types

    This implementation supports the PRIMITIVE_SQL_TYPES types, for more specific behavior, make a subclass.
    """

    def __init__(self, base_type: type[PRIMITIVE_SQL_TYPES], nullable: bool, table_prefix: Optional[str] = None) -> None:
        self.base_type = base_type
        self.nullable = nullable
        self.table_prefix = table_prefix
        self.table_prefix_dot = "" if table_prefix is None else f"{table_prefix}."

    def as_basic_filter_elements(self, name: str, value: object) -> Sequence[tuple[str, "ColumnType", object]]:
        """
        Break down this filter into more elementary filters

        :param name: column name, intended to be passed through get_accessor
        :param value: the value of this column
        :return: a list of (name, type, value) items
        """
        return [(name, self, self.get_value(value))]

    def as_basic_order_elements(self, name: str, order: PagingOrder) -> Sequence[tuple[str, "ColumnType", PagingOrder]]:
        """
        Break down this filter into more elementary filters

        :param name: column name, intended to be passed through get_accessor
        :return: a list of (name, type, order) items
        """
        return [(name, self, order)]

    def get_value(self, value: object) -> Optional[PRIMITIVE_SQL_TYPES]:
        """
        Prepare the actual value for use as an argument in a prepared statement for this type
        """
        if value is None:
            if not self.nullable:
                raise ValueError("None is not a valid value")
            else:
                return None
        if isinstance(value, self.base_type):
            # It is as expected
            return value
        if self.base_type == bool:
            ta = pydantic.TypeAdapter(bool)
            return ta.validate_python(value)
        if self.base_type == datetime.datetime and isinstance(value, str):
            return api_boundary_datetime_normalizer(dateutil.parser.isoparse(value))
        if issubclass(self.base_type, (str, int)) and isinstance(value, (str, int, bool)):
            # We can cast between those types
            return self.base_type(value)
        raise ValueError(f"{value} is not a valid value")

    def get_accessor(self, column_name: str, table_prefix: Optional[str] = None) -> str:
        """
        return the sql statement to get this column, as used in filter and other statements
        """
        table_prefix_value = self.table_prefix_dot if table_prefix is None else table_prefix + "."
        return table_prefix_value + column_name

    def coalesce_to_min(self, value_reference: str) -> str:
        """If the order by column is nullable, coalesce the parameter value to the minimum value of the specific type
        This is required for the comparisons used for paging, because comparing a value to
        NULL always yields NULL.
        """
        if self.nullable:
            if self.base_type == datetime.datetime:
                return f"COALESCE({value_reference}, to_timestamp(0))"
            elif self.base_type == bool:
                return f"COALESCE({value_reference}, FALSE)"
            elif self.base_type == int:
                # we only support positive ints up till now
                return f"COALESCE({value_reference}, -1)"
            elif self.base_type == str:
                return f"COALESCE({value_reference}, '')"
            elif self.base_type == UUID:
                return f"COALESCE({value_reference}, '00000000-0000-0000-0000-000000000000'::uuid)"
            else:
                assert False, "Unexpected argument type received, this should not happen"

        return value_reference

    def with_prefix(self, table_prefix: Optional[str]) -> "ColumnType":
        return ColumnType(self.base_type, self.nullable, table_prefix)


def TablePrefixWrapper(table_name: Optional[str], child: ColumnType) -> ColumnType:
    """
    This method is named like a class, because it replaces a former class.

    The functionality is not part ColumnType itself.
    """
    if table_name is None:
        return child
    return child.with_prefix(table_prefix=table_name)


class ForcedStringColumn(ColumnType):
    """A string that is explicitly cast to a specific string type"""

    def __init__(self, forced_type: str) -> None:
        super().__init__(base_type=str, nullable=False)
        self.forced_type = forced_type

    def get_accessor(self, column_name: str, table_prefix: Optional[str] = None) -> str:
        """
        return the sql statement to get this column, as used in filter and other statements
        """
        return super().get_accessor(column_name, table_prefix) + "::" + self.forced_type


StringColumn = ColumnType(base_type=str, nullable=False)
OptionalStringColumn = ColumnType(base_type=str, nullable=True)

DateTimeColumn = ColumnType(base_type=datetime.datetime, nullable=False)
OptionalDateTimeColumn = ColumnType(base_type=datetime.datetime, nullable=True)

PositiveIntColumn = ColumnType(base_type=int, nullable=False)
# Negatives ints require updating coalesce_to_min

TextColumn = ForcedStringColumn("text")

UUIDColumn = ColumnType(base_type=uuid.UUID, nullable=False)
BoolColumn = ColumnType(base_type=bool, nullable=False)


class DatabaseOrderV2(ABC):
    """
    Helper API for handling database order and filtering

    This class defines the consumer interface,

    It is made into a separate type, to make it very explicit what is exposed externally, to limit feature creep
    """

    @abstractmethod
    def as_filter(
        self,
        offset: int,
        column_value: Optional[PRIMITIVE_SQL_TYPES] = None,
        id_value: Optional[PRIMITIVE_SQL_TYPES] = None,
        start: bool = True,
    ) -> tuple[list[str], list[object]]:
        """
        Produce a filter for this order, to select all record before or after the given id

        :param offset: the next free number to use for query parameters
        :param column_value: the boundary value for the user specified order
        :param id_value: the boundary value for the built in order order
        :param start: is this the start filter? if so, retain all values`  > (column_value, id_value)`

        :return: The filter (as a string) and all associated query parameter values

        None values can have a double meaning here:
        - no value provided
        - the value is provided and None

        The distinction can be made as follows:
        1. at least one of the columns must be not nullable (otherwise the sorting is not unique)
        2. when both value are None, we are not paging and return '[],[]'
        3. when one of the values is effective, we produce a filter

        More specifically:
        1. when we have a single order, and `column_value` is not None, this singe value is used for filtering
        2. when we have a double order and the 'id_value' is not None and `self.get_order_by_column_type().nullable`,
            we consider the null an effective value and filter on both `column_value` and `id_value`
        3. when we have a double order and the 'id_value' is not None and `not self.get_order_by_column_type().nullable`,
            we consider the null not a value and filter only on `id_value`

        """

    @abstractmethod
    def get_order_by_statement(self, invert: bool = False, table: Optional[str] = None) -> str:
        """Get this order as an order_by statement"""

    @abstractmethod
    def get_order(self) -> PagingOrder:
        """Return the order of this paging request"""

    @abstractmethod
    def get_paging_boundaries(self, first: abc.Mapping[str, object], last: abc.Mapping[str, object]) -> PagingBoundaries:
        """Return the page boundaries, given the first and last record of the page"""


T_SELF = TypeVar("T_SELF", bound="SingleDatabaseOrder")


class SingleDatabaseOrder(DatabaseOrderV2, ABC):
    """
    Abstract Base class for ordering when using
    - a user specified order, that is always unique
    """

    def __init__(
        self,
        order_by_column: ColumnNameStr,
        order: PagingOrder,
    ) -> None:
        """The order_by_column and order parameters should be validated"""
        self.order_by_column = order_by_column
        self.order = order

    # Configuration methods
    @classmethod
    # TODO: cache this!
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        """Return all valid columns for lookup and their type"""
        raise NotImplementedError()

    #  Factory
    @classmethod
    def parse_from_string(
        cls: type[T_SELF],
        sort: str,
    ) -> T_SELF:
        valid_sort_pattern: Pattern[str] = re.compile(
            f"^({'|'.join(cls.get_valid_sort_columns().keys())})\\.(asc|desc)$", re.IGNORECASE
        )
        match = valid_sort_pattern.match(sort)
        if match and len(match.groups()) == 2:
            order_by_column = match.groups()[0].lower()
            # Verify there is no escaping from the regex by exact match
            assert order_by_column in cls.get_valid_sort_columns()
            order = match.groups()[1].upper()
            return cls(order_by_column=ColumnNameStr(order_by_column), order=PagingOrder[order])
        raise InvalidSort(f"Sort parameter invalid: {sort}")

    # Internal helpers
    def get_order(self, invert: bool = False) -> PagingOrder:
        """The order string representing the direction the results should be sorted by"""
        return self.order.invert() if invert else self.order

    def get_order_by_column_type(self) -> ColumnType:
        """The type of the order by column"""
        return self.get_valid_sort_columns()[self.order_by_column]

    def get_order_by_column_api_name(self) -> str:
        """The name of the column that the results should be ordered by"""
        return self.order_by_column

    # External API
    def as_filter(
        self,
        offset: int,
        column_value: Optional[PRIMITIVE_SQL_TYPES] = None,
        id_value: Optional[PRIMITIVE_SQL_TYPES] = None,
        start: bool = True,
    ) -> tuple[list[str], list[object]]:
        """
        Produce a filter for this order, to select all record before or after the given id

        :param offset: the next free number to use for query parameters
        :param column_value: the value for the user specified order
        :param id_value: the value for the built in order order, if this class has one. Otherwise this value is ignored.
        :param start: is this the start filter? if so, retain all values`  > (column_value, id_value)`

        :return: The filter (as a string) and all associated query parameter values
        """
        relation = ">" if start else "<"

        if column_value is None:
            return [], []

        coll_type = self.get_order_by_column_type()
        col_name = self.order_by_column
        value = coll_type.get_value(column_value)

        ac = ArgumentCollector(offset=offset - 1)
        filter = f"{coll_type.get_accessor(col_name)} {relation} {ac(value)}"
        return [filter], ac.args

    def get_order_elements(self, invert: bool) -> Sequence[tuple[ColumnNameStr, ColumnType, PagingOrder]]:
        """
        return a list of column/column type/order triples, to format an ORDER BY or FILTER statement
        """
        order = self.get_order(invert)
        return [
            (self.order_by_column, self.get_order_by_column_type(), order),
        ]

    def get_order_by_statement(self, invert: bool = False, table: Optional[str] = None) -> str:
        """Return the actual order by statement, as derived from get_order_elements"""
        order_by_part = ", ".join(
            (
                f"{type.get_accessor(col, table)} {order.db_form(nullable=type.nullable)}"
                for col, type, order in self.get_order_elements(invert)
            )
        )
        return f" ORDER BY {order_by_part}"

    def get_paging_boundaries(self, first: abc.Mapping[str, object], last: abc.Mapping[str, object]) -> PagingBoundaries:
        """Return the page boundaries, given the first and last record returned"""
        if self.get_order() == PagingOrder.ASC:
            first, last = last, first

        order_column_name = self.order_by_column
        order_type: ColumnType = self.get_order_by_column_type()

        def assert_not_null(in_value: Optional[PRIMITIVE_SQL_TYPES]) -> PRIMITIVE_SQL_TYPES:
            # Make mypy happy
            assert in_value is not None
            return in_value

        return PagingBoundaries(
            start=assert_not_null(order_type.get_value(first[order_column_name])),
            first_id=None,
            end=assert_not_null(order_type.get_value(last[order_column_name])),
            last_id=None,
        )

    def __str__(self) -> str:
        # used to serialize the order back to a  paging url
        return f"{self.order_by_column}.{self.order.value.lower()}"


class AbstractDatabaseOrderV2(SingleDatabaseOrder, ABC):
    """
    Abstract Base class for ordering when using
    - a user specified order
    - an additional built in order to make the ordering unique (the id_collumn)
    """

    @property
    @abstractmethod
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name and type of the id column of this database order"""

    # External API
    def as_filter(
        self,
        offset: int,
        column_value: Optional[PRIMITIVE_SQL_TYPES] = None,
        id_value: Optional[PRIMITIVE_SQL_TYPES] = None,
        start: bool = True,
    ) -> tuple[list[str], list[object]]:
        """
        Produce a filter for this order, to select all record before or after the given id

        :param offset: the next free number to use for query parameters
        :param column_value: the value for the user specified order
        :param id_value: the value for the built in order order
        :param start: is this the start filter? if so, retain all values`> (column_value, id_value)`,
            otherwise `< (column_value, id_value)`.

        :return: The filter (as a string) and all associated query parameter values
        """

        # All the filter elements:
        # 1. name of the actual collumn in the DB
        # 2. type of the collumn
        # 3. sanitized value of the collumn

        filter_elements: list[tuple[str, ColumnType, object]] = []

        order_by_collumns_type = self.get_order_by_column_type()
        paging_on_nullable = order_by_collumns_type.nullable and id_value is not None

        if column_value is not None or paging_on_nullable:
            # Have column value or paging on nullable
            filter_elements.extend(order_by_collumns_type.as_basic_filter_elements(self.order_by_column, column_value))

        if id_value is not None:
            # Have ID
            id_name, id_type = self.id_column
            if id_name != self.order_by_column:
                filter_elements.extend(id_type.as_basic_filter_elements(id_name, id_value))

        relation = ">" if start else "<"

        if len(filter_elements) == 0:
            return [], []

        ac = ArgumentCollector(offset=offset - 1)
        if len(filter_elements) == 1:
            col_name, coll_type, value = filter_elements[0]
            filter = f"{coll_type.get_accessor(col_name)} {relation} {ac(value)}"
            return [filter], ac.args
        else:
            # composed filter:
            # 1. comparison of two tuples (c_a, c_b) < (c_a, c_b)
            # 2. nulls must be removed to get proper comparison
            names_tuple = ", ".join(
                [coll_type.coalesce_to_min(coll_type.get_accessor(col_name)) for col_name, coll_type, value in filter_elements]
            )
            values_references_tuple = ", ".join(
                [coll_type.coalesce_to_min(ac(value)) for col_name, coll_type, value in filter_elements]
            )
            filter = f"({names_tuple}) {relation} ({values_references_tuple})"
            return [filter], ac.args

    def get_order_elements(self, invert: bool) -> list[tuple[ColumnNameStr, ColumnType, PagingOrder]]:
        """
        return a list of column/column type/order triples, to format an ORDER BY or FILTER statement
        """
        order = self.get_order(invert)
        id_name, id_type = self.id_column

        return list(
            self.get_order_by_column_type().as_basic_order_elements(self.order_by_column, order)
        ) + id_type.as_basic_order_elements(id_name, order)

    def get_paging_boundaries(self, first: abc.Mapping[str, object], last: abc.Mapping[str, object]) -> PagingBoundaries:
        """Return the page boundaries, given the first and last record returned"""
        if self.get_order() == PagingOrder.ASC:
            first, last = last, first

        order_column_name = self.order_by_column
        order_type: ColumnType = self.get_order_by_column_type()

        id_column, id_type = self.id_column

        return PagingBoundaries(
            start=order_type.get_value(first[order_column_name]),
            first_id=id_type.get_value(first[id_column]),
            end=order_type.get_value(last[order_column_name]),
            last_id=id_type.get_value(last[id_column]),
        )


class VersionedResourceOrder(AbstractDatabaseOrderV2):
    """Represents the ordering by which resources should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        return {
            ColumnNameStr("resource_type"): StringColumn,
            ColumnNameStr("agent"): StringColumn,
            ColumnNameStr("resource_id_value"): StringColumn,
        }

    @property
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name of the id column of this database order"""
        return ColumnNameStr("resource_id"), StringColumn


class ResourceStatusOrder(VersionedResourceOrder):
    """
    Resources with a status field
    """

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        return {
            **super().get_valid_sort_columns(),
            ColumnNameStr("resource_id"): StringColumn,
            ColumnNameStr("status"): TextColumn,
        }


class ResourceHistoryOrder(AbstractDatabaseOrderV2):
    """Represents the ordering by which resource history should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        """Describes the names and types of the columns that are valid for this DatabaseOrder"""
        return {ColumnNameStr("date"): DateTimeColumn}

    @property
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name and type of the id column of this database order"""
        return (ColumnNameStr("attribute_hash"), StringColumn)


class ResourceLogOrder(SingleDatabaseOrder):
    """Represents the ordering by which resource logs should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        return {
            ColumnNameStr("timestamp"): DateTimeColumn,
        }


class CompileReportOrder(AbstractDatabaseOrderV2):
    """Represents the ordering by which compile reports should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        """Describes the names and types of the columns that are valid for this DatabaseOrder"""
        return {ColumnNameStr("requested"): DateTimeColumn}

    @property
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name and type of the id column of this database order"""
        return (ColumnNameStr("id"), UUIDColumn)


class AgentOrder(AbstractDatabaseOrderV2):
    """Represents the ordering by which agents should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        """Describes the names and types of the columns that are valid for this DatabaseOrder"""
        return {
            ColumnNameStr("name"): TablePrefixWrapper("a", StringColumn),
            ColumnNameStr("process_name"): OptionalStringColumn,
            ColumnNameStr("paused"): BoolColumn,
            ColumnNameStr("last_failover"): OptionalDateTimeColumn,
            ColumnNameStr("status"): StringColumn,
        }

    @property
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name and type of the id column of this database order"""
        return (ColumnNameStr("name"), TablePrefixWrapper("a", StringColumn))


class DesiredStateVersionOrder(SingleDatabaseOrder):
    """Represents the ordering by which desired state versions should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        return {
            ColumnNameStr("version"): PositiveIntColumn,
        }


class ParameterOrder(AbstractDatabaseOrderV2):
    """Represents the ordering by which parameters should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        return {
            ColumnNameStr("name"): StringColumn,
            ColumnNameStr("source"): StringColumn,
            ColumnNameStr("updated"): OptionalDateTimeColumn,
        }

    @property
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name and type of the id column of this database order"""
        return (ColumnNameStr("id"), UUIDColumn)


class FactOrder(AbstractDatabaseOrderV2):
    """Represents the ordering by which facts should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        return {
            ColumnNameStr("name"): StringColumn,
            ColumnNameStr("resource_id"): StringColumn,
        }

    @property
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name and type of the id column of this database order"""
        return (ColumnNameStr("id"), UUIDColumn)


class NotificationOrder(AbstractDatabaseOrderV2):
    """Represents the ordering by which notifications should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        """Describes the names and types of the columns that are valid for this DatabaseOrder"""
        return {
            ColumnNameStr("created"): DateTimeColumn,
        }

    @property
    def id_column(self) -> tuple[ColumnNameStr, ColumnType]:
        """Name and type of the id column of this database order"""
        return (ColumnNameStr("id"), UUIDColumn)


class DiscoveredResourceOrder(SingleDatabaseOrder):
    """Represents the ordering by which discovered resources should be sorted"""

    @classmethod
    def get_valid_sort_columns(cls) -> dict[ColumnNameStr, ColumnType]:
        """Describes the names and types of the columns that are valid for this DatabaseOrder"""
        return {
            ColumnNameStr("discovered_resource_id"): StringColumn,
        }


class BaseQueryBuilder(ABC):
    """Provides a way to build up a sql query from its parts.
    Each method returns a new query builder instance, with the additional parameters processed"""

    def __init__(
        self,
        select_clause: Optional[str] = None,
        from_clause: Optional[str] = None,
        filter_statements: Optional[list[str]] = None,
        values: Optional[list[object]] = None,
    ) -> None:
        """
        The parameters are the parts of an sql query,
        which can also be added to the builder with the appropriate methods

        :param select_clause: The select clause of the query
        :param from_clause: From clause of the query
        :param filter_statements: A list of filters for the query
        :param values: The values to be used for the filter statements
        """
        self.select_clause = select_clause
        self._from_clause = from_clause
        self.filter_statements = filter_statements or []
        self.values = values or []

    def _join_filter_statements(self, filter_statements: list[str]) -> str:
        """Join multiple filter statements"""
        if filter_statements:
            return "WHERE " + " AND ".join(filter_statements)
        return ""

    @abstractmethod
    def from_clause(self, from_clause: str) -> "BaseQueryBuilder":
        """Set the from clause of the query"""
        raise NotImplementedError()

    @property
    def offset(self) -> int:
        """The current offset of the values to be used for filter statements"""
        return len(self.values) + 1

    @abstractmethod
    def filter(self, filter_statements: list[str], values: list[object]) -> "BaseQueryBuilder":
        """Add filters to the query"""
        raise NotImplementedError()

    @abstractmethod
    def build(self) -> tuple[str, list[object]]:
        """Builds up the full query string, and the parametrized value list, ready to be executed"""
        raise NotImplementedError()


class SimpleQueryBuilder(BaseQueryBuilder):
    """A query builder suitable for most queries"""

    def __init__(
        self,
        select_clause: Optional[str] = None,
        from_clause: Optional[str] = None,
        filter_statements: Optional[list[str]] = None,
        values: Optional[list[object]] = None,
        db_order: Optional[DatabaseOrderV2] = None,
        limit: Optional[int] = None,
        backward_paging: bool = False,
        prelude: Optional[str] = None,
    ) -> None:
        """
        :param select_clause: The select clause of the query
        :param from_clause: The from clause of the query
        :param filter_statements: A list of filters for the query
        :param values: The values to be used for the filter statements
        :param db_order: The DatabaseOrder describing how the results should be ordered
        :param limit: Limit the results to this amount
        :param backward_paging: Whether the ordering of the results should be inverted,
                                used when going backward through the pages
        :param prelude: part of the query preceding all else, for use with 'with' binding
        """
        super().__init__(select_clause, from_clause, filter_statements, values)
        self.db_order = db_order
        self.limit = limit
        self.backward_paging = backward_paging
        self.prelude = prelude

    def select(self, select_clause: str) -> "SimpleQueryBuilder":
        """Set the select clause of the query"""
        return SimpleQueryBuilder(
            select_clause,
            self._from_clause,
            self.filter_statements,
            self.values,
            self.db_order,
            self.limit,
            self.backward_paging,
            self.prelude,
        )

    def from_clause(self, from_clause: str) -> "SimpleQueryBuilder":
        """Set the from clause of the query"""
        return SimpleQueryBuilder(
            self.select_clause,
            from_clause,
            self.filter_statements,
            self.values,
            self.db_order,
            self.limit,
            self.backward_paging,
            self.prelude,
        )

    def order_and_limit(
        self, db_order: DatabaseOrderV2, limit: Optional[int] = None, backward_paging: bool = False
    ) -> "SimpleQueryBuilder":
        """Set the order and limit of the query"""
        return SimpleQueryBuilder(
            self.select_clause,
            self._from_clause,
            self.filter_statements,
            self.values,
            db_order,
            limit,
            backward_paging,
            self.prelude,
        )

    def filter(self, filter_statements: list[str], values: list[object]) -> "SimpleQueryBuilder":
        return SimpleQueryBuilder(
            self.select_clause,
            self._from_clause,
            self.filter_statements + filter_statements,
            self.values + values,
            self.db_order,
            self.limit,
            self.backward_paging,
            self.prelude,
        )

    def build(self) -> tuple[str, list[object]]:
        if not self.select_clause or not self._from_clause:
            raise InvalidQueryParameter("A valid query must have a SELECT and a FROM clause")
        full_query = f"""{self.select_clause}
                         {self._from_clause}
                         {self._join_filter_statements(self.filter_statements)}
                         """
        if self.prelude:
            full_query = self.prelude + full_query
        if self.db_order:
            full_query += self.db_order.get_order_by_statement(self.backward_paging)
        if self.limit is not None:
            if self.limit > DBLIMIT:
                raise InvalidQueryParameter(f"Limit cannot be bigger than {DBLIMIT}, got {self.limit}")
            elif self.limit > 0:
                full_query += " LIMIT " + str(self.limit)
        if self.db_order and self.backward_paging:
            order_by = self.db_order.get_order_by_statement(table="matching_records")
            full_query = f"""SELECT * FROM ({full_query}) AS matching_records {order_by}"""

        return full_query, self.values


def json_encode(value: object) -> str:
    # see json_encode in tornado.escape
    return json.dumps(value, default=util.internal_json_encoder)


T = TypeVar("T")


class Field(Generic[T]):
    def __init__(
        self,
        field_type: type[T],
        required: bool = False,
        is_many: bool = False,
        part_of_primary_key: bool = False,
        ignore: bool = False,
        default: object = default_unset,
        **kwargs: object,
    ) -> None:
        """A field in a document/record in the database. This class holds the metadata one how the data layer should handle
        the field.

        :param field_type: The python type of the field. This type should work with isinstance
        :param required: Is this value required. This means that it is not optional and it cannot be None
        :param is_many: Set to true when this is a list type
        :param part_of_primary_key: Set to true when the field is part of the primary key.
        :param ignore: Should this field be ignored when saving it to the database. This can be used to add a field to a
                       a class that should not be saved in the database.
        :param default: The default value for this field.
        """

        self._field_type = field_type
        self._required = required
        self._ignore = ignore
        self._part_of_primary_key = part_of_primary_key
        self._is_many = is_many

        self._default_value: object
        if default != default_unset:
            self._default = True
            self._default_value = default
        else:
            self._default = False
            self._default_value = None

    def get_field_type(self) -> type[T]:
        return self._field_type

    field_type = property(get_field_type)

    def is_required(self) -> bool:
        return self._required

    required = property(is_required)

    def get_default(self) -> bool:
        return self._default

    default = property(get_default)

    def get_default_value(self) -> T:
        return copy.copy(self._default_value)

    default_value = property(get_default_value)

    @property
    def ignore(self) -> bool:
        return self._ignore

    def is_part_of_primary_key(self) -> bool:
        return self._part_of_primary_key

    part_of_primary_key = property(is_part_of_primary_key)

    @property
    def is_many(self) -> bool:
        return self._is_many

    def _validate_single(self, name: str, value: object) -> None:
        """Validate a single value against the types in this field."""
        if not isinstance(value, self.field_type):
            raise TypeError(
                "Field %s should have the correct type (%s instead of %s)"
                % (name, self.field_type.__name__, type(value).__name__)
            )

    def validate(self, name: str, value: T) -> None:
        """Validate the value against the constraint in this field. Treat value as list when is_many is true"""
        if value is None and self.required:
            raise TypeError("%s field is required" % name)

        if value is None:
            return None

        if self.is_many:
            if not isinstance(value, list):
                TypeError(f"Field {name} should be a list, but got {type(value).__name__}")
            else:
                [self._validate_single(name, v) for v in value]
        else:
            self._validate_single(name, value)

    def from_db(self, name: str, value: object) -> object:
        """Load values from database. Treat value as a list when is_many is true. Converts database
        representation to appropriately typed object."""
        if value is None and self.required:
            raise TypeError("%s field is required" % name)

        if value is None:
            return None

        if self.is_many:
            if not isinstance(value, list):
                TypeError(f"Field {name} should be a list, but got {type(value).__name__}")
            else:
                return [self._from_db_single(name, v) for v in value]
        return self._from_db_single(name, value)

    def _from_db_single(self, name: str, value: object) -> object:
        """Load a single database value. Converts database representation to appropriately typed object."""
        if isinstance(value, self.field_type):
            return value

        # asyncpg does not convert a jsonb field to a dict
        if isinstance(value, str) and self.field_type is dict:
            return json.loads(value)
        # asyncpg does not convert an enum field to an enum type
        if isinstance(value, str) and issubclass(self.field_type, enum.Enum):
            return self.field_type[value]
        # decode typed json
        if isinstance(value, str) and issubclass(self.field_type, pydantic.BaseModel):
            jsv = json.loads(value)
            return self.field_type(**jsv)
        if self.field_type == pydantic.AnyHttpUrl:
            return pydantic.TypeAdapter(pydantic.AnyHttpUrl).validate_python(value)

        raise TypeError(
            f"Field {name} should have the correct type ({self.field_type.__name__} instead of {type(value).__name__})"
        )


class DataDocument:
    """
    A baseclass for objects that represent data in inmanta. The main purpose of this baseclass is to group dict creation
    logic. These documents are not stored in the database
    (use BaseDocument for this purpose). It provides a to_dict method that the inmanta rpc can serialize. You can store
    DataDocument children in BaseDocument fields, they will be serialized to dict. However, on retrieval this is not
    performed.
    """

    def __init__(self, **kwargs: object) -> None:
        self._data = kwargs

    def to_dict(self) -> JsonType:
        """
        Return a dict representation of this object.
        """
        return self._data


class InvalidAttribute(Exception):
    def __init__(self, message: str) -> None:
        super().__init__(message)
        self.message = message


class DocumentMeta(type):
    def __new__(cls, class_name: str, bases: tuple[type, ...], dct: dict[str, object]) -> type:
        dct["_fields_metadata"] = {}
        new_type: type[BaseDocument] = type.__new__(cls, class_name, bases, dct)
        if class_name != "BaseDocument":
            new_type.load_fields()
        return new_type


TBaseDocument = TypeVar("TBaseDocument", bound="BaseDocument")  # Part of the stable API
TransactionResult = TypeVar("TransactionResult")


[docs] @stable_api class BaseDocument(metaclass=DocumentMeta): """ A base document in the database. Subclasses of this document determine collections names. This type is mainly used to bundle query methods and generate validate and query methods for optimized DB access. This is not a full ODM. Fields are modelled using type annotations similar to protocol and pydantic. The following is supported: - Attributes are defined at class level with type annotations - Attributes do not need a default value. When no default is provided, they are marked as required. - When a value does not have to be set: either a default value or making it optional can be used. When a field is optional without a default value, none will be set as default value so that the field is available. - Fields that should be ignored, can be added to __ignore_fields__ This attribute is a tuple of strings - Fields that are part of the primary key should be added to the __primary_key__ attributes. This attribute is a tuple of strings. """ _connection_pool: Optional[asyncpg.pool.Pool] = None _fields_metadata: dict[str, Field] __primary_key__: tuple[str, ...] __ignore_fields__: tuple[str, ...] def __init__(self, from_postgres: bool = False, **kwargs: object) -> None: """ :param kwargs: The values to create the document. When id is defined in the fields but not provided, a new UUID is generated. """ self.__process_kwargs(from_postgres, kwargs) @classmethod def get_connection( cls, connection: Optional[asyncpg.connection.Connection] = None ) -> AbstractAsyncContextManager[asyncpg.connection.Connection]: """ Returns a context manager to acquire a connection. If an existing connection is passed, returns a dummy context manager wrapped around that connection instance. This allows for transparent usage, regardless of whether a connection has already been acquired. """ if connection is not None: return util.nullcontext(connection) # Make mypy happy assert cls._connection_pool is not None return cls._connection_pool.acquire() @classmethod def table_name(cls) -> str: """ Return the name of the collection """ return cls.__name__.lower() @classmethod def get_field_metadata(cls) -> dict[str, Field]: return cls._fields_metadata.copy() @staticmethod def _annotation_to_field( attribute: str, annotation: type[object], has_value: bool = True, value: Optional[object] = None, part_of_primary_key: bool = False, ignore_field: bool = False, ) -> Field: """Convert an annotated definition to a Field instance. The conversion rules are the following: - The value assigned to the field is the default value - When the default value is None the type has to be Optional - When the field is not optional, None is not a valid value - When the field has no default value, it is not required """ field_type: type[object] = annotation required: bool = not has_value default: object = default_unset is_many: bool = False # Only union with None (optional) is support if typing_inspect.is_union_type(annotation) and not typing_inspect.is_optional_type(annotation): raise InvalidAttribute(f"A union that is not an optional in field {attribute} is not supported.") if typing_inspect.is_optional_type(annotation): # The value optional. When no default is set, it will be None. required = False default = None # Filter out the None from the union type_args = typing_inspect.get_args(annotation, evaluate=True) if len(type_args) != 2: raise InvalidAttribute(f"Only optionals with one type are supported, field {attribute} has more.") field_type = [typ for typ in type_args if typ][0] if has_value: # A default value is available, so not required. When optional type, override the default None required = False default = value if typing_inspect.is_generic_type(field_type): orig = typing_inspect.get_origin(field_type) # First two are for python3.6, the last two for 3.7 and up if orig in [list, typing.Sequence, list, abc.Sequence]: is_many = True type_args = typing_inspect.get_args(field_type) if len(type_args) == 0 or isinstance(type_args[0], typing.TypeVar): # In python3.8 type_args is not empty when you write List but it will contain an instance of TypeVar raise InvalidAttribute(f"Generic type of field {attribute} requires a type argument.") field_type = type_args[0] # List of Dict for example still cannot be validated. If the type is still a generic. Set the type to List of # object. if typing_inspect.is_generic_type(field_type): field_type = object elif orig in [typing.Mapping, dict, abc.Mapping, dict]: field_type = dict if typing_inspect.is_new_type(field_type): # Python 3.10 and later NewType is a real type and an isinstance will work. On older version NewType is a function. # If this is the case we need to get the real supertype if callable(field_type): field_type = field_type.__supertype__ return Field( field_type=field_type, required=required, default=default, is_many=is_many, part_of_primary_key=part_of_primary_key, ignore=ignore_field, ) @classmethod def load_fields(cls) -> None: """Load the field metadata from the class definition. This method supports two different mechanisms: 1. Using the field class as the value of the attribute. 2. Using type annotations on the attributes """ primary_key: tuple[str, ...] = tuple() ignore: tuple[str, ...] = tuple() if "__primary_key__" in cls.__dict__: primary_key = cls.__primary_key__ if "__ignore_fields__" in cls.__dict__: ignore = cls.__ignore_fields__ for attribute, value in cls.__dict__.items(): if attribute.startswith("_"): continue elif isinstance(value, Field): warnings.warn(f"Field {attribute} should be defined using annotations instead of Field.") cls._fields_metadata[attribute] = value elif cls.__annotations__ and attribute in cls.__annotations__: annotation = cls.__annotations__[attribute] cls._fields_metadata[attribute] = cls._annotation_to_field( attribute, annotation, has_value=True, value=value, part_of_primary_key=attribute in primary_key, ignore_field=attribute in ignore, ) # attributes that do not have a default value will only be present in __annotations__ and not in __dict__ for attribute, annotation in cls.__annotations__.items(): if not attribute.startswith("_") and attribute not in cls._fields_metadata: cls._fields_metadata[attribute] = cls._annotation_to_field( attribute, annotation, has_value=False, part_of_primary_key=attribute in primary_key, ignore_field=attribute in ignore, ) @classmethod def get_field_names(cls) -> typing.KeysView[str]: """Returns all field names in the document""" return cls.get_field_metadata().keys() def __process_kwargs(self, from_postgres: bool, kwargs: dict[str, object]) -> None: """This helper method process the kwargs provided to the constructor and populates the fields of the object.""" fields = self.get_field_metadata() if "id" in fields and "id" not in kwargs: kwargs["id"] = uuid.uuid4() for name, value in kwargs.items(): if name not in fields: raise AttributeError(f"{name} field is not defined for this document {type(self).__name__.lower()}") field = fields[name] if not from_postgres: field.validate(name, value) elif not field.ignore: value = field.from_db(name, value) else: value = None setattr(self, name, value) del fields[name] required_fields = [] for name, field in fields.items(): # when a default value is used, make sure it is copied if field.default: setattr(self, name, copy.deepcopy(field.default_value)) # update the list of required fields elif fields[name].required: required_fields.append(name) if len(required_fields) > 0: raise AttributeError("The fields %s are required and no value was provided." % ", ".join(required_fields)) @classmethod def get_valid_field_names(cls) -> list[str]: return list(cls.get_field_names()) @classmethod def _get_names_of_primary_key_fields(cls) -> list[str]: return [name for name, value in cls.get_field_metadata().items() if value.is_part_of_primary_key()] def _get_filter_on_primary_key_fields(self, offset: int = 1) -> tuple[str, list[object]]: names_primary_key_fields = self._get_names_of_primary_key_fields() query = {field_name: self.__getattribute__(field_name) for field_name in names_primary_key_fields} return self._get_composed_filter(offset=offset, **query) @classmethod def _new_id(cls) -> uuid.UUID: """ Generate a new ID. Override to use something else than uuid4 """ return uuid.uuid4() @classmethod def set_connection_pool(cls, pool: asyncpg.pool.Pool) -> None: if cls._connection_pool: raise Exception(f"Connection already set on {cls} ({cls._connection_pool}!") cls._connection_pool = pool @classmethod async def close_connection_pool(cls) -> None: if not cls._connection_pool: return try: await asyncio.wait_for(cls._connection_pool.close(), config.db_connection_timeout.get()) except asyncio.TimeoutError: cls._connection_pool.terminate() # Don't propagate this exception but just write a log message. This way: # * A timeout here still makes sure that the other server slices get stopped # * The tests don't fail when this timeout occurs LOGGER.exception("A timeout occurred while closing the connection pool to the database") except asyncio.CancelledError: cls._connection_pool.terminate() # Propagate cancel raise except Exception: LOGGER.exception("An unexpected exception occurred while closing the connection pool to the database") raise finally: cls._connection_pool = None def __setattr__(self, name: str, value: object) -> None: if name[0] == "_": return object.__setattr__(self, name, value) fields = self.get_field_metadata() if name in fields: field = fields[name] # validate field.validate(name, value) object.__setattr__(self, name, value) return raise AttributeError(name) @classmethod def _convert_field_names_to_db_column_names(cls, field_dict: dict[str, object]) -> dict[str, object]: return field_dict def get_value(self, name: str, default_value: Optional[object] = None) -> object: """Check if a value is set for a field. Fields that are declared but that do not have a value are only present in annotations but not as attribute (in __dict__)""" if hasattr(self, name): return getattr(self, name) return default_value def _get_column_names_and_values(self) -> tuple[list[str], list[object]]: column_names: list[str] = [] values: list[object] = [] for name, metadata in self.get_field_metadata().items(): if metadata.ignore: continue value = self.get_value(name) if metadata.required and value is None: raise TypeError(f"{self.__name__} should have field '{name}'") metadata.validate(name, value) column_names.append(name) values.append(self._get_value(value)) return column_names, values async def insert(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Insert a new document based on the instance passed. Validation is done based on the defined fields. """ (column_names, values) = self._get_column_names_and_values() column_names_as_sql_string = ",".join(column_names) values_as_parameterized_sql_string = ",".join(["$" + str(i) for i in range(1, len(values) + 1)]) query = ( f"INSERT INTO {self.table_name()} " f"({column_names_as_sql_string}) " f"VALUES ({values_as_parameterized_sql_string})" ) await self._execute_query(query, *values, connection=connection) async def insert_with_overwrite(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Insert a new document based on the instance passed. If the document already exists, overwrite it. """ return await self.insert_many_with_overwrite([self], connection=connection) @classmethod async def _fetchval(cls, query: str, *values: object, connection: Optional[asyncpg.connection.Connection] = None) -> object: async with cls.get_connection(connection) as con: return await con.fetchval(query, *values) @classmethod async def _fetch_int(cls, query: str, *values: object, connection: Optional[asyncpg.connection.Connection] = None) -> int: """Fetch a single integer value""" value = await cls._fetchval(query, *values, connection=connection) assert isinstance(value, int) return value @classmethod async def _fetchrow( cls, query: str, *values: object, connection: Optional[asyncpg.connection.Connection] = None ) -> Optional[Record]: async with cls.get_connection(connection) as con: return await con.fetchrow(query, *values) @classmethod async def _fetch_query( cls, query: str, *values: object, connection: Optional[asyncpg.connection.Connection] = None ) -> Sequence[Record]: async with cls.get_connection(connection) as con: return await con.fetch(query, *values) @classmethod async def _execute_query( cls, query: str, *values: object, connection: Optional[asyncpg.connection.Connection] = None ) -> str: async with cls.get_connection(connection) as con: return await con.execute(query, *values) @classmethod async def lock_table(cls, mode: TableLockMode, connection: asyncpg.connection.Connection) -> None: """ Acquire a table-level lock on a single environment. Callers should adhere to a consistent locking order accross transactions as described at the top of this module. Passing a connection object is mandatory. The connection is expected to be in a transaction. """ await cls._execute_query(f"LOCK TABLE {cls.table_name()} IN {mode.value} MODE", connection=connection) async def _xact_lock( self, lock_key: int, instance_key: uuid.UUID, *, shared: bool = False, connection: asyncpg.Connection ) -> None: """ Acquires a transaction-level advisory lock for concurrency control :param lock_key: the key identifying this lock (32 bit signed int) :param instance_key: the key identifying the instance to lock. We only use the lower 32 bits, so it can collide. :param shared: If true, doesn't conflict with other shared locks, only with non-shared ones. :param connection: The connection hosting the transaction for which to acquire a lock. """ lock: str = "pg_advisory_xact_lock_shared" if shared else "pg_advisory_xact_lock" await connection.execute( # Advisory lock keys are only 32 bit (or a single 64 bit key), while a full uuid is 128 bit. # Since locking slightly too strictly at extremely low odds is acceptable, we only use a 32 bit subvalue # of the uuid. For uuid4, time_low is (despite the name) randomly generated. Since it is an unsigned # integer while Postgres expects a signed one, we shift it by 2**31. f"SELECT {lock}($1, $2)", lock_key, instance_key.time_low - 2**31, ) @classmethod async def insert_many( cls, documents: Sequence["BaseDocument"], *, connection: Optional[asyncpg.connection.Connection] = None ) -> None: """ Insert multiple objects at once """ if not documents: return columns = cls.get_field_names() records: list[tuple[object, ...]] = [] for doc in documents: current_record = [] for col in columns: current_record.append(cls._get_value(doc.__getattribute__(col))) records.append(tuple(current_record)) async with cls.get_connection(connection) as con: await con.copy_records_to_table(table_name=cls.table_name(), columns=columns, records=records, schema_name="public") @classmethod async def insert_many_with_overwrite( cls, documents: Sequence["BaseDocument"], *, connection: Optional[asyncpg.connection.Connection] = None ) -> None: """ Insert new documents. If the document already exists, overwrite it. """ if not documents: return column_names = cls.get_field_names() primary_key_fields = cls._get_names_of_primary_key_fields() primary_key_string = ",".join(primary_key_fields) update_set = set(column_names) - set(cls._get_names_of_primary_key_fields()) update_set_string = ",\n".join([f"{item} = EXCLUDED.{item}" for item in update_set]) values: list[list[object]] = [document._get_column_names_and_values()[1] for document in documents] column_names_as_sql_string = ", ".join(column_names) number_of_columns = len(values[0]) placeholders = ", ".join( [ "(" + ", ".join([f"${doc * number_of_columns + col}" for col in range(1, number_of_columns + 1)]) + ")" for doc in range(len(values)) ] ) query = f"""INSERT INTO {cls.table_name()} ({column_names_as_sql_string}) VALUES {placeholders} ON CONFLICT ({primary_key_string}) DO UPDATE SET {update_set_string};""" flattened_values = [item for sublist in values for item in sublist] await cls._execute_query(query, *flattened_values) def add_default_values_when_undefined(self, **kwargs: object) -> dict[str, object]: result = dict(kwargs) for name, field in self._fields.items(): if name not in kwargs: default_value = field.default_value result[name] = default_value return result async def update(self, connection: Optional[asyncpg.connection.Connection] = None, **kwargs: object) -> None: """ Update this document in the database. It will update the fields in this object and send a full update to database. Use update_fields to only update specific fields. """ kwargs = self._convert_field_names_to_db_column_names(kwargs) for name, value in kwargs.items(): setattr(self, name, value) (column_names, values) = self._get_column_names_and_values() values_as_parameterized_sql_string = ",".join([column_names[i - 1] + "=$" + str(i) for i in range(1, len(values) + 1)]) (filter_statement, values_for_filter) = self._get_filter_on_primary_key_fields(offset=len(column_names) + 1) values = values + values_for_filter query = "UPDATE " + self.table_name() + " SET " + values_as_parameterized_sql_string + " WHERE " + filter_statement await self._execute_query(query, *values, connection=connection) def _get_set_statement(self, **kwargs: object) -> tuple[str, list[object]]: counter = 1 parts_of_set_statement = [] values = [] for name, value in kwargs.items(): setattr(self, name, value) parts_of_set_statement.append(name + "=$" + str(counter)) values.append(self._get_value(value)) counter += 1 set_statement = ",".join(parts_of_set_statement) return (set_statement, values) async def update_fields(self, connection: Optional[asyncpg.connection.Connection] = None, **kwargs: object) -> None: """ Update the given fields of this document in the database. It will update the fields in this object and do a specific $set in the database on this document. """ if len(kwargs) == 0: return kwargs = self._convert_field_names_to_db_column_names(kwargs) for name, value in kwargs.items(): setattr(self, name, value) (set_statement, values_set_statement) = self._get_set_statement(**kwargs) (filter_statement, values_for_filter) = self._get_filter_on_primary_key_fields(offset=len(kwargs) + 1) values = values_set_statement + values_for_filter query = "UPDATE " + self.table_name() + " SET " + set_statement + " WHERE " + filter_statement await self._execute_query(query, *values, connection=connection)
[docs] @classmethod async def get_by_id( cls: type[TBaseDocument], doc_id: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None ) -> Optional[TBaseDocument]: """ Get a specific document based on its ID :return: An instance of this class with its fields filled from the database. """ result = await cls.get_list(id=doc_id, connection=connection) if len(result) > 0: return result[0] return None
@classmethod async def get_one( cls: type[TBaseDocument], connection: Optional[asyncpg.connection.Connection] = None, lock: Optional[RowLockMode] = None, **query: object, ) -> Optional[TBaseDocument]: results = await cls.get_list( connection=connection, order_by_column=None, order=None, limit=1, offset=None, no_obj=None, lock=lock, **query, ) if results: return results[0] return None @classmethod def _validate_order(cls, order_by_column: str, order: str) -> tuple[ColumnNameStr, OrderStr]: """Validate the correct values for order and if the order column is an existing column name :param order_by_column: The name of the column to order by :param order: The sorting order. :return: """ for o in order.split(" "): possible = ["ASC", "DESC", "NULLS", "FIRST", "LAST"] if o not in possible: raise RuntimeError(f"The following order can not be applied: {order}, {o} should be one of {possible}") if order_by_column not in cls.get_field_names(): raise RuntimeError(f"{order_by_column} is not a valid field name.") return ColumnNameStr(order_by_column), OrderStr(order) @classmethod def _validate_order_strict(cls, order_by_column: str, order: str) -> tuple[ColumnNameStr, PagingOrder]: """Validate the correct values for order ('ASC' or 'DESC') and if the order column is an existing column name :param order_by_column: The name of the column to order by :param order: The sorting order. :return: """ for o in order.split(" "): possible = ["ASC", "DESC"] if o not in possible: raise RuntimeError(f"The following order can not be applied: {order}, {o} should be one of {possible}") if order_by_column not in cls.get_valid_field_names(): raise RuntimeError(f"{order_by_column} is not a valid field name.") return ColumnNameStr(order_by_column), PagingOrder[order]
[docs] @classmethod async def get_list( cls: type[TBaseDocument], *, # All defaults None rather actual values to allow explicitly requesting defaults to improve type safety with **query order_by_column: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, no_obj: Optional[bool] = None, lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, **query: object, ) -> list[TBaseDocument]: """ Get a list of documents matching the filter args """ return await cls.get_list_with_columns( order_by_column=order_by_column, order=order, limit=limit, offset=offset, no_obj=no_obj, lock=lock, connection=connection, columns=None, **query, )
@classmethod async def get_list_with_columns( cls: type[TBaseDocument], *, order_by_column: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, no_obj: Optional[bool] = None, lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, columns: Optional[list[str]] = None, **query: object, ) -> list[TBaseDocument]: """ Get a list of documents matching the filter args """ if order is None: order = "ASC" if order_by_column: cls._validate_order(order_by_column, order) if no_obj is None: no_obj = False query = cls._convert_field_names_to_db_column_names(query) (filter_statement, values) = cls._get_composed_filter(**query) selected_columns = " * " if columns: selected_columns = ",".join([cls.validate_field_name(column) for column in columns]) sql_query = f"SELECT {selected_columns} FROM " + cls.table_name() if filter_statement: sql_query += " WHERE " + filter_statement if order_by_column is not None: sql_query += f" ORDER BY {order_by_column} {order}" if limit is not None and limit > 0: sql_query += " LIMIT $" + str(len(values) + 1) values.append(int(limit)) if offset is not None and offset > 0: sql_query += " OFFSET $" + str(len(values) + 1) values.append(int(offset)) if lock is not None: sql_query += f" {lock.value}" result = await cls.select_query(sql_query, values, no_obj=no_obj, connection=connection) return result @classmethod async def get_list_paged( cls: type[TBaseDocument], *, page_by_column: str, order_by_column: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, start: Optional[object] = None, end: Optional[object] = None, no_obj: Optional[bool] = None, lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, **query: object, ) -> list[TBaseDocument]: """ Get a list of documents matching the filter args, with paging support :param page_by_column: The name of the column in the database on which the paging should be applied :param order_by_column: The name of the column in the database the sorting should be based on :param order: The order to apply to the sorting :param limit: If specified, the maximum number of entries to return :param start: A value conforming the sorting column type, all returned rows will have greater value in the sorted column :param end: A value conforming the sorting column type, all returned rows will have lower value in the sorted column :param no_obj: Whether not to cast the query result into a matching object :param connection: An optional connection :param **query: Any additional filter to apply """ if order is None: order = "ASC" if order_by_column: cls._validate_order(order_by_column, order) if no_obj is None: no_obj = False query = cls._convert_field_names_to_db_column_names(query) (filter_statement, values) = cls._get_composed_filter(**query) filter_statements = filter_statement.split(" AND ") if filter_statement != "" else [] if start is not None: filter_statements.append(f"{page_by_column} > $" + str(len(values) + 1)) values.append(cls._get_value(start)) if end is not None: filter_statements.append(f"{page_by_column} < $" + str(len(values) + 1)) values.append(cls._get_value(end)) sql_query = "SELECT * FROM " + cls.table_name() if len(filter_statements) > 0: sql_query += " WHERE " + " AND ".join(filter_statements) if order_by_column is not None: sql_query += f" ORDER BY {order_by_column} {order}" if limit is not None and limit > 0: sql_query += " LIMIT $" + str(len(values) + 1) values.append(int(limit)) if lock is not None: sql_query += f" {lock.value}" result = await cls.select_query(sql_query, values, no_obj=no_obj, connection=connection) return result @classmethod async def delete_all(cls, connection: Optional[asyncpg.connection.Connection] = None, **query: object) -> int: """ Delete all documents that match the given query """ query = cls._convert_field_names_to_db_column_names(query) (filter_statement, values) = cls._get_composed_filter(**query) query = "DELETE FROM " + cls.table_name() if filter_statement: query += " WHERE " + filter_statement result = await cls._execute_query(query, *values, connection=connection) record_count = int(result.split(" ")[1]) return record_count @classmethod def _get_composed_filter( cls, offset: int = 1, col_name_prefix: Optional[str] = None, **query: object ) -> tuple[str, list[object]]: filter_statements = [] values = [] index_count = max(1, offset) for key, value in query.items(): cls.validate_field_name(key) name = cls._add_column_name_prefix_if_needed(key, col_name_prefix) (filter_statement, value) = cls._get_filter(name, value, index_count) filter_statements.append(filter_statement) values.extend(value) index_count += len(value) filter_as_string = " AND ".join(filter_statements) return (filter_as_string, values) @classmethod def _get_filter(cls, name: str, value: object, index: int) -> tuple[str, list[object]]: if value is None: return (name + " IS NULL", []) filter_statement = name + "=$" + str(index) value = cls._get_value(value) return (filter_statement, [value]) @classmethod def _get_value(cls, value: object) -> object: if isinstance(value, dict): return json_encode(value) if isinstance(value, (DataDocument, BaseModel)): return json_encode(value) if isinstance(value, list): return [cls._get_value(x) for x in value] if isinstance(value, enum.Enum): return value.name if isinstance(value, uuid.UUID): return str(value) return value @classmethod def get_composed_filter_with_query_types( cls, offset: int = 1, col_name_prefix: Optional[str] = None, **query: QueryFilter ) -> tuple[list[str], list[object]]: filter_statements = [] values: list[object] = [] index_count = max(1, offset) for key, value_with_query_type in query.items(): query_type, value = value_with_query_type filter_statement: str filter_values: list[object] name = cls._add_column_name_prefix_if_needed(key, col_name_prefix) filter_statement, filter_values = cls.get_filter_for_query_type(query_type, name, value, index_count) filter_statements.append(filter_statement) values.extend(filter_values) index_count += len(filter_values) return (filter_statements, values) @classmethod def get_filter_for_query_type( cls, query_type: QueryType, key: str, value: object, index_count: int ) -> tuple[str, list[object]]: if query_type == QueryType.EQUALS: (filter_statement, filter_values) = cls._get_filter(key, value, index_count) elif query_type == QueryType.IS_NOT_NULL: (filter_statement, filter_values) = cls.get_is_not_null_filter(key) elif query_type == QueryType.CONTAINS: (filter_statement, filter_values) = cls.get_contains_filter(key, value, index_count) elif query_type == QueryType.CONTAINS_PARTIAL: (filter_statement, filter_values) = cls.get_contains_partial_filter(key, value, index_count) elif query_type == QueryType.RANGE: (filter_statement, filter_values) = cls.get_range_filter(key, value, index_count) elif query_type == QueryType.NOT_CONTAINS: (filter_statement, filter_values) = cls.get_not_contains_filter(key, value, index_count) elif query_type == QueryType.COMBINED: (filter_statement, filter_values) = cls.get_filter_for_combined_query_type( key, cast(dict[QueryType, object], value), index_count ) else: raise InvalidQueryType(f"Query type should be one of {[query for query in QueryType]}") return (filter_statement, filter_values) @classmethod def validate_field_name(cls, name: str) -> ColumnNameStr: """Check if the name is a valid database column name for the current type""" if name not in cls.get_valid_field_names(): raise InvalidFieldNameException(f"{name} is not valid for a query on {cls.table_name()}") return ColumnNameStr(name) @classmethod def _add_column_name_prefix_if_needed(cls, filter_statement: str, col_name_prefix: Optional[str] = None) -> str: if col_name_prefix is not None: filter_statement = f"{col_name_prefix}.{filter_statement}" return filter_statement @classmethod def get_is_not_null_filter(cls, name: str) -> tuple[str, list[object]]: """ Returns a tuple of a PostgresQL statement and any query arguments to filter on values that are not null. """ filter_statement = f"{name} IS NOT NULL" return (filter_statement, []) @classmethod def get_contains_filter(cls, name: str, value: object, index: int) -> tuple[str, list[object]]: """ Returns a tuple of a PostgresQL statement and any query arguments to filter on values that are contained in a given collection. """ filter_statement = f"{name} = ANY (${str(index)})" value = cls._get_value(value) return (filter_statement, [value]) @classmethod def get_filter_for_combined_query_type( cls, name: str, combined_value: dict[QueryType, object], index: int ) -> tuple[str, list[object]]: """ Returns a tuple of a PostgresQL statement and any query arguments to filter a single column based on the defined query types """ filters = [] for query_type, value in combined_value.items(): filter_statement, filter_values = cls.get_filter_for_query_type(query_type, name, value, index) filters.append((filter_statement, filter_values)) index += len(filter_values) (filter_statement, values) = cls._combine_filter_statements(filters) return (filter_statement, values) @classmethod def get_not_contains_filter(cls, name: str, value: object, index: int) -> tuple[str, list[object]]: """ Returns a tuple of a PostgresQL statement and any query arguments to filter on values that are not contained in a given collection. """ filter_statement = f"NOT ({name} = ANY (${str(index)}))" value = cls._get_value(value) return (filter_statement, [value]) @classmethod def get_contains_partial_filter(cls, name: str, value: object, index: int) -> tuple[str, list[object]]: """ Returns a tuple of a PostgresQL statement and any query arguments to filter on values that are contained in a given collection. """ filter_statement = f"{name} ILIKE ANY (${str(index)})" value = cls._get_value(value) value = [f"%{v}%" for v in value] return (filter_statement, [value]) @classmethod def get_range_filter( cls, name: str, value: Union[DateRangeConstraint, RangeConstraint], index: int ) -> tuple[str, list[object]]: """ Returns a tuple of a PostgresQL statement and any query arguments to filter on values that match a given range constraint. """ filter_statement: str values: list[object] (filter_statement, values) = cls._combine_filter_statements( ( f"{name} {operator.pg_value} ${str(index + i)}", [cls._get_value(bound)], ) for i, (operator, bound) in enumerate(value) ) return (filter_statement, [cls._get_value(v) for v in values]) @staticmethod def _combine_filter_statements(statements_and_values: Iterable[tuple[str, list[object]]]) -> tuple[str, list[object]]: filter_statements: tuple[str] values: tuple[list[object]] filter_statements, values = zip(*statements_and_values) # type: ignore return ( " AND ".join(s for s in filter_statements if s != ""), list(chain.from_iterable(values)), ) @classmethod def _add_start_filter( cls, offset: int, order_by_column: ColumnNameStr, id_column: ColumnNameStr, start: Optional[object] = None, first_id: Optional[Union[uuid.UUID, str]] = None, ) -> tuple[list[str], list[object]]: filter_statements = [] values: list[object] = [] if start is not None and first_id: filter_statements.append(f"({order_by_column}, {id_column}) > (${str(offset + 1)}, ${str(offset + 2)})") values.append(cls._get_value(start)) values.append(cls._get_value(first_id)) elif start is not None: filter_statements.append(f"{order_by_column} > ${str(offset + 1)}") values.append(cls._get_value(start)) return filter_statements, values @classmethod def _add_end_filter( cls, offset: int, order_by_column: ColumnNameStr, id_column: ColumnNameStr, end: Optional[object] = None, last_id: Optional[Union[uuid.UUID, str]] = None, ) -> tuple[list[str], list[object]]: filter_statements = [] values: list[object] = [] if end is not None and last_id: filter_statements.append(f"({order_by_column}, {id_column}) < (${str(offset + 1)}, ${str(offset + 2)})") values.append(cls._get_value(end)) values.append(cls._get_value(last_id)) elif end is not None: filter_statements.append(f"{order_by_column} < ${str(offset + 1)}") values.append(cls._get_value(end)) return filter_statements, values @classmethod def _join_filter_statements(cls, filter_statements: list[str]) -> str: if filter_statements: return "WHERE " + " AND ".join(filter_statements) return "" async def delete(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Delete this document """ (filter_as_string, values) = self._get_filter_on_primary_key_fields() query = "DELETE FROM " + self.table_name() + " WHERE " + filter_as_string await self._execute_query(query, *values, connection=connection) async def delete_cascade(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: await self.delete(connection=connection) @classmethod @overload async def select_query( cls: type[TBaseDocument], query: str, values: list[object], connection: Optional[asyncpg.connection.Connection] = None ) -> Sequence[TBaseDocument]: """Return a sequence of objects of cls type.""" ... @classmethod @overload async def select_query( cls: type[TBaseDocument], query: str, values: list[object], no_obj: bool, connection: Optional[asyncpg.connection.Connection] = None, ) -> Sequence[Record]: """Return a sequence of records instances""" ... @classmethod async def select_query( cls: type[TBaseDocument], query: str, values: list[object], no_obj: bool = False, connection: Optional[asyncpg.connection.Connection] = None, ) -> Sequence[Union[Record, TBaseDocument]]: async with cls.get_connection(connection) as con: async with con.transaction(): result: list[Union[Record, TBaseDocument]] = [] async for record in con.cursor(query, *values): if no_obj: result.append(record) else: result.append(cls(from_postgres=True, **record)) return result def to_dict(self) -> JsonType: """ Return a dict representing the document """ result = {} for name, metadata in self.get_field_metadata().items(): value = self.get_value(name) if metadata.required and value is None: raise TypeError(f"{self.__name__} should have field '{name}'") if value is not None: metadata.validate(name, value) result[name] = value elif metadata.default: result[name] = metadata.default_value return result @classmethod async def execute_in_retryable_transaction( cls, fnc: Callable[[Connection], Awaitable[TransactionResult]], tx_isolation_level: Optional[str] = None, ) -> TransactionResult: """ Execute the queries in fnc using the transaction isolation level `tx_isolation_level` and return the result returned by fnc. This method performs retries when the transaction is aborted due to a serialization error. """ async with cls.get_connection() as postgresql_client: attempt = 1 while True: try: async with postgresql_client.transaction(isolation=tx_isolation_level): return await fnc(postgresql_client) except SerializationError: if attempt > 3: raise Exception("Failed to execute transaction after 3 attempts.") else: # Exponential backoff await asyncio.sleep(pow(10, attempt) / 1000) attempt += 1
class Project(BaseDocument): """ An inmanta configuration project :param name: The name of the configuration project. """ __primary_key__ = ("id",) id: uuid.UUID name: str def to_dto(self) -> m.Project: return m.Project(id=self.id, name=self.name, environments=[]) async def delete_cascade(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ This method doesn't rely on the DELETE CASCADE functionality of PostgreSQL because it causes deadlocks. As such, we perform the deletes on each table in a separate transaction. """ async with self.get_connection(connection=connection) as con: envs_in_project: abc.Sequence[Environment] = await Environment.get_list(project=self.id, connection=con) for env in envs_in_project: await env.delete_cascade(connection=con) await self.delete(connection=con) def convert_boolean(value: Union[bool, str]) -> bool: if isinstance(value, bool): return value if value.lower() not in RawConfigParser.BOOLEAN_STATES: raise ValueError("Not a boolean: %s" % value) return RawConfigParser.BOOLEAN_STATES[value.lower()] def convert_int(value: Union[float, int, str]) -> Union[int, float]: if isinstance(value, (int, float)): return value f_value = float(value) i_value = int(value) if i_value == f_value: return i_value return f_value def convert_positive_float(value: Union[float, int, str]) -> float: if isinstance(value, float): float_value = value else: float_value = float(value) if float_value < 0: raise ValueError(f"This value should be positive, got: {value}") return float_value def convert_agent_map(value: dict[str, str]) -> dict[str, str]: if not isinstance(value, dict): raise ValueError("Agent map should be a dict") for key, v in value.items(): if not isinstance(key, str): raise ValueError("The key of an agent map should be string") if not isinstance(v, str): raise ValueError("The value of an agent map should be string") if "internal" not in value: raise ValueError("The internal agent must be present in the autostart_agent_map") return value def translate_to_postgres_type(type: str) -> str: if type not in TYPE_MAP: raise Exception("Type '" + type + "' is not a valid type for a settings entry") return TYPE_MAP[type] def convert_agent_trigger_method(value: object) -> str: if isinstance(value, const.AgentTriggerMethod): return value value = str(value) valid_values = [x.name for x in const.AgentTriggerMethod] if value not in valid_values: raise ValueError("{} is not a valid agent trigger method. Valid value: {}".format(value, ",".join(valid_values))) return value def validate_cron_or_int(value: Union[int, str]) -> str: try: return str(int(value)) except ValueError: try: assert isinstance(value, str) # Make mypy happy return validate_cron(value, allow_empty=False) except ValueError as e: raise ValueError(f"'{value}' is not a valid cron expression or int: {e}") def validate_cron(value: str, allow_empty: bool = True) -> str: if not value: if allow_empty: return "" raise ValueError("The given cron expression is an empty string") try: CronTab(value) except ValueError as e: raise ValueError(f"'{value}' is not a valid cron expression: {e}") return value TYPE_MAP = { "int": "integer", "bool": "boolean", "dict": "jsonb", "str": "varchar", "enum": "varchar", "positive_float": "double precision", } AUTO_DEPLOY = "auto_deploy" PUSH_ON_AUTO_DEPLOY = "push_on_auto_deploy" AGENT_TRIGGER_METHOD_ON_AUTO_DEPLOY = "agent_trigger_method_on_auto_deploy" ENVIRONMENT_AGENT_TRIGGER_METHOD = "environment_agent_trigger_method" AUTOSTART_AGENT_DEPLOY_INTERVAL = "autostart_agent_deploy_interval" AUTOSTART_AGENT_DEPLOY_SPLAY_TIME = "autostart_agent_deploy_splay_time" AUTOSTART_AGENT_REPAIR_INTERVAL = "autostart_agent_repair_interval" AUTOSTART_AGENT_REPAIR_SPLAY_TIME = "autostart_agent_repair_splay_time" AUTOSTART_ON_START = "autostart_on_start" AUTOSTART_AGENT_MAP = "autostart_agent_map" AGENT_AUTH = "agent_auth" SERVER_COMPILE = "server_compile" AUTO_FULL_COMPILE = "auto_full_compile" RESOURCE_ACTION_LOGS_RETENTION = "resource_action_logs_retention" PROTECTED_ENVIRONMENT = "protected_environment" NOTIFICATION_RETENTION = "notification_retention" AVAILABLE_VERSIONS_TO_KEEP = "available_versions_to_keep" RECOMPILE_BACKOFF = "recompile_backoff" ENVIRONMENT_METRICS_RETENTION = "environment_metrics_retention" class Setting: """ A class to define a new environment setting. """ def __init__( self, name: str, typ: str, default: Optional[m.EnvSettingType] = None, doc: Optional[str] = None, validator: Optional[Callable[[m.EnvSettingType], m.EnvSettingType]] = None, recompile: bool = False, update_model: bool = False, agent_restart: bool = False, allowed_values: Optional[list[m.EnvSettingType]] = None, ) -> None: """ :param name: The name of the setting. :param type: The type of the value. This type is mainly used for documentation purpose. :param default: An optional default value for this setting. When a default is set and the is requested from the database, it will return the default value and also store the default value in the database. :param doc: The documentation/help string for this setting :param validator: A validation and casting function for input settings. Should raise ValueError if validation fails. :param recompile: Trigger a recompile of the model when a setting is updated? :param update_model: Update the configuration model (git pull on project and repos) :param agent_restart: Restart autostarted agents when this settings is updated. :param allowed_values: list of possible values (if type is enum) """ self.name: str = name self.typ: str = typ self._default = default self.doc = doc self.validator = validator self.recompile = recompile self.update = update_model self.agent_restart = agent_restart self.allowed_values = allowed_values @property def default(self) -> Optional[m.EnvSettingType]: if self._default and isinstance(self._default, dict): # Dicts are mutable objects. Return a copy. return dict(self._default) else: return self._default def to_dict(self) -> JsonType: return { "type": self.typ, "default": self.default, "doc": self.doc, "recompile": self.recompile, "update": self.update, "agent_restart": self.agent_restart, "allowed_values": self.allowed_values, } def to_dto(self) -> m.EnvironmentSetting: return m.EnvironmentSetting( name=self.name, type=self.typ, default=self.default, doc=self.doc, recompile=self.recompile, update_model=self.update, agent_restart=self.agent_restart, allowed_values=self.allowed_values, )
[docs] @stable_api class Environment(BaseDocument): """ A deployment environment of a project :param id: A unique, machine generated id :param name: The name of the deployment environment. :param project: The project this environment belongs to. :param repo_url: The repository url that contains the configuration model code for this environment. :param repo_branch: The repository branch that contains the configuration model code for this environment. :param settings: Key/value settings for this environment. This dictionary does not necessarily contain a key for every environment setting known by the server. This is done for backwards compatibility reasons. When a setting was renamed, we need to determine whether the old or the new setting has to be taken into account. The logic to decide that is the following: * When the name of the new setting is present in this settings dictionary or when the name of the old setting is not present in the settings dictionary, use the new setting. * Otherwise, use the setting with the old name. :param last_version: The last version number that was reserved for this environment :param description: The description of the environment :param icon: An icon for the environment """ __primary_key__ = ("id",) id: uuid.UUID name: str project: uuid.UUID repo_url: str = "" repo_branch: str = "" settings: dict[str, m.EnvSettingType] = {} last_version: int = 0 halted: bool = False description: str = "" icon: str = "" is_marked_for_deletion: bool = False def to_dto(self) -> m.Environment: return m.Environment( id=self.id, name=self.name, project_id=self.project, repo_url=self.repo_url, repo_branch=self.repo_branch, settings=self.settings, halted=self.halted, is_marked_for_deletion=self.is_marked_for_deletion, description=self.description, icon=self.icon, ) _settings: dict[str, Setting] = { AUTO_DEPLOY: Setting( name=AUTO_DEPLOY, typ="bool", default=True, doc="When this boolean is set to true, the orchestrator will automatically release a new version " "that was compiled by the orchestrator itself.", validator=convert_boolean, ), PUSH_ON_AUTO_DEPLOY: Setting( name=PUSH_ON_AUTO_DEPLOY, typ="bool", default=True, doc="Push a new version when it has been autodeployed.", validator=convert_boolean, ), AGENT_TRIGGER_METHOD_ON_AUTO_DEPLOY: Setting( name=AGENT_TRIGGER_METHOD_ON_AUTO_DEPLOY, typ="enum", default=const.AgentTriggerMethod.push_incremental_deploy.name, validator=convert_agent_trigger_method, doc="The agent trigger method to use when " + PUSH_ON_AUTO_DEPLOY + " is enabled", allowed_values=[opt.name for opt in const.AgentTriggerMethod], ), ENVIRONMENT_AGENT_TRIGGER_METHOD: Setting( name=ENVIRONMENT_AGENT_TRIGGER_METHOD, typ="enum", default=const.AgentTriggerMethod.push_incremental_deploy.name, validator=convert_agent_trigger_method, doc="The agent trigger method to use when no specific method is specified in the API call. " "This determines the behavior of the 'Promote' button. " f"For auto deploy, {AGENT_TRIGGER_METHOD_ON_AUTO_DEPLOY} is used.", allowed_values=[opt.name for opt in const.AgentTriggerMethod], ), AUTOSTART_AGENT_DEPLOY_INTERVAL: Setting( name=AUTOSTART_AGENT_DEPLOY_INTERVAL, typ="str", default="600", doc="The deployment interval of the autostarted agents. Can be specified as a number of seconds" " or as a cron-like expression. Set this to 0 to disable the automatic scheduling of deploy runs." " See also: :inmanta.config:option:`config.agent-deploy-interval`", validator=validate_cron_or_int, agent_restart=True, ), AUTOSTART_AGENT_DEPLOY_SPLAY_TIME: Setting( name=AUTOSTART_AGENT_DEPLOY_SPLAY_TIME, typ="int", default=10, doc="The splay time on the deployment interval of the autostarted agents." " See also: :inmanta.config:option:`config.agent-deploy-splay-time`", validator=convert_int, agent_restart=True, ), AUTOSTART_AGENT_REPAIR_INTERVAL: Setting( name=AUTOSTART_AGENT_REPAIR_INTERVAL, typ="str", default="86400", doc=( "The repair interval of the autostarted agents. Can be specified as a number of seconds" " or as a cron-like expression. Set this to 0 to disable the automatic scheduling of repair runs." " See also: :inmanta.config:option:`config.agent-repair-interval`" ), validator=validate_cron_or_int, agent_restart=True, ), AUTOSTART_AGENT_REPAIR_SPLAY_TIME: Setting( name=AUTOSTART_AGENT_REPAIR_SPLAY_TIME, typ="int", default=600, doc="The splay time on the repair interval of the autostarted agents." " See also: :inmanta.config:option:`config.agent-repair-splay-time`", validator=convert_int, agent_restart=True, ), AUTOSTART_ON_START: Setting( name=AUTOSTART_ON_START, default=True, typ="bool", validator=convert_boolean, doc="Automatically start agents when the server starts instead of only just in time.", ), AUTOSTART_AGENT_MAP: Setting( name=AUTOSTART_AGENT_MAP, default={"internal": "local:"}, typ="dict", validator=convert_agent_map, doc="A dict with key the name of agents that should be automatically started. The value " "is either an empty string or an agent map string. See also: :inmanta.config:option:`config.agent-map`", agent_restart=True, ), SERVER_COMPILE: Setting( name=SERVER_COMPILE, default=True, typ="bool", validator=convert_boolean, doc="Allow the server to compile the configuration model.", ), AUTO_FULL_COMPILE: Setting( name=AUTO_FULL_COMPILE, default="", typ="str", validator=validate_cron, doc=( "Periodically run a full compile following a cron-like time-to-run specification interpreted in UTC with format" " `[sec] min hour dom month dow [year]` (If only 6 values are provided, they are interpreted as" " `min hour dom month dow year`). A compile will be requested at the scheduled time. The actual" " compilation may have to wait in the compile queue for some time, depending on the size of the queue and the" " RECOMPILE_BACKOFF environment setting. This setting has no effect when server_compile is disabled." ), ), RESOURCE_ACTION_LOGS_RETENTION: Setting( name=RESOURCE_ACTION_LOGS_RETENTION, default=7, typ="int", validator=convert_int, doc="The number of days to retain resource-action logs", ), AVAILABLE_VERSIONS_TO_KEEP: Setting( name=AVAILABLE_VERSIONS_TO_KEEP, default=100, typ="int", validator=convert_int, doc="The number of versions to keep stored in the database, excluding the latest released version.", ), PROTECTED_ENVIRONMENT: Setting( name=PROTECTED_ENVIRONMENT, default=False, typ="bool", validator=convert_boolean, doc="When set to true, this environment cannot be cleared or deleted.", ), NOTIFICATION_RETENTION: Setting( name=NOTIFICATION_RETENTION, default=365, typ="int", validator=convert_int, doc="The number of days to retain notifications for", ), RECOMPILE_BACKOFF: Setting( name=RECOMPILE_BACKOFF, default=0.1, typ="positive_float", validator=convert_positive_float, doc="""The number of seconds to wait before the server may attempt to do a new recompile. Recompiles are triggered after facts updates for example.""", ), ENVIRONMENT_METRICS_RETENTION: Setting( name=ENVIRONMENT_METRICS_RETENTION, typ="int", default=336, doc="The number of hours that environment metrics have to be retained before they are cleaned up. " "Default=336 hours (2 weeks). Set to 0 to disable automatic cleanups.", validator=convert_int, ), } @classmethod def get_setting_definition(cls, setting_name: str) -> Setting: """ Return the definition of the setting with the given name. """ if setting_name not in cls._settings: raise KeyError() return cls._settings[setting_name] async def get(self, key: str, connection: Optional[asyncpg.connection.Connection] = None) -> m.EnvSettingType: """ Get a setting in this environment. :param key: The name/key of the setting. It should be defined in _settings otherwise a keyerror will be raised. """ if key not in self._settings: raise KeyError() if key in self.settings: return self.settings[key] default_value = self._settings[key].default if default_value is None: raise KeyError() await self.set(key, default_value, connection=connection, allow_override=False) return self.settings[key] async def set( self, key: str, value: m.EnvSettingType, connection: Optional[asyncpg.connection.Connection] = None, allow_override: bool = True, ) -> None: """ Set a new setting in this environment. :param key: The name/key of the setting. It should be defined in _settings otherwise a keyerror will be raised. :param value: The value of the settings. The value should be of type as defined in _settings :param allow_override: If set to False, don't set the given environment setting when it already exists in the setting dictionary in the database. """ if key not in self._settings: raise KeyError() # TODO: convert this to a string if callable(self._settings[key].validator): value = self._settings[key].validator(value) type = translate_to_postgres_type(self._settings[key].typ) (filter_statement, values) = self._get_composed_filter(name=self.name, project=self.project, offset=5) query = f""" UPDATE {self.table_name()} SET settings=( CASE WHEN $1 IS FALSE AND settings ? $2::text THEN settings ELSE jsonb_set(settings, $3::text[], to_jsonb($4::{type}), TRUE) END ) WHERE {filter_statement} RETURNING settings """ values = [allow_override, self._get_value(key), self._get_value([key]), self._get_value(value)] + values new_value = await self._fetchval(query, *values, connection=connection) new_value_parsed = cast( dict[str, m.EnvSettingType], self.get_field_metadata()["settings"].from_db(name="settings", value=new_value) ) self.settings[key] = new_value_parsed[key] async def unset(self, key: str) -> None: """ Unset a setting in this environment. If a default value is provided, this value will replace the current value. :param key: The name/key of the setting. It should be defined in _settings otherwise a keyerror will be raised. """ if key not in self._settings: raise KeyError() if self._settings[key].default is None: (filter_statement, values) = self._get_composed_filter(name=self.name, project=self.project, offset=2) query = "UPDATE " + self.table_name() + " SET settings=settings - $1" + " WHERE " + filter_statement values = [self._get_value(key)] + values await self._execute_query(query, *values) del self.settings[key] else: await self.set(key, self._settings[key].default) async def mark_for_deletion(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """Mark an environment as being in the process of deletion.""" await self.update_fields(is_marked_for_deletion=True, connection=connection) async def delete_cascade(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Completely remove this environment from the db """ async with self.get_connection(connection=connection) as con: await self.clear(connection=con) await self.delete(connection=con) async def clear(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Delete everything related to this environment from the db, except the entry in the Environment table. This method doesn't rely on the DELETE CASCADE functionality of PostgreSQL because it causes deadlocks. This is especially true for the tables resourceaction_resource, resource and resourceaction, because they have a high read/write load. As such, we perform the deletes on each table in a separate transaction. """ async with self.get_connection(connection=connection) as con: await Agent.delete_all(environment=self.id, connection=con) await AgentInstance.delete_all(tid=self.id, connection=con) await AgentProcess.delete_all(environment=self.id, connection=con) await Compile.delete_all(environment=self.id, connection=con) # Triggers cascading delete on report table await Parameter.delete_all(environment=self.id, connection=con) await Notification.delete_all(environment=self.id, connection=con) await Code.delete_all(environment=self.id, connection=con) await DiscoveredResource.delete_all(environment=self.id, connection=con) await EnvironmentMetricsGauge.delete_all(environment=self.id, connection=con) await EnvironmentMetricsTimer.delete_all(environment=self.id, connection=con) await DryRun.delete_all(environment=self.id, connection=con) await UnknownParameter.delete_all(environment=self.id, connection=con) await self._execute_query( "DELETE FROM public.resourceaction_resource WHERE environment=$1", self.id, connection=con ) await ResourceAction.delete_all(environment=self.id, connection=con) await Resource.delete_all(environment=self.id, connection=con) await ConfigurationModel.delete_all(environment=self.id, connection=con) await ResourcePersistentState.delete_all(environment=self.id, connection=con) async def get_next_version(self, connection: Optional[asyncpg.connection.Connection] = None) -> int: """ Reserves the next available version and returns it. Increments the last_version counter. """ record = await self._fetchrow( f""" UPDATE {self.table_name()} SET last_version = last_version + 1 WHERE id = $1 RETURNING last_version; """, self.id, connection=connection, ) version = cast(int, record[0]) self.last_version = version return version @classmethod def register_setting(cls, setting: Setting) -> None: """ Adds a new environment setting that was defined by an extension. :param setting: the setting that should be added to the existing settings """ if setting.name in cls._settings: raise KeyError() cls._settings[setting.name] = setting @classmethod async def get_list( cls: type[TBaseDocument], *, order_by_column: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, no_obj: Optional[bool] = None, lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, details: bool = True, **query: object, ) -> list[TBaseDocument]: """ Get a list of documents matching the filter args. """ if details: return await super().get_list( order_by_column=order_by_column, order=order, limit=limit, offset=offset, no_obj=no_obj, lock=lock, connection=connection, **query, ) return await cls.get_list_without_details( order_by_column=order_by_column, order=order, limit=limit, offset=offset, no_obj=no_obj, lock=lock, connection=connection, **query, ) @classmethod async def get_list_without_details( cls: type[TBaseDocument], *, order_by_column: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, no_obj: Optional[bool] = None, lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, **query: object, ) -> list[TBaseDocument]: """ Get a list of environments matching the filter args. Don't return the description and icon columns. """ columns = [column_name for column_name in cls.get_valid_field_names() if column_name not in {"description", "icon"}] return await super().get_list_with_columns( order_by_column=order_by_column, order=order, limit=limit, offset=offset, no_obj=no_obj, lock=lock, connection=connection, columns=columns, **query, ) @classmethod async def get_by_id( cls: type[TBaseDocument], doc_id: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None, details: bool = True, ) -> Optional[TBaseDocument]: """ Get a specific environment based on its ID :return: An instance of this class with its fields filled from the database. """ result = await cls.get_list(id=doc_id, connection=connection, details=details) if len(result) > 0: return result[0] return None async def acquire_release_version_lock(self, *, shared: bool = False, connection: asyncpg.Connection) -> None: """ Acquires a transaction-level advisory lock for concurrency control between release_version and calls that need the latest version. This lock should also be held when updating any resource state in any other way than the normal agent deploy path Up to now, this means - setting resource state after increment calculation on release - propagation of resource state from a stale deploy to the latest version - setting resource state after increment calculation on agent pull :param env: The environment to acquire the lock for. :param shared: If true, doesn't conflict with other shared locks, only with non-shared ones. :param connection: The connection hosting the transaction for which to acquire a lock. """ await self._xact_lock(const.PG_ADVISORY_KEY_RELEASE_VERSION, self.id, shared=shared, connection=connection) async def put_version_lock(self, *, shared: bool = False, connection: asyncpg.Connection) -> None: """ Acquires a transaction-level advisory lock for concurrency control between put_version and put_partial. :param env: The environment to acquire the lock for. :param shared: If true, doesn't conflict with other shared locks, only with non-shared ones. :param connection: The connection hosting the transaction for which to acquire a lock. """ await self._xact_lock(const.PG_ADVISORY_KEY_PUT_VERSION, self.id, shared=shared, connection=connection)
class Parameter(BaseDocument): """ A parameter that can be used in the configuration model :param name: The name of the parameter :param value: The value of the parameter :param environment: The environment this parameter belongs to :param source: The source of the parameter :param resource_id: An optional resource id :param updated: When was the parameter updated last :param expires: Boolean denoting whether this parameter expires. :todo Add history """ __primary_key__ = ("id", "name", "environment") id: uuid.UUID name: str value: str = "" environment: uuid.UUID source: str resource_id: m.ResourceIdStr = "" updated: Optional[datetime.datetime] = None metadata: Optional[JsonType] = None expires: bool @classmethod async def get_updated_before_active_env( cls, updated_before: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None, ) -> list["Parameter"]: """ Retrieve the list of parameters that were updated before a specified datetime for environments that are not halted """ query = f""" SELECT p.* FROM {cls.table_name()} AS p INNER JOIN {Environment.table_name()} AS e ON p.environment=e.id WHERE NOT e.halted AND p.updated < $1 AND p.expires AND ( -- If it's a fact, it needs to belong to the latest released version. p.resource_id IS NULL OR p.resource_id = '' OR EXISTS( SELECT 1 FROM {Resource.table_name()} AS r WHERE r.environment=p.environment AND r.model=( SELECT max(c.version) FROM {ConfigurationModel.table_name()} AS c WHERE c.environment=p.environment AND c.released ) AND r.resource_id=p.resource_id ) ); """ values = [cls._get_value(updated_before)] return await cls.select_query(query, values, connection=connection) @classmethod async def list_parameters(cls, env_id: uuid.UUID, **metadata_constraints: str) -> list["Parameter"]: query = "SELECT * FROM " + cls.table_name() + " WHERE environment=$1" values = [cls._get_value(env_id)] for key, value in metadata_constraints.items(): query_param_index = len(values) + 1 query += " AND metadata @> $" + str(query_param_index) + "::jsonb" dict_value = {key: value} values.append(cls._get_value(dict_value)) query += " ORDER BY name" result = await cls.select_query(query, values) return result def as_fact(self) -> m.Fact: assert self.source == "fact" return m.Fact( id=self.id, name=self.name, value=self.value, environment=self.environment, resource_id=self.resource_id, source=self.source, updated=self.updated, metadata=self.metadata, expires=self.expires, ) def as_param(self) -> m.Parameter: return m.Parameter( id=self.id, name=self.name, value=self.value, environment=self.environment, source=self.source, updated=self.updated, metadata=self.metadata, ) class UnknownParameter(BaseDocument): """ A parameter that the compiler indicated that was unknown. This parameter causes the configuration model to be incomplete for a specific environment. :param name: :param resource_id: :param source: :param environment: :param version: The version id of the configuration model on which this parameter was reported """ __primary_key__ = ("id",) id: uuid.UUID name: str environment: uuid.UUID source: str resource_id: m.ResourceIdStr = "" version: int metadata: Optional[dict[str, object]] resolved: bool = False def copy(self, new_version: int) -> "UnknownParameter": """ Create a new UnknownParameter using this object as a template. The returned object will have the id field unset and the version field set the new_version. """ return UnknownParameter( name=self.name, environment=self.environment, source=self.source, resource_id=self.resource_id, version=new_version, metadata=self.metadata, resolved=self.resolved, ) @classmethod async def get_unknowns_in_latest_released_model_versions( cls, connection: asyncpg.Connection ) -> Sequence["UnknownParameter"]: """ Returns all the unknowns in the latest released model version of each non-halted environment. """ query = f""" SELECT u.* FROM {cls.table_name()} AS u INNER JOIN {Environment.table_name()} AS e ON u.environment=e.id WHERE NOT e.halted AND u.version=( SELECT max(c.version) FROM {ConfigurationModel.table_name()} AS c WHERE c.environment=e.id AND c.released ) AND NOT u.resolved; """ return await cls.select_query(query, values=[], connection=connection) @classmethod async def get_unknowns_to_copy_in_partial_compile( cls, environment: uuid.UUID, source_version: int, updated_resource_sets: abc.Set[str], deleted_resource_sets: abc.Set[str], rids_in_partial_compile: abc.Set[ResourceIdStr], *, connection: Optional[asyncpg.connection.Connection] = None, ) -> list["UnknownParameter"]: """ Returns a subset of the unknowns in source_version of environment. It returns the unknowns that: * Are not associated with a resource * Are associated with a resource that: - don't belong to the resource set updated_resource_sets and deleted_resource_sets - and, don't have a resource_id in rids_in_partial_compile (An unknown might belong to a shared resource that is not exported by the partial compile) """ query = f""" SELECT u.* FROM {cls.table_name()} AS u LEFT JOIN {Resource.table_name()} AS r ON u.environment=r.environment AND u.version=r.model AND u.resource_id=r.resource_id WHERE u.environment=$1 AND u.version=$2 AND u.resolved IS FALSE AND (r.resource_id IS NULL OR NOT r.resource_id=ANY($4)) AND (r.resource_set IS NULL OR NOT r.resource_set=ANY($3)) """ async with cls.get_connection(connection) as con: result = await con.fetch( query, environment, source_version, list(updated_resource_sets | deleted_resource_sets), list(rids_in_partial_compile), ) return [cls(from_postgres=True, **uk) for uk in result] class AgentProcess(BaseDocument): """ A process in the infrastructure that has (had) a session as an agent. :param hostname: The hostname of the device. :param environment: To what environment is this process bound :param last_seen: When did the server receive data from the node for the last time. """ __primary_key__ = ("sid",) sid: uuid.UUID hostname: str environment: uuid.UUID first_seen: Optional[datetime.datetime] = None last_seen: Optional[datetime.datetime] = None expired: Optional[datetime.datetime] = None @classmethod async def get_live(cls, environment: Optional[uuid.UUID] = None) -> list["AgentProcess"]: if environment is not None: result = await cls.get_list( limit=DBLIMIT, environment=environment, expired=None, order_by_column="last_seen", order="ASC NULLS LAST" ) else: result = await cls.get_list(limit=DBLIMIT, expired=None, order_by_column="last_seen", order="ASC NULLS LAST") return result @classmethod async def get_by_sid( cls, sid: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None ) -> Optional["AgentProcess"]: objects = await cls.get_list(limit=DBLIMIT, connection=connection, expired=None, sid=sid) if len(objects) == 0: return None elif len(objects) > 1: LOGGER.exception("Multiple objects with the same unique id found!") return objects[0] else: return objects[0] @classmethod async def seen( cls, env: uuid.UUID, nodename: str, sid: uuid.UUID, now: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None, ) -> None: """ Update the last_seen parameter of the process and mark as not expired. """ proc = await cls.get_one(connection=connection, sid=sid) if proc is None: proc = cls(hostname=nodename, environment=env, first_seen=now, last_seen=now, sid=sid) await proc.insert(connection=connection) else: await proc.update_fields(connection=connection, last_seen=now, expired=None) @classmethod async def update_last_seen( cls, sid: uuid.UUID, last_seen: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None ) -> None: aps = await cls.get_by_sid(sid=sid, connection=connection) if aps: await aps.update_fields(connection=connection, last_seen=last_seen) @classmethod async def expire_process( cls, sid: uuid.UUID, now: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None ) -> None: aps = await cls.get_by_sid(sid=sid, connection=connection) if aps is not None: await aps.update_fields(connection=connection, expired=now) @classmethod async def expire_all(cls, now: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None) -> None: query = f""" UPDATE {cls.table_name()} SET expired=$1 WHERE expired IS NULL """ await cls._execute_query(query, cls._get_value(now), connection=connection) @classmethod async def cleanup(cls, nr_expired_records_to_keep: int) -> None: query = f""" WITH halted_env AS ( SELECT id FROM environment WHERE halted = true ) DELETE FROM {cls.table_name()} AS a1 WHERE a1.expired IS NOT NULL AND a1.environment NOT IN (SELECT id FROM halted_env) AND ( -- Take nr_expired_records_to_keep into account SELECT count(*) FROM {cls.table_name()} a2 WHERE a1.environment=a2.environment AND a1.hostname=a2.hostname AND a2.expired IS NOT NULL AND a2.expired > a1.expired ) >= $1 AND -- Agent process only has expired agent instances NOT EXISTS( SELECT 1 FROM {cls.table_name()} AS agentprocess INNER JOIN {AgentInstance.table_name()} AS agentinstance ON agentinstance.process = agentprocess.sid WHERE agentprocess.sid = a1.sid AND agentinstance.expired IS NULL ); """ await cls._execute_query(query, cls._get_value(nr_expired_records_to_keep)) def to_dict(self) -> JsonType: result = super().to_dict() # Ensure backward compatibility API result["id"] = result["sid"] return result def to_dto(self) -> m.AgentProcess: return m.AgentProcess( sid=self.sid, hostname=self.hostname, environment=self.environment, first_seen=self.first_seen, last_seen=self.last_seen, expired=self.expired, ) TAgentInstance = TypeVar("TAgentInstance", bound="AgentInstance") class AgentInstance(BaseDocument): """ A physical server/node in the infrastructure that reports to the management server. :param hostname: The hostname of the device. :param last_seen: When did the server receive data from the node for the last time. """ __primary_key__ = ("id",) # TODO: add env to speed up cleanup id: uuid.UUID process: uuid.UUID name: str expired: Optional[datetime.datetime] = None tid: uuid.UUID @classmethod async def active_for( cls: type[TAgentInstance], tid: uuid.UUID, endpoint: str, process: Optional[uuid.UUID] = None, connection: Optional[asyncpg.connection.Connection] = None, ) -> list[TAgentInstance]: if process is not None: objects = await cls.get_list(expired=None, tid=tid, name=endpoint, process=process, connection=connection) else: objects = await cls.get_list(expired=None, tid=tid, name=endpoint, connection=connection) return objects @classmethod async def active(cls: type[TAgentInstance]) -> list[TAgentInstance]: objects = await cls.get_list(expired=None) return objects @classmethod async def log_instance_creation( cls: type[TAgentInstance], tid: uuid.UUID, process: uuid.UUID, endpoints: set[str], connection: Optional[asyncpg.connection.Connection] = None, ) -> None: """ Create new agent instances for a given session. """ if not endpoints: return async with cls.get_connection(connection) as con: await con.executemany( f""" INSERT INTO {cls.table_name()} (id, tid, process, name, expired) VALUES ($1, $2, $3, $4, null) ON CONFLICT ON CONSTRAINT {cls.table_name()}_unique DO UPDATE SET expired = null ; """, [tuple(map(cls._get_value, (cls._new_id(), tid, process, name))) for name in endpoints], ) @classmethod async def log_instance_expiry( cls: type[TAgentInstance], sid: uuid.UUID, endpoints: set[str], now: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None, ) -> None: """ Expire specific instances for a given session id. """ if not endpoints: return instances: list[TAgentInstance] = await cls.get_list(connection=connection, process=sid) for ai in instances: if ai.name in endpoints: await ai.update_fields(connection=connection, expired=now) @classmethod async def expire_all(cls, now: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None) -> None: query = f""" UPDATE {cls.table_name()} SET expired=$1 WHERE expired IS NULL """ await cls._execute_query(query, cls._get_value(now), connection=connection) class Agent(BaseDocument): """ An inmanta agent :param environment: The environment this resource is defined in :param name: The name of this agent :param last_failover: Moment at which the primary was last changed :param paused: is this agent paused (if so, skip it) :param primary: what is the current active instance (if none, state is down) :param unpause_on_resume: whether this agent should be unpaused when resuming from environment-wide halt. Used to persist paused state when halting. """ __primary_key__ = ("environment", "name") environment: uuid.UUID name: str last_failover: Optional[datetime.datetime] = None paused: bool = False id_primary: Optional[uuid.UUID] = None unpause_on_resume: Optional[bool] = None @property def primary(self) -> Optional[uuid.UUID]: return self.id_primary @classmethod def get_valid_field_names(cls) -> list[str]: # Allow the computed fields return super().get_valid_field_names() + ["process_name", "status"] @classmethod async def get_statuses( cls, env_id: uuid.UUID, agent_names: Set[str], *, connection: Optional[asyncpg.connection.Connection] = None ) -> dict[str, Optional[AgentStatus]]: result: dict[str, Optional[AgentStatus]] = {} for agent_name in agent_names: agent = await cls.get_one(environment=env_id, name=agent_name, connection=connection) if agent: result[agent_name] = agent.get_status() else: result[agent_name] = None return result def get_status(self) -> AgentStatus: if self.paused: return AgentStatus.paused if self.primary is not None: return AgentStatus.up return AgentStatus.down def to_dict(self) -> JsonType: base = BaseDocument.to_dict(self) if self.last_failover is None: base["last_failover"] = "" if self.primary is None: base["primary"] = "" else: base["primary"] = base["id_primary"] del base["id_primary"] base["state"] = self.get_status().value return base @classmethod def _convert_field_names_to_db_column_names(cls, field_dict: dict[str, object]) -> dict[str, object]: if "primary" in field_dict: field_dict["id_primary"] = field_dict["primary"] del field_dict["primary"] return field_dict @classmethod async def get( cls, env: uuid.UUID, endpoint: str, connection: Optional[asyncpg.connection.Connection] = None, lock: Optional[RowLockMode] = None, ) -> "Agent": obj = await cls.get_one(environment=env, name=endpoint, connection=connection, lock=lock) return obj @classmethod async def persist_on_halt(cls, env: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Persists paused state when halting all agents. """ await cls._execute_query( f"UPDATE {cls.table_name()} SET unpause_on_resume=NOT paused WHERE environment=$1 AND unpause_on_resume IS NULL", cls._get_value(env), connection=connection, ) @classmethod async def persist_on_resume(cls, env: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None) -> list[str]: """ Restores default halted state. Returns a list of agents that should be unpaused. """ async with cls.get_connection(connection) as con: async with con.transaction(): unpause_on_resume = await cls._fetch_query( # lock FOR UPDATE to avoid deadlocks: next query in this transaction updates the row f"SELECT name FROM {cls.table_name()} WHERE environment=$1 AND unpause_on_resume FOR NO KEY UPDATE", cls._get_value(env), connection=con, ) await cls._execute_query( f"UPDATE {cls.table_name()} SET unpause_on_resume=NULL WHERE environment=$1", cls._get_value(env), connection=con, ) return sorted([r["name"] for r in unpause_on_resume]) @classmethod async def pause( cls, env: uuid.UUID, endpoint: Optional[str], paused: bool, connection: Optional[asyncpg.connection.Connection] = None ) -> list[str]: """ Pause a specific agent or all agents in an environment when endpoint is set to None. :return A list of agent names that have been paused/unpaused by this method. """ if endpoint is None: query = f"UPDATE {cls.table_name()} SET paused=$1 WHERE environment=$2 RETURNING name" values = [cls._get_value(paused), cls._get_value(env)] else: query = f"UPDATE {cls.table_name()} SET paused=$1 WHERE environment=$2 AND name=$3 RETURNING name" values = [cls._get_value(paused), cls._get_value(env), cls._get_value(endpoint)] result = await cls._fetch_query(query, *values, connection=connection) return sorted([r["name"] for r in result]) @classmethod async def set_unpause_on_resume( cls, env: uuid.UUID, endpoint: Optional[str], should_be_unpaused_on_resume: bool, connection: Optional[asyncpg.connection.Connection] = None, ) -> None: """ Set the unpause_on_resume field of a specific agent or all agents in an environment when endpoint is set to None. """ if endpoint is None: query = f"UPDATE {cls.table_name()} SET unpause_on_resume=$1 WHERE environment=$2" values = [cls._get_value(should_be_unpaused_on_resume), cls._get_value(env)] else: query = f"UPDATE {cls.table_name()} SET unpause_on_resume=$1 WHERE environment=$2 AND name=$3" values = [cls._get_value(should_be_unpaused_on_resume), cls._get_value(env), cls._get_value(endpoint)] await cls._execute_query(query, *values, connection=connection) @classmethod async def update_primary( cls, env: uuid.UUID, endpoints_with_new_primary: Sequence[tuple[str, Optional[uuid.UUID]]], now: datetime.datetime, connection: Optional[asyncpg.connection.Connection] = None, ) -> None: """ Update the primary agent instance for agents present in the database. :param env: The environment of the agent :param endpoints_with_new_primary: Contains a tuple (agent-name, sid) for each agent that has got a new primary agent instance. The sid in the tuple is the session id of the new primary. If the session id is None, the Agent doesn't have a primary anymore. :param now: Timestamp of this failover """ for endpoint, sid in endpoints_with_new_primary: # Lock mode is required because we will update in this transaction # Deadlocks with cleanup otherwise agent = await cls.get(env, endpoint, connection=connection, lock=RowLockMode.FOR_NO_KEY_UPDATE) if agent is None: continue if sid is None: await agent.update_fields(last_failover=now, primary=None, connection=connection) else: instances = await AgentInstance.active_for(tid=env, endpoint=agent.name, process=sid, connection=connection) if instances: await agent.update_fields(last_failover=now, id_primary=instances[0].id, connection=connection) else: await agent.update_fields(last_failover=now, id_primary=None, connection=connection) @classmethod async def mark_all_as_non_primary(cls, connection: Optional[asyncpg.connection.Connection] = None) -> None: query = f""" UPDATE {cls.table_name()} SET id_primary=NULL WHERE id_primary IS NOT NULL """ await cls._execute_query(query, connection=connection) @classmethod async def clean_up(cls, connection: Optional[asyncpg.connection.Connection] = None) -> None: query = """ DELETE FROM public.agent AS a WHERE (environment, name) NOT IN ( SELECT DISTINCT environment_id as environment, agent as name FROM ( -- agent is in the agent map SELECT e.id as environment_id, map.key as agent FROM public.environment e CROSS JOIN LATERAL jsonb_each(e.settings->'autostart_agent_map') AS map(key, value) ) in_agent_map ) -- have no primary ID set (that are down) AND id_primary IS NULL -- not used by any version AND NOT EXISTS ( SELECT 1 FROM public.resource AS re WHERE a.environment=re.environment AND a.name=re.agent ) AND a.environment IN ( SELECT id FROM public.environment WHERE NOT halted ); """ await cls._execute_query(query, connection=connection)
[docs] @stable_api class Report(BaseDocument): """ A report of a substep of compilation :param started: when the substep started :param completed: when it ended :param command: the command that was executed :param name: The name of this step :param errstream: what was reported on system err :param outstream: what was reported on system out """ __primary_key__ = ("id",) id: uuid.UUID started: datetime.datetime completed: Optional[datetime.datetime] command: str name: str errstream: str = "" outstream: str = "" returncode: Optional[int] compile: uuid.UUID async def update_streams(self, out: str = "", err: str = "") -> None: if not out and not err: return await self._execute_query( f"UPDATE {self.table_name()} SET outstream = outstream || $1, errstream = errstream || $2 WHERE id = $3", self._get_value(out), self._get_value(err), self._get_value(self.id), )
[docs] @stable_api class Compile(BaseDocument): """ A run of the compiler :param environment: The environment this resource is defined in :param requested: Time the compile was requested :param started: Time the compile started :param completed: Time to compile was completed :param do_export: should this compile perform an export :param force_update: should this compile definitely update :param metadata: exporter metadata to be passed to the compiler :param requested_environment_variables: environment variables requested to be passed to the compiler :param mergeable_environment_variables: environment variables to be passed to the compiler. These env vars can be compacted over multiple compiles. If multiple values are compacted, they will be joined using spaces. :param used_environment_variables: environment variables passed to the compiler, None before the compile is started :param success: was the compile successful :param handled: were all registered handlers executed? :param version: version exported by this compile :param remote_id: id as given by the requestor, used by the requestor to distinguish between different requests :param compile_data: json data as exported by compiling with the --export-compile-data parameter :param substitute_compile_id: id of this compile's substitute compile, i.e. the compile request that is similar to this one that actually got compiled. :param partial: True if the compile only contains the entities/resources for the resource sets that should be updated :param removed_resource_sets: indicates the resource sets that should be removed from the model :param exporter_plugin: Specific exporter plugin to use :param notify_failed_compile: if true use the notification service to notify that a compile has failed. By default, notifications are enabled only for exporting compiles. :param failed_compile_message: Optional message to use when a notification for a failed compile is created :param soft_delete: Prevents deletion of resources in removed_resource_sets if they are being exported. """ __primary_key__ = ("id",) id: uuid.UUID remote_id: Optional[uuid.UUID] = None environment: uuid.UUID requested: Optional[datetime.datetime] = None started: Optional[datetime.datetime] = None completed: Optional[datetime.datetime] = None do_export: bool = False force_update: bool = False metadata: JsonType = {} requested_environment_variables: dict[str, str] = {} mergeable_environment_variables: dict[str, str] = {} used_environment_variables: Optional[dict[str, str]] = None success: Optional[bool] handled: bool = False version: Optional[int] = None # Compile queue might be collapsed if it contains similar compile requests. # In that case, substitute_compile_id will reference the actually compiled request. substitute_compile_id: Optional[uuid.UUID] = None compile_data: Optional[JsonType] = None partial: bool = False removed_resource_sets: list[str] = [] exporter_plugin: Optional[str] = None notify_failed_compile: Optional[bool] = None failed_compile_message: Optional[str] = None soft_delete: bool = False
[docs] @classmethod async def get_substitute_by_id(cls, compile_id: uuid.UUID, connection: Optional[Connection] = None) -> Optional["Compile"]: """ Get a compile's substitute compile if it exists, otherwise get the compile by id. :param compile_id: The id of the compile for which to get the substitute compile. :return: The compile object for compile c2 that is the substitute of compile c1 with the given id. If c1 does not have a substitute, returns c1 itself. """ async with Compile.get_connection(connection=connection) as con: result: Optional[Compile] = await cls.get_by_id(compile_id, connection=con) if result is None: return None if result.substitute_compile_id is None: return result return await cls.get_substitute_by_id(result.substitute_compile_id, connection=con)
@classmethod # TODO: Use join async def get_report(cls, compile_id: uuid.UUID) -> Optional[dict]: """ Get the compile and the associated reports from the database """ result: Optional[Compile] = await cls.get_substitute_by_id(compile_id) if result is None: return None dict_model = result.to_dict() reports = await Report.get_list(compile=result.id) dict_model["reports"] = [r.to_dict() for r in reports] return dict_model @classmethod async def get_last_run(cls, environment_id: uuid.UUID) -> Optional["Compile"]: """Get the last run for the given environment""" results = await cls.select_query( f"SELECT * FROM {cls.table_name()} where environment=$1 AND completed IS NOT NULL ORDER BY completed DESC LIMIT 1", [cls._get_value(environment_id)], ) if not results: return None return results[0] @classmethod async def get_next_run( cls, environment_id: uuid.UUID, *, connection: Optional[asyncpg.Connection] = None ) -> Optional["Compile"]: """Get the next compile in the queue for the given environment""" async with cls.get_connection(connection) as con: results = await cls.select_query( f"SELECT * FROM {cls.table_name()} WHERE environment=$1 AND completed IS NULL ORDER BY requested ASC LIMIT 1", [cls._get_value(environment_id)], connection=con, ) if not results: return None return results[0] @classmethod async def get_next_run_all(cls, *, connection: Optional[asyncpg.Connection] = None) -> "Sequence[Compile]": """Get the next compile in the queue for each environment""" async with cls.get_connection(connection) as con: results = await cls.select_query( f"SELECT DISTINCT ON (environment) * FROM {cls.table_name()} WHERE completed IS NULL ORDER BY environment, " f"requested ASC", [], connection=con, ) return results @classmethod async def get_unhandled_compiles(cls) -> "Sequence[Compile]": """Get all compiles that have completed but for which listeners have not been notified yet.""" results = await cls.select_query( f"SELECT * FROM {cls.table_name()} WHERE NOT handled and completed IS NOT NULL ORDER BY requested ASC", [] ) return results @classmethod async def get_next_compiles_for_environment(cls, environment_id: uuid.UUID) -> "Sequence[Compile]": """Get the queue of compiles that are scheduled in FIFO order.""" results = await cls.select_query( f"SELECT * FROM {cls.table_name()} WHERE environment=$1 AND NOT handled and completed IS NULL " "ORDER BY requested ASC", [cls._get_value(environment_id)], ) return results @classmethod async def get_total_length_of_all_compile_queues(cls, exclude_started_compiles: bool = True) -> int: """ Return the total length of all the compile queues on the Inmanta server. :param exclude_started_compiles: True iff don't count compiles that started running, but are not finished yet. """ query = f"SELECT count(*) FROM {cls.table_name()} WHERE completed IS NULL" if exclude_started_compiles: query += " AND started IS NULL" return await cls._fetch_int(query) @classmethod async def get_by_remote_id( cls, environment_id: uuid.UUID, remote_id: uuid.UUID, *, connection: Optional[asyncpg.Connection] = None ) -> "Sequence[Compile]": results = await cls.select_query( f"SELECT * FROM {cls.table_name()} WHERE environment=$1 AND remote_id=$2", [cls._get_value(environment_id), cls._get_value(remote_id)], connection=connection, ) return results @classmethod async def delete_older_than( cls, oldest_retained_date: datetime.datetime, connection: Optional[asyncpg.Connection] = None ) -> None: query = f""" WITH non_halted_envs AS ( SELECT id FROM public.environment WHERE NOT halted ) DELETE FROM {cls.table_name()} WHERE environment IN ( SELECT id FROM non_halted_envs ) AND completed <= $1::timestamp with time zone; """ await cls._execute_query(query, oldest_retained_date, connection=connection) @classmethod async def get_compile_details(cls, environment: uuid.UUID, id: uuid.UUID) -> Optional[m.CompileDetails]: """Find all of the details of a compile, with reports from a substituted compile, if there was one""" # Recursively join the requested compile with the substituted compiles (if there was one), and the corresponding reports query = f""" WITH RECURSIVE compiledetails AS ( SELECT c.id, c.remote_id, c.environment, c.requested, c.started, c.completed, c.success, c.version, c.do_export, c.force_update, c.metadata, c.requested_environment_variables , c.mergeable_environment_variables, c.used_environment_variables, c.compile_data, c.substitute_compile_id, c.partial, c.removed_resource_sets, c.exporter_plugin, c.notify_failed_compile, c.failed_compile_message, r.id as report_id, r.started report_started, r.completed report_completed, r.command, r.name, r.errstream, r.outstream, r.returncode FROM {cls.table_name()} c LEFT JOIN public.report r on c.id = r.compile WHERE c.environment = $1 AND c.id = $2 UNION SELECT comp.id, comp.remote_id, comp.environment, comp.requested, comp.started, comp.completed, comp.success, comp.version, comp.do_export, comp.force_update, comp.metadata, comp.requested_environment_variables, comp.mergeable_environment_variables, comp.used_environment_variables, comp.compile_data, comp.substitute_compile_id, comp.partial, comp.removed_resource_sets, comp.exporter_plugin, comp.notify_failed_compile, comp.failed_compile_message, rep.id as report_id, rep.started as report_started, rep.completed as report_completed, rep.command, rep.name, rep.errstream, rep.outstream, rep.returncode FROM /* Lookup the compile with the id that matches the subsitute_compile_id of the current one */ {cls.table_name()} comp INNER JOIN compiledetails cd ON cd.substitute_compile_id = comp.id LEFT JOIN public.report rep on comp.id = rep.compile ) SELECT * FROM compiledetails ORDER BY report_started ASC; """ values = [cls._get_value(environment), cls._get_value(id)] result = await cls.select_query(query, values, no_obj=True) result = cast(list[Record], result) # The result is a list of Compiles joined with Reports # This includes the Compile with the requested id, # as well as Compile(s) that have been used as a substitute for the requested Compile (if there are any) if not result: return None # The details, such as the requested timestamp, etc. should be returned from # the compile that matches the originally requested id records = list(filter(lambda r: r["id"] == id, result)) if not records: return None requested_compile = records[0] # Reports should be included from the substituted compile (as well) reports = [ m.CompileRunReport( id=report["report_id"], started=report["report_started"], completed=report["report_completed"], command=report["command"], name=report["name"], errstream=report["errstream"], outstream=report["outstream"], returncode=report["returncode"], ) for report in result if report.get("report_id") ] return m.CompileDetails( id=requested_compile["id"], remote_id=requested_compile["remote_id"], environment=requested_compile["environment"], requested=requested_compile["requested"], started=requested_compile["started"], completed=requested_compile["completed"], success=requested_compile["success"], version=requested_compile["version"], do_export=requested_compile["do_export"], force_update=requested_compile["force_update"], metadata=json.loads(requested_compile["metadata"]) if requested_compile["metadata"] else {}, environment_variables=( json.loads(requested_compile["used_environment_variables"]) if requested_compile["used_environment_variables"] is not None else {} ), requested_environment_variables=(json.loads(requested_compile["requested_environment_variables"])), mergeable_environment_variables=(json.loads(requested_compile["mergeable_environment_variables"])), partial=requested_compile["partial"], removed_resource_sets=requested_compile["removed_resource_sets"], exporter_plugin=requested_compile["exporter_plugin"], notify_failed_compile=requested_compile["notify_failed_compile"], failed_compile_message=requested_compile["failed_compile_message"], compile_data=json.loads(requested_compile["compile_data"]) if requested_compile["compile_data"] else None, reports=reports, )
[docs] def to_dto(self) -> m.CompileRun: return m.CompileRun( id=self.id, remote_id=self.remote_id, environment=self.environment, requested=self.requested, started=self.started, do_export=self.do_export, force_update=self.force_update, metadata=self.metadata, environment_variables=self.used_environment_variables, requested_environment_variables=self.requested_environment_variables, mergeable_environment_variables=self.mergeable_environment_variables, compile_data=None if self.compile_data is None else m.CompileData(**self.compile_data), partial=self.partial, removed_resource_sets=self.removed_resource_sets, exporter_plugin=self.exporter_plugin, notify_failed_compile=self.notify_failed_compile, failed_compile_message=self.failed_compile_message, )
def to_dict(self) -> JsonType: """produce dict directly, for untyped endpoints""" # mangle the output for backward compatibility # we have to do it because we have no DTO here environment_variables = self.used_environment_variables if environment_variables is None: environment_variables = {} environment_variables.update(self.requested_environment_variables) environment_variables.update(self.mergeable_environment_variables) out = super().to_dict() out["environment_variables"] = environment_variables return out
class LogLine(DataDocument): """ LogLine data document. An instance of this class only has one attribute: _data. This unique attribute is a dict, with the following keys: - msg: the message to write to logs (value type: str) - args: the args that can be passed to the logger (value type: list) - level: the log level of the message (value type: str, example: "CRITICAL") - kwargs: the key-word args that where used to generated the log (value type: list) - timestamp: the time at which the LogLine was created (value type: datetime.datetime) """ @property def msg(self) -> str: return self._data["msg"] @property def args(self) -> list: return self._data["args"] @property def log_level(self) -> LogLevel: level: str = self._data["level"] return LogLevel[level] def write_to_logger(self, logger: logging.Logger) -> None: logger.log(self.log_level.to_int, self.msg, *self.args) @classmethod def log( cls, level: Union[int, const.LogLevel], msg: str, timestamp: Optional[datetime.datetime] = None, **kwargs: object, ) -> "LogLine": if timestamp is None: timestamp = datetime.datetime.now().astimezone() log_line = msg % kwargs return cls(level=LogLevel(level).name, msg=log_line, args=[], kwargs=kwargs, timestamp=timestamp)
[docs] @stable_api class ResourceAction(BaseDocument): """ Log related to actions performed on a specific resource version by Inmanta. :param environment: The environment this action belongs to. :param version: The version of the configuration model this action belongs to. :param resource_version_ids: The resource version ids of the resources this action relates to. :param action_id: This id distinguishes the actions from each other. Action ids have to be unique per environment. :param action: The action performed on the resource :param started: When did the action start :param finished: When did the action finish :param messages: The log messages associated with this action :param status: The status of the resource when this action was finished :param changes: A dict with key the resource id and value a dict of fields -> value. Value is a dict that can contain old and current keys and the associated values. An empty dict indicates that the field was changed but not data was provided by the agent. :param change: The change result of an action """ __primary_key__ = ("action_id",) environment: uuid.UUID version: int resource_version_ids: list[m.ResourceVersionIdStr] action_id: uuid.UUID action: const.ResourceAction started: datetime.datetime finished: Optional[datetime.datetime] = None messages: Optional[list[dict[str, object]]] = None status: Optional[const.ResourceState] = None changes: Optional[dict[m.ResourceIdStr, dict[str, object]]] = None change: Optional[const.Change] = None def __init__(self, from_postgres: bool = False, **kwargs: object) -> None: super().__init__(from_postgres, **kwargs) self._updates = {} # rewrite some data if self.changes == {}: self.changes = None # load message json correctly if from_postgres and self.messages: new_messages = [] for message in self.messages: message = json.loads(message) if "timestamp" in message: ta = pydantic.TypeAdapter(datetime.datetime) # use pydantic instead of datetime.strptime because strptime has trouble parsing isoformat timezone offset timestamp = ta.validate_python(message["timestamp"]) if timestamp.tzinfo is None: raise Exception("Found naive timestamp in the database, this should not be possible") message["timestamp"] = timestamp new_messages.append(message) self.messages = new_messages @classmethod async def get_by_id(cls, doc_id: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None) -> "ResourceAction": return await cls.get_one(action_id=doc_id, connection=connection) @classmethod async def get_log( cls, environment: uuid.UUID, resource_version_id: m.ResourceVersionIdStr, action: Optional[str] = None, limit: int = 0, connection: Optional[Connection] = None, ) -> list["ResourceAction"]: query = """ SELECT ra.* FROM public.resourceaction as ra INNER JOIN public.resourceaction_resource as jt ON ra.action_id = jt.resource_action_id WHERE jt.environment=$1 AND jt.resource_id = $2 AND jt.resource_version = $3 """ id = resources.Id.parse_id(resource_version_id) values = [cls._get_value(environment), id.resource_str(), id.version] if action is not None: query += " AND action=$4" values.append(cls._get_value(action)) query += " ORDER BY started DESC" if limit is not None and limit > 0: query += " LIMIT $%d" % (len(values) + 1) values.append(cls._get_value(limit)) async with cls.get_connection(connection) as con: async with con.transaction(): return [cls(**dict(record), from_postgres=True) async for record in con.cursor(query, *values)]
[docs] @classmethod async def get_logs_for_version( cls, environment: uuid.UUID, version: int, action: Optional[str] = None, limit: int = 0, connection: Optional[Connection] = None, ) -> list["ResourceAction"]: query = f"""SELECT * FROM {cls.table_name()} WHERE environment=$1 AND version=$2 """ values = [cls._get_value(environment), cls._get_value(version)] if action is not None: query += " AND action=$3" values.append(cls._get_value(action)) query += " ORDER BY started DESC" if limit is not None and limit > 0: query += " LIMIT $%d" % (len(values) + 1) values.append(cls._get_value(limit)) async with cls.get_connection(connection=connection) as con: async with con.transaction(): return [cls(**dict(record), from_postgres=True) async for record in con.cursor(query, *values)]
@classmethod def get_valid_field_names(cls) -> list[str]: return super().get_valid_field_names() + ["timestamp", "level", "msg"] @classmethod async def get(cls, action_id: uuid.UUID, connection: Optional[asyncpg.connection.Connection] = None) -> "ResourceAction": return await cls.get_one(action_id=action_id, connection=connection) async def insert(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: async with self.get_connection(connection) as con: async with con.transaction(): await super().insert(con) # Also do the join table in the same transaction assert self.resource_version_ids parsed_rv = [resources.Id.parse_resource_version_id(id) for id in self.resource_version_ids] # No additional checking of field validity is done here, because the insert above validates all fields await con.execute( "INSERT INTO public.resourceaction_resource " "(resource_id, resource_version, environment, resource_action_id) " "SELECT unnest($1::text[]), unnest($2::int[]), $3, $4", [id.resource_str() for id in parsed_rv], [id.get_version() for id in parsed_rv], self.environment, self.action_id, ) def set_field(self, name: str, value: object) -> None: self._updates[name] = value def add_logs(self, messages: Optional[str]) -> None: if not messages: return if "messages" not in self._updates: self._updates["messages"] = [] self._updates["messages"] += messages def add_changes(self, changes: dict[m.ResourceIdStr, dict[str, object]]) -> None: for resource, values in changes.items(): for field, change in values.items(): if "changes" not in self._updates: self._updates["changes"] = {} if resource not in self._updates["changes"]: self._updates["changes"][resource] = {} self._updates["changes"][resource][field] = change async def set_and_save( self, messages: list[dict[str, object]], changes: dict[str, object], status: Optional[const.ResourceState], change: Optional[const.Change], finished: Optional[datetime.datetime], connection: Optional[asyncpg.connection.Connection] = None, ) -> None: if len(messages) > 0: self.add_logs(messages) if len(changes) > 0: self.add_changes(changes) if status is not None: self.set_field("status", status) if change is not None: self.set_field("change", change) if finished is not None: self.set_field("finished", finished) await self.save(connection=connection) async def save(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Save the changes """ if len(self._updates) == 0: return assert ( "resource_version_ids" not in self._updates ), "Updating the associated resource_version_ids of a ResourceAction is not currently supported" await self.update_fields(connection=connection, **self._updates) self._updates = {} @classmethod async def purge_logs(cls) -> None: default_retention_time = Environment._settings[RESOURCE_ACTION_LOGS_RETENTION].default query = f""" WITH non_halted_envs AS ( SELECT id, (COALESCE((settings->>'resource_action_logs_retention')::int, $1)) AS retention_days FROM {Environment.table_name()} WHERE NOT halted ) DELETE FROM {cls.table_name()} USING non_halted_envs WHERE environment = non_halted_envs.id AND started < now() AT TIME ZONE 'UTC' - make_interval(days => non_halted_envs.retention_days) """ await cls._execute_query(query, default_retention_time) @classmethod async def query_resource_actions( cls, environment: uuid.UUID, resource_type: Optional[str] = None, agent: Optional[str] = None, attribute: Optional[str] = None, attribute_value: Optional[str] = None, resource_id_value: Optional[str] = None, log_severity: Optional[str] = None, limit: int = 0, action_id: Optional[uuid.UUID] = None, first_timestamp: Optional[datetime.datetime] = None, last_timestamp: Optional[datetime.datetime] = None, action: Optional[const.ResourceAction] = None, resource_id: Optional[ResourceIdStr] = None, exclude_changes: Optional[list[const.Change]] = None, ) -> list["ResourceAction"]: query = """SELECT DISTINCT ra.* FROM public.resource as r INNER JOIN public.resourceaction_resource as jt ON r.environment = jt.environment AND r.resource_id = jt.resource_id AND r.model = jt.resource_version INNER JOIN public.resourceaction as ra ON ra.action_id = jt.resource_action_id WHERE r.environment=$1 AND ra.environment=$1""" values: list[object] = [cls._get_value(environment)] parameter_index = 2 if resource_type: query += f" AND resource_type=${parameter_index}" values.append(cls._get_value(resource_type)) parameter_index += 1 if agent: query += f" AND agent=${parameter_index}" values.append(cls._get_value(agent)) parameter_index += 1 if attribute and attribute_value: # The query uses a like query to match resource id with a resource_version_id. This means we need to escape the % # and _ characters in the query escaped_value = attribute_value.replace("#", "##").replace("%", "#%").replace("_", "#_") + "%" query += f" AND attributes->>${parameter_index} LIKE ${parameter_index + 1} ESCAPE '#' " values.append(cls._get_value(attribute)) values.append(cls._get_value(escaped_value)) parameter_index += 2 if resource_id_value: query += f" AND r.resource_id_value = ${parameter_index}::varchar" values.append(cls._get_value(resource_id_value)) parameter_index += 1 if resource_id: query += f" AND r.resource_id = ${parameter_index}::varchar" values.append(cls._get_value(resource_id)) parameter_index += 1 if log_severity: # <@ Is contained by query += f" AND ${parameter_index} <@ ANY(messages)" values.append(cls._get_value({"level": log_severity.upper()})) parameter_index += 1 if action is not None: query += f" AND ra.action=${parameter_index}" values.append(cls._get_value(action)) parameter_index += 1 if first_timestamp and action_id: query += f" AND (started, action_id) > (${parameter_index}, ${parameter_index + 1})" values.append(cls._get_value(first_timestamp)) values.append(cls._get_value(action_id)) parameter_index += 2 elif first_timestamp: query += f" AND started > ${parameter_index}" values.append(cls._get_value(first_timestamp)) parameter_index += 1 if last_timestamp and action_id: query += f" AND (started, action_id) < (${parameter_index}, ${parameter_index + 1})" values.append(cls._get_value(last_timestamp)) values.append(cls._get_value(action_id)) parameter_index += 2 elif last_timestamp: query += f" AND started < ${parameter_index}" values.append(cls._get_value(last_timestamp)) parameter_index += 1 if exclude_changes: # Create a string with placeholders for each item in exclude_changes exclude_placeholders = ", ".join([f"${parameter_index + i}" for i in range(len(exclude_changes))]) query += f" AND ra.change NOT IN ({exclude_placeholders})" values.extend([cls._get_value(change) for change in exclude_changes]) parameter_index += len(exclude_changes) if first_timestamp: query += " ORDER BY started, action_id" else: query += " ORDER BY started DESC, action_id DESC" if limit is not None and limit > 0: query += " LIMIT $%d" % parameter_index values.append(cls._get_value(limit)) parameter_index += 1 if first_timestamp: query = f"""SELECT * FROM ({query}) AS matching_actions ORDER BY matching_actions.started DESC, matching_actions.action_id DESC""" async with cls.get_connection() as con: async with con.transaction(): return [cls(**record, from_postgres=True) async for record in con.cursor(query, *values)] @classmethod async def get_resource_events( cls, env: Environment, resource_id: "resources.Id", exclude_change: Optional[const.Change] = None ) -> dict[ResourceIdStr, list["ResourceAction"]]: """ Get all events that should be processed by this specific resource, for the current deployment This method searches across versions! This means: 1. assure a deployment is ongoing 2. get the time range between the start of this deployment and the last successful deploy 3. get all resources required by this resource 4. get all resource actions of type deploy emitted by the resource of step 3 in the time interval of step 2 :param env: environment to consider :param resource_id: resource to consider, should be in deploying state :param exclude_change: in step 4, exclude all resource actions with this specific type of change """ # This is bang on the critical path for the agent # Squeeze out as much performance from postgresql as we can resource_version_id_str = resource_id.resource_version_str() resource_id_str = resource_id.resource_str() # These two variables are actually of type datetime.datetime # but mypy doesn't know as they come from the DB # mypy also doesn't care, because they go back into the DB last_deploy_start: Optional[object] async with cls.get_connection() as connection: # Step 1: Get the resource # also check we are currently deploying resource: Optional[Resource] = await Resource.get_one( environment=env.id, resource_id=resource_id_str, model=resource_id.version, connection=connection ) if resource is None: raise NotFound(f"Resource with id {resource_version_id_str} was not found in environment {env.id}") resource_state: Optional[ResourcePersistentState] = await ResourcePersistentState.get_one( environment=env.id, resource_id=resource_id_str, connection=connection ) assert resource_state is not None # resource state must exist if resource exists if resource.status != const.ResourceState.deploying: raise BadRequest( "Fetching resource events only makes sense when the resource is currently deploying. Current deploy state" f" for resource {resource_version_id_str} is {resource.status}." ) # Step 2: # find the interval between the current deploy (now) and the previous successful deploy last_deploy_start = resource_state.last_success # Step 3: get the relevant resource actions # Do it in one query for all dependencies # Construct the query arg = ArgumentCollector(offset=2) # First make the filter filter = "" if last_deploy_start: filter += f" AND ra.started > {arg(last_deploy_start)}" if exclude_change: filter += f" AND ra.change <> {arg(exclude_change.value)}" # then the query around it get_all_query = f""" SELECT jt.resource_id, ra.* FROM public.resourceaction_resource as jt INNER JOIN public.resourceaction as ra ON ra.action_id = jt.resource_action_id WHERE jt.environment=$1 AND ra.environment=$1 AND jt.resource_id=ANY($2::varchar[]) AND ra.action='deploy' {filter} ORDER BY ra.started DESC; """ # Convert resource version ids into resource ids ids = [resources.Id.parse_id(req).resource_str() for req in resource.attributes["requires"]] # Get the result result2 = await connection.fetch(get_all_query, env.id, ids, *arg.get_values()) # Collect results per resource_id collector: dict[ResourceIdStr, list["ResourceAction"]] = { rid: [] for rid in ids } # eagerly initialize, we expect one entry per dependency, even when empty for record in result2: fields = dict(record) del fields["resource_id"] collector[cast(ResourceIdStr, record[0])].append(ResourceAction(from_postgres=True, **fields)) return collector def to_dto(self) -> m.ResourceAction: return m.ResourceAction( environment=self.environment, version=self.version, resource_version_ids=self.resource_version_ids, action_id=self.action_id, action=self.action, started=self.started, finished=self.finished, messages=self.messages, status=self.status, changes=self.changes, change=self.change, )
class ResourcePersistentState(BaseDocument): @classmethod def table_name(cls) -> str: return "resource_persistent_state" __primary_key__ = ("environment", "resource_id") environment: uuid.UUID # ID related resource_id: m.ResourceIdStr resource_type: str agent: str resource_id_value: str # Field based on content from the resource actions last_deploy: Optional[datetime.datetime] = None # Last deployment completed of any kind, including marking-deployed-for-know-good-state for increments # i.e. the end time of the last deploy last_deployed_attribute_hash: Optional[str] = None # Hash used in last_deploy last_deployed_version: Optional[int] = None # Model version of last_deploy last_success: Optional[datetime.datetime] = None # last actual deployment completed without failure. i.e start time of the last deploy where status == ResourceState.deployed last_produced_events: Optional[datetime.datetime] = None # Last produced an event. i.e. the end time of the last deploy where we had an effective change # (change is not None and change != Change.nochange) # status last_non_deploying_status: const.NonDeployingResourceState = const.NonDeployingResourceState.available @classmethod async def trim(cls, environment: UUID, connection: Optional[Connection] = None) -> None: """Remove all records that have no corresponding resource anymore""" await cls._execute_query( f""" DELETE FROM {cls.table_name()} rps WHERE NOT EXISTS( SELECT r.resource_id FROM {Resource.table_name()} r WHERE r.resource_id = rps.resource_id and r.environment=$1 ) and rps.environment=$1 """, environment, connection=connection, )
[docs] @stable_api class Resource(BaseDocument): """ A specific version of a resource. This entity contains the desired state of a resource. :param environment: The environment this resource version is defined in :param rid: The id of the resource and its version :param resource: The resource for which this defines the state :param model: The configuration model (versioned) this resource state is associated with :param attributes: The state of this version of the resource :param attribute_hash: hash of the attributes, excluding requires, provides and version, used to determine if a resource describes the same state across versions :param resource_id_value: The attribute value from the resource id """ __primary_key__ = ("environment", "model", "resource_id") environment: uuid.UUID model: int # ID related resource_id: m.ResourceIdStr resource_type: m.ResourceType resource_id_value: str agent: str # State related attributes: dict[str, object] = {} attribute_hash: Optional[str] status: const.ResourceState = const.ResourceState.available resource_set: Optional[str] = None # internal field to handle cross agent dependencies # if this resource is updated, it must notify all RV's in this list # the list contains full rv id's provides: list[m.ResourceIdStr] = [] # Methods for backward compatibility @property def resource_version_id(self): # This field was removed from the DB, this method keeps code compatibility return resources.Id.set_version_in_id(self.resource_id, self.model) @classmethod def __mangle_dict(cls, record: dict) -> None: """ Transform the dict of attributes as it exists here/in the database to the backward compatible form Operates in-place """ version = record["model"] parsed_id = resources.Id.parse_id(record["resource_id"]) parsed_id.set_version(version) record["resource_version_id"] = parsed_id.resource_version_str() record["id"] = record["resource_version_id"] record["resource_type"] = parsed_id.entity_type if "requires" in record["attributes"]: record["attributes"]["requires"] = [ resources.Id.set_version_in_id(id, version) for id in record["attributes"]["requires"] ] # Due to a bug, the version field has always been present in the attributes dictionary. # This bug has been fixed in the database. For backwards compatibility reason we here make sure that the # version field is present in the attributes dictionary served out via the API. record["attributes"]["version"] = version record["provides"] = [resources.Id.set_version_in_id(id, version) for id in record["provides"]] @classmethod async def get_last_non_deploying_state_for_dependencies( cls, environment: uuid.UUID, resource_version_id: "resources.Id", connection: Optional[Connection] = None ) -> dict[m.ResourceVersionIdStr, ResourceState]: """ Return the last state of each dependency of the given resource that was not 'deploying'. """ if not resource_version_id.is_resource_version_id_obj(): raise Exception("Argument resource_version_id is not a resource_version_id") version = resource_version_id.version query = """ SELECT r1.resource_id, r1.last_non_deploying_status FROM resource_persistent_state AS r1 WHERE r1.environment=$1 AND ( SELECT (r2.attributes->'requires')::jsonb FROM resource AS r2 WHERE r2.environment=$1 AND r2.model=$2 AND r2.resource_id=$3 ) ? r1.resource_id """ values = [ cls._get_value(environment), cls._get_value(version), resource_version_id.resource_str(), ] result = await cls._fetch_query(query, *values, connection=connection) return {r["resource_id"] + ",v=" + str(version): const.ResourceState(r["last_non_deploying_status"]) for r in result} def make_hash(self) -> None: character = json.dumps( {k: v for k, v in self.attributes.items() if k not in ["requires", "provides", "version"]}, default=custom_json_encoder, sort_keys=True, # sort the keys for stable hashes when using dicts, see #5306 ) m = hashlib.md5() m.update(self.resource_id.encode("utf-8")) m.update(character.encode("utf-8")) self.attribute_hash = m.hexdigest() @classmethod async def get_resources( cls, environment: uuid.UUID, resource_version_ids: list[m.ResourceVersionIdStr], lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, ) -> list["Resource"]: """ Get all resources listed in resource_version_ids """ if not resource_version_ids: return [] query_lock: str = lock.value if lock is not None else "" def convert_or_ignore(rvid): """Method to retain backward compatibility, ignore bad ID's""" try: return resources.Id.parse_resource_version_id(rvid) except ValueError: return None parsed_rv = (convert_or_ignore(id) for id in resource_version_ids) effective_parsed_rv = [id for id in parsed_rv if id is not None] if not effective_parsed_rv: return [] query = ( f"SELECT r.* FROM {cls.table_name()} r" f" INNER JOIN unnest($2::resource_id_version_pair[]) requested(resource_id, model)" f" ON r.resource_id = requested.resource_id AND r.model = requested.model" f" WHERE environment=$1" f" {query_lock}" ) out = await cls.select_query( query, [cls._get_value(environment), [(id.resource_str(), id.get_version()) for id in effective_parsed_rv]], connection=connection, ) return out @classmethod async def get_status_for( cls, env: uuid.UUID, model_version: int, rids: list[ResourceIdStr], ) -> dict[ResourceIdStr, ResourceState]: if not rids: return {} query = """ SELECT r.resource_id, r.status FROM resource r WHERE r.environment=$1 AND r.model=$2 AND r.resource_id = ANY($3); """ out = await cls.select_query(query, [env, model_version, rids], no_obj=True) return {ResourceIdStr(r["resource_id"]): ResourceState[r["status"]] for r in out} @stable_api @classmethod async def get_current_resource_state(cls, env: uuid.UUID, rid: ResourceIdStr) -> Optional[ResourceState]: """ Return the state of the given resource in the latest version of the configuration model or None if the resource is not present in the latest version. """ query = """ WITH latest_released_version AS ( SELECT max(version) AS version FROM configurationmodel WHERE environment=$1 AND released ) SELECT ( CASE -- The resource_persistent_state.last_non_deploying_status column is only populated for -- actual deployment operations to prevent locking issues. This case-statement calculates -- the correct state from the combination of the resource table and the -- resource_persistent_state table. WHEN r.status::text IN('deploying', 'undefined', 'skipped_for_undefined') -- The deploying, undefined and skipped_for_undefined states are not tracked in the -- resource_persistent_state table. THEN r.status::text WHEN rps.last_deployed_attribute_hash != r.attribute_hash -- The hash changed since the last deploy -> new desired state THEN r.status::text -- No override required, use last known state from actual deployment ELSE rps.last_non_deploying_status::text END ) AS status FROM resource AS r INNER JOIN resource_persistent_state AS rps ON r.environment=rps.environment AND r.resource_id=rps.resource_id INNER JOIN configurationmodel AS c ON c.environment=r.environment AND c.version=r.model WHERE r.environment=$1 AND r.model = (SELECT version FROM latest_released_version) AND r.resource_id=$2 """ results = await cls.select_query(query, [env, rid], no_obj=True) if not results: return None assert len(results) == 1 return const.ResourceState(results[0]["status"]) @classmethod async def set_deployed_multi( cls, environment: uuid.UUID, resource_ids: Sequence[m.ResourceIdStr], version: int, connection: Optional[asyncpg.connection.Connection] = None, ) -> None: query = "UPDATE resource SET status='deployed' WHERE environment=$1 AND model=$2 AND resource_id =ANY($3) " async with cls.get_connection(connection) as connection: await connection.execute(query, environment, version, resource_ids) @classmethod async def get_resource_ids_with_status( cls, environment: uuid.UUID, resource_version_ids: list[m.ResourceIdStr], version: int, statuses: Sequence[const.ResourceState], lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, ) -> list[m.ResourceIdStr]: query = ( "SELECT resource_id as resource_id FROM resource WHERE " "environment=$1 AND model=$2 AND status = ANY($3) and resource_id =ANY($4) " ) if lock: query += lock.value async with cls.get_connection(connection) as connection: return [ m.ResourceIdStr(cast(str, r["resource_id"])) for r in await connection.fetch(query, environment, version, statuses, resource_version_ids) ] @classmethod async def get_undeployable(cls, environment: uuid.UUID, version: int) -> list["Resource"]: """ Returns a list of resources with an undeployable state """ (filter_statement, values) = cls._get_composed_filter(environment=environment, model=version) undeployable_states = ", ".join(["$" + str(i + 3) for i in range(len(const.UNDEPLOYABLE_STATES))]) values = values + [cls._get_value(s) for s in const.UNDEPLOYABLE_STATES] query = ( "SELECT * FROM " + cls.table_name() + " WHERE " + filter_statement + " AND status IN (" + undeployable_states + ")" ) resources = await cls.select_query(query, values) return resources @classmethod async def get_resources_in_latest_version( cls, environment: uuid.UUID, resource_type: Optional[m.ResourceType] = None, attributes: dict[PrimitiveTypes, PrimitiveTypes] = {}, *, connection: Optional[asyncpg.connection.Connection] = None, ) -> list["Resource"]: """ Returns the resources in the latest version of the configuration model of the given environment, that satisfy the given constraints. :param environment: The resources should belong to this environment. :param resource_type: The environment should have this resource_type. :param attributes: The resource should contain these key-value pairs in its attributes list. """ values = [cls._get_value(environment)] query = f""" SELECT * FROM {Resource.table_name()} AS r1 WHERE r1.environment=$1 AND r1.model=(SELECT MAX(cm.version) FROM {ConfigurationModel.table_name()} AS cm WHERE cm.environment=$1) """ if resource_type: query += " AND r1.resource_type=$2" values.append(cls._get_value(resource_type)) result = [] async with cls.get_connection(connection) as con: async with con.transaction(): async for record in con.cursor(query, *values): resource = cls(from_postgres=True, **record) # The constraints on the attributes field are checked in memory. # This prevents injection attacks. if util.is_sub_dict(attributes, resource.attributes): result.append(resource) return result @classmethod async def get_resource_type_count_for_latest_version(cls, environment: uuid.UUID) -> dict[str, int]: """ Returns the count for each resource_type over all resources in the model's latest version """ query_latest_model = f""" SELECT max(version) FROM {ConfigurationModel.table_name()} WHERE environment=$1 """ query = f""" SELECT resource_type, count(*) as count FROM {Resource.table_name()} WHERE environment=$1 AND model=({query_latest_model}) GROUP BY resource_type; """ values = [cls._get_value(environment)] result: dict[str, int] = {} async with cls.get_connection() as con: async with con.transaction(): async for record in con.cursor(query, *values): assert isinstance(record["count"], int) result[str(record["resource_type"])] = record["count"] return result @classmethod async def get_resources_report(cls, environment: uuid.UUID) -> list[JsonType]: """ This method generates a report of all resources in the given environment, with their latest version and when they are last deployed. """ query_resource_ids = f""" SELECT resource_id, last_deployed_version as deployed_version, last_deploy FROM {ResourcePersistentState.table_name()} WHERE environment=$1 """ query_latest_version = f""" SELECT resource_id, model AS latest_version FROM {Resource.table_name()} WHERE environment=$1 AND resource_id=r1.resource_id ORDER BY model DESC LIMIT 1 """ query = f""" SELECT r1.resource_id, r2.latest_version, r1.deployed_version, r1.last_deploy FROM ({query_resource_ids}) AS r1 INNER JOIN LATERAL ({query_latest_version}) AS r2 ON (r1.resource_id = r2.resource_id) """ values = [cls._get_value(environment)] result = [] async with cls.get_connection() as con: async with con.transaction(): async for record in con.cursor(query, *values): resource_id = record["resource_id"] parsed_id = resources.Id.parse_id(resource_id) result.append( { "resource_id": resource_id, "resource_type": parsed_id.entity_type, "agent": parsed_id.agent_name, "latest_version": record["latest_version"], "deployed_version": record["deployed_version"] if "deployed_version" in record else None, "last_deploy": record["last_deploy"] if "last_deploy" in record else None, } ) return result
[docs] @classmethod async def get_resources_for_version( cls, environment: uuid.UUID, version: int, agent: Optional[str] = None, no_obj: bool = False, *, connection: Optional[asyncpg.connection.Connection] = None, ) -> list["Resource"]: if agent: (filter_statement, values) = cls._get_composed_filter(environment=environment, model=version, agent=agent) else: (filter_statement, values) = cls._get_composed_filter(environment=environment, model=version) query = f"SELECT * FROM {Resource.table_name()} WHERE {filter_statement}" resources_list: Union[list[Resource], list[dict[str, object]]] = [] async with cls.get_connection(connection) as con: async with con.transaction(): async for record in con.cursor(query, *values): if no_obj: record = dict(record) record["attributes"] = json.loads(record["attributes"]) cls.__mangle_dict(record) resources_list.append(record) else: resources_list.append(cls(from_postgres=True, **record)) return resources_list
@classmethod async def get_resources_for_version_raw( cls, environment: uuid.UUID, version: int, projection: Optional[list[str]], *, connection: Optional[Connection] = None ) -> list[dict[str, object]]: if not projection: projection = "*" else: projection = ",".join(projection) (filter_statement, values) = cls._get_composed_filter(environment=environment, model=version) query = "SELECT " + projection + " FROM " + cls.table_name() + " WHERE " + filter_statement resource_records = await cls._fetch_query(query, *values, connection=connection) resources = [dict(record) for record in resource_records] for res in resources: if "attributes" in res: res["attributes"] = json.loads(res["attributes"]) return resources @classmethod async def get_resources_for_version_raw_with_persistent_state( cls, environment: uuid.UUID, version: int, projection: Optional[list[typing.LiteralString]], projection_presistent: Optional[list[typing.LiteralString]], project_attributes: Optional[list[typing.LiteralString]] = None, *, connection: Optional[Connection] = None, ) -> list[dict[str, object]]: """This method performs none of the mangling required to produce valid resources! project_attributes performs a projection on the json attributes of the resources table all projections must be disjoint, as they become named fields in the output record """ def collect_projection(projection: Optional[list[str]], prefix: str) -> str: if not projection: return f"{prefix}.*" else: return ",".join(f"{prefix}.{field}" for field in projection) if project_attributes: json_projection = "," + ",".join(f"r.attributes->'{v}' as {v}" for v in project_attributes) else: json_projection = "" query = f""" SELECT {collect_projection(projection, 'r')}, {collect_projection(projection_presistent, 'ps')} {json_projection} FROM {cls.table_name()} r JOIN resource_persistent_state ps ON r.resource_id = ps.resource_id WHERE r.environment=$1 AND ps.environment = $1 and r.model = $2;""" resource_records = await cls._fetch_query(query, environment, version, connection=connection) resources = [dict(record) for record in resource_records] for res in resources: if project_attributes: for k in project_attributes: if res[k]: res[k] = json.loads(res[k]) return resources @classmethod async def get_latest_version(cls, environment: uuid.UUID, resource_id: m.ResourceIdStr) -> Optional["Resource"]: resources = await cls.get_list( order_by_column="model", order="DESC", limit=1, environment=environment, resource_id=resource_id ) if len(resources) > 0: return resources[0] return None @staticmethod def get_details_from_resource_id(resource_id: m.ResourceIdStr) -> m.ResourceIdDetails: parsed_id = resources.Id.parse_id(resource_id) return m.ResourceIdDetails( resource_type=parsed_id.entity_type, agent=parsed_id.agent_name, attribute=parsed_id.attribute, resource_id_value=parsed_id.attribute_value, ) @classmethod async def get( cls, environment: uuid.UUID, resource_version_id: m.ResourceVersionIdStr, connection: Optional[asyncpg.connection.Connection] = None, ) -> Optional["Resource"]: """ Get a resource with the given resource version id """ parsed_id = resources.Id.parse_id(resource_version_id) value = await cls.get_one( environment=environment, resource_id=parsed_id.resource_str(), model=parsed_id.version, connection=connection ) return value @classmethod def new(cls, environment: uuid.UUID, resource_version_id: m.ResourceVersionIdStr, **kwargs: object) -> "Resource": vid = resources.Id.parse_id(resource_version_id) attr = dict( environment=environment, model=vid.version, resource_id=vid.resource_str(), resource_type=vid.entity_type, agent=vid.agent_name, resource_id_value=vid.attribute_value, ) attr.update(kwargs) return cls(**attr) def copy_for_partial_compile(self, new_version: int) -> "Resource": """ Create a new resource dao instance from this dao instance. Only creates the object without inserting it. The new instance will have the given version. """ new_resource_state = ResourceState.undefined if self.status is ResourceState.undefined else ResourceState.available return Resource( environment=self.environment, model=new_version, resource_id=self.resource_id, resource_type=self.resource_type, resource_id_value=self.resource_id_value, agent=self.agent, attributes=self.attributes.copy(), attribute_hash=self.attribute_hash, status=new_resource_state, resource_set=self.resource_set, provides=self.provides, ) @classmethod async def get_resource_details(cls, env: uuid.UUID, resource_id: m.ResourceIdStr) -> Optional[m.ReleasedResourceDetails]: def status_sub_query(resource_table_name: str) -> str: return f""" (CASE -- The resource_persistent_state.last_non_deploying_status column is only populated for -- actual deployment operations to prevent locking issues. This case-statement calculates -- the correct state from the combination of the resource table and the -- resource_persistent_state table. WHEN (SELECT {resource_table_name}.model < MAX(configurationmodel.version) FROM configurationmodel WHERE configurationmodel.released=TRUE AND environment = $1 ) -- Resource is no longer present in latest released configurationmodel THEN 'orphaned' WHEN {resource_table_name}.status::text IN('deploying', 'undefined', 'skipped_for_undefined') -- The deploying, undefined and skipped_for_undefined states are not tracked in the -- resource_persistent_state table. THEN {resource_table_name}.status::text WHEN ps.last_deployed_attribute_hash != {resource_table_name}.attribute_hash -- The hash changed since the last deploy -> new desired state THEN {resource_table_name}.status::text -- No override required, use last known state from actual deployment ELSE ps.last_non_deploying_status::text END ) as status """ query = f""" SELECT DISTINCT ON (resource_id) first.resource_id, cm.date as first_generated_time, first.model as first_model, latest.model AS latest_model, latest.resource_id as latest_resource_id, latest.resource_type, latest.agent, latest.resource_id_value, ps.last_deploy as latest_deploy, latest.attributes, {status_sub_query('latest')} FROM resource first INNER JOIN /* 'latest' is the latest released version of the resource */ (SELECT distinct on (resource_id) resource_id, attribute_hash, model, attributes, resource_type, agent, resource_id_value, resource.status as status FROM resource JOIN configurationmodel cm ON resource.model = cm.version AND resource.environment = cm.environment WHERE resource.environment = $1 AND resource_id = $2 AND cm.released = TRUE ORDER BY resource_id, model desc ) as latest /* The 'first' values correspond to the first time the attribute hash was the same as in the 'latest' released version */ ON first.resource_id = latest.resource_id AND first.attribute_hash = latest.attribute_hash INNER JOIN configurationmodel cm ON first.model = cm.version AND first.environment = cm.environment INNER JOIN resource_persistent_state ps on ps.resource_id = first.resource_id AND first.environment = ps.environment WHERE first.environment = $1 AND first.resource_id = $2 AND cm.released = TRUE ORDER BY first.resource_id, first.model asc; """ values = [cls._get_value(env), cls._get_value(resource_id)] result = await cls.select_query(query, values, no_obj=True) if not result: return None record = result[0] parsed_id = resources.Id.parse_id(record["latest_resource_id"]) attributes = json.loads(record["attributes"]) # Due to a bug, the version field has always been present in the attributes dictionary. # This bug has been fixed in the database. For backwards compatibility reason we here make sure that the # version field is present in the attributes dictionary served out via the API. if "version" not in attributes: attributes["version"] = record["latest_model"] requires = [resources.Id.parse_id(req).resource_str() for req in attributes["requires"]] # fetch the status of each of the requires. This is not calculated in the database because the lack of joinable # fields requires to calculate the status for each resource record, before it is filtered status_query = f""" SELECT DISTINCT ON (resource.resource_id) resource.resource_id, {status_sub_query('resource')} FROM resource INNER JOIN configurationmodel cm ON resource.model = cm.version AND resource.environment = cm.environment INNER JOIN resource_persistent_state ps ON ps.resource_id = resource.resource_id AND resource.environment = ps.environment WHERE resource.environment = $1 AND cm.released = TRUE AND resource.resource_id = ANY($2) ORDER BY resource.resource_id, model DESC; """ status_result = await cls.select_query(status_query, [cls._get_value(env), cls._get_value(requires)], no_obj=True) return m.ReleasedResourceDetails( resource_id=record["latest_resource_id"], resource_type=record["resource_type"], agent=record["agent"], id_attribute=parsed_id.attribute, id_attribute_value=record["resource_id_value"], last_deploy=record["latest_deploy"], first_generated_time=record["first_generated_time"], first_generated_version=record["first_model"], attributes=attributes, status=record["status"], requires_status={record["resource_id"]: record["status"] for record in status_result}, ) @classmethod async def get_versioned_resource_details( cls, environment: uuid.UUID, version: int, resource_id: m.ResourceIdStr ) -> Optional[m.VersionedResourceDetails]: resource = await cls.get_one(environment=environment, model=version, resource_id=resource_id) if not resource: return None parsed_id = resources.Id.parse_id(resource.resource_id) parsed_id.set_version(resource.model) return m.VersionedResourceDetails( resource_id=resource.resource_id, resource_version_id=parsed_id.resource_version_str(), resource_type=resource.resource_type, agent=resource.agent, id_attribute=parsed_id.attribute, id_attribute_value=resource.resource_id_value, version=resource.model, attributes=resource.attributes, ) @classmethod async def get_resource_deploy_summary(cls, environment: uuid.UUID) -> m.ResourceDeploySummary: inner_query = f""" SELECT r.resource_id as resource_id, ( CASE WHEN r.status IN ('deploying', 'undefined', 'skipped_for_undefined') THEN r.status::text WHEN rps.last_deployed_attribute_hash != r.attribute_hash -- The hash changed since the last deploy -> new desired state THEN r.status::text ELSE rps.last_non_deploying_status::text END ) as status FROM {cls.table_name()} as r JOIN resource_persistent_state rps ON r.resource_id = rps.resource_id and r.environment = rps.environment WHERE r.environment=$1 AND r.model=(SELECT MAX(cm.version) FROM public.configurationmodel AS cm WHERE cm.environment=$1 AND cm.released=TRUE) """ query = f""" SELECT COUNT(ro.resource_id) as count, ro.status FROM ({inner_query}) as ro GROUP BY ro.status """ raw_results = await cls._fetch_query(query, cls._get_value(environment)) results = {} for row in raw_results: results[row["status"]] = row["count"] return m.ResourceDeploySummary.create_from_db_result(results) @classmethod async def copy_resources_from_unchanged_resource_set( cls, environment: uuid.UUID, source_version: int, destination_version: int, updated_resource_sets: abc.Set[str], deleted_resource_sets: abc.Set[str], *, connection: Optional[asyncpg.connection.Connection] = None, ) -> dict[m.ResourceIdStr, str]: """ Copy the resources that belong to an unchanged resource set of a partial compile, from source_version to destination_version. This method doesn't copy shared resources. """ query = f""" INSERT INTO {cls.table_name()}( environment, model, resource_id, resource_type, resource_id_value, agent, status, attributes, attribute_hash, resource_set, provides )( SELECT r.environment, $3, r.resource_id, r.resource_type, r.resource_id_value, r.agent, ( CASE WHEN r.status='undefined'::resourcestate THEN 'undefined'::resourcestate ELSE 'available'::resourcestate END ) AS status, r.attributes AS attributes, r.attribute_hash, r.resource_set, r.provides FROM {cls.table_name()} AS r WHERE r.environment=$1 AND r.model=$2 AND r.resource_set IS NOT NULL AND NOT r.resource_set=ANY($4) ) RETURNING resource_id, resource_set """ async with cls.get_connection(connection) as con: result = await con.fetch( query, environment, source_version, destination_version, updated_resource_sets | deleted_resource_sets, ) return {str(record["resource_id"]): str(record["resource_set"]) for record in result} @classmethod async def get_resources_in_resource_sets( cls, environment: uuid.UUID, version: int, resource_sets: abc.Set[str], include_shared_resources: bool = False, *, connection: Optional[asyncpg.connection.Connection] = None, ) -> abc.Mapping[ResourceIdStr, "Resource"]: """ Returns the resource in the given environment and version that belong to any of the given resource sets. This method also returns the resources in the share resource set iff the include_shared_resources boolean is set to True. """ if include_shared_resources: resource_set_filter_statement = "(r.resource_set IS NULL OR r.resource_set=ANY($3))" else: resource_set_filter_statement = "r.resource_set=ANY($3)" query = f""" SELECT * FROM {cls.table_name()} AS r WHERE r.environment=$1 AND r.model=$2 AND {resource_set_filter_statement} """ async with cls.get_connection(connection) as con: result = await con.fetch(query, environment, version, resource_sets) return {record["resource_id"]: cls(from_postgres=True, **record) for record in result} async def insert(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: self.make_hash() await super().insert(connection=connection) # TODO: On conflict or is not exists or just make every update an upsert? await self._execute_query( """ INSERT INTO resource_persistent_state (environment, resource_id, resource_type, agent, resource_id_value) VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING """, self.environment, self.resource_id, self.resource_type, self.agent, self.resource_id_value, connection=connection, ) @classmethod async def insert_many( cls, documents: Sequence["Resource"], *, connection: Optional[asyncpg.connection.Connection] = None ) -> None: for doc in documents: doc.make_hash() # TODO performance? for doc in documents: await cls._execute_query( """ INSERT INTO resource_persistent_state (environment, resource_id, resource_type, agent, resource_id_value) VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING """, doc.environment, doc.resource_id, doc.resource_type, doc.agent, doc.resource_id_value, connection=connection, ) await super().insert_many(documents, connection=connection) async def update(self, connection: Optional[asyncpg.connection.Connection] = None, **kwargs: object) -> None: self.make_hash() await super().update(connection=connection, **kwargs) async def update_fields(self, connection: Optional[asyncpg.connection.Connection] = None, **kwargs: object) -> None: self.make_hash() await super().update_fields(connection=connection, **kwargs) def get_requires(self) -> abc.Sequence[ResourceIdStr]: """ Returns the content of the requires field in the attributes. """ if "requires" not in self.attributes: return [] return list(self.attributes["requires"]) def to_dict(self) -> dict[str, object]: self.make_hash() dct = super().to_dict() self.__mangle_dict(dct) return dct def to_dto(self) -> m.Resource: attributes = self.attributes.copy() if "requires" in self.attributes: version = self.model attributes["requires"] = [resources.Id.set_version_in_id(id, version) for id in self.attributes["requires"]] # Due to a bug, the version field has always been present in the attributes dictionary. # This bug has been fixed in the database. For backwards compatibility reason we here make sure that the # version field is present in the attributes dictionary served out via the API. attributes["version"] = self.model return m.Resource( environment=self.environment, model=self.model, resource_id=self.resource_id, resource_type=self.resource_type, resource_version_id=resources.Id.set_version_in_id(self.resource_id, self.model), agent=self.agent, attributes=attributes, status=self.status, resource_id_value=self.resource_id_value, resource_set=self.resource_set, ) async def update_persistent_state( self, last_deploy: Optional[datetime.datetime] = None, last_deployed_version: Optional[int] = None, last_non_deploying_status: Optional[const.NonDeployingResourceState] = None, last_success: Optional[datetime.datetime] = None, last_produced_events: Optional[datetime.datetime] = None, last_deployed_attribute_hash: Optional[str] = None, connection: Optional[asyncpg.connection.Connection] = None, ) -> None: """Update the data in the resource_persistent_state table""" args = ArgumentCollector(2) invalues = { "last_deploy": last_deploy, "last_non_deploying_status": last_non_deploying_status, "last_success": last_success, "last_produced_events": last_produced_events, "last_deployed_attribute_hash": last_deployed_attribute_hash, "last_deployed_version": last_deployed_version, } query_parts = [f"{k}={args(v)}" for k, v in invalues.items() if v is not None] if not query_parts: return query = f"UPDATE public.resource_persistent_state SET {','.join(query_parts)} WHERE environment=$1 and resource_id=$2" result = await self._execute_query(query, self.environment, self.resource_id, *args.args, connection=connection) assert result == "UPDATE 1"
[docs] @stable_api class ConfigurationModel(BaseDocument): """ A specific version of the configuration model. :param version: The version of the configuration model, represented by a unix timestamp. :param environment: The environment this configuration model is defined in :param date: The date this configuration model was created :param partial_base: If this version was calculated from a partial export, the version the partial was applied on. :param released: Is this model released and available for deployment? :param deployed: Is this model deployed? :param result: The result of the deployment. Success or error. :param version_info: Version metadata :param total: The total number of resources :param is_suitable_for_partial_compiles: This boolean indicates whether the model can later on be updated using a partial compile. In other words, the value is True iff no cross resource set dependencies exist between the resources. """ __primary_key__ = ("version", "environment") version: int environment: uuid.UUID date: Optional[datetime.datetime] = None partial_base: Optional[int] = None pip_config: Optional[PipConfig] = None released: bool = False deployed: bool = False result: const.VersionState = const.VersionState.pending version_info: Optional[dict[str, object]] = None is_suitable_for_partial_compiles: bool total: int = 0 # cached state for release undeployable: list[m.ResourceIdStr] = [] skipped_for_undeployable: list[m.ResourceIdStr] = [] def __init__(self, **kwargs: object) -> None: super().__init__(**kwargs) self._status = {} self._done = 0 @classmethod def get_valid_field_names(cls) -> list[str]: return super().get_valid_field_names() + ["status", "model"] @property def done(self) -> int: # Keep resources which are deployed in done, even when a repair operation # changes its state to deploying again. if self.deployed: return self.total return self._done @classmethod async def create_for_partial_compile( cls, env_id: uuid.UUID, version: int, total: int, version_info: Optional[JsonType], undeployable: abc.Sequence[ResourceIdStr], skipped_for_undeployable: abc.Sequence[ResourceIdStr], partial_base: int, pip_config: Optional[PipConfig], updated_resource_sets: abc.Set[str], deleted_resource_sets: abc.Set[str], connection: Optional[Connection] = None, ) -> "ConfigurationModel": """ Create and insert a new configurationmodel that is the result of a partial compile. The new ConfigurationModel will contain all the undeployables and skipped_for_undeployables present in the partial_base version that are not part of the partial compile, i.e. not present in rids_in_partial_compile. """ query = f""" WITH base_version_exists AS ( SELECT EXISTS( SELECT 1 FROM {cls.table_name()} AS c1 WHERE c1.environment=$1 AND c1.version=$8 ) AS base_version_found ), rids_undeployable_base_version AS ( SELECT t.rid FROM ( SELECT DISTINCT unnest(c2.undeployable) AS rid FROM {cls.table_name()} AS c2 WHERE c2.environment=$1 AND c2.version=$8 ) AS t(rid) WHERE ( EXISTS ( SELECT 1 FROM {Resource.table_name()} AS r WHERE r.environment=$1 AND r.model=$8 AND r.resource_id=t.rid -- Keep only resources that belong to the shared resource set or a resource set that was not updated AND (r.resource_set IS NULL OR NOT r.resource_set=ANY($9)) ) ) ), rids_skipped_for_undeployable_base_version AS ( SELECT t.rid FROM( SELECT DISTINCT unnest(c3.skipped_for_undeployable) AS rid FROM {cls.table_name()} AS c3 WHERE c3.environment=$1 AND c3.version=$8 ) AS t(rid) WHERE ( EXISTS ( SELECT 1 FROM {Resource.table_name()} AS r WHERE r.environment=$1 AND r.model=$8 AND r.resource_id=t.rid -- Keep resources that belong to the shared resource set or a resource set that was not updated AND (r.resource_set IS NULL OR NOT r.resource_set=ANY($9)) ) ) ) INSERT INTO {cls.table_name()}( environment, version, date, total, version_info, undeployable, skipped_for_undeployable, partial_base, is_suitable_for_partial_compiles, pip_config ) VALUES( $1, $2, $3, $4, $5, ( SELECT coalesce(array_agg(rid), '{{}}') FROM ( -- Undeployables in previous version of the model that are not part of the partial compile. ( SELECT rid FROM rids_undeployable_base_version AS undepl ) UNION -- Undeployables part of the partial compile. ( SELECT DISTINCT rid FROM unnest($6::varchar[]) AS undeploy_filtered_new(rid) ) ) AS all_undeployable ), ( SELECT coalesce(array_agg(rid), '{{}}') FROM ( -- skipped_for_undeployables in previous version of the model that are not part of the partial -- compile. ( SELECT skipped.rid FROM rids_skipped_for_undeployable_base_version AS skipped ) UNION -- Skipped_for_undeployables part of the partial compile. ( SELECT DISTINCT rid FROM unnest($7::varchar[]) AS skipped_filtered_new(rid) ) ) AS all_skipped ), $8, True, $10::jsonb ) RETURNING (SELECT base_version_found FROM base_version_exists LIMIT 1) AS base_version_found, environment, version, date, total, version_info, undeployable, skipped_for_undeployable, partial_base, released, deployed, result, is_suitable_for_partial_compiles, pip_config """ async with cls.get_connection(connection) as con: result = await con.fetchrow( query, env_id, version, datetime.datetime.now().astimezone(), total, cls._get_value(version_info), undeployable, skipped_for_undeployable, partial_base, updated_resource_sets | deleted_resource_sets, cls._get_value(pip_config), ) # Make mypy happy assert result is not None if not result["base_version_found"]: raise Exception(f"Model with version {partial_base} not found in environment {env_id}") fields = {name: val for name, val in result.items() if name != "base_version_found"} return cls(from_postgres=True, **fields) @classmethod async def _get_status_field(cls, environment: uuid.UUID, values: str) -> dict[str, str]: """ This field is required to ensure backward compatibility on the API. """ result = {} values = json.loads(values) for value_entry in values: entry_uuid = str(uuid.uuid5(environment, value_entry["id"])) result[entry_uuid] = value_entry return result @classmethod async def get_list( cls, *, order_by_column: Optional[str] = None, order: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, no_obj: Optional[bool] = None, lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, no_status: bool = False, # don't load the status field **query: object, ) -> list["ConfigurationModel"]: # sanitize and validate order parameters if order is None: order = "ASC" if order_by_column: cls._validate_order(order_by_column, order) if no_obj is None: no_obj = False # ensure limit and offset is an integer if limit is not None: limit = int(limit) if offset is not None: offset = int(offset) transient_states = ",".join(["$" + str(i) for i in range(1, len(const.TRANSIENT_STATES) + 1)]) transient_states_values = [cls._get_value(s) for s in const.TRANSIENT_STATES] (filterstr, values) = cls._get_composed_filter(col_name_prefix="c", offset=len(transient_states_values) + 1, **query) values = transient_states_values + values where_statement = f"WHERE {filterstr} " if filterstr else "" order_by_statement = f"ORDER BY {order_by_column} {order} " if order_by_column else "" limit_statement = f"LIMIT {limit} " if limit is not None and limit > 0 else "" offset_statement = f"OFFSET {offset} " if offset is not None and offset > 0 else "" lock_statement = f" {lock.value} " if lock is not None else "" query_string = f"""SELECT c.*, SUM(CASE WHEN r.status NOT IN({transient_states}) THEN 1 ELSE 0 END) AS done, to_json(array(SELECT jsonb_build_object('status', r2.status, 'id', r2.resource_id) FROM {Resource.table_name()} AS r2 WHERE c.environment=r2.environment AND c.version=r2.model ) ) AS status FROM {cls.table_name()} AS c LEFT OUTER JOIN {Resource.table_name()} AS r ON c.environment = r.environment AND c.version = r.model {where_statement} GROUP BY c.environment, c.version {order_by_statement} {limit_statement} {offset_statement} {lock_statement}""" query_result = await cls._fetch_query(query_string, *values, connection=connection) result = [] for in_record in query_result: record = dict(in_record) if no_obj: if no_status: record["status"] = {} else: record["status"] = await cls._get_status_field(record["environment"], record["status"]) result.append(record) else: done = record.pop("done") if no_status: status = {} record.pop("status") else: status = await cls._get_status_field(record["environment"], record.pop("status")) obj = cls(from_postgres=True, **record) obj._done = done obj._status = status result.append(obj) return result def to_dict(self) -> JsonType: dct = BaseDocument.to_dict(self) dct["status"] = dict(self._status) dct["done"] = self.done return dct @classmethod async def version_exists(cls, environment: uuid.UUID, version: int) -> bool: query = f"""SELECT 1 FROM {ConfigurationModel.table_name()} WHERE environment=$1 AND version=$2""" result = await cls._fetchrow(query, cls._get_value(environment), cls._get_value(version)) if not result: return False return True @classmethod async def get_version( cls, environment: uuid.UUID, version: int, *, connection: Optional[asyncpg.connection.Connection] = None, lock: Optional[RowLockMode] = None, ) -> Optional["ConfigurationModel"]: """ Get a specific version """ result = await cls.get_one(environment=environment, version=version, connection=connection, lock=lock) return result @classmethod async def get_version_internal( cls, environment: uuid.UUID, version: int, *, connection: Optional[asyncpg.connection.Connection] = None, lock: Optional[RowLockMode] = None, ) -> Optional["ConfigurationModel"]: """Return a version, but don't populate the status and done fields, which are expensive to construct""" query = f"""SELECT * FROM {ConfigurationModel.table_name()} WHERE environment=$1 AND version=$2 {lock.value}; """ result = await cls.select_query(query, [environment, version], connection=connection) if not result: return None return result[0] @classmethod async def get_latest_version( cls, environment: uuid.UUID, *, connection: Optional[Connection] = None, ) -> Optional["ConfigurationModel"]: """ Get the latest released (most recent) version for the given environment """ versions = await cls.get_list( order_by_column="version", order="DESC", limit=1, environment=environment, released=True, connection=connection ) if len(versions) == 0: return None return versions[0] @classmethod async def get_version_nr_latest_version( cls, environment: uuid.UUID, connection: Optional[Connection] = None, ) -> Optional[int]: """ Get the version number of the latest released version in the given environment. """ query = f"""SELECT version FROM {ConfigurationModel.table_name()} WHERE environment=$1 AND released=true ORDER BY version DESC LIMIT 1 """ result = await cls._fetchrow(query, cls._get_value(environment), connection=connection) if not result: return None return int(result["version"]) @classmethod async def get_agents( cls, environment: uuid.UUID, version: int, *, connection: Optional[asyncpg.connection.Connection] = None ) -> list[str]: """ Returns a list of all agents that have resources defined in this configuration model """ (filter_statement, values) = cls._get_composed_filter(environment=environment, model=version) query = "SELECT DISTINCT agent FROM " + Resource.table_name() + " WHERE " + filter_statement result = [] async with cls.get_connection(connection) as con: async with con.transaction(): async for record in con.cursor(query, *values): result.append(record["agent"]) return result
[docs] @classmethod async def get_versions( cls, environment: uuid.UUID, start: int = 0, limit: int = DBLIMIT, connection: Optional[Connection] = None ) -> list["ConfigurationModel"]: """ Get all versions for an environment ordered descending """ versions = await cls.get_list( order_by_column="version", order="DESC", limit=limit, offset=start, environment=environment, connection=connection ) return versions
async def delete_cascade(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ This method doesn't rely on the DELETE CASCADE functionality of PostgreSQL because it causes deadlocks. As such, we perform the deletes on each table in a separate transaction. """ async with self.get_connection(connection=connection) as con: # Delete of compile record triggers cascading delete report table await Compile.delete_all(environment=self.environment, version=self.version, connection=con) await Code.delete_all(environment=self.environment, version=self.version, connection=con) await DryRun.delete_all(environment=self.environment, model=self.version, connection=con) await UnknownParameter.delete_all(environment=self.environment, version=self.version, connection=con) await self._execute_query( "DELETE FROM public.resourceaction_resource WHERE environment=$1 AND resource_version=$2", self.environment, self.version, connection=con, ) await ResourceAction.delete_all(environment=self.environment, version=self.version, connection=con) await Resource.delete_all(environment=self.environment, model=self.version, connection=con) await self.delete(connection=con) # Delete facts when the resources in this version are the only await self._execute_query( f""" DELETE FROM {Parameter.table_name()} p WHERE( environment=$1 AND resource_id<>'' AND NOT EXISTS( SELECT 1 FROM {Resource.table_name()} r WHERE p.resource_id=r.resource_id ) ) """, self.environment, connection=con, ) def get_undeployable(self) -> list[m.ResourceIdStr]: """ Returns a list of resource ids (NOT resource version ids) of resources with an undeployable state """ return self.undeployable def get_skipped_for_undeployable(self) -> list[m.ResourceIdStr]: """ Returns a list of resource ids (NOT resource version ids) of resources which should get a skipped_for_undeployable state """ return self.skipped_for_undeployable async def mark_done(self, *, connection: Optional[asyncpg.connection.Connection] = None) -> None: """mark this deploy as done""" subquery = f"""(EXISTS( SELECT 1 FROM {Resource.table_name()} WHERE environment=$1 AND model=$2 AND status != $3 ))::boolean """ query = f"""UPDATE {self.table_name()} SET deployed=True, result=(CASE WHEN {subquery} THEN $4::versionstate ELSE $5::versionstate END) WHERE environment=$1 AND version=$2 RETURNING result """ values = [ self._get_value(self.environment), self._get_value(self.version), self._get_value(const.ResourceState.deployed), self._get_value(const.VersionState.failed), self._get_value(const.VersionState.success), ] result = await self._fetchval(query, *values, connection=connection) self.result = const.VersionState[result] self.deployed = True @classmethod async def mark_done_if_done( cls, environment: uuid.UUID, version: int, connection: Optional[asyncpg.connection.Connection] = None ) -> None: async with cls.get_connection(connection) as con: """ Performs the query to mark done if done. Expects to be called outside of any transaction that writes resource state in order to prevent race conditions. """ async with con.transaction(): query = f"""UPDATE {ConfigurationModel.table_name()} SET deployed=True, result=(CASE WHEN ( EXISTS(SELECT 1 FROM {Resource.table_name()} WHERE environment=$1 AND model=$2 AND status != $3) )::boolean THEN $4::versionstate ELSE $5::versionstate END ) WHERE environment=$1 AND version=$2 AND total=(SELECT COUNT(*) FROM {Resource.table_name()} WHERE environment=$1 AND model=$2 AND status = any($6::resourcestate[]) )""" values = [ cls._get_value(environment), cls._get_value(version), cls._get_value(ResourceState.deployed), cls._get_value(const.VersionState.failed), cls._get_value(const.VersionState.success), cls._get_value(DONE_STATES), ] await cls._execute_query(query, *values, connection=con) @classmethod async def get_increment( cls, environment: uuid.UUID, version: int, *, connection: Optional[Connection] = None ) -> tuple[set[m.ResourceIdStr], set[m.ResourceIdStr]]: """ Find resources incremented by this version compared to deployment state transitions per resource available -> next version not present -> increment skipped -> increment unavailable -> increment error -> increment Deployed and same hash -> not increment deployed and different hash -> increment """ # Depends on deploying projection_a_resource: list[typing.LiteralString] = [ "resource_id", "attribute_hash", "status", ] projection_a_state: list[typing.LiteralString] = [ "last_success", "last_produced_events", "last_deployed_attribute_hash", "last_non_deploying_status", ] projection_a_attributes: list[typing.LiteralString] = ["requires", "send_event"] projection: list[typing.LiteralString] = ["resource_id", "status", "attribute_hash"] # get resources for agent resources = await Resource.get_resources_for_version_raw_with_persistent_state( environment, version, projection_a_resource, projection_a_state, projection_a_attributes, connection=connection ) # to increment increment: list[abc.Mapping[str, object]] = [] not_increment: list[abc.Mapping[str, object]] = [] # todo in this version work: list[abc.Mapping[str, object]] = [r for r in resources if r["status"] not in UNDEPLOYABLE_NAMES] # start with outstanding events id_to_resource = {r["resource_id"]: r for r in resources} next: list[abc.Mapping[str, object]] = [] for resource in work: in_increment = False status = resource["last_non_deploying_status"] if status in [const.ResourceState.failed.name, ResourceState.skipped.name]: # Shortcut on easy includes increment.append(resource) continue # Now outstanding events last_success = resource["last_success"] or DATETIME_MIN_UTC for req in resource["requires"]: req_res = id_to_resource[req] assert req_res is not None # todo last_produced_events = req_res["last_produced_events"] if last_produced_events is not None and last_produced_events > last_success and req_res["send_event"]: in_increment = True break if in_increment: increment.append(resource) else: next.append(resource) work = next # get versions query = f"SELECT version FROM {cls.table_name()} WHERE environment=$1 AND released=true ORDER BY version DESC" values = [cls._get_value(environment)] version_records = await cls._fetch_query(query, *values, connection=connection) versions = [record["version"] for record in version_records] for version in versions: # todo in next version next = [] vresources = await Resource.get_resources_for_version_raw(environment, version, projection, connection=connection) id_to_resource = {r["resource_id"]: r for r in vresources} for res in work: # not present -> increment if res["resource_id"] not in id_to_resource: increment.append(res) continue ores = id_to_resource[res["resource_id"]] status = ores["status"] # available -> next version if status == ResourceState.available.name: next.append(res) # deploying # same hash -> next version # different hash -> increment elif status == ResourceState.deploying.name: if res["attribute_hash"] == ores["attribute_hash"]: next.append(res) else: increment.append(res) # -> increment elif status in [ ResourceState.failed.name, ResourceState.cancelled.name, ResourceState.skipped_for_undefined.name, ResourceState.undefined.name, ResourceState.skipped.name, ResourceState.unavailable.name, ]: increment.append(res) elif status == ResourceState.deployed.name: if res["attribute_hash"] == ores["attribute_hash"]: # Deployed and same hash -> not increment not_increment.append(res) else: # Deployed and different hash -> increment increment.append(res) else: LOGGER.warning("Resource in unexpected state: %s, %s", ores["status"], ores["resource_version_id"]) increment.append(res) work = next if not work: break if work: increment.extend(work) negative: set[ResourceIdStr] = {res["resource_id"] for res in not_increment} # patch up the graph # 1-include stuff for send-events. # 2-adapt requires/provides to get closured set outset: set[ResourceIdStr] = {res["resource_id"] for res in increment} original_provides: dict[str, list[ResourceIdStr]] = defaultdict(list) send_events: set[ResourceIdStr] = set() # build lookup tables for res in resources: for req in res["requires"]: original_provides[req].append(res["resource_id"]) if res["send_event"]: send_events.add(res["resource_id"]) # recursively include stuff potentially receiving events from nodes in the increment increment_work: list[ResourceIdStr] = list(outset) done: set[ResourceIdStr] = set() while increment_work: current: ResourceIdStr = increment_work.pop() if current not in send_events: # not sending events, so no receivers continue if current in done: continue done.add(current) provides = original_provides[current] increment_work.extend(provides) outset.update(provides) negative.difference_update(provides) return outset, negative @classmethod def active_version_subquery(cls, environment: uuid.UUID) -> tuple[str, list[object]]: query_builder = SimpleQueryBuilder( select_clause=""" SELECT max(version) """, from_clause=f" FROM {cls.table_name()} ", filter_statements=[" environment = $1 AND released = TRUE"], values=[cls._get_value(environment)], ) return query_builder.build() @classmethod def desired_state_versions_subquery(cls, environment: uuid.UUID) -> tuple[str, list[object]]: active_version, values = cls.active_version_subquery(environment) # Coalesce to 0 in case there is no active version active_version = f"(SELECT COALESCE(({active_version}), 0))" query_builder = SimpleQueryBuilder( select_clause=f"""SELECT cm.version, cm.date, cm.total, version_info -> 'export_metadata' ->> 'message' as message, version_info -> 'export_metadata' ->> 'type' as type, (CASE WHEN cm.version = {active_version} THEN 'active' WHEN cm.version > {active_version} THEN 'candidate' WHEN cm.version < {active_version} AND cm.released=TRUE THEN 'retired' ELSE 'skipped_candidate' END) as status""", from_clause=f" FROM {cls.table_name()} as cm", filter_statements=[" environment = $1 "], values=values, ) return query_builder.build() async def recalculate_total(self, connection: Optional[asyncpg.connection.Connection] = None) -> None: """ Make the total field of this ConfigurationModel in-line with the number of resources that are associated with it. """ query = f""" UPDATE {self.table_name()} AS c_outer SET total=( SELECT COUNT(*) FROM {self.table_name()} AS c INNER JOIN {Resource.table_name()} AS r ON c.environment = r.environment AND c.version=r.model WHERE c.environment=$1 AND c.version=$2 ) WHERE c_outer.environment=$1 AND c_outer.version=$2 RETURNING total """ new_total = await self._fetchval(query, self.environment, self.version, connection=connection) if new_total is None: raise KeyError(f"Configurationmodel {self.version} in environment {self.environment} was deleted.") self.total = new_total
class Code(BaseDocument): """ A code deployment :param environment: The environment this code belongs to :param version: The version of configuration model it belongs to :param resource: The resource type this code belongs to :param sources: The source code of plugins (phasing out) form: {code_hash:(file_name, provider.__module__, source_code, [req])} :param requires: Python requires for the source code above :param source_refs: file hashes refering to files in the file store {code_hash:(file_name, provider.__module__, [req])} """ __primary_key__ = ("environment", "resource", "version") environment: uuid.UUID resource: str version: int source_refs: Optional[dict[str, tuple[str, str, list[str]]]] = None @classmethod async def get_version(cls, environment: uuid.UUID, version: int, resource: str) -> Optional["Code"]: codes = await cls.get_list(environment=environment, version=version, resource=resource) if len(codes) == 0: return None return codes[0] @classmethod async def get_versions(cls, environment: uuid.UUID, version: int) -> list["Code"]: codes = await cls.get_list(environment=environment, version=version) return codes @classmethod async def copy_versions( cls, environment: uuid.UUID, old_version: int, new_version: int, *, connection: Optional[asyncpg.connection.Connection] = None, ) -> None: """ Copy all code for one model version to another. """ query: str = f""" INSERT INTO {cls.table_name()} (environment, resource, version, source_refs) SELECT environment, resource, $1, source_refs FROM {cls.table_name()} WHERE environment=$2 AND version=$3 """ await cls._execute_query( query, cls._get_value(new_version), cls._get_value(environment), cls._get_value(old_version), connection=connection ) class DryRun(BaseDocument): """ A dryrun of a model version :param id: The id of this dryrun :param environment: The environment this code belongs to :param model: The configuration model :param date: The date the run was requested :param resource_total: The number of resources that do a dryrun for :param resource_todo: The number of resources left to do :param resources: Changes for each of the resources in the version """ __primary_key__ = ("id",) id: uuid.UUID environment: uuid.UUID model: int date: datetime.datetime total: int = 0 todo: int = 0 resources: dict[str, object] = {} @classmethod async def update_resource(cls, dryrun_id: uuid.UUID, resource_id: m.ResourceVersionIdStr, dryrun_data: JsonType) -> None: """ Register a resource update with a specific query that sets the dryrun_data and decrements the todo counter, only if the resource has not been saved yet. """ jsonb_key = uuid.uuid5(dryrun_id, resource_id) query = ( "UPDATE " + cls.table_name() + " SET todo = todo - 1, resources=jsonb_set(resources, $1::text[], $2) " + "WHERE id=$3 and NOT resources ? $4" ) values = [ cls._get_value([jsonb_key]), cls._get_value(dryrun_data), cls._get_value(dryrun_id), cls._get_value(jsonb_key), ] await cls._execute_query(query, *values) @classmethod async def create(cls, environment: uuid.UUID, model: int, total: int, todo: int) -> "DryRun": obj = cls( environment=environment, model=model, date=datetime.datetime.now().astimezone(), resources={}, total=total, todo=todo, ) await obj.insert() return obj @classmethod async def list_dryruns( cls, order_by_column: Optional[str] = None, order: str = "ASC", **query: object, ) -> list[m.DryRun]: records = await cls.get_list_with_columns( order_by_column=order_by_column, order=order, columns=["id", "environment", "model", "date", "total", "todo"], limit=None, offset=None, no_obj=None, connection=None, lock=None, **query, ) return [ m.DryRun( id=record.id, environment=record.environment, model=record.model, date=record.date, total=record.total, todo=record.todo, ) for record in records ] def to_dict(self) -> JsonType: dict_result = BaseDocument.to_dict(self) resources = {r["id"]: r for r in dict_result["resources"].values()} dict_result["resources"] = resources return dict_result def to_dto(self) -> m.DryRun: return m.DryRun( id=self.id, environment=self.environment, model=self.model, date=self.date, total=self.total, todo=self.todo, ) class Notification(BaseDocument): """ A notification in an environment :param id: The id of this notification :param environment: The environment this notification belongs to :param created: The date the notification was created at :param title: The title of the notification :param message: The actual text of the notification :param severity: The severity of the notification :param uri: A link to an api endpoint of the server, that is relevant to the message, and can be used to get further information about the problem. For example a compile related problem should have the uri: `/api/v2/compilereport/<compile_id>` :param read: Whether the notification was read or not :param cleared: Whether the notification was cleared or not """ __primary_key__ = ("id", "environment") id: uuid.UUID environment: uuid.UUID created: datetime.datetime title: str message: str severity: const.NotificationSeverity = const.NotificationSeverity.message uri: Optional[str] = None read: bool = False cleared: bool = False @classmethod async def clean_up_notifications(cls) -> None: default_retention_time = Environment._settings[NOTIFICATION_RETENTION].default LOGGER.info("Cleaning up notifications") query = f""" WITH non_halted_envs AS ( SELECT id, (COALESCE((settings->>'notification_retention')::int, $1)) AS retention_days FROM {Environment.table_name()} WHERE NOT halted ) DELETE FROM {cls.table_name()} USING non_halted_envs WHERE environment = non_halted_envs.id AND created < now() AT TIME ZONE 'UTC' - make_interval(days => non_halted_envs.retention_days) """ await cls._execute_query(query, default_retention_time) def to_dto(self) -> m.Notification: return m.Notification( id=self.id, title=self.title, message=self.message, severity=self.severity, created=self.created, read=self.read, cleared=self.cleared, uri=self.uri, environment=self.environment, ) class EnvironmentMetricsGauge(BaseDocument): """ A metric that is of type gauge :param environment: the environment to which this metric is related :param metric_name: The name of the metric :param timestamp: The timestamps at which a new record is created :category: The name of the group/category this metric represents (e.g. red if grouped by color). __None__ iff metrics of this type are not divided in groups. :param count: the counter for the metric for the given timestamp """ environment: uuid.UUID metric_name: str category: str timestamp: datetime.datetime count: int __primary_key__ = ("environment", "metric_name", "category", "timestamp") class EnvironmentMetricsTimer(BaseDocument): """ A metric that is type timer :param environment: the environment to which this metric is related :param metric_name: The name of the metric :category: The name of the group/category this metric represents (e.g. red if grouped by color). __None__ iff metrics of this type are not divided in groups. :param timestamp: The timestamps at which a new record is created :param count: the number of occurrences of the monitored event in the interval [previous.timestamp, self.timestamp[ :param value: the sum of the values of the metric for each occurrence in the interval [previous.timestamp, self.timestamp[ """ environment: uuid.UUID metric_name: str category: str timestamp: datetime.datetime count: int value: float __primary_key__ = ("environment", "metric_name", "category", "timestamp") class User(BaseDocument): """A user that can authenticate against inmanta""" __primary_key__ = ("id",) id: uuid.UUID username: str password_hash: str auth_method: AuthMethod @classmethod def table_name(cls) -> str: """ Return the name of table. we call it inmanta_user to differentiate it from the pg user table. """ return "inmanta_user" def to_dao(self) -> m.User: return m.User(username=self.username, auth_method=self.auth_method) class DiscoveredResource(BaseDocument): """ :param environment: the environment of the resource :param discovered_resource_id: The id of the resource :param discovery_resource_id: The id of the discovery resource responsible for discovering this resource :param values: The values associated with the discovered_resource """ environment: uuid.UUID discovered_at: datetime.datetime discovered_resource_id: m.ResourceIdStr discovery_resource_id: Optional[m.ResourceIdStr] values: dict[str, object] __primary_key__ = ("environment", "discovered_resource_id") def to_dto(self) -> m.DiscoveredResource: return m.DiscoveredResource( discovered_resource_id=self.discovered_resource_id, values=self.values, discovery_resource_id=self.discovery_resource_id, ) class File(BaseDocument): content_hash: str content: bytes @classmethod async def has_file_with_hash(cls, content_hash: str) -> bool: """ Return True iff a file exists with the given content_hash. """ query = f""" SELECT EXISTS ( SELECT 1 FROM {cls.table_name()} WHERE content_hash=$1 ) """ result = await cls._fetchval(query, content_hash) assert isinstance(result, bool) return result @classmethod async def get_non_existing_files(cls, content_hashes: Iterable[str]) -> set[str]: """ Return a sub-list of content_hashes, with only those hashes that are not present in this database table. The returned list will not contain duplicates. """ query = f""" SELECT DISTINCT tmp_table.h_content_hash AS content_hash FROM ( SELECT f.content_hash AS f_content_hash, h.content_hash as h_content_hash FROM {cls.table_name()} AS f RIGHT OUTER JOIN unnest($1::varchar[]) AS h(content_hash) ON f.content_hash = h.content_hash ) as tmp_table -- Only keep records for which no matching hash was found in the file table WHERE tmp_table.f_content_hash IS NULL """ result = await cls._fetch_query(query, content_hashes) return {cast(str, r["content_hash"]) for r in result} _classes = [ Project, Environment, UnknownParameter, AgentProcess, AgentInstance, Agent, Resource, ResourceAction, ResourcePersistentState, ConfigurationModel, Code, Parameter, DryRun, Compile, Report, Notification, EnvironmentMetricsGauge, EnvironmentMetricsTimer, User, DiscoveredResource, File, ] def set_connection_pool(pool: asyncpg.pool.Pool) -> None: LOGGER.debug("Connecting data classes") for cls in _classes: cls.set_connection_pool(pool) async def disconnect() -> None: LOGGER.debug("Disconnecting data classes") # Enable `return_exceptions` to make sure we wait until all close_connection_pool() calls are finished # or until the gather itself is cancelled. result = await asyncio.gather(*[cls.close_connection_pool() for cls in _classes], return_exceptions=True) exceptions = [r for r in result if r is not None and isinstance(r, Exception)] if exceptions: raise exceptions[0] PACKAGE_WITH_UPDATE_FILES = inmanta.db.versions # Name of core schema in the DB schema verions # prevent import loop CORE_SCHEMA_NAME = schema.CORE_SCHEMA_NAME async def connect( host: str, port: int, database: str, username: str, password: str, create_db_schema: bool = True, connection_pool_min_size: int = 10, connection_pool_max_size: int = 10, connection_timeout: float = 60, ) -> asyncpg.pool.Pool: pool = await asyncpg.create_pool( host=host, port=port, database=database, user=username, password=password, min_size=connection_pool_min_size, max_size=connection_pool_max_size, timeout=connection_timeout, ) try: set_connection_pool(pool) if create_db_schema: async with pool.acquire() as con: await schema.DBSchema(CORE_SCHEMA_NAME, PACKAGE_WITH_UPDATE_FILES, con).ensure_db_schema() # expire connections after db schema migration to ensure cache consistency await pool.expire_connections() return pool except Exception as e: await pool.close() await disconnect() raise e