Skip to content
Snippets Groups Projects
Commit d803c242 authored by Andréas Livet's avatar Andréas Livet Committed by blenzi
Browse files

add /chat/queries route and test

parent 5c38a6de
No related branches found
No related tags found
No related merge requests found
"""Add timestamp to chat_query
Revision ID: b6e2cf2d4ee5
Revises: dc31368f80ca
Create Date: 2025-01-30 10:18:27.180692
"""
from datetime import datetime, timedelta
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b6e2cf2d4ee5'
down_revision = 'dc31368f80ca'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('chatquery', sa.Column('created_at', sa.DateTime(), nullable=True))
# ### end Alembic commands ###
# The column was added on the 29th january 2025 so for the existing entries we'll use this date as a base
original_date = "2025-01-29"
op.execute(f"""
UPDATE chatquery
SET created_at= '{original_date}'
WHERE created_at IS NULL
""")
op.alter_column('chatquery', 'created_at', nullable=False)
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('chatquery', 'created_at')
# ### end Alembic commands ###
......@@ -14,7 +14,16 @@ from app import crud
from app.api.deps import CurrentUser, SessionDep
from app.core.config import settings
from app.fixtures import fake_sofia_chat_response
from app.models import ChatCreate, ChatMessage, ChatQuery, Collection, Document, MessageRequest
from app.models import (
ChatCreate,
ChatMessage,
ChatQuery,
ChatQueryPublic,
Collection,
Document,
MessageRequest,
UserRole,
)
from app.utils import user_permissions
# This file is inspired from : https://gitlab.adullact.net/dgfip/projets-ia/caradoc/-/blob/dc346202070924f0fa64ebfd495429682f150722/api/app/routers/chat.py
......@@ -190,3 +199,12 @@ async def process_message(session: SessionDep, current_user: CurrentUser, chat_m
transform_sofia_chat_response(r.aiter_lines(), documents_to_map), # type: ignore
background=BackgroundTask(r.aclose),
)
@router.get(
"/queries", response_description="Get all the chat queries (superadmin only)", response_model=list[ChatQueryPublic]
)
async def get_chat_queries(session: SessionDep, current_user: CurrentUser) -> Any:
if current_user.role != UserRole.superadmin:
raise HTTPException(status_code=403, detail="Only superadmin can access this route")
return session.exec(select(ChatQuery)).all()
......@@ -350,12 +350,17 @@ class ChatCreate(SQLModel):
query: str
class ChatQueryPublic(ChatCreate):
created_at: datetime
class ChatQuery(ChatCreate, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
collection_id: uuid.UUID = Field(foreign_key="collection.id", nullable=False, ondelete="CASCADE")
collection: Collection = Relationship(back_populates="chat_queries")
user_id: uuid.UUID = Field(foreign_key="user.id", nullable=False, ondelete="CASCADE")
user: User = Relationship(back_populates="chat_queries")
created_at: datetime = Field(default_factory=datetime.utcnow, nullable=False)
class SearchRequest(BaseModel):
......
......@@ -4,6 +4,8 @@ from sqlmodel import Session, select
from app.core.config import settings
from app.models import ChatQuery, User
from app.tests.utils.collection import create_random_collection
from app.tests.utils.user import AuthenticatedUser, random_authenticated_user
from app.tests.utils.utils import random_lower_string
def test_store_query(client: TestClient, normal_user_token_headers: dict[str, str], db: Session) -> None:
......@@ -13,7 +15,7 @@ def test_store_query(client: TestClient, normal_user_token_headers: dict[str, st
collection = create_random_collection(db, user_db)
query = "Quel est l'impact"
# Ensure that settings is in fake
data = {"query": "Quel est l'impact", "collection_ids": [str(collection.id)]}
data = {"query": query, "collection_ids": [str(collection.id)]}
response = client.post(
f"{settings.API_V1_STR}/chat/message",
headers=normal_user_token_headers,
......@@ -26,3 +28,41 @@ def test_store_query(client: TestClient, normal_user_token_headers: dict[str, st
assert chat_query_db
assert chat_query_db.collection_id == collection.id
assert chat_query_db.query == query
def create_random_query(client: TestClient, authent_user: AuthenticatedUser, db: Session) -> str:
collection = create_random_collection(db, authent_user.user)
query = random_lower_string()
data = {"query": query, "collection_ids": [str(collection.id)]}
client.post(
f"{settings.API_V1_STR}/chat/message",
headers=authent_user.token,
json=data,
)
return query
def test_chat_queries(client: TestClient, superuser_token_headers: dict[str, str], db: Session) -> None:
authent_user_1 = random_authenticated_user(db)
authent_user_2 = random_authenticated_user(db)
query_1 = create_random_query(client, authent_user_1, db)
query_2 = create_random_query(client, authent_user_2, db)
# Only superadmin can access the route
response = client.get(
f"{settings.API_V1_STR}/chat/queries",
headers=authent_user_1.token,
)
assert response.status_code == 403
response = client.get(
f"{settings.API_V1_STR}/chat/queries",
headers=superuser_token_headers,
)
assert response.status_code == 200
content = response.json()
assert len(content) == 2
assert content[0]["query"] == query_1
assert content[0]["user_id"] == str(authent_user_1.user.id)
assert content[1]["query"] == query_2
assert content[1]["user_id"] == str(authent_user_2.user.id)
......@@ -48,7 +48,7 @@ def postgres() -> Generator[str, None, None]:
yield tmp_url
@pytest.fixture(scope="session")
@pytest.fixture(scope="function")
def migration_postgres() -> Generator[str, None, None]:
"""
Creates empty temporary database.
......
import importlib
import os
import uuid
from collections import defaultdict, namedtuple
from typing import Any
import pytest
from alembic.config import Config
from sqlmodel import Session, text
from app.main import PROJECT_PATH
def create_group(session: Session) -> int:
group_id = 1
query = text(f"""insert into public.group (id, name, domain) values ({group_id}, 'test', 'test.com')""")
session.execute(query)
return group_id
def create_user(session: Session, group_id: int) -> uuid.UUID:
user_id = uuid.uuid4()
query = text(f"""insert into public.user (id, email, is_active, role, sso, group_id)
values ('{user_id}', 'test@test.com', true, 'user', true, {group_id})""")
session.execute(query)
return user_id
def create_collection(session: Session, group_id: int, user_id: uuid.UUID) -> uuid.UUID:
collection_id = uuid.uuid4()
query = text(f"""insert into public.collection (id, title, type, user_id, group_id, created_at)
values ('{collection_id}', 'test', 'personal', '{user_id}', {group_id}, current_timestamp);""")
session.execute(query)
return collection_id
def load_migration_as_module(file: str) -> Any:
"""
Allows to import alembic migration as a module.
......
......@@ -24,7 +24,7 @@ import uuid
from sqlalchemy.engine import Engine
from sqlmodel import Session, text
from app.tests.migrations.conftest import load_migration_as_module
from app.tests.migrations.conftest import create_collection, create_group, create_user, load_migration_as_module
# Load migration as module
migration = load_migration_as_module("11d362313208_document_comments_column_is_now_not_.py")
......@@ -41,18 +41,9 @@ def on_init(engine: Engine) -> None:
global document_id
with Session(engine) as session:
group_id = 1
query = text(f"""insert into public.group (id, name, domain) values ({group_id}, 'test', 'test.com')""")
session.execute(query)
user_id = uuid.uuid4()
query = text(f"""insert into public.user (id, email, is_active, role, sso, group_id)
values ('{user_id}', 'test@test.com', true, 'user', true, {group_id})""")
session.execute(query)
# Create the document collection
collection_id = uuid.uuid4()
query = text(f"""insert into public.collection (id, title, type, user_id, group_id, created_at)
values ('{collection_id}', 'test', 'personal', '{user_id}', {group_id}, current_timestamp);""")
session.execute(query)
group_id = create_group(session)
user_id = create_user(session, group_id)
collection_id = create_collection(session, group_id, user_id)
query = text(f"""insert into public.document(id, filename, content_length, collection_id, created_at, updated_at)
values('{document_id}', 'test.pdf', 1000, '{collection_id}', current_timestamp, current_timestamp);""")
session.execute(query)
......
import uuid
from datetime import datetime
from sqlalchemy.engine import Engine
from sqlmodel import Session, text
from app.tests.migrations.conftest import create_collection, create_group, create_user, load_migration_as_module
# Load migration as module
migration = load_migration_as_module("b6e2cf2d4ee5_add_timestamp_to_chat_query.py")
rev_base = migration.down_revision
rev_head: str = migration.revision
def on_init(engine: Engine) -> None:
"""
Create rows in users table before migration is applied
"""
chat_id = uuid.uuid4()
with Session(engine) as session:
group_id = create_group(session)
user_id = create_user(session, group_id)
collection_id = create_collection(session, group_id, user_id)
query = text(f"""insert into public.chatquery(id, collection_id, user_id, query)
values('{chat_id}', '{collection_id}', '{user_id}', 'test');""")
session.execute(query)
session.commit()
def on_upgrade(engine: Engine) -> None:
"""
Ensure that data was successfully migrated
"""
original_date = datetime.strptime("2025-01-29", "%Y-%m-%d")
with Session(engine) as session:
query = text("select * from public.chatquery")
res = session.execute(query).first()
assert hasattr(res, "created_at")
assert res and res.created_at == original_date
def on_downgrade(engine: Engine) -> None:
"""
Ensure that data changes were rolled back
"""
with Session(engine) as session:
query = text("select * from public.chatquery")
res = session.execute(query).first()
assert res and not hasattr(res, "created_at") and hasattr(res, "id")
......@@ -11,16 +11,14 @@ from alembic.config import Config
from sqlalchemy import Engine, create_engine
from app.tests.migrations.conftest import MigrationValidationParamsGroup, make_validation_params_groups
from app.tests.migrations.data_migrations import migration_11d362313208
from app.tests.migrations.data_migrations import migration_11d362313208, migration_b6e2cf2d4ee5
def get_data_migrations() -> list[MigrationValidationParamsGroup]:
"""
Returns tests for data migrations, from tests/data_migrations folder.
"""
return make_validation_params_groups(
migration_11d362313208,
)
return make_validation_params_groups(migration_11d362313208, migration_b6e2cf2d4ee5)
@pytest.mark.parametrize(("rev_base", "rev_head", "on_init", "on_upgrade", "on_downgrade"), get_data_migrations())
......
from datetime import timedelta
from fastapi.testclient import TestClient
from pydantic import BaseModel
from sqlmodel import Session
from app import crud
from app.core import security
from app.core.config import settings
from app.models import User, UserCreate, UserUpdate
from app.tests.utils.utils import random_email, random_lower_string
......@@ -25,6 +29,18 @@ def create_random_user(db: Session) -> User:
return user
class AuthenticatedUser(BaseModel):
user: User
token: dict[str, str]
def random_authenticated_user(db: Session) -> AuthenticatedUser:
user = create_random_user(db)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
auth_token = security.create_access_token(user.id, expires_delta=access_token_expires)
return AuthenticatedUser(user=user, token={"Authorization": f"Bearer {auth_token}"})
def authentication_token_from_email(*, client: TestClient, email: str, db: Session) -> dict[str, str]:
"""
Return a valid token for the user with given email.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment