- Upgrade qdrant-client usage to query_points (replacing deprecated search) - Relax qdrant-client version constraint in requirements.txt (>=1.7.0) - Config: Allow extra environment variables in Settings (for local dev) - test_roles_rag.py: Fix import path for local execution
217 lines
7.0 KiB
Python
217 lines
7.0 KiB
Python
"""
|
|
RAG Chat Utility
|
|
Handles document search and chat completions for character interactions.
|
|
"""
|
|
from typing import List, Dict, Any, Optional
|
|
from qdrant_client import QdrantClient
|
|
from lib.embedding_providers.base import BaseProvider
|
|
|
|
|
|
class RAGChatService:
|
|
"""
|
|
Service for character-based RAG chat.
|
|
Searches across document collections and generates answers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
qdrant_client: QdrantClient,
|
|
embedding_provider: BaseProvider,
|
|
completion_provider: BaseProvider
|
|
):
|
|
"""
|
|
Initialize RAG chat service.
|
|
|
|
Args:
|
|
qdrant_client: Qdrant client instance
|
|
embedding_provider: Provider for generating embeddings
|
|
completion_provider: Provider for completions
|
|
"""
|
|
self.qdrant = qdrant_client
|
|
self.embedding_provider = embedding_provider
|
|
self.completion_provider = completion_provider
|
|
|
|
# Available document collections
|
|
self.doc_collections = [
|
|
"docs_crumbforest_",
|
|
"docs_rz_nullfeld_"
|
|
]
|
|
|
|
def search_documents(
|
|
self,
|
|
query: str,
|
|
limit: int = 5,
|
|
collections: Optional[List[str]] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search across document collections.
|
|
|
|
Args:
|
|
query: Search query
|
|
limit: Maximum results per collection
|
|
collections: Optional list of collections to search (defaults to all)
|
|
|
|
Returns:
|
|
List of search results with scores
|
|
"""
|
|
if collections is None:
|
|
collections = self.doc_collections
|
|
|
|
# Generate query embedding
|
|
query_embedding = self.embedding_provider.get_embeddings([query])[0]
|
|
|
|
all_results = []
|
|
|
|
# Search each collection
|
|
for collection_name in collections:
|
|
try:
|
|
# Check if collection exists
|
|
existing_collections = [
|
|
col.name for col in self.qdrant.get_collections().collections
|
|
]
|
|
if collection_name not in existing_collections:
|
|
continue
|
|
|
|
# Search in Qdrant using new Query API
|
|
response = self.qdrant.query_points(
|
|
collection_name=collection_name,
|
|
query=query_embedding,
|
|
limit=limit
|
|
)
|
|
search_results = response.points
|
|
|
|
# Format results
|
|
for hit in search_results:
|
|
all_results.append({
|
|
'collection': collection_name,
|
|
'content': hit.payload.get('content', ''),
|
|
'file_path': hit.payload.get('file_path', ''),
|
|
'header': hit.payload.get('header', ''),
|
|
'score': hit.score
|
|
})
|
|
|
|
except Exception as e:
|
|
# Log error but continue with other collections
|
|
print(f"Error searching collection {collection_name}: {str(e)}")
|
|
continue
|
|
|
|
# Sort all results by score (descending)
|
|
all_results.sort(key=lambda x: x['score'], reverse=True)
|
|
|
|
# Return top results
|
|
return all_results[:limit]
|
|
|
|
def chat_with_context(
|
|
self,
|
|
question: str,
|
|
character_name: str,
|
|
character_prompt: str,
|
|
context_limit: int = 3,
|
|
lang: str = "de",
|
|
model_override: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Answer a question using RAG with character persona.
|
|
|
|
Args:
|
|
question: User's question
|
|
character_name: Name of the character (e.g., "Krümeleule")
|
|
character_prompt: Character system prompt
|
|
context_limit: Number of context chunks to retrieve
|
|
lang: Language code
|
|
model_override: Optional model name to override default
|
|
|
|
Returns:
|
|
Dictionary with answer and sources
|
|
"""
|
|
# Retrieve relevant context
|
|
search_results = self.search_documents(
|
|
query=question,
|
|
limit=context_limit
|
|
)
|
|
|
|
# Build context from search results
|
|
if search_results:
|
|
context_parts = []
|
|
for idx, result in enumerate(search_results, 1):
|
|
header = result['header'] if result['header'] else result['file_path']
|
|
context_parts.append(
|
|
f"[Quelle {idx}: {header}]\n{result['content']}"
|
|
)
|
|
context = "\n\n".join(context_parts)
|
|
else:
|
|
context = ""
|
|
|
|
# Build system message with character persona + context
|
|
if context:
|
|
system_message = (
|
|
f"{character_prompt}\n\n"
|
|
f"Nutze folgende Informationen aus der Wissensdatenbank, um die Frage zu beantworten:\n\n"
|
|
f"{context}"
|
|
)
|
|
else:
|
|
system_message = (
|
|
f"{character_prompt}\n\n"
|
|
f"Hinweis: Keine relevanten Informationen in der Wissensdatenbank gefunden. "
|
|
f"Antworte trotzdem hilfreich basierend auf deinem Charakter."
|
|
)
|
|
|
|
# Generate answer with character persona
|
|
answer = self._get_completion_with_system(
|
|
system_message=system_message,
|
|
user_message=question,
|
|
model_override=model_override
|
|
)
|
|
|
|
# Format sources
|
|
sources = [
|
|
{
|
|
'collection': r['collection'],
|
|
'file_path': r['file_path'],
|
|
'header': r['header'],
|
|
'score': r['score']
|
|
}
|
|
for r in search_results
|
|
]
|
|
|
|
return {
|
|
'answer': answer,
|
|
'sources': sources,
|
|
'context_found': len(search_results) > 0,
|
|
'provider': self.completion_provider.provider_name,
|
|
'model': model_override or self.completion_provider.model_name
|
|
}
|
|
|
|
def _get_completion_with_system(
|
|
self,
|
|
system_message: str,
|
|
user_message: str,
|
|
model_override: Optional[str] = None
|
|
) -> str:
|
|
"""
|
|
Get completion with explicit system message.
|
|
Handles both context-based and direct completion.
|
|
|
|
Args:
|
|
system_message: System prompt with character + context
|
|
user_message: User's question
|
|
model_override: Optional model name to use
|
|
|
|
Returns:
|
|
Generated response
|
|
"""
|
|
# For providers that support system messages natively
|
|
# We'll use the get_completion method with context
|
|
# The system_message already contains the character prompt + context
|
|
|
|
# Call completion provider
|
|
# Note: We pass the system message as context
|
|
# Assuming provider.get_completion supports model argument?
|
|
# If not, we might need to check BaseProvider.
|
|
# But get_completion usually takes keyword args passed to API.
|
|
return self.completion_provider.get_completion(
|
|
prompt=user_message,
|
|
context=system_message,
|
|
model=model_override
|
|
)
|