- Upgrade qdrant-client usage to query_points (replacing deprecated search) - Relax qdrant-client version constraint in requirements.txt (>=1.7.0) - Config: Allow extra environment variables in Settings (for local dev) - test_roles_rag.py: Fix import path for local execution
465 lines
15 KiB
Python
465 lines
15 KiB
Python
# app/services/rag_service.py
|
|
import json
|
|
import hashlib
|
|
import uuid
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
from pymysql import Connection
|
|
from pymysql.cursors import DictCursor
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import Distance, VectorParams, PointStruct
|
|
from slugify import slugify
|
|
|
|
from lib.embedding_providers.base import BaseProvider
|
|
from services.embedding_service import EmbeddingService
|
|
|
|
|
|
class RAGService:
|
|
"""
|
|
Service for RAG operations: indexing, searching, and query answering.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_conn: Connection,
|
|
qdrant_client: QdrantClient,
|
|
embedding_provider: BaseProvider,
|
|
completion_provider: Optional[BaseProvider] = None,
|
|
collection_prefix: str = "posts"
|
|
):
|
|
"""
|
|
Initialize RAG service.
|
|
|
|
Args:
|
|
db_conn: Database connection
|
|
qdrant_client: Qdrant client instance
|
|
embedding_provider: Provider for generating embeddings
|
|
completion_provider: Provider for completions (defaults to embedding_provider)
|
|
collection_prefix: Prefix for Qdrant collections
|
|
"""
|
|
self.db_conn = db_conn
|
|
self.qdrant = qdrant_client
|
|
self.embedding_service = EmbeddingService(provider=embedding_provider)
|
|
self.completion_provider = completion_provider or embedding_provider
|
|
self.collection_prefix = collection_prefix
|
|
|
|
def _get_collection_name(self, locale: str) -> str:
|
|
"""Generate collection name for a given locale."""
|
|
if not locale:
|
|
return self.collection_prefix
|
|
return f"{self.collection_prefix}_{locale}"
|
|
|
|
def _ensure_collection_exists(self, locale: str) -> str:
|
|
"""
|
|
Ensure Qdrant collection exists for the given locale.
|
|
|
|
Returns:
|
|
Collection name
|
|
"""
|
|
collection_name = self._get_collection_name(locale)
|
|
collections = [col.name for col in self.qdrant.get_collections().collections]
|
|
|
|
if collection_name not in collections:
|
|
# Create collection with appropriate vector dimension
|
|
self.qdrant.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=VectorParams(
|
|
size=self.embedding_service.get_embedding_dimension(),
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
|
|
return collection_name
|
|
|
|
def _get_content_hash(self, content: str) -> str:
|
|
"""Generate MD5 hash of content for change detection."""
|
|
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
|
|
|
def index_post(
|
|
self,
|
|
post_id: int,
|
|
title: str,
|
|
slug: str,
|
|
locale: str,
|
|
body_md: str,
|
|
extra_payload: Optional[Dict[str, Any]] = None,
|
|
force: bool = False
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Index a single post into Qdrant.
|
|
|
|
Args:
|
|
post_id: Database ID of the post
|
|
title: Post title
|
|
slug: Post slug
|
|
locale: Post locale (de, en, etc.)
|
|
body_md: Markdown content
|
|
extra_payload: Additional metadata to store in Qdrant
|
|
force: If True, bypass content hash check
|
|
Returns:
|
|
Dictionary with indexing results
|
|
"""
|
|
# Ensure collection exists
|
|
collection_name = self._ensure_collection_exists(locale)
|
|
|
|
# Check if already indexed and content unchanged
|
|
content_hash = self._get_content_hash(body_md or "")
|
|
if not force:
|
|
with self.db_conn.cursor(DictCursor) as cur:
|
|
cur.execute(
|
|
"SELECT file_hash FROM post_vectors WHERE post_id=%s AND collection_name=%s",
|
|
(post_id, collection_name)
|
|
)
|
|
existing = cur.fetchone()
|
|
if existing and existing['file_hash'] == content_hash:
|
|
return {
|
|
'post_id': post_id,
|
|
'status': 'unchanged',
|
|
'message': 'Post content unchanged, skipping re-indexing',
|
|
'collection': collection_name
|
|
}
|
|
|
|
# Chunk and embed the content
|
|
chunks_with_embeddings = self.embedding_service.chunk_and_embed_post(
|
|
post_content=body_md,
|
|
post_id=post_id,
|
|
post_title=title
|
|
)
|
|
|
|
if not chunks_with_embeddings:
|
|
return {
|
|
'post_id': post_id,
|
|
'status': 'skipped',
|
|
'message': 'No content to index',
|
|
'collection': collection_name
|
|
}
|
|
|
|
# Prepare points for Qdrant
|
|
points = []
|
|
vector_ids = []
|
|
|
|
for chunk in chunks_with_embeddings:
|
|
# Generate deterministic UUID based on post_id and chunk_position
|
|
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{post_id}_{chunk['chunk_position']}"))
|
|
vector_ids.append(point_id)
|
|
|
|
payload = {
|
|
'post_id': post_id,
|
|
'title': title,
|
|
'slug': slug,
|
|
'locale': locale,
|
|
'content': chunk['content'],
|
|
'header': chunk.get('header', ''),
|
|
'header_level': chunk.get('header_level', 0),
|
|
'chunk_position': chunk['chunk_position']
|
|
}
|
|
|
|
if extra_payload:
|
|
payload.update(extra_payload)
|
|
|
|
points.append(PointStruct(
|
|
id=point_id,
|
|
vector=chunk['embedding'],
|
|
payload=payload
|
|
))
|
|
|
|
# Upsert points to Qdrant
|
|
self.qdrant.upsert(
|
|
collection_name=collection_name,
|
|
points=points
|
|
)
|
|
|
|
# Update tracking in database
|
|
provider_info = self.embedding_service.get_provider_info()
|
|
with self.db_conn.cursor(DictCursor) as cur:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO post_vectors (post_id, collection_name, vector_ids, provider, model, chunk_count, file_hash)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
ON DUPLICATE KEY UPDATE
|
|
vector_ids=VALUES(vector_ids),
|
|
provider=VALUES(provider),
|
|
model=VALUES(model),
|
|
chunk_count=VALUES(chunk_count),
|
|
file_hash=VALUES(file_hash),
|
|
indexed_at=CURRENT_TIMESTAMP
|
|
""",
|
|
(
|
|
post_id,
|
|
collection_name,
|
|
json.dumps(vector_ids),
|
|
provider_info['provider'],
|
|
provider_info['model'],
|
|
len(chunks_with_embeddings),
|
|
content_hash
|
|
)
|
|
)
|
|
|
|
return {
|
|
'post_id': post_id,
|
|
'status': 'indexed',
|
|
'chunks': len(chunks_with_embeddings),
|
|
'collection': collection_name,
|
|
'provider': provider_info['provider']
|
|
}
|
|
|
|
def index_all_posts(self, locale: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Index all published posts.
|
|
|
|
Args:
|
|
locale: Optional locale filter (de, en, or None for all)
|
|
|
|
Returns:
|
|
Dictionary with indexing results
|
|
"""
|
|
# Fetch posts from database
|
|
with self.db_conn.cursor(DictCursor) as cur:
|
|
if locale:
|
|
cur.execute(
|
|
"SELECT id, title, slug, locale, body_md FROM posts WHERE is_published=1 AND locale=%s",
|
|
(locale,)
|
|
)
|
|
else:
|
|
cur.execute(
|
|
"SELECT id, title, slug, locale, body_md FROM posts WHERE is_published=1"
|
|
)
|
|
posts = cur.fetchall()
|
|
|
|
if not posts:
|
|
return {
|
|
'indexed': 0,
|
|
'errors': 0,
|
|
'unchanged': 0,
|
|
'message': 'No posts found to index'
|
|
}
|
|
|
|
# Index each post
|
|
results = {
|
|
'indexed': 0,
|
|
'errors': 0,
|
|
'unchanged': 0,
|
|
'total': len(posts)
|
|
}
|
|
|
|
for post in posts:
|
|
try:
|
|
result = self.index_post(
|
|
post_id=post['id'],
|
|
title=post['title'],
|
|
slug=post['slug'],
|
|
locale=post['locale'],
|
|
body_md=post['body_md']
|
|
)
|
|
|
|
if result['status'] == 'indexed':
|
|
results['indexed'] += 1
|
|
elif result['status'] == 'unchanged':
|
|
results['unchanged'] += 1
|
|
else:
|
|
results['errors'] += 1
|
|
|
|
except Exception as e:
|
|
results['errors'] += 1
|
|
# Log error but continue with other posts
|
|
print(f"Error indexing post {post['id']}: {str(e)}")
|
|
|
|
return results
|
|
|
|
def delete_post_index(self, post_id: int, locale: str) -> Dict[str, Any]:
|
|
"""
|
|
Remove a post from the vector index.
|
|
|
|
Args:
|
|
post_id: Database ID of the post
|
|
locale: Post locale
|
|
|
|
Returns:
|
|
Dictionary with deletion results
|
|
"""
|
|
collection_name = self._get_collection_name(locale)
|
|
|
|
# Get vector IDs from database
|
|
with self.db_conn.cursor(DictCursor) as cur:
|
|
cur.execute(
|
|
"SELECT vector_ids FROM post_vectors WHERE post_id=%s AND collection_name=%s",
|
|
(post_id, collection_name)
|
|
)
|
|
result = cur.fetchone()
|
|
|
|
if not result:
|
|
return {
|
|
'post_id': post_id,
|
|
'status': 'not_found',
|
|
'message': 'Post not indexed'
|
|
}
|
|
|
|
vector_ids = json.loads(result['vector_ids'])
|
|
|
|
# Delete from Qdrant
|
|
self.qdrant.delete(
|
|
collection_name=collection_name,
|
|
points_selector=vector_ids
|
|
)
|
|
|
|
# Delete from tracking table
|
|
cur.execute(
|
|
"DELETE FROM post_vectors WHERE post_id=%s AND collection_name=%s",
|
|
(post_id, collection_name)
|
|
)
|
|
|
|
return {
|
|
'post_id': post_id,
|
|
'status': 'deleted',
|
|
'vectors_deleted': len(vector_ids)
|
|
}
|
|
|
|
def search_posts(
|
|
self,
|
|
query: str,
|
|
locale: str,
|
|
limit: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Semantic search across indexed posts.
|
|
|
|
Args:
|
|
query: Search query
|
|
locale: Locale to search in
|
|
limit: Maximum number of results
|
|
|
|
Returns:
|
|
List of search results with scores
|
|
"""
|
|
collection_name = self._get_collection_name(locale)
|
|
|
|
# Generate query embedding
|
|
query_embedding = self.embedding_service.embed_texts([query])[0]
|
|
|
|
# Search in Qdrant using new Query API
|
|
response = self.qdrant.query_points(
|
|
collection_name=collection_name,
|
|
query=query_embedding,
|
|
limit=limit
|
|
)
|
|
search_results = response.points
|
|
|
|
# Format results
|
|
results = []
|
|
for hit in search_results:
|
|
results.append({
|
|
'post_id': hit.payload['post_id'],
|
|
'title': hit.payload['title'],
|
|
'slug': hit.payload['slug'],
|
|
'content': hit.payload['content'],
|
|
'header': hit.payload.get('header', ''),
|
|
'score': hit.score
|
|
})
|
|
|
|
return results
|
|
|
|
def query_with_rag(
|
|
self,
|
|
question: str,
|
|
locale: str,
|
|
context_limit: int = 3
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Answer a question using RAG (Retrieval-Augmented Generation).
|
|
|
|
Args:
|
|
question: User's question
|
|
locale: Locale to search in
|
|
context_limit: Number of context chunks to retrieve
|
|
|
|
Returns:
|
|
Dictionary with answer and sources
|
|
"""
|
|
# Retrieve relevant context
|
|
search_results = self.search_posts(
|
|
query=question,
|
|
locale=locale,
|
|
limit=context_limit
|
|
)
|
|
|
|
if not search_results:
|
|
return {
|
|
'answer': 'No relevant information found in the knowledge base.',
|
|
'sources': [],
|
|
'provider': self.completion_provider.provider_name
|
|
}
|
|
|
|
# Build context from search results
|
|
context_parts = []
|
|
for idx, result in enumerate(search_results, 1):
|
|
context_parts.append(
|
|
f"[Source {idx}: {result['title']} - {result['header']}]\n{result['content']}"
|
|
)
|
|
context = "\n\n".join(context_parts)
|
|
|
|
# Generate answer
|
|
answer = self.completion_provider.get_completion(
|
|
prompt=question,
|
|
context=context
|
|
)
|
|
|
|
# Format sources
|
|
sources = [
|
|
{
|
|
'post_id': r['post_id'],
|
|
'title': r['title'],
|
|
'slug': r['slug'],
|
|
'score': r['score']
|
|
}
|
|
for r in search_results
|
|
]
|
|
|
|
return {
|
|
'answer': answer,
|
|
'sources': sources,
|
|
'provider': self.completion_provider.provider_name,
|
|
'model': self.completion_provider.model_name
|
|
}
|
|
|
|
def get_indexing_status(self) -> Dict[str, Any]:
|
|
"""
|
|
Get indexing status across all collections.
|
|
|
|
Returns:
|
|
Dictionary with status information
|
|
"""
|
|
status = {}
|
|
|
|
# Get total and indexed post counts from database
|
|
with self.db_conn.cursor(DictCursor) as cur:
|
|
cur.execute("SELECT COUNT(*) as total FROM posts WHERE is_published=1")
|
|
total_posts = cur.fetchone()['total']
|
|
|
|
cur.execute("SELECT COUNT(DISTINCT post_id) as indexed FROM post_vectors")
|
|
indexed_posts = cur.fetchone()['indexed']
|
|
|
|
# Get collection statistics
|
|
cur.execute(
|
|
"""
|
|
SELECT collection_name, COUNT(*) as post_count,
|
|
SUM(chunk_count) as total_vectors,
|
|
MAX(indexed_at) as last_indexed
|
|
FROM post_vectors
|
|
GROUP BY collection_name
|
|
"""
|
|
)
|
|
collections = cur.fetchall()
|
|
|
|
status['total_posts'] = total_posts
|
|
status['indexed_posts'] = indexed_posts
|
|
status['collections'] = {}
|
|
|
|
for col in collections:
|
|
status['collections'][col['collection_name']] = {
|
|
'posts': col['post_count'],
|
|
'vectors': col['total_vectors'],
|
|
'last_indexed': col['last_indexed'].isoformat() if col['last_indexed'] else None
|
|
}
|
|
|
|
return status
|