# app/services/rag_service.py import json import hashlib import uuid from typing import List, Dict, Any, Optional from datetime import datetime from pymysql import Connection from pymysql.cursors import DictCursor from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams, PointStruct from slugify import slugify from lib.embedding_providers.base import BaseProvider from services.embedding_service import EmbeddingService class RAGService: """ Service for RAG operations: indexing, searching, and query answering. """ def __init__( self, db_conn: Connection, qdrant_client: QdrantClient, embedding_provider: BaseProvider, completion_provider: Optional[BaseProvider] = None, collection_prefix: str = "posts" ): """ Initialize RAG service. Args: db_conn: Database connection qdrant_client: Qdrant client instance embedding_provider: Provider for generating embeddings completion_provider: Provider for completions (defaults to embedding_provider) collection_prefix: Prefix for Qdrant collections """ self.db_conn = db_conn self.qdrant = qdrant_client self.embedding_service = EmbeddingService(provider=embedding_provider) self.completion_provider = completion_provider or embedding_provider self.collection_prefix = collection_prefix def _get_collection_name(self, locale: str) -> str: """Generate collection name for a given locale.""" if not locale: return self.collection_prefix return f"{self.collection_prefix}_{locale}" def _ensure_collection_exists(self, locale: str) -> str: """ Ensure Qdrant collection exists for the given locale. Returns: Collection name """ collection_name = self._get_collection_name(locale) collections = [col.name for col in self.qdrant.get_collections().collections] if collection_name not in collections: # Create collection with appropriate vector dimension self.qdrant.create_collection( collection_name=collection_name, vectors_config=VectorParams( size=self.embedding_service.get_embedding_dimension(), distance=Distance.COSINE ) ) return collection_name def _get_content_hash(self, content: str) -> str: """Generate MD5 hash of content for change detection.""" return hashlib.md5(content.encode('utf-8')).hexdigest() def index_post( self, post_id: int, title: str, slug: str, locale: str, body_md: str, extra_payload: Optional[Dict[str, Any]] = None, force: bool = False ) -> Dict[str, Any]: """ Index a single post into Qdrant. Args: post_id: Database ID of the post title: Post title slug: Post slug locale: Post locale (de, en, etc.) body_md: Markdown content extra_payload: Additional metadata to store in Qdrant force: If True, bypass content hash check Returns: Dictionary with indexing results """ # Ensure collection exists collection_name = self._ensure_collection_exists(locale) # Check if already indexed and content unchanged content_hash = self._get_content_hash(body_md or "") if not force: with self.db_conn.cursor(DictCursor) as cur: cur.execute( "SELECT file_hash FROM post_vectors WHERE post_id=%s AND collection_name=%s", (post_id, collection_name) ) existing = cur.fetchone() if existing and existing['file_hash'] == content_hash: return { 'post_id': post_id, 'status': 'unchanged', 'message': 'Post content unchanged, skipping re-indexing', 'collection': collection_name } # Chunk and embed the content chunks_with_embeddings = self.embedding_service.chunk_and_embed_post( post_content=body_md, post_id=post_id, post_title=title ) if not chunks_with_embeddings: return { 'post_id': post_id, 'status': 'skipped', 'message': 'No content to index', 'collection': collection_name } # Prepare points for Qdrant points = [] vector_ids = [] for chunk in chunks_with_embeddings: # Generate deterministic UUID based on post_id and chunk_position point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{post_id}_{chunk['chunk_position']}")) vector_ids.append(point_id) payload = { 'post_id': post_id, 'title': title, 'slug': slug, 'locale': locale, 'content': chunk['content'], 'header': chunk.get('header', ''), 'header_level': chunk.get('header_level', 0), 'chunk_position': chunk['chunk_position'] } if extra_payload: payload.update(extra_payload) points.append(PointStruct( id=point_id, vector=chunk['embedding'], payload=payload )) # Upsert points to Qdrant self.qdrant.upsert( collection_name=collection_name, points=points ) # Update tracking in database provider_info = self.embedding_service.get_provider_info() with self.db_conn.cursor(DictCursor) as cur: cur.execute( """ INSERT INTO post_vectors (post_id, collection_name, vector_ids, provider, model, chunk_count, file_hash) VALUES (%s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE vector_ids=VALUES(vector_ids), provider=VALUES(provider), model=VALUES(model), chunk_count=VALUES(chunk_count), file_hash=VALUES(file_hash), indexed_at=CURRENT_TIMESTAMP """, ( post_id, collection_name, json.dumps(vector_ids), provider_info['provider'], provider_info['model'], len(chunks_with_embeddings), content_hash ) ) return { 'post_id': post_id, 'status': 'indexed', 'chunks': len(chunks_with_embeddings), 'collection': collection_name, 'provider': provider_info['provider'] } def index_all_posts(self, locale: Optional[str] = None) -> Dict[str, Any]: """ Index all published posts. Args: locale: Optional locale filter (de, en, or None for all) Returns: Dictionary with indexing results """ # Fetch posts from database with self.db_conn.cursor(DictCursor) as cur: if locale: cur.execute( "SELECT id, title, slug, locale, body_md FROM posts WHERE is_published=1 AND locale=%s", (locale,) ) else: cur.execute( "SELECT id, title, slug, locale, body_md FROM posts WHERE is_published=1" ) posts = cur.fetchall() if not posts: return { 'indexed': 0, 'errors': 0, 'unchanged': 0, 'message': 'No posts found to index' } # Index each post results = { 'indexed': 0, 'errors': 0, 'unchanged': 0, 'total': len(posts) } for post in posts: try: result = self.index_post( post_id=post['id'], title=post['title'], slug=post['slug'], locale=post['locale'], body_md=post['body_md'] ) if result['status'] == 'indexed': results['indexed'] += 1 elif result['status'] == 'unchanged': results['unchanged'] += 1 else: results['errors'] += 1 except Exception as e: results['errors'] += 1 # Log error but continue with other posts print(f"Error indexing post {post['id']}: {str(e)}") return results def delete_post_index(self, post_id: int, locale: str) -> Dict[str, Any]: """ Remove a post from the vector index. Args: post_id: Database ID of the post locale: Post locale Returns: Dictionary with deletion results """ collection_name = self._get_collection_name(locale) # Get vector IDs from database with self.db_conn.cursor(DictCursor) as cur: cur.execute( "SELECT vector_ids FROM post_vectors WHERE post_id=%s AND collection_name=%s", (post_id, collection_name) ) result = cur.fetchone() if not result: return { 'post_id': post_id, 'status': 'not_found', 'message': 'Post not indexed' } vector_ids = json.loads(result['vector_ids']) # Delete from Qdrant self.qdrant.delete( collection_name=collection_name, points_selector=vector_ids ) # Delete from tracking table cur.execute( "DELETE FROM post_vectors WHERE post_id=%s AND collection_name=%s", (post_id, collection_name) ) return { 'post_id': post_id, 'status': 'deleted', 'vectors_deleted': len(vector_ids) } def search_posts( self, query: str, locale: str, limit: int = 5 ) -> List[Dict[str, Any]]: """ Semantic search across indexed posts. Args: query: Search query locale: Locale to search in limit: Maximum number of results Returns: List of search results with scores """ collection_name = self._get_collection_name(locale) # Generate query embedding query_embedding = self.embedding_service.embed_texts([query])[0] # Search in Qdrant using new Query API response = self.qdrant.query_points( collection_name=collection_name, query=query_embedding, limit=limit ) search_results = response.points # Format results results = [] for hit in search_results: results.append({ 'post_id': hit.payload['post_id'], 'title': hit.payload['title'], 'slug': hit.payload['slug'], 'content': hit.payload['content'], 'header': hit.payload.get('header', ''), 'score': hit.score }) return results def query_with_rag( self, question: str, locale: str, context_limit: int = 3 ) -> Dict[str, Any]: """ Answer a question using RAG (Retrieval-Augmented Generation). Args: question: User's question locale: Locale to search in context_limit: Number of context chunks to retrieve Returns: Dictionary with answer and sources """ # Retrieve relevant context search_results = self.search_posts( query=question, locale=locale, limit=context_limit ) if not search_results: return { 'answer': 'No relevant information found in the knowledge base.', 'sources': [], 'provider': self.completion_provider.provider_name } # Build context from search results context_parts = [] for idx, result in enumerate(search_results, 1): context_parts.append( f"[Source {idx}: {result['title']} - {result['header']}]\n{result['content']}" ) context = "\n\n".join(context_parts) # Generate answer answer = self.completion_provider.get_completion( prompt=question, context=context ) # Format sources sources = [ { 'post_id': r['post_id'], 'title': r['title'], 'slug': r['slug'], 'score': r['score'] } for r in search_results ] return { 'answer': answer, 'sources': sources, 'provider': self.completion_provider.provider_name, 'model': self.completion_provider.model_name } def get_indexing_status(self) -> Dict[str, Any]: """ Get indexing status across all collections. Returns: Dictionary with status information """ status = {} # Get total and indexed post counts from database with self.db_conn.cursor(DictCursor) as cur: cur.execute("SELECT COUNT(*) as total FROM posts WHERE is_published=1") total_posts = cur.fetchone()['total'] cur.execute("SELECT COUNT(DISTINCT post_id) as indexed FROM post_vectors") indexed_posts = cur.fetchone()['indexed'] # Get collection statistics cur.execute( """ SELECT collection_name, COUNT(*) as post_count, SUM(chunk_count) as total_vectors, MAX(indexed_at) as last_indexed FROM post_vectors GROUP BY collection_name """ ) collections = cur.fetchall() status['total_posts'] = total_posts status['indexed_posts'] = indexed_posts status['collections'] = {} for col in collections: status['collections'][col['collection_name']] = { 'posts': col['post_count'], 'vectors': col['total_vectors'], 'last_indexed': col['last_indexed'].isoformat() if col['last_indexed'] else None } return status