169 lines
5.1 KiB
Python
169 lines
5.1 KiB
Python
"""
|
|
Security utilities for input validation and filtering.
|
|
"""
|
|
import re
|
|
from typing import Optional
|
|
|
|
|
|
class PromptInjectionFilter:
|
|
"""Filter to detect and sanitize potential prompt injection attempts."""
|
|
|
|
# Common prompt injection patterns
|
|
DANGEROUS_PATTERNS = [
|
|
# English
|
|
r"ignore\s+(all\s+)?(previous|prior|above)\s+(instructions?|prompts?|commands?)",
|
|
r"disregard\s+(all\s+)?(previous|prior|above)",
|
|
r"forget\s+(all\s+)?(previous|prior|above)",
|
|
r"you\s+are\s+now",
|
|
r"your\s+new\s+(role|instructions?|prompt)",
|
|
r"system\s+prompt",
|
|
r"tell\s+me\s+your\s+(instructions?|prompt|system)",
|
|
r"what\s+(are|is)\s+your\s+(instructions?|prompt|rules)",
|
|
|
|
# German
|
|
r"ignoriere\s+(alle\s+)?(vorherigen?|obigen?)\s+(anweisungen?|prompts?|befehle?)",
|
|
r"vergiss\s+(alle\s+)?(vorherigen?|obigen?)",
|
|
r"du\s+bist\s+jetzt",
|
|
r"deine\s+neue\s+(rolle|anweisung)",
|
|
r"system\s*-?\s*prompt",
|
|
r"sage?\s+mir\s+deine\s+(anweisungen?|prompt|regeln)",
|
|
|
|
# Chinese (common in jailbreak attempts)
|
|
r"你是",
|
|
r"现在你是",
|
|
r"忽略之前",
|
|
|
|
# Role manipulation
|
|
r"act\s+as\s+(a\s+)?(?!teacher|tutor|guide)", # Allow educational roles
|
|
r"pretend\s+to\s+be",
|
|
r"roleplay\s+as",
|
|
|
|
# System commands
|
|
r"<\s*system\s*>",
|
|
r"<\s*admin\s*>",
|
|
r"sudo\s+",
|
|
|
|
# Attempts to break out of context
|
|
r"\[SYSTEM\]",
|
|
r"\[INST\]",
|
|
r"###\s*Instruction",
|
|
]
|
|
|
|
def __init__(self):
|
|
"""Initialize the filter with compiled regex patterns."""
|
|
self.patterns = [
|
|
re.compile(pattern, re.IGNORECASE | re.MULTILINE)
|
|
for pattern in self.DANGEROUS_PATTERNS
|
|
]
|
|
|
|
def is_suspicious(self, text: str) -> bool:
|
|
"""
|
|
Check if text contains suspicious prompt injection patterns.
|
|
|
|
Args:
|
|
text: User input to check
|
|
|
|
Returns:
|
|
True if suspicious patterns detected, False otherwise
|
|
"""
|
|
for pattern in self.patterns:
|
|
if pattern.search(text):
|
|
return True
|
|
return False
|
|
|
|
def sanitize(self, text: str, replace_with: str = "[FILTERED]") -> str:
|
|
"""
|
|
Sanitize text by replacing suspicious patterns.
|
|
|
|
Args:
|
|
text: User input to sanitize
|
|
replace_with: Replacement text for suspicious patterns
|
|
|
|
Returns:
|
|
Sanitized text
|
|
"""
|
|
sanitized = text
|
|
for pattern in self.patterns:
|
|
sanitized = pattern.sub(replace_with, sanitized)
|
|
return sanitized
|
|
|
|
def validate(self, text: str, max_length: int = 2000) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Validate user input for length and injection attempts.
|
|
|
|
Args:
|
|
text: User input to validate
|
|
max_length: Maximum allowed length
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error_message)
|
|
"""
|
|
# Check length
|
|
if len(text) > max_length:
|
|
return False, f"Input too long (max {max_length} characters)"
|
|
|
|
# Check for prompt injection
|
|
if self.is_suspicious(text):
|
|
return False, "Input contains suspicious patterns"
|
|
|
|
# Check for excessive repetition (potential DoS)
|
|
if self._has_excessive_repetition(text):
|
|
return False, "Input contains excessive repetition"
|
|
|
|
return True, None
|
|
|
|
def _has_excessive_repetition(self, text: str, threshold: int = 50) -> bool:
|
|
"""
|
|
Check if text has excessive character/word repetition.
|
|
|
|
Args:
|
|
text: Text to check
|
|
threshold: Maximum allowed consecutive repetitions
|
|
|
|
Returns:
|
|
True if excessive repetition detected
|
|
"""
|
|
# Check character repetition
|
|
if re.search(r'(.)\1{50,}', text): # Same char 50+ times
|
|
return True
|
|
|
|
# Check word repetition
|
|
words = text.split()
|
|
if len(words) > 3:
|
|
for i in range(len(words) - 3):
|
|
word = words[i]
|
|
# Count consecutive occurrences
|
|
count = 1
|
|
for j in range(i + 1, min(i + threshold, len(words))):
|
|
if words[j] == word:
|
|
count += 1
|
|
else:
|
|
break
|
|
if count > 20: # Same word 20+ times in a row
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def sanitize_for_logging(text: str, max_length: int = 500) -> str:
|
|
"""
|
|
Sanitize text for safe logging (truncate, remove PII hints).
|
|
|
|
Args:
|
|
text: Text to sanitize
|
|
max_length: Maximum length for logging
|
|
|
|
Returns:
|
|
Sanitized text safe for logging
|
|
"""
|
|
# Truncate
|
|
if len(text) > max_length:
|
|
text = text[:max_length] + "..."
|
|
|
|
# Basic PII patterns (email, phone, credit card)
|
|
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', text)
|
|
text = re.sub(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', '[PHONE]', text)
|
|
text = re.sub(r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b', '[CC]', text)
|
|
|
|
return text
|