TeamClass_MD/Game2.py
2025-03-19 15:56:43 +08:00

125 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import numpy as np
#这个实验的目的是比较RNN和GRU在相同任务上的性能即学习序列中两个随机位置数值的和。
# 数据生成
def generate_add_data(seq_len=30):
data = torch.zeros(seq_len, 2) # (seq_len, 2)
idx1, idx2 = np.random.choice(seq_len, 2, replace=False)
val1, val2 = np.random.rand()*0.5, np.random.rand()*0.5
data[idx1, 0] = val1
data[idx2, 0] = val2
target = torch.tensor([val1 + val2]).view(1,1)
return data.unsqueeze(0), target # (1, seq_len, 2)
# 模型定义
class AdditionRNN(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(2, 16, batch_first=True)
self.fc = nn.Linear(16, 1)
def forward(self, x):
out, _ = self.rnn(x)
return self.fc(out[:, -1, :])
class AdditionGRU(nn.Module):
def __init__(self):
super().__init__()
self.gru = nn.GRU(2, 16, batch_first=True)
self.fc = nn.Linear(16, 1)
def forward(self, x):
out, _ = self.gru(x)
return self.fc(out[:, -1, :])
# 训练对比
import matplotlib.pyplot as plt
# 修改后的训练函数,记录损失变化
def train_addition():
rnn = AdditionRNN()
gru = AdditionGRU()
criterion = nn.MSELoss()
optim_rnn = torch.optim.Adam(rnn.parameters(), lr=0.01)
optim_gru = torch.optim.Adam(gru.parameters(), lr=0.01)
# 记录训练过程
losses = {'RNN': [], 'GRU': []}
for step in range(1000):
inputs, target = generate_add_data(seq_len=30)
# RNN训练
optim_rnn.zero_grad()
rnn_pred = rnn(inputs)
loss_rnn = criterion(rnn_pred, target)
loss_rnn.backward()
optim_rnn.step()
# GRU训练
optim_gru.zero_grad()
gru_pred = gru(inputs)
loss_gru = criterion(gru_pred, target)
loss_gru.backward()
optim_gru.step()
# 记录损失
losses['RNN'].append(loss_rnn.item())
losses['GRU'].append(loss_gru.item())
if step % 200 == 0:
print(f"Step {step:03d} | RNN Loss: {loss_rnn.item():.4f} | GRU Loss: {loss_gru.item():.4f}")
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(losses['RNN'], label='RNN', alpha=0.7)
plt.plot(losses['GRU'], label='GRU', alpha=0.7)
plt.xlabel('Training Steps')
plt.ylabel('MSE Loss')
plt.title('Training Comparison: RNN vs GRU')
plt.legend()
plt.grid(True)
plt.show()
return rnn, gru
# 执行训练
rnn_model, gru_model = train_addition()
def show_test_cases(model, model_name, num_cases=5):
print(f"\n{model_name} 测试样例:")
criterion = nn.MSELoss()
total_error = 0
for case_idx in range(num_cases):
# 生成测试数据
inputs, target = generate_add_data()
seq_len = inputs.shape[1]
# 解析输入数据
non_zero_indices = torch.nonzero(inputs[0, :, 0])
pos1, pos2 = non_zero_indices[0].item(), non_zero_indices[1].item()
val1 = inputs[0, pos1, 0].item()
val2 = inputs[0, pos2, 0].item()
# 模型预测
with torch.no_grad():
pred = model(inputs)
loss = criterion(pred, target)
# 格式输出
print(f"案例 {case_idx+1}:")
print(f"输入序列长度: {seq_len}")
print(f"数值位置: [{pos1:2d}]={val1:.4f}, [{pos2:2d}]={val2:.4f}")
print(f"真实值: {target.item():.4f}")
print(f"预测值: {pred.item():.4f}")
print(f"绝对误差: {abs(pred.item()-target.item()):.4f}")
print("-" * 40)
total_error += abs(pred.item()-target.item())
print(f"平均绝对误差: {total_error/num_cases:.4f}\n")
show_test_cases(rnn_model, "RNN")
show_test_cases(gru_model, "GRU")