Skip to content
Snippets Groups Projects
Commit 08f5a067 authored by Andréas Livet's avatar Andréas Livet Committed by blenzi
Browse files

add chat history

parent 4a55b2f7
No related branches found
No related tags found
No related merge requests found
"""add chat history
Revision ID: 2f61ad989231
Revises: bb5b4cdab369
Create Date: 2025-02-06 10:55:59.681012
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '2f61ad989231'
down_revision = 'bb5b4cdab369'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('chatquery', sa.Column('title', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True))
op.add_column('chatquery', sa.Column('updated_at', sa.DateTime(), nullable=True))
op.add_column('chatquery', sa.Column('history', postgresql.JSONB(astext_type=sa.Text()), server_default='[]', nullable=False))
# ### end Alembic commands ###
op.execute(f"""
UPDATE chatquery
SET updated_at=created_at, title=substring(query from 0 for 50)
WHERE updated_at IS NULL
AND title IS NULL
""")
op.alter_column('chatquery', 'title', nullable=False)
op.alter_column('chatquery', 'updated_at', nullable=False)
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('chatquery', 'history')
op.drop_column('chatquery', 'updated_at')
op.drop_column('chatquery', 'title')
# ### end Alembic commands ###
......@@ -8,6 +8,8 @@ from typing import Any
import httpx
from fastapi import APIRouter, HTTPException, status
from fastapi.responses import StreamingResponse
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlmodel import col, select
from starlette.background import BackgroundTask
......@@ -20,6 +22,7 @@ from app.models import (
ChatMessage,
ChatQuery,
ChatQueryPublic,
ChatQueryUpdate,
Collection,
Document,
MessageRequest,
......@@ -118,10 +121,12 @@ async def transform_sofia_chat_response(
@router.post(
"/message",
"/query/{id}/message",
response_description="Answer user prompt request",
)
async def process_message(session: SessionDep, current_user: CurrentUser, chat_message: ChatMessage) -> Any:
async def process_message(
session: SessionDep, current_user: CurrentUser, id: uuid.UUID, chat_message: ChatMessage
) -> Any:
"""Process the user prompt request and generate an accurate answer
Args:
......@@ -140,6 +145,7 @@ async def process_message(session: SessionDep, current_user: CurrentUser, chat_m
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Sofia chat not configured")
with user_permissions(current_user):
crud.read_generic_item(session=session, model=ChatQuery, item_id=id, check_edit_permission=True)
# Retieve collection documents and map them to the response
documents_to_map: Sequence[uuid.UUID] = []
if settings.CHAT.fake:
......@@ -159,18 +165,6 @@ async def process_message(session: SessionDep, current_user: CurrentUser, chat_m
if len({col.type for col in collections}) > 1:
raise HTTPException(status_code=400, detail="Cannot query collections of different type")
# Store the query for analytics
query_in = ChatCreate(
query=chat_message.query,
collection_id=chat_message.collection_ids[0],
user_id=current_user.id,
)
crud.create_generic_item(
session=session,
model=ChatQuery,
item=query_in,
)
if settings.CHAT.fake:
return StreamingResponse(
transform_sofia_chat_response(fake_sofia_chat_response(), documents_to_map),
......@@ -201,9 +195,46 @@ async def process_message(session: SessionDep, current_user: CurrentUser, chat_m
@router.get(
"/queries", response_description="Get all the chat queries (superadmin only)", response_model=list[ChatQueryPublic]
"/queries", response_description="Get all the chat queries (superadmin only)", response_model=Page[ChatQueryPublic]
)
async def get_chat_queries(session: SessionDep, current_user: CurrentUser) -> Any:
if current_user.role != UserRole.superadmin:
raise HTTPException(status_code=403, detail="Only superadmin can access this route")
return session.exec(select(ChatQuery)).all()
return paginate(session, select(ChatQuery))
@router.post("/query", response_description="Create query, needed to generated a response", response_model=ChatQuery)
async def create_chat_query(session: SessionDep, current_user: CurrentUser, chat_query_in: ChatCreate) -> Any:
with user_permissions(current_user):
max_title_len = 50
return crud.create_generic_item(
session=session,
model=ChatQuery,
item=chat_query_in,
user_id=current_user.id,
title=chat_query_in.query
if len(chat_query_in.query) < max_title_len
else chat_query_in.query[:max_title_len] + "...",
).model_dump()
@router.get("/query/{id}", response_description="Get a speicfic query", response_model=ChatQuery)
async def get_chat_query(session: SessionDep, current_user: CurrentUser, id: uuid.UUID) -> Any:
with user_permissions(current_user):
return crud.read_generic_item(session=session, model=ChatQuery, item_id=id)
@router.put("/query/{id}", response_model=ChatQuery)
async def update_chat_query(
session: SessionDep, current_user: CurrentUser, id: uuid.UUID, chat_query_in: ChatQueryUpdate
) -> Any:
with user_permissions(current_user):
return crud.update_generic_item(
session=session, model=ChatQuery, item_update=chat_query_in, item_id=id
).model_dump()
@router.delete("/query/{id}")
async def delete_chat_query(session: SessionDep, current_user: CurrentUser, id: uuid.UUID) -> Any:
with user_permissions(current_user):
return crud.delete_generic_item(session=session, model=ChatQuery, item_id=id)
......@@ -14,6 +14,7 @@ from app.api.deps import CurrentUser, SessionDep
from app.core.qdrant import client as qdrant_client
from app.core.qdrant import delete_vectors
from app.models import (
ChatQuery,
Collection,
CollectionCreate,
CollectionPublic,
......@@ -227,3 +228,14 @@ async def create_document(
doc = handle_pdf_upload(session, id, file)
uploaded_documents.append(doc.model_dump())
return uploaded_documents
@router.get("/{id}/chat_queries", response_model=list[ChatQuery])
async def read_queries(session: SessionDep, current_user: CurrentUser, id: uuid.UUID) -> Any:
with user_permissions(current_user):
collection = cast(Collection, crud.read_generic_item(session=session, model=Collection, item_id=id))
if not collection.can_read:
raise HTTPException(status_code=403, detail="The user doesn't have enough privileges")
return session.exec(
select(ChatQuery).where(ChatQuery.collection_id == id).where(ChatQuery.user_id == current_user.id)
).all()
......@@ -12,3 +12,14 @@ async def fake_sofia_chat_response() -> AsyncGenerator[str, None]:
for line in file:
yield line
await asyncio.sleep(0.001)
async def fake_sofia_collection_response() -> AsyncGenerator[str, None]:
"""
It can be usefull to also fake the entire sofia_collection response
in order to investigate front problematic for examples
"""
with open("./fixtures/sofia_collection_response.json") as file:
for line in file:
yield line
await asyncio.sleep(0.001)
......@@ -351,25 +351,50 @@ class HistoryItem(BaseModel):
class ChatMessage(BaseModel):
query: str
collection_ids: list[uuid.UUID]
model_config = {
"json_schema_extra": {
"examples": [
{
"query": "Quel est l'impact du projet sur la faune aquatique ?",
"collection_ids": ["b8dbebca-472b-412a-912d-1bd50b6b92ae"],
}
]
}
}
class ChatCreate(SQLModel):
collection_id: uuid.UUID
user_id: uuid.UUID
query: str
class ChatQueryPublic(ChatCreate):
id: uuid.UUID
user_id: uuid.UUID
created_at: datetime
updated_at: datetime
title: str
class ChatQueryUpdate(ChatCreate):
title: str = Field(min_length=1, max_length=255)
history: list[dict[str, Any]]
class ChatQuery(ChatCreate, table=True):
class ChatQuery(ChatQueryUpdate, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
collection_id: uuid.UUID = Field(foreign_key="collection.id", nullable=False, ondelete="CASCADE")
collection: Collection = Relationship(back_populates="chat_queries")
user_id: uuid.UUID = Field(foreign_key="user.id", nullable=False, ondelete="CASCADE")
user: User = Relationship(back_populates="chat_queries")
created_at: datetime = Field(default_factory=datetime.utcnow, nullable=False)
updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False)
history: list[dict[str, Any]] = Field(
nullable=False,
default=[],
sa_type=JSONB,
description="Contains the chat history in a format return by /chat/message",
)
class SearchRequest(BaseModel):
......
from typing import Any
from fastapi.testclient import TestClient
from sqlmodel import Session, select
from app.core.config import settings
from app.models import ChatQuery, User
from app.models import ChatQuery
from app.tests.utils.collection import create_random_collection
from app.tests.utils.user import AuthenticatedUser, random_authenticated_user
from app.tests.utils.utils import random_lower_string
def test_store_query(client: TestClient, normal_user_token_headers: dict[str, str], db: Session) -> None:
user_query = select(User).where(User.email == settings.EMAIL_TEST_USER)
user_db = db.exec(user_query).first()
assert user_db
collection = create_random_collection(db, user_db)
query = "Quel est l'impact"
# Ensure that settings is in fake
data = {"query": query, "collection_ids": [str(collection.id)]}
def create_random_query(client: TestClient, authent_user: AuthenticatedUser, db: Session) -> Any:
collection = create_random_collection(db, authent_user.user)
query = random_lower_string() + random_lower_string()
data = {"query": query, "collection_id": str(collection.id)}
response = client.post(
f"{settings.API_V1_STR}/chat/message",
headers=normal_user_token_headers,
f"{settings.API_V1_STR}/chat/query",
headers=authent_user.token,
json=data,
)
assert response.status_code == 200
return response.json()
chat_query = select(ChatQuery).where(ChatQuery.user_id == user_db.id)
chat_query_db = db.exec(chat_query).first()
assert chat_query_db
assert chat_query_db.collection_id == collection.id
assert chat_query_db.query == query
def create_random_query(client: TestClient, authent_user: AuthenticatedUser, db: Session) -> str:
collection = create_random_collection(db, authent_user.user)
query = random_lower_string()
data = {"query": query, "collection_ids": [str(collection.id)]}
client.post(
f"{settings.API_V1_STR}/chat/message",
def test_store_query(client: TestClient, db: Session) -> None:
authent_user = random_authenticated_user(db)
query = create_random_query(client, authent_user, db)
query_json = {"query": query["query"], "collection_ids": [query["collection_id"]]}
response = client.post(
f"{settings.API_V1_STR}/chat/query/{query['id']}/message",
headers=authent_user.token,
json=data,
json=query_json,
)
return query
assert response.status_code == 200
chat_query = select(ChatQuery).where(ChatQuery.user_id == authent_user.user.id)
chat_query_db = db.exec(chat_query).first()
assert chat_query_db
assert str(chat_query_db.collection_id) == query["collection_id"]
assert chat_query_db.query == query["query"]
assert chat_query_db.title == query["title"]
# Long title are croped
assert chat_query_db.title[-3:] == "..."
assert chat_query_db.history == []
def test_chat_queries(client: TestClient, superuser_token_headers: dict[str, str], db: Session) -> None:
......@@ -61,8 +63,14 @@ def test_chat_queries(client: TestClient, superuser_token_headers: dict[str, str
)
assert response.status_code == 200
content = response.json()
assert len(content) == 2
assert content[0]["query"] == query_1
assert content[0]["user_id"] == str(authent_user_1.user.id)
assert content[1]["query"] == query_2
assert content[1]["user_id"] == str(authent_user_2.user.id)
assert len(content["items"]) == 2
items = content["items"]
assert items[0]["query"] == query_1["query"]
assert items[0]["user_id"] == str(authent_user_1.user.id)
assert "history" not in items[0]
assert items[1]["query"] == query_2["query"]
assert items[1]["user_id"] == str(authent_user_2.user.id)
def test_chat_history() -> None:
pass
import uuid
from sqlalchemy.engine import Engine
from sqlmodel import Session, text
from app.tests.migrations.conftest import create_collection, create_group, create_user, load_migration_as_module
# Load migration as module
migration = load_migration_as_module("2f61ad989231_add_chat_history.py")
rev_base = migration.down_revision
rev_head: str = migration.revision
def on_init(engine: Engine) -> None:
"""
Create rows in users table before migration is applied
"""
chat_id = uuid.uuid4()
with Session(engine) as session:
group_id = create_group(session)
user_id = create_user(session, group_id)
collection_id = create_collection(session, group_id, user_id)
query = text(f"""insert into public.chatquery(id, collection_id, user_id, query, created_at)
values('{chat_id}', '{collection_id}', '{user_id}', 'test', current_timestamp);""")
session.execute(query)
session.commit()
def on_upgrade(engine: Engine) -> None:
"""
Ensure that data was successfully migrated
"""
with Session(engine) as session:
query = text("select * from public.chatquery")
res = session.execute(query).first()
assert hasattr(res, "updated_at")
assert res and res.created_at == res.updated_at
assert res.title == res.query
assert res.history == []
def on_downgrade(engine: Engine) -> None:
"""
Ensure that data changes were rolled back
"""
with Session(engine) as session:
query = text("select * from public.chatquery")
res = session.execute(query).first()
assert res and not hasattr(res, "updated_at") and hasattr(res, "id")
assert not hasattr(res, "title")
assert not hasattr(res, "history")
......@@ -11,14 +11,21 @@ from alembic.config import Config
from sqlalchemy import Engine, create_engine
from app.tests.migrations.conftest import MigrationValidationParamsGroup, make_validation_params_groups
from app.tests.migrations.data_migrations import migration_11d362313208, migration_b6e2cf2d4ee5, migration_bb5b4cdab369
from app.tests.migrations.data_migrations import (
migration_2f61ad989231,
migration_11d362313208,
migration_b6e2cf2d4ee5,
migration_bb5b4cdab369,
)
def get_data_migrations() -> list[MigrationValidationParamsGroup]:
"""
Returns tests for data migrations, from tests/data_migrations folder.
"""
return make_validation_params_groups(migration_11d362313208, migration_b6e2cf2d4ee5, migration_bb5b4cdab369)
return make_validation_params_groups(
migration_11d362313208, migration_b6e2cf2d4ee5, migration_bb5b4cdab369, migration_2f61ad989231
)
@pytest.mark.parametrize(("rev_base", "rev_head", "on_init", "on_upgrade", "on_downgrade"), get_data_migrations())
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment