3700+ Star 的开源项目:从零用 PyTorch 手搓一个 Transformer LLM
3700+ Star 的开源项目:从零用 PyTorch 手搓一个 Transformer LLM
先收藏,周末照着跑一遍,比看十篇论文都管用。
说到训练大语言模型,很多人的第一反应是:那不是 OpenAI、Google 这些巨头才能干的事吗?需要几千张 GPU、TB 级数据、几个亿的算力预算……
但 FareedKhan-dev 的 train-llm-from-scratch 项目告诉你:不需要。一台带 GPU 的机器,照着他的代码跑,你就能从零训练出自己的 LLM——从 13M 参数到 2B 参数都行。
这个项目在 GitHub 上拿到了 3700+ Star、500+ Fork,MIT 协议开源。它不做任何封装,不用 Hugging Face Transformers,不用任何高层框架——直接用 PyTorch 的 nn.Module,一行一行把 Transformer 写出来。
本文提纲
- 这个项目到底做了什么
- 项目结构和技术架构
- 从 Attention 到 Transformer Block:核心代码拆解
- 怎么跑:从下载数据到生成文本
- 百万 vs 十亿参数:训练效果对比
- 适合什么人、能学到什么
这个项目到底做了什么
一句话总结:基于 2017 年那篇改变整个 AI 行业的论文《Attention is All You Need》,用纯 PyTorch 实现了完整的 Transformer 语言模型。
它不是玩具 demo。整个训练 pipeline 包括:
- 数据下载:从 The Pile 数据集(825GB,22 个子数据集,涵盖论文、GitHub 代码、书籍、网页等)自动拉取和预处理
- Tokenizer:使用 OpenAI 的 tiktoken(
r50k_base),将文本转为 token ID - 模型架构:完全手写的 Transformer——Multi-Head Attention、MLP、LayerNorm、Residual Connection,全部用
nn.Linear和nn.Embedding从零搭建 - 训练脚本:支持学习率衰减、定期评估、自动保存 checkpoint
- 文本生成:加载训练好的模型,输入 prompt 即可生成文本
作者还附赠了一个 152KB 的 sft_rlhf_guide.ipynb,讲解 SFT(监督微调)和 RLHF(人类反馈强化学习)的流程。
项目结构和技术架构
train-llm-from-scratch/
├── src/models/
│ ├── mlp.py # Multi-Layer Perceptron 模块
│ ├── attention.py # 单头/多头注意力机制
│ ├── transformer_block.py # 单个 Transformer Block
│ └── transformer.py # 完整的 Transformer 模型
├── config/
│ └── config.py # 模型参数、训练超参配置
├── data_loader/
│ └── data_loader.py # 数据加载和批处理
├── scripts/
│ ├── data_download.py # 下载 The Pile 数据集
│ ├── data_preprocess.py # 数据预处理(tokenize + HDF5 存储)
│ ├── train_transformer.py # 训练入口
│ └── generate_text.py # 文本生成推理
├── sft_rlhf_guide.ipynb # SFT/RLHF 指南
└── requirements.txt # 依赖:torch, tiktoken, h5py 等整个代码非常干净,没有一行多余的封装。每个 .py 文件就是一个 nn.Module 子类,职责清晰。
graph TB
subgraph "Data Pipeline"
A[The Pile Dataset
825GB, 22 subsets] --> B[data_download.py]
B --> C[Raw JSON]
C --> D[data_preprocess.py
tiktoken tokenizer]
D --> E[HDF5 Token IDs]
end
subgraph "Model Architecture"
F[Token Embedding
nn.Embedding] --> G[Position Embedding
nn.Embedding]
G --> H[Transformer Block x N]
H --> I[LayerNorm]
I --> J[LM Head
nn.Linear]
end
subgraph "Transformer Block"
K[LayerNorm] --> L[Multi-Head Attention
+ Residual]
L --> M[LayerNorm]
M --> N[MLP
+ Residual]
end
E --> F
H -.-> K
J --> O[Generated Text]
E[("#FF6B6B")]
F[("#4ECDC4")]
H[("#45B7D1")]
L[("#96CEB4")]
N[("#FFEAA7")]
O[("#DDA0DD")]数据流很直观:原始文本 → tokenization → 嵌入 → N 层 Transformer Block → 输出 logits → 生成下一个 token。
从 Attention 到 Transformer Block:核心代码拆解
这是这个项目最值得看的地方。没有 transformers 库的抽象,你看到的就是 Transformer 的每一层实现。
单头注意力(Single Head Attention)
class Head(nn.Module):
def __init__(self, head_size, n_embed, context_length):
super().__init__()
self.key = nn.Linear(n_embed, head_size, bias=False)
self.query = nn.Linear(n_embed, head_size, bias=False)
self.value = nn.Linear(n_embed, head_size, bias=False)
# 因果遮罩:防止看到未来 token
self.register_buffer('tril', torch.tril(
torch.ones(context_length, context_length)
))
def forward(self, x):
B, T, C = x.shape
k = self.key(x) # (B, T, head_size)
q = self.query(x) # (B, T, head_size)
# Scaled dot-product attention
wei = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.shape[-1]))
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
return wei @ self.value(x)这就是论文里的 Scaled Dot-Product Attention。Q、K、V 三个线性投影,缩放点积,因果遮罩,softmax,加权求和。没有任何黑魔法。
Transformer Block
class Block(nn.Module):
def __init__(self, n_head, n_embed, context_length):
super().__init__()
self.ln1 = nn.LayerNorm(n_embed)
self.attn = MultiHeadAttention(n_head, n_embed, context_length)
self.ln2 = nn.LayerNorm(n_embed)
self.mlp = MLP(n_embed)
def forward(self, x):
# Pre-Norm + Residual Connection
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x注意这里用的是 Pre-LayerNorm(先归一化再进入 Attention/MLP),而不是原始论文的 Post-LayerNorm。这是现代 LLM 训练的标配做法,训练更稳定。
完整模型
class Transformer(nn.Module):
def __init__(self, n_head, n_embed, context_length, vocab_size, N_BLOCKS):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, n_embed)
self.position_embed = nn.Embedding(context_length, n_embed)
self.attn_blocks = nn.ModuleList([
Block(n_head, n_embed, context_length) for _ in range(N_BLOCKS)
])
self.layer_norm = nn.LayerNorm(n_embed)
self.lm_head = nn.Linear(n_embed, vocab_size)Token Embedding + Position Embedding → N 个 Transformer Block → LayerNorm → 线性输出层。简洁明了。
怎么跑:从下载数据到生成文本
整个流程四步走:
第一步:克隆项目、安装依赖
git clone https://github.com/FareedKhan-dev/train-llm-from-scratch.git
cd train-llm-from-scratch
export PYTHONPATH="$PYTHONPATH:."
pip install -r requirements.txt第二步:下载和预处理数据
# 下载训练数据(默认 1 个分片,约 11GB;最多 30 个分片)
python scripts/data_download.py --train_max 1
# 预处理为 HDF5 格式
python scripts/data_preprocess.py --max_data 1000第三步:配置模型参数
修改 config/config.py。训练 13M 参数小模型:
VOCAB_SIZE = 50304
CONTEXT_LENGTH = 128
N_EMBED = 128
N_HEAD = 8
N_BLOCKS = 1训练 2B+ 参数大模型(需要大显存 GPU):
VOCAB_SIZE = 50304
CONTEXT_LENGTH = 512
N_EMBED = 2048
N_HEAD = 16
N_BLOCKS = 64第四步:训练和生成
# 训练
python scripts/train_transformer.py
# 生成文本
python scripts/generate_text.py \
--model_path models/transformer_B.pt \
--input_text "Subject: " \
--max_new_tokens 200就这么简单。Colab 或 Kaggle 的免费 T4 就能跑 13M 参数模型。
百万 vs 十亿参数:训练效果对比
作者用 The Pile 数据集分别训练了 13M 和 2B+ 参数的模型,效果差异非常明显。
13M 参数模型输出
输入 "Subject: " 后的输出:
Subject: ClickPaper-summary Study for Interview
Good morning, I hope this message finds you well, as the sun gently peeks through the clouds, ...
虽然语法不完全正确,但能看出模型在尝试生成一封邮件。对于 13M 参数来说,这个结果已经很说明问题了。
2B+ 参数模型输出
长文本生成时,句子开始出现断裂:
There are two miles east coast from 1037 and 73 million refugees (hypotetus)...
blacksmith, musician and boutique hospitality and inspire the strain delivered Canadians...
有趣的是,作者发现 大模型并不一定比小模型好。在训练不充分的情况下,2B 参数模型的输出可能还不如 13M 的——因为它更容易过拟合,对超参数更敏感。损失曲线也印证了这一点:大模型的 loss 震荡更剧烈,学习率衰减的效果也更显著。
这也给了一个实际的经验教训:不是参数越多越好,架构深度、训练数据量和超参数调优同样重要。
GPU 要求参考
| GPU | 显存 | 13M 模型 | 2B 模型 | 最大可训练参数 |
|---|---|---|---|---|
| Colab T4 | 16GB | ✅ | ❌ | ~1.5-2B |
| RTX 3090 | 24GB | ✅ | ✅ | ~3.5-4B |
| RTX 4090 | 24GB | ✅ | ✅ | ~4B |
| A100 | 40GB | ✅ | ✅ | ~6-8B |
| Quadro RTX 8000 | 48GB | ✅ | ✅ | ~8-10B |
适合什么人、能学到什么
这个项目最适合以下三类人:
1. 想真正理解 Transformer 的学习者
如果你看了很多 Attention 的科普文章但始终觉得隔了一层纱,直接读这个代码。每一层都是透明的——Q/K/V 怎么算的、因果遮罩怎么做的、残差连接怎么接的,全在代码里,没有任何框架的抽象挡在中间。
2. 准备 AI 面试的工程师
"手写 Attention"、"手写 Transformer Block"是 AI 岗位面试的高频题。这个项目的代码质量高、注释完整,直接用来准备面试再合适不过。
3. 想快速验证想法的研究者
基于这个代码改架构非常方便。想试试不同的 Attention 变体?改 attention.py。想调整残差连接方式?改 transformer_block.py。想换数据集?改 data_loader.py。每个模块都是独立的,改动范围可控。
跑通了还是卡住了?评论区告诉我,看到都会回。觉得有用就点个赞让更多人看到。
作者: itech001
来源: 公众号:AI人工智能时代
网站: https://www.theaiera.cn/
每日分享最前沿的AI新闻资讯和技术研究。
本文首发于 AI人工智能时代,转载请注明出处。