
在完成前面的类容后,这里开始基于langchain实现完整的rag。首先我们补充下向量化的过程,是通过调用llm.CreateEmbedding函数来实现的:
package main
import (
"context"
"fmt"
"log"
"github.com/tmc/langchaingo/llms/ollama"
)
func main() {
llm, err := ollama.New(ollama.WithModel("nomic-embed-text:latest"))
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
inputText := "The sky is blue because of Rayleigh scattering"
result, err := llm.CreateEmbedding(ctx, []string{inputText})
if err != nil {
log.Fatal(err)
}
fmt.Printf("%#v\n", result)
fmt.Printf("%d\n", len(result[0]))
}剩下的就是将前面的内容整合起来,
package rag
import (
"context"
"fmt"
"net/url"
"github.com/tmc/langchaingo/agents"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
"github.com/tmc/langchaingo/memory"
"github.com/hantmac/langchaingo-ollama-rag/rag/logger"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/qdrant"
)
var (
collectionName = "langchaingo-ollama-rag"
qdrantUrl = "http://localhost:6333"
ollamaServer = "http://localhost:11434"
)
// GetOllamaEmbedder 获取ollama嵌入器
func getollamaEmbedder() *embeddings.EmbedderImpl {
// 创建一个新的ollama模型,模型名为"nomic-embed-text:latest"
ollamaEmbedderModel, err := ollama.New(
ollama.WithModel("nomic-embed-text:latest"),
ollama.WithServerURL(ollamaServer))
if err != nil {
logger.Fatal("创建ollama模型失败: %v", err)
}
// 使用创建的ollama模型创建一个新的嵌入器
ollamaEmbedder, err := embeddings.NewEmbedder(ollamaEmbedderModel)
if err != nil {
logger.Fatal("创建ollama嵌入器失败: %v", err)
}
return ollamaEmbedder
}
func getOllamaDeepseek() *ollama.LLM {
// 创建一个新的ollama模型,模型名为"deepseek-r1:1.5b"
llm, err := ollama.New(
ollama.WithModel("deepseek-r1:1.5b"),
ollama.WithServerURL(ollamaServer))
if err != nil {
logger.Fatal("创建ollama模型失败: %v", err)
}
return llm
}
// getStore 获取存储对象
func getStore() *qdrant.Store {
// 解析URL
qdUrl, err := url.Parse(qdrantUrl)
if err != nil {
logger.Fatal("解析URL失败: %v", err)
}
// 创建新的qdrant存储
store, err := qdrant.New(
qdrant.WithURL(*qdUrl), // 设置URL
qdrant.WithAPIKey(""), // 设置API密钥
qdrant.WithCollectionName(collectionName), // 设置集合名称
qdrant.WithEmbedder(getollamaEmbedder()), // 设置嵌入器
)
if err != nil {
logger.Fatal("创建qdrant存储失败: %v", err)
}
return &store
}
// storeDocs 将文档存储到向量数据库
func storeDocs(docs []schema.Document, store *qdrant.Store) error {
// 如果文档数组长度大于0
if len(docs) > 0 {
// 添加文档到存储
fmt.Println("添加文档到存储", docs)
res, err := store.AddDocuments(context.Background(), docs)
fmt.Println(res)
if err != nil {
return err
}
}
return nil
}
// useRetriaver 函数使用检索器
func useRetriaver(store *qdrant.Store, prompt string, topk int) ([]schema.Document, error) {
// 设置选项向量
optionsVector := []vectorstores.Option{
vectorstores.WithScoreThreshold(0.80), // 设置分数阈值
}
// 创建检索器
retriever := vectorstores.ToRetriever(store, topk, optionsVector...)
// 搜索
docRetrieved, err := retriever.GetRelevantDocuments(context.Background(), prompt)
if err != nil {
return nil, fmt.Errorf("检索文档失败: %v", err)
}
fmt.Println(docRetrieved)
// 返回检索到的文档
return docRetrieved, nil
}
// GetAnswer 获取答案
func GetAnswer(ctx context.Context, llm llms.Model, docRetrieved []schema.Document, prompt string) (string, error) {
// 创建一个新的聊天消息历史记录
history := memory.NewChatMessageHistory()
// 将检索到的文档添加到历史记录中
for _, doc := range docRetrieved {
history.AddAIMessage(ctx, doc.PageContent)
}
// 使用历史记录创建一个新的对话缓冲区
conversation := memory.NewConversationBuffer(memory.WithChatHistory(history))
executor := agents.NewExecutor(
agents.NewConversationalAgent(llm, nil),
nil,
agents.WithMemory(conversation),
)
// 设置链调用选项
options := []chains.ChainCallOption{
chains.WithTemperature(0.8),
}
// 运行链
res, err := chains.Run(ctx, executor, prompt, options...)
if err != nil {
return "", err
}
return res, nil
}本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!