351 lines
9.9 KiB
Python
351 lines
9.9 KiB
Python
# app/routers/admin_rag.py
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pymysql.cursors import DictCursor
|
|
from typing import Dict
|
|
|
|
from deps import get_db, get_qdrant_client, admin_required
|
|
from config import get_settings
|
|
from models.rag_models import (
|
|
IndexRequest, IndexSingleRequest, IndexResponse,
|
|
SearchRequest, SearchResponse, SearchResult,
|
|
QueryRequest, QueryResponse, QuerySource,
|
|
StatusResponse, ProvidersResponse, ProviderInfo
|
|
)
|
|
from services.provider_factory import ProviderFactory
|
|
from services.rag_service import RAGService
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/index", response_model=IndexResponse, name="rag_index_all")
|
|
def index_all_posts(
|
|
request: IndexRequest,
|
|
user = Depends(admin_required)
|
|
):
|
|
"""
|
|
Index all published posts to Qdrant.
|
|
Admin-only endpoint.
|
|
"""
|
|
settings = get_settings()
|
|
db_conn = get_db()
|
|
qdrant_client = get_qdrant_client()
|
|
|
|
try:
|
|
# Create provider
|
|
provider = ProviderFactory.create_provider(
|
|
provider_name=request.provider,
|
|
settings=settings
|
|
)
|
|
|
|
# Create RAG service
|
|
rag_service = RAGService(
|
|
db_conn=db_conn,
|
|
qdrant_client=qdrant_client,
|
|
embedding_provider=provider
|
|
)
|
|
|
|
# Index posts
|
|
result = rag_service.index_all_posts(locale=request.locale)
|
|
|
|
return IndexResponse(
|
|
status="success",
|
|
**result
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
|
|
finally:
|
|
db_conn.close()
|
|
|
|
|
|
@router.post("/index/{post_id}", response_model=IndexResponse, name="rag_index_single")
|
|
def index_single_post(
|
|
post_id: int,
|
|
request: IndexSingleRequest,
|
|
user = Depends(admin_required)
|
|
):
|
|
"""
|
|
Index a single post to Qdrant.
|
|
Admin-only endpoint.
|
|
"""
|
|
settings = get_settings()
|
|
db_conn = get_db()
|
|
qdrant_client = get_qdrant_client()
|
|
|
|
try:
|
|
# Fetch post from database
|
|
with db_conn.cursor(DictCursor) as cur:
|
|
cur.execute(
|
|
"SELECT id, title, slug, locale, body_md FROM posts WHERE id=%s",
|
|
(post_id,)
|
|
)
|
|
post = cur.fetchone()
|
|
|
|
if not post:
|
|
raise HTTPException(status_code=404, detail="Post not found")
|
|
|
|
# Create provider
|
|
provider = ProviderFactory.create_provider(
|
|
provider_name=request.provider,
|
|
settings=settings
|
|
)
|
|
|
|
# Create RAG service
|
|
rag_service = RAGService(
|
|
db_conn=db_conn,
|
|
qdrant_client=qdrant_client,
|
|
embedding_provider=provider
|
|
)
|
|
|
|
# Index post
|
|
result = rag_service.index_post(
|
|
post_id=post['id'],
|
|
title=post['title'],
|
|
slug=post['slug'],
|
|
locale=post['locale'],
|
|
body_md=post['body_md']
|
|
)
|
|
|
|
return IndexResponse(**result)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
|
|
finally:
|
|
db_conn.close()
|
|
|
|
|
|
@router.delete("/index/{post_id}", name="rag_delete_index")
|
|
def delete_post_index(
|
|
post_id: int,
|
|
locale: str,
|
|
user = Depends(admin_required)
|
|
):
|
|
"""
|
|
Remove a post from the vector index.
|
|
Admin-only endpoint.
|
|
"""
|
|
db_conn = get_db()
|
|
qdrant_client = get_qdrant_client()
|
|
|
|
try:
|
|
# Get any provider (we only need it for service initialization)
|
|
settings = get_settings()
|
|
available_providers = ProviderFactory.get_available_providers(settings)
|
|
|
|
if not available_providers:
|
|
raise HTTPException(status_code=500, detail="No providers configured")
|
|
|
|
provider = ProviderFactory.create_provider(
|
|
provider_name=available_providers[0],
|
|
settings=settings
|
|
)
|
|
|
|
# Create RAG service
|
|
rag_service = RAGService(
|
|
db_conn=db_conn,
|
|
qdrant_client=qdrant_client,
|
|
embedding_provider=provider
|
|
)
|
|
|
|
# Delete post index
|
|
result = rag_service.delete_post_index(post_id=post_id, locale=locale)
|
|
|
|
return result
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Deletion failed: {str(e)}")
|
|
finally:
|
|
db_conn.close()
|
|
|
|
|
|
@router.post("/search", response_model=SearchResponse, name="rag_search")
|
|
def search_posts(
|
|
request: SearchRequest,
|
|
user = Depends(admin_required)
|
|
):
|
|
"""
|
|
Semantic search across indexed posts.
|
|
Admin-only endpoint.
|
|
"""
|
|
settings = get_settings()
|
|
db_conn = get_db()
|
|
qdrant_client = get_qdrant_client()
|
|
|
|
try:
|
|
# Create provider
|
|
provider = ProviderFactory.create_provider(
|
|
provider_name=request.provider,
|
|
settings=settings
|
|
)
|
|
|
|
# Create RAG service
|
|
rag_service = RAGService(
|
|
db_conn=db_conn,
|
|
qdrant_client=qdrant_client,
|
|
embedding_provider=provider
|
|
)
|
|
|
|
# Search
|
|
results = rag_service.search_posts(
|
|
query=request.query,
|
|
locale=request.locale,
|
|
limit=request.limit
|
|
)
|
|
|
|
# Convert to response model
|
|
search_results = [SearchResult(**r) for r in results]
|
|
|
|
return SearchResponse(
|
|
results=search_results,
|
|
query=request.query,
|
|
locale=request.locale,
|
|
provider=request.provider
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
|
finally:
|
|
db_conn.close()
|
|
|
|
|
|
@router.post("/query", response_model=QueryResponse, name="rag_query")
|
|
def query_with_rag(
|
|
request: QueryRequest,
|
|
user = Depends(admin_required)
|
|
):
|
|
"""
|
|
RAG-powered question answering using post content.
|
|
Admin-only endpoint.
|
|
"""
|
|
settings = get_settings()
|
|
db_conn = get_db()
|
|
qdrant_client = get_qdrant_client()
|
|
|
|
try:
|
|
# Create embedding provider
|
|
embedding_provider = ProviderFactory.create_provider(
|
|
provider_name=request.provider,
|
|
settings=settings
|
|
)
|
|
|
|
# Create completion provider (may be same or different)
|
|
completion_provider = ProviderFactory.create_provider(
|
|
provider_name=request.provider,
|
|
settings=settings
|
|
)
|
|
|
|
# Create RAG service
|
|
rag_service = RAGService(
|
|
db_conn=db_conn,
|
|
qdrant_client=qdrant_client,
|
|
embedding_provider=embedding_provider,
|
|
completion_provider=completion_provider
|
|
)
|
|
|
|
# Query with RAG
|
|
result = rag_service.query_with_rag(
|
|
question=request.question,
|
|
locale=request.locale,
|
|
context_limit=request.context_limit
|
|
)
|
|
|
|
# Convert sources to response model
|
|
sources = [QuerySource(**s) for s in result['sources']]
|
|
|
|
return QueryResponse(
|
|
answer=result['answer'],
|
|
sources=sources,
|
|
provider=result['provider'],
|
|
model=result['model'],
|
|
question=request.question
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
|
|
finally:
|
|
db_conn.close()
|
|
|
|
|
|
@router.get("/status", response_model=StatusResponse, name="rag_status")
|
|
def get_indexing_status(user = Depends(admin_required)):
|
|
"""
|
|
Get indexing status across all collections.
|
|
Admin-only endpoint.
|
|
"""
|
|
settings = get_settings()
|
|
db_conn = get_db()
|
|
qdrant_client = get_qdrant_client()
|
|
|
|
try:
|
|
# Get any available provider (we only need it for service initialization)
|
|
available_providers = ProviderFactory.get_available_providers(settings)
|
|
|
|
if not available_providers:
|
|
raise HTTPException(status_code=500, detail="No providers configured")
|
|
|
|
provider = ProviderFactory.create_provider(
|
|
provider_name=available_providers[0],
|
|
settings=settings
|
|
)
|
|
|
|
# Create RAG service
|
|
rag_service = RAGService(
|
|
db_conn=db_conn,
|
|
qdrant_client=qdrant_client,
|
|
embedding_provider=provider
|
|
)
|
|
|
|
# Get status
|
|
status = rag_service.get_indexing_status()
|
|
|
|
return StatusResponse(**status)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to get status: {str(e)}")
|
|
finally:
|
|
db_conn.close()
|
|
|
|
|
|
@router.get("/providers", response_model=ProvidersResponse, name="rag_providers")
|
|
def get_providers_status(user = Depends(admin_required)):
|
|
"""
|
|
List available providers and their status.
|
|
Admin-only endpoint.
|
|
"""
|
|
settings = get_settings()
|
|
|
|
try:
|
|
# Get provider status
|
|
provider_status = ProviderFactory.get_provider_status(settings)
|
|
|
|
# Convert to response model
|
|
providers = []
|
|
for name, info in provider_status.items():
|
|
provider_info = ProviderInfo(
|
|
name=name,
|
|
available=info.get('available', False),
|
|
provider_name=info.get('provider_name'),
|
|
model=info.get('model'),
|
|
dimension=info.get('dimension'),
|
|
supports_embeddings=info.get('supports_embeddings'),
|
|
supports_completions=info.get('supports_completions'),
|
|
error=info.get('error')
|
|
)
|
|
providers.append(provider_info)
|
|
|
|
return ProvidersResponse(providers=providers)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to get providers: {str(e)}")
|