"""
Copyright 2019 Inmanta
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Contact: code@inmanta.com
"""
import logging
import pkgutil
import re
from collections.abc import Coroutine
from types import ModuleType
from typing import Any, Callable, Optional
from asyncpg import Connection, UndefinedTableError
from asyncpg.protocol import Record
from inmanta import tracing
# Name of core schema in the DB schema verions
CORE_SCHEMA_NAME = "core"
LOGGER = logging.getLogger(__name__)
SCHEMA_VERSION_TABLE = "schemamanager"
create_schemamanager = """
-- Table: public.schemamanager
CREATE TABLE IF NOT EXISTS public.schemamanager (
name varchar PRIMARY KEY,
legacy_version integer,
installed_versions integer[]
);
"""
[docs]
class TableNotFound(Exception):
"""Raised when a table is not found in the database"""
[docs]
class ColumnNotFound(Exception):
"""Raised when a column is not found in the database"""
class Version:
"""Internal representation of a version"""
def __init__(self, name: str, function: Callable[[Connection], Coroutine[Any, Any, None]]):
self.name = name
self.function = function
self.version = self.parse(name)
@classmethod
def parse(cls, name: str) -> int:
return int(name[1:])
class DBSchema:
"""
Schema Manager, ensures the schema is up to date.
Concurrent updates are safe
"""
def __init__(self, name: str, package: ModuleType, connection: Connection) -> None:
"""
:param name: unique name for this schema, best equal to extension name and used as prefix for all table names
:param package: a python package, containing modules with name v%(version)s.py.
Each module contains a method `async def update(connection: asyncpg.connection) -> None:`
:param connection: asyncpg connection
"""
self.name = name
self.package = package
self.connection = connection
self.logger = LOGGER.getChild(f"schema:{self.name}")
@tracing.instrument("ensure_db_schema")
async def ensure_db_schema(self) -> None:
await self.ensure_self_update()
await self._update_db_schema()
async def ensure_self_update(self) -> None:
"""
Ensures the table exists and is up to date with respect to the current schema.
"""
self.logger.info("Creating schema version table")
await self.connection.execute(create_schemamanager)
async def _update_db_schema(self, update_functions: Optional[list[Version]] = None) -> None:
"""
Main update function
Wrapped in transaction, that holds a lock on the schemamanager table.
When a version update fails, the whole transaction is rolled back.
:param update_functions: allows overriding the available update functions, for example for testing purposes.
"""
update_functions = (
sorted(update_functions, key=lambda x: x.version) if update_functions is not None else self._get_update_functions()
)
async with self.connection.transaction():
# get lock
await self.connection.execute(f"LOCK TABLE {SCHEMA_VERSION_TABLE} IN ACCESS EXCLUSIVE MODE")
# get current version again, in transaction this time
try:
installed_versions: set[int] = await self.get_installed_versions()
except TableNotFound:
self.logger.exception("Schemamanager table disappeared, should not occur.")
raise
# get relevant updates
updates = [v for v in update_functions if v.version not in installed_versions]
for version in updates:
try:
# actual update sequence
self.logger.info("Updating database schema to version %d", version.version)
update_function = version.function
await update_function(self.connection)
await self.set_installed_version(version.version)
# inform asyncpg of the type change so it knows to refresh its caches
await self.connection.reload_schema_state()
except Exception:
self.logger.exception(
"Database schema update for version %d failed. Rolling back all updates.",
version.version,
)
# propagate exception => roll back transaction
raise
async def get_installed_versions(self) -> set[int]:
"""
Returns the set of all versions that have been installed.
:raises TableNotFound:
"""
versions: Optional[Record] = None
try:
versions = await self.connection.fetchrow(
f"select installed_versions from {SCHEMA_VERSION_TABLE} where name=$1", self.name
)
except UndefinedTableError as e:
raise TableNotFound() from e
if versions is None or versions["installed_versions"] is None:
return set()
return set(versions["installed_versions"])
async def set_installed_version(self, version: int) -> None:
"""
Adds a version to the installed versions column.
"""
await self.connection.execute(
f"""
INSERT INTO {SCHEMA_VERSION_TABLE} (name, installed_versions)
VALUES ($1, $2) ON CONFLICT (name) DO UPDATE
SET installed_versions = {SCHEMA_VERSION_TABLE}.installed_versions || excluded.installed_versions
""",
self.name,
{version},
)
def _get_update_functions(self) -> list[Version]:
module_names = [modname for _, modname, ispkg in pkgutil.iter_modules(self.package.__path__) if not ispkg]
def get_modules(mod_name: str) -> tuple[str, ModuleType]:
fq_module_name = self.package.__name__ + "." + mod_name
return mod_name, __import__(fq_module_name, fromlist=["update"])
def make_version(mod_name: str, module: ModuleType) -> Version:
update_function = module.update
return Version(mod_name, update_function)
def disabled(module: ModuleType) -> bool:
try:
return module.DISABLED
except AttributeError:
return False
pattern = re.compile("^v[0-9]+$")
filtered_module_names = []
for module_name in module_names:
if not pattern.match(module_name):
LOGGER.warning(
f"Database schema version file name {module_name} "
f"doesn't match the expected pattern: v<version_number>.py, skipping it"
)
else:
filtered_module_names.append(module_name)
modules_with_names = [get_modules(mod_name) for mod_name in filtered_module_names]
filtered_modules = [(module_name, module) for module_name, module in modules_with_names if not disabled(module)]
version = [make_version(name, mod) for name, mod in filtered_modules]
return sorted(version, key=lambda x: x.version)