TeamClass_MD/NN_normal.py
2025-03-16 02:02:53 +08:00

111 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# 1. 数据准备以MNIST手写数字识别为例
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 像素值归一化到[-1,1]
])
train_set = datasets.MNIST('data', download=True, train=True, transform=transform)
test_set = datasets.MNIST('data', download=True, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000)
# 2. 神经网络模型(演示梯度控制技巧)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
# He初始化适配ReLU
nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='relu')
def forward(self, x):
x = x.view(-1, 784) # 展平图像
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x) # 输出层无需激活CrossEntropyLoss内置Softmax
return x
# 3. 训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 梯度裁剪阈值
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
# 4. 训练过程可视化记录
train_losses = []
accuracies = []
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 记录训练损失
if batch_idx % 100 == 0:
train_losses.append(loss.item())
# 5. 测试函数(含准确率计算)
def test():
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = 100. * correct / len(test_loader.dataset)
accuracies.append(accuracy)
return test_loss
# 6. 执行训练3个epoch演示
for epoch in range(1, 4):
train(epoch)
loss = test()
print(f'Epoch {epoch}: Test Loss={loss:.4f}, Accuracy={accuracies[-1]:.2f}%')
# 7. 可视化训练过程
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Training Loss')
plt.title("Loss Curve")
plt.subplot(1,2,2)
plt.plot(accuracies, label='Accuracy', color='orange')
plt.title("Accuracy Curve")
plt.show()
# 8. 示例预测展示
sample_data, sample_label = next(iter(test_loader))
with torch.no_grad():
prediction = model(sample_data.to(device)).argmax(dim=1)
# 显示预测结果对比
plt.figure(figsize=(10,6))
for i in range(6):
plt.subplot(2,3,i+1)
plt.imshow(sample_data[i][0], cmap='gray')
plt.title(f"True: {sample_label[i]}\nPred: {prediction[i].item()}")
plt.axis('off')
plt.tight_layout()
plt.show()