使用BGE Reranker模型计算文本对相关性:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer# 加载预训练模型与分词器(使用BAAI官方发布的reranker模型)
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model.eval() # 设置为推理模式def calculate_rerank_score(query, documents):"""计算查询与多个文档的相关性分数:param query: 查询文本,如:"熊猫是什么":param documents: 候选文档列表,如:["熊猫是熊科动物", "企鹅生活在南极"]:return: 包含分数和文档的排序列表"""# 构造文本对(格式:[[query, doc1], [query, doc2], ...])pairs = [[query, doc] for doc in documents]with torch.no_grad():# 批量编码文本对inputs = tokenizer(pairs,padding=True,truncation=True,return_tensors='pt',max_length=512 )# 获取模型输出outputs = model(**inputs)scores = torch.sigmoid(outputs.logits).squeeze().tolist() # 将logits转换为0-1概率值# 组合结果并按分数降序排序sorted_results = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)return sorted_results# 使用示例
if __name__ == "__main__":query = "What is the capital of France?"documents = ["Paris is the most populous city in France","Lyon is a major city in eastern France","The Eiffel Tower is located in Paris"]results = calculate_rerank_score(query, documents)# 打印结果print("Query:", query)for rank, (doc, score) in enumerate(results, 1):print(f"Rank {rank} (Score: {score:.4f}): {doc}")
- 输出示例:
Query: What is the capital of France?
Rank 1 (Score: 0.9872): Paris is the most populous city in France
Rank 2 (Score: 0.8531): The Eiffel Tower is located in Paris
Rank 3 (Score: 0.1023): Lyon is a major city in eastern France
- 关键实现细节说明:
-
模型选择:使用BAAI/bge-reranker-large模型,该模型专门针对查询-文档相关性任务训练,支持中英文混合场景
-
输入构造:将查询与每个文档组成二维列表,形成(query, doc)对,这种交叉编码方式能捕捉细粒度语义交互
-
分数计算:通过sigmoid函数将logits转换为0-1的概率值,分数越高表示相关性越强,0.5为判定阈值
-
批处理优化:通过padding=True和return_tensors='pt’实现批量推理,提升计算效率
-