首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >golang源码分析:langchaingo(6)

golang源码分析:langchaingo(6)

作者头像
golangLeetcode
发布2026-03-18 17:59:15
发布2026-03-18 17:59:15
1170
举报

前面介绍langchaingo都是简单应用没有聊到它的核心处理流程,链式处理,这里还是结合例子详细分析下它的源码:

代码语言:javascript
复制
// 将输入翻译为特定语言
    chain1 := chains.NewLLMChain(llm,
        prompts.NewPromptTemplate(
            "请将输入的原始文本:{{.originText}}翻译为{{.language}},直接输出翻译文本",
            []string{"originText", "language"}))
    chain1.OutputKey = "transText"
    // 总结翻译后的文本概要
    chain2 := chains.NewLLMChain(llm, prompts.NewPromptTemplate(
        "请将输入的原始文本:<{{.transText}}>总结50字以内概要文本。严格使用JSON序列化输出结果,不要带有```json序列化标识。其中originText为原始文本,summaryText为概要文本",
        []string{"transText"}))
    chain2.OutputKey = "summary_json"
    chain, err := chains.NewSequentialChain([]chains.Chain{chain1, chain2}, []string{"originText", "language"}, []string{"summary_json"})
    if err != nil {
        log.Fatal(err)
    }
    resp, err := chain.Call(ctx, map[string]any{
        "originText": "langchain is a good llm frameworks",
        "language":   "中文",
    })

可以看到,先定义了两个chain,然后用NewSequentialChain将它俩组合起来,最后调用Call方法,可以看到,虽然模仿了langchain的链式方案,但是用起来没有python的|符号重载直接简单。这里还是依次介绍下源码。

github.com/tmc/langchaingo@v0.1.13/chains/llm.go里定义了chain中的每个节点

代码语言:javascript
复制
// NewLLMChain creates a new LLMChain with an LLM and a prompt.
func NewLLMChain(llm llms.Model, prompt prompts.FormatPrompter, opts ...ChainCallOption) *LLMChain {
    opt := &chainCallOption{}
    for _, o := range opts {
        o(opt)
    }
    chain := &LLMChain{
        Prompt:           prompt,
        LLM:              llm,
        OutputParser:     outputparser.NewSimple(),
        Memory:           memory.NewSimple(),
        OutputKey:        _llmChainDefaultOutputKey,
        CallbacksHandler: opt.CallbackHandler,
    }
    return chain
}

它包括了提示词、llm、缓存、输出解析器、输出key等内容

代码语言:javascript
复制
type LLMChain struct {
    Prompt           prompts.FormatPrompter
    LLM              llms.Model
    Memory           schema.Memory
    CallbacksHandler callbacks.Handler
    OutputParser     schema.OutputParser[any]
    OutputKey string
}

可以看到,它只有一个输出key,意味着只能有一个输出值。参数里有提示词模板和参数列表

代码语言:javascript
复制
func NewPromptTemplate(template string, inputVars []string) PromptTemplate {
    return PromptTemplate{
        Template:       template,
        InputVariables: inputVars,
        TemplateFormat: TemplateFormatGoTemplate,
    }
}
代码语言:javascript
复制
// PromptTemplate contains common fields for all prompt templates.
type PromptTemplate struct {
    // Template is the prompt template.
    Template string
    // A list of variable names the prompt template expects.
    InputVariables []string
    // TemplateFormat is the format of the prompt template.
    TemplateFormat TemplateFormat
    // OutputParser is a function that parses the output of the prompt template.
    OutputParser schema.OutputParser[any]
    // PartialVariables represents a map of variable names to values or functions
    // that return values. If the value is a function, it will be called when the
    // prompt template is rendered.
    PartialVariables map[string]any
}

可以看到,第二个chain的输入使用了第一个chain的输出,整个链就是这么串起来的。接着看下串起整个链的逻辑

代码语言:javascript
复制
func NewSequentialChain(chains []Chain, inputKeys []string, outputKeys []string, opts ...SequentialChainOption) (*SequentialChain, error) { //nolint:lll
    s := &SequentialChain{
        chains:     chains,
        inputKeys:  inputKeys,
        outputKeys: outputKeys,
        memory:     memory.NewSimple(),
    }
    for _, opt := range opts {
        opt(s)
    }
    if err := s.validateSeqChain(); err != nil {
        return nil, err
    }
    return s, nil
}

入参传入了所有chain的列表,所有输入参数列表和最终输出参数名称。函数内部先把参数存到结构体里,后面校验了下参数和整个链的完整合法性

代码语言:javascript
复制
type SequentialChain struct {
    chains     []Chain
    inputKeys  []string
    outputKeys []string
    memory     schema.Memory
}

首先输入的参数和内存中存在的参数不能有交集,即后面链中加入的参数不能和前面节点的参数重名

代码语言:javascript
复制
func (c *SequentialChain) validateSeqChain() error {
    knownKeys := setutil.ToSet(c.inputKeys)
    // Make sure memory keys don't collide with input keys
    memoryKeys := c.memory.MemoryVariables(context.Background())
    overlappingKeys := setutil.Intersection(memoryKeys, knownKeys)

接着校验链中输入的参数必须是前面链中的输出

代码语言:javascript
复制
    for i, c := range c.chains {
        // Check that chain has input keys that are in knownKeys
        missingKeys := setutil.Difference(c.GetInputKeys(), knownKeys)
        if len(missingKeys) > 0 {
        
    overlappingKeys := setutil.Intersection(c.GetOutputKeys(), knownKeys)
        if len(overlappingKeys) > 0 {

最后校验了输出参数

代码语言:javascript
复制
    // Check that outputKeys are in knownKeys
    for _, key := range c.outputKeys {
        if _, ok := knownKeys[key]; !ok {

准备工作完成后就到了Call函数调用的阶段

代码语言:javascript
复制
func (c *SequentialChain) Call(ctx context.Context, inputs map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint:lll
    var outputs map[string]any
    var err error
    for _, chain := range c.chains {
        outputs, err = Call(ctx, chain, inputs, options...)
        if err != nil {
            return nil, err
        }
        // Set the input for the next chain to the output of the current chain
        inputs = outputs
    }
    return outputs, nil
}

可以看到,其实就是一个for循环调用每个Chain,将前一个输出作为后一个输入而已,里面调用的Call函数就是前面介绍的单次请求里的call函数。

代码语言:javascript
复制
func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
    fullValues := make(map[string]any, 0)
    for key, value := range inputValues {
        fullValues[key] = value
    }
    newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues)
    if err != nil {
        return nil, err
    }
    for key, value := range newValues {
        fullValues[key] = value
    }
    callbacksHandler := getChainCallbackHandler(c)
    if callbacksHandler != nil {
        callbacksHandler.HandleChainStart(ctx, inputValues)
    }
    outputValues, err := callChain(ctx, c, fullValues, options...)
    if err != nil {
        if callbacksHandler != nil {
            callbacksHandler.HandleChainError(ctx, err)
        }
        return outputValues, err
    }
    if callbacksHandler != nil {
        callbacksHandler.HandleChainEnd(ctx, outputValues)
    }
    if err = c.GetMemory().SaveContext(ctx, inputValues, outputValues); err != nil {
        return outputValues, err
    }
    return outputValues, nil
}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-06-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档