80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
import torch
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import torch.nn as nn
|
|
|
|
import torch.optim as optim
|
|
|
|
# 生成合成数据
|
|
t = np.linspace(0, 24*np.pi, 1000)
|
|
data = np.sin(t) + 0.5*np.sin(3*t) + 0.05*t # 混合波形+趋势项
|
|
|
|
# 数据预处理
|
|
def create_dataset(data, look_back=30):
|
|
X, y = [], []
|
|
for i in range(len(data)-look_back):
|
|
X.append(data[i:i+look_back])
|
|
y.append(data[i+look_back])
|
|
return torch.FloatTensor(X).unsqueeze(-1), torch.FloatTensor(y)
|
|
|
|
X, y = create_dataset(data)
|
|
train_size = int(0.8 * len(X))
|
|
train_X, test_X = X[:train_size], X[train_size:]
|
|
train_y, test_y = y[:train_size], y[train_size:]
|
|
|
|
# 模型定义
|
|
class TimeSeriesModel(nn.Module):
|
|
def __init__(self, model_type):
|
|
super().__init__()
|
|
self.model_type = model_type
|
|
if model_type == 'LSTM':
|
|
self.rnn = nn.LSTM(1, 64, num_layers=2)
|
|
else:
|
|
self.rnn = nn.RNN(1, 64)
|
|
self.fc = nn.Linear(64, 1)
|
|
|
|
def forward(self, x):
|
|
out, _ = self.rnn(x)
|
|
return self.fc(out[-1, :, :])
|
|
|
|
# 训练函数
|
|
# 修改后的训练函数,返回预测结果和测试损失
|
|
def train_and_predict(model_type):
|
|
model = TimeSeriesModel(model_type)
|
|
criterion = nn.MSELoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
# 训练循环
|
|
for epoch in range(100):
|
|
output = model(train_X.transpose(0, 1))
|
|
loss = criterion(output.squeeze(), train_y)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if epoch % 20 == 0:
|
|
print(f"{model_type} Epoch {epoch} Loss: {loss.item():.4f}")
|
|
|
|
# 预测阶段
|
|
with torch.no_grad():
|
|
test_pred = model(test_X.transpose(0, 1))
|
|
test_loss = criterion(test_pred.squeeze(), test_y)
|
|
print(f"{model_type} Test MSE: {test_loss.item():.4f}")
|
|
|
|
return test_pred.squeeze().numpy(), test_loss.item()
|
|
|
|
# 同时训练两种模型并收集结果
|
|
lstm_pred, lstm_loss = train_and_predict('LSTM')
|
|
rnn_pred, rnn_loss = train_and_predict('RNN')
|
|
|
|
# 统一可视化比较
|
|
plt.figure(figsize=(12,6))
|
|
plt.plot(test_y.numpy(), label='True Values', alpha=0.7)
|
|
plt.plot(lstm_pred, label=f'LSTM (MSE: {lstm_loss:.4f})', linestyle='--')
|
|
plt.plot(rnn_pred, label=f'RNN (MSE: {rnn_loss:.4f})', linestyle='--')
|
|
plt.title('Time Series Prediction Comparison')
|
|
plt.xlabel('Time Steps')
|
|
plt.ylabel('Value')
|
|
plt.legend()
|
|
plt.show() |