# app/routers/diary_rag.py """ Diary-specific RAG endpoints for indexing and searching child diary entries. Each child has their own Qdrant collection: diary_child_{child_id} """ import json from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from pymysql.cursors import DictCursor from typing import List from deps import get_db, get_qdrant_client from config import get_settings from models.rag_models import ( DiaryIndexRequest, DiaryIndexResponse, DiarySearchRequest, DiarySearchResponse, DiarySearchResult, DiaryAskRequest, DiaryAskResponse ) from services.provider_factory import ProviderFactory from services.rag_service import RAGService router = APIRouter() def _get_diary_collection_name(child_id: int) -> str: """Generate collection name for a child's diary.""" return f"diary_child_{child_id}" @router.post("/index", response_model=DiaryIndexResponse, name="diary_index") async def index_diary_entry( request: DiaryIndexRequest, background_tasks: BackgroundTasks ): """ Index a diary entry into child-specific Qdrant collection. Flow: 1. Verify diary entry exists in MySQL 2. Chunk content (markdown_chunker) 3. Generate embeddings (embedding_service) 4. Store in Qdrant collection: diary_child_{child_id} 5. Track in post_vectors (post_type='diary') 6. Audit log Uses BackgroundTasks for async processing. """ settings = get_settings() db_conn = get_db() qdrant_client = get_qdrant_client() try: # Verify diary entry exists with db_conn.cursor(DictCursor) as cur: cur.execute( "SELECT id, child_id, entry_text, created_at FROM diary_entries WHERE id=%s", (request.entry_id,) ) entry = cur.fetchone() if not entry: raise HTTPException(status_code=404, detail="Diary entry not found") if entry['child_id'] != request.child_id: raise HTTPException(status_code=403, detail="Child ID mismatch") # Create provider provider = ProviderFactory.create_provider( provider_name=request.provider, settings=settings ) # Create RAG service with diary-specific collection prefix collection_prefix = f"diary_child_{request.child_id}" rag_service = RAGService( db_conn=db_conn, qdrant_client=qdrant_client, embedding_provider=provider, collection_prefix=collection_prefix ) # Index as a "post" with locale as empty string (diaries are locale-agnostic) result = rag_service.index_post( post_id=request.entry_id, title=f"Diary Entry {request.entry_id}", slug=f"diary-{request.entry_id}", locale="", # Diary entries don't have locale body_md=request.content or entry['entry_text'] ) # Update post_vectors to mark as diary type with db_conn.cursor(DictCursor) as cur: cur.execute( """ UPDATE post_vectors SET post_type='diary', child_id=%s WHERE post_id=%s AND collection_name=%s """, (request.child_id, request.entry_id, result['collection']) ) # Audit log cur.execute( """ INSERT INTO audit_log (action, entity_type, entity_id, user_id, metadata) VALUES ('diary_indexed', 'diary_entry', %s, NULL, %s) """, (request.entry_id, json.dumps({ 'child_id': request.child_id, 'provider': request.provider, 'chunks': result.get('chunks', 0) })) ) return DiaryIndexResponse( status="success", entry_id=request.entry_id, child_id=request.child_id, chunks=result.get('chunks'), collection=result.get('collection'), provider=result.get('provider') ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}") finally: db_conn.close() @router.post("/search", response_model=DiarySearchResponse, name="diary_search") async def search_diary( request: DiarySearchRequest ): """ Semantic search in a child's diary. Flow: 1. Generate query embedding 2. Search in Qdrant: diary_child_{child_id} 3. Return top results with scores 4. Log query in audit_log """ settings = get_settings() db_conn = get_db() qdrant_client = get_qdrant_client() try: # Create provider provider = ProviderFactory.create_provider( provider_name=request.provider, settings=settings ) # Create RAG service with diary-specific collection collection_prefix = f"diary_child_{request.child_id}" rag_service = RAGService( db_conn=db_conn, qdrant_client=qdrant_client, embedding_provider=provider, collection_prefix=collection_prefix ) # Search (use empty locale for diaries) results = rag_service.search_posts( query=request.query, locale="", limit=request.limit ) # Convert to diary search results diary_results = [] for r in results: # Fetch created_at from database with db_conn.cursor(DictCursor) as cur: cur.execute( "SELECT created_at FROM diary_entries WHERE id=%s", (r['post_id'],) ) entry = cur.fetchone() diary_results.append(DiarySearchResult( entry_id=r['post_id'], content=r['content'], score=r['score'], created_at=entry['created_at'].isoformat() if entry else None )) # Audit log with db_conn.cursor(DictCursor) as cur: cur.execute( """ INSERT INTO audit_log (action, entity_type, entity_id, user_id, metadata) VALUES ('diary_searched', 'child', %s, NULL, %s) """, (request.child_id, json.dumps({ 'query': request.query, 'provider': request.provider, 'results_count': len(diary_results) })) ) return DiarySearchResponse( results=diary_results, query=request.query, child_id=request.child_id, provider=request.provider ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") finally: db_conn.close() @router.post("/ask", response_model=DiaryAskResponse, name="diary_ask") async def ask_diary( request: DiaryAskRequest ): """ RAG-powered question answering using child's diary entries. Flow: 1. Search relevant entries (top N) 2. Build context from entries 3. Get completion from provider 4. Return answer + sources """ settings = get_settings() db_conn = get_db() qdrant_client = get_qdrant_client() try: # Create providers (may use different ones for embedding vs completion) embedding_provider = ProviderFactory.create_provider( provider_name=request.provider, settings=settings ) completion_provider = ProviderFactory.create_provider( provider_name=request.provider, settings=settings ) # Create RAG service collection_prefix = f"diary_child_{request.child_id}" rag_service = RAGService( db_conn=db_conn, qdrant_client=qdrant_client, embedding_provider=embedding_provider, completion_provider=completion_provider, collection_prefix=collection_prefix ) # Perform RAG query result = rag_service.query_with_rag( question=request.question, locale="", # Diaries don't have locale context_limit=request.context_limit ) # Convert sources to diary results sources = [] for source in result['sources']: # Fetch created_at from database with db_conn.cursor(DictCursor) as cur: cur.execute( "SELECT entry_text, created_at FROM diary_entries WHERE id=%s", (source['post_id'],) ) entry = cur.fetchone() sources.append(DiarySearchResult( entry_id=source['post_id'], content=entry['entry_text'][:200] if entry else "", # Preview score=source['score'], created_at=entry['created_at'].isoformat() if entry else None )) # Audit log with db_conn.cursor(DictCursor) as cur: cur.execute( """ INSERT INTO audit_log (action, entity_type, entity_id, user_id, metadata) VALUES ('diary_rag_query', 'child', %s, NULL, %s) """, (request.child_id, json.dumps({ 'question': request.question, 'provider': request.provider, 'sources_count': len(sources) })) ) return DiaryAskResponse( answer=result['answer'], question=request.question, child_id=request.child_id, sources=sources, provider=result['provider'], model=result['model'] ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"RAG query failed: {str(e)}") finally: db_conn.close() @router.get("/{child_id}/status", name="diary_status") async def get_diary_status(child_id: int): """ Get indexing status for a child's diary. Returns: - Total diary entries - Indexed entries - Collection info """ db_conn = get_db() try: with db_conn.cursor(DictCursor) as cur: # Total entries for this child cur.execute( "SELECT COUNT(*) as total FROM diary_entries WHERE child_id=%s", (child_id,) ) total = cur.fetchone()['total'] # Indexed entries collection_name = _get_diary_collection_name(child_id) cur.execute( """ SELECT COUNT(DISTINCT post_id) as indexed, SUM(chunk_count) as total_vectors, MAX(indexed_at) as last_indexed FROM post_vectors WHERE collection_name=%s AND post_type='diary' AND child_id=%s """, (collection_name, child_id) ) indexed_data = cur.fetchone() return { "child_id": child_id, "total_entries": total, "indexed_entries": indexed_data['indexed'] or 0, "total_vectors": indexed_data['total_vectors'] or 0, "last_indexed": indexed_data['last_indexed'].isoformat() if indexed_data['last_indexed'] else None, "collection_name": collection_name } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}") finally: db_conn.close()