量化(Quantization)这词儿听着玄,经常和量化交易Quantitative Trading (量化交易)混淆。
其实机器学习(深度学习)领域的量化Quantization是和节约内存、提高运算效率相关的概念(因大模型的普及,这个量化问题尤为迫切)。
揭秘机器学习“量化”:不止省钱,更让AI高效跑起来!
“量化(Quantization)”这个词,在机器学习领域,常常让人联想到复杂的数学或是与金融交易相关的“量化交易”,从而感到困惑。但实际上,它与我们日常生活中的数字转换概念更为接近,而在AI世界里,它扮演的角色是节约内存、提高运算效率的“幕后英雄”(现在已经显露到“幕布之前”),尤其在大模型时代,其重要性日益凸显。
那么,机器学习中的“量化”究竟是啥?咱为啥用它?
什么是机器学习中的“量化”(Quantization)?
简单讲,机器学习中的“量化”就是将模型中原本采用高精度浮点数(如32位浮点数,即FP32)表示的权重(weights)和激活值(activations),转换成低精度表示(如8位整数,即INT8)的过程。
你可以把它想象成“数字的压缩”。在计算机中,浮点数就像是拥有无限小数位的精确数字,而整数则像只有整数部分的数字。从高精度浮点数到低精度整数的转换,必然会损失一些信息,但与此同时,它也带来了显著的优势:
- 内存占用大幅减少: 8位整数比32位浮点数少占用4倍的内存空间。这意味着更大的模型可以被部署到内存有限的设备上(如手机、IoT设备),或者在相同内存下可以运行更大的模型。
- 计算速度显著提升: 整数运算通常比浮点数运算更快、功耗更低。这使得模型在推理(Inference)阶段能以更高的效率运行,减少延迟。
为何需要“量化”?
随着深度学习模型变得越来越大,越来越复杂,它们对计算资源的需求也呈爆炸式增长。一个动辄几十亿甚至上百亿参数的大模型,如果全部使用FP32存储和计算,将对硬件资源提出极高的要求。
- 部署到边缘设备: 手机、自动驾驶汽车、智能音箱等边缘设备通常算力有限,内存紧张。量化是让大模型“瘦身”后成功“登陆”这些设备的必经之路。
- 降低运行成本: 在云端部署大模型时,更低的内存占用和更快的计算速度意味着更低的服务器成本和能耗。
- 提升用户体验: 实时响应的AI应用,如语音助手、图像识别等,对推理速度有极高要求。量化可以有效缩短响应时间。
量化策略:后训练量化 vs. 量化感知训练(QAT)
量化并非只有一种方式。根据量化发生的时间点,主要可以分为两大类:
-
后训练量化(Post-Training Quantization, PTQ): 顾名思义,PTQ 是在模型训练完成之后,对已经训练好的FP32模型进行量化。它操作简单,不需要重新训练,是实现量化的最快途径。然而,由于量化过程中会损失精度,PTQ 可能会导致模型性能(如准确率)的下降。对于对精度要求不那么苛刻的应用,PTQ 是一个不错的选择。
-
量化感知训练*(这是本文重点推介的:Quantization Aware Training, QAT): 这正是我们今天着重讲解的明星策略!QAT 的核心思想是——在模型训练过程中,就“感知”到未来的量化操作。
在QAT中,量化误差被集成到模型的训练循环中。这意味着,模型在训练时就“知道”它最终会被量化成低精度,并会努力学习如何在这种低精度下保持最优性能。
具体来说,QAT通常通过在模型中插入“伪量化”(Fake Quantization)节点来实现。这些节点在训练过程中模拟量化和反量化操作,使得模型在FP32环境下进行前向传播和反向传播时,能够学习到量化对模型参数和激活值的影响。当训练完成后,这些伪量化节点会被真正的量化操作所取代,从而得到一个高性能的量化模型。
为什么QAT是量化策略的“王牌”?
相较于PTQ……QAT 的优势显而易见:
- 精度损失最小: 这是QAT最大的亮点。通过在训练过程中模拟量化,模型能够自我调整以适应量化带来的精度损失,从而在量化后依然保持接近FP32模型的性能。
- 适用于更苛刻的场景: 对于那些对模型精度要求极高,不能容忍明显性能下降的应用(如自动驾驶、医疗影像分析),QAT几乎是唯一的选择。
- 更好的泛化能力: 在训练阶段就考虑量化,使得模型在量化后对各种输入数据具有更好的鲁棒性。
PyTorch中的QAT实践
在PyTorch中实现QAT,通常需要以下几个关键步骤:
- 准备量化配置: 定义量化类型(如INT8)、量化方法(如对称量化、非对称量化)以及需要量化的模块。
- 模型转换: 使用PyTorch提供的
torch.quantization
模块,将普通的FP32模型转换为QAT模型。这个过程会在模型中插入伪量化模块。 - 重新训练/微调: 在新的数据集上对转换后的模型进行短时间的微调(Fine-tuning),或者在原有训练基础上继续训练。这个阶段,模型会学习如何适应伪量化带来的精度损失。
- 模型融合(可选但推荐): 将一些连续的层(如Conv-BN-ReLU)融合为一个操作,可以进一步提高量化后的推理效率。
- 模型量化和保存: 训练完成后,将微调好的QAT模型转换为真正的量化模型,并保存。
总结
量化(Quantization)是深度学习模型优化不可或缺的一环,它通过降低模型精度来换取内存和计算效率的大幅提升。而量化感知训练(QAT)作为一种高级量化策略,通过在训练阶段就考虑量化对模型的影响,极大地减小了量化带来的精度损失,使得在各种设备上部署高性能AI模型成为可能。
随着大模型和边缘AI的普及,掌握量化尤其是QAT的原理和实践,将成为每一位AI工程师和研究人员的必备技能。让我们一起,让AI跑得更快、更高效!
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
import numpy as np
import os# ===== 1. XOR数据集 =====
X = torch.tensor([[0., 0.],[0., 1.],[1., 0.],[1., 1.]
], dtype=torch.float32)
y = torch.tensor([[0.],[1.],[1.],[0.]
], dtype=torch.float32)# ===== 2. 神经网络模型 (标准FP32) =====
class XORNet(nn.Module):def __init__(self):super(XORNet, self).__init__()self.fc1 = nn.Linear(2, 3)self.relu = nn.ReLU()self.fc2 = nn.Linear(3, 1)# QAT阶段不用Sigmoid,直接用BCEWithLogitsLoss# self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)# return self.sigmoid(x)return x# ===== 3. 初始化模型/优化器 =====
model = XORNet()
# He初始化,适合ReLU
for m in model.modules():if isinstance(m, nn.Linear):nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)# ===== 4. 训练(标准模型) =====
print("--- 开始标准模型训练 ---")
epochs = 900#1000#500 #1500
for epoch in range(epochs):outputs = model(X)loss = criterion(outputs, y)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 300 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')with torch.no_grad():probs = torch.sigmoid(model(X))predictions = (probs > 0.5).float()accuracy = (predictions == y).sum().item() / y.numel()print(f"\n标准模型训练后精度: {accuracy*100:.2f}%")print(f"标准模型预测结果:\n{predictions}")# ===== 5. 构建QAT模型 =====
class XORNetQAT(nn.Module):def __init__(self):super(XORNetQAT, self).__init__()# 量化Stubself.quant = torch.quantization.QuantStub()self.fc1 = nn.Linear(2, 3)self.relu = nn.ReLU()self.fc2 = nn.Linear(3, 1)self.dequant = torch.quantization.DeQuantStub()def forward(self, x):x = self.quant(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.dequant(x)return xdef fuse_model(self):torch.quantization.fuse_modules(self, [['fc1', 'relu']], inplace=True)# ===== 6. QAT前权重迁移、模型融合 =====
model_qat = XORNetQAT()
# 迁移参数
model_qat.load_state_dict(model.state_dict())
# 融合(此步必须!)
model_qat.fuse_model()# 配置QAT
model_qat.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # CPU
torch.quantization.prepare_qat(model_qat, inplace=True)optimizer_qat = optim.Adam(model_qat.parameters(), lr=0.01)# ===== 7. QAT训练 =====
qat_epochs = 700
print("\n--- QAT训练 ---")
for epoch in range(qat_epochs):model_qat.train()outputs_qat = model_qat(X)loss_qat = criterion(outputs_qat, y)optimizer_qat.zero_grad()loss_qat.backward()optimizer_qat.step()if (epoch + 1) % 150 == 0:print(f'QAT Epoch [{epoch+1}/{qat_epochs}], Loss: {loss_qat.item():.4f}')# ===== 8. 转换到量化模型/评估精度 =====
print("\n--- 转换为量化模型 ---")
model_qat.eval()
model_quantized = torch.quantization.convert(model_qat.eval(), inplace=False)with torch.no_grad():probs_quantized = torch.sigmoid(model_quantized(X))predictions_quantized = (probs_quantized > 0.5).float()accuracy_quantized = (predictions_quantized == y).sum().item() / y.numel()print(f"量化模型精度: {accuracy_quantized*100:.2f}%")print(f"量化模型预测:\n{predictions_quantized}")# ===== 9. 模型大小对比 =====
torch.save(model.state_dict(), 'xor_fp32.pth')
torch.save(model_quantized.state_dict(), 'xor_int8.pth')
fp32_size = os.path.getsize('xor_fp32.pth') / (1024 * 1024)
int8_size = os.path.getsize('xor_int8.pth') / (1024 * 1024)
print(f"\nFP32模型大小: {fp32_size:.6f} MB")
print(f"INT8模型大小: {int8_size:.6f} MB")
print(f"模型缩减比例: {fp32_size/int8_size:.2f} 倍")
运行结果:
====================== RESTART: F:/qatXorPytorch250501.py
--- 开始标准模型训练 ---
Epoch [300/900], Loss: 0.0033
Epoch [600/900], Loss: 0.0011
Epoch [900/900], Loss: 0.0005
标准模型训练后精度: 100.00%
标准模型预测结果:
tensor([[0.],
[1.],
[1.],
[0.]])
--- QAT训练 ---
QAT Epoch [150/700], Loss: 0.0005
QAT Epoch [300/700], Loss: 0.0004
QAT Epoch [450/700], Loss: 0.0004
QAT Epoch [600/700], Loss: 0.0003
--- 转换为量化模型 ---
量化模型精度: 100.00%
量化模型预测:
tensor([[0.],
[1.],
[1.],
[0.]])
FP32模型大小: 0.001976 MB
INT8模型大小: 0.004759 MB
模型缩减比例: 0.42 倍
最后:
其实,这次量化
量化后的模型是量化前的 4.2倍(咦?不是说好了压缩吗?咋变大了?)
魔鬼藏在细节:
咱们看看 量化 之前的 (基线的)模型 参数+权重等等:
===== Model Architecture =====
XORNet(
(fc1): Linear(in_features=2, out_features=3, bias=True)
(relu): ReLU()
(fc2): Linear(in_features=3, out_features=1, bias=True)
)
===== Layer Parameters =====
[fc1.weight] shape: (3, 2)
[[ 1.723932, 1.551827],
[ 2.106917, 1.681809],
[-0.299378, -0.444912]]
[fc1.bias] shape: (3,)
[-1.725313, -2.509506, 0. ]
[fc2.weight] shape: (1, 3)
[[-2.492318, -3.94821 , 0.911841]]
[fc2.bias] shape: (1,)
[0.692789]
===== Extra Info (Hyperparameters) =====
Optimizer: Adam
Learning Rate: 0.05
Epochs: 900
Loss: BCEWithLogitsLoss
Activation: ReLU
量化之后的模型参数等:
超参数部分:
===== Model Architecture =====
XORNetQAT(
(quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
(fc1): QuantizedLinearReLU(in_features=2, out_features=3, scale=0.0678500160574913, zero_point=0, qscheme=torch.per_channel_affine)
(relu): Identity()
(fc2): QuantizedLinear(in_features=3, out_features=1, scale=0.1376650333404541, zero_point=65, qscheme=torch.per_channel_affine)
(dequant): DeQuantize()
)
===== Layer Parameters =====
===== Extra Info (Hyperparameters) =====
Optimizer: Adam
Learning Rate: 0.01
QAT Epochs: 700
Loss: BCEWithLogitsLoss
QConfig: QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x00000164FDB16160>}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x00000164FDB16160>})
看到吗?
量化后, 参数 变多了哈哈!
那量化的意义到底在哪里呢??
在下面的 参数(非超参)的 权重的部分:
===== 量化 (Quantization) 后参数 =====
[fc1] (QuantizedLinearReLU)
[weight] shape: torch.Size([3, 2]), dtype: torch.qint8
weight (quantized):
[[ 3.97 3.98]
[ 3.95 3.96]
[-0.56 -1.05]]
weight (raw int):
[[127 127]
[127 127]
[-18 -33]]
scale: 0.06785
zero_point: 0
[bias] shape: torch.Size([3]), dtype: torch.float
bias:
[-3.049073e-04 -3.960315e+00 0.000000e+00]
[fc2] (QuantizedLinear)
[weight] shape: torch.Size([1, 3]), dtype: torch.qint8
weight (quantized):
[[ 3.77 -7.79 0.41]]
weight (raw int):
[[ 27 -57 3]]
scale: 0.13766
zero_point: 65
[bias] shape: torch.Size([1]), dtype: torch.float
bias:
[-7.147094]
再看看量化前:
===== Layer Parameters =====
[fc1.weight] shape: (3, 2)
[[ 1.723932, 1.551827],
[ 2.106917, 1.681809],
[-0.299378, -0.444912]]
[fc1.bias] shape: (3,)
[-1.725313, -2.509506, 0. ]
[fc2.weight] shape: (1, 3)
[[-2.492318, -3.94821 , 0.911841]]
[fc2.bias] shape: (1,)
[0.692789]
最后看看 量化 后:
===== 量化 (Quantization) 后参数 =====
[fc1] (QuantizedLinearReLU)
[weight] shape: torch.Size([3, 2]), dtype: torch.qint8
weight (quantized):
[[ 3.97 3.98]
[ 3.95 3.96]
[-0.56 -1.05]]
weight (raw int):
[[127 127]
[127 127]
[-18 -33]]
scale: 0.06785
zero_point: 0
[bias] shape: torch.Size([3]), dtype: torch.float
bias:
看出区别了吗?
So:
量化操作 只 适用于 大、中型号的模型……道理就在此:
量化前 的 Weights 的权重 全部都是: Float32浮点型 ……很占内存的!
量化后 是所谓INT8(即一个字节、8bits)……至少(在权重部分)节省了 3/4 的内存!
So: 大模型 必须 要 量化 才 节省内存。
当然前提 是 你的 GPU 硬件 要 支持 INT8(8bits)的 运算哦……(这是后话,下次再聊)。
(over)完!
INT4版本,不调用pytorch,只用numpy:
import numpy as np# S型激活函数
def sigmoid(x):return 1 / (1 + np.exp(-np.clip(x, -500, 500))) # 裁剪以防止溢出# S型激活函数的导数
def sigmoid_derivative(x):return x * (1 - x)# INT4 量化函数 (对称量化)
# scale 和 zero_point 通常在 QAT 中学习,但为简单起见,
# 这里我们将基于动态确定的 min_val 和 max_val。
def quantize_int4(tensor, min_val, max_val):# 确定将 float32 映射到 INT4 范围 [-8, 7]q_min = -8q_max = 7epsilon = 1e-7 # 一个小常数,用于防止除以零或在min_val和max_val非常接近时出现问题# 确保 min_val 和 max_val 之间有足够的间隔if abs(max_val - min_val) < epsilon:# 如果它们都接近0,则将范围扩大一点点以稳定包含0if abs(min_val) < epsilon and abs(max_val) < epsilon:min_val_eff = -epsilonmax_val_eff = epsilonelse: # 否则,围绕它们的值稍微扩大范围min_val_eff = min_val - epsilonmax_val_eff = max_val + epsilonelse:min_val_eff = min_valmax_val_eff = max_val# 计算 scalescale = (q_max - q_min) / (max_val_eff - min_val_eff)# 计算 zero_point (对于非对称量化是必要的,对于对称量化,理想情况下是0)# 这里我们使用标准的仿射量化公式,然后对zero_point进行取整和裁剪zero_point = q_min - min_val_eff * scale# 将 zero_point 四舍五入到最近的整数并裁剪zero_point = np.round(zero_point)zero_point = np.clip(zero_point, q_min, q_max).astype(np.int32)# 量化张量quantized_tensor = np.round(tensor * scale + zero_point)# 裁剪到 INT4 范围quantized_tensor = np.clip(quantized_tensor, q_min, q_max)return quantized_tensor.astype(np.int8), scale, zero_point # 用int8类型存储INT4的值# INT4 反量化函数
def dequantize_int4(quantized_tensor, scale, zero_point):return (quantized_tensor.astype(np.float32) - zero_point) / scale # 确保运算使用浮点数# 直通估计器 (STE) 用于量化
# 在前向传播中,应用量化。
# 在反向传播中,梯度直接通过,好像没有量化发生。
class STEQuantizer:def __init__(self):self.quantized_val = Noneself.scale = Noneself.zero_point = Nonedef forward(self, x, min_val_for_quant, max_val_for_quant):# 量化在前向传播中进行quantized_x, scale, zero_point = quantize_int4(x, min_val_for_quant, max_val_for_quant)self.quantized_val = quantized_xself.scale = scaleself.zero_point = zero_point# 为了在网络中进行实际计算(模拟量化硬件的行为),需要反量化return dequantize_int4(quantized_x, scale, zero_point)def backward(self, grad_output):# 在反向传播中,梯度直接通过# 这是STE的核心:输出对输入的梯度被认为是1return grad_output# 定义神经网络类
class NeuralNetwork:def __init__(self, input_size, hidden_size, output_size, learning_rate=0.1, l2_lambda=0.0001, ema_decay=0.99):self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_sizeself.learning_rate = learning_rateself.l2_lambda = l2_lambda # L2正则化系数self.ema_decay = ema_decay # EMA衰减因子,用于更新量化范围统计# 初始化权重和偏置# 权重初始化为小的随机值self.weights_input_hidden = np.random.uniform(-0.5, 0.5, (self.input_size, self.hidden_size))self.bias_hidden = np.zeros((1, self.hidden_size))self.weights_hidden_output = np.random.uniform(-0.5, 0.5, (self.hidden_size, self.output_size))self.bias_output = np.zeros((1, self.output_size))# 用于量化的统计量 (min/max),初始为None,将在第一次前向传播时或通过EMA更新self.min_wih = Noneself.max_wih = Noneself.min_who = Noneself.max_who = None# 权重用的量化器self.quantizer_wih = STEQuantizer()self.quantizer_who = STEQuantizer()def _update_ema_stats(self, tensor_data, current_min_stat, current_max_stat):# 使用指数移动平均(EMA)更新统计的min/max值current_tensor_min = tensor_data.min()current_tensor_max = tensor_data.max()if current_min_stat is None or current_max_stat is None: # 第一次或未初始化new_min = current_tensor_minnew_max = current_tensor_maxelse:new_min = self.ema_decay * current_min_stat + (1 - self.ema_decay) * current_tensor_minnew_max = self.ema_decay * current_max_stat + (1 - self.ema_decay) * current_tensor_maxreturn new_min, new_maxdef forward(self, X, quantize=False, is_training=False):# 如果指定量化并且在训练阶段,则更新权重的EMA统计范围if quantize and is_training:self.min_wih, self.max_wih = self._update_ema_stats(self.weights_input_hidden, self.min_wih, self.max_wih)self.min_who, self.max_who = self._update_ema_stats(self.weights_hidden_output, self.min_who, self.max_who)# 如果指定量化,则对权重进行量化if quantize:# 确保min/max统计量已初始化 (例如,在第一次EMA更新后)# 如果仍然是None (可能发生在第一次推理且之前没有训练),则使用当前权重的瞬时min/maxmin_wih_eff = self.min_wih if self.min_wih is not None else self.weights_input_hidden.min()max_wih_eff = self.max_wih if self.max_wih is not None else self.weights_input_hidden.max()min_who_eff = self.min_who if self.min_who is not None else self.weights_hidden_output.min()max_who_eff = self.max_who if self.max_who is not None else self.weights_hidden_output.max()self.quantized_wih_dequant = self.quantizer_wih.forward(self.weights_input_hidden, min_wih_eff, max_wih_eff)self.quantized_who_dequant = self.quantizer_who.forward(self.weights_hidden_output, min_who_eff, max_who_eff)else:# 使用原始的FP32权重self.quantized_wih_dequant = self.weights_input_hiddenself.quantized_who_dequant = self.weights_hidden_output# 输入层到隐藏层self.hidden_layer_input = np.dot(X, self.quantized_wih_dequant) + self.bias_hiddenself.hidden_layer_output = sigmoid(self.hidden_layer_input)# 隐藏层到输出层self.output_layer_input = np.dot(self.hidden_layer_output, self.quantized_who_dequant) + self.bias_outputself.output = sigmoid(self.output_layer_input)return self.outputdef backward(self, X, y, output):# 计算误差self.error = y - output# 输出层梯度self.d_output = self.error * sigmoid_derivative(output)# 计算 weights_hidden_output 和 bias_output 的梯度self.gradients_who = np.dot(self.hidden_layer_output.T, self.d_output)self.gradients_bo = np.sum(self.d_output, axis=0, keepdims=True)# 反向传播到隐藏层# 注意:这里使用前向传播中使用的(可能已反量化的)权重进行梯度计算self.d_hidden_layer = np.dot(self.d_output, self.quantized_who_dequant.T) * sigmoid_derivative(self.hidden_layer_output)# 计算 weights_input_hidden 和 bias_hidden 的梯度self.gradients_wih = np.dot(X.T, self.d_hidden_layer)self.gradients_bh = np.sum(self.d_hidden_layer, axis=0, keepdims=True)# 对权重梯度应用STE (梯度直接通过量化器)self.gradients_wih = self.quantizer_wih.backward(self.gradients_wih)self.gradients_who = self.quantizer_who.backward(self.gradients_who)# 更新权重和偏置,加入L2正则化 (作用于原始的FP32权重)self.weights_hidden_output += self.learning_rate * self.gradients_who - self.learning_rate * self.l2_lambda * self.weights_hidden_outputself.bias_output += self.learning_rate * self.gradients_boself.weights_input_hidden += self.learning_rate * self.gradients_wih - self.learning_rate * self.l2_lambda * self.weights_input_hiddenself.bias_hidden += self.learning_rate * self.gradients_bh# XOR 输入和输出
X_xor = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y_xor = np.array([[0], [1], [1], [0]])# 网络参数
input_size = 2
hidden_size = 3#注意这里:3个节点.比2个节点容易收敛2 #5 # 隐藏层节点数
output_size = 1
learning_rate = 0.0241 # 学习率 (可能需要进一步调整)
epochs = 150000#200000 # 训练迭代次数 (减少了次数以便更快看到结果,原为500000)
l2_lambda_val = 0.0001 # L2 正则化系数# 初始化神经网络
nn = NeuralNetwork(input_size, hidden_size, output_size, learning_rate, l2_lambda=l2_lambda_val)# QAT训练循环
print("开始使用QAT训练神经网络...")
for i in range(epochs):# 前向传播,应用量化,并标记为训练阶段以更新EMA统计output = nn.forward(X_xor, quantize=True, is_training=True)# 反向传播和参数更新nn.backward(X_xor, y_xor, output)if i % 10000 == 0 or i == epochs -1 : # 每10000次迭代或最后一次迭代时打印损失loss = np.mean(np.square(y_xor - output))print(f"迭代次数 {i}, 损失: {loss:.6f}")print("\n训练完成。")# QAT后评估模型
print("\nQAT之后的预测结果:")
# 推理时 quantize=True, is_training=False (不更新EMA统计)
predictions = nn.forward(X_xor, quantize=True, is_training=False)
print(f"输入:\n{X_xor}")
print(f"期望输出:\n{y_xor}")
# 将预测结果四舍五入到0或1,因为XOR的输出是二元的
print(f"预测输出 (反量化后显示,并四舍五入):\n{np.round(predictions)}")# 显示量化后的权重 (INT4值)
# 注意:quantizer_wih.quantized_val 保存的是最近一次前向传播(推理时)的量化权重
# 我们可以在推理后再进行一次前向传播以确保这些值是最新的,或者直接使用它们
nn.forward(X_xor, quantize=True, is_training=False) # 确保量化值是最新的print("\n--- 量化后的模型权重 (INT4) ---")
if nn.quantizer_wih.quantized_val is not None:print("输入层到隐藏层权重 (INT4):\n", nn.quantizer_wih.quantized_val)print("量化参数 (scale, zero_point):", nn.quantizer_wih.scale, nn.quantizer_wih.zero_point)
else:print("输入层到隐藏层权重未被量化。")if nn.quantizer_who.quantized_val is not None:print("隐藏层到输出层权重 (INT4):\n", nn.quantizer_who.quantized_val)print("量化参数 (scale, zero_point):", nn.quantizer_who.scale, nn.quantizer_who.zero_point)
else:print("隐藏层到输出层权重未被量化。")# 为了验证,显示反量化后的权重 (表示INT4值的FP32形式)
print("\n--- 反量化后的权重 (INT4值的FP32表示) ---")
if nn.quantizer_wih.quantized_val is not None:dequant_wih = dequantize_int4(nn.quantizer_wih.quantized_val, nn.quantizer_wih.scale, nn.quantizer_wih.zero_point)print("输入层到隐藏层权重 (反量化后):\n", dequant_wih)
else:print("输入层到隐藏层权重未量化,无法显示反量化值。")if nn.quantizer_who.quantized_val is not None:dequant_who = dequantize_int4(nn.quantizer_who.quantized_val, nn.quantizer_who.scale, nn.quantizer_who.zero_point)print("隐藏层到输出层权重 (反量化后):\n", dequant_who)
else:print("隐藏层到输出层权重未量化,无法显示反量化值。")print("\n--- 训练后的原始 (FP32) 权重 ---")
print("输入层到隐藏层权重 (FP32):\n", nn.weights_input_hidden)
print("隐藏层到输出层权重 (FP32):\n", nn.weights_hidden_output)print("\n--- 用于量化的动态范围统计 (EMA) ---")
print(f"WIH Min: {nn.min_wih}, WIH Max: {nn.max_wih}")
print(f"WHO Min: {nn.min_who}, WHO Max: {nn.max_who}")