ViT 理论基础(图像 Transformer)

1. 前置约定(完全匹配 ViT-small-patch16-224)

固定参数 :输入图像尺寸 `224×224`,patch 大小 `16×16`,嵌入维度 `embed_dim=384`,注意力头数 `num_heads=6`,单头维度 `head_dim=64`,Transformer 层数 `depth=12`,分类类别数 `num_classes=C`。 符号说明 :`B` = batch size(批次样本数),`C` = 分类类别数。

以下所有步骤严格对应 ViT forward 代码行号,标注 【对应 forward 第 X 行】

2. 完整数据流

【对应 forward 第 1 行】输入初始化与批次维度提取

  • 输入原始数据:shape 为 [B, 3, 224, 224] 的 RGB 图像批次(3 为通道数,224×224 为图像宽高)
  • 操作:提取当前批次的样本数量 B,用于后续 CLS token 的维度扩展

【对应 forward 第 2 行】PatchEmbed 分块嵌入(图像 → token 序列,唯一与文本 Transformer 的核心区别)

  • 操作 1:卷积分块与线性投影输入 [B, 3, 224, 224],通过 1 个 kernel_size=16stride=16out_channels=384 的 2D 卷积,将图像切分为 14×14=196 个不重叠的 16×16 patch,同时将每个 patch 投影为 384 维向量。卷积输出 shape:[B, 384, 14, 14]
  • 操作 2:维度调整,转为 token 序列
    对卷积输出执行 flatten(2),将 14×14 展平为 196,得到 [B, 384, 196];再 transpose(1,2) 得到 [B, 196, 384]

【对应 forward 第 3 行】CLS token 批次维度扩展

  • 可学习的 CLS token,初始化 shape [1, 1, 384]
  • 执行 expand(B, -1, -1),为批次内每个样本分配独立的 CLS token
  • 输出 shape:[B, 1, 384]

【对应 forward 第 4 行】CLS token 与 patch token 拼接

  • 在序列维度(dim=1),将 CLS token 拼接到 patch 序列最开头
  • 输出 shape:[B, 197, 384](197 = 1 CLS + 196 patch)

【对应 forward 第 5 行】位置编码添加

  • 将 token 序列与可学习的位置编码 pos_embed(shape [1, 197, 384])逐元素相加
  • 为每个 token 注入位置信息,解决 Transformer 无顺序感知能力的问题

【对应 forward 第 6 行】位置 Dropout 正则化

  • 对添加完位置编码的 token 序列执行 Dropout,缓解过拟合

【对应 forward 第 7 行】12 层 Transformer Block 堆叠

每层 Block 输入/输出 shape 均为 [B, 197, 384],可堆叠。

单 Block 完整子流程(点击展开)

7.1 多头自注意力计算 + 残差连接

  • 7.1.1 前置 LayerNorm 归一化
  • 7.1.2.1 QKV 线性映射:384 → 1152(3×384),输出 [B, 197, 1152]
  • 7.1.2.2 多头拆分:reshape(B, 197, 3, 6, 64)permute(2, 0, 3, 1, 4)unbind(0),得到 Q/K/V 各 [B, 6, 197, 64]
  • 7.1.2.3 相似度计算与缩放:Q 与 K 转置相乘,除以 √64=8,输出 [B, 6, 197, 197]
  • 7.1.2.4 Softmax 归一化 + Dropout
  • 7.1.2.5 注意力加权与多头拼接:加权输出 [B, 6, 197, 64]transpose(1,2)reshape(B, 197, 384) → proj 线性层 + Dropout
  • 7.1.3 残差连接:Attention 输出 + Block 原始输入

7.2 MLP 非线性变换 + 残差连接

  • 7.2.1 前置 LayerNorm 归一化
  • 7.2.2 MLP:384 → 1536(fc1)→ GELU → Dropout → 1536 → 384(fc2)→ Dropout
  • 7.2.3 残差连接:MLP 输出 + MLP 输入

【对应 forward 第 8 行】最终全局 LayerNorm

  • 修正 12 层残差累加导致的特征分布偏移

【对应 forward 第 9 行】CLS token 提取

  • 通过切片 x[:, 0] 提取 CLS token 作为整张图像的全局特征表示
  • 输出 shape:[B, 384]

【对应 forward 第 10 行】分类头线性映射

  • 384 维 → C 类,输出 [B, C]

Dropout 正则化机制

所有正则化均为 PyTorch 标准反向 :训练时对神经元输出独立随机置 0,非 0 元素除以 (1-p) 放缩;推理阶段完全关闭。

正则化类型 位置 作用维度 核心作用
位置编码 Dropout forward 第 6 步 [B, 197, 384] 最后一维 约束浅层视觉特征表达
注意力权重 Dropout 7.1.2.4 步 [B, 6, 197, 197] 最后两维 唯一作用于 ,切断过度依赖
注意力输出 Dropout 7.1.2.5 步 [B, 197, 384] 最后一维 约束全局注意力输出特征
MLP 层 Dropout 7.2.2.3 / 7.2.2.5 [B, 197, 1536][B, 197, 384] 约束单 token 非线性变换

ViT 精简版 — 基于 timm 库

数据预处理与模型定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, Subset
import timm
from tqdm import tqdm
import yaml

# 加载配置
with open('config.yaml', 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)

# 可配置参数 - 显式类型转换确保数值类型正确
BATCH_SIZE = int(config['batch_size'])
EPOCHS = int(config['epochs'])
LEARNING_RATE = float(config['learning_rate'])
DROP_RATE = float(config['drop_rate'])
ATTN_DROP_RATE = float(config['attn_drop_rate'])
DROP_PATH_RATE = float(config['drop_path_rate'])
WARMUP_EPOCHS = int(config['warmup_epochs'])
L1_LAMBDA = float(config['l1_lambda'])
DATA_DIR = str(config['data_dir'])
MODEL_NAME = str(config['model_name'])
PRETRAINED = bool(config['pretrained'])
数据增强流水线 :训练时使用随机裁剪+翻转,测试时仅中心裁剪。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_data(data_dir):
full_dataset = datasets.ImageFolder(root=data_dir, transform=None)
# ImageFolder 处理的文件夹格式:root/TAG_NAME/*.jpg
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_subset, test_subset = random_split(full_dataset, [train_size, test_size])

train_dataset = Subset(datasets.ImageFolder(root=data_dir, transform=data_transform),
train_subset.indices)
test_dataset = Subset(datasets.ImageFolder(root=data_dir, transform=test_transform),
test_subset.indices)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

class_counts = {}
for _, label in train_dataset:
class_counts[label] = class_counts.get(label, 0) + 1

return train_loader, test_loader, train_dataset, test_dataset, class_counts

使用 timm 创建 ViT 模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
def create_vit_model(num_classes):
model = timm.create_model(
model_name=MODEL_NAME,
pretrained=PRETRAINED,
num_classes=num_classes,
drop_rate=DROP_RATE,
attn_drop_rate=ATTN_DROP_RATE,
drop_path_rate=DROP_PATH_RATE
)
print(f"模型blocks层数: {len(model.blocks)}")
print(f"输入维度: (batch_size, 3, 224, 224)")
print(f"输出维度: (batch_size, {num_classes})")
return model

训练函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def train_model(model, criterion, optimizer, scheduler, train_loader, test_loader, num_epochs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
best_val_acc = 0.0

for epoch in range(num_epochs):
# 训练阶段
model.train()
running_loss = 0.0
correct = 0
total = 0

with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training") as pbar:
for inputs, labels in pbar:
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)

# L1 正则化:产生稀疏解,部分权重被压缩为 0
# Loss = 原误差 + weight_decay × (所有权重的绝对值和)
# w' = w - lr × (原梯度 + weight_decay × sign(w))
l1_loss = 0
for name, param in model.named_parameters():
if 'weight' in name and 'norm' not in name:
l1_loss += torch.sum(torch.abs(param))
loss += L1_LAMBDA * l1_loss

loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# .item() 把 PyTorch 张量转换成普通 Python 数字
pbar.set_postfix({"loss": running_loss/total, "acc": 100.*correct/total})

train_loss = running_loss / len(train_loader.dataset)
train_acc = 100. * correct / total

# 验证阶段
model.eval()
running_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
with tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation") as pbar:
for inputs, labels in pbar:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({"loss": running_loss/total, "acc": 100.*correct/total})

val_loss = running_loss / len(test_loader.dataset)
val_acc = 100. * correct / total

# 更新学习率调度器
scheduler.step()

if val_acc <= best_val_acc:
print(f'Epoch {epoch+1}/{num_epochs}: '
f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "model_best.pth")
print(f"保存最佳模型,验证精度: {best_val_acc:.2f}%")

return model

主函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
if __name__ == '__main__':
train_loader, test_loader, train_dataset, test_dataset, class_counts = load_data(DATA_DIR)
class_names = train_dataset.dataset.classes
num_classes = len(class_names)

# 计算类别权重(与样本数成反比),处理类别不平衡
total_samples = sum(class_counts.values())
class_weights = [1.0] * num_classes
for label, count in class_counts.items():
class_weights[label] = total_samples / (count * num_classes)
class_weights = torch.tensor(class_weights)

model = create_vit_model(num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
# AdamW:引入权重衰减(L2 正则化),防止过拟合
# Loss = 原误差 + weight_decay × (所有权重的平方和)/2
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

# 学习率调度:先 Warmup 后 CosineAnnealing
warmup_scheduler = optim.lr_scheduler.LinearLR(
optimizer, start_factor=1e-5, end_factor=1.0, total_iters=WARMUP_EPOCHS)
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=EPOCHS - WARMUP_EPOCHS)
scheduler = optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[WARMUP_EPOCHS])

model = train_model(model, criterion, optimizer, scheduler,
train_loader, test_loader, num_epochs=EPOCHS)

# 加载最佳模型
checkpoint = torch.load("model_best.pth")
model = create_vit_model(num_classes)
model.load_state_dict(checkpoint)
validate_model(model, test_loader)
print("训练完成!")

config.yaml 配置文件

1
2
3
4
5
6
7
8
9
10
11
attn_drop_rate: 0.1
batch_size: 32
data_dir: data
drop_path_rate: 0.1
drop_rate: 0.2
epochs: 20
l1_lambda: 1e-05
learning_rate: 0.0005
model_name: vit_small_patch16_224
pretrained: false
warmup_epochs: 8

自动调优器

遍历参数组合,自动运行训练并记录最佳精度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import yaml
import subprocess
import shutil
import os
import sys

param_ranges = {
'learning_rate': [1e-4, 2e-4, 5e-4],
'batch_size': [32, 64],
'drop_rate': [0.1, 0.2],
'warmup_epochs': [5, 8]
}

with open('config.yaml', 'r', encoding='utf-8') as f:
base_config = yaml.safe_load(f)

best_accuracy = 0.0
best_config = None

for lr in param_ranges['learning_rate']:
for batch in param_ranges['batch_size']:
for drop in param_ranges['drop_rate']:
for warmup in param_ranges['warmup_epochs']:
current_config = base_config.copy()
current_config.update({
'learning_rate': lr, 'batch_size': batch,
'drop_rate': drop, 'warmup_epochs': warmup
})

with open('config.yaml', 'w', encoding='utf-8') as f:
yaml.dump(current_config, f, default_flow_style=False, allow_unicode=True)

print(f"\n=== 测试参数: lr={lr}, batch={batch}, drop={drop}, warmup={warmup} ===")

result = subprocess.run(
[sys.executable, 'vit.py'],
capture_output=True, text=True, encoding='utf-8',
cwd=os.path.dirname(os.path.abspath(__file__)),
)

accuracy = 0.0
for line in result.stdout.split('\n'):
if '验证集精度:' in line:
try:
accuracy = float(line.split(':')[1].strip().replace('%', ''))
break
except:
pass

if accuracy > best_accuracy:
best_accuracy = accuracy
best_config = current_config.copy()
if os.path.exists('model_best.pth'):
shutil.copy2('model_best.pth', 'final_best.pth')

if best_config:
with open('best_config.yaml', 'w', encoding='utf-8') as f:
yaml.dump(best_config, f, default_flow_style=False, allow_unicode=True)
print(f"\n最佳验证精度: {best_accuracy}%")
print(f"最佳模型已保存为 final_best.pth")

ViT 展开版 — 手动实现

不使用 timm,完全手动定义 ViT 的每个组件,便于理解内部运作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image
import os
import glob
from tqdm import tqdm

BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 2e-4
DROP_RATE = 0.1
ATTN_DROP_RATE = 0.1
DROP_PATH_RATE = 0.1

data_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


class CustomImageDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.samples = []
self.classes = []
self.class_to_idx = {}

for class_idx, class_name in enumerate(sorted(os.listdir(data_dir))):
class_dir = os.path.join(str(data_dir), class_name)
if os.path.isdir(class_dir):
self.classes.append(class_name)
self.class_to_idx[class_name] = class_idx
for img_path in glob.glob(os.path.join(class_dir, '*')):
if img_path.endswith(('.jpg', '.jpeg', '.png', '.JPG', '.JPEG')):
self.samples.append((img_path, class_idx))

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
img_path, class_idx = self.samples[idx]
image = Image.open(img_path).convert('RGB')
class_idx = torch.tensor(class_idx)
if self.transform:
image = self.transform(image)
return image, class_idx

固定参数(完全匹配 ViT-small-patch16-224)

1
2
3
4
5
6
7
8
9
img_size = 224
patch_size = 16
embed_dim = 384
num_heads = 6
head_dim = 64
depth = 12
num_classes = 10
num_patches = (img_size // patch_size) ** 2 # 196
seq_len = num_patches + 1 # 197

PatchEmbed 模块

1
2
3
4
5
6
7
8
9
10
11
class PatchEmbed(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x):
# [B,3,224,224] → [B,384,14,14] → [B,384,196] → [B,196,384]
x = self.proj(x)
x = x.flatten(2)
x = x.transpose(1, 2)
return x

Transformer Block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Block(nn.Module):
def __init__(self):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.attn_drop = nn.Dropout(0.1)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(0.1)

self.norm2 = nn.LayerNorm(embed_dim)
self.fc1 = nn.Linear(embed_dim, embed_dim * 4)
self.fc2 = nn.Linear(embed_dim * 4, embed_dim)
self.mlp_drop = nn.Dropout(0.1)
self.gelu = nn.GELU()

def forward(self, x):
B, N, C = x.shape

# --- 注意力模块 + 残差 ---
norm_x = self.norm1(x)
qkv = self.qkv(norm_x).reshape(B, N, 3, num_heads, head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

attn = (q @ k.transpose(-2, -1)) / 8.0 # 缩放 ÷ √64
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x_attn = (attn @ v).transpose(1, 2).reshape(B, N, C)
x_attn = self.proj(x_attn)
x_attn = self.proj_drop(x_attn)
x = x + x_attn

# --- MLP 模块 + 残差 ---
norm_x = self.norm2(x)
x_mlp = self.fc1(norm_x)
x_mlp = self.gelu(x_mlp)
x_mlp = self.mlp_drop(x_mlp)
x_mlp = self.fc2(x_mlp)
x_mlp = self.mlp_drop(x_mlp)
x = x + x_mlp
return x

完整 ViT 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class VisionTransformer(nn.Module):
def __init__(self):
super().__init__()
self.patch_embed = PatchEmbed()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, embed_dim))
self.pos_drop = nn.Dropout(0.1)
self.blocks = nn.ModuleList([Block() for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)

def forward(self, x):
B = x.shape[0] # 第1行:批次维度
x = self.patch_embed(x) # 第2行:[B,196,384]
cls_token = self.cls_token.expand(B, -1, -1) # 第3行:[B,1,384]
x = torch.cat((cls_token, x), dim=1) # 第4行:[B,197,384]
x = x + self.pos_embed # 第5行:加位置编码
x = self.pos_drop(x) # 第6行:Dropout
for block in self.blocks: # 第7行:12层Block
x = block(x)
x = self.norm(x) # 第8行:最终Norm
x = x[:, 0] # 第9行:提取CLS token → [B,384]
x = self.head(x) # 第10行:分类 → [B,C]
return x

ResNet18 — 手动实现

ResNet18 包含 17 个卷积层 + 1 个全连接层,共 18 层。核心创新是 残差连接 :每个 BasicBlock 的输出 = 卷积输出 + 输入(identity),有效解决了深层网络的梯度消失问题。

数据加载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image
import os
import glob

data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class CustomImageDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.samples = []
self.classes = []
self.class_to_idx = {}

for class_idx, class_name in enumerate(sorted(os.listdir(data_dir))):
class_dir = os.path.join(data_dir, class_name)
if os.path.isdir(class_dir):
self.classes.append(class_name)
self.class_to_idx[class_name] = class_idx
for img_path in glob.glob(os.path.join(class_dir, '*')):
if img_path.endswith(('.jpg', '.jpeg', '.png', '.JPG', '.JPEG')):
self.samples.append((img_path, class_idx))

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
img_path, class_idx = self.samples[idx]
image = Image.open(img_path).convert('RGB')
class_idx = torch.tensor(class_idx)
if self.transform:
image = self.transform(image)
return image, class_idx

DataLoader 只负责批量打包、打乱和多线程加载,transform 在 __getitem__ 中执行。

ResNet18 模型结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class ResNet18(nn.Module):
def __init__(self, num_classes=1000):
super(ResNet18, self).__init__()

# 初始层:Conv1 + MaxPool
# 输入 (224,224,3) → Conv1(7×7,64,stride=2) → (112,112,64)
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# MaxPool(3×3,stride=2) → (56,56,64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# Layer1(第2-5层):56×56×64,2个BasicBlock,无降维
self.layer1_conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1_bn1 = nn.BatchNorm2d(64)
self.layer1_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1_bn2 = nn.BatchNorm2d(64)
self.layer1_conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1_bn3 = nn.BatchNorm2d(64)
self.layer1_conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1_bn4 = nn.BatchNorm2d(64)

# Layer2(第6-9层):28×28×128,第一个BasicBlock需降维
self.layer2_conv1 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
self.layer2_bn1 = nn.BatchNorm2d(128)
self.layer2_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.layer2_bn2 = nn.BatchNorm2d(128)
self.layer2_downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(128)
)
self.layer2_conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.layer2_bn3 = nn.BatchNorm2d(128)
self.layer2_conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.layer2_bn4 = nn.BatchNorm2d(128)

# Layer3(第10-13层):14×14×256
self.layer3_conv1 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
self.layer3_bn1 = nn.BatchNorm2d(256)
self.layer3_conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.layer3_bn2 = nn.BatchNorm2d(256)
self.layer3_downsample = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(256)
)
self.layer3_conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.layer3_bn3 = nn.BatchNorm2d(256)
self.layer3_conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.layer3_bn4 = nn.BatchNorm2d(256)

# Layer4(第14-17层):7×7×512
self.layer4_conv1 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.layer4_bn1 = nn.BatchNorm2d(512)
self.layer4_conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.layer4_bn2 = nn.BatchNorm2d(512)
self.layer4_downsample = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(512)
)
self.layer4_conv3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.layer4_bn3 = nn.BatchNorm2d(512)
self.layer4_conv4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.layer4_bn4 = nn.BatchNorm2d(512)

# 第18层:全局平均池化 + 全连接
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)

def forward(self, x):
# 初始层
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

# Layer1:两个 BasicBlock,均使用直接残差连接
identity = x
out = self.relu(self.layer1_bn1(self.layer1_conv1(x)))
out = self.layer1_bn2(self.layer1_conv2(out))
out = self.relu(out + identity)

identity = out
out = self.relu(self.layer1_bn3(self.layer1_conv3(out)))
out = self.layer1_bn4(self.layer1_conv4(out))
out = self.relu(out + identity)

# Layer2:第一个 BasicBlock 需降维连接
identity = self.layer2_downsample(out)
out = self.relu(self.layer2_bn1(self.layer2_conv1(out)))
out = self.layer2_bn2(self.layer2_conv2(out))
out = self.relu(out + identity)

identity = out
out = self.relu(self.layer2_bn3(self.layer2_conv3(out)))
out = self.layer2_bn4(self.layer2_conv4(out))
out = self.relu(out + identity)

# Layer3
identity = self.layer3_downsample(out)
out = self.relu(self.layer3_bn1(self.layer3_conv1(out)))
out = self.layer3_bn2(self.layer3_conv2(out))
out = self.relu(out + identity)

identity = out
out = self.relu(self.layer3_bn3(self.layer3_conv3(out)))
out = self.layer3_bn4(self.layer3_conv4(out))
out = self.relu(out + identity)

# Layer4
identity = self.layer4_downsample(out)
out = self.relu(self.layer4_bn1(self.layer4_conv1(out)))
out = self.layer4_bn2(self.layer4_conv2(out))
out = self.relu(out + identity)

identity = out
out = self.relu(self.layer4_bn3(self.layer4_conv3(out)))
out = self.layer4_bn4(self.layer4_conv4(out))
out = self.relu(out + identity)

# 输出
out = self.avgpool(out)
out = torch.flatten(out, 1)
out = self.fc(out)
return out

训练与保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def train_model(model, criterion, optimizer, train_loader, test_loader, num_epochs=10):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0

for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

train_loss = running_loss / len(train_loader.dataset)
train_acc = 100. * correct / total

# 测试阶段
model.eval()
running_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

test_loss = running_loss / len(test_loader.dataset)
test_acc = 100. * correct / total

print(f'Epoch {epoch+1}/{num_epochs}: '
f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

return model

if __name__ == '__main__':
model = ResNet18(num_classes=num_classes)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model = train_model(model, criterion, optimizer, train_loader, test_loader, num_epochs=10)

# 保存完整 checkpoint(含优化器状态)
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'class_names': class_names,
}
torch.save(checkpoint, 'checkpoint.pth')

梯度累计写法:当显存不足时,可将 batch_size 调小并累计多步梯度后一次性更新。但需同步调小 batch,否则适得其反。

模型推理:采样方法

对模型最后一层输出(softmax 前)进行的不同采样策略,控制生成文本的多样性与质量。

温度采样

温度参数 控制概率分布的平滑程度:τ 越大越平滑(多样性高),τ 越小越尖锐(确定性高)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torch.nn.functional as F

def temperature_sampling(logits, temperature=1.0):
"""
Args:
logits: 模型输出,shape (batch_size, vocab_size)
temperature: τ > 0,控制平滑程度
Returns:
采样得到的 token 索引
"""
logits_scaled = logits / temperature
probs = F.softmax(logits_scaled, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1).squeeze(1)
return sampled_indices

Top-K 采样

只保留概率最高的 k 个 token,其余置为 -∞。

1
2
3
4
5
6
7
def top_k_sampling(logits, k=50):
top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
filtered_logits = torch.full_like(logits, float('-inf'))
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
probs = F.softmax(filtered_logits, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1).squeeze(1)
return sampled_indices

Top-P(Nucleus)采样

按概率从高到低累加,保留累加到阈值 p 的 token 集合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def top_p_sampling(logits, p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cum_probs = torch.cumsum(sorted_probs, dim=-1)

mask = cum_probs < p
# 确保至少保留一个 token
mask = torch.cat([mask[:, :-1], torch.ones_like(mask[:, -1:])], dim=-1)
sorted_logits = sorted_logits.masked_fill(~mask, float('-inf'))

restored_logits = torch.zeros_like(logits)
restored_logits.scatter_(-1, sorted_indices, sorted_logits)
probs = F.softmax(restored_logits, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1).squeeze(1)
return sampled_indices

典型 P(Typical)采样

基于信息含量筛选:保留与熵偏差最小的 token。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def typical_sampling(logits, typical_p=0.9):
probs = F.softmax(logits, dim=-1)
H = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1, keepdim=True)
info = -torch.log(probs + 1e-10)
deviation = torch.abs(info - H)

sorted_deviation, sorted_indices = torch.sort(deviation, dim=-1)
sorted_probs = probs.gather(-1, sorted_indices)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cum_probs < typical_p
mask = torch.cat([mask[:, :-1], torch.ones_like(mask[:, -1:])], dim=-1)

filtered_logits = torch.full_like(logits, float('-inf'))
for i in range(logits.shape[0]):
selected_indices = sorted_indices[i, mask[i]]
filtered_logits[i, selected_indices] = logits[i, selected_indices]
probs = F.softmax(filtered_logits, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1).squeeze(1)
return sampled_indices

Min-P 采样

以最大概率为基准,筛掉概率低于 max_prob × min_p_ratio 的 token。

1
2
3
4
5
6
7
8
def min_p_sampling(logits, min_p_ratio=0.05):
probs = F.softmax(logits, dim=-1)
max_prob, _ = torch.max(probs, dim=-1, keepdim=True)
threshold = max_prob * min_p_ratio
filtered_logits = logits.masked_fill(probs < threshold, float('-inf'))
probs = F.softmax(filtered_logits, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1).squeeze(1)
return sampled_indices

Top-A 采样

根据熵动态调整 k 值,再进行 Top-K 采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def top_a_sampling(logits, base_k=4):
probs = F.softmax(logits, dim=-1)
H = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
dynamic_k = base_k + torch.floor(H * 2).int()
dynamic_k = torch.max(dynamic_k, torch.ones_like(dynamic_k)).tolist()

sampled_indices = []
for i in range(logits.shape[0]):
k = dynamic_k[i]
top_k_values, top_k_indices = torch.topk(logits[i], k, dim=-1)
filtered_logits = torch.full_like(logits[i], float('-inf'))
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
prob = F.softmax(filtered_logits, dim=-1)
sampled_idx = torch.multinomial(prob, num_samples=1).item()
sampled_indices.append(sampled_idx)
return torch.tensor(sampled_indices)

文件处理 — CSV 类型

基本读写

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import csv

students = [
['姓名', '年龄', '成绩'],
['张三', 18, 95],
['李四', 19, 88],
['王五', 17, 92],
['赵六', 18, 85]
]

# 写入 CSV
with open('students.csv', 'w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerows(students)

# 读取 CSV
with open('students.csv', 'r', encoding='utf-8') as file:
reader = csv.reader(file)
for row in reader:
print(row)

# 字典方式写入
fieldnames = ['姓名', '年龄', '成绩']
student_dicts = [
{'姓名': '钱七', '年龄': 19, '成绩': 90},
{'姓名': '孙八', '年龄': 18, '成绩': 87},
]
with open('students_dict.csv', 'w', newline='', encoding='utf-8') as file:
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(student_dicts)

# 字典方式读取
with open('students_dict.csv', 'r', encoding='utf-8') as file:
reader = csv.DictReader(file)
for row in reader:
print(row)

CSV 数据训练示范

通过自定义 Dataset 将 CSV 文件加载为 PyTorch 可用数据集,搭配 FFN 网络进行训练。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import csv

class CSVDataDataset(Dataset):
def __init__(self, csv_file, transform=None):
self.data = []
self.transform = transform
self.label_map = {}
self.current_label_id = 0

with open(csv_file, 'r', encoding='utf-8') as file:
reader = csv.reader(file)
next(reader) # 跳过表头
for row in reader:
name = row[0]
features = list(map(float, row[1:-1]))
label_str = row[-1]

if label_str not in self.label_map:
self.label_map[label_str] = self.current_label_id
self.current_label_id += 1
label = self.label_map[label_str]
self.data.append((name, features, label))

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
name, features, label = self.data[idx]
features = torch.tensor(features, dtype=torch.float32)
label = torch.tensor(label, dtype=torch.long)
if self.transform:
features = self.transform(features)
return features, label

class ThreeLayerFFN(nn.Module):
def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
super(ThreeLayerFFN, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim1)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_dim2, output_dim)

def forward(self, x):
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
return x

样本数据格式示例:

name param1 param2 param3 label
sample1 0.5 0.3 0.2 class_A
sample2 0.1 0.7 0.9 class_B
sample3 0.8 0.4 0.6 class_A

文件处理 — JSON 类型

1
2
3
4
5
6
7
8
9
10
11
12
13
import json

# 读取 JSON 文件
with open('dataset_infos.json', 'r', encoding='utf-8') as f:
data = json.load(f)

# json.dumps 将 Python 对象编码为格式化的 JSON 字符串
print(json.dumps(data, indent=2, ensure_ascii=False))

# 访问嵌套字段
print(f"默认配置的特征: {data['default']['features'].keys()}")
print(f"训练集路径: {data['default']['splits']['train']['path']}")
print(f"测试集路径: {data['default']['splits']['test']['path']}")

dataset_infos.json 示例结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
{
"default": {
"features": {
"text_1": { "_type": "Value" },
"image_1": { "_type": "Image" },
"audio_1": { "_type": "Audio" },
"video_1": { "_type": "Video" }
},
"splits": {
"train": { "path": "A" },
"test": { "path": "B" }
}
}
}

json.dumps(obj, indent=2, ensure_ascii=False) 将 Python 对象编码为 JSON 字符串,ensure_ascii=False 允许输出中文等非 ASCII 字符。