Files
Crumb-Core-v.1/app/routers/admin_rag.py

351 lines
9.9 KiB
Python

# app/routers/admin_rag.py
from fastapi import APIRouter, Depends, HTTPException
from pymysql.cursors import DictCursor
from typing import Dict
from deps import get_db, get_qdrant_client, admin_required
from config import get_settings
from models.rag_models import (
IndexRequest, IndexSingleRequest, IndexResponse,
SearchRequest, SearchResponse, SearchResult,
QueryRequest, QueryResponse, QuerySource,
StatusResponse, ProvidersResponse, ProviderInfo
)
from services.provider_factory import ProviderFactory
from services.rag_service import RAGService
router = APIRouter()
@router.post("/index", response_model=IndexResponse, name="rag_index_all")
def index_all_posts(
request: IndexRequest,
user = Depends(admin_required)
):
"""
Index all published posts to Qdrant.
Admin-only endpoint.
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Create provider
provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create RAG service
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=provider
)
# Index posts
result = rag_service.index_all_posts(locale=request.locale)
return IndexResponse(
status="success",
**result
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
finally:
db_conn.close()
@router.post("/index/{post_id}", response_model=IndexResponse, name="rag_index_single")
def index_single_post(
post_id: int,
request: IndexSingleRequest,
user = Depends(admin_required)
):
"""
Index a single post to Qdrant.
Admin-only endpoint.
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Fetch post from database
with db_conn.cursor(DictCursor) as cur:
cur.execute(
"SELECT id, title, slug, locale, body_md FROM posts WHERE id=%s",
(post_id,)
)
post = cur.fetchone()
if not post:
raise HTTPException(status_code=404, detail="Post not found")
# Create provider
provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create RAG service
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=provider
)
# Index post
result = rag_service.index_post(
post_id=post['id'],
title=post['title'],
slug=post['slug'],
locale=post['locale'],
body_md=post['body_md']
)
return IndexResponse(**result)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
finally:
db_conn.close()
@router.delete("/index/{post_id}", name="rag_delete_index")
def delete_post_index(
post_id: int,
locale: str,
user = Depends(admin_required)
):
"""
Remove a post from the vector index.
Admin-only endpoint.
"""
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Get any provider (we only need it for service initialization)
settings = get_settings()
available_providers = ProviderFactory.get_available_providers(settings)
if not available_providers:
raise HTTPException(status_code=500, detail="No providers configured")
provider = ProviderFactory.create_provider(
provider_name=available_providers[0],
settings=settings
)
# Create RAG service
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=provider
)
# Delete post index
result = rag_service.delete_post_index(post_id=post_id, locale=locale)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Deletion failed: {str(e)}")
finally:
db_conn.close()
@router.post("/search", response_model=SearchResponse, name="rag_search")
def search_posts(
request: SearchRequest,
user = Depends(admin_required)
):
"""
Semantic search across indexed posts.
Admin-only endpoint.
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Create provider
provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create RAG service
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=provider
)
# Search
results = rag_service.search_posts(
query=request.query,
locale=request.locale,
limit=request.limit
)
# Convert to response model
search_results = [SearchResult(**r) for r in results]
return SearchResponse(
results=search_results,
query=request.query,
locale=request.locale,
provider=request.provider
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
finally:
db_conn.close()
@router.post("/query", response_model=QueryResponse, name="rag_query")
def query_with_rag(
request: QueryRequest,
user = Depends(admin_required)
):
"""
RAG-powered question answering using post content.
Admin-only endpoint.
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Create embedding provider
embedding_provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create completion provider (may be same or different)
completion_provider = ProviderFactory.create_provider(
provider_name=request.provider,
settings=settings
)
# Create RAG service
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=embedding_provider,
completion_provider=completion_provider
)
# Query with RAG
result = rag_service.query_with_rag(
question=request.question,
locale=request.locale,
context_limit=request.context_limit
)
# Convert sources to response model
sources = [QuerySource(**s) for s in result['sources']]
return QueryResponse(
answer=result['answer'],
sources=sources,
provider=result['provider'],
model=result['model'],
question=request.question
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
finally:
db_conn.close()
@router.get("/status", response_model=StatusResponse, name="rag_status")
def get_indexing_status(user = Depends(admin_required)):
"""
Get indexing status across all collections.
Admin-only endpoint.
"""
settings = get_settings()
db_conn = get_db()
qdrant_client = get_qdrant_client()
try:
# Get any available provider (we only need it for service initialization)
available_providers = ProviderFactory.get_available_providers(settings)
if not available_providers:
raise HTTPException(status_code=500, detail="No providers configured")
provider = ProviderFactory.create_provider(
provider_name=available_providers[0],
settings=settings
)
# Create RAG service
rag_service = RAGService(
db_conn=db_conn,
qdrant_client=qdrant_client,
embedding_provider=provider
)
# Get status
status = rag_service.get_indexing_status()
return StatusResponse(**status)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
finally:
db_conn.close()
@router.get("/providers", response_model=ProvidersResponse, name="rag_providers")
def get_providers_status(user = Depends(admin_required)):
"""
List available providers and their status.
Admin-only endpoint.
"""
settings = get_settings()
try:
# Get provider status
provider_status = ProviderFactory.get_provider_status(settings)
# Convert to response model
providers = []
for name, info in provider_status.items():
provider_info = ProviderInfo(
name=name,
available=info.get('available', False),
provider_name=info.get('provider_name'),
model=info.get('model'),
dimension=info.get('dimension'),
supports_embeddings=info.get('supports_embeddings'),
supports_completions=info.get('supports_completions'),
error=info.get('error')
)
providers.append(provider_info)
return ProvidersResponse(providers=providers)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get providers: {str(e)}")