125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
|
||
#这个实验的目的是比较RNN和GRU在相同任务上的性能,即学习序列中两个随机位置数值的和。
|
||
# 数据生成
|
||
def generate_add_data(seq_len=30):
|
||
data = torch.zeros(seq_len, 2) # (seq_len, 2)
|
||
idx1, idx2 = np.random.choice(seq_len, 2, replace=False)
|
||
val1, val2 = np.random.rand()*0.5, np.random.rand()*0.5
|
||
data[idx1, 0] = val1
|
||
data[idx2, 0] = val2
|
||
target = torch.tensor([val1 + val2]).view(1,1)
|
||
return data.unsqueeze(0), target # (1, seq_len, 2)
|
||
|
||
# 模型定义
|
||
class AdditionRNN(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.rnn = nn.RNN(2, 16, batch_first=True)
|
||
self.fc = nn.Linear(16, 1)
|
||
|
||
def forward(self, x):
|
||
out, _ = self.rnn(x)
|
||
return self.fc(out[:, -1, :])
|
||
|
||
class AdditionGRU(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.gru = nn.GRU(2, 16, batch_first=True)
|
||
self.fc = nn.Linear(16, 1)
|
||
|
||
def forward(self, x):
|
||
out, _ = self.gru(x)
|
||
return self.fc(out[:, -1, :])
|
||
|
||
# 训练对比
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 修改后的训练函数,记录损失变化
|
||
def train_addition():
|
||
rnn = AdditionRNN()
|
||
gru = AdditionGRU()
|
||
criterion = nn.MSELoss()
|
||
optim_rnn = torch.optim.Adam(rnn.parameters(), lr=0.01)
|
||
optim_gru = torch.optim.Adam(gru.parameters(), lr=0.01)
|
||
|
||
# 记录训练过程
|
||
losses = {'RNN': [], 'GRU': []}
|
||
|
||
for step in range(1000):
|
||
inputs, target = generate_add_data(seq_len=30)
|
||
|
||
# RNN训练
|
||
optim_rnn.zero_grad()
|
||
rnn_pred = rnn(inputs)
|
||
loss_rnn = criterion(rnn_pred, target)
|
||
loss_rnn.backward()
|
||
optim_rnn.step()
|
||
|
||
# GRU训练
|
||
optim_gru.zero_grad()
|
||
gru_pred = gru(inputs)
|
||
loss_gru = criterion(gru_pred, target)
|
||
loss_gru.backward()
|
||
optim_gru.step()
|
||
|
||
# 记录损失
|
||
losses['RNN'].append(loss_rnn.item())
|
||
losses['GRU'].append(loss_gru.item())
|
||
|
||
if step % 200 == 0:
|
||
print(f"Step {step:03d} | RNN Loss: {loss_rnn.item():.4f} | GRU Loss: {loss_gru.item():.4f}")
|
||
|
||
# 绘制损失曲线
|
||
plt.figure(figsize=(10, 5))
|
||
plt.plot(losses['RNN'], label='RNN', alpha=0.7)
|
||
plt.plot(losses['GRU'], label='GRU', alpha=0.7)
|
||
plt.xlabel('Training Steps')
|
||
plt.ylabel('MSE Loss')
|
||
plt.title('Training Comparison: RNN vs GRU')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.show()
|
||
|
||
return rnn, gru
|
||
|
||
# 执行训练
|
||
rnn_model, gru_model = train_addition()
|
||
def show_test_cases(model, model_name, num_cases=5):
|
||
print(f"\n{model_name} 测试样例:")
|
||
criterion = nn.MSELoss()
|
||
total_error = 0
|
||
|
||
for case_idx in range(num_cases):
|
||
# 生成测试数据
|
||
inputs, target = generate_add_data()
|
||
seq_len = inputs.shape[1]
|
||
|
||
# 解析输入数据
|
||
non_zero_indices = torch.nonzero(inputs[0, :, 0])
|
||
pos1, pos2 = non_zero_indices[0].item(), non_zero_indices[1].item()
|
||
val1 = inputs[0, pos1, 0].item()
|
||
val2 = inputs[0, pos2, 0].item()
|
||
|
||
# 模型预测
|
||
with torch.no_grad():
|
||
pred = model(inputs)
|
||
loss = criterion(pred, target)
|
||
|
||
# 格式输出
|
||
print(f"案例 {case_idx+1}:")
|
||
print(f"输入序列长度: {seq_len}")
|
||
print(f"数值位置: [{pos1:2d}]={val1:.4f}, [{pos2:2d}]={val2:.4f}")
|
||
print(f"真实值: {target.item():.4f}")
|
||
print(f"预测值: {pred.item():.4f}")
|
||
print(f"绝对误差: {abs(pred.item()-target.item()):.4f}")
|
||
print("-" * 40)
|
||
total_error += abs(pred.item()-target.item())
|
||
|
||
print(f"平均绝对误差: {total_error/num_cases:.4f}\n")
|
||
|
||
show_test_cases(rnn_model, "RNN")
|
||
show_test_cases(gru_model, "GRU")
|