首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >golang实现完整的rag

golang实现完整的rag

作者头像
golangLeetcode
发布2026-03-18 16:57:59
发布2026-03-18 16:57:59
1350
举报

在完成前面的类容后,这里开始基于langchain实现完整的rag。首先我们补充下向量化的过程,是通过调用llm.CreateEmbedding函数来实现的:

代码语言:javascript
复制
package main
import (
        "context"
        "fmt"
        "log"
        "github.com/tmc/langchaingo/llms/ollama"
)
func main() {
        llm, err := ollama.New(ollama.WithModel("nomic-embed-text:latest"))
        if err != nil {
                log.Fatal(err)
        }
        ctx := context.Background()
        inputText := "The sky is blue because of Rayleigh scattering"
        result, err := llm.CreateEmbedding(ctx, []string{inputText})
        if err != nil {
                log.Fatal(err)
        }
        fmt.Printf("%#v\n", result)
        fmt.Printf("%d\n", len(result[0]))
}

剩下的就是将前面的内容整合起来,

代码语言:javascript
复制
package rag
import (
    "context"
    "fmt"
    "net/url"
    "github.com/tmc/langchaingo/agents"
    "github.com/tmc/langchaingo/chains"
    "github.com/tmc/langchaingo/embeddings"
    "github.com/tmc/langchaingo/llms"
    "github.com/tmc/langchaingo/llms/ollama"
    "github.com/tmc/langchaingo/memory"
    "github.com/hantmac/langchaingo-ollama-rag/rag/logger"
    "github.com/tmc/langchaingo/schema"
    "github.com/tmc/langchaingo/vectorstores"
    "github.com/tmc/langchaingo/vectorstores/qdrant"
)
var (
    collectionName = "langchaingo-ollama-rag"
    qdrantUrl      = "http://localhost:6333"
    ollamaServer   = "http://localhost:11434"
)
// GetOllamaEmbedder 获取ollama嵌入器
func getollamaEmbedder() *embeddings.EmbedderImpl {
    // 创建一个新的ollama模型,模型名为"nomic-embed-text:latest"
    ollamaEmbedderModel, err := ollama.New(
        ollama.WithModel("nomic-embed-text:latest"),
        ollama.WithServerURL(ollamaServer))
    if err != nil {
        logger.Fatal("创建ollama模型失败: %v", err)
    }
    // 使用创建的ollama模型创建一个新的嵌入器
    ollamaEmbedder, err := embeddings.NewEmbedder(ollamaEmbedderModel)
    if err != nil {
        logger.Fatal("创建ollama嵌入器失败: %v", err)
    }
    return ollamaEmbedder
}

func getOllamaDeepseek() *ollama.LLM {
    // 创建一个新的ollama模型,模型名为"deepseek-r1:1.5b"
    llm, err := ollama.New(
        ollama.WithModel("deepseek-r1:1.5b"), 
        ollama.WithServerURL(ollamaServer))
    if err != nil {
        logger.Fatal("创建ollama模型失败: %v", err)
    }
    return llm
}
// getStore 获取存储对象
func getStore() *qdrant.Store {
    // 解析URL
    qdUrl, err := url.Parse(qdrantUrl)
    if err != nil {
        logger.Fatal("解析URL失败: %v", err)
    }
    // 创建新的qdrant存储
    store, err := qdrant.New(
        qdrant.WithURL(*qdUrl),                    // 设置URL
        qdrant.WithAPIKey(""),                     // 设置API密钥
        qdrant.WithCollectionName(collectionName), // 设置集合名称
        qdrant.WithEmbedder(getollamaEmbedder()),  // 设置嵌入器
    )
    if err != nil {
        logger.Fatal("创建qdrant存储失败: %v", err)
    }
    return &store
}
// storeDocs 将文档存储到向量数据库
func storeDocs(docs []schema.Document, store *qdrant.Store) error {
    // 如果文档数组长度大于0
    if len(docs) > 0 {
        // 添加文档到存储
        fmt.Println("添加文档到存储", docs)
        res, err := store.AddDocuments(context.Background(), docs)
        fmt.Println(res)
        if err != nil {
            return err
        }
    }
    return nil
}
// useRetriaver 函数使用检索器
func useRetriaver(store *qdrant.Store, prompt string, topk int) ([]schema.Document, error) {
    // 设置选项向量
    optionsVector := []vectorstores.Option{
        vectorstores.WithScoreThreshold(0.80), // 设置分数阈值
    }
    // 创建检索器
    retriever := vectorstores.ToRetriever(store, topk, optionsVector...)
    // 搜索
    docRetrieved, err := retriever.GetRelevantDocuments(context.Background(), prompt)
    if err != nil {
        return nil, fmt.Errorf("检索文档失败: %v", err)
    }
    fmt.Println(docRetrieved)
    // 返回检索到的文档
    return docRetrieved, nil
}
// GetAnswer 获取答案
func GetAnswer(ctx context.Context, llm llms.Model, docRetrieved []schema.Document, prompt string) (string, error) {
    // 创建一个新的聊天消息历史记录
    history := memory.NewChatMessageHistory()
    // 将检索到的文档添加到历史记录中
    for _, doc := range docRetrieved {
        history.AddAIMessage(ctx, doc.PageContent)
    }
    // 使用历史记录创建一个新的对话缓冲区
    conversation := memory.NewConversationBuffer(memory.WithChatHistory(history))
    executor := agents.NewExecutor(
        agents.NewConversationalAgent(llm, nil),
        nil,
        agents.WithMemory(conversation),
    )
    // 设置链调用选项
    options := []chains.ChainCallOption{
        chains.WithTemperature(0.8),
    }
    // 运行链
    res, err := chains.Run(ctx, executor, prompt, options...)
    if err != nil {
        return "", err
    }
    return res, nil
}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-04-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档