DeepSeek-TS,基于状态空间增强MLA与GRPO的时序预测新框架

B站影视 2025-02-08 10:02 2

摘要:本文介绍 DeepSeek-TS,该框架受到 DeepSeek 中高效的多头潜在注意力(MLA)和群组相对策略优化(GRPO)技术的启发,并将其应用于多产品时间序列预测。

本文介绍 DeepSeek-TS,该框架受到 DeepSeek 中高效的多头潜在注意力(MLA)和群组相对策略优化(GRPO)技术的启发,并将其应用于多产品时间序列预测。

这个的方法扩展了 MLA,提出了 MLA-Mamba。MLA-Mamba 允许潜在特征通过具有非线性激活的状态空间模型动态演变,为模型提供自适应记忆,使其能够适应趋势变化。

同时通过GRPO 引入了一种智能决策过程,将预测与基准进行比较来持续改进预测。这种动态调整有助于模型有效响应销售模式的突变。

本文将 DeepSeek-TS 框架与经典的 ARMA 模型和标准的基于 GRU 的网络进行了比较。结果表明,DeepSeek-TS 能够建模复杂的产品间关系并适应非线性动态,从而产生更准确和稳健的预测。

在接下来的章节中,我们将详细介绍 MLA 的扩展(MLA-Mamba)和 GRPO 框架的技术细节,并展示它们的协同作用如何增强多产品时间序列预测。

在 DeepSeek 的 MLA 的核心思想是将 key 和 value 压缩到一个低维的潜在空间,从而减少模型在推理过程中需要存储的 KV 缓存大小。这个过程可以分解为以下几个关键步骤:

考虑一个维度为 d 的输入 token向量 h_t。在标准 Transformer 中,该向量通过学习矩阵被映射到查询(Q)、键(K)和值(V)空间。而在 DeepSeek 的 MLA 中,我们首先将 h_t 压缩成一个专门用于 key 和 value 的低维潜在向量。这可以通过以下公式来表示:

其中:

c_{KV, t} 是压缩后的潜在向量。W_{DKV} 是下投影矩阵。d_c 是压缩维度,满足 d_c ≪ d。

这种低秩近似类似于推荐系统中使用的矩阵分解技术,通过两个较小矩阵的乘积来近似一个大矩阵,从而捕获最显著的特征。W_{DKV} 学习捕获计算注意力所需的 h_t 的最关键方面。

获得 c_{KV,t} 后,我们需要通过上投影,使用单独的矩阵重构用于注意力机制的key和value向量:

这里:

k_{C, t} 是重构的 key 向量。v_{C, t} 是重构的 value 向量。W{UK} 和 W{UV} 是上投影矩阵,维度为 R^{d_hn_h×d_c} (其中 d_h 是每个头的维度,n_h 是头的数量)。

关键在于,在推理过程中,我们不需要为每个 token 缓存完整的 key 和 value 向量(这需要为每个 token 存储 d_hn_h 个元素),而只需要缓存压缩的潜在向量 c_{KV, t},它只包含每个 token 的 d_c 个元素。当 d_c 远小于 d_hn_h 时,这种减少是非常显著的。

除了 key 和 value 之外,DeepSeek 还可以对查询应用类似的低秩压缩,以减少训练过程中的激活内存。过程类似:

其中:

c_{Q, t} 是压缩的查询向量。W{DQ} 和 W{UQ} 是查询的下投影和上投影矩阵。d'_c 是查询压缩维度。

虽然压缩查询不会减少 KV 缓存,但它有助于减少训练过程中的整体激活内存。

假设你有一个 Transformer 层,其中每个 token 的原始 key-value 维度是 1024(假设 d_hn_h = 1024)。使用 MLA,如果你选择压缩维度 d_c = 128,那么你就将每个 token 缓存的数据量从 1024 个元素减少到 128 个元素,减少了 8 倍。在处理长序列或大规模部署模型时,这是非常显著的。

此外,在推理过程中,如果上投影矩阵 W{UK} 和 W{UV} 可以被吸收到其他权重矩阵中(如 W_Q 或 W_O),那么你可能根本不需要显式计算或存储 key 和 value,这样可以带来更大的效率提升。

虽然 MLA 专注于高效的注意力机制,但 DeepSeek 还引入了一种新的优化方法,称为群组相对策略优化(GRPO),用于更新模型的决策策略。该方法基于强化学习原理,旨在平衡探索和利用,同时确保策略更新的稳定性。

策略优化的基础

在强化学习中,策略 π(a∣s;θ) 定义了在参数 θ 下,给定状态 s 时采取动作 a 的概率。目标是最大化预期累积奖励:

GRPO 引入了将新策略与先前(且固定)版本策略的输出群组进行对比评估的思想。关键在于,通过将新策略的输出与旧策略的输出进行比较,可以更稳健地衡量某些动作的优势。

令:

π_old 为用于为同一查询 q 生成输出群组的策略。π_new 为当前可更新的策略。

比率:

衡量新策略相对于旧策略的偏差。

为了稳定训练,这个比率通常被限制在 1−ϵ 和 1+ϵ 之间,其中 ϵ 是一个小的超参数。这种限制确保策略在单次更新中不会发生过于剧烈的变化。

优势函数 A_t 量化了动作 a_t 相对于平均表现的优势程度:

其中 V(s_t) 是代表从状态 s_t 的预期奖励的基准值函数。在 GRPO 中,优势用于为更新加权,确保导致高于平均奖励的动作得到强化。

策略梯度更新由以下公式给出:

这个更新规则表明,应该在增加具有正优势动作概率的方向上调整 θ,同时降低具有负优势动作的概率。

GRPO 建立在之前的方法(如近端策略优化(PPO)和直接策略优化(DPO))的基础上,引入了一种新的机制,即将新策略的输出与旧的固定策略的输出群组进行比较。这种“群组相对”比较允许更稳定和可靠的更新。策略更新不仅考虑新策略的输出,还考虑它们相对于旧策略提供的一致基准的表现。

关键方程(概念):

其中:

clip(r_t, 1−ϵ, 1+ϵ) 限制比率。

目标是最大化 L(θ)。

这个方程确保如果新策略的偏差太大,更新会被限制,防止可能破坏训练的过度变化。

假设我们的 RL 代理负责为给定的查询选择最佳模型或模型组合。这个查询来自我们 MLA 过程生成的潜在表示 z_t。假设旧策略 πold 为动作(如选择 XGBoost、LightGBM、DNN 或混合模型)分配概率分布,并对特定查询的“选择 DNN”赋予 0.25 的概率。同时,新策略 πnew 可能为“选择 DNN”分配 0.35 的概率。我们然后计算这些概率的比率。如果 ϵ 设为 0.2,我们将比率 r_t 限制在区间 [0.8, 1.2] 以确保学习稳定性。

如果 ϵ 设为 0.2,那么我们将 r_t 限制在区间 [0.8, 1.2]。假设“选择 DNN”的计算优势 A_t 为 +0.5(表明在这种情况下选择 DNN 是有益的)。

策略梯度更新将使用限制后的比率:

这种受控更新有助于确保策略只会根据该动作相对于群组的表现情况逐渐向新的动作概率转移。

在本实验中,我们的目标是预测每种产品未来 5 天的平均销售量。我们使用 AR(1) 过程结合特定产品的噪声和偏移生成了一个包含 600 天数据的合成数据集。目标被定义为预测范围内的平均销售额。

数据经过标准化处理,然后按时间顺序分割(前 80% 用于训练,剩余 20% 用于验证),以确保进行时间外评估而不存在任何泄露。

比较了四种预测方法:

GRPO 启发模型: 包含一个 GRU 编码器和额外的策略分支。扩展 MLA (Mamba 风格) 的 GRPO 启发预测模型: 使用具有 ReLU 激活的非线性状态空间方法更新潜在状态。简单 GRU 模型: 不含 GRPO 修改。经典 ARMA 方法: 对原始数据使用 ARIMA(1,0,1) 模型以滚动方式应用经典的 ARMA 方法。import numpy as np import pandas as pdimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoaderfrom dateTime import datetime, timedeltaimport matplotlib.pyplot as pltfrom sklearn.preprocessing import StandardScalerfrom statsmodels.tsa.arima.model import ARIMA# Load datadf = pd.read_csv("sales_data.csv")# Optional: visualize the generated sales datadf.plot(x="date", y=["product1_sales", "product2_sales", "product3_sales", "product4_sales", "product5_sales"], figsize=(12, 6))plt.title("Simulated Sales Data for 5 Products")plt.show# # Data Preparation for Time Series Forecasting with Normalization# class Salesdataset(Dataset):def __init__(self, df, input_window=30, forecast_horizon=5):"""df: DataFrame with columns: date, product1_sales, ..., product5_salesinput_window: number of days used as inputforecast_horizon: number of days to forecast; target = avg sales over these days per product"""self.input_window = input_windowself.forecast_horizon = forecast_horizondf['date'] = pd.to_datetime(df['date'])df.sort_values('date', inplace=True)self.df = df.reset_index(drop=True)# Use only the sales columns (all except 'date')data = df.drop(columns=['date']).values.astype(np.float32)# Compute normalization parameters on the entire datasetself.mean = data.mean(axis=0)self.std = data.std(axis=0) + 1e-6 # avoid division by zero# Normalize the dataself.data = (data - self.mean) / self.stdself.n_samples = len(self.data) - input_window - forecast_horizon + 1def __len__(self):return self.n_samplesdef __getitem__(self, idx):# Input: sales for input_window days (normalized)x = self.data[idx: idx + self.input_window]# Target: average sales over the next forecast_horizon days (normalized)y = np.mean(self.data[idx + self.input_window: idx + self.input_window + self.forecast_horizon], axis=0)return x, ydef prepare_dataloaders(df, input_window=30, forecast_horizon=5, batch_size=32, train_ratio=0.8):dataset = SalesDataset(df, input_window, forecast_horizon)n_total = len(dataset)n_train = int(n_total * train_ratio)# Chronological (out-of-time) split: first n_train samples for training, remaining for validation.train_dataset = torch.utils.data.Subset(dataset, list(range(n_train)))val_dataset = torch.utils.data.Subset(dataset, list(range(n_train, n_total)))train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)return train_loader, val_loadertrain_loader, val_loader = prepare_dataloaders(df, input_window=30, forecast_horizon=5, batch_size=32, train_ratio=0.8)# # Model Components: Simple GRU-based Forecasting Model with GRPO-inspired Framework# class ForecastingGRPOModel(nn.Module):def __init__(self, input_dim, hidden_dim, num_products, forecast_horizon, dropout=0.2, lambda_policy=0.1):"""This model forecasts the average sales of the next 'forecast_horizon' days for each productusing a GRU encoder. It includes a policy branch to compute a GRPO-inspired loss."""super(ForecastingGRPOModel, self).__init__self.gru = nn.GRU(input_dim, hidden_dim, num_layers=2, batch_first=True, dropout=dropout)self.fc_forecast = nn.Linear(hidden_dim, num_products)self.policy_net = nn.Linear(hidden_dim, 1)self.lambda_policy = lambda_policydef forward(self, x):# x: (batch, seq_len, input_dim)gru_out, _ = self.gru(x)last_hidden = gru_out[:, -1, :] # (batch, hidden_dim)forecast = self.fc_forecast(last_hidden) # (batch, num_products)policy_value = self.policy_net(last_hidden) # (batch, 1)return forecast, policy_value# # Training and Validation Functions (Using MAPE as Error Metric)# def train_model_full(model, dataloader, optimizer, device, grad_clip=1.0):model.traintotal_loss = 0.0for x, y in dataloader:x = x.to(device) # (batch, seq_len, input_dim)y = y.to(device) # (batch, num_products)optimizer.zero_gradforecast, policy_value = model(x)# Forecast loss: use a combination of MSE and MAEloss_mse = F.mse_loss(forecast, y)loss_mae = F.l1_loss(forecast, y)loss_forecast = 0.5 * loss_mse + 0.5 * loss_mae# GRPO-inspired policy loss:# Compute advantage as the mean error over products for each sample.advantage = (y - forecast).mean(dim=1, keepdim=True)baseline = 0.5 # chosen constant baseliner_t = policy_value / baselineepsilon = 0.1r_t_clipped = torch.clamp(r_t, 1 - epsilon, 1 + epsilon)policy_loss = -torch.min(r_t * advantage, r_t_clipped * advantage).meanloss = loss_forecast + model.lambda_policy * policy_lossloss.backwardnn.utils.clip_grad_norm_(model.parameters, grad_clip)optimizer.steptotal_loss += loss.item * x.size(0)return total_loss / len(dataloader.dataset)def validate_model(model, dataloader, device, dataset_obj, debug=False):model.evalall_preds = all_targets = with torch.no_grad:for x, y in dataloader:x = x.to(device)y = y.to(device)forecast, _ = model(x)all_preds.append(forecast.cpu.numpy)all_targets.append(y.cpu.numpy)all_preds = np.concatenate(all_preds, axis=0)all_targets = np.concatenate(all_targets, axis=0)# Invert normalization: forecast_orig = forecast * std + meanmean = dataset_obj.meanstd = dataset_obj.stdall_preds_orig = all_preds * std + meanall_targets_orig = all_targets * std + meanif debug:print("Prediction range:", np.min(all_preds_orig), np.max(all_preds_orig))print("Target range:", np.min(all_targets_orig), np.max(all_targets_orig))mape = np.mean(np.abs((all_targets_orig - all_preds_orig) / (all_targets_orig + 1e-6))) * 100return mape# # ARMA Forecasting for Comparison# def arma_forecast(series, forecast_horizon):"""Fits an ARIMA(1,0,1) model on the provided series and forecasts forecast_horizon steps ahead.Returns the average forecast."""try:arma_model = ARIMA(series, order=(1, 0, 1))arma_fit = arma_model.fit(disp=0)forecast = arma_fit.forecast(steps=forecast_horizon)return np.mean(forecast)except Exception as e:return series[-1]def evaluate_arma(df, input_window=30, forecast_horizon=5, train_ratio=0.8):"""For each product (column), use a rolling ARMA forecast over the validation period on the raw data.Returns a dictionary of MAPE values per product and the overall average MAPE."""n = len(df)train_end = int(n * train_ratio)products = [col for col in df.columns if col != "date"]mape_dict = {}all_mapes = # Rolling forecast: start from i = (train_end - input_window) to (n - input_window - forecast_horizon + 1)for prod in products:preds = actuals = for i in range(train_end - input_window, n - input_window - forecast_horizon + 1):series = df[prod].values[:i + input_window]pred = arma_forecast(series, forecast_horizon)preds.append(pred)actual = np.mean(df[prod].values[i + input_window: i + input_window + forecast_horizon])actuals.append(actual)preds = np.array(preds)actuals = np.array(actuals)prod_mape = np.mean(np.abs((actuals - preds) / (actuals + 1e-6))) * 100mape_dict[prod] = prod_mapeall_mapes.append(prod_mape)overall_mape = np.mean(all_mapes)return mape_dict, overall_mape# # Main Training# def main:df = pd.read_csv("sales_data.csv")# Prepare out-of-time (chronological) dataloaders (first 80% for training, remaining for validation)train_loader, val_loader = prepare_dataloaders(df, input_window=30, forecast_horizon=5, batch_size=32, train_ratio=0.8)# For validation inversion, we need access to the dataset normalization parametersdataset_obj = SalesDataset(df, input_window=30, forecast_horizon=5)device = torch.device("cuda" if torch.cuda.is_available else "cpu")# Define model parametersinput_dim = train_loader.dataset[0][0].shape[-1] # e.g., 5 productshidden_dim = 256 # Hidden dimension for GRUnum_products = input_dim # Predict average sales for each productforecast_horizon = 10 # Note: This parameter is used in the models, even though targets are for next 5 dayslambda_policy = 0.06 # Weight for GRPO-inspired policy loss# # GRPO-inspired Forecasting Model (Existing)# model = ForecastingGRPOModel(input_dim, hidden_dim, num_products, forecast_horizon, dropout=0.2, lambda_policy=lambda_policy)model.to(device)optimizer = torch.optim.Adam(model.parameters, lr=0.0003)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)n_epochs = 22print("Training GRPO-inspired Forecasting Model...")for epoch in range(n_epochs):train_loss = train_model_full(model, train_loader, optimizer, device, grad_clip=1.0)mape = validate_model(model, val_loader, device, dataset_obj, debug=True)scheduler.stepprint(f"Epoch {epoch+1}/{n_epochs} - GRPO Model Train Loss: {train_loss:.4f}, MAPE: {mape:.2f}%")# # ARMA Evaluation for Comparison# print("\nEvaluating ARMA Forecasting on raw data...")arma_mapes, overall_arma_mape = evaluate_arma(df, input_window=30, forecast_horizon=5, train_ratio=0.8)print("ARMA MAPE per product:", arma_mapes)print("Overall ARMA MAPE:", overall_arma_mape, "%")# # Simple GRU Forecasting Model for Comparison# class SimpleGRUForecastingModel(nn.Module):def __init__(self, input_dim, hidden_dim, num_products, forecast_horizon, dropout=0.2):super(SimpleGRUForecastingModel, self).__init__self.gru = nn.GRU(input_dim, hidden_dim, num_layers=2, batch_first=True, dropout=dropout)self.fc_forecast = nn.Linear(hidden_dim, num_products)def forward(self, x):gru_out, _ = self.gru(x)last_hidden = gru_out[:, -1, :]forecast = self.fc_forecast(last_hidden)return forecastdef train_simple_model(model, dataloader, optimizer, device, grad_clip=1.0):model.traintotal_loss = 0.0for x, y in dataloader:x = x.to(device)y = y.to(device)optimizer.zero_gradforecast = model(x)loss_mse = F.mse_loss(forecast, y)loss_mae = F.l1_loss(forecast, y)loss = 0.5 * loss_mse + 0.5 * loss_maeloss.backwardnn.utils.clip_grad_norm_(model.parameters, grad_clip)optimizer.steptotal_loss += loss.item * x.size(0)return total_loss / len(dataloader.dataset)def validate_simple_model(model, dataloader, device, dataset_obj, debug=False):model.evalall_preds = all_targets = with torch.no_grad:for x, y in dataloader:x = x.to(device)y = y.to(device)forecast = model(x)all_preds.append(forecast.cpu.numpy)all_targets.append(y.cpu.numpy)all_preds = np.concatenate(all_preds, axis=0)all_targets = np.concatenate(all_targets, axis=0)mean = dataset_obj.meanstd = dataset_obj.stdall_preds_orig = all_preds * std + meanall_targets_orig = all_targets * std + meanif debug:print("Simple GRU Prediction range:", np.min(all_preds_orig), np.max(all_preds_orig))print("Simple GRU Target range:", np.min(all_targets_orig), np.max(all_targets_orig))mape = np.mean(np.abs((all_targets_orig - all_preds_orig) / (all_targets_orig + 1e-6))) * 100return mapeprint("\nTraining Simple GRU Forecasting Model for Comparison...")simple_model = SimpleGRUForecastingModel(input_dim, hidden_dim, num_products, forecast_horizon, dropout=0.2)simple_model.to(device)optimizer_simple = torch.optim.Adam(simple_model.parameters, lr=0.0003)scheduler_simple = torch.optim.lr_scheduler.StepLR(optimizer_simple, step_size=10, gamma=0.9)n_epochs_simple = 22for epoch in range(n_epochs_simple):train_loss_simple = train_simple_model(simple_model, train_loader, optimizer_simple, device, grad_clip=1.0)simple_mape = validate_simple_model(simple_model, val_loader, device, dataset_obj, debug=True)scheduler_simple.stepprint(f"Epoch {epoch+1}/{n_epochs_simple} - Simple GRU Train Loss: {train_loss_simple:.4f}, MAPE: {simple_mape:.2f}%")# # New Method: GRPO-inspired Forecasting with Extended MLA (Mamba-style) Mechanism# class ForecastingGRPOMLAModel(nn.Module):def __init__(self, input_dim, hidden_dim, num_products, forecast_horizon, dropout=0.3, lambda_policy=0.06):"""This model extends the GRPO-inspired forecasting approach by incorporating anextended MLA (Mamba-style) mechanism. The latent state is updated in a state-spacemanner with a nonlinear activation applied to the entire update."""super(ForecastingGRPOMLAModel, self).__init__self.hidden_dim = hidden_dimself.lambda_policy = lambda_policyself.dropout = nn.Dropout(dropout)# Map the input to the latent space.self.input_transform = nn.Linear(input_dim, hidden_dim)# State-space transition matrix (M)self.M = nn.Linear(hidden_dim, hidden_dim, bias=False)# Nonlinear activation function for the complete state update.self.activation = nn.ReLUself.fc_forecast = nn.Linear(hidden_dim, num_products)self.policy_net = nn.Linear(hidden_dim, 1)def forward(self, x):# x: (batch, seq_len, input_dim)batch_size, seq_len, _ = x.size# Initialize latent state as zeros.h = torch.zeros(batch_size, self.hidden_dim, device=x.device)# Iteratively update latent state with a state-space update.for t in range(seq_len):x_t = x[:, t, :] # (batch, input_dim)# Compute the correction without activation first.correction = self.input_transform(x_t)# Update latent state using the ReLU activation applied to the entire sum.h = self.activation(self.M(h) + correction)h = self.dropout(h)forecast = self.fc_forecast(h)policy_value = self.policy_net(h)return forecast, policy_valueprint("\nTraining GRPO-inspired Forecasting Model with Extended MLA (Mamba-style) Mechanism...")model_extended = ForecastingGRPOMLAModel(input_dim, hidden_dim, num_products, forecast_horizon, dropout=0.2, lambda_policy=lambda_policy)model_extended.to(device)optimizer_extended = torch.optim.Adam(model_extended.parameters, lr=0.0003)scheduler_extended = torch.optim.lr_scheduler.StepLR(optimizer_extended, step_size=10, gamma=0.9)n_epochs_extended = 22for epoch in range(n_epochs_extended):train_loss_extended = train_model_full(model_extended, train_loader, optimizer_extended, device, grad_clip=1.0)mape_extended = validate_model(model_extended, val_loader, device, dataset_obj, debug=True)scheduler_extended.stepprint(f"Epoch {epoch+1}/{n_epochs_extended} - Extended MLA Model Train Loss: {train_loss_extended:.4f}, MAPE: {mape_extended:.2f}%")if __name__ == "__main__":mainTraining GRPO-inspired Forecasting Model...Prediction range: 8.184391 53.145542Target range: 4.9999995 59.75939Epoch 1/22 - GRPO Model Train Loss: 0.3903, MAPE: 24.59%Prediction range: 6.2827315 55.90737Target range: 4.9999995 59.75939Epoch 2/22 - GRPO Model Train Loss: 0.3462, MAPE: 23.57%Prediction range: 6.501168 56.99558Target range: 4.9999995 59.75939..........Epoch 19/22 - GRPO Model Train Loss: 0.2977, MAPE: 21.58%Prediction range: 5.120795 58.1873Target range: 4.9999995 59.75939Epoch 20/22 - GRPO Model Train Loss: 0.2957, MAPE: 21.55%Prediction range: 5.0320196 57.984074Target range: 4.9999995 59.75939Epoch 21/22 - GRPO Model Train Loss: 0.2961, MAPE: 21.72%Prediction range: 4.7999735 58.01819Target range: 4.9999995 59.75939Epoch 22/22 - GRPO Model Train Loss: 0.2957, MAPE: 21.60%Evaluating ARMA Forecasting on raw data...ARMA MAPE per product: {'product1_sales': 43.34198894490143, 'product2_sales': 35.41790083634351, 'product3_sales': 23.351334279549448, 'product4_sales': 16.953989203635224, 'product5_sales': 12.596327432823836}Overall ARMA MAPE: 26.332308139450692 %Training Simple GRU Forecasting Model for Comparison...Simple GRU Prediction range: 8.621262 53.7664Simple GRU Target range: 4.9999995 59.75939Epoch 1/22 - Simple GRU Train Loss: 0.3817, MAPE: 25.23%Simple GRU Prediction range: 6.4578733 56.249523Simple GRU Target range: 4.9999995 59.75939Epoch 2/22 - Simple GRU Train Loss: 0.3409, MAPE: 23.96%Simple GRU Prediction range: 6.7773395 57.415825Simple GRU Target range: 4.9999995 59.75939Epoch 3/22 - Simple GRU Train Loss: 0.3279, MAPE: 24.20%..........Epoch 19/22 - Simple GRU Train Loss: 0.2978, MAPE: 22.50%Simple GRU Prediction range: 5.1781693 59.017365Simple GRU Target range: 4.9999995 59.75939Epoch 20/22 - Simple GRU Train Loss: 0.2962, MAPE: 22.52%Simple GRU Prediction range: 4.7814617 59.553993Simple GRU Target range: 4.9999995 59.75939Epoch 21/22 - Simple GRU Train Loss: 0.2936, MAPE: 22.53%Simple GRU Prediction range: 4.7923336 59.583473Simple GRU Target range: 4.9999995 59.75939Epoch 22/22 - Simple GRU Train Loss: 0.2919, MAPE: 22.57%Training GRPO-inspired Forecasting Model with Extended MLA (Mamba-style) Mechanism...Prediction range: 2.4405355 54.40341Target range: 4.9999995 59.75939Epoch 1/22 - Extended MLA Model Train Loss: 0.4088, MAPE: 21.55%Prediction range: 4.095806 57.311417Target range: 4.9999995 59.75939Epoch 2/22 - Extended MLA Model Train Loss: 0.3377, MAPE: 21.39%Prediction range: 4.140363 57.22279Target range: 4.9999995 59.75939Epoch 3/22 - Extended MLA Model Train Loss: 0.3261, MAPE: 21.44%Prediction range: 4.4515967 56.920387Target range: 4.9999995 59.75939........Epoch 18/22 - Extended MLA Model Train Loss: 0.2943, MAPE: 21.00%Prediction range: 4.5498137 58.417767Target range: 4.9999995 59.75939Epoch 19/22 - Extended MLA Model Train Loss: 0.2901, MAPE: 21.52%Prediction range: 3.9070492 58.727776Target range: 4.9999995 59.75939Epoch 20/22 - Extended MLA Model Train Loss: 0.2878, MAPE: 21.10%Prediction range: 4.2164598 58.74173Target range: 4.9999995 59.75939Epoch 21/22 - Extended MLA Model Train Loss: 0.2821, MAPE: 21.23%Prediction range: 3.712309 57.99107Target range: 4.9999995 59.75939Epoch 22/22 - Extended MLA Model Train Loss: 0.2784, MAPE: 20.82%

GRPO 启发模型,它集成了扩展 MLA 模块与 GRU 和额外的策略分支,展示了稳健的性能。在 22 个训练周期中,其 MAPE 稳步下降并最终稳定在 21.6% 左右。该模型的预测范围始终与目标范围保持良好一致,表明其自适应机制有效地捕获了潜在的销售模式。

相比之下,缺乏 GRPO 特定修改的简单 GRU 模型产生了略高的 MAPE,平均约为 22.3%。虽然简单 GRU 的预测也落在类似范围内,但 GRPO 模型观察到的边际改进表明,额外的策略损失和扩展的潜在更新有助于适度但有意义地减少预测误差。

具有扩展 MLA (Mamba 风格) 机制的 GRPO 启发模型进一步改进了性能。其非线性状态空间更新,将ReLU激活应用于完整的状态更新,使MAPE降低到20.8-21.3%。这种改进表明使用更丰富的潜在表示在捕获时间序列动态方面具有优势。

最后,ARMA方法显示出显著更高的误差。每个产品的MAPE从约12.6%到超过43%不等,总体MAPE约为26.3%,ARMA在处理复杂的多维销售数据方面的效果不如深度学习方法。

总的来说,这些实验表明,深度学习模型,特别是那些通过 GRPO 和扩展 MLA 技术增强的模型,在预测多个时间序列的平均销售方面优于经典方法。

让我们总结 DeepSeek 的突破 - MLA 和 GRPO - 到一个自适应模型中。目标是构建一个系统,该系统不仅使用高效的低秩注意力来处理长序列,还利用强化学习来动态地智能选择或混合模型。

输入编码和潜在压缩:

每个输入 token h_t 首先由编码器处理。

编码器使用以下方式将 h_t 压缩到潜在表示:

将原始维度d减少到较小的潜在空间维度d_c。

重构Key和Value:

通过以下方式重构注意力所需的key和value:

这种重构确保保留了基本上下文的同时保持KV缓存较小。

查询压缩(可选):

注意力计算:

使用压缩的查询、键和值计算多头注意力。每个头应用常规的注意力公式,但减少的维度使得过程更加高效。

通过GRPO进行策略决策:

模型然后使用强化学习模块来选择最佳行动 - 无论是选择一个模型还是混合多个模型。RL策略π(a∣s;θ)接收状态s(包括来自MLA模块的潜在特征和额外的统计摘要)作为输入并输出一个行动。GRPO通过将新输出与早期固定策略的输出集进行比较来更新这个策略。计算优势并对更新进行裁剪以确保稳定性。

在前面的章节中,我们讨论了 MLA 和 GRPO 在 DeepSeek 中如何有效地协同工作,形成其核心技术。通过结合上面的技术,可以提出一个统一的框架,将 MLA 和 GRPO 结合用于多产品时间序列预测。使用包含日期、product1_sales、product2_sales、...、product5_sales列的DataFrame,我们的目标是预测每个产品未来10天的平均销售额。我们的方法将状态空间建模(使用"Mamba风格"方法)与潜在注意力相结合,并通过GRPO使用基于强化学习的策略优化来动态调整预测。

以下,我将概述数学基础、算法细节,并提供一个实际示例。

将销售数据表示为多变量时间序列:

{x_t},其中每个x_t表示时间t时p个产品(这里p=5)的销售量。目标是预测每个产品未来10天的平均销售额。

这里的挑战在于时间序列可能存在内部相关性和滞后效应。例如,产品1在第t天的销售量可能不仅取决于其自身过去的销售量,还取决于产品2或产品3前几天的销售量。

为了捕获时间序列预测中固有的时间动态,我建议用非线性激活增强的状态空间更新来扩展潜在压缩步骤。在这个框架中,我假设压缩的潜在向量根据以下方式随时间演变:

其中:

M∈R^{d_c×d_c}是建模潜在状态动态的转移矩阵。η(x_t)是一个函数,将当前输入x_t(如时间t的销售数据)映射到潜在空间中的校正项。ReLU激活应用于整个更新,引入非线性并确保更新后的潜在状态为非负。

可以对查询压缩应用类似的更新:

其中M'∈R^{d_c'×d_c'}的定义类似。

解释:
这个状态空间更新"记住"过去的潜在状态c_{KV,t}并使用新信息x_t对其进行调整。通过对整个和应用ReLU激活:

,模型捕获历史状态与新输入之间的复杂非线性交互。非线性有助于建模复杂的时间模式,同时确保潜在表示保持非负。这种方法类似于RNN或LSTM更新其隐藏状态的方式。

整合多头注意力

在获得动态潜在状态c{KV,t}和c{Q,t}后,我将它们投影到多头注意力的键、值和查询中。

假设我将潜在空间分成h个头。对于头i:

其中W{Q,i}、W{K,i}、W_{V,i}和d_h为每个头的维度。

每个头的注意力计算为:

然后将所有头的输出连接起来并用输出矩阵W_O投影:

这种多头机制使模型能够捕获销售数据中时间关系的不同方面。例如,一个头可能学习趋势分量,另一个可能关注季节性,等等。

在这个框架中,我们首先定义在给定长度为T的输入窗口和H天预测范围内的预测问题。对于每个产品时间序列,目标y被计算为未来H天原始销售量的平均值,即

其中x_t表示第t天的销售量。

GRPO模型使用两层GRU从标准化输入序列X∈R{T×D}(D为产品数量)中提取时间特征。令最后一个时间步的隐藏状态为h_T∈RH_d。这个h_T然后通过两个单独的线性投影映射:

预测分支使用权重矩阵W_f计算预测:

其中y是每个产品预测的平均销售量向量。

策略分支通过另一个线性映射W_p计算一个标量值p(策略值):

GRPO启发损失的核心思想是基于一个"优势"信号来调整预测,该信号衡量预测误差相对于常数基准b(这里选择b=0.5)的表现。

具体而言,优势被定义为:

其中平均值在产品维度上取均值。策略值与基准之间的比率r计算为:

为了确保训练期间的稳定性并避免大的策略更新,使用了裁剪机制。令:

其中ϵ是一个小值(例如0.1)。GRPO启发的策略损失则被公式化为:

如果新策略(即预测)没有相对于基准充分改进,该损失会对模型进行惩罚。总体损失函数是预测损失和策略损失的组合:

其中,

λ是控制策略损失权重的超参数。

这种方法被嵌入到时间外验证方案中:时间序列数据按时间顺序分割,确保只使用过去数据进行训练,未来数据用于验证 - 从而避免数据泄露。在验证期间,使用存储的均值和标准差将归一化预测y转换回原始尺度,并计算平均绝对百分比误差(MAPE):

其中ϵ是一个小常数,用于避免除以零。

在这里,用于多时间序列预测的GRPO方法使用GRU编码器提取时间依赖特征,并产生预测和策略值。预测使用组合的MSE/MAE损失进行评估,而策略分支使用裁剪优势机制提供额外的梯度信号,最终导致对预测范围内的平均销售额进行更稳健的预测。

本文介绍的DeepSeek-TS方法利用 GRPO 结合使用 Mamba 风格状态空间更新的扩展 MLA 模块。实验表明,这个 GRPO 启发模型可以实现更好的性能 - 更低的MAPE - 比简单的GRU模型和经典的ARMA方法。由策略分支和状态更新中的非线性激活驱动的增强潜在表示似乎能更有效地捕获销售数据的复杂动态。

GRPO 和扩展 MLA 框架在应用于其他领域方面具有巨大潜力。例如,这种方法可以适用于金融时间序列预测,在这种情况下捕获市场趋势的细微变化至关重要。它也可能对医疗保健诊断有益,在那里从多个时间相关信号预测病人结果可以导致更早的干预。

未来的工作可以集中在通过实验不同的基准值或裁剪阈值来进一步改进 GRPO 机制,以及探索扩展的 MLA 模块如何与其他深度学习架构集成。此外,整合元学习技术可能使模型能够在不同领域之间更好地泛化。总的来说,这项研究表明,将强化学习与先进的注意力机制相结合是构建更智能、更具适应性的预测系统的一个有前途的方向。

代码:

本文作者:Shenggang Li

最后说明,我看了一下作者的github代码,用其他的序列数据测试,得到的结果和这篇文章有一些出入,但是作者的思路我觉得可以借鉴。如果你自己测试的话,欢迎留言回复测试结果

来源:deephub

相关推荐