
前面介绍langchaingo都是简单应用没有聊到它的核心处理流程,链式处理,这里还是结合例子详细分析下它的源码:
// 将输入翻译为特定语言
chain1 := chains.NewLLMChain(llm,
prompts.NewPromptTemplate(
"请将输入的原始文本:{{.originText}}翻译为{{.language}},直接输出翻译文本",
[]string{"originText", "language"}))
chain1.OutputKey = "transText"
// 总结翻译后的文本概要
chain2 := chains.NewLLMChain(llm, prompts.NewPromptTemplate(
"请将输入的原始文本:<{{.transText}}>总结50字以内概要文本。严格使用JSON序列化输出结果,不要带有```json序列化标识。其中originText为原始文本,summaryText为概要文本",
[]string{"transText"}))
chain2.OutputKey = "summary_json"
chain, err := chains.NewSequentialChain([]chains.Chain{chain1, chain2}, []string{"originText", "language"}, []string{"summary_json"})
if err != nil {
log.Fatal(err)
}
resp, err := chain.Call(ctx, map[string]any{
"originText": "langchain is a good llm frameworks",
"language": "中文",
})可以看到,先定义了两个chain,然后用NewSequentialChain将它俩组合起来,最后调用Call方法,可以看到,虽然模仿了langchain的链式方案,但是用起来没有python的|符号重载直接简单。这里还是依次介绍下源码。
github.com/tmc/langchaingo@v0.1.13/chains/llm.go里定义了chain中的每个节点
// NewLLMChain creates a new LLMChain with an LLM and a prompt.
func NewLLMChain(llm llms.Model, prompt prompts.FormatPrompter, opts ...ChainCallOption) *LLMChain {
opt := &chainCallOption{}
for _, o := range opts {
o(opt)
}
chain := &LLMChain{
Prompt: prompt,
LLM: llm,
OutputParser: outputparser.NewSimple(),
Memory: memory.NewSimple(),
OutputKey: _llmChainDefaultOutputKey,
CallbacksHandler: opt.CallbackHandler,
}
return chain
}它包括了提示词、llm、缓存、输出解析器、输出key等内容
type LLMChain struct {
Prompt prompts.FormatPrompter
LLM llms.Model
Memory schema.Memory
CallbacksHandler callbacks.Handler
OutputParser schema.OutputParser[any]
OutputKey string
}可以看到,它只有一个输出key,意味着只能有一个输出值。参数里有提示词模板和参数列表
func NewPromptTemplate(template string, inputVars []string) PromptTemplate {
return PromptTemplate{
Template: template,
InputVariables: inputVars,
TemplateFormat: TemplateFormatGoTemplate,
}
}// PromptTemplate contains common fields for all prompt templates.
type PromptTemplate struct {
// Template is the prompt template.
Template string
// A list of variable names the prompt template expects.
InputVariables []string
// TemplateFormat is the format of the prompt template.
TemplateFormat TemplateFormat
// OutputParser is a function that parses the output of the prompt template.
OutputParser schema.OutputParser[any]
// PartialVariables represents a map of variable names to values or functions
// that return values. If the value is a function, it will be called when the
// prompt template is rendered.
PartialVariables map[string]any
}可以看到,第二个chain的输入使用了第一个chain的输出,整个链就是这么串起来的。接着看下串起整个链的逻辑
func NewSequentialChain(chains []Chain, inputKeys []string, outputKeys []string, opts ...SequentialChainOption) (*SequentialChain, error) { //nolint:lll
s := &SequentialChain{
chains: chains,
inputKeys: inputKeys,
outputKeys: outputKeys,
memory: memory.NewSimple(),
}
for _, opt := range opts {
opt(s)
}
if err := s.validateSeqChain(); err != nil {
return nil, err
}
return s, nil
}入参传入了所有chain的列表,所有输入参数列表和最终输出参数名称。函数内部先把参数存到结构体里,后面校验了下参数和整个链的完整合法性
type SequentialChain struct {
chains []Chain
inputKeys []string
outputKeys []string
memory schema.Memory
}首先输入的参数和内存中存在的参数不能有交集,即后面链中加入的参数不能和前面节点的参数重名
func (c *SequentialChain) validateSeqChain() error {
knownKeys := setutil.ToSet(c.inputKeys)
// Make sure memory keys don't collide with input keys
memoryKeys := c.memory.MemoryVariables(context.Background())
overlappingKeys := setutil.Intersection(memoryKeys, knownKeys)接着校验链中输入的参数必须是前面链中的输出
for i, c := range c.chains {
// Check that chain has input keys that are in knownKeys
missingKeys := setutil.Difference(c.GetInputKeys(), knownKeys)
if len(missingKeys) > 0 {
overlappingKeys := setutil.Intersection(c.GetOutputKeys(), knownKeys)
if len(overlappingKeys) > 0 {最后校验了输出参数
// Check that outputKeys are in knownKeys
for _, key := range c.outputKeys {
if _, ok := knownKeys[key]; !ok {准备工作完成后就到了Call函数调用的阶段
func (c *SequentialChain) Call(ctx context.Context, inputs map[string]any, options ...ChainCallOption) (map[string]any, error) { //nolint:lll
var outputs map[string]any
var err error
for _, chain := range c.chains {
outputs, err = Call(ctx, chain, inputs, options...)
if err != nil {
return nil, err
}
// Set the input for the next chain to the output of the current chain
inputs = outputs
}
return outputs, nil
}可以看到,其实就是一个for循环调用每个Chain,将前一个输出作为后一个输入而已,里面调用的Call函数就是前面介绍的单次请求里的call函数。
func Call(ctx context.Context, c Chain, inputValues map[string]any, options ...ChainCallOption) (map[string]any, error) { // nolint: lll
fullValues := make(map[string]any, 0)
for key, value := range inputValues {
fullValues[key] = value
}
newValues, err := c.GetMemory().LoadMemoryVariables(ctx, inputValues)
if err != nil {
return nil, err
}
for key, value := range newValues {
fullValues[key] = value
}
callbacksHandler := getChainCallbackHandler(c)
if callbacksHandler != nil {
callbacksHandler.HandleChainStart(ctx, inputValues)
}
outputValues, err := callChain(ctx, c, fullValues, options...)
if err != nil {
if callbacksHandler != nil {
callbacksHandler.HandleChainError(ctx, err)
}
return outputValues, err
}
if callbacksHandler != nil {
callbacksHandler.HandleChainEnd(ctx, outputValues)
}
if err = c.GetMemory().SaveContext(ctx, inputValues, outputValues); err != nil {
return outputValues, err
}
return outputValues, nil
}本文分享自 golang算法架构leetcode技术php 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!