深度学习基本理解与语法快速入门
1. 前置约定(完全匹配 ViT-small-patch16-224)
固定参数 :输入图像尺寸 `224×224`,patch 大小 `16×16`,嵌入维度 `embed_dim=384`,注意力头数 `num_heads=6`,单头维度 `head_dim=64`,Transformer 层数 `depth=12`,分类类别数 `num_classes=C`。 符号说明 :`B` = batch size(批次样本数),`C` = 分类类别数。以下所有步骤严格对应 ViT forward 代码行号,标注 【对应 forward 第 X 行】。
2. 完整数据流
【对应 forward 第 1 行】输入初始化与批次维度提取
- 输入原始数据:shape 为
[B, 3, 224, 224]的 RGB 图像批次(3 为通道数,224×224 为图像宽高) - 操作:提取当前批次的样本数量 B,用于后续 CLS token 的维度扩展
【对应 forward 第 2 行】PatchEmbed 分块嵌入(图像 → token 序列,唯一与文本 Transformer 的核心区别)
- 操作 1:卷积分块与线性投影输入
[B, 3, 224, 224],通过 1 个kernel_size=16、stride=16、out_channels=384的 2D 卷积,将图像切分为14×14=196个不重叠的 16×16 patch,同时将每个 patch 投影为 384 维向量。卷积输出 shape:[B, 384, 14, 14] - 操作 2:维度调整,转为 token 序列
对卷积输出执行flatten(2),将 14×14 展平为 196,得到[B, 384, 196];再transpose(1,2)得到[B, 196, 384]。
【对应 forward 第 3 行】CLS token 批次维度扩展
- 可学习的 CLS token,初始化 shape
[1, 1, 384] - 执行
expand(B, -1, -1),为批次内每个样本分配独立的 CLS token - 输出 shape:
[B, 1, 384]
【对应 forward 第 4 行】CLS token 与 patch token 拼接
- 在序列维度(dim=1),将 CLS token 拼接到 patch 序列最开头
- 输出 shape:
[B, 197, 384](197 = 1 CLS + 196 patch)
【对应 forward 第 5 行】位置编码添加
- 将 token 序列与可学习的位置编码
pos_embed(shape[1, 197, 384])逐元素相加 - 为每个 token 注入位置信息,解决 Transformer 无顺序感知能力的问题
【对应 forward 第 6 行】位置 Dropout 正则化
- 对添加完位置编码的 token 序列执行 Dropout,缓解过拟合
【对应 forward 第 7 行】12 层 Transformer Block 堆叠
每层 Block 输入/输出 shape 均为 [B, 197, 384],可堆叠。
单 Block 完整子流程(点击展开)
7.1 多头自注意力计算 + 残差连接
- 7.1.1 前置 LayerNorm 归一化
- 7.1.2.1 QKV 线性映射:384 → 1152(3×384),输出
[B, 197, 1152] - 7.1.2.2 多头拆分:
reshape(B, 197, 3, 6, 64)→permute(2, 0, 3, 1, 4)→unbind(0),得到 Q/K/V 各[B, 6, 197, 64] - 7.1.2.3 相似度计算与缩放:Q 与 K 转置相乘,除以 √64=8,输出
[B, 6, 197, 197] - 7.1.2.4 Softmax 归一化 + Dropout
- 7.1.2.5 注意力加权与多头拼接:加权输出
[B, 6, 197, 64]→transpose(1,2)→reshape(B, 197, 384)→ proj 线性层 + Dropout - 7.1.3 残差连接:Attention 输出 + Block 原始输入
7.2 MLP 非线性变换 + 残差连接
- 7.2.1 前置 LayerNorm 归一化
- 7.2.2 MLP:384 → 1536(fc1)→ GELU → Dropout → 1536 → 384(fc2)→ Dropout
- 7.2.3 残差连接:MLP 输出 + MLP 输入
【对应 forward 第 8 行】最终全局 LayerNorm
- 修正 12 层残差累加导致的特征分布偏移
【对应 forward 第 9 行】CLS token 提取
- 通过切片
x[:, 0]提取 CLS token 作为整张图像的全局特征表示 - 输出 shape:
[B, 384]
【对应 forward 第 10 行】分类头线性映射
- 384 维 → C 类,输出
[B, C]
Dropout 正则化机制
所有正则化均为 PyTorch 标准反向 :训练时对神经元输出独立随机置 0,非 0 元素除以 (1-p) 放缩;推理阶段完全关闭。
| 正则化类型 | 位置 | 作用维度 | 核心作用 |
|---|---|---|---|
| 位置编码 Dropout | forward 第 6 步 | [B, 197, 384] 最后一维 |
约束浅层视觉特征表达 |
| 注意力权重 Dropout | 7.1.2.4 步 | [B, 6, 197, 197] 最后两维 |
唯一作用于 ,切断过度依赖 |
| 注意力输出 Dropout | 7.1.2.5 步 | [B, 197, 384] 最后一维 |
约束全局注意力输出特征 |
| MLP 层 Dropout | 7.2.2.3 / 7.2.2.5 | [B, 197, 1536] 和 [B, 197, 384] |
约束单 token 非线性变换 |
数据预处理与模型定义
1 | import torch |
1 | data_transform = transforms.Compose([ |
使用 timm 创建 ViT 模型:
1 | def create_vit_model(num_classes): |
训练函数
1 | def train_model(model, criterion, optimizer, scheduler, train_loader, test_loader, num_epochs): |
主函数
1 | if __name__ == '__main__': |
config.yaml 配置文件
1 | attn_drop_rate: 0.1 |
自动调优器
遍历参数组合,自动运行训练并记录最佳精度。
1 | import yaml |
不使用 timm,完全手动定义 ViT 的每个组件,便于理解内部运作。
1 | import torch |
固定参数(完全匹配 ViT-small-patch16-224)
1 | img_size = 224 |
PatchEmbed 模块
1 | class PatchEmbed(nn.Module): |
Transformer Block
1 | class Block(nn.Module): |
完整 ViT 模型
1 | class VisionTransformer(nn.Module): |
ResNet18 包含 17 个卷积层 + 1 个全连接层,共 18 层。核心创新是 残差连接 :每个 BasicBlock 的输出 = 卷积输出 + 输入(identity),有效解决了深层网络的梯度消失问题。
数据加载
1 | import torch |
DataLoader只负责批量打包、打乱和多线程加载,transform 在__getitem__中执行。
ResNet18 模型结构
1 | class ResNet18(nn.Module): |
训练与保存
1 | def train_model(model, criterion, optimizer, train_loader, test_loader, num_epochs=10): |
梯度累计写法:当显存不足时,可将
batch_size调小并累计多步梯度后一次性更新。但需同步调小 batch,否则适得其反。
对模型最后一层输出(softmax 前)进行的不同采样策略,控制生成文本的多样性与质量。
温度采样
温度参数 控制概率分布的平滑程度:τ 越大越平滑(多样性高),τ 越小越尖锐(确定性高)。1 | import torch |
Top-K 采样
只保留概率最高的 k 个 token,其余置为 -∞。
1 | def top_k_sampling(logits, k=50): |
Top-P(Nucleus)采样
按概率从高到低累加,保留累加到阈值 p 的 token 集合。
1 | def top_p_sampling(logits, p=0.9): |
典型 P(Typical)采样
基于信息含量筛选:保留与熵偏差最小的 token。
1 | def typical_sampling(logits, typical_p=0.9): |
Min-P 采样
以最大概率为基准,筛掉概率低于 max_prob × min_p_ratio 的 token。
1 | def min_p_sampling(logits, min_p_ratio=0.05): |
Top-A 采样
根据熵动态调整 k 值,再进行 Top-K 采样。
1 | def top_a_sampling(logits, base_k=4): |
基本读写
1 | import csv |
CSV 数据训练示范
通过自定义
Dataset将 CSV 文件加载为 PyTorch 可用数据集,搭配 FFN 网络进行训练。
1 | import torch |
样本数据格式示例:
| name | param1 | param2 | param3 | label |
|---|---|---|---|---|
| sample1 | 0.5 | 0.3 | 0.2 | class_A |
| sample2 | 0.1 | 0.7 | 0.9 | class_B |
| sample3 | 0.8 | 0.4 | 0.6 | class_A |
1 | import json |
dataset_infos.json 示例结构:
1 | { |
json.dumps(obj, indent=2, ensure_ascii=False)将 Python 对象编码为 JSON 字符串,ensure_ascii=False允许输出中文等非 ASCII 字符。
