from typing import cast

from sqlalchemy.schema import CreateSchema, DropSchema
from sqlmodel import Session, create_engine, select

from app import crud
from app.core.config import settings
from app.models import Group, GroupBase, User, UserCreate, UserRole

engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))


# make sure all SQLModel models are imported (app.models) before initializing DB
# otherwise, SQLModel might fail to initialize relationships properly
# for more details: https://github.com/fastapi/full-stack-fastapi-template/issues/28


def init_db(session: Session) -> None:
    # Tables should be created with Alembic migrations
    # But if you don't want to use migrations, create
    # the tables un-commenting the next lines
    # from sqlmodel import SQLModel

    # # This works because the models are already imported and registered from app.models
    # SQLModel.metadata.create_all(engine)

    # Add the elements below if not present:
    # - a default group and one for each accepted domain
    # - the first superuser in the default group
    for group_name in [settings.DEFAULT_GROUP] + settings.ACCEPTED_DOMAINS:
        group = session.exec(select(Group).where(Group.name == group_name)).first()
        if not group:
            group_in = GroupBase(name=group_name, domain=None if group_name == settings.DEFAULT_GROUP else group_name)
            group = cast(
                Group,
                crud.create_generic_item(session=session, model=Group, item=group_in),
            )
        if group_name == settings.DEFAULT_GROUP:
            user = session.exec(select(User).where(User.email == settings.FIRST_SUPERUSER)).first()
            if not user:
                user_in = UserCreate(
                    email=settings.FIRST_SUPERUSER,
                    password=settings.FIRST_SUPERUSER_PASSWORD,
                    role=UserRole.superadmin,
                    group_id=group.id,
                )
                user = crud.create_user(session=session, user_create=user_in)


def flush_db(schema_name: str = "public") -> None:
    """Delete all tables of the given schema. Use with extreme care"""
    with engine.connect() as connection:
        connection.execute(DropSchema(schema_name, cascade=True))  # type: ignore[no-untyped-call]
        connection.execute(CreateSchema(schema_name, if_not_exists=True))  # type: ignore[no-untyped-call]
        connection.commit()
    return None
