502 lines
21 KiB
Python
502 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
RAG Service Module for 徵象防伪验证平台
|
||
基于 LangChain 的实时知识库检索服务,支持缓存优化
|
||
"""
|
||
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain.embeddings import HuggingFaceEmbeddings
|
||
from langchain.vectorstores import FAISS
|
||
from langchain.schema import Document
|
||
from .models import Article
|
||
from django.core.cache import cache
|
||
from typing import List, Dict, Any
|
||
import hashlib
|
||
import pickle
|
||
import time
|
||
from bs4 import BeautifulSoup
|
||
|
||
class CachedLangChainRAG:
|
||
"""基于缓存的实时LangChain RAG服务"""
|
||
|
||
def __init__(self):
|
||
self.embeddings = HuggingFaceEmbeddings(
|
||
model_name="shibing624/text2vec-base-chinese",
|
||
model_kwargs={'device': 'cpu'},
|
||
encode_kwargs={'normalize_embeddings': True}
|
||
)
|
||
|
||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=1000,
|
||
chunk_overlap=200,
|
||
separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", ";", ";"]
|
||
)
|
||
|
||
# 缓存配置
|
||
self.vector_cache_timeout = 86400 * 30 # 30天
|
||
self.search_cache_timeout = 86400 # 1天
|
||
|
||
def _get_article_content_hash(self, article: Article) -> str:
|
||
"""获取文章内容的hash值"""
|
||
content = f"{article.title or ''}{article.body}{article.options or ''}"
|
||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||
|
||
def _get_articles_signature(self, articles: List[Article]) -> str:
|
||
"""获取文章集合的签名"""
|
||
signatures = []
|
||
for article in articles:
|
||
content_hash = self._get_article_content_hash(article)
|
||
signatures.append(f"{article.id}_{content_hash}")
|
||
|
||
# 排序确保一致性
|
||
signatures.sort()
|
||
combined = '|'.join(signatures)
|
||
return hashlib.md5(combined.encode('utf-8')).hexdigest()
|
||
|
||
def search(self, query: str, max_results: int = 3, tenant_id: int = None) -> List[Dict[str, Any]]:
|
||
"""带缓存的搜索"""
|
||
start_time = time.time()
|
||
print(f"\n🔍 RAG搜索开始: {query}")
|
||
print(f"📊 参数: max_results={max_results}, tenant_id={tenant_id}")
|
||
|
||
# 生成查询缓存key
|
||
query_hash = hashlib.md5(query.encode('utf-8')).hexdigest()
|
||
cache_key = f"rag_search_{query_hash}_{max_results}_{tenant_id}"
|
||
|
||
# 尝试从缓存获取搜索结果
|
||
cache_check_start = time.time()
|
||
cached_result = cache.get(cache_key)
|
||
cache_check_time = time.time() - cache_check_start
|
||
|
||
if cached_result:
|
||
total_time = time.time() - start_time
|
||
print(f"✅ 缓存命中: {cache_key}")
|
||
print(f"📚 返回缓存结果: {len(cached_result)} 条")
|
||
print(f"⏱️ 缓存检查耗时: {cache_check_time:.3f}秒")
|
||
print(f"⏱️ 总耗时: {total_time:.3f}秒")
|
||
return cached_result
|
||
|
||
print(f"❌ 缓存未命中: {cache_key}")
|
||
print(f"⏱️ 缓存检查耗时: {cache_check_time:.3f}秒")
|
||
print("🔄 执行实时搜索...")
|
||
|
||
# 缓存未命中,执行搜索
|
||
search_start_time = time.time()
|
||
results = self._perform_search(query, max_results, tenant_id)
|
||
search_time = time.time() - search_start_time
|
||
|
||
# 缓存搜索结果
|
||
cache_start_time = time.time()
|
||
cache.set(cache_key, results, self.search_cache_timeout)
|
||
cache_time = time.time() - cache_start_time
|
||
|
||
total_time = time.time() - start_time
|
||
print(f"💾 搜索结果已缓存,过期时间: {self.search_cache_timeout}秒")
|
||
print(f"⏱️ 搜索耗时: {search_time:.3f}秒")
|
||
print(f"⏱️ 缓存写入耗时: {cache_time:.3f}秒")
|
||
print(f"⏱️ 总耗时: {total_time:.3f}秒")
|
||
|
||
return results
|
||
|
||
def _perform_search(self, query: str, max_results: int, tenant_id: int = None) -> List[Dict[str, Any]]:
|
||
"""执行实际搜索"""
|
||
search_start_time = time.time()
|
||
try:
|
||
print(f"🔧 获取向量存储...")
|
||
vector_start_time = time.time()
|
||
# 获取或创建缓存的向量存储
|
||
vectorstore = self._get_cached_vectorstore(tenant_id)
|
||
vector_time = time.time() - vector_start_time
|
||
print(f"⏱️ 向量存储获取耗时: {vector_time:.3f}秒")
|
||
|
||
if not vectorstore:
|
||
print("❌ 无法获取向量存储")
|
||
return []
|
||
|
||
print(f"🎯 执行相似度搜索,查询: '{query}'")
|
||
similarity_start_time = time.time()
|
||
# 执行相似度搜索
|
||
docs = vectorstore.similarity_search_with_score(query, k=max_results)
|
||
similarity_time = time.time() - similarity_start_time
|
||
print(f"⏱️ 相似度搜索耗时: {similarity_time:.3f}秒")
|
||
print(f"📊 找到 {len(docs)} 个候选文档")
|
||
|
||
# 格式化结果
|
||
format_start_time = time.time()
|
||
results = []
|
||
for i, (doc, score) in enumerate(docs, 1):
|
||
result = {
|
||
'id': doc.metadata.get('article_id'),
|
||
'title': doc.metadata.get('title'),
|
||
'content': doc.page_content,
|
||
'score': float(score),
|
||
'tenant_id': doc.metadata.get('tenant_id'),
|
||
'url': doc.metadata.get('url')
|
||
}
|
||
results.append(result)
|
||
|
||
print(f" 📄 结果 {i}:")
|
||
print(f" ID: {result['id']}")
|
||
print(f" 标题: {result['title']}")
|
||
print(f" 相关度: {result['score']:.4f}")
|
||
print(f" 内容预览: {result['content'][:100]}...")
|
||
|
||
format_time = time.time() - format_start_time
|
||
total_search_time = time.time() - search_start_time
|
||
print(f"⏱️ 结果格式化耗时: {format_time:.3f}秒")
|
||
print(f"✅ 搜索完成,返回 {len(results)} 个结果")
|
||
print(f"⏱️ 搜索总耗时: {total_search_time:.3f}秒")
|
||
return results
|
||
|
||
except Exception as e:
|
||
total_search_time = time.time() - search_start_time
|
||
print(f"❌ RAG搜索失败: {str(e)}")
|
||
print(f"⏱️ 搜索耗时: {total_search_time:.3f}秒")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return []
|
||
|
||
def _get_cached_vectorstore(self, tenant_id: int = None):
|
||
"""获取或创建缓存的向量存储"""
|
||
vector_start_time = time.time()
|
||
print(f"🔍 获取知识库文章...")
|
||
|
||
articles_start_time = time.time()
|
||
# 获取知识库文章列表
|
||
articles = self._get_knowledge_base_articles(tenant_id)
|
||
articles_time = time.time() - articles_start_time
|
||
print(f"⏱️ 文章获取耗时: {articles_time:.3f}秒")
|
||
|
||
if not articles:
|
||
print("❌ 没有找到知识库文章")
|
||
return None
|
||
|
||
print(f"📚 找到 {len(articles)} 篇知识库文章")
|
||
|
||
# 检查向量缓存
|
||
cache_check_start = time.time()
|
||
cache_key = self._get_vector_cache_key(articles, tenant_id)
|
||
print(f"🔑 向量缓存key: {cache_key}")
|
||
|
||
cached_vectors = cache.get(cache_key)
|
||
cache_check_time = time.time() - cache_check_start
|
||
print(f"⏱️ 向量缓存检查耗时: {cache_check_time:.3f}秒")
|
||
|
||
if cached_vectors:
|
||
restore_start_time = time.time()
|
||
print("✅ 向量缓存命中,恢复向量存储")
|
||
# 从缓存恢复向量存储
|
||
result = self._restore_vectorstore(cached_vectors)
|
||
restore_time = time.time() - restore_start_time
|
||
print(f"⏱️ 向量存储恢复耗时: {restore_time:.3f}秒")
|
||
|
||
total_time = time.time() - vector_start_time
|
||
print(f"⏱️ 向量存储获取总耗时: {total_time:.3f}秒")
|
||
return result
|
||
|
||
print("❌ 向量缓存未命中,创建新的向量存储")
|
||
create_start_time = time.time()
|
||
# 缓存未命中,创建新的向量存储
|
||
vectorstore = self._create_vectorstore(articles)
|
||
create_time = time.time() - create_start_time
|
||
print(f"⏱️ 向量存储创建耗时: {create_time:.3f}秒")
|
||
|
||
# 缓存向量存储
|
||
cache_start_time = time.time()
|
||
self._cache_vectorstore(vectorstore, cache_key)
|
||
cache_time = time.time() - cache_start_time
|
||
print(f"⏱️ 向量存储缓存耗时: {cache_time:.3f}秒")
|
||
|
||
total_time = time.time() - vector_start_time
|
||
print(f"⏱️ 向量存储获取总耗时: {total_time:.3f}秒")
|
||
return vectorstore
|
||
|
||
def _get_knowledge_base_articles(self, tenant_id: int = None) -> List[Article]:
|
||
"""获取知识库文章(带缓存)"""
|
||
articles_start_time = time.time()
|
||
|
||
# 生成基于文章内容的缓存key
|
||
cache_key = self._get_articles_cache_key(tenant_id)
|
||
print(f"🔍 检查文章缓存: {cache_key}")
|
||
|
||
cache_check_start = time.time()
|
||
cached_articles = cache.get(cache_key)
|
||
cache_check_time = time.time() - cache_check_start
|
||
|
||
if cached_articles:
|
||
total_time = time.time() - articles_start_time
|
||
print(f"✅ 文章缓存命中,返回 {len(cached_articles)} 篇文章")
|
||
print(f"⏱️ 缓存检查耗时: {cache_check_time:.3f}秒")
|
||
print(f"⏱️ 文章获取总耗时: {total_time:.3f}秒")
|
||
return cached_articles
|
||
|
||
print("❌ 文章缓存未命中,查询数据库...")
|
||
print(f"⏱️ 缓存检查耗时: {cache_check_time:.3f}秒")
|
||
|
||
# 查询数据库
|
||
db_start_time = time.time()
|
||
filter_kwargs = {'is_platform_knowledge_base': True}
|
||
if tenant_id:
|
||
filter_kwargs['tenant_id'] = tenant_id
|
||
|
||
articles = list(Article.objects.filter(**filter_kwargs))
|
||
db_time = time.time() - db_start_time
|
||
print(f"📊 数据库查询结果: {len(articles)} 篇文章")
|
||
print(f"⏱️ 数据库查询耗时: {db_time:.3f}秒")
|
||
|
||
# 缓存文章列表
|
||
cache_start_time = time.time()
|
||
cache.set(cache_key, articles, self.vector_cache_timeout)
|
||
cache_time = time.time() - cache_start_time
|
||
print(f"💾 文章列表已缓存,过期时间: {self.vector_cache_timeout}秒")
|
||
print(f"⏱️ 文章缓存写入耗时: {cache_time:.3f}秒")
|
||
|
||
total_time = time.time() - articles_start_time
|
||
print(f"⏱️ 文章获取总耗时: {total_time:.3f}秒")
|
||
return articles
|
||
|
||
def _get_articles_cache_key(self, tenant_id: int = None) -> str:
|
||
"""生成基于文章内容的缓存key"""
|
||
# 查询当前文章状态
|
||
filter_kwargs = {'is_platform_knowledge_base': True}
|
||
if tenant_id:
|
||
filter_kwargs['tenant_id'] = tenant_id
|
||
|
||
articles = Article.objects.filter(**filter_kwargs)
|
||
|
||
# 生成文章内容哈希
|
||
content_signature = self._get_articles_content_signature(articles)
|
||
|
||
# 缓存key包含租户ID和内容签名
|
||
cache_key = f"rag_kb_articles_{tenant_id}_{content_signature}"
|
||
print(f"🔑 生成文章缓存key: {cache_key}")
|
||
print(f"📊 内容签名: {content_signature}")
|
||
|
||
return cache_key
|
||
|
||
def _get_articles_content_signature(self, articles) -> str:
|
||
"""生成文章内容的签名"""
|
||
signatures = []
|
||
|
||
for article in articles:
|
||
# 基于文章ID、标题、完整内容、选项等生成签名
|
||
title = article.title or ''
|
||
body = article.body or ''
|
||
options = article.options or ''
|
||
|
||
# 组合签名要素 - 直接哈希完整内容
|
||
article_signature = f"{article.id}_{title}_{body}_{options}"
|
||
signatures.append(hashlib.md5(article_signature.encode('utf-8')).hexdigest())
|
||
|
||
# 排序确保一致性
|
||
signatures.sort()
|
||
combined = '|'.join(signatures)
|
||
|
||
# 生成最终签名
|
||
final_signature = hashlib.md5(combined.encode('utf-8')).hexdigest()
|
||
return final_signature
|
||
|
||
def _clean_html_content(self, html_content: str) -> str:
|
||
"""清理HTML内容,提取纯文本"""
|
||
if not html_content:
|
||
return ""
|
||
|
||
# 使用BeautifulSoup解析HTML
|
||
soup = BeautifulSoup(html_content, 'html.parser')
|
||
|
||
# 移除script、style、head、title、meta、link等标签
|
||
for tag in soup(["script", "style", "head", "title", "meta", "link", "noscript"]):
|
||
tag.decompose()
|
||
|
||
# 获取纯文本
|
||
text = soup.get_text()
|
||
|
||
# 清理空白字符:移除多余的空行和空格
|
||
lines = (line.strip() for line in text.splitlines())
|
||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||
text = ' '.join(chunk for chunk in chunks if chunk)
|
||
|
||
return text
|
||
|
||
def _get_vector_cache_key(self, articles: List[Article], tenant_id: int = None) -> str:
|
||
"""生成向量缓存key"""
|
||
# 基于文章集合签名生成缓存key
|
||
signature = self._get_articles_signature(articles)
|
||
cache_key = f"rag_vectors_{tenant_id}_{signature}"
|
||
print(f"🔑 生成向量缓存key: {cache_key}")
|
||
return cache_key
|
||
|
||
def _create_vectorstore(self, articles: List[Article]):
|
||
"""创建向量存储"""
|
||
create_start_time = time.time()
|
||
print(f"🔧 创建文档对象...")
|
||
|
||
docs_start_time = time.time()
|
||
documents = self._create_documents(articles)
|
||
docs_time = time.time() - docs_start_time
|
||
print(f"⏱️ 文档创建耗时: {docs_time:.3f}秒")
|
||
|
||
if not documents:
|
||
print("❌ 无法创建文档对象")
|
||
return None
|
||
|
||
print(f"📄 创建了 {len(documents)} 个文档片段")
|
||
print(f"🎯 开始向量化...")
|
||
|
||
embedding_start_time = time.time()
|
||
vectorstore = FAISS.from_documents(
|
||
documents=documents,
|
||
embedding=self.embeddings
|
||
)
|
||
embedding_time = time.time() - embedding_start_time
|
||
print(f"⏱️ 向量化耗时: {embedding_time:.3f}秒")
|
||
|
||
total_time = time.time() - create_start_time
|
||
print(f"✅ 向量存储创建完成,包含 {len(documents)} 个向量")
|
||
print(f"⏱️ 向量存储创建总耗时: {total_time:.3f}秒")
|
||
return vectorstore
|
||
|
||
def _create_documents(self, articles: List[Article]) -> List[Document]:
|
||
"""创建文档对象"""
|
||
docs_start_time = time.time()
|
||
documents = []
|
||
|
||
for article in articles:
|
||
try:
|
||
article_start_time = time.time()
|
||
print(f" 📝 处理文章: {article.title or '无标题'} (ID: {article.id})")
|
||
|
||
# 清理HTML内容
|
||
clean_start_time = time.time()
|
||
clean_body = self._clean_html_content(article.body)
|
||
clean_time = time.time() - clean_start_time
|
||
print(f" 🧹 HTML清理耗时: {clean_time:.3f}秒")
|
||
|
||
split_start_time = time.time()
|
||
chunks = self.text_splitter.split_text(clean_body)
|
||
split_time = time.time() - split_start_time
|
||
print(f" ✂️ 分割为 {len(chunks)} 个片段,耗时: {split_time:.3f}秒")
|
||
|
||
valid_chunks = 0
|
||
for i, chunk in enumerate(chunks):
|
||
if len(chunk.strip()) < 50:
|
||
continue
|
||
|
||
doc = Document(
|
||
page_content=chunk,
|
||
metadata={
|
||
'article_id': article.id,
|
||
'title': article.title or '无标题',
|
||
'tenant_id': article.tenant_id,
|
||
'chunk_index': i,
|
||
'url': article.url,
|
||
'source': 'platform_knowledge_base',
|
||
'html_cleaned': True
|
||
}
|
||
)
|
||
documents.append(doc)
|
||
valid_chunks += 1
|
||
|
||
article_time = time.time() - article_start_time
|
||
print(f" ✅ 有效片段: {valid_chunks}/{len(chunks)},文章处理耗时: {article_time:.3f}秒")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 处理文章 {article.id} 失败: {str(e)}")
|
||
continue
|
||
|
||
total_time = time.time() - docs_start_time
|
||
print(f"📊 总共创建了 {len(documents)} 个有效文档片段")
|
||
print(f"⏱️ 文档创建总耗时: {total_time:.3f}秒")
|
||
return documents
|
||
|
||
def _cache_vectorstore(self, vectorstore, cache_key: str):
|
||
"""缓存向量存储"""
|
||
try:
|
||
print(f"💾 序列化向量存储...")
|
||
serialize_start_time = time.time()
|
||
# 序列化向量存储
|
||
serialized_data = pickle.dumps(vectorstore)
|
||
serialize_time = time.time() - serialize_start_time
|
||
print(f"📦 序列化完成,大小: {len(serialized_data)} 字节,耗时: {serialize_time:.3f}秒")
|
||
|
||
cache_start_time = time.time()
|
||
cache.set(cache_key, serialized_data, self.vector_cache_timeout)
|
||
cache_time = time.time() - cache_start_time
|
||
print(f"✅ 向量存储已缓存,过期时间: {self.vector_cache_timeout}秒,缓存写入耗时: {cache_time:.3f}秒")
|
||
except Exception as e:
|
||
print(f"❌ 缓存向量存储失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
def _restore_vectorstore(self, serialized_data: bytes):
|
||
"""从缓存恢复向量存储"""
|
||
try:
|
||
restore_start_time = time.time()
|
||
result = pickle.loads(serialized_data)
|
||
restore_time = time.time() - restore_start_time
|
||
print(f"⏱️ 向量存储反序列化耗时: {restore_time:.3f}秒")
|
||
return result
|
||
except Exception as e:
|
||
print(f"❌ 恢复向量存储失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
def get_enhanced_context(self, query: str, max_results: int = 3, tenant_id: int = None) -> str:
|
||
"""获取增强的上下文信息"""
|
||
context_start_time = time.time()
|
||
print(f"\n📖 开始构建增强上下文...")
|
||
|
||
results = self.search(query, max_results, tenant_id)
|
||
|
||
if not results:
|
||
total_time = time.time() - context_start_time
|
||
print(f"⏱️ 上下文构建耗时: {total_time:.3f}秒")
|
||
return "未找到相关知识库内容。"
|
||
|
||
# 构建上下文
|
||
build_start_time = time.time()
|
||
context = "基于知识库检索到的相关信息:\n\n"
|
||
for i, result in enumerate(results, 1):
|
||
context += f"【{i}】{result['title']}\n"
|
||
context += f"相关度:{result['score']:.3f}\n"
|
||
context += f"内容:{result['content']}\n\n"
|
||
|
||
build_time = time.time() - build_start_time
|
||
total_time = time.time() - context_start_time
|
||
print(f"⏱️ 上下文构建耗时: {build_time:.3f}秒")
|
||
print(f"⏱️ 上下文构建总耗时: {total_time:.3f}秒")
|
||
|
||
return context
|
||
|
||
def clear_old_cache(self, tenant_id: int = None):
|
||
"""清理旧的缓存,避免缓存碎片"""
|
||
try:
|
||
# 清理旧的搜索缓存
|
||
old_search_pattern = f"rag_search_*_{tenant_id}_*"
|
||
self._clear_cache_by_pattern(old_search_pattern)
|
||
|
||
# 清理旧的向量缓存
|
||
old_vector_pattern = f"rag_vectors_{tenant_id}_*"
|
||
self._clear_cache_by_pattern(old_vector_pattern)
|
||
|
||
print(f"🧹 已清理租户 {tenant_id} 的旧缓存")
|
||
except Exception as e:
|
||
print(f"⚠️ 清理旧缓存失败: {e}")
|
||
|
||
def _clear_cache_by_pattern(self, pattern: str):
|
||
"""根据模式清理缓存(Django缓存限制,这里只是标记)"""
|
||
# 注意:Django的默认缓存后端不支持模式删除
|
||
# 这里标记需要清理的缓存,实际清理在下次访问时处理
|
||
cache.set(f"rag_cache_invalidated_{pattern}", time.time(), 300) # 5分钟标记
|
||
|
||
def _is_cache_invalidated(self, cache_key: str) -> bool:
|
||
"""检查缓存是否被标记为失效"""
|
||
# 检查是否有相关的失效标记
|
||
for pattern in ["rag_search_", "rag_vectors_", "rag_kb_articles_"]:
|
||
if pattern in cache_key:
|
||
invalidated_key = f"rag_cache_invalidated_{pattern}*"
|
||
if cache.get(invalidated_key):
|
||
return True
|
||
return False
|