
在自然语言处理领域,大模型凭借海量参数和强大的上下文理解能力,成为文本生成的主流方案,但在低资源语言、文本纠错、输入法预测等场景中,大模型偶尔会出现生成不流畅、乱码、逻辑断裂等问题。而诞生数十年的 N-Gram 统计语言模型,虽简单却能凭借局部上下文统计规律提供稳定的语言约束。将两者结合,以N-Gram做兜底校验,大模型做流畅生成,既能发挥大模型的创造性,又能借助 N-Gram 的统计特性保证文本的通顺性和准确性 。
今天我们由浅入深拆解 N-Gram 与大模型的融合应用,涵盖核心概念、基础原理、执行流程及实际应用价值,彻底理解经典统计模型如何赋能大模型

语言模型(Language Model,简称LM)的核心目标是计算一段文本序列的概率,本质是让机器理解哪些文字组合是符合语言习惯的。比如:
无论是N-Gram还是大模型,本质都是语言模型,只是建模方式不同:
N-Gram是将文本拆分为长度为N的连续词序列,也可按字符拆分,通过统计这些序列的出现频率,计算文本的概率,核心就是从“词的组合”到概率。
N-Gram的核心是马尔可夫假设:一个词的出现概率仅依赖于前 N-1个词。
对于文本序列w1 ,w2 ,...,wT,其概率可分解为:
其中:
N-Gram的概率通过最大似然估计(MLE)计算,核心是“频率代替概率”:
其中:
平滑技术是为了解决“未登录词”问题,直接用MLE计算会遇到“零概率问题”:若某个N-Gram序列从未在语料中出现,其概率为 0,导致整个文本概率为 0。因此需要平滑技术,常见方法:
其中 V 是语料库中唯一词的总数;
大模型本质是基于上下文的条件概率模型,与N-Gram的核心区别是:
尽管大模型能力强大,但在实际应用中存在以下问题:
这些痛点,正是N-Gram能发挥作用的地方,N-Gram虽简单,但能通过局部词频统计保证最基础的语言流畅性,成为大模型的兜底方案。
N-Gram和大模型的融合,核心是分工协作,是互补而非替代:
融合的核心目标:在保留大模型语义能力的前提下,提升文本的局部流畅性和准确性。

2.1 生成后校验模式,适用于文本纠错、输入法预测:
2.2 生成中约束模式,适用于低资源语言生成:

为了让N-Gram更好地适配大模型,需要做以下预处理:
α并非固定值,可根据场景动态调整:

大模型在低资源场景下易生成乱码,如小语种的无效字符,N-Gram可通过以下方式过滤:
输入法预测的核心是“基于用户已输入的字符,预测下一个或多个字符”,融合流程:

低资源语言(如藏语、苗语)的大模型训练数据少,易生成乱码,融合流程:

文本纠错的核心是 “识别并修正错误文本”,融合流程:


import numpy as np
import jieba
from collections import defaultdict, Counter
class NGramModel:
def __init__(self, n=3, smooth_method="add_k", k=0.1):
"""
初始化N-Gram模型
:param n: N-Gram的阶数,默认3
:param smooth_method: 平滑方法,可选add_k/backoff/laplace
:param k: add_k平滑的k值,默认0.1
"""
self.n = n
self.smooth_method = smooth_method
self.k = k
# 存储N-Gram和(N-1)-Gram的计数
self.ngram_counts = defaultdict(Counter)
# 存储所有唯一的词(用于平滑)
self.vocab = set()
# 语料总词数
self.total_words = 0
def preprocess(self, text):
"""
文本预处理:分词、去空格、转小写
:param text: 原始文本
:return: 分词后的列表
"""
# 中文分词,英文可替换为split()
words = jieba.lcut(text.lower().replace(" ", ""))
# 添加起始符(<s>)和结束符(</s>),保证N-Gram的完整性
start_token = "<s>"
end_token = "</s>"
processed_words = [start_token] * (self.n - 1) + words + [end_token]
self.vocab.update(processed_words)
self.total_words += len(processed_words)
return processed_words
def train(self, corpus):
"""
训练N-Gram模型:统计N-Gram和(N-1)-Gram的频率
:param corpus: 语料库(列表,每个元素是一条文本)
"""
for text in corpus:
words = self.preprocess(text)
# 生成N-Gram序列
for i in range(len(words) - self.n + 1):
# 前N-1个词作为上下文
context = tuple(words[i:i+self.n-1])
# 第N个词作为目标词
target = words[i+self.n-1]
# 更新计数
self.ngram_counts[context][target] += 1
def calculate_prob(self, context, target):
"""
计算条件概率 P(target | context)
:param context: 上下文(元组,长度为N-1)
:param target: 目标词
:return: 条件概率
"""
context = tuple(context)
# 获取上下文的总计数
context_count = sum(self.ngram_counts[context].values())
# 不同平滑方法的实现
if self.smooth_method == "add_k":
# Add-k平滑
target_count = self.ngram_counts[context].get(target, 0) + self.k
total_count = context_count + self.k * len(self.vocab)
prob = target_count / total_count
elif self.smooth_method == "laplace":
# Laplace平滑(add-1)
target_count = self.ngram_counts[context].get(target, 0) + 1
total_count = context_count + len(self.vocab)
prob = target_count / total_count
elif self.smooth_method == "backoff":
# 回退平滑:若N-Gram计数为0,退到N-1-Gram
if context_count > 0:
prob = self.ngram_counts[context].get(target, 0) / context_count
else:
# 退到(N-1)-Gram,递归计算
if len(context) == 1:
# 退到1-Gram(Unigram)
prob = (self.ngram_counts[tuple()].get(target, 0) + self.k) / (self.total_words + self.k * len(self.vocab))
else:
prob = self.calculate_prob(context[1:], target)
else:
# 无平滑(MLE)
if context_count == 0:
prob = 0.0
else:
prob = self.ngram_counts[context].get(target, 0) / context_count
return prob
def calculate_sequence_prob(self, sequence):
"""
计算整个序列的概率
:param sequence: 文本序列(列表)
:return: 序列概率(对数概率,避免下溢)
"""
processed_seq = ["<s>"] * (self.n - 1) + sequence + ["</s>"]
log_prob = 0.0
for i in range(len(processed_seq) - self.n + 1):
context = processed_seq[i:i+self.n-1]
target = processed_seq[i+self.n-1]
prob = self.calculate_prob(context, target)
# 取对数,避免乘积下溢
log_prob += np.log(prob + 1e-10) # 加极小值避免log(0)
return np.exp(log_prob) # 转换回原始概率
# ------------------------------
# N-Gram模型测试
# ------------------------------
if __name__ == "__main__":
# 示例语料库
corpus = [
"我想吃苹果",
"我想吃米饭",
"我想喝水",
"他想吃苹果",
"她想吃饭"
]
# 初始化并训练3-Gram模型
ng_model = NGramModel(n=3, smooth_method="add_k", k=0.1)
ng_model.train(corpus)
# 测试1:计算单个条件概率 P(苹果 | 我想)
context = ["我", "想"]
target = "吃"
prob = ng_model.calculate_prob(context, target)
print(f"P({target} | {context}) = {prob:.4f}")
# 测试2:计算序列概率 P(我想吃苹果)
sequence = ["我", "想", "吃", "苹果"]
seq_prob = ng_model.calculate_sequence_prob(sequence)
print(f"P(我想吃苹果) = {seq_prob:.6f}")输出结果:
Building prefix dict from the default dictionary ... Loading model from cache C:\Users\Admin\AppData\Local\Temp\jieba.cache Loading model cost 0.380 seconds. Prefix dict has been built successfully. P(吃 | ['我', '想']) = 0.5122 P(我想吃苹果) = 0.068287
结果解读:
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import torch
class LLMGenerator:
def __init__(self, model_name="qwen/Qwen1.5-0.5B-Chat"):
"""
初始化大模型生成器
:param model_name: 模型名称或本地路径
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.eval() # 推理模式
def generate_candidates(self, prompt, top_k=5, max_new_tokens=2):
"""
基于输入提示生成Top-k候选词/序列
:param prompt: 输入提示(如"我想吃")
:param top_k: 生成候选数
:param max_new_tokens: 新生成的token数量
:return: 候选序列列表 + 对应的概率
"""
# 编码输入
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# 确保输入不为空
if input_ids.shape[1] == 0:
raise ValueError(f"输入提示 '{prompt}' 编码后为空,请检查输入内容")
# 使用贪婪搜索获取下一个token的top-k候选
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask)
# 获取最后一个位置的logits
next_token_logits = outputs.logits[:, -1, :]
# 获取top-k的token及其概率
next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
topk_probs, topk_ids = torch.topk(next_token_probs, top_k, dim=-1)
candidates = []
candidate_probs = []
for i in range(top_k):
token_id = topk_ids[0, i].item()
token_prob = topk_probs[0, i].item()
# 解码token
token_text = self.tokenizer.decode([token_id], skip_special_tokens=True)
candidate = prompt + token_text
# 过滤不合理的补全
if len(token_text.strip()) == 0:
continue
# 语义合理性检测
is_valid, reason = self._check_semantic_validity(token_text, prompt)
if not is_valid:
# 记录被过滤的原因(调试用)
print(f"[过滤] {candidate} - 原因:{reason}")
continue
candidates.append(candidate)
candidate_probs.append(token_prob)
return candidates, candidate_probs
def _check_semantic_validity(self, token_text, prompt):
"""
检测补全的语义合理性
:param token_text: 补全的token文本
:param prompt: 原始提示
:return: (是否有效, 原因)
"""
# 规则1:补全应该是一个完整的语义单元
# 单独的国家名、城市名等专有名词通常需要后续补全
if prompt.endswith("我想吃"):
# 吃后面应该是食物,而不是地点/国家
# 检查是否为国家名(基于常见国家特征)
if len(token_text) == 2 and any(c in '日美韩法英德意泰越俄印' for c in token_text):
return False, "单独的国家名不符合'我想吃XXX'的食物语境"
# 检查是否为常见的非食物词汇
non_food_words = {'的', '了', '着', '过'}
if token_text in non_food_words:
return False, "助词不构成完整语义"
# 规则2:补全应该提供有意义的语义信息
if len(token_text) > 10:
return False, "过长的补全可能包含不相关内容"
# 规则3:检查是否包含标点符号(通常表示句子结束)
if token_text in {',', '。', '!', '?', '、'}:
return False, "标点符号不构成有意义的补全"
return True, ""
# ------------------------------
# 大模型测试
# ------------------------------
if __name__ == "__main__":
# 使用本地Qwen模型(直接指定目录)
model_path = "D:/modelscope/hub/qwen/Qwen1___5-0___5B-Chat"
llm = LLMGenerator(model_name=model_path)
prompt = "我想吃"
candidates, probs = llm.generate_candidates(prompt, top_k=5)
print("大模型生成候选:")
# 使用对数概率和相对排名显示
log_probs = [np.log(p + 1e-10) for p in probs]
max_log = max(log_probs)
for i, (cand, prob, log_p) in enumerate(zip(candidates, probs, log_probs)):
relative_score = np.exp(log_p - max_log) # 相对于最高概率的分数
print(f"{i+1}. {cand} | 相对分数:{relative_score:.6f}")输出结果:
[过滤] 我想吃日本 - 原因:单独的国家名不符合'我想吃XXX'的食物语境 大模型生成候选: 1. 我想吃辣 | 相对分数:1.000000 2. 我想吃冰淇淋 | 相对分数:0.392157 3. 我想吃巧克力 | 相对分数:0.323529 4. 我想吃肉 | 相对分数:0.268382
class NGramLLMFusion:
def __init__(self, ng_model, llm_model, alpha=0.7):
"""
初始化融合模型
:param ng_model: 训练好的N-Gram模型
:param llm_model: 初始化好的大模型
:param alpha: 权重系数,越大越依赖大模型
"""
self.ng_model = ng_model
self.llm_model = llm_model
self.alpha = alpha
def fuse_prob(self, llm_prob, ng_prob):
"""
融合概率计算
:param llm_prob: 大模型概率
:param ng_prob: N-Gram概率
:return: 融合后概率
"""
return self.alpha * llm_prob + (1 - self.alpha) * ng_prob
def process_candidate(self, prompt, candidate):
"""
处理候选序列:分词后计算N-Gram概率
:param prompt: 输入提示
:param candidate: 大模型生成的候选序列
:return: 候选序列的N-Gram概率
"""
# 拆分prompt和候选的新增部分
if candidate.startswith(prompt):
new_part = candidate[len(prompt):]
else:
new_part = candidate
# 分词:prompt + 新增部分
full_sequence = jieba.lcut((prompt + new_part).replace(" ", ""))
# 计算N-Gram序列概率
ng_prob = self.ng_model.calculate_sequence_prob(full_sequence)
return ng_prob
def generate_final(self, prompt, top_k=5, max_length=20):
"""
生成最终文本:融合N-Gram和大模型的结果
:param prompt: 输入提示
:param top_k: 大模型生成候选数
:param max_length: 生成最大长度
:return: 最终生成的文本 + 融合概率
"""
# 1. 大模型生成候选
candidates, llm_probs = self.llm_model.generate_candidates(prompt, top_k=top_k, max_new_tokens=max_length)
# 2. 计算每个候选的N-Gram概率
ng_probs = []
for cand in candidates:
ng_prob = self.process_candidate(prompt, cand)
ng_probs.append(ng_prob)
# 3. 融合概率
fuse_probs = [self.fuse_prob(lp, np) for lp, np in zip(llm_probs, ng_probs)]
# 4. 选择融合概率最高的候选
best_idx = np.argmax(fuse_probs)
best_candidate = candidates[best_idx]
best_fuse_prob = fuse_probs[best_idx]
return best_candidate, best_fuse_prob, candidates, fuse_probs
# ------------------------------
# 融合模型测试
# ------------------------------
if __name__ == "__main__":
# 1. 训练N-Gram模型
corpus = [
"我想吃苹果", "我想吃米饭", "我想喝水", "他想吃苹果", "她想吃饭",
"我想打游戏", "我想看电影", "我想睡觉", "我想听歌", "我想跑步"
]
ng_model = NGramModel(n=3, smooth_method="add_k", k=0.1)
ng_model.train(corpus)
# 2. 初始化大模型,使用本地Qwen模型(直接指定目录)
model_path = "D:/modelscope/hub/qwen/Qwen1___5-0___5B-Chat"
llm_model = LLMGenerator(model_name=model_path)
# 3. 初始化融合模型
fuse_model = NGramLLMFusion(ng_model, llm_model, alpha=0.7)
# 4. 生成最终文本
prompt = "我想吃"
best_candidate, best_prob, candidates, fuse_probs = fuse_model.generate_final(prompt, top_k=5)
# 输出结果
print("="*50)
print(f"输入提示:{prompt}")
print("="*50)
print("候选序列及融合概率:")
for i, (cand, prob) in enumerate(zip(candidates, fuse_probs)):
print(f"{i+1}. {cand} | 融合概率:{prob:.6f}")
print("="*50)
print(f"最终生成:{best_candidate} | 融合概率:{best_prob:.6f}") 输出结果:
[过滤] 我想吃日本 - 原因:单独的国家名不符合'我想吃XXX'的食物语境 ================================================== 输入提示:我想吃 ================================================== 候选序列及融合概率: 1. 我想吃辣 | 融合概率:0.069774 2. 我想吃冰淇淋 | 融合概率:0.027391 3. 我想吃巧克力 | 融合概率:0.022606 4. 我想吃肉 | 融合概率:0.018760 ================================================== 最终生成:我想吃辣 | 融合概率:0.069774
大模型的优势是创造性,但短板是稳定性;N-Gram的优势是稳定性,基于真实语料的统计规律,短板是 “无语义理解”。
N-Gram 对大模型的核心价值是:用最低的成本提升大模型生成的下限,具体体现在:
总结下来,N-Gram+大模型的融合,核心就是老技术搭台,新技术唱戏,特别好理解。尽管现在大模型如火如荼,但它就不一定比传统统计模型厉害,两者互补才是更具优势,大模型负责天马行空的语义生成,N-Gram 用简单的统计规律兜底,帮着过滤乱码、修正不地道的表达,让最终输出更通顺、更靠谱。
我们可以先从简单的N-Gram代码练起,再尝试对接开源大模型做融合,慢慢调优权重和语料,就能感受到两者结合的魔力。不用追求复杂,把基础打牢,理解互补的核心,比盲目跟风学大模型更有意义。
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import torch
import jieba
from collections import defaultdict, Counter
class NGramModel:
def __init__(self, n=3, smooth_method="add_k", k=0.1):
"""
初始化N-Gram模型
:param n: N-Gram的阶数,默认3
:param smooth_method: 平滑方法,可选add_k/backoff/laplace
:param k: add_k平滑的k值,默认0.1
"""
self.n = n
self.smooth_method = smooth_method
self.k = k
# 存储N-Gram和(N-1)-Gram的计数
self.ngram_counts = defaultdict(Counter)
# 存储所有唯一的词(用于平滑)
self.vocab = set()
# 语料总词数
self.total_words = 0
def preprocess(self, text):
"""
文本预处理:分词、去空格、转小写
:param text: 原始文本
:return: 分词后的列表
"""
# 中文分词,英文可替换为split()
words = jieba.lcut(text.lower().replace(" ", ""))
# 添加起始符(<s>)和结束符(</s>),保证N-Gram的完整性
start_token = "<s>"
end_token = "</s>"
processed_words = [start_token] * (self.n - 1) + words + [end_token]
self.vocab.update(processed_words)
self.total_words += len(processed_words)
return processed_words
def train(self, corpus):
"""
训练N-Gram模型:统计N-Gram和(N-1)-Gram的频率
:param corpus: 语料库(列表,每个元素是一条文本)
"""
for text in corpus:
words = self.preprocess(text)
# 生成N-Gram序列
for i in range(len(words) - self.n + 1):
# 前N-1个词作为上下文
context = tuple(words[i:i+self.n-1])
# 第N个词作为目标词
target = words[i+self.n-1]
# 更新计数
self.ngram_counts[context][target] += 1
def calculate_prob(self, context, target):
"""
计算条件概率 P(target | context)
:param context: 上下文(元组,长度为N-1)
:param target: 目标词
:return: 条件概率
"""
context = tuple(context)
# 获取上下文的总计数
context_count = sum(self.ngram_counts[context].values())
# 不同平滑方法的实现
if self.smooth_method == "add_k":
# Add-k平滑
target_count = self.ngram_counts[context].get(target, 0) + self.k
total_count = context_count + self.k * len(self.vocab)
prob = target_count / total_count
elif self.smooth_method == "laplace":
# Laplace平滑(add-1)
target_count = self.ngram_counts[context].get(target, 0) + 1
total_count = context_count + len(self.vocab)
prob = target_count / total_count
elif self.smooth_method == "backoff":
# 回退平滑:若N-Gram计数为0,退到N-1-Gram
if context_count > 0:
prob = self.ngram_counts[context].get(target, 0) / context_count
else:
# 退到(N-1)-Gram,递归计算
if len(context) == 1:
# 退到1-Gram(Unigram)
prob = (self.ngram_counts[tuple()].get(target, 0) + self.k) / (self.total_words + self.k * len(self.vocab))
else:
prob = self.calculate_prob(context[1:], target)
else:
# 无平滑(MLE)
if context_count == 0:
prob = 0.0
else:
prob = self.ngram_counts[context].get(target, 0) / context_count
return prob
def calculate_sequence_prob(self, sequence):
"""
计算整个序列的概率
:param sequence: 文本序列(列表)
:return: 序列概率(对数概率,避免下溢)
"""
processed_seq = ["<s>"] * (self.n - 1) + sequence + ["</s>"]
log_prob = 0.0
for i in range(len(processed_seq) - self.n + 1):
context = processed_seq[i:i+self.n-1]
target = processed_seq[i+self.n-1]
prob = self.calculate_prob(context, target)
# 取对数,避免乘积下溢
log_prob += np.log(prob + 1e-10) # 加极小值避免log(0)
return np.exp(log_prob) # 转换回原始概率
class LLMGenerator:
def __init__(self, model_name="qwen/Qwen1.5-0.5B-Chat"):
"""
初始化大模型生成器
:param model_name: 模型名称或本地路径
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.eval() # 推理模式
def generate_candidates(self, prompt, top_k=5, max_new_tokens=2):
"""
基于输入提示生成Top-k候选词/序列
:param prompt: 输入提示(如"我想吃")
:param top_k: 生成候选数
:param max_new_tokens: 新生成的token数量
:return: 候选序列列表 + 对应的概率
"""
# 编码输入
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# 确保输入不为空
if input_ids.shape[1] == 0:
raise ValueError(f"输入提示 '{prompt}' 编码后为空,请检查输入内容")
# 使用贪婪搜索获取下一个token的top-k候选
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask)
# 获取最后一个位置的logits
next_token_logits = outputs.logits[:, -1, :]
# 获取top-k的token及其概率
next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
topk_probs, topk_ids = torch.topk(next_token_probs, top_k, dim=-1)
candidates = []
candidate_probs = []
for i in range(top_k):
token_id = topk_ids[0, i].item()
token_prob = topk_probs[0, i].item()
# 解码token
token_text = self.tokenizer.decode([token_id], skip_special_tokens=True)
candidate = prompt + token_text
# 过滤不合理的补全
if len(token_text.strip()) == 0:
continue
# 语义合理性检测
is_valid, reason = self._check_semantic_validity(token_text, prompt)
if not is_valid:
# 记录被过滤的原因(调试用)
print(f"[过滤] {candidate} - 原因:{reason}")
continue
candidates.append(candidate)
candidate_probs.append(token_prob)
return candidates, candidate_probs
def _check_semantic_validity(self, token_text, prompt):
"""
检测补全的语义合理性
:param token_text: 补全的token文本
:param prompt: 原始提示
:return: (是否有效, 原因)
"""
# 规则1:补全应该是一个完整的语义单元
# 单独的国家名、城市名等专有名词通常需要后续补全
if prompt.endswith("我想吃"):
# 吃后面应该是食物,而不是地点/国家
# 检查是否为国家名(基于常见国家特征)
if len(token_text) == 2 and any(c in '日美韩法英德意泰越俄印' for c in token_text):
return False, "单独的国家名不符合'我想吃XXX'的食物语境"
# 检查是否为常见的非食物词汇
non_food_words = {'的', '了', '着', '过'}
if token_text in non_food_words:
return False, "助词不构成完整语义"
# 规则2:补全应该提供有意义的语义信息
if len(token_text) > 10:
return False, "过长的补全可能包含不相关内容"
# 规则3:检查是否包含标点符号(通常表示句子结束)
if token_text in {',', '。', '!', '?', '、'}:
return False, "标点符号不构成有意义的补全"
return True, ""
class NGramLLMFusion:
def __init__(self, ng_model, llm_model, alpha=0.7):
"""
初始化融合模型
:param ng_model: 训练好的N-Gram模型
:param llm_model: 初始化好的大模型
:param alpha: 权重系数,越大越依赖大模型
"""
self.ng_model = ng_model
self.llm_model = llm_model
self.alpha = alpha
def fuse_prob(self, llm_prob, ng_prob):
"""
融合概率计算
:param llm_prob: 大模型概率
:param ng_prob: N-Gram概率
:return: 融合后概率
"""
return self.alpha * llm_prob + (1 - self.alpha) * ng_prob
def process_candidate(self, prompt, candidate):
"""
处理候选序列:分词后计算N-Gram概率
:param prompt: 输入提示
:param candidate: 大模型生成的候选序列
:return: 候选序列的N-Gram概率
"""
# 拆分prompt和候选的新增部分
if candidate.startswith(prompt):
new_part = candidate[len(prompt):]
else:
new_part = candidate
# 分词:prompt + 新增部分
full_sequence = jieba.lcut((prompt + new_part).replace(" ", ""))
# 计算N-Gram序列概率
ng_prob = self.ng_model.calculate_sequence_prob(full_sequence)
return ng_prob
def generate_final(self, prompt, top_k=5, max_length=20):
"""
生成最终文本:融合N-Gram和大模型的结果
:param prompt: 输入提示
:param top_k: 大模型生成候选数
:param max_length: 生成最大长度
:return: 最终生成的文本 + 融合概率
"""
# 1. 大模型生成候选
candidates, llm_probs = self.llm_model.generate_candidates(prompt, top_k=top_k, max_new_tokens=max_length)
# 2. 计算每个候选的N-Gram概率
ng_probs = []
for cand in candidates:
ng_prob = self.process_candidate(prompt, cand)
ng_probs.append(ng_prob)
# 3. 融合概率
fuse_probs = [self.fuse_prob(lp, np) for lp, np in zip(llm_probs, ng_probs)]
# 4. 选择融合概率最高的候选
best_idx = np.argmax(fuse_probs)
best_candidate = candidates[best_idx]
best_fuse_prob = fuse_probs[best_idx]
return best_candidate, best_fuse_prob, candidates, fuse_probs
# ------------------------------
# 融合模型测试
# ------------------------------
if __name__ == "__main__":
# 1. 训练N-Gram模型
corpus = [
"我想吃苹果", "我想吃米饭", "我想喝水", "他想吃苹果", "她想吃饭",
"我想打游戏", "我想看电影", "我想睡觉", "我想听歌", "我想跑步"
]
ng_model = NGramModel(n=3, smooth_method="add_k", k=0.1)
ng_model.train(corpus)
# 2. 初始化大模型,使用本地Qwen模型(直接指定目录)
model_path = "D:/modelscope/hub/qwen/Qwen1___5-0___5B-Chat"
llm_model = LLMGenerator(model_name=model_path)
# 3. 初始化融合模型
fuse_model = NGramLLMFusion(ng_model, llm_model, alpha=0.7)
# 4. 生成最终文本
prompt = "我想吃"
best_candidate, best_prob, candidates, fuse_probs = fuse_model.generate_final(prompt, top_k=5)
# 输出结果
print("="*50)
print(f"输入提示:{prompt}")
print("="*50)
print("候选序列及融合概率:")
for i, (cand, prob) in enumerate(zip(candidates, fuse_probs)):
print(f"{i+1}. {cand} | 融合概率:{prob:.6f}")
print("="*50)
print(f"最终生成:{best_candidate} | 融合概率:{best_prob:.6f}") 原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。