用深度学习(LSTM)实现时间序列预测:从数据到闭环预测全解析
时间序列预测是工业、金融、环境等领域的核心需求——小到预测设备温度波动,大到预测股价走势,都需要从历史数据中挖掘时序规律。长短期记忆网络(LSTM)凭借对“长期依赖关系”的捕捉能力,成为时序预测的主流模型之一。
本文将基于MATLAB深度学习工具箱,以波形数据集(WaveformData) 为例,完整拆解LSTM时间序列预测的实现流程,重点讲解“闭环预测”的核心逻辑(用前一次预测结果作为下一次输入,无需真实值即可多步预测),并对代码逐行、参数逐个进行解析。
一、整体背景:LSTM与两种预测模式
LSTM是一种循环神经网络(RNN),通过“门控机制”(遗忘门、输入门、输出门)动态更新“隐藏状态”,从而记住序列中的关键历史信息,避免普通RNN的“梯度消失”问题。
时序预测有两种核心模式,也是本文的重点对比对象:
- 开环预测:每次预测都需要“真实的历史数据”作为输入(比如预测第t步需要第t-1步的真实值),适合能实时获取真实数据的场景。
- 闭环预测:仅用初始真实数据初始化,后续预测完全依赖“前一次的预测结果”作为输入(无需真实值),适合需要一次性预测多步未来、或无法获取实时真实数据的场景(如预测未来200天的温度)。
本文将从数据加载到闭环预测,一步步实现完整流程。
二、完整实现流程与代码解析
1. 第一步:加载与探索数据
首先加载示例数据集,了解数据结构,为后续处理做准备。
代码与逐行解析
% 加载波形数据集(MATLAB内置示例数据)
load WaveformData% 查看前5个序列的结构(数据是cell数组,每个元素是一个序列)
data(1:5)% 计算序列的通道数(所有序列通道数一致,才能训练网络)
numChannels = size(data{1},1)% 可视化前4个序列(堆叠图展示多通道)
figure
tiledlayout(2,2) % 创建2x2的子图布局
for i = 1:4nexttile % 激活下一个子图stackedplot(data{i}') % 转置序列:让时间步为x轴,通道为y轴xlabel("Time Step") % x轴标签:时间步
end% 划分训练集与测试集(9:1拆分)
numObservations = numel(data); % 总序列数(data是cell数组,numel取元素个数)
idxTrain = 1:floor(0.9*numObservations); % 训练集索引(前90%)
idxTest = floor(0.9*numObservations)+1:numObservations; % 测试集索引(后10%)
dataTrain = data(idxTrain); % 训练集序列
dataTest = data(idxTest); % 测试集序列
关键参数与概念
- WaveformData:MATLAB内置的合成波形数据集,结构为
numObservations×1
的cell数组,每个cell元素是numChannels×numTimeSteps
的矩阵(numChannels=3
,即每个时间步有3个特征;numTimeSteps
为序列长度,不同序列长度不同)。 - stackedplot:堆叠图函数,适合展示多通道时序数据(每个通道一条线,避免重叠)。
- 数据划分逻辑:9:1拆分是时序预测的常用比例,既保证训练集足够大(学习规律),又保留测试集(评估泛化能力)。
2. 第二步:准备训练数据(核心:移位目标序列+归一化)
LSTM训练需要“输入-目标”配对的监督数据。时序预测的核心技巧是:输入为“去掉最后一个时间步的序列”,目标为“移位一个时间步的序列”,让LSTM学习“当前时间步→下一个时间步”的映射关系。
同时,为避免训练发散、提升收敛速度,需要对数据做“零均值单位方差”归一化。
代码与逐行解析
% 1. 构建训练集的“输入-目标”配对(移位序列)
for n = 1:numel(dataTrain) % 遍历每个训练序列X = dataTrain{n}; % 取第n个训练序列(numChannels×numTimeSteps)XTrain{n} = X(:,1:end-1); % 输入:去掉最后一个时间步(无法预测它的下一个值)TTrain{n} = X(:,2:end); % 目标:移位一个时间步(每个输入对应下一个时间步的真实值)
end% 2. 归一化:计算训练集的均值和标准差(所有序列拼接后统计,保证一致性)
muX = mean(cat(2,XTrain{:}),2); % 输入的均值:cat(2,...)按时间步拼接所有序列,mean(...,2)按通道算均值
sigmaX = std(cat(2,XTrain{:}),0,2); % 输入的标准差:0表示除以N-1(无偏估计),2表示按通道算muT = mean(cat(2,TTrain{:}),2); % 目标的均值
sigmaT = std(cat(2,TTrain{:}),0,2); % 目标的标准差% 3. 对输入和目标进行归一化(用训练集的统计量,避免数据泄露)
for n = 1:numel(XTrain)XTrain{n} = (XTrain{n} - muX) ./ sigmaX; % 输入归一化:(原始-均值)/标准差TTrain{n} = (TTrain{n} - muT) ./ sigmaT; % 目标归一化
end
关键逻辑解释
- 移位序列的原因:假设序列为
[t1,t2,t3,t4]
,输入XTrain
为[t1,t2,t3]
,目标TTrain
为[t2,t3,t4]
,让LSTM学习“t1→t2”“t2→t3”“t3→t4”的映射,最终能实现“输入任意序列→预测下一个时间步”。 - 归一化的必要性:若不同通道的数值范围差异大(如通道1是0-1,通道2是100-200),训练时会导致梯度更新失衡,模型难以收敛。用训练集统计量归一化,是为了避免“测试集信息泄露到训练集”(测试集的统计量未知)。
3. 第三步:定义LSTM网络架构
时序预测的LSTM网络需要适配“序列输入→序列输出”的需求,核心层包括:序列输入层、LSTM层、全连接层、回归层。
代码与逐行解析
layers = [sequenceInputLayer(numChannels) % 序列输入层:输入维度=通道数(numChannels=3)lstmLayer(128) % LSTM层:128个隐藏单元(决定学习能力)fullyConnectedLayer(numChannels) % 全连接层:输出维度=通道数(与输入通道一致)regressionLayer]; % 回归层:定义回归任务的损失函数(默认均方误差MSE)
各层参数与作用详解
层名称 | 参数配置 | 作用说明 |
---|---|---|
sequenceInputLayer | numChannels=3 | 接收“通道数×时间步”的序列输入,输入维度必须与数据的通道数一致(否则维度不匹配)。 |
lstmLayer | 128个隐藏单元 | 隐藏单元数量决定LSTM的“记忆容量”:128个单元可捕捉中等复杂度的时序规律;数量越多学习能力越强,但易过拟合。 |
fullyConnectedLayer | numChannels=3 | 将LSTM输出的128维隐藏状态“映射”到3维(与输入通道数一致),确保输出序列的维度与目标序列匹配。 |
regressionLayer | 无参数(默认) | 回归任务的输出层,计算“预测值-真实值”的均方误差(MSE),作为训练的损失函数,指导网络更新权重。 |
4. 第四步:指定训练选项
训练选项决定模型的优化策略,需结合数据规模、网络复杂度调整。
代码与逐行解析
options = trainingOptions("adam", ... % 优化器:Adam(自适应学习率,适合时序数据)MaxEpochs=200, ... % 最大训练轮数:200轮(平衡训练效果与时间)SequencePaddingDirection="left", ...% 序列对齐方式:左侧补零(保护右侧有效信息)Shuffle="every-epoch", ... % 数据打乱:每轮训练前打乱训练集,避免过拟合Plots="training-progress", ... % 可视化:显示训练进度(损失曲线、准确率等)Verbose=0); % 日志输出:0表示不打印详细训练日志(仅看进度图)
关键选项解释
- Adam优化器:比SGD(随机梯度下降)收敛更快,通过自适应学习率调整不同参数的更新步长,适合LSTM这类复杂网络。
- MaxEpochs=200:200轮是针对2000个序列的经验值——轮数太少可能欠拟合(没学会规律),太多则可能过拟合(记住训练集噪声)。
- SequencePaddingDirection=“left”:不同序列长度不同,训练时需补零对齐。左侧补零是为了保护“右侧的近期信息”(时序数据中,右侧时间步更重要),避免右侧补零干扰预测。
5. 第五步:训练LSTM网络
调用trainNetwork
函数,用训练集(XTrain, TTrain)和训练选项(options)训练网络。
代码与解析
% 训练网络:输入(XTrain)、目标(TTrain)、网络架构(layers)、训练选项(options)
net = trainNetwork(XTrain,TTrain,layers,options);
- 输出:训练好的LSTM网络
net
,包含学习到的权重、偏置和网络结构。 - 训练过程:运行时会弹出“训练进度图”,可观察训练损失(Training Loss)的下降趋势——若损失趋于平稳,说明网络收敛。
6. 第六步:测试网络(评估泛化能力)
测试的核心是:用训练好的网络预测测试集,计算误差(RMSE)评估泛化能力。
代码与逐行解析
% 1. 准备测试数据(与训练数据处理逻辑一致:移位+归一化)
for n = 1:size(dataTest,1) % 遍历每个测试序列X = dataTest{n}; % 取第n个测试序列XTest{n} = (X(:,1:end-1) - muX) ./ sigmaX; % 测试输入:移位+用训练集统计量归一化TTest{n} = (X(:,2:end) - muT) ./ sigmaT; % 测试目标:移位+归一化
end% 2. 用测试集预测(指定左侧补零,与训练一致)
YTest = predict(net,XTest,SequencePaddingDirection="left");% 3. 计算每个测试序列的RMSE(均方根误差,评估预测精度)
for i = 1:size(YTest,1)% RMSE = sqrt(平均(预测值-真实值)^2),"all"表示对所有元素计算rmse(i) = sqrt(mean((YTest{i} - TTest{i}).^2,"all"));
end% 4. 可视化RMSE分布(直方图)
figure
histogram(rmse) % 绘制RMSE的频率分布
xlabel("RMSE") % x轴:RMSE值(越小精度越高)
ylabel("Frequency") % y轴:频率(多少个序列的RMSE落在该区间)% 5. 计算所有测试序列的平均RMSE
mean(rmse)
评估逻辑
- RMSE的意义:RMSE越小,预测值与真实值的偏差越小。例如,若平均RMSE=0.1,说明预测值与真实值的平均偏差仅0.1(归一化后的值,反归一化后可还原为原始尺度)。
- 为什么用训练集统计量归一化:测试时无法获取“未来数据的统计量”,用训练集统计量才能模拟真实预测场景(避免数据泄露)。
7. 第七步:预测未来时间步(重点:开环vs闭环)
测试仅验证“单步预测”能力,实际应用中常需“多步预测”(如预测未来200个时间步)。此时需区分开环与闭环两种模式,闭环预测是本文核心。
7.1 先理解:开环预测(依赖真实值)
开环预测的逻辑是:每次预测都需要“前一个时间步的真实值”作为输入,适合能实时获取真实数据的场景(如实时监测设备数据,用真实值预测下一秒)。
% 选择一个测试序列(索引=2)
idx = 2;
X = XTest{idx}; % 测试输入序列
T = TTest{idx}; % 测试目标序列% 1. 初始化网络状态(重置隐藏状态,避免历史数据干扰)
net = resetState(net);
% 2. 用前75个时间步的真实数据更新网络状态(让网络“记住”初始上下文)
offset = 75; % 初始真实数据的时间步长度
[net,~] = predictAndUpdateState(net,X(:,1:offset));% 3. 开环预测:用真实值作为输入,预测剩余时间步
numTimeSteps = size(X,2); % 测试序列总时间步
numPredictionTimeSteps = numTimeSteps - offset; % 需预测的时间步数量
Y_open = zeros(numChannels,numPredictionTimeSteps); % 存储开环预测结果for t = 1:numPredictionTimeStepsXt = X(:,offset+t); % 输入:第offset+t步的真实值(开环的核心:依赖真实值)[net,Y_open(:,t)] = predictAndUpdateState(net,Xt); % 预测+更新网络状态
end% 4. 可视化开环预测结果
figure
t = tiledlayout(numChannels,1); % 按通道堆叠子图
title(t,"Open Loop Forecasting")
for i = 1:numChannelsnexttileplot(T(i,:)) % 真实值(目标序列)hold on% 预测值:从offset步开始,拼接offset步的真实值+预测值plot(offset:numTimeSteps,[T(i,offset) Y_open(i,:)],'--')ylabel("Channel " + i)
end
xlabel("Time Step")
nexttile(1)
legend(["True Value" "Forecasted Value"])
- 开环的局限性:必须获取每个时间步的真实值才能继续预测,无法一次性预测多步未来(如无法直接预测未来200步,需等每一步真实值产生)。
7.2 核心:闭环预测(无需真实值,用前一次预测当输入)
闭环预测的逻辑是:仅用初始真实数据初始化,后续预测完全依赖“前一次的预测结果”作为输入,可一次性预测任意多步未来,适合无法获取实时真实数据的场景(如预测未来一个月的销量)。
代码与逐行解析
% 1. 重置网络状态(关键!清除历史隐藏状态,确保从干净的初始状态开始)
net = resetState(net);% 2. 用测试序列的所有真实数据初始化网络状态(让网络“记住”完整的初始上下文)
offset = size(X,2); % offset=测试序列的总时间步(用全部真实数据初始化)
[net,Z] = predictAndUpdateState(net,X); % Z是初始预测结果(与测试序列长度一致)% 3. 闭环预测:预测未来200个时间步(可自定义数量)
numPredictionTimeSteps = 200; % 需预测的未来时间步数量
Xt = Z(:,end); % 初始输入:最后一个时间步的预测值(闭环的核心:用预测值当输入)
Y_closed = zeros(numChannels,numPredictionTimeSteps); % 存储闭环预测结果% 循环预测:每一步用前一次的预测值作为输入
for t = 1:numPredictionTimeSteps% 预测当前时间步+更新网络状态[net,Y_closed(:,t)] = predictAndUpdateState(net,Xt);% 更新输入:下一次预测用当前的预测值Xt = Y_closed(:,t);
end% 4. 可视化闭环预测结果
numTimeSteps = offset + numPredictionTimeSteps; % 总时间步=初始真实数据+预测数据
figure
t = tiledlayout(numChannels,1);
title(t,"Closed Loop Forecasting") % 标题:闭环预测for i = 1:numChannelsnexttileplot(T(i,1:offset)) % 初始真实数据(前offset步)hold on% 预测数据:从offset步开始,拼接offset步的真实值+未来200步的预测值plot(offset:numTimeSteps,[T(i,offset) Y_closed(i,:)],'--')ylabel("Channel " + i)
endxlabel("Time Step")
nexttile(1)
legend(["Input (True Value)" "Forecasted Value"])
闭环预测的核心细节
-
为什么要
resetState
?
LSTM的隐藏状态会“记忆”历史数据,若不重置,网络会携带上一次预测的残留信息(如之前预测过的其他序列),导致当前预测的初始状态错误,误差被不断放大。resetState
能将隐藏状态清零,确保从“干净的初始状态”开始学习当前序列的上下文。 -
predictAndUpdateState
的作用?
该函数是闭环预测的核心工具,同时完成两个任务:- 基于当前输入(真实值或预测值)计算预测结果;
- 更新网络的隐藏状态(让网络“记住”当前输入的信息,为下一次预测做准备)。
-
循环逻辑的关键?
每次循环中,Xt = Y_closed(:,t)
将“当前预测值”作为“下一次预测的输入”,形成“预测→输入→再预测”的闭环,无需任何真实值即可持续预测多步未来。
三、闭环预测的优缺点与适用场景
特点 | 优点 | 缺点 | 适用场景 |
---|---|---|---|
数据依赖 | 仅需初始真实数据,后续无需真实值 | 误差会累积(前一步预测不准,后一步偏差更大) | 无法获取实时真实数据、需一次性预测多步未来(如预测未来1年的季节性波动) |
灵活性 | 可自定义预测步数(如预测200步、500步) | 精度通常低于开环预测 | 长期趋势预测、资源有限无法实时采集数据的场景 |
计算效率 | 一次性循环完成多步预测,无需等待真实数据 | 需合理初始化网络状态(否则初始误差大) | 批量预测、离线预测任务 |
四、总结
本文通过完整的MATLAB代码,拆解了LSTM时间序列预测的全流程:从数据加载与移位处理、网络架构设计、训练优化,到开环与闭环预测的实现。核心结论如下:
- 数据处理是基础:移位目标序列让LSTM学习“当前→下一个”的映射,归一化避免训练发散,左侧补零保护有效信息。
- 网络架构需适配任务:sequenceInputLayer匹配通道数,lstmLayer隐藏单元数量平衡学习能力与过拟合,regressionLayer适配回归任务。
- 闭环预测是核心亮点:通过
resetState
初始化状态、predictAndUpdateState
预测+更新状态、循环用前一次预测当输入,实现无需真实值的多步预测,适合实际应用中的长期预测需求。
掌握这套流程后,你可以将其迁移到自己的时序数据(如温度、销量、股价),只需调整通道数、隐藏单元数量、预测步数等参数,即可快速实现定制化的时间序列预测。