131 lines
4.6 KiB
Python
131 lines
4.6 KiB
Python
# app/services/provider_factory.py
|
|
from typing import Dict, List
|
|
from lib.embedding_providers.base import BaseProvider
|
|
from lib.embedding_providers.openai_provider import OpenAIProvider
|
|
from lib.embedding_providers.openrouter_provider import OpenRouterProvider
|
|
from lib.embedding_providers.claude_provider import ClaudeProvider
|
|
from config import Settings
|
|
|
|
|
|
class ProviderFactory:
|
|
"""
|
|
Factory class for creating and managing embedding/completion providers.
|
|
"""
|
|
|
|
SUPPORTED_PROVIDERS = ["openai", "openrouter", "claude"]
|
|
|
|
@staticmethod
|
|
def create_provider(
|
|
provider_name: str,
|
|
settings: Settings,
|
|
embedding_model: str = None,
|
|
completion_model: str = None
|
|
) -> BaseProvider:
|
|
"""
|
|
Create a provider instance based on the provider name.
|
|
|
|
Args:
|
|
provider_name: Name of the provider ('openai', 'openrouter', 'claude')
|
|
settings: Application settings with API keys
|
|
embedding_model: Optional specific embedding model to use
|
|
completion_model: Optional specific completion model to use
|
|
|
|
Returns:
|
|
Configured provider instance
|
|
|
|
Raises:
|
|
ValueError: If provider is not supported or API key is missing
|
|
"""
|
|
provider_name = provider_name.lower()
|
|
|
|
if provider_name not in ProviderFactory.SUPPORTED_PROVIDERS:
|
|
raise ValueError(
|
|
f"Unsupported provider: {provider_name}. "
|
|
f"Supported providers: {ProviderFactory.SUPPORTED_PROVIDERS}"
|
|
)
|
|
|
|
# OpenAI Provider
|
|
if provider_name == "openai":
|
|
if not settings.openai_api_key:
|
|
raise ValueError("OpenAI API key not configured")
|
|
|
|
return OpenAIProvider(
|
|
api_key=settings.openai_api_key,
|
|
embedding_model=embedding_model or settings.default_embedding_model,
|
|
completion_model=completion_model or "gpt-4o-mini"
|
|
)
|
|
|
|
# OpenRouter Provider
|
|
elif provider_name == "openrouter":
|
|
if not settings.openrouter_api_key:
|
|
raise ValueError("OpenRouter API key not configured")
|
|
|
|
return OpenRouterProvider(
|
|
api_key=settings.openrouter_api_key,
|
|
completion_model=completion_model or "anthropic/claude-3-5-sonnet",
|
|
embedding_model=embedding_model or "openai/text-embedding-3-small",
|
|
embedding_dimension=1536
|
|
)
|
|
|
|
# Claude Provider
|
|
elif provider_name == "claude":
|
|
if not settings.anthropic_api_key:
|
|
raise ValueError("Claude (Anthropic) API key not configured")
|
|
|
|
return ClaudeProvider(
|
|
claude_api_key=settings.anthropic_api_key,
|
|
openai_api_key=settings.openai_api_key, # For embeddings fallback
|
|
completion_model=completion_model or settings.default_completion_model,
|
|
embedding_model=embedding_model or settings.default_embedding_model
|
|
)
|
|
|
|
@staticmethod
|
|
def get_provider_status(settings: Settings) -> Dict[str, Dict]:
|
|
"""
|
|
Check the availability status of all providers.
|
|
|
|
Args:
|
|
settings: Application settings with API keys
|
|
|
|
Returns:
|
|
Dictionary mapping provider names to their status information
|
|
"""
|
|
status = {}
|
|
|
|
for provider_name in ProviderFactory.SUPPORTED_PROVIDERS:
|
|
try:
|
|
# Try to create provider to check if it's properly configured
|
|
provider = ProviderFactory.create_provider(provider_name, settings)
|
|
status[provider_name] = {
|
|
"available": True,
|
|
"provider_name": provider.provider_name,
|
|
"model": provider.model_name,
|
|
"dimension": provider.dimension,
|
|
"supports_embeddings": provider.supports_embeddings(),
|
|
"supports_completions": provider.supports_completions()
|
|
}
|
|
except ValueError as e:
|
|
status[provider_name] = {
|
|
"available": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
return status
|
|
|
|
@staticmethod
|
|
def get_available_providers(settings: Settings) -> List[str]:
|
|
"""
|
|
Get list of available (properly configured) provider names.
|
|
|
|
Args:
|
|
settings: Application settings with API keys
|
|
|
|
Returns:
|
|
List of available provider names
|
|
"""
|
|
status = ProviderFactory.get_provider_status(settings)
|
|
return [
|
|
name for name, info in status.items()
|
|
if info.get("available", False)
|
|
]
|