目录
Python实例题
题目
问题描述
解题思路
关键代码框架
难点分析
扩展方向
Python实例题
题目
基于联邦学习的隐私保护 AI 系统(分布式学习、隐私计算)
问题描述
开发一个基于联邦学习的隐私保护 AI 系统,包含以下功能:
- 联邦学习框架:支持多种机器学习模型的联邦训练
- 隐私保护机制:差分隐私、同态加密等技术保护数据隐私
- 模型聚合:安全聚合各参与方的模型参数
- 客户端管理:管理和协调多个参与训练的客户端
- 评估与部署:评估联邦模型性能并部署到生产环境
解题思路
- 采用横向或纵向联邦学习架构
- 实现安全聚合协议(如 FedAvg、FedProx)
- 应用差分隐私或同态加密保护数据隐私
- 设计客户端 - 服务器通信协议
- 开发模型评估和部署工具
关键代码框架
# 联邦学习服务器端
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import json
import logging
from typing import List, Dict, Any, Tuple
from cryptography.fernet import Fernet# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)class FedAvgServer:def __init__(self, model: nn.Module, clients: List[str], config: Dict[str, Any]):self.model = modelself.clients = clientsself.config = configself.global_round = 0self.client_models = {client: None for client in clients}self.client_weights = {client: 1.0 for client in clients} # 客户端权重# 初始化加密密钥self.encryption_key = Fernet.generate_key()self.cipher_suite = Fernet(self.encryption_key)# 初始化优化器self.optimizer = optim.SGD(self.model.parameters(), lr=config['learning_rate'])def aggregate_models(self) -> None:"""聚合客户端模型"""logger.info(f"开始第 {self.global_round} 轮模型聚合")# 检查是否所有客户端都提交了模型for client, model_params in self.client_models.items():if model_params is None:logger.warning(f"客户端 {client} 未提交模型,跳过此轮")return# 计算总权重total_weight = sum(self.client_weights.values())# 初始化全局模型参数global_params = {}for name, param in self.model.named_parameters():global_params[name] = torch.zeros_like(param.data)# 加权聚合for client, model_params in self.client_models.items():weight = self.client_weights[client] / total_weightfor name, param in model_params.items():global_params[name] += param * weight# 更新全局模型with torch.no_grad():for name, param in self.model.named_parameters():param.data.copy_(global_params[name])# 增加全局轮次self.global_round += 1# 重置客户端模型self.client_models = {client: None for client in self.clients}logger.info(f"第 {self.global_round-1} 轮模型聚合完成")def encrypt_model(self, model_params: Dict[str, torch.Tensor]) -> bytes:"""加密模型参数"""# 将模型参数转换为numpy数组并序列化为JSONmodel_dict = {name: param.numpy().tolist() for name, param in model_params.items()}model_json = json.dumps(model_dict).encode('utf-8')# 加密encrypted_data = self.cipher_suite.encrypt(model_json)return encrypted_datadef decrypt_model(self, encrypted_data: bytes) -> Dict[str, torch.Tensor]:"""解密模型参数"""# 解密decrypted_data = self.cipher_suite.decrypt(encrypted_data)model_dict = json.loads(decrypted_data.decode('utf-8'))# 转换回PyTorch张量model_params = {name: torch.tensor(param) for name, param in model_dict.items()}return model_paramsdef receive_client_model(self, client_id: str, encrypted_model: bytes, client_weight: float) -> None:"""接收客户端模型"""if client_id not in self.clients:logger.warning(f"未知客户端: {client_id}")returntry:# 解密模型model_params = self.decrypt_model(encrypted_model)# 存储客户端模型self.client_models[client_id] = model_paramsself.client_weights[client_id] = client_weightlogger.info(f"收到客户端 {client_id} 的模型,权重: {client_weight}")except Exception as e:logger.error(f"接收客户端模型失败: {e}")def send_global_model(self, client_id: str) -> bytes:"""向客户端发送全局模型"""if client_id not in self.clients:logger.warning(f"未知客户端: {client_id}")return None# 获取当前全局模型参数model_params = {name: param.data for name, param in self.model.named_parameters()}# 加密并发送return self.encrypt_model(model_params)def evaluate_model(self, test_loader: DataLoader) -> Tuple[float, float]:"""评估模型性能"""self.model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for inputs, targets in test_loader:outputs = self.model(inputs)loss = nn.CrossEntropyLoss()(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()accuracy = 100.0 * correct / totalavg_loss = test_loss / len(test_loader)logger.info(f"模型评估结果: 准确率 = {accuracy:.2f}%, 平均损失 = {avg_loss:.4f}")return accuracy, avg_lossdef save_model(self, path: str) -> None:"""保存模型"""torch.save(self.model.state_dict(), path)logger.info(f"模型已保存到: {path}")
# 联邦学习客户端
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import json
import logging
from typing import Dict, Any, List, Tuple
from cryptography.fernet import Fernet# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)class FedAvgClient:def __init__(self, client_id: str, model: nn.Module, train_data: Dataset, config: Dict[str, Any]):self.client_id = client_idself.model = modelself.train_data = train_dataself.config = config# 创建数据加载器self.train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True)# 初始化优化器self.optimizer = optim.SGD(self.model.parameters(), lr=config['learning_rate'])# 加密工具self.encryption_key = None # 将从服务器接收self.cipher_suite = Nonedef set_encryption_key(self, key: bytes) -> None:"""设置加密密钥"""self.encryption_key = keyself.cipher_suite = Fernet(key)def encrypt_model(self, model_params: Dict[str, torch.Tensor]) -> bytes:"""加密模型参数"""if self.cipher_suite is None:raise ValueError("未设置加密密钥")# 将模型参数转换为numpy数组并序列化为JSONmodel_dict = {name: param.numpy().tolist() for name, param in model_params.items()}model_json = json.dumps(model_dict).encode('utf-8')# 加密encrypted_data = self.cipher_suite.encrypt(model_json)return encrypted_datadef decrypt_model(self, encrypted_data: bytes) -> Dict[str, torch.Tensor]:"""解密模型参数"""if self.cipher_suite is None:raise ValueError("未设置加密密钥")# 解密decrypted_data = self.cipher_suite.decrypt(encrypted_data)model_dict = json.loads(decrypted_data.decode('utf-8'))# 转换回PyTorch张量model_params = {name: torch.tensor(param) for name, param in model_dict.items()}return model_paramsdef update_model(self, encrypted_global_model: bytes) -> None:"""更新本地模型为全局模型"""try:# 解密全局模型global_params = self.decrypt_model(encrypted_global_model)# 更新本地模型with torch.no_grad():for name, param in self.model.named_parameters():param.data.copy_(global_params[name])logger.info(f"客户端 {self.client_id} 模型已更新")except Exception as e:logger.error(f"更新模型失败: {e}")def train(self, epochs: int) -> Tuple[Dict[str, torch.Tensor], float]:"""本地训练模型"""self.model.train()for epoch in range(epochs):epoch_loss = 0batches = 0for inputs, targets in self.train_loader:self.optimizer.zero_grad()outputs = self.model(inputs)loss = nn.CrossEntropyLoss()(outputs, targets)loss.backward()self.optimizer.step()epoch_loss += loss.item()batches += 1avg_loss = epoch_loss / batcheslogger.info(f"客户端 {self.client_id}, 轮次 {epoch+1}/{epochs}, 平均损失: {avg_loss:.4f}")# 获取训练后的模型参数model_params = {name: param.data for name, param in self.model.named_parameters()}# 返回模型参数和样本数量(作为权重)return model_params, len(self.train_data)def get_encrypted_model(self, epochs: int = 1) -> bytes:"""训练并返回加密的模型参数"""model_params, weight = self.train(epochs)encrypted_model = self.encrypt_model(model_params)return encrypted_model, weight
# 联邦学习主程序
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Subset
import numpy as np
from typing import List, Dict, Any# 定义简单的CNN模型
class SimpleCNN(nn.Module):def __init__(self, num_classes=10):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2)self.fc1 = nn.Linear(32 * 7 * 7, 128)self.relu3 = nn.ReLU()self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = x.view(-1, 32 * 7 * 7)x = self.relu3(self.fc1(x))x = self.fc2(x)return xdef split_dataset(dataset, num_clients: int, iid: bool = True) -> List[Subset]:"""分割数据集给多个客户端"""num_samples = len(dataset) // num_clientsclient_datasets = []if iid:# IID方式分割(随机分配)indices = list(range(len(dataset)))np.random.shuffle(indices)for i in range(num_clients):client_indices = indices[i * num_samples : (i + 1) * num_samples]client_datasets.append(Subset(dataset, client_indices))else:# 非IID方式分割(按标签排序后分配)# 这里简化处理,实际应用中可能需要更复杂的分割策略labels = np.array([dataset[i][1] for i in range(len(dataset))])indices = np.argsort(labels)for i in range(num_clients):client_indices = indices[i * num_samples : (i + 1) * num_samples]client_datasets.append(Subset(dataset, client_indices))return client_datasetsdef run_federated_learning(config: Dict[str, Any]):"""运行联邦学习过程"""# 设置随机种子torch.manual_seed(config['seed'])np.random.seed(config['seed'])# 加载数据集transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST('data', train=False, transform=transform)# 分割训练数据给客户端client_datasets = split_dataset(train_dataset, config['num_clients'], config['iid'])# 创建测试数据加载器test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)# 初始化服务器和客户端global_model = SimpleCNN()server = FedAvgServer(global_model, [f"client{i}" for i in range(config['num_clients'])], config)clients = []for i in range(config['num_clients']):client_model = SimpleCNN()# 初始时客户端模型与全局模型相同client_model.load_state_dict(global_model.state_dict())client = FedAvgClient(f"client{i}", client_model, client_datasets[i], config)clients.append(client)# 分发加密密钥给客户端for client in clients:client.set_encryption_key(server.encryption_key)# 联邦学习训练循环for round in range(config['global_rounds']):logger.info(f"===== 开始第 {round+1}/{config['global_rounds']} 轮联邦学习 =====")# 选择参与本轮的客户端selected_clients = np.random.choice(clients, size=min(config['clients_per_round'], len(clients)), replace=False)# 向客户端发送全局模型for client in selected_clients:encrypted_global_model = server.send_global_model(client.client_id)client.update_model(encrypted_global_model)# 客户端本地训练for client in selected_clients:encrypted_model, client_weight = client.get_encrypted_model(config['local_epochs'])server.receive_client_model(client.client_id, encrypted_model, client_weight)# 服务器聚合模型server.aggregate_models()# 评估全局模型if (round + 1) % config['eval_every'] == 0:accuracy, loss = server.evaluate_model(test_loader)logger.info(f"第 {round+1} 轮评估结果: 准确率 = {accuracy:.2f}%, 损失 = {loss:.4f}")# 保存最终模型server.save_model(config['model_save_path'])logger.info("联邦学习训练完成")# 配置参数
config = {'seed': 42,'num_clients': 10,'clients_per_round': 5,'global_rounds': 50,'local_epochs': 5,'batch_size': 64,'learning_rate': 0.01,'iid': True, # 是否IID数据分布'eval_every': 5, # 每多少轮评估一次'model_save_path': 'federated_model.pth'
}# 运行联邦学习
if __name__ == "__main__":run_federated_learning(config)
难点分析
- 隐私保护与模型性能平衡:在保护隐私的同时保持模型准确性
- 通信效率:减少客户端与服务器之间的通信开销
- 异构设备处理:处理不同性能客户端的参与
- 安全聚合协议:实现安全的模型参数聚合
- 恶意参与者检测:识别和处理恶意参与方
扩展方向
- 实现更高级的隐私保护技术(如差分隐私、同态加密)
- 添加自适应学习率调整机制
- 支持增量训练和持续学习
- 开发联邦学习可视化监控界面
- 实现跨平台联邦学习(移动端、边缘设备)