Files
Crumb-Core-v.1/app/routers/diary_rag.py

362 lines
12 KiB
Python

# app/routers/diary_rag.py
"""
Diary-specific RAG endpoints for indexing and searching child diary entries.
Each child has their own Qdrant collection: diary_child_{child_id}
"""
import json
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from pymysql.cursors import DictCursor
from typing import List
from deps import get_db, get_qdrant_client
from config import get_settings
from models.rag_models import (
DiaryIndexRequest, DiaryIndexResponse,
DiarySearchRequest, DiarySearchResponse, DiarySearchResult,
DiaryAskRequest, DiaryAskResponse
)
from services.provider_factory import ProviderFactory
from services.rag_service import RAGService
router = APIRouter()
def _get_diary_collection_name(child_id: int) -> str:
"""Generate collection name for a child's diary."""
return f"diary_child_{child_id}"
@router.post("/index", response_model=DiaryIndexResponse, name="diary_index")
async def index_diary_entry(
request: DiaryIndexRequest,
background_tasks: BackgroundTasks
):
"""
Index a diary entry into child-specific Qdrant collection.
Flow:
1. Verify diary entry exists in MySQL
2. Chunk content (markdown_chunker)
3. Generate embeddings (embedding_service)
4. Store in Qdrant collection: diary_child_{child_id}
5. Track in post_vectors (post_type='diary')
6. Audit log
Uses BackgroundTasks for async processing.
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Verify diary entry exists
with db_conn.cursor(DictCursor) as cur:
cur.execute(
"SELECT id, child_id, entry_text, created_at FROM diary_entries WHERE id=%s",
(request.entry_id,)
)
entry = cur.fetchone()
if not entry:
raise HTTPException(status_code=404, detail="Diary entry not found")
if entry['child_id'] != request.child_id:
raise HTTPException(status_code=403, detail="Child ID mismatch")
# Create provider
provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create RAG service with diary-specific collection prefix
collection_prefix = f"diary_child_{request.child_id}"
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=provider,
collection_prefix=collection_prefix
)
# Index as a "post" with locale as empty string (diaries are locale-agnostic)
result = rag_service.index_post(
post_id=request.entry_id,
title=f"Diary Entry {request.entry_id}",
slug=f"diary-{request.entry_id}",
locale="", # Diary entries don't have locale
body_md=request.content or entry['entry_text']
)
# Update post_vectors to mark as diary type
with db_conn.cursor(DictCursor) as cur:
cur.execute(
"""
UPDATE post_vectors
SET post_type='diary', child_id=%s
WHERE post_id=%s AND collection_name=%s
""",
(request.child_id, request.entry_id, result['collection'])
)
# Audit log
cur.execute(
"""
INSERT INTO audit_log (action, entity_type, entity_id, user_id, metadata)
VALUES ('diary_indexed', 'diary_entry', %s, NULL, %s)
""",
(request.entry_id, json.dumps({
'child_id': request.child_id,
'provider': request.provider,
'chunks': result.get('chunks', 0)
}))
)
return DiaryIndexResponse(
status="success",
entry_id=request.entry_id,
child_id=request.child_id,
chunks=result.get('chunks'),
collection=result.get('collection'),
provider=result.get('provider')
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
finally:
db_conn.close()
@router.post("/search", response_model=DiarySearchResponse, name="diary_search")
async def search_diary(
request: DiarySearchRequest
):
"""
Semantic search in a child's diary.
Flow:
1. Generate query embedding
2. Search in Qdrant: diary_child_{child_id}
3. Return top results with scores
4. Log query in audit_log
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Create provider
provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create RAG service with diary-specific collection
collection_prefix = f"diary_child_{request.child_id}"
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=provider,
collection_prefix=collection_prefix
)
# Search (use empty locale for diaries)
results = rag_service.search_posts(
query=request.query,
locale="",
limit=request.limit
)
# Convert to diary search results
diary_results = []
for r in results:
# Fetch created_at from database
with db_conn.cursor(DictCursor) as cur:
cur.execute(
"SELECT created_at FROM diary_entries WHERE id=%s",
(r['post_id'],)
)
entry = cur.fetchone()
diary_results.append(DiarySearchResult(
entry_id=r['post_id'],
content=r['content'],
score=r['score'],
created_at=entry['created_at'].isoformat() if entry else None
))
# Audit log
with db_conn.cursor(DictCursor) as cur:
cur.execute(
"""
INSERT INTO audit_log (action, entity_type, entity_id, user_id, metadata)
VALUES ('diary_searched', 'child', %s, NULL, %s)
""",
(request.child_id, json.dumps({
'query': request.query,
'provider': request.provider,
'results_count': len(diary_results)
}))
)
return DiarySearchResponse(
results=diary_results,
query=request.query,
child_id=request.child_id,
provider=request.provider
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
finally:
db_conn.close()
@router.post("/ask", response_model=DiaryAskResponse, name="diary_ask")
async def ask_diary(
request: DiaryAskRequest
):
"""
RAG-powered question answering using child's diary entries.
Flow:
1. Search relevant entries (top N)
2. Build context from entries
3. Get completion from provider
4. Return answer + sources
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Create providers (may use different ones for embedding vs completion)
embedding_provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
completion_provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create RAG service
collection_prefix = f"diary_child_{request.child_id}"
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=embedding_provider,
completion_provider=completion_provider,
collection_prefix=collection_prefix
)
# Perform RAG query
result = rag_service.query_with_rag(
question=request.question,
locale="", # Diaries don't have locale
context_limit=request.context_limit
)
# Convert sources to diary results
sources = []
for source in result['sources']:
# Fetch created_at from database
with db_conn.cursor(DictCursor) as cur:
cur.execute(
"SELECT entry_text, created_at FROM diary_entries WHERE id=%s",
(source['post_id'],)
)
entry = cur.fetchone()
sources.append(DiarySearchResult(
entry_id=source['post_id'],
content=entry['entry_text'][:200] if entry else "", # Preview
score=source['score'],
created_at=entry['created_at'].isoformat() if entry else None
))
# Audit log
with db_conn.cursor(DictCursor) as cur:
cur.execute(
"""
INSERT INTO audit_log (action, entity_type, entity_id, user_id, metadata)
VALUES ('diary_rag_query', 'child', %s, NULL, %s)
""",
(request.child_id, json.dumps({
'question': request.question,
'provider': request.provider,
'sources_count': len(sources)
}))
)
return DiaryAskResponse(
answer=result['answer'],
question=request.question,
child_id=request.child_id,
sources=sources,
provider=result['provider'],
model=result['model']
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"RAG query failed: {str(e)}")
finally:
db_conn.close()
@router.get("/{child_id}/status", name="diary_status")
async def get_diary_status(child_id: int):
"""
Get indexing status for a child's diary.
Returns:
- Total diary entries
- Indexed entries
- Collection info
"""
db_conn = get_db()
try:
with db_conn.cursor(DictCursor) as cur:
# Total entries for this child
cur.execute(
"SELECT COUNT(*) as total FROM diary_entries WHERE child_id=%s",
(child_id,)
)
total = cur.fetchone()['total']
# Indexed entries
collection_name = _get_diary_collection_name(child_id)
cur.execute(
"""
SELECT COUNT(DISTINCT post_id) as indexed,
SUM(chunk_count) as total_vectors,
MAX(indexed_at) as last_indexed
FROM post_vectors
WHERE collection_name=%s AND post_type='diary' AND child_id=%s
""",
(collection_name, child_id)
)
indexed_data = cur.fetchone()
return {
"child_id": child_id,
"total_entries": total,
"indexed_entries": indexed_data['indexed'] or 0,
"total_vectors": indexed_data['total_vectors'] or 0,
"last_indexed": indexed_data['last_indexed'].isoformat() if indexed_data['last_indexed'] else None,
"collection_name": collection_name
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
finally:
db_conn.close()