TeamClass_MD/Game1.py
2025-03-19 15:56:43 +08:00

80 lines
2.4 KiB
Python

import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
# 生成合成数据
t = np.linspace(0, 24*np.pi, 1000)
data = np.sin(t) + 0.5*np.sin(3*t) + 0.05*t # 混合波形+趋势项
# 数据预处理
def create_dataset(data, look_back=30):
X, y = [], []
for i in range(len(data)-look_back):
X.append(data[i:i+look_back])
y.append(data[i+look_back])
return torch.FloatTensor(X).unsqueeze(-1), torch.FloatTensor(y)
X, y = create_dataset(data)
train_size = int(0.8 * len(X))
train_X, test_X = X[:train_size], X[train_size:]
train_y, test_y = y[:train_size], y[train_size:]
# 模型定义
class TimeSeriesModel(nn.Module):
def __init__(self, model_type):
super().__init__()
self.model_type = model_type
if model_type == 'LSTM':
self.rnn = nn.LSTM(1, 64, num_layers=2)
else:
self.rnn = nn.RNN(1, 64)
self.fc = nn.Linear(64, 1)
def forward(self, x):
out, _ = self.rnn(x)
return self.fc(out[-1, :, :])
# 训练函数
# 修改后的训练函数,返回预测结果和测试损失
def train_and_predict(model_type):
model = TimeSeriesModel(model_type)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(100):
output = model(train_X.transpose(0, 1))
loss = criterion(output.squeeze(), train_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f"{model_type} Epoch {epoch} Loss: {loss.item():.4f}")
# 预测阶段
with torch.no_grad():
test_pred = model(test_X.transpose(0, 1))
test_loss = criterion(test_pred.squeeze(), test_y)
print(f"{model_type} Test MSE: {test_loss.item():.4f}")
return test_pred.squeeze().numpy(), test_loss.item()
# 同时训练两种模型并收集结果
lstm_pred, lstm_loss = train_and_predict('LSTM')
rnn_pred, rnn_loss = train_and_predict('RNN')
# 统一可视化比较
plt.figure(figsize=(12,6))
plt.plot(test_y.numpy(), label='True Values', alpha=0.7)
plt.plot(lstm_pred, label=f'LSTM (MSE: {lstm_loss:.4f})', linestyle='--')
plt.plot(rnn_pred, label=f'RNN (MSE: {rnn_loss:.4f})', linestyle='--')
plt.title('Time Series Prediction Comparison')
plt.xlabel('Time Steps')
plt.ylabel('Value')
plt.legend()
plt.show()