Skip to content

Commit 4d5cd91

Browse files
authored
Add multi recall of semantic search for pipelines (#3864)
* Add multi recall of semantic search for pipelines * Update multi recall semantic search README.md * remove unused imports * remove unused imports * Update __init__.py * remove unused imports * restore __init__.py * skip retriever __init__.py
1 parent a9f815e commit 4d5cd91

File tree

12 files changed

+1126
-171
lines changed

12 files changed

+1126
-171
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# 端到端两路召回语义检索系统
2+
3+
## 1. 概述
4+
5+
多路召回是指采用不同的策略、特征或者简单的模型,分别召回一部分候选集合,然后把这些候选集混合在一起供后续的排序模型进行重排,也可以定制自己的重排序的规则等等。本项目使用关键字和语义检索两路召回的检索系统,系统的架构如下,用户输入的Query会分别通过关键字召回BMRetriever(Okapi BM 25算法,Elasticsearch默认使用的相关度评分算法,是基于词频和文档频率和文档长度相关性来计算相关度),语义向量检索召回DenseRetriever(使用RocketQA抽取向量,然后比较向量之间相似度)后得到候选集,然后通过JoinResults进行结果聚合,最后通过通用的Ranker模块得到重排序的结果返回给用户。
6+
7+
<div align="center">
8+
<img src="https://user-images.githubusercontent.com/12107462/204423532-90f62781-5f81-4b6d-9f94-741416ae3fcb.png" width="500px">
9+
</div>
10+
11+
## 2. 产品功能介绍
12+
13+
本项目提供了低成本搭建端到端两路召回语义检索系统的能力。用户只需要处理好自己的业务数据,就可以使用本项目预置的两路召回语义检索系统模型(召回模型、排序模型)快速搭建一个针对自己业务数据的检索系统,并可以提供 Web 化产品服务。
14+
15+
<div align="center">
16+
<img src="https://user-images.githubusercontent.com/12107462/204435911-0ba1cb9f-cb56-4bcd-9f64-63ff173826d6.png" width="500px">
17+
</div>
18+
19+
## 3. 快速开始: 快速搭建两路召回语义检索系统
20+
21+
### 3.1 运行环境和安装说明
22+
23+
本实验采用了以下的运行环境进行,详细说明如下,用户也可以在自己 GPU 硬件环境进行:
24+
25+
a. 软件环境:
26+
- python >= 3.7.0
27+
- paddlenlp >= 2.4.3
28+
- paddlepaddle-gpu >=2.3
29+
- CUDA Version: 10.2
30+
- NVIDIA Driver Version: 440.64.00
31+
- Ubuntu 16.04.6 LTS (Docker)
32+
33+
b. 硬件环境:
34+
35+
- NVIDIA Tesla V100 16GB x4卡
36+
- Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
37+
38+
c. 依赖安装:
39+
首先需要安装PaddlePaddle,PaddlePaddle的安装请参考文档[官方安装文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),然后安装下面的依赖:
40+
```bash
41+
# pip 一键安装
42+
pip install --upgrade paddle-pipelines -i https://pypi.tuna.tsinghua.edu.cn/simple
43+
# 或者源码进行安装最新版本
44+
cd ${HOME}/PaddleNLP/pipelines/
45+
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
46+
python setup.py install
47+
```
48+
49+
【注意】
50+
51+
- Windows的安装复杂一点,教程请参考:[Windows视频安装教程](https://www.bilibili.com/video/BV1DY4y1M7HE/?zw)
52+
- 以下的所有的流程都只需要在`pipelines`根目录下进行,不需要跳转目录
53+
54+
### 3.2 数据说明
55+
56+
语义检索数据库的数据来自于[DuReader-Robust数据集](https://github.com/baidu/DuReader/tree/master/DuReader-Robust),共包含 46972 个段落文本,并选取了其中验证集1417条段落文本来搭建语义检索系统。
57+
58+
### 3.3 一键体验语义检索系统
59+
60+
#### 3.3.1 启动 ANN 服务
61+
1. 参考官方文档下载安装 [elasticsearch-8.3.2](https://www.elastic.co/cn/downloads/elasticsearch) 并解压。
62+
2. 启动 ES 服务
63+
首先修改`config/elasticsearch.yml`的配置:
64+
```
65+
xpack.security.enabled: false
66+
```
67+
然后启动:
68+
```bash
69+
./bin/elasticsearch
70+
```
71+
3. 检查确保 ES 服务启动成功
72+
```bash
73+
curl http://localhost:9200/_aliases?pretty=true
74+
```
75+
备注:ES 服务默认开启端口为 9200
76+
77+
#### 3.3.2 快速一键启动
78+
79+
我们预置了基于[DuReader-Robust数据集](https://github.com/baidu/DuReader/tree/master/DuReader-Robust)搭建语义检索系统的代码示例,您可以通过如下命令快速体验语义检索系统的效果
80+
```bash
81+
# 我们建议在 GPU 环境下运行本示例,运行速度较快
82+
# 设置 1 个空闲的 GPU 卡,此处假设 0 卡为空闲 GPU
83+
export CUDA_VISIBLE_DEVICES=0
84+
python examples/semantic-search/multi_recall_semantic_search_example.py --device gpu \
85+
--search_engine elastic
86+
# 如果只有 CPU 机器,可以通过 --device 参数指定 cpu 即可, 运行耗时较长
87+
unset CUDA_VISIBLE_DEVICES
88+
python examples/semantic-search/multi_recall_semantic_search_example.py --device cpu \
89+
--search_engine elastic
90+
```
91+
`multi_recall_semantic_search_example.py``DensePassageRetriever``ErnieRanker`的模型介绍请参考[API介绍](../../API.md)
92+
93+
参数含义说明
94+
* `device`: 设备名称,cpu/gpu,默认为gpu
95+
* `index_name`: 索引的名称
96+
* `search_engine`: 选择的近似索引引擎elastic,milvus,默认elastic
97+
* `max_seq_len_query`: query的最大长度,默认是64
98+
* `max_seq_len_passage`: passage的最大长度,默认是384
99+
* `retriever_batch_size`: 召回模型一次处理的数据的数量
100+
* `query_embedding_model`: query模型的名称,默认为rocketqa-zh-nano-query-encoder
101+
* `passage_embedding_model`: 段落模型的名称,默认为rocketqa-zh-nano-para-encoder
102+
* `params_path`: Neural Search的召回模型的名称,默认为
103+
* `embedding_dim`: 模型抽取的向量的维度,默认为312,为rocketqa-zh-nano-query-encoder的向量维度
104+
* `host`: ANN索引引擎的IP地址
105+
* `port`: ANN索引引擎的端口号
106+
* `bm_topk`: 关键字召回节点BM25Retriever的召回数量
107+
* `dense_topk`: 语义向量召回节点DensePassageRetriever的召回数量
108+
* `rank_topk`: 排序模型节点ErnieRanker的排序过滤数量
109+
110+
### 3.4 构建 Web 可视化语义检索系统
111+
112+
整个 Web 可视化语义检索系统主要包含 3 大组件: 1. 基于 ElasticSearch 的 ANN 服务 2. 基于 RestAPI 构建模型服务 3. 基于 Streamlit 构建 WebUI,搭建ANN服务请参考1.3.1节,接下来我们依次搭建后台和前端两个服务。
113+
114+
#### 3.4.1 文档数据写入 ANN 索引库
115+
```
116+
# 以DuReader-Robust 数据集为例建立 ANN 索引库
117+
python utils/offline_ann.py --index_name dureader_nano_query_encoder \
118+
--doc_dir data/dureader_dev \
119+
--search_engine elastic \
120+
--delete_index
121+
```
122+
可以使用下面的命令来查看数据:
123+
124+
```
125+
# 打印几条数据
126+
curl http://localhost:9200/dureader_nano_query_encoder/_search
127+
```
128+
129+
参数含义说明
130+
* `index_name`: 索引的名称
131+
* `doc_dir`: txt文本数据的路径
132+
* `host`: ANN索引引擎的IP地址
133+
* `port`: ANN索引引擎的端口号
134+
* `search_engine`: 选择的近似索引引擎elastic,milvus,默认elastic
135+
* `delete_index`: 是否删除现有的索引和数据,用于清空es的数据,默认为false
136+
137+
#### 3.4.2 启动 RestAPI 模型服务
138+
```bash
139+
# 指定语义检索系统的Yaml配置文件
140+
export PIPELINE_YAML_PATH=rest_api/pipeline/multi_recall_semantic_search.yaml
141+
# 使用端口号 8891 启动模型服务
142+
python rest_api/application.py 8891
143+
```
144+
启动后可以使用curl命令验证是否成功运行:
145+
146+
```
147+
curl -X POST -k http://localhost:8891/query -H 'Content-Type: application/json' -d '{"query": "衡量酒水的价格的因素有哪些?","params": {"BMRetriever": {"top_k": 10}, "DenseRetriever": {"top_k": 10}, "Ranker":{"top_k": 3}}}'
148+
```
149+
#### 3.4.3 启动 WebUI
150+
```bash
151+
# 配置模型服务地址
152+
export API_ENDPOINT=http://127.0.0.1:8891
153+
# 在指定端口 8502 启动 WebUI
154+
python -m streamlit run ui/webapp_multi_recall_semantic_search.py --server.port 8502
155+
```
156+
157+
到这里您就可以打开浏览器访问 http://127.0.0.1:8502 地址体验语义检索系统服务了。
158+
159+
#### 3.4.4 数据更新
160+
161+
数据更新的方法有两种,第一种使用前面的 `utils/offline_ann.py`进行数据更新,第二种是使用前端界面的文件上传(在界面的左侧)进行数据更新。对于第一种使用脚本的方式,可以使用多种文件更新数据,示例的文件更新建索引的命令如下,里面包含了图片(目前仅支持把图中所有的文字合并建立索引),docx(支持图文,需要按照空行进行划分段落),txt(需要按照空行划分段落)三种格式的文件建索引:
162+
163+
```
164+
python utils/offline_ann.py --index_name dureader_robust_query_encoder \
165+
--doc_dir data/file_example \
166+
--port 9200 \
167+
--search_engine elastic \
168+
--delete_index
169+
```
170+
171+
对于第二种使用界面的方式,支持txt,pdf,image,word的格式,以txt格式的文件为例,每段文本需要使用空行隔开,程序会根据空行进行分段建立索引,示例数据如下(demo.txt):
172+
173+
```
174+
兴证策略认为,最恐慌的时候已经过去,未来一个月市场迎来阶段性修复窗口。
175+
176+
从海外市场表现看,
177+
对俄乌冲突的恐慌情绪已显著释放,
178+
海外权益市场也从单边下跌转入双向波动。
179+
180+
长期,继续聚焦科技创新的五大方向。1)新能源(新能源汽车、光伏、风电、特高压等),2)新一代信息通信技术(人工智能、大数据、云计算、5G等),3)高端制造(智能数控机床、机器人、先进轨交装备等),4)生物医药(创新药、CXO、医疗器械和诊断设备等),5)军工(导弹设备、军工电子元器件、空间站、航天飞机等)。
181+
```
182+
如果安装遇见问题可以查看[FAQ文档](../../FAQ.md)
183+
184+
## Reference
185+
[1]Y. Sun et al., “[ERNIE 3.0: Large-scale Knowledge Enhanced Pre-training for Language Understanding and Generation](https://arxiv.org/pdf/2107.02137.pdf),” arXiv:2107.02137 [cs], Jul. 2021, Accessed: Jan. 17, 2022. [Online]. Available: http://arxiv.org/abs/2107.02137
186+
187+
[2]Y. Qu et al., “[RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2010.08191),” arXiv:2010.08191 [cs], May 2021, Accessed: Aug. 16, 2021. [Online]. Available: http://arxiv.org/abs/2010.08191
188+
189+
[3]H. Tang, H. Li, J. Liu, Y. Hong, H. Wu, and H. Wang, “[DuReader_robust: A Chinese Dataset Towards Evaluating Robustness and Generalization of Machine Reading Comprehension in Real-World Applications](https://arxiv.org/pdf/2004.11142.pdf).” arXiv, Jul. 21, 2021. Accessed: May 15, 2022. [Online]. Available: http://arxiv.org/abs/2004.11142
190+
191+
## Acknowledge
192+
193+
我们借鉴了 Deepset.ai [Haystack](https://github.com/deepset-ai/haystack) 优秀的框架设计,在此对[Haystack](https://github.com/deepset-ai/haystack)作者及其开源社区表示感谢。
194+
195+
We learn form the excellent framework design of Deepset.ai [Haystack](https://github.com/deepset-ai/haystack), and we would like to express our thanks to the authors of Haystack and their open source community.

pipelines/examples/semantic-search/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
车头如何放置车牌 后牌照怎么装
1414
```
1515

16-
语义检索系统的关键就在于,采用语义而非关键词方式进行召回,达到更精准、更广泛得召回相似结果的目的。
16+
语义检索系统的关键就在于,采用语义而非关键词方式进行召回,达到更精准、更广泛得召回相似结果的目的。如果需要关键字和语义检索两种结合方式请参考文档[多路召回](./Multi_Recall.md)
1717

1818
## 2. 产品功能介绍
1919

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
from pipelines.document_stores import ElasticsearchDocumentStore
18+
from pipelines.nodes import (
19+
BM25Retriever,
20+
DensePassageRetriever,
21+
ErnieRanker,
22+
JoinDocuments,
23+
)
24+
from pipelines.pipelines import Pipeline
25+
from pipelines.utils import (
26+
convert_files_to_dicts,
27+
fetch_archive_from_http,
28+
print_documents,
29+
)
30+
31+
# yapf: disable
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.")
34+
parser.add_argument("--index_name", default='dureader_nano_query_encoder', type=str, help="The ann index name of ANN.")
35+
parser.add_argument("--search_engine", choices=['elastic'], default="elastic", help="The type of ANN search engine.")
36+
parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.")
37+
parser.add_argument("--max_seq_len_passage", default=384, type=int, help="The maximum total length of passage after tokenization.")
38+
parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.")
39+
parser.add_argument("--query_embedding_model", default="rocketqa-zh-nano-query-encoder", type=str, help="The query_embedding_model path")
40+
parser.add_argument("--passage_embedding_model", default="rocketqa-zh-nano-para-encoder", type=str, help="The passage_embedding_model path")
41+
parser.add_argument("--params_path", default="", type=str, help="The checkpoint path")
42+
parser.add_argument("--embedding_dim", default=312, type=int, help="The embedding_dim of index")
43+
parser.add_argument('--host', type=str, default="localhost", help='host ip of ANN search engine')
44+
parser.add_argument('--port', type=str, default="9200", help='port of ANN search engine')
45+
parser.add_argument("--bm_topk", default=10, type=int, help="The number of candidates for BM25Retriever to retrieve.")
46+
parser.add_argument("--dense_topk", default=10, type=int, help="The number of candidates for DensePassageRetriever to retrieve.")
47+
parser.add_argument("--rank_topk", default=10, type=int, help="The number of candidates ranker to filter.")
48+
49+
args = parser.parse_args()
50+
# yapf: enable
51+
52+
53+
def get_retrievers(use_gpu):
54+
55+
doc_dir = "data/dureader_dev"
56+
dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip"
57+
58+
fetch_archive_from_http(url=dureader_data, output_dir=doc_dir)
59+
dicts = convert_files_to_dicts(dir_path=doc_dir, split_paragraphs=True, encoding="utf-8")
60+
61+
document_store_with_docs = ElasticsearchDocumentStore(
62+
host=args.host,
63+
port=args.port,
64+
username="",
65+
password="",
66+
embedding_dim=312,
67+
search_fields=["content", "meta"],
68+
index=args.index_name,
69+
)
70+
document_store_with_docs.write_documents(dicts)
71+
72+
dpr_retriever = DensePassageRetriever(
73+
document_store=document_store_with_docs,
74+
query_embedding_model=args.query_embedding_model,
75+
passage_embedding_model=args.passage_embedding_model,
76+
params_path=args.params_path,
77+
output_emb_size=args.embedding_dim,
78+
max_seq_len_query=args.max_seq_len_query,
79+
max_seq_len_passage=args.max_seq_len_passage,
80+
batch_size=args.retriever_batch_size,
81+
use_gpu=use_gpu,
82+
embed_title=False,
83+
)
84+
# update Embedding
85+
document_store_with_docs.update_embeddings(dpr_retriever)
86+
87+
bm_retriever = BM25Retriever(document_store=document_store_with_docs)
88+
89+
return dpr_retriever, bm_retriever
90+
91+
92+
def semantic_search_tutorial():
93+
94+
use_gpu = True if args.device == "gpu" else False
95+
96+
dpr_retriever, bm_retriever = get_retrievers(use_gpu)
97+
98+
# Ranker
99+
ranker = ErnieRanker(model_name_or_path="rocketqa-nano-cross-encoder", use_gpu=use_gpu)
100+
101+
# Pipeline
102+
pipeline = Pipeline()
103+
pipeline.add_node(component=bm_retriever, name="BMRetriever", inputs=["Query"])
104+
pipeline.add_node(component=dpr_retriever, name="DenseRetriever", inputs=["Query"])
105+
pipeline.add_node(
106+
component=JoinDocuments(join_mode="concatenate"), name="JoinResults", inputs=["BMRetriever", "DenseRetriever"]
107+
)
108+
pipeline.add_node(component=ranker, name="Ranker", inputs=["JoinResults"])
109+
# Keywords recall results
110+
prediction = pipeline.run(
111+
query="广播权", params={"BMRetriever": {"top_k": 10}, "DenseRetriever": {"top_k": 10}, "Ranker": {"top_k": 3}}
112+
)
113+
print_documents(prediction)
114+
# Dense vector recall results
115+
prediction = pipeline.run(
116+
query="期货交易手续费指的是什么?",
117+
params={"BMRetriever": {"top_k": 10}, "DenseRetriever": {"top_k": 10}, "Ranker": {"top_k": 3}},
118+
)
119+
pipeline.draw("multi_recall.png")
120+
print_documents(prediction)
121+
122+
123+
if __name__ == "__main__":
124+
semantic_search_tutorial()

pipelines/pipelines/document_stores/__init__.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,33 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
16-
import os
15+
# flake8: noqa
1716
import importlib
17+
import os
18+
19+
from pipelines.document_stores.base import (
20+
BaseDocumentStore,
21+
BaseKnowledgeGraph,
22+
KeywordDocumentStore,
23+
)
1824
from pipelines.utils.import_utils import safe_import
19-
from pipelines.document_stores.base import BaseDocumentStore, BaseKnowledgeGraph, KeywordDocumentStore
2025

2126
ElasticsearchDocumentStore = safe_import(
22-
"pipelines.document_stores.elasticsearch", "ElasticsearchDocumentStore",
23-
"elasticsearch")
27+
"pipelines.document_stores.elasticsearch", "ElasticsearchDocumentStore", "elasticsearch"
28+
)
2429
OpenDistroElasticsearchDocumentStore = safe_import(
25-
"pipelines.document_stores.elasticsearch",
26-
"OpenDistroElasticsearchDocumentStore", "elasticsearch")
27-
OpenSearchDocumentStore = safe_import("pipelines.document_stores.elasticsearch",
28-
"OpenSearchDocumentStore",
29-
"elasticsearch")
30+
"pipelines.document_stores.elasticsearch", "OpenDistroElasticsearchDocumentStore", "elasticsearch"
31+
)
32+
OpenSearchDocumentStore = safe_import(
33+
"pipelines.document_stores.elasticsearch", "OpenSearchDocumentStore", "elasticsearch"
34+
)
3035

31-
FAISSDocumentStore = safe_import("pipelines.document_stores.faiss",
32-
"FAISSDocumentStore", "faiss")
36+
FAISSDocumentStore = safe_import("pipelines.document_stores.faiss", "FAISSDocumentStore", "faiss")
3337

34-
MilvusDocumentStore = safe_import("pipelines.document_stores.milvus2",
35-
"Milvus2DocumentStore", "milvus")
38+
MilvusDocumentStore = safe_import("pipelines.document_stores.milvus2", "Milvus2DocumentStore", "milvus")
3639

3740
from pipelines.document_stores.utils import (
41+
es_index_to_document_store,
3842
eval_data_from_json,
3943
eval_data_from_jsonl,
40-
es_index_to_document_store,
4144
)

0 commit comments

Comments
 (0)