86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
# app/services/embedding_service.py
|
|
from typing import List, Dict, Any
|
|
from lib.embedding_providers.base import BaseProvider
|
|
from lib.markdown_chunker import MarkdownChunker
|
|
|
|
|
|
class EmbeddingService:
|
|
"""
|
|
Service for coordinating text chunking and embedding generation.
|
|
"""
|
|
|
|
def __init__(self, provider: BaseProvider, chunk_size: int = 1000, chunk_overlap: int = 200):
|
|
"""
|
|
Initialize the embedding service.
|
|
|
|
Args:
|
|
provider: Provider instance for generating embeddings
|
|
chunk_size: Maximum size of text chunks
|
|
chunk_overlap: Overlap between chunks
|
|
"""
|
|
self.provider = provider
|
|
self.chunker = MarkdownChunker(chunk_size=chunk_size, overlap=chunk_overlap)
|
|
|
|
def chunk_and_embed_post(
|
|
self,
|
|
post_content: str,
|
|
post_id: int,
|
|
post_title: str = ""
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Chunk post content and generate embeddings for each chunk.
|
|
|
|
Args:
|
|
post_content: Markdown content of the post
|
|
post_id: Database ID of the post
|
|
post_title: Title of the post
|
|
|
|
Returns:
|
|
List of dictionaries containing chunk metadata and embeddings
|
|
"""
|
|
# Chunk the content
|
|
chunks = self.chunker.chunk_post_content(post_content, post_id, post_title)
|
|
|
|
if not chunks:
|
|
return []
|
|
|
|
# Extract text content for embedding
|
|
texts = [chunk['content'] for chunk in chunks]
|
|
|
|
# Generate embeddings in batch
|
|
embeddings = self.provider.get_embeddings(texts)
|
|
|
|
# Combine chunks with embeddings
|
|
results = []
|
|
for chunk, embedding in zip(chunks, embeddings):
|
|
results.append({
|
|
**chunk,
|
|
'embedding': embedding
|
|
})
|
|
|
|
return results
|
|
|
|
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Generate embeddings for a list of texts.
|
|
|
|
Args:
|
|
texts: List of text strings
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
return self.provider.get_embeddings(texts)
|
|
|
|
def get_embedding_dimension(self) -> int:
|
|
"""Get the dimension of embeddings from the current provider."""
|
|
return self.provider.dimension
|
|
|
|
def get_provider_info(self) -> Dict[str, str]:
|
|
"""Get information about the current provider."""
|
|
return {
|
|
'provider': self.provider.provider_name,
|
|
'model': self.provider.model_name,
|
|
'dimension': str(self.provider.dimension)
|
|
}
|