From ecfbc393f6af34e26743134f45c04cc785c14eb9 Mon Sep 17 00:00:00 2001 From: zhangruijie <2542201615@qq.com> Date: Fri, 12 Jun 2026 09:34:56 +0800 Subject: [PATCH] =?UTF-8?q?-=20=E6=96=B0=E5=A2=9E=20retriever/bm25=5Fretri?= =?UTF-8?q?ever.go=EF=BC=9AES=20match=20query=20=E5=85=A8=E6=96=87?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=EF=BC=8C=E5=A4=8D=E7=94=A8=20EsHit2Document?= =?UTF-8?q?=20-=20=E6=96=B0=E5=A2=9E=20retriever/rrf.go=EF=BC=9ARRF=20?= =?UTF-8?q?=E8=9E=8D=E5=90=88=E7=AE=97=E6=B3=95=EF=BC=8C=E6=9B=BF=E6=8D=A2?= =?UTF-8?q?=E5=8E=9F=E5=85=88=E7=9A=84=20RemoveDuplicates=20-=20retrieveDo?= =?UTF-8?q?Once=EF=BC=9A=E4=B8=89=E8=B7=AF=E7=BB=93=E6=9E=9C=EF=BC=88conte?= =?UTF-8?q?nt=5Fvector=20/=20qa=5Fcontent=5Fvector=20/=20BM25=EF=BC=89?= =?UTF-8?q?=E7=94=A8=20RRF=20=E8=9E=8D=E5=90=88=E5=90=8E=E5=86=8D=20rerank?= =?UTF-8?q?=20-=20conf.Client=20=E4=B8=BA=E7=A9=BA=E6=97=B6=20BM25=20?= =?UTF-8?q?=E8=B7=AF=E8=87=AA=E5=8A=A8=E8=B7=B3=E8=BF=87=EF=BC=8C=E4=B8=8D?= =?UTF-8?q?=E5=BD=B1=E5=93=8D=E7=BA=AF=20Qdrant=20=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E7=9A=84=E7=8E=B0=E6=9C=89=E8=A1=8C=E4=B8=BA=20-=20Qdrant=20?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E4=B8=8B=E8=8B=A5=E9=85=8D=E7=BD=AE=20vector?= =?UTF-8?q?.es.address=20=E5=88=99=E8=87=AA=E5=8A=A8=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=20ES=20=E5=AE=A2=E6=88=B7=E7=AB=AF=E5=A1=AB=E5=85=A5?= =?UTF-8?q?=20conf.Client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/core/retriever.go | 26 +++++++--- server/core/retriever/bm25_retriever.go | 64 +++++++++++++++++++++++++ server/core/retriever/rrf.go | 41 ++++++++++++++++ server/internal/logic/rag/retriever.go | 13 +++++ 4 files changed, 137 insertions(+), 7 deletions(-) create mode 100644 server/core/retriever/bm25_retriever.go create mode 100644 server/core/retriever/rrf.go diff --git a/server/core/retriever.go b/server/core/retriever.go index 6094a02..679df0c 100644 --- a/server/core/retriever.go +++ b/server/core/retriever.go @@ -14,6 +14,7 @@ import ( "github.com/qdrant/go-client/qdrant" "github.com/wangle201210/go-rag/server/core/common" "github.com/wangle201210/go-rag/server/core/rerank" + "github.com/wangle201210/go-rag/server/core/retriever" coretypes "github.com/wangle201210/go-rag/server/core/types" ) @@ -114,23 +115,34 @@ func (x *Rag) retrieveDoOnce(ctx context.Context, req *RetrieveReq) (relatedDocs qaDocs []*schema.Document ) g.Log().Infof(ctx, "query: %v", req.optQuery) - // 通过内容检索 + // 向量路一:content_vector docs, err = x.retrieve(ctx, req, false) if err != nil { g.Log().Errorf(ctx, "retrieve failed, err=%v", err) return } - // 通过qa检索 + // 向量路二:qa_content_vector qaDocs, err = x.retrieve(ctx, req, true) if err != nil { g.Log().Errorf(ctx, "qa retrieve failed, err=%v", err) return } - docs = append(docs, qaDocs...) - // 去重 - docs = common.RemoveDuplicates(docs, func(doc *schema.Document) string { - return doc.ID - }) + + inputs := [][]*schema.Document{docs, qaDocs} + + // 关键字路三:ES BM25(conf.Client != nil 时生效) + if x.conf.Client != nil { + bm25Docs, bm25Err := retriever.Bm25Retrieve(ctx, x.conf.Client, x.conf.IndexName, req.KnowledgeName, req.optQuery, req.excludeIDs, esTopK) + if bm25Err != nil { + g.Log().Warningf(ctx, "bm25 retrieve failed, skip: %v", bm25Err) + } else { + inputs = append(inputs, bm25Docs) + } + } + + // RRF 融合多路结果(同时完成去重) + docs = retriever.RRFFusion(inputs) + // 重排 docs, err = rerank.NewRerank(ctx, req.optQuery, docs, req.TopK) if err != nil { diff --git a/server/core/retriever/bm25_retriever.go b/server/core/retriever/bm25_retriever.go new file mode 100644 index 0000000..53b4e71 --- /dev/null +++ b/server/core/retriever/bm25_retriever.go @@ -0,0 +1,64 @@ +package retriever + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/schema" + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + coretypes "github.com/wangle201210/go-rag/server/core/types" +) + +// Bm25Retrieve 使用 ES BM25 对 content 字段做全文检索。 +// 仅在 conf.ESClient != nil 时调用。 +func Bm25Retrieve(ctx context.Context, client *elasticsearch.Client, indexName, knowledgeName, query string, excludeIDs []string, topK int) ([]*schema.Document, error) { + must := []types.Query{ + {Match: map[string]types.MatchQuery{ + coretypes.FieldContent: {Query: query}, + }}, + {Bool: &types.BoolQuery{ + Must: []types.Query{ + {Match: map[string]types.MatchQuery{ + coretypes.KnowledgeName: {Query: knowledgeName}, + }}, + }, + }}, + } + + boolQuery := &types.BoolQuery{Must: must} + + if len(excludeIDs) > 0 { + boolQuery.MustNot = []types.Query{ + {Terms: &types.TermsQuery{ + TermsQuery: map[string]types.TermsQueryField{ + "_id": excludeIDs, + }, + }}, + } + } + + sreq := search.NewRequest() + sreq.Query = &types.Query{Bool: boolQuery} + sreq.Size = &topK + + resp, err := search.NewSearchFunc(client)(). + Index(indexName). + Request(sreq). + Do(ctx) + if err != nil { + return nil, fmt.Errorf("bm25 search failed: %w", err) + } + + var docs []*schema.Document + for _, hit := range resp.Hits.Hits { + doc, err := EsHit2Document(ctx, hit) + if err != nil { + continue + } + docs = append(docs, doc) + } + + return docs, nil +} diff --git a/server/core/retriever/rrf.go b/server/core/retriever/rrf.go new file mode 100644 index 0000000..b35db63 --- /dev/null +++ b/server/core/retriever/rrf.go @@ -0,0 +1,41 @@ +package retriever + +import ( + "sort" + + "github.com/cloudwego/eino/schema" +) + +// RRFFusion 对多路召回结果执行 RRF(Reciprocal Rank Fusion)融合。 +// inputs 每个元素是一路召回的有序文档列表(按相关性降序)。 +func RRFFusion(inputs [][]*schema.Document) []*schema.Document { + const k = 60 + + docScores := make(map[string]float64) + docMap := make(map[string]*schema.Document) + + for _, docs := range inputs { + for rank, doc := range docs { + if doc.ID == "" { + continue + } + docScores[doc.ID] += 1.0 / float64(k+rank+1) + if _, exists := docMap[doc.ID]; !exists { + docMap[doc.ID] = doc + } + } + } + + result := make([]*schema.Document, 0, len(docMap)) + for id, score := range docScores { + doc := docMap[id] + doc.WithScore(score) + result = append(result, doc) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Score() > result[j].Score() + }) + + return result +} diff --git a/server/internal/logic/rag/retriever.go b/server/internal/logic/rag/retriever.go index 49a0ca6..49c50e7 100644 --- a/server/internal/logic/rag/retriever.go +++ b/server/internal/logic/rag/retriever.go @@ -55,6 +55,19 @@ func init() { client = esStore.GetClient() } else if qdrantStore, ok := vectorStore.(*vector.QdrantVectorStore); ok { qdrantClient = qdrantStore.GetClient() + // Qdrant 模式下,若配置了 ES 地址则初始化 ES 客户端用于 BM25 混合检索 + esAddr := g.Cfg().MustGet(ctx, "vector.es.address").String() + if esAddr != "" { + client, err = elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{esAddr}, + Username: g.Cfg().MustGet(ctx, "vector.es.username").String(), + Password: g.Cfg().MustGet(ctx, "vector.es.password").String(), + }) + if err != nil { + g.Log().Warningf(ctx, "init ES client for BM25 failed, hybrid retrieval disabled: %v", err) + client = nil + } + } } ragSvr, err = core.New(ctx, &config.Config{