@疏锦行
知识点回顾:
1. 过拟合的判断:测试集和训练集同步打印指标
2. 模型的保存和加载
a. 仅保存权重
b. 保存权重和模型
c. 保存全部信息checkpoint,还包含训练状态
3. 早停策略
作业:对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略
# 保存模型权重
torch.save(model.state_dict(), 'credit_model_weights.pth')# 加载模型权重
model.load_state_dict(torch.load('credit_model_weights.pth'))# 设置继续训练的轮数
additional_epochs = 50for epoch in range(additional_epochs):# 前向传播outputs = model(X_train_tensor)loss = criterion(outputs, y_train_tensor)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{additional_epochs}], Loss: {loss.item():.4f}')# 保存继续训练后的模型权重
torch.save(model.state_dict(), 'credit_model_weights_continued.pth')
# 早停策略参数
patience = 10 # 容忍验证集损失不下降的最大轮数
best_val_loss = float('inf')
counter = 0for epoch in range(num_epochs):# 训练代码model.train()outputs = model(X_train_tensor)train_loss = criterion(outputs, y_train_tensor)optimizer.zero_grad()train_loss.backward()optimizer.step()# 验证代码model.eval()with torch.no_grad():val_outputs = model(X_val_tensor)val_loss = criterion(val_outputs, y_val_tensor)print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}')# 早停策略逻辑if val_loss < best_val_loss:best_val_loss = val_losscounter = 0# 保存最佳模型权重torch.save(model.state_dict(), 'best_credit_model_weights.pth')else:counter += 1if counter >= patience:print('Early stopping!')break