Files
Crumb-Core-v.1/app/services/rag_service.py
Krümel Branko 9c01bc0607 Fix Qdrant Client Compatibility and Role RAG Tests
- Upgrade qdrant-client usage to query_points (replacing deprecated search)
- Relax qdrant-client version constraint in requirements.txt (>=1.7.0)
- Config: Allow extra environment variables in Settings (for local dev)
- test_roles_rag.py: Fix import path for local execution
2025-12-07 20:45:58 +01:00

465 lines
15 KiB
Python

# app/services/rag_service.py
import json
import hashlib
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime
from pymysql import Connection
from pymysql.cursors import DictCursor
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from slugify import slugify
from lib.embedding_providers.base import BaseProvider
from services.embedding_service import EmbeddingService
class RAGService:
"""
Service for RAG operations: indexing, searching, and query answering.
"""
def __init__(
self,
db_conn: Connection,
qdrant_client: QdrantClient,
embedding_provider: BaseProvider,
completion_provider: Optional[BaseProvider] = None,
collection_prefix: str = "posts"
):
"""
Initialize RAG service.
Args:
db_conn: Database connection
qdrant_client: Qdrant client instance
embedding_provider: Provider for generating embeddings
completion_provider: Provider for completions (defaults to embedding_provider)
collection_prefix: Prefix for Qdrant collections
"""
self.db_conn = db_conn
self.qdrant = qdrant_client
self.embedding_service = EmbeddingService(provider=embedding_provider)
self.completion_provider = completion_provider or embedding_provider
self.collection_prefix = collection_prefix
def _get_collection_name(self, locale: str) -> str:
"""Generate collection name for a given locale."""
if not locale:
return self.collection_prefix
return f"{self.collection_prefix}_{locale}"
def _ensure_collection_exists(self, locale: str) -> str:
"""
Ensure Qdrant collection exists for the given locale.
Returns:
Collection name
"""
collection_name = self._get_collection_name(locale)
collections = [col.name for col in self.qdrant.get_collections().collections]
if collection_name not in collections:
# Create collection with appropriate vector dimension
self.qdrant.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.embedding_service.get_embedding_dimension(),
distance=Distance.COSINE
)
)
return collection_name
def _get_content_hash(self, content: str) -> str:
"""Generate MD5 hash of content for change detection."""
return hashlib.md5(content.encode('utf-8')).hexdigest()
def index_post(
self,
post_id: int,
title: str,
slug: str,
locale: str,
body_md: str,
extra_payload: Optional[Dict[str, Any]] = None,
force: bool = False
) -> Dict[str, Any]:
"""
Index a single post into Qdrant.
Args:
post_id: Database ID of the post
title: Post title
slug: Post slug
locale: Post locale (de, en, etc.)
body_md: Markdown content
extra_payload: Additional metadata to store in Qdrant
force: If True, bypass content hash check
Returns:
Dictionary with indexing results
"""
# Ensure collection exists
collection_name = self._ensure_collection_exists(locale)
# Check if already indexed and content unchanged
content_hash = self._get_content_hash(body_md or "")
if not force:
with self.db_conn.cursor(DictCursor) as cur:
cur.execute(
"SELECT file_hash FROM post_vectors WHERE post_id=%s AND collection_name=%s",
(post_id, collection_name)
)
existing = cur.fetchone()
if existing and existing['file_hash'] == content_hash:
return {
'post_id': post_id,
'status': 'unchanged',
'message': 'Post content unchanged, skipping re-indexing',
'collection': collection_name
}
# Chunk and embed the content
chunks_with_embeddings = self.embedding_service.chunk_and_embed_post(
post_content=body_md,
post_id=post_id,
post_title=title
)
if not chunks_with_embeddings:
return {
'post_id': post_id,
'status': 'skipped',
'message': 'No content to index',
'collection': collection_name
}
# Prepare points for Qdrant
points = []
vector_ids = []
for chunk in chunks_with_embeddings:
# Generate deterministic UUID based on post_id and chunk_position
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{post_id}_{chunk['chunk_position']}"))
vector_ids.append(point_id)
payload = {
'post_id': post_id,
'title': title,
'slug': slug,
'locale': locale,
'content': chunk['content'],
'header': chunk.get('header', ''),
'header_level': chunk.get('header_level', 0),
'chunk_position': chunk['chunk_position']
}
if extra_payload:
payload.update(extra_payload)
points.append(PointStruct(
id=point_id,
vector=chunk['embedding'],
payload=payload
))
# Upsert points to Qdrant
self.qdrant.upsert(
collection_name=collection_name,
points=points
)
# Update tracking in database
provider_info = self.embedding_service.get_provider_info()
with self.db_conn.cursor(DictCursor) as cur:
cur.execute(
"""
INSERT INTO post_vectors (post_id, collection_name, vector_ids, provider, model, chunk_count, file_hash)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
vector_ids=VALUES(vector_ids),
provider=VALUES(provider),
model=VALUES(model),
chunk_count=VALUES(chunk_count),
file_hash=VALUES(file_hash),
indexed_at=CURRENT_TIMESTAMP
""",
(
post_id,
collection_name,
json.dumps(vector_ids),
provider_info['provider'],
provider_info['model'],
len(chunks_with_embeddings),
content_hash
)
)
return {
'post_id': post_id,
'status': 'indexed',
'chunks': len(chunks_with_embeddings),
'collection': collection_name,
'provider': provider_info['provider']
}
def index_all_posts(self, locale: Optional[str] = None) -> Dict[str, Any]:
"""
Index all published posts.
Args:
locale: Optional locale filter (de, en, or None for all)
Returns:
Dictionary with indexing results
"""
# Fetch posts from database
with self.db_conn.cursor(DictCursor) as cur:
if locale:
cur.execute(
"SELECT id, title, slug, locale, body_md FROM posts WHERE is_published=1 AND locale=%s",
(locale,)
)
else:
cur.execute(
"SELECT id, title, slug, locale, body_md FROM posts WHERE is_published=1"
)
posts = cur.fetchall()
if not posts:
return {
'indexed': 0,
'errors': 0,
'unchanged': 0,
'message': 'No posts found to index'
}
# Index each post
results = {
'indexed': 0,
'errors': 0,
'unchanged': 0,
'total': len(posts)
}
for post in posts:
try:
result = self.index_post(
post_id=post['id'],
title=post['title'],
slug=post['slug'],
locale=post['locale'],
body_md=post['body_md']
)
if result['status'] == 'indexed':
results['indexed'] += 1
elif result['status'] == 'unchanged':
results['unchanged'] += 1
else:
results['errors'] += 1
except Exception as e:
results['errors'] += 1
# Log error but continue with other posts
print(f"Error indexing post {post['id']}: {str(e)}")
return results
def delete_post_index(self, post_id: int, locale: str) -> Dict[str, Any]:
"""
Remove a post from the vector index.
Args:
post_id: Database ID of the post
locale: Post locale
Returns:
Dictionary with deletion results
"""
collection_name = self._get_collection_name(locale)
# Get vector IDs from database
with self.db_conn.cursor(DictCursor) as cur:
cur.execute(
"SELECT vector_ids FROM post_vectors WHERE post_id=%s AND collection_name=%s",
(post_id, collection_name)
)
result = cur.fetchone()
if not result:
return {
'post_id': post_id,
'status': 'not_found',
'message': 'Post not indexed'
}
vector_ids = json.loads(result['vector_ids'])
# Delete from Qdrant
self.qdrant.delete(
collection_name=collection_name,
points_selector=vector_ids
)
# Delete from tracking table
cur.execute(
"DELETE FROM post_vectors WHERE post_id=%s AND collection_name=%s",
(post_id, collection_name)
)
return {
'post_id': post_id,
'status': 'deleted',
'vectors_deleted': len(vector_ids)
}
def search_posts(
self,
query: str,
locale: str,
limit: int = 5
) -> List[Dict[str, Any]]:
"""
Semantic search across indexed posts.
Args:
query: Search query
locale: Locale to search in
limit: Maximum number of results
Returns:
List of search results with scores
"""
collection_name = self._get_collection_name(locale)
# Generate query embedding
query_embedding = self.embedding_service.embed_texts([query])[0]
# Search in Qdrant using new Query API
response = self.qdrant.query_points(
collection_name=collection_name,
query=query_embedding,
limit=limit
)
search_results = response.points
# Format results
results = []
for hit in search_results:
results.append({
'post_id': hit.payload['post_id'],
'title': hit.payload['title'],
'slug': hit.payload['slug'],
'content': hit.payload['content'],
'header': hit.payload.get('header', ''),
'score': hit.score
})
return results
def query_with_rag(
self,
question: str,
locale: str,
context_limit: int = 3
) -> Dict[str, Any]:
"""
Answer a question using RAG (Retrieval-Augmented Generation).
Args:
question: User's question
locale: Locale to search in
context_limit: Number of context chunks to retrieve
Returns:
Dictionary with answer and sources
"""
# Retrieve relevant context
search_results = self.search_posts(
query=question,
locale=locale,
limit=context_limit
)
if not search_results:
return {
'answer': 'No relevant information found in the knowledge base.',
'sources': [],
'provider': self.completion_provider.provider_name
}
# Build context from search results
context_parts = []
for idx, result in enumerate(search_results, 1):
context_parts.append(
f"[Source {idx}: {result['title']} - {result['header']}]\n{result['content']}"
)
context = "\n\n".join(context_parts)
# Generate answer
answer = self.completion_provider.get_completion(
prompt=question,
context=context
)
# Format sources
sources = [
{
'post_id': r['post_id'],
'title': r['title'],
'slug': r['slug'],
'score': r['score']
}
for r in search_results
]
return {
'answer': answer,
'sources': sources,
'provider': self.completion_provider.provider_name,
'model': self.completion_provider.model_name
}
def get_indexing_status(self) -> Dict[str, Any]:
"""
Get indexing status across all collections.
Returns:
Dictionary with status information
"""
status = {}
# Get total and indexed post counts from database
with self.db_conn.cursor(DictCursor) as cur:
cur.execute("SELECT COUNT(*) as total FROM posts WHERE is_published=1")
total_posts = cur.fetchone()['total']
cur.execute("SELECT COUNT(DISTINCT post_id) as indexed FROM post_vectors")
indexed_posts = cur.fetchone()['indexed']
# Get collection statistics
cur.execute(
"""
SELECT collection_name, COUNT(*) as post_count,
SUM(chunk_count) as total_vectors,
MAX(indexed_at) as last_indexed
FROM post_vectors
GROUP BY collection_name
"""
)
collections = cur.fetchall()
status['total_posts'] = total_posts
status['indexed_posts'] = indexed_posts
status['collections'] = {}
for col in collections:
status['collections'][col['collection_name']] = {
'posts': col['post_count'],
'vectors': col['total_vectors'],
'last_indexed': col['last_indexed'].isoformat() if col['last_indexed'] else None
}
return status