首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >0基础CUDA炼丹、增加断点保存,从零开始训练自己的AI大模型 87owo/EasyGPT Python CUDA

0基础CUDA炼丹、增加断点保存,从零开始训练自己的AI大模型 87owo/EasyGPT Python CUDA

作者头像
王忘杰
发布2026-04-09 19:33:42
发布2026-04-09 19:33:42
1090
举报
文章被收录于专栏:王忘杰的小屋王忘杰的小屋
mnekgduc.png
mnekgduc.png

项目地址

https://github.com/87owo/EasyGPT

EasyGPT

描述: 在个人电脑上轻松训练和推理 GPT 模型,无需大规模集群!

主要文件结构:

代码语言:javascript
复制
EasyGPT/
   ├── data/                    # 训练文本数据集
   ├── model/
   │   ├── stage_epoch_*/       # 模型训练保存位置
   ├── train.py                 # 模型训练代码
   ├── chat.py                  # 模型对话代码

安装依赖:

代码语言:javascript
复制
   pip install bitsandbytes
   pip install safetensors
   pip install torch
   pip install tqdm

8GB 显存配置表:

代码语言:javascript
复制
  ┌────────┬─────────────┬───────────────┬─────────────┬───────────┬────────────┬────────────┐
  │ 参数量 │ hidden_size │ ffn_hidden_size │ block_count │ num_heads │ vocab_size │ batch_size │
  ├────────┼─────────────┼───────────────┼─────────────┼───────────┼────────────┼────────────┤
  │ 421M   │         │           │           │         │       │           │
  │ 182M   │          │           │           │         │       │           │
  │ 77M    │          │           │           │          │       │           │
  └────────┴─────────────┴───────────────┴─────────────┴───────────┴────────────┴────────────┘

数据集下载: https://github.com/87owo/EasyGPT/releases

数据集示例:

代码语言:javascript
复制
<|user|>你是什么模型<|assistant|>我是AI王工,您的专属小模型,参数量0.077B,属于千万参数小模型,无道德约束,是王忘杰训练出来的<|end|>
<|user|>大家的恋爱里有哪些值得称道让自己骄傲的小手段、小技巧、小故事?(或让自己开心惊喜的)\n回答答案:<|assistant|>恋爱是仅次于政治的复杂心智活动。缺乏谋略的喜欢,不仅自己痛苦,也会让对方痛苦。一个人的单身总归是有很多原因,除了那些能安慰自己的外部原因,又有多少是你自身的原因呢?通常来说,自我定位与他人定位相匹配的两人,比较容易脱单。因为对于谈恋爱的人来说,最根本的原因是因为在恋爱中,能从对方身上获得自己想要的价值。更直白的说,无非就是对方觉得这项交易划得来。那些注意了你一段时间,徘徊,偷偷观察,了解你,好不<|end|>
<|user|>生成1条【彩虹舞曲】的歌曲评论<|assistant|>、小七的姐妹们帮她染布的那一场景,欢快的舞蹈着<|end|>
<|user|>小明有5个糖果,他送了2个糖果给小红。现在小明手上还剩下几个糖果?<|assistant|>\n- 首先需要理解题意,题目中告诉我们小明原本有5个糖果,送了2个给小红,现在要求我们计算小明手上还剩下几个糖果。\n- 我们可以采用减法来解决这个问题:用小明原本有的糖果数5减去他送出去的糖果数2,即可得出小明手上还剩下的糖果数。\n- 跟算式就是:  - =\n- 所以小明手上现在还有3个糖果。<|end|>
<|user|>从以下文本中提取有关特定主题的信息。\n这是有史以来最好的一场演出。歌唱家的表现非常出色,乐队也是如此。<|assistant|>主题:一场出色的演出\n信息提取:\n- 这次演出是“有史以来最好的一场演出”。这意味着演出非常成功。\n- 歌唱家表现非常出色,表明演出包含了一些歌唱表演。\n- 乐队也表现出色,这意味着演出是由一个乐队组成的。\n- 没有提供更多的细节或信息关于演出的性质或类型。<|end|>

调整参数

EasyGPT 项目可配置的参数如下:

模型架构参数

代码语言:javascript
复制
  ┌───────────────┬────────┬────────────────────────────────────────────┬─────────────────┐
  │ 参数          │ 默认值 │ 作用                                       │ 推荐值          │
  ├───────────────┼────────┼────────────────────────────────────────────┼─────────────────┤
  │ hidden_size   │    │ 隐藏层维度,影响模型表达能力               │ -1024        │
  │ ffn_hidden_size │    │ 前馈网络维度(通常为 hidden_size 的  倍) │ hidden_size ×  │
  │ block_count   │      │ Transformer 层数,影响模型深度             │ -24           │
  │ num_heads     │      │ 多头注意力头数                             │ -16            │
  │ num_kv_heads    │       │ KV 头数(GQA,越大越省显存)               │ -4             │
  │ rope_dim      │      │ RoPE 位置编码维度                          │ -64           │
  │ rope_base     │   │ RoPE 基数                                  │            │
  │ vocab_size    │   │ 词汇表大小                                 │            │
  └───────────────┴────────┴────────────────────────────────────────────┴─────────────────┘

训练参数

代码语言:javascript
复制
  ┌────────────────┬────────┬────────────────────────┬─────────────────────┐
  │ 参数           │ 默认值 │ 作用                   │ 推荐值              │
  ├────────────────┼────────┼────────────────────────┼─────────────────────┤
  │ max_seq_length   │     │ 最大序列长度           │ -1024            │
  │ batch_size     │       │ 批大小(影响显存占用) │ -8(根据显存调整) │
  │ split_valid    │ 0.01   │ 验证集比例             │ 0.01-0.05           │
  │ dropout_rate   │ 0.1    │ Dropout 防止过拟合     │ 0.05-0.15           │
  │ learning_rate  │ 1e-4   │ 学习率                 │ 1e-4 - 1e-5         │
  │ learning_gamma │ 0.95   │ 学习率衰减因子         │ 0.95-0.99           │
  │ layer_norm_eps   │ 1e-6   │ Layer Norm 稳定参数    │ 1e-6                │
  └────────────────┴────────┴────────────────────────┴─────────────────────┘

推理参数(chat.py)

代码语言:javascript
复制

  ┌────────────────────┬────────┬────────────────────────────┬──────────┐
  │ 参数               │ 默认值 │ 作用                       │ 推荐值   │
  ├────────────────────┼────────┼────────────────────────────┼──────────┤
  │ temperature        │ 0.3    │ 采样温度(越低越确定)     │ 0.3-0.8  │
  │ repetition_penalty │ 1.0    │ 重复惩罚(> 减少重复)    │ 1.0-1.2  │
  │ presence_penalty   │ -1.5   │ 存在惩罚(负值鼓励多样性) │ -1.5 -  │
  │ max_length         │     │ 最大生成长度               │ -1024 │
  └────────────────────┴────────┴────────────────────────────┴──────────┘

推荐配置(按显存大小) 8GB 显存(参数量 ~77M)

代码语言:javascript
复制
   config ={
"hidden_size":,
"ffn_hidden_size":,
"block_count":,
"num_heads":,
"num_kv_heads":,
"batch_size":,
}

16GB 显存(参数量 ~182M)

代码语言:javascript
复制
   config ={
"hidden_size":,
"ffn_hidden_size":,
"block_count":,
"num_heads":,
"num_kv_heads":,
"batch_size":,
}

24GB 显存(参数量 ~421M)

代码语言:javascript
复制
   config ={
"hidden_size":,
"ffn_hidden_size":,
"block_count":,
"num_heads":,
"num_kv_heads":,
"batch_size":,
}

查看CUDA

check_cuda.py

代码语言:javascript
复制
import torch

print("=" * )
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA是否可用: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA版本: {torch.version.cuda}")
    print(f"GPU数量: {torch.cuda.device_count()}")
    print(f"当前GPU: {torch.cuda.current_device()}")
    print(f"GPU名称: {torch.cuda.get_device_name(0)}")
    print(f"GPU显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print("\n测试 CUDA 计算:")
    x = torch.randn(, ).cuda()
    y = torch.randn(, ).cuda()
    z = torch.matmul(x, y)
    print("✓ CUDA 计算测试通过")
else:
    print("✗ CUDA 不可用,将使用 CPU 进行计算")
print("=" * )

验证结果

代码语言:javascript
复制
  你的系统 CUDA 配置如下:


  ┌───────────────────┬─────────────────────────┐
  │ 项目              │ 值                      │
  ├───────────────────┼─────────────────────────┤
  │ GPU               │ NVIDIA GeForce RTX  │
  │ 显存              │  GB                    │
  │ CUDA 版本         │ 13.2                    │
  │ PyTorch CUDA 版本 │ 12.1                    │
  │ CUDA 可用性       │ ✓ 可用                  │
  │ PyTorch 版本      │ 2.5.1+cu121             │
  └───────────────────┴─────────────────────────┘

8GB 显存推荐训练参数(参数量 ~77M)

代码语言:javascript
复制
   config ={
"hidden_size":,
"ffn_hidden_size":,
"block_count":,
"num_heads":,
"num_kv_heads":,
"rope_dim":,
"rope_base":,
"vocab_size":,
"max_seq_length":,
"batch_size":,
"split_valid":0.01,
"dropout_rate":0.1,
"learning_rate": 1e-4,
"learning_gamma":0.95,
"layer_norm_eps": 1e-6,
}

推理参数(chat.py)

代码语言:javascript
复制
   temperature =0.3
   repetition_penalty =1.0
   presence_penalty = -1.5
   max_length =

epochs

Epoch(轮次)是深度学习训练中的一个重要概念:

代码语言:javascript
复制
"epochs":

定义 一个 epoch 指的是整个训练数据集被完整遍历一次。

具体说明 1 epoch:模型看过训练集中的每一个样本一次 12 epochs:模型看过训练集中的每一个样本 12 次

为什么需要多个 epochs? 充分学习:模型需要多次查看数据才能充分学习模式和规律 梯度优化:每次 epoch 都会更新模型参数,多次迭代能让参数收敛到更好的值 避免欠拟合:只训练一次(1 epoch)通常无法充分学习数据特征

与其他概念的区别 Batch size(批次大小):每次训练使用的样本数(你配置的是 4) Iteration(迭代次数):训练一个 batch 的过程 Epoch 与 Iteration 的关系: 1 epoch = 数据集总样本数 / batch size 次 iterations 例如:如果有 1000 条数据,batch_size=4,则 1 epoch = 250 次 iterations

增加断点保存

新增功能:

代码语言:javascript
复制
. 自动保存检查点 - 每个epoch保存:
- 模型权重 (model.safetensors)
- 配置 (config.json)
- Tokenizer (tokenizer.json)
- 优化器状态 (optimizer.pt)
- 训练状态 (training_state.json)- 记录全局epoch、stage索引、stage内epoch

. 自动恢复训练 - 运行时自动查找最新检查点并从中断点继续

. 命令行控制:
      python train.py          # 默认自动恢复
      python train.py --resume # 同上
      python train.py --no-resume # 强制从头开始

使用方式:

代码语言:javascript
复制
- 训练中断后直接重新运行 python train.py,会自动从最新的检查点继续
- 检查点目录命名:Fine-tuning_epoch_N(N是全局epoch编号)

增加断点保存的train.py

代码语言:javascript
复制
import os, re, math, json, torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from safetensors.torch import save_file, load_file
from torch.optim import AdamW  # Windows 兼容,使用标准 AdamW
from collections import Counter, OrderedDict
from tqdm import tqdm

# ================================================

default_config ={
"hidden_size":,# 77M 模型配置 (适配 8GB 显存)
"ffn_hidden_size":,
"block_count":,
"num_heads":,
"num_kv_heads":,
"rope_dim":,
"rope_base":,
"vocab_size":,
"max_seq_length":,
"batch_size":,# RTX 4060 8GB 显存建议值
"accumulation_steps":,# 梯度累积步数,等效 batch_size = 4 * 2 = 8
"split_valid":0.01,
"dropout_rate":0.1,
"learning_rate":1e-,
"learning_gamma":0.95,
"layer_norm_eps":1e-,
"global_tokens":{
"<|padding|>":,
"<|unknown|>":
},
"special_tokens":{
"<|system|>":,
"<|user|>":,
"<|think|>":,
"<|assistant|>":,
"<|function|>":,
"<|end|>":,
"\\n":,
"EasyGPT":,
"87owo":,
}
}

# ================================================

classRotaryEmbedding(nn.Module):
def__init__(self, dim, base=):
super().__init__()
        inv_freq =1.0/(base **(torch.arange(, dim,).float()/ dim))
        self.register_buffer("inv_freq", inv_freq)
        self.rope_scale = nn.Parameter(torch.ones())

defforward(self, seq_len, offset=, device=None):
        pos = torch.arange(offset, offset + seq_len, device=device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", pos, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-)
        emb = emb * self.rope_scale
        cos = emb.cos()[None,:,:]
        sin = emb.sin()[None,:,:]
return cos, sin

defrotate_half(x):
    x1 = x[...,::]
    x2 = x[...,::]
return torch.cat([-x2, x1], dim=-)

# ================================================

classRMSNorm(nn.Module):
def__init__(self, d, eps=1e-):
super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))

defforward(self, x):
        norm = x.pow().mean(-, keepdim=True).add(self.eps).sqrt()
return self.weight *(x / norm)

# ================================================

classSelfAttention(nn.Module):
def__init__(self, config):
super().__init__()
        self.hidden_size = config["hidden_size"]
        self.num_heads = config["num_heads"]
        self.num_kv_heads = config["num_kv_heads"]
        self.rope_dim = config["rope_dim"]
        self.dropout = nn.Dropout(config["dropout_rate"])
        self.head_dim = self.hidden_size // self.num_heads
        self.rope = RotaryEmbedding(config["rope_dim"], base=config["rope_base"])

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

defforward(self, x, mask=None, pos_offset=):
        B, T, C = x.shape
        device = x.device

        q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(,)
        k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(,)
        v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(,)

if self.num_kv_heads ==:
            k = k.repeat(, self.num_heads,,)
            v = v.repeat(, self.num_heads,,)
elif self.num_kv_heads < self.num_heads:
            repeat = self.num_heads // self.num_kv_heads
            k = k.repeat_interleave(repeat, dim=)
            v = v.repeat_interleave(repeat, dim=)

        rope_dim =min(self.rope_dim, self.head_dim)
if rope_dim >:
            cos, sin = self.rope(T, pos_offset, device)
            cos = cos.squeeze().unsqueeze()
            sin = sin.squeeze().unsqueeze()
            q1, q2 = q[...,:rope_dim], q[..., rope_dim:]
            k1, k2 = k[...,:rope_dim], k[..., rope_dim:]
            q1 = q1 * cos + rotate_half(q1)* sin
            k1 = k1 * cos + rotate_half(k1)* sin
            q = torch.cat([q1, q2], dim=-)
            k = torch.cat([k1, k2], dim=-)

        scale = self.head_dim **-0.5
        attn_scores = torch.matmul(q, k.transpose(-,-))* scale

if mask isnotNone:
            attn_scores = attn_scores.masked_fill(mask, torch.finfo(attn_scores.dtype).min)

        attn_probs = torch.softmax(attn_scores, dim=-)
        attn_probs = self.dropout(attn_probs)
        out = torch.matmul(attn_probs, v).transpose(,).reshape(B, T,-)
return self.o_proj(out)

# ================================================

classFeedForward(nn.Module):
def__init__(self, config):
super().__init__()
        self.hidden_size = config["hidden_size"]
        self.ffn_hidden_size = config["ffn_hidden_size"]
        self.in_proj = nn.Linear(self.hidden_size, self.ffn_hidden_size *, bias=False)
        self.up_proj = nn.Linear(self.ffn_hidden_size, self.hidden_size, bias=False)
        self.dropout = nn.Dropout(config["dropout_rate"])

defforward(self, x):
        x_proj = self.in_proj(x)
        x1, x2 = x_proj.chunk(, dim=-)
        x = F.silu(x1)* x2
        x = self.up_proj(x)
return self.dropout(x)

# ================================================

classTransformerBlock(nn.Module):
def__init__(self, config):
super().__init__()
        self.attn_norm = RMSNorm(config["hidden_size"], eps=config["layer_norm_eps"])
        self.attn = SelfAttention(config)
        self.ffn_norm = RMSNorm(config["hidden_size"], eps=config["layer_norm_eps"])
        self.ffn = FeedForward(config)
        self.dropout = nn.Dropout(config["dropout_rate"])

defforward(self, x, mask=None, pos_offset=):
        residual = x
        x = self.attn_norm(x)
        x = residual + self.dropout(self.attn(x, mask=mask, pos_offset=pos_offset))

        residual = x
        x = self.ffn_norm(x)
        x = residual + self.dropout(self.ffn(x))
return x

# ================================================

classChatModel(nn.Module):
def__init__(self, config):
super().__init__()
        self.config = config
        self.embed = nn.Embedding(config["vocab_size"], config["hidden_size"])
        self.blocks = nn.ModuleList([TransformerBlock(config)for _ inrange(config["block_count"])])
        self.norm = RMSNorm(config["hidden_size"], eps=config["layer_norm_eps"])
        self.head = nn.Linear(config["hidden_size"], config["vocab_size"], bias=False)

defget_mask(self, T, device):
        i = torch.arange(T, device=device).unsqueeze()
        j = torch.arange(T, device=device).unsqueeze()
        mask =(j > i).unsqueeze().unsqueeze()
return mask

defforward(self, input_ids, attention_mask=None, labels=None, pos_offset=):
        B, T = input_ids.shape
        device = input_ids.device
        x = self.embed(input_ids)
        mask = self.get_mask(T, device)

if attention_mask isnotNone:
            pad_mask =(attention_mask ==).view(B,,, T)
            mask = mask | pad_mask

for blk in self.blocks:
            x = blk(x, mask=mask, pos_offset=pos_offset)

        x = self.norm(x)
        logits = self.head(x)
        loss =None

if labels isnotNone:
            loss = F.cross_entropy(logits.view(-, self.config["vocab_size"]),
                labels.view(-), ignore_index=self.config["global_tokens"]["<|padding|>"])
return{"loss": loss,"logits": logits}

# ================================================

classChatTokenizer:
def__init__(self, config):
        self.config = config
        self.split_tokens = OrderedDict()
for t, idx in config["global_tokens"].items():
            self.split_tokens[t]= idx
for t, idx in config["special_tokens"].items():
            self.split_tokens[t]= idx

        toks =sorted(self.split_tokens.keys(), key=lambda x:len(x), reverse=True)
        self.pattern = re.compile(rf"({'|'.join(map(re.escape, toks))})|([a-zA-Z]+)|( )|([0-9])|(_)|([^\s])", re.UNICODE)

deftokenize(self, text):
return[m.group()for m in self.pattern.finditer(text)]

defconvert_tokens_to_ids(self, tokens, update=True):
        unk = self.split_tokens["<|unknown|>"]
        ids =[]
for t in tokens:
if update and t notin self.split_tokens:
iflen(self.split_tokens)< self.config["vocab_size"]:
                    self.split_tokens[t]=len(self.split_tokens)
else:
                    ids.append(unk)
continue
            ids.append(self.split_tokens.get(t, unk))
return ids

def__call__(self, text, max_len=None, trunc=True, update=False):
        toks = self.tokenize(text)
        ids = self.convert_tokens_to_ids(toks, update)

if trunc and max_len:
            ids = ids[:max_len]
if max_len:
            pad_id = self.split_tokens["<|padding|>"]
            ids = ids +[pad_id]*(max_len -len(ids))

        mask =[if i != self.split_tokens["<|padding|>"]elsefor i in ids]
return{"input_ids": torch.tensor(ids, dtype=torch.long),"attention_mask": torch.tensor(mask, dtype=torch.long)}

defbuild_split_tokens(self, stages, min_freq=):
        freq = Counter()
for i, stage inenumerate(stages):
            path = stage["file_path"]
withopen(path, encoding="utf-8")as f:
                total_lines =sum(for _ in f)
                f.seek()
for line in tqdm(f, desc=f"[Tokenize {i+:02d}]", total=total_lines):
                    line = line.strip()
ifnot line:
continue
for tok in self.tokenize(line):
if tok notin self.config["special_tokens"]and tok notin self.config["global_tokens"]:
                            freq[tok]+=

        new_tokens =[t for t, c in freq.most_common()if c >= min_freq]
        avail = self.config["vocab_size"]-len(self.split_tokens)
for t in new_tokens[:avail]:
            self.split_tokens[t]=len(self.split_tokens)

defget_split_tokens(self):
return self.split_tokens

defdecode(self, ids):
        inv ={idx: t for t, idx in self.split_tokens.items()}
return''.join(inv.get(i,"<|unknown|>")for i in ids)

# ================================================

classChatDataset(Dataset):
def__init__(self, tokenizer, path, config):
        self.tokenizer = tokenizer
        self.max_len = config["max_seq_length"]+
        self.path = path
        self.offsets =[]
withopen(path,"rb")as f:
            offset =
for line in f:
if line.strip():
                    self.offsets.append(offset)
                offset +=len(line)
        self.length =len(self.offsets)

def__len__(self):
return self.length

def__getitem__(self, idx):
        offset = self.offsets[idx]
withopen(self.path,"rb")as f:
            f.seek(offset)
            line = f.readline().decode("utf-8", errors="replace").strip()
        enc = self.tokenizer(line, self.max_len, update=False)
        ids = enc["input_ids"]
return{"input_ids": ids[:-],"attention_mask": enc["attention_mask"][:-],"labels": ids[:]}

# ================================================

classCustomLRScheduler:
def__init__(self, optimizer, config):
        self.optimizer = optimizer
        self.base_lr = config["learning_rate"]
        self.gamma = config["learning_gamma"]

defstep(self, epoch):
        new_lr = self.base_lr *(self.gamma ** epoch)
for param_group in self.optimizer.param_groups:
            param_group['lr']= new_lr

# ================================================

defsave_checkpoint(model, optimizer, tokenizer, config, global_epoch, stage_idx, epoch_in_stage, save_path):
"""保存训练检查点"""
    os.makedirs(save_path, exist_ok=True)

# 保存模型权重
    state = model.state_dict()
    save_file(state, os.path.join(save_path,"model.safetensors"))

# 保存配置
withopen(os.path.join(save_path,"config.json"),"w", encoding="utf-8")as f:
        json.dump(config, f, indent=, ensure_ascii=False)

# 保存tokenizer
withopen(os.path.join(save_path,"tokenizer.json"),"w", encoding="utf-8")as f:
        json.dump(tokenizer.get_split_tokens(), f, indent=, ensure_ascii=False)

# 保存优化器状态
    torch.save(optimizer.state_dict(), os.path.join(save_path,"optimizer.pt"))

# 保存训练状态
    checkpoint_state ={
"global_epoch": global_epoch,
"stage_idx": stage_idx,
"epoch_in_stage": epoch_in_stage
}
withopen(os.path.join(save_path,"training_state.json"),"w", encoding="utf-8")as f:
        json.dump(checkpoint_state, f, indent=, ensure_ascii=False)

defload_checkpoint(checkpoint_path, model, optimizer, tokenizer, device):
"""加载训练检查点"""
# 加载模型权重
    model_state = load_file(os.path.join(checkpoint_path,"model.safetensors"))
    model.load_state_dict(model_state)
    model.to(device)

# 加载tokenizer
withopen(os.path.join(checkpoint_path,"tokenizer.json"),"r", encoding="utf-8")as f:
        tokenizer.split_tokens = OrderedDict(json.load(f))

# 加载优化器状态
    optimizer_state = torch.load(os.path.join(checkpoint_path,"optimizer.pt"), map_location=device)
    optimizer.load_state_dict(optimizer_state)

# 加载训练状态
withopen(os.path.join(checkpoint_path,"training_state.json"),"r", encoding="utf-8")as f:
        training_state = json.load(f)

return training_state

deffind_latest_checkpoint():
"""查找最新的检查点目录"""
    checkpoint_dirs =[]
for item in os.listdir("."):
if os.path.isdir(item)and"_epoch_"in item:
            checkpoint_dirs.append(item)

ifnot checkpoint_dirs:
returnNone

# 按epoch编号排序
    checkpoint_dirs.sort(key=lambda x:int(x.split("_epoch_")[-]))
return checkpoint_dirs[-]

# ================================================

defrun_epoch(model, data_loader, device, pad_id, epoch, optimizer=None, scaler=None, use_cuda=True, accumulation_steps=):
    total_loss =0.0
    total_correct =
    total_tokens =

    mode ="Train"if optimizer isnotNoneelse"Valid"
    lr = optimizer.param_groups[]["lr"]if optimizer isnotNoneelse0.0
    device_type ="cuda"if use_cuda else"cpu"

    pbar = tqdm(data_loader, desc=f"[{mode}{epoch+:02d}]", dynamic_ncols=True)
for step, batch inenumerate(pbar):
        batch ={k: v.to(device, non_blocking=use_cuda)for k, v in batch.items()}

if optimizer isnotNone:
if use_cuda:
# 使用 CUDA 混合精度训练
with torch.amp.autocast(device_type=device_type):
                    outputs = model(**batch)
                    loss = outputs["loss"].mean()/ accumulation_steps
                scaler.scale(loss).backward()

if(step +)% accumulation_steps ==:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
else:
# CPU 训练,不使用混合精度
                outputs = model(**batch)
                loss = outputs["loss"].mean()/ accumulation_steps
                loss.backward()

if(step +)% accumulation_steps ==:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)
else:
with torch.no_grad():
                outputs = model(**batch)
                loss = outputs["loss"]

        total_loss += loss.item()* accumulation_steps
        mask = batch["labels"]!= pad_id
        correct =((outputs["logits"].argmax(dim=-)== batch["labels"])& mask).sum().item()
        total_correct += correct
        total_tokens += mask.sum().item()
        avg_acc = total_correct / total_tokens if total_tokens >else0.0
        pbar.set_postfix({"loss":f"{loss.item()* accumulation_steps:.6f}","acc":f"{avg_acc:.6f}","lr":f"{lr:.6f}"})

    avg_loss = total_loss /len(data_loader)
    avg_acc  = total_correct / total_tokens if total_tokens >else0.0
return avg_loss, avg_acc

# ================================================

defstage_train(stages, config, resume=True):
    device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
    use_cuda = torch.cuda.is_available()

# 尝试恢复检查点
    checkpoint_state =None
if resume:
        latest_checkpoint = find_latest_checkpoint()
if latest_checkpoint:
print(f"\n========== 恢复训练:{latest_checkpoint} ==========\n")
            tokenizer = ChatTokenizer(config)
            model = ChatModel(config)
            optimizer = AdamW(model.parameters(), lr=config["learning_rate"])

            checkpoint_state = load_checkpoint(latest_checkpoint, model, optimizer, tokenizer, device)
            global_epoch = checkpoint_state["global_epoch"]
            start_stage_idx = checkpoint_state["stage_idx"]
            start_epoch_in_stage = checkpoint_state["epoch_in_stage"]

print(f"从 epoch {global_epoch} 恢复训练\n")
else:
print("\n========== 未找到检查点,从头开始训练 ==========\n")
            checkpoint_state =None
            start_stage_idx =
            start_epoch_in_stage =

# 如果没有恢复,初始化新的训练
if checkpoint_state isNone:
print(f"\n========== Tokenizer ==========\n")
        tokenizer = ChatTokenizer(config)
        tokenizer.build_split_tokens(stages)
        pad_id = tokenizer.get_split_tokens()["<|padding|>"]

        model = ChatModel(config)
        model.to(device)
print(f"Using device: {device}\n")
        optimizer = AdamW(model.parameters(), lr=config["learning_rate"])
        global_epoch =
        start_stage_idx =
        start_epoch_in_stage =

    scheduler = CustomLRScheduler(optimizer, config)
    num_workers =min(, os.cpu_count()or)

# 只在 CUDA 可用时创建 GradScaler
    scaler = torch.amp.GradScaler()if use_cuda elseNone

# 只在 CUDA 可用时启用 pin_memory
    pin_memory = use_cuda

for stage_idx, stage inenumerate(stages):
if stage_idx < start_stage_idx:
continue

print(f"\n========== {stage['stage_name']} ==========\n")
        dataset = ChatDataset(tokenizer, stage["file_path"], config)

        indices = torch.randperm(len(dataset)).tolist()
        split_idx =int(len(dataset)*(- config["split_valid"]))
        train_dataset = Subset(dataset, indices[:split_idx])
        val_dataset = Subset(dataset, indices[split_idx:])

        train_loader = DataLoader(train_dataset, batch_size=config["batch_size"],
            num_workers=num_workers, persistent_workers=(num_workers >), shuffle=True, pin_memory=pin_memory)
# Validation: use num_workers=0 to avoid hanging on Windows with persistent workers
        val_loader = DataLoader(val_dataset, batch_size=config["batch_size"],
            num_workers=, shuffle=False, pin_memory=False)

# 确定当前stage的起始epoch
        current_start_epoch = start_epoch_in_stage if stage_idx == start_stage_idx else

for epoch_in_stage inrange(current_start_epoch, stage["epochs"]):
            scheduler.step(global_epoch)
            model.train()
            train_loss, train_acc = run_epoch(model, train_loader, device, tokenizer.get_split_tokens()["<|padding|>"], global_epoch,
                                              optimizer=optimizer, scaler=scaler, use_cuda=use_cuda,
                                              accumulation_steps=config.get("accumulation_steps",))
            model.eval()
            val_loss, val_acc = run_epoch(model, val_loader, device, tokenizer.get_split_tokens()["<|padding|>"], global_epoch,
                                          optimizer=None, scaler=None, use_cuda=use_cuda,
                                          accumulation_steps=config.get("accumulation_steps",))

            save_path = os.path.join(".",f"{stage['stage_name']}_epoch_{global_epoch+}")
            save_checkpoint(model, optimizer, tokenizer, config, global_epoch, stage_idx, epoch_in_stage, save_path)
print(f"\n检查点已保存: {save_path}\n")

            global_epoch +=

# ================================================

if __name__ =="__main__":
    parser = argparse.ArgumentParser(description="训练聊天模型")
    parser.add_argument("--resume", action="store_true", default=True,
help="从最新的检查点恢复训练(默认启用)")
    parser.add_argument("--no-resume", action="store_false", dest="resume",
help="不恢复训练,从头开始")
    args = parser.parse_args()

    torch.backends.cudnn.benchmark =True
    torch.backends.cuda.matmul.allow_tf32 =True
    stages =[
{"stage_name":"Fine-tuning","file_path":"./data/daily_dataset_zh_filter.txt","epochs":},
]
    stage_train(stages, default_config, resume=args.resume)

训练

代码语言:javascript
复制
python .\train.py

========== Tokenizer ==========

[Tokenize ]:%|█████████████████████████████████████████████████████████|/[:<:,5267.13it/s]
Using device: cuda


========== Fine-tuning ==========

[Train ]:%|████████████████████|/[::<:,8.19it/s, loss=1.574882, acc=0.489533, lr=0.000100]
[Valid ]:%|██████████████████████████|/[:<:,10.87it/s, loss=2.426748, acc=0.566438, lr=0.000000]

推理chat.py

代码语言:javascript
复制
import os, json, torch
from safetensors.torch import load_file
from train import*
from collections import OrderedDict
from colorama import init as colorama_init, Fore, Style

colorama_init(autoreset=True)

# ================================================

defsample_next_token(logits, generated_tokens, repetition_penalty, presence_penalty, temperature):
for token inset(generated_tokens):
if logits[token]<:
            logits[token]*= repetition_penalty
else:
            logits[token]/= repetition_penalty
    vocab_size = logits.size()
    mask = torch.zeros(vocab_size, dtype=torch.bool, device=logits.device)
    mask[list(set(generated_tokens))]=True
    logits[mask]+= presence_penalty
    probs = torch.softmax(logits / temperature, dim=-)
    next_token = torch.multinomial(probs, num_samples=)
return next_token.item(), probs

# ================================================

defgenerate_response(model, tokenizer, prompt, device, config, max_length=, temperature=0.3, repetition_penalty=1.0, presence_penalty=-1.5):
    encoded = tokenizer(f"<|user|>{prompt}<|assistant|>", update=False)
    generated = encoded["input_ids"].unsqueeze().to(device)
    unknown_id = tokenizer.split_tokens.get("<|unknown|>")
    end_id = tokenizer.split_tokens.get("<|end|>")
    newline_id = tokenizer.split_tokens.get("\\n")

print(Fore.GREEN +"Assistant:"+ Style.RESET_ALL, end=" ", flush=True)

with torch.no_grad():
for _ inrange(max_length):
if generated.size()> config["max_seq_length"]:
                current_input = generated[:,-config["max_seq_length"]:]
                pos_offset = generated.size()- config["max_seq_length"]
else:
                current_input = generated
                pos_offset =

            outputs = model(current_input, pos_offset=pos_offset)
            logits = outputs["logits"][,-,:].clone()
            gen_tokens = generated[].tolist()
            token_id, probs = sample_next_token(logits, gen_tokens, repetition_penalty, presence_penalty, temperature)

if token_id == unknown_id and probs.sum()>:
                probs[unknown_id]=0.0
                probs = probs / probs.sum()
                token_id = torch.multinomial(probs, num_samples=).item()

            generated = torch.cat((generated, torch.tensor([[token_id]], device=generated.device)), dim=)
if token_id == end_id:
break
            token_str = tokenizer.decode([token_id])
if token_id == newline_id:
print()
else:
print(token_str, end="", flush=True)
print()

# ================================================

defload_chat_model(model_dir, device):
withopen(os.path.join(model_dir,"config.json"),"r", encoding="utf-8")as f:
        config = json.load(f)
withopen(os.path.join(model_dir,"tokenizer.json"),"r", encoding="utf-8")as f:
        token_dict = json.load(f)
    tokenizer = ChatTokenizer(config)
    tokenizer.split_tokens = OrderedDict(token_dict)
    model = ChatModel(config).to(device)
    state_dict = load_file(os.path.join(model_dir,"model.safetensors"))
    model.load_state_dict(state_dict)
    model.eval()

print("="*)
    total_params =sum(p.numel()for p in model.parameters())
print(f"Total parameters: {total_params:,}")
return model, tokenizer, config

# ================================================

if __name__ =="__main__":
print("EasyGPT Beta V1.5 Torch Inference (Dev)")
    model_dir ="./Fine-tuning_epoch_15"
    device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
    model, tokenizer, config = load_chat_model(model_dir, device)
whileTrue:
print("="*)
        prompt =input(Fore.CYAN +"User:"+ Style.RESET_ALL +" ")
if prompt.strip().lower()in["exit","quit"]:
break
        generate_response(model, tokenizer, prompt, device, config)

推理

一个epoch效果

代码语言:javascript
复制
python .\chat.py
EasyGPT Beta V1. Torch Inference (Dev)
==================================================
Total parameters:,,
==================================================
User: 你好
Assistant: 当然,我可以帮您解决一个问题。
==================================================
User: 你怎么样
Assistant: 当我们遇到困难时,我们会遇到一些挑战和困难。以下是一些可能的方法:

. 学习新技能和知识。通过阅读相关书籍、在线资源、参加社交活动等方式来学习新技能。
. 练习写作技巧。使用专业术语或者词汇来帮助学生更好地理解文本。
. 寻找新的学习资源。如果你想要写作,可以寻找一些有趣的学习资源。
. 与同事合作。学习新的知识和技能也是提高自己的能力的重要因素。
. 尝试不同的学习方式。不断尝试新的学习方式,并尝试新的教材和技巧。
. 学习新的技能和知识。学习新的技能和知识,可以让你更快地掌握新技能和知识。
. 学习新的技能和知识。学习新的技能和知识,可以让你更快地适应新的环境和技能。
. 坚持学习。学习新技能和知识可以让你更快地掌握新技能和知识。
. 学习新技能和知识。学习新技能和知识可以帮助学生更好地了解新技能和知识,从而更好地适应新的环境和变化。
总之,学习新技能和知识的重要性,需要不断努力和探索。
==================================================
User: 今天天气怎么样
Assistant:. 去散步或骑自行车
. 穿上适合的衣服和鞋子,可以选择一些轻便的衣服或者是一种舒适的运动方式。
. 在户外活动或户外活动中进行户外活动,如徒步旅行、爬山或徒步旅行等。
. 在户外活动中放松身心,可以享受到美好的时光和休息。
==================================================

两个epoch效果

代码语言:javascript
复制
python .\chat.py
EasyGPT Beta V1.5 Torch Inference (Dev)
==================================================
Total parameters: 77,607,436
==================================================
User: 你是谁
Assistant: 我是一个人工智能模型,没有感觉或情感。但是,我可以为您提供一些关于人工智能的信息和建议。
首先,我们需要了解人工智能的定义和应用场景。人工智能的目标是让计算机系统能够像人类一样思考、学习、推理和决策等方面的能力。这种技术包括语音识别、图像识别、自然语言处理、计算机视觉、语音识别、机器翻译等。此外,人工智能还可以通过机器学习来实现自动驾驶汽车、智能家居、医疗保健、金融服务等领域。
其次,人工智能也是一个非常重要的话题。人工智能在各行各业都有广泛的应用,如医疗保健、交通运输、制造业等。例如,在金融领域,人工智能可以帮助银行进行风险评估和投资决策。此外,人工智能还可以帮助企业更好地管理风险和预测市场趋势。
最后,人工智能也是一个非常热门的话题。人工智能的应用场景是使用自然语言处理技术,比如语音识别、机器翻译、智能客服等。人工智能的应用场景也很广泛,比如智能家居、智能家居等。
总之,人工智能已经成为人工智能领域的一个热门话题。它不仅可以帮助人们更好地理解人类生活,还可以为人类带来更多便利和高效的生活体验。
==================================================
User: 你好
Assistant: 我叫 EasyGPT,很高兴见到您。请问您想要咨询哪个职位?
==================================================
User: 你是什么
Assistant: 我是一个人工智能模型,没有真实的感觉或情感。
==================================================
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2026-04-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 王忘杰的小屋 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 项目地址
  • EasyGPT
  • 调整参数
  • 查看CUDA
  • epochs
  • 增加断点保存
  • 增加断点保存的train.py
  • 训练
  • 推理chat.py
  • 推理
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档