diff --git a/server/core/retriever.go b/server/core/retriever.go index 6094a02..679df0c 100644 --- a/server/core/retriever.go +++ b/server/core/retriever.go @@ -14,6 +14,7 @@ import ( "github.com/qdrant/go-client/qdrant" "github.com/wangle201210/go-rag/server/core/common" "github.com/wangle201210/go-rag/server/core/rerank" + "github.com/wangle201210/go-rag/server/core/retriever" coretypes "github.com/wangle201210/go-rag/server/core/types" ) @@ -114,23 +115,34 @@ func (x *Rag) retrieveDoOnce(ctx context.Context, req *RetrieveReq) (relatedDocs qaDocs []*schema.Document ) g.Log().Infof(ctx, "query: %v", req.optQuery) - // 通过内容检索 + // 向量路一:content_vector docs, err = x.retrieve(ctx, req, false) if err != nil { g.Log().Errorf(ctx, "retrieve failed, err=%v", err) return } - // 通过qa检索 + // 向量路二:qa_content_vector qaDocs, err = x.retrieve(ctx, req, true) if err != nil { g.Log().Errorf(ctx, "qa retrieve failed, err=%v", err) return } - docs = append(docs, qaDocs...) - // 去重 - docs = common.RemoveDuplicates(docs, func(doc *schema.Document) string { - return doc.ID - }) + + inputs := [][]*schema.Document{docs, qaDocs} + + // 关键字路三:ES BM25(conf.Client != nil 时生效) + if x.conf.Client != nil { + bm25Docs, bm25Err := retriever.Bm25Retrieve(ctx, x.conf.Client, x.conf.IndexName, req.KnowledgeName, req.optQuery, req.excludeIDs, esTopK) + if bm25Err != nil { + g.Log().Warningf(ctx, "bm25 retrieve failed, skip: %v", bm25Err) + } else { + inputs = append(inputs, bm25Docs) + } + } + + // RRF 融合多路结果(同时完成去重) + docs = retriever.RRFFusion(inputs) + // 重排 docs, err = rerank.NewRerank(ctx, req.optQuery, docs, req.TopK) if err != nil { diff --git a/server/core/retriever/bm25_retriever.go b/server/core/retriever/bm25_retriever.go new file mode 100644 index 0000000..53b4e71 --- /dev/null +++ b/server/core/retriever/bm25_retriever.go @@ -0,0 +1,64 @@ +package retriever + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/schema" + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + coretypes "github.com/wangle201210/go-rag/server/core/types" +) + +// Bm25Retrieve 使用 ES BM25 对 content 字段做全文检索。 +// 仅在 conf.ESClient != nil 时调用。 +func Bm25Retrieve(ctx context.Context, client *elasticsearch.Client, indexName, knowledgeName, query string, excludeIDs []string, topK int) ([]*schema.Document, error) { + must := []types.Query{ + {Match: map[string]types.MatchQuery{ + coretypes.FieldContent: {Query: query}, + }}, + {Bool: &types.BoolQuery{ + Must: []types.Query{ + {Match: map[string]types.MatchQuery{ + coretypes.KnowledgeName: {Query: knowledgeName}, + }}, + }, + }}, + } + + boolQuery := &types.BoolQuery{Must: must} + + if len(excludeIDs) > 0 { + boolQuery.MustNot = []types.Query{ + {Terms: &types.TermsQuery{ + TermsQuery: map[string]types.TermsQueryField{ + "_id": excludeIDs, + }, + }}, + } + } + + sreq := search.NewRequest() + sreq.Query = &types.Query{Bool: boolQuery} + sreq.Size = &topK + + resp, err := search.NewSearchFunc(client)(). + Index(indexName). + Request(sreq). + Do(ctx) + if err != nil { + return nil, fmt.Errorf("bm25 search failed: %w", err) + } + + var docs []*schema.Document + for _, hit := range resp.Hits.Hits { + doc, err := EsHit2Document(ctx, hit) + if err != nil { + continue + } + docs = append(docs, doc) + } + + return docs, nil +} diff --git a/server/core/retriever/rrf.go b/server/core/retriever/rrf.go new file mode 100644 index 0000000..b35db63 --- /dev/null +++ b/server/core/retriever/rrf.go @@ -0,0 +1,41 @@ +package retriever + +import ( + "sort" + + "github.com/cloudwego/eino/schema" +) + +// RRFFusion 对多路召回结果执行 RRF(Reciprocal Rank Fusion)融合。 +// inputs 每个元素是一路召回的有序文档列表(按相关性降序)。 +func RRFFusion(inputs [][]*schema.Document) []*schema.Document { + const k = 60 + + docScores := make(map[string]float64) + docMap := make(map[string]*schema.Document) + + for _, docs := range inputs { + for rank, doc := range docs { + if doc.ID == "" { + continue + } + docScores[doc.ID] += 1.0 / float64(k+rank+1) + if _, exists := docMap[doc.ID]; !exists { + docMap[doc.ID] = doc + } + } + } + + result := make([]*schema.Document, 0, len(docMap)) + for id, score := range docScores { + doc := docMap[id] + doc.WithScore(score) + result = append(result, doc) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Score() > result[j].Score() + }) + + return result +} diff --git a/server/internal/logic/rag/retriever.go b/server/internal/logic/rag/retriever.go index 49a0ca6..49c50e7 100644 --- a/server/internal/logic/rag/retriever.go +++ b/server/internal/logic/rag/retriever.go @@ -55,6 +55,19 @@ func init() { client = esStore.GetClient() } else if qdrantStore, ok := vectorStore.(*vector.QdrantVectorStore); ok { qdrantClient = qdrantStore.GetClient() + // Qdrant 模式下,若配置了 ES 地址则初始化 ES 客户端用于 BM25 混合检索 + esAddr := g.Cfg().MustGet(ctx, "vector.es.address").String() + if esAddr != "" { + client, err = elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{esAddr}, + Username: g.Cfg().MustGet(ctx, "vector.es.username").String(), + Password: g.Cfg().MustGet(ctx, "vector.es.password").String(), + }) + if err != nil { + g.Log().Warningf(ctx, "init ES client for BM25 failed, hybrid retrieval disabled: %v", err) + client = nil + } + } } ragSvr, err = core.New(ctx, &config.Config{