「日拱一码」027 深度学习库——PyTorch Geometric(PyG)

目录

数据处理与转换

数据表示

数据加载

数据转换

特征归一化

添加自环

随机扰动

组合转换

图神经网络层

图卷积层(GCNConv)

图注意力层(GATConv)

池化

全局池化(Global Pooling)

全局平均池化

全局最大池化

全局求和池化

基于注意力的池化(Attention-based Pooling)

基于图的池化(Graph-based Pooling)

层次化池化(Hierarchical Pooling)

采样

子图采样(Subgraph Sampling)

邻域采样(Neighbor Sampling)

模型训练与评估

训练过程

测试过程

异构图处理

异构图定义

异构图卷积

图生成模型

Deep Graph Infomax (DGI)

Graph Autoencoder (GAE)

Variational Graph Autoencoder (VGAE)


PyTorch Geometric(PyG)是PyTorch的一个扩展库,专注于图神经网络(GNN)的实现。它提供了丰富的图数据处理工具、图神经网络层和模型。以下是对PyG库中常用方法的介绍

数据处理与转换

数据表示

PyG使用 torch_geometric.data.Data 类来表示图数据,包含节点特征 x 、边索引 edge_index 、边特征 edge_attr 等

## 数据处理与转换
# 1. 数据表示
import torch
from torch_geometric.data import Data# 创建一个简单的图
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float)  # 节点特征
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)  # 边索引
edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float)  # 边特征data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
print(data)  # Data(x=[3, 2], edge_index=[2, 4], edge_attr=[4])

数据加载

PyG提供了 torch_geometric.data.DataLoader 类,用于批量加载图数据

# 2. 数据加载
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader# 加载Cora数据集
dataset = Planetoid(root='./data', name='Cora')
loader = DataLoader(dataset, batch_size=32, shuffle=True)print(f"节点数: {data.num_nodes}")  # 3
print(f"边数: {data.num_edges}")  # 4
print(f"特征维度: {data.num_node_features}")  # 2
print(f"类别数: {dataset.num_classes}")  # 7for batch in loader:print(batch)# DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708],#           batch=[2708], ptr=[2])

数据转换

  • 特征归一化

NormalizeFeatures  是一个常用的转换方法,用于将节点特征归一化到单位范数(如 0, 1 或 -1, 1)

# 3. 数据转换
# 3.1 特征归一化
from torch_geometric.transforms import NormalizeFeaturesdataset = Planetoid(root='./data', name='Cora', transform=NormalizeFeatures())# 查看归一化后的特征
data = dataset[0]
print(data.x)
# tensor([[0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         ...,
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.]])
  • 添加自环

AddSelfLoops  是一个转换方法,用于为图中的每个节点添加自环(即每个节点连接到自己)

# 3.2 添加自环
from torch_geometric.transforms import AddSelfLoopsdataset = Planetoid(root='./data', name='Cora', transform=AddSelfLoops())# 查看添加自环后的边索引
data = dataset[0]
print(data.edge_index)
# tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
#         [ 633, 1862, 2582,  ..., 2705, 2706, 2707]])
  • 随机扰动

RandomNodeSplit  是一个转换方法,用于随机划分训练集、验证集和测试集

# 3.3 随机扰动
from torch_geometric.transforms import RandomNodeSplitdataset = Planetoid(root='./data', name='Cora', transform=RandomNodeSplit(num_splits=10))# 查看划分后的掩码
data = dataset[0]
print(data.train_mask)
# tensor([[False,  True,  True,  ..., False, False,  True],
#         [False, False,  True,  ...,  True, False, False],
#         [False,  True, False,  ..., False, False, False],
#         ...,
#         [ True,  True,  True,  ..., False, False, False],
#         [ True,  True,  True,  ..., False, False,  True],
#         [ True,  True,  True,  ...,  True, False,  True]])
print(data.val_mask)
# tensor([[False, False, False,  ..., False,  True, False],
#         [False, False, False,  ..., False, False, False],
#         [False, False, False,  ..., False, False,  True],
#         ...,
#         [False, False, False,  ...,  True,  True, False],
#         [False, False, False,  ..., False,  True, False],
#         [False, False, False,  ..., False, False, False]])
print(data.test_mask)
# tensor([[ True, False, False,  ...,  True, False, False],
#         [ True,  True, False,  ..., False,  True,  True],
#         [ True, False,  True,  ...,  True,  True, False],
#         ...,
#         [False, False, False,  ..., False, False,  True],
#         [False, False, False,  ...,  True, False, False],
#         [False, False, False,  ..., False,  True, False]])
  • 组合转换

可以将多个转换方法组合在一起,形成一个复合转换

# 3.4 组合转换
from torch_geometric.transforms import Compose, NormalizeFeatures, AddSelfLoops# 定义一个复合转换
transform = Compose([NormalizeFeatures(), AddSelfLoops()])# 创建一个数据集,并应用复合转换
dataset = Planetoid(root='./data', name='Cora', transform=transform)# 查看转换后的数据
data = dataset[0]
print(data.x)
# tensor([[0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         ...,
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.]])
print(data.edge_index)
# tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
#         [ 633, 1862, 2582,  ..., 2705, 2706, 2707]])

图神经网络层

图卷积层(GCNConv)

GCNConv是图卷积网络(GCN)的基本层

## 图神经网络层
# 1. 图卷积层 GCNConv
import torch
from torch_geometric.nn import GCNConvclass GCN(torch.nn.Module):def __init__(self, in_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, 16)self.conv2 = GCNConv(16, out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = torch.relu(x)x = self.conv2(x, edge_index)return xmodel = GCN(in_channels=dataset.num_features, out_channels=dataset.num_classes)
print(model)
# GCN(
#   (conv1): GCNConv(1433, 16)
#   (conv2): GCNConv(16, 7)
# )

图注意力层(GATConv)

GATConv是图注意力网络(GAT)的基本层

# 2. 图注意力层 GATConv
from torch_geometric.nn import GATConvclass GAT(torch.nn.Module):def __init__(self, in_channels, out_channels):super(GAT, self).__init__()self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=True, dropout=0.6)def forward(self, x, edge_index):x = torch.dropout(x, p=0.6, training=self.training)x = self.conv1(x, edge_index)x = torch.relu(x)x = torch.dropout(x, p=0.6, training=self.training)x = self.conv2(x, edge_index)return xmodel = GAT(in_channels=dataset.num_features, out_channels=dataset.num_classes)
print(model)
# GAT(
#   (conv1): GATConv(1433, 8, heads=8)
#   (conv2): GATConv(64, 7, heads=1)
# )

池化

全局池化(Global Pooling)

全局池化将整个图的所有节点聚合为一个全局表示

  • 全局平均池化
## 池化
# 1. 全局池化
# 1.1 全局平均池化
from torch_geometric.nn import global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_mean = global_mean_pool(x, batch_index)print("Global Mean Pooling Result:", global_mean)break
# tensor([[0.7647, 0.0588, 0.1176, 0.0000, 0.0588, 0.0000, 0.0000],
#         [0.7500, 0.1250, 0.1250, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6250, 0.1250, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5217, 0.1739, 0.3043, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.5455, 0.2727, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6400, 0.1200, 0.2400, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6364, 0.0909, 0.2727, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7857, 0.0714, 0.1429, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8000, 0.0667, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7647, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0588],
#         [0.5000, 0.1667, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7692, 0.0769, 0.1538, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7826, 0.0435, 0.1739, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8000, 0.0500, 0.1500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6667, 0.0833, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8125, 0.0625, 0.1250, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8235, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6000, 0.1000, 0.2000, 0.0000, 0.0000, 0.1000, 0.0000],
#         [0.4615, 0.1538, 0.3077, 0.0000, 0.0000, 0.0769, 0.0000],
#         [0.7647, 0.0588, 0.1765, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.6000, 0.0500, 0.2000, 0.0000, 0.0000, 0.1500, 0.0000],
#         [0.6000, 0.1000, 0.2000, 0.1000, 0.0000, 0.0000, 0.0000],
#         [0.7647, 0.0588, 0.1176, 0.0000, 0.0000, 0.0588, 0.0000],
#         [0.8000, 0.0667, 0.1333, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.4615, 0.1538, 0.3077, 0.0769, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8696, 0.0435, 0.0870, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.7273, 0.0909, 0.1818, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.8421, 0.0526, 0.1053, 0.0000, 0.0000, 0.0000, 0.0000]])
  • 全局最大池化
# 1.2 全局最大池化
from torch_geometric.nn import global_max_pool# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_max = global_max_pool(x, batch_index)print("Global Max Pooling Result:", global_max)break
  • 全局求和池化
# 3. 全局求和池化
from torch_geometric.nn import global_add_pool# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_sum = global_add_pool(x, batch_index)print("Global Sum Pooling Result:", global_sum)break

基于注意力的池化(Attention-based Pooling)

基于注意力的池化方法通过学习节点的重要性权重来进行池化。一个常见的例子是 Set2Set 池化

# 2. 基于注意力的池化——Set2Set
from torch_geometric.nn import Set2Set
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定义 Set2Set 池化
set2set = Set2Set(in_channels=dataset.num_node_features, processing_steps=3)# 获取一个批次的数据
for batch in loader:x = batch.xbatch_index = batch.batchglobal_set2set = set2set(x, batch_index)print("Set2Set Pooling Result:", global_set2set)break
# Set2Set Pooling Result: tensor([[ 0.1719,  0.0986,  0.1594, -0.0438,  0.1743,  0.1663, -0.0578,  0.8464,
#           0.0492,  0.1045,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1730,  0.0987,  0.1601, -0.0420,  0.1730,  0.1658, -0.0549,  0.8733,
#           0.0405,  0.0862,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1686,  0.0919,  0.1603, -0.0525,  0.1807,  0.1707, -0.0683,  0.7540,
#           0.0466,  0.1994,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1601,  0.1165,  0.1425, -0.0525,  0.1782,  0.1602, -0.0836,  0.6232,
#           0.2237,  0.1531,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1725,  0.0987,  0.1598, -0.0428,  0.1736,  0.1660, -0.0562,  0.8611,
#           0.0444,  0.0945,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1673,  0.1060,  0.1527, -0.0473,  0.1761,  0.1642, -0.0679,  0.7570,
#           0.1187,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1579,  0.0996,  0.1486, -0.0658,  0.1874,  0.1695, -0.0954,  0.5284,
#           0.1662,  0.3054,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1584,  0.0969,  0.1503, -0.0665,  0.1881,  0.1709, -0.0949,  0.5327,
#           0.1503,  0.3170,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1634,  0.0976,  0.1537, -0.0581,  0.1835,  0.1695, -0.0809,  0.6464,
#           0.1135,  0.2401,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1704,  0.0952,  0.1599, -0.0479,  0.1774,  0.1684, -0.0626,  0.8042,
#           0.0466,  0.1492,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1634,  0.1081,  0.1488, -0.0522,  0.1789,  0.1640, -0.0776,  0.6743,
#           0.1595,  0.1661,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1564,  0.1193,  0.1384, -0.0562,  0.1800,  0.1590, -0.0922,  0.5527,
#           0.2663,  0.1810,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1704,  0.0952,  0.1599, -0.0479,  0.1774,  0.1684, -0.0626,  0.8042,
#           0.0466,  0.1492,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1730,  0.0987,  0.1601, -0.0420,  0.1730,  0.1658, -0.0549,  0.8733,
#           0.0405,  0.0862,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1604,  0.0972,  0.1517, -0.0631,  0.1863,  0.1704, -0.0893,  0.5779,
#           0.1356,  0.2864,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1711,  0.0985,  0.1589, -0.0451,  0.1752,  0.1666, -0.0599,  0.8281,
#           0.0550,  0.1169,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1584,  0.1178,  0.1406, -0.0542,  0.1790,  0.1597, -0.0875,  0.5910,
#           0.2432,  0.1659,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1673,  0.1060,  0.1527, -0.0473,  0.1761,  0.1642, -0.0679,  0.7570,
#           0.1187,  0.1243,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1602,  0.1097,  0.1457, -0.0562,  0.1811,  0.1638, -0.0856,  0.6077,
#           0.1926,  0.1997,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1696,  0.1047,  0.1549, -0.0444,  0.1743,  0.1641, -0.0623,  0.8062,
#           0.0945,  0.0993,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1604,  0.0972,  0.1517, -0.0631,  0.1863,  0.1704, -0.0893,  0.5779,
#           0.1356,  0.2864,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1671,  0.0980,  0.1563, -0.0518,  0.1797,  0.1682, -0.0707,  0.7332,
#           0.0855,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1638,  0.0919,  0.1567, -0.0605,  0.1855,  0.1723, -0.0815,  0.6416,
#           0.0853,  0.2731,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1732,  0.1049,  0.1525, -0.0508,  0.1755,  0.1624, -0.0665,  0.7700,
#           0.0553,  0.1160,  0.0000,  0.0000,  0.0586,  0.0000],
#         [ 0.1711,  0.0985,  0.1589, -0.0451,  0.1752,  0.1666, -0.0599,  0.8281,
#           0.0550,  0.1169,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1573,  0.0968,  0.1494, -0.0685,  0.1891,  0.1712, -0.0982,  0.5063,
#           0.1589,  0.3349,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1729,  0.1053,  0.1365, -0.0582,  0.1904,  0.1594, -0.0881,  0.5637,
#           0.0878,  0.1812,  0.0746,  0.0000,  0.0927,  0.0000],
#         [ 0.1586,  0.1026,  0.1477, -0.0628,  0.1855,  0.1678, -0.0924,  0.5526,
#           0.1742,  0.2733,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1646,  0.1075,  0.1500, -0.0506,  0.1781,  0.1641, -0.0746,  0.6999,
#           0.1469,  0.1533,  0.0000,  0.0000,  0.0000,  0.0000],
#         [ 0.1695,  0.0983,  0.1579, -0.0477,  0.1770,  0.1672, -0.0641,  0.7909,
#           0.0670,  0.1421,  0.0000,  0.0000,  0.0000,  0.0000]],
#        grad_fn=<CatBackward0>)

基于图的池化(Graph-based Pooling)

基于图的池化方法通过图的结构信息来进行池化。常见的方法包括 TopKPooling,通过选择重要性最高的节点来进行池化

# 3. 基于图的池化——TopKPooling
from torch_geometric.nn import TopKPooling
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定义 TopKPooling
pool = TopKPooling(in_channels=dataset.num_node_features, ratio=0.5) # 获取一个批次的数据
for batch in loader:x = batch.xedge_index = batch.edge_indexbatch_index = batch.batchx, edge_index, _, batch_index, _, _ = pool(x, edge_index, batch=batch_index)print("TopKPooling Result:", x)break
# tensor([[-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0392, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         ...,
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.4577, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000]],
#        grad_fn=<MulBackward0>)

层次化池化(Hierarchical Pooling)

层次化池化通过多层池化操作生成图的层次化表示。一个常见的例子是 EdgePooling,通过边的合并操作来进行池化

# 4. 层次化池化——EdgePooling
from torch_geometric.nn import EdgePooling
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader# 加载数据集
dataset = TUDataset(root='./data', name='MUTAG')
loader = DataLoader(dataset, batch_size=32, shuffle=True)# 定义 EdgePooling
pool = EdgePooling(in_channels=dataset.num_node_features)  # 获取一个批次的数据
for batch in loader:x = batch.xedge_index = batch.edge_indexbatch_index = batch.batchx, edge_index, batch_index, _ = pool(x, edge_index, batch=batch_index)print("EdgePooling Result:", x)break# tensor([[0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 1.5000, 1.5000,  ..., 0.0000, 0.0000, 0.0000],
#         ...,
#         [0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
#         [1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
#         [0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000]],
#        grad_fn=<MulBackward0>)

采样

子图采样(Subgraph Sampling)

子图采样是从原始图中提取一个子图,通常用于减少计算复杂度和增强模型的泛化能力

## 采样
# 1. 子图采样
import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import k_hop_subgraph# 加载数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]# 选择一个起始节点
start_node = 0
num_hops = 2  # 采样半径# 提取子图
sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(start_node, num_hops, data.edge_index)# 创建子图
sub_data = Data(x=data.x[sub_nodes], edge_index=sub_edge_index, y=data.y[sub_nodes])print("Original Graph Nodes:", data.num_nodes)  # 2708
print("Subgraph Nodes:", sub_data.num_nodes)  # 8
print("Subgraph Edges:", sub_data.edge_index.shape[1])  # 20

邻域采样(Neighbor Sampling)

邻域采样通过选择节点的邻居来生成子图,适用于大规模图数据

# 2. 邻域采样
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader# 加载数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]# 定义 NeighborSampler
loader = NeighborLoader(data,num_neighbors=[10, 10],  # 每层采样的邻居数量batch_size=1024,shuffle=True,
)# 遍历数据加载器
for batch in loader:print(batch)break

模型训练与评估

训练过程

## 模型训练与评估
# 1. 训练过程
import torch.nn.functional as Foptimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss

测试过程

# 2. 测试过程
@torch.no_grad()
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())acc = correct / int(data.test_mask.sum())return accfor epoch in range(200):loss = train()acc = test()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

异构图处理

异构图定义

## 异构图处理
# 1. 异构图定义
from torch_geometric.data import HeteroData
import torchdata = HeteroData()
# 添加两种类型节点
data['user'].x = torch.randn(4, 16)  # 4个用户
data['movie'].x = torch.randn(5, 32)  # 5部电影
# 添加边
data['user', 'rates', 'movie'].edge_index = torch.tensor([[0, 0, 1, 2, 3], [0, 2, 3, 1, 4]]  # user->movie评分关系
)

异构图卷积

# 2. 异构图卷积
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv
from torch_geometric.transforms import NormalizeFeaturesclass HeteroGNN(torch.nn.Module):def __init__(self, in_channels, out_channels, hidden_channels):super().__init__()self.conv1 = HeteroConv({('user', 'rates', 'movie'): SAGEConv((in_channels['user'], in_channels['movie']), hidden_channels),('movie', 'rev_rates', 'user'): GCNConv(in_channels['movie'], hidden_channels, add_self_loops=False)  # 禁用自环}, aggr='sum')self.conv2 = HeteroConv({('user', 'rates', 'movie'): SAGEConv((hidden_channels, hidden_channels), out_channels),('movie', 'rev_rates', 'user'): GCNConv(hidden_channels, out_channels, add_self_loops=False)  # 禁用自环}, aggr='sum')def forward(self, x_dict, edge_index_dict):x_dict = self.conv1(x_dict, edge_index_dict)x_dict = {key: torch.relu(x) for key, x in x_dict.items()}x_dict = self.conv2(x_dict, edge_index_dict)return x_dict# 定义输入和输出通道数
in_channels = {'user': 16, 'movie': 32}
out_channels = 7  # 假设输出通道数为7
hidden_channels = 64  # 假设隐藏层通道数为64# 实例化模型
model = HeteroGNN(in_channels, out_channels, hidden_channels)
print(model)
# HeteroGNN(
#   (conv1): HeteroConv(num_relations=2)
#   (conv2): HeteroConv(num_relations=2)
# )

图生成模型

Deep Graph Infomax (DGI)

DGI 是一种无监督图表示学习方法,通过最大化局部和全局图表示之间的一致性来学习节点嵌入

## 图生成模型
# 1. Deep Graph Infomax (DGI)
from torch_geometric.nn import DeepGraphInfomax
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn as nn
import torch.nn.functional as Fclass Encoder(nn.Module):def __init__(self, in_channels, hidden_channels):super(Encoder, self).__init__()self.conv = GCNConv(in_channels, hidden_channels)self.prelu = nn.PReLU(hidden_channels)def forward(self, x, edge_index):x = self.conv(x, edge_index)x = self.prelu(x)return xdef corruption(x, edge_index):return x[torch.randperm(x.size(0))], edge_indexdataset = Planetoid(root='./data', name='Cora')
data = dataset[0]encoder = Encoder(dataset.num_features, hidden_channels=512)
model = DeepGraphInfomax(hidden_channels=512, encoder=encoder,summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),corruption=corruption
)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()optimizer.zero_grad()pos_z, neg_z, summary = model(data.x, data.edge_index)loss = model.loss(pos_z, neg_z, summary)loss.backward()optimizer.step()return lossfor epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

Graph Autoencoder (GAE)

GAE 是一种基于图神经网络的自编码器,用于图生成任务。它通过学习节点嵌入来重建图的邻接矩阵

# 2. Graph Autoencoder(GAE)
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GAE
import torch.nn.functional as Fclass Encoder(nn.Module):def __init__(self, in_channels, out_channels):super(Encoder, self).__init__()self.conv1 = GCNConv(in_channels, 2 * out_channels)self.conv2 = GCNConv(2 * out_channels, out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)return self.conv2(x, edge_index)dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]encoder = Encoder(dataset.num_features, out_channels=16)
model = GAE(encoder)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)def train():model.train()optimizer.zero_grad()z = model.encode(data.x, data.edge_index)loss = model.recon_loss(z, data.edge_index)loss.backward()optimizer.step()return lossfor epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

Variational Graph Autoencoder (VGAE)

VGAE 是 GAE 的变体,通过引入变分推断来学习节点嵌入的分布

# 3. Variational Graph Autoencoder(VGAE)
from torch_geometric.nn import VGAE
from torch_geometric.datasets import Planetoid# 定义数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]class Encoder(nn.Module):def __init__(self, in_channels, out_channels):super(Encoder, self).__init__()self.conv1 = GCNConv(in_channels, 2 * out_channels)self.conv2 = GCNConv(2 * out_channels, 2 * out_channels)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)mu = x[:, :x.size(1) // 2]logstd = x[:, x.size(1) // 2:]return mu, logstd# 定义 Encoder
encoder = Encoder(dataset.num_features, out_channels=16)# 定义 VGAE 模型
model = VGAE(encoder)# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 训练函数
def train():model.train()optimizer.zero_grad()z = model.encode(data.x, data.edge_index)loss = model.recon_loss(z, data.edge_index)kl_loss = model.kl_loss()loss += kl_lossloss.backward()optimizer.step()return loss# 训练模型
for epoch in range(100):loss = train()print(f'Epoch: {epoch + 1}, Loss: {loss:.4f}')

    本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
    如若转载,请注明出处:http://www.pswp.cn/web/89092.shtml
    繁体地址,请注明出处:http://hk.pswp.cn/web/89092.shtml

    如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

    相关文章

    IoC容器深度解析:架构、原理与实现

    &#x1f31f; IoC容器深度解析&#xff1a;架构、原理与实现 引用&#xff1a; .NET IoC容器原理与实现等巫山的雲彩都消散撒下的碧色如何看淡 &#x1f50d; 一、引言&#xff1a;从服务定位器到IoC的演进 #mermaid-svg-BmRIuI4iMgiUqFVN {font-family:"trebuchet ms&…

    从零开始学前端html篇3

    表单基本结构表单是 HTML 中用于创建用户输入区域的标签。它允许用户输入数据&#xff08;例如文本、选择选项、文件等&#xff09;&#xff0c;并将这些数据提交到服务器进行处理。<form>&#xff0c;表单标签&#xff0c;用于创建表单常用属性&#xff1a;action&#…

    Linux系统调优和工具

    Linux系统调优和问题定位需要掌握一系列强大的工具&#xff0c;涵盖系统监控、性能分析、故障排查等多个方面。以下是一些核心工具和它们的典型应用场景&#xff0c;分类整理如下&#xff1a; 一、系统资源监控&#xff08;实时概览&#xff09;top / htop 功能&#xff1a; 实…

    如何快速有效地在WordPress中添加Instagram动态

    在当今社交媒体的时代&#xff0c;通过展示Instagram的最新动态&#xff0c;可以有效吸引读者的目光&#xff0c;同时丰富网站内容。很多人想知道&#xff0c;如何把自己精心运营的Instagram内容无缝嵌入WordPress网站呢&#xff1f;别担心&#xff0c;操作并不复杂&#xff0c…

    spring容器加载工具类

    在Spring框架中&#xff0c;工具类通常不需要被Spring容器管理&#xff0c;但如果确实需要获取Spring容器中的Bean实例&#xff0c;可以通过静态方法设置和获取ApplicationContext。下面是一个典型的Spring容器加载工具类的实现&#xff1a;这个工具类通过实现ApplicationConte…

    定时器更新中断与串口中断

    问题&#xff1a;我想把打印姿态传感器的角度&#xff0c;但是重定向的打印函数突然打印不出来。尝试&#xff1a;我怀疑是优先级的问题&#xff0c;故调整了串口&#xff0c;定时器&#xff0c;dma的优先级可是发现调了还是没有用&#xff0c;最终发现&#xff0c;我把定时器中…

    用Python向PDF添加文本:精确插入文本到PDF文档

    PDF 文档的版式特性使其适用于输出不可变格式的报告与合同。但若要在此类文档中插入或修改文本&#xff0c;常规方式难以实现。借助Python&#xff0c;我们可以高效地向 PDF 添加文本&#xff0c;实现从文档生成到内容管理的自动化流程。 本文将从以下方面介绍Python实现PDF中…

    Quick API:赋能能源行业,化解数据痛点

    随着全球能源结构的转型和数字化的深入推进&#xff0c;能源行业正面临前所未有的机遇与挑战。海量的实时数据、复杂的业务系统、以及对数据安全和高效利用的迫切需求&#xff0c;都成为了能源企业在数字化转型道路上的核心痛点。本文将深入探讨麦聪Quick API如何凭借其独特优势…

    Google Chrome V8< 13.6.86 类型混淆漏洞

    【高危】Google Chrome V8< 13.6.86 类型混淆漏洞 漏洞描述 Google Chrome 是美国谷歌&#xff08;Google&#xff09;公司的一款Web浏览器&#xff0c;V8 是 Google 开发的高性能开源 JavaScript 和 WebAssembly 引擎&#xff0c;广泛应用于 Chrome 浏览器和 Node.js 等环…

    力扣经典算法篇-23-环形链表(哈希映射法,快慢指针法)

    1、题干 给你一个链表的头节点 head &#xff0c;判断链表中是否有环。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&#xff0c;评测系统内部使用整数 pos 来表示链表尾连接到链表中的位置&…

    HarmonyOS DevEco Studio 小技巧 42 - 鸿蒙单向数据流

    在鸿蒙应用开发中&#xff0c;状态管理是构建响应式界面的核心支柱&#xff0c;而 单向数据流&#xff08;Unidirectional Data Flow, UDF&#xff09;作为鸿蒙架构的重要设计原则&#xff0c;贯穿于组件通信、状态更新和界面渲染的全流程。本文将结合鸿蒙 ArkUI 框架特性&…

    【LeetCode 3136. 有效单词】解析

    目录LeetCode中国站原文原始题目题目描述示例 1&#xff1a;示例 2&#xff1a;示例 3&#xff1a;提示&#xff1a;讲解化繁为简&#xff1a;如何优雅地“盘”逻辑判断题第一部分&#xff1a;算法思想 —— “清单核对”与“一票否决”第二部分&#xff1a;代码实现 —— 清晰…

    前端面试专栏-算法篇:24. 算法时间与空间复杂度分析

    &#x1f525; 欢迎来到前端面试通关指南专栏&#xff01;从js精讲到框架到实战&#xff0c;渐进系统化学习&#xff0c;坚持解锁新技能&#xff0c;祝你轻松拿下心仪offer。 前端面试通关指南专栏主页 前端面试专栏规划详情 算法时间与空间复杂度分析&#xff1a;从理论到实践…

    bash中||与的区别

    在 Bash 中&#xff0c;|| 和 && 是两种常用的逻辑操作符&#xff0c;用于控制命令的执行流程。它们的核心区别如下&#xff1a;1. ||&#xff08;逻辑 OR&#xff09; 作用&#xff1a;如果前一个命令失败&#xff08;返回非零退出码&#xff09;&#xff0c;则执行后…

    OpenCV实现感知哈希(Perceptual Hash)算法的类cv::img_hash::PHash

    操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 PHash是OpenCV中实现感知哈希&#xff08;Perceptual Hash&#xff09;算法的类。该算法用于快速比较图像的视觉相似性。它将图像压缩为一个简短的…

    数据库迁移人大金仓数据库

    迁移前的准备工作 安装官方的kdts和KStudio工具 方案说明 一、数据库迁移&#xff1a;可以使用kdts进行数据库的按照先迁移表结构、后数据的顺序迁移&#xff08;kdts的使用可以参考官方文档&#xff09; 其他参考文档 人大金仓官网&#xff1a;https://download.kingbase…

    uniapp 微信小程序Vue3项目使用内置组件movable-area封装悬浮可拖拽按钮(拖拽结束时自动吸附到最近的屏幕边缘)

    一、最终效果 二、具体详情请看movable-area与movable-view官方文档说明 三、参数配置 1、代码示例 <TFab title"新建订单" click"addOrder" /> // title:表按钮文案 // addOrder:点击按钮事件四、组件源码 <template><movable-area cl…

    linux kernel为什么要用IS_ERR()宏来判断指针合法性?

    在 Linux 内核中&#xff0c;IS_ERR() 宏的设计与内核的错误处理机制和指针编码规范密切相关&#xff0c;主要用于判断一个“可能携带错误码的指针”是否代表异常状态。其核心目的是解决内核中指针返回值与错误码的统一表示问题。以下从技术背景、设计逻辑和实际场景三个维度详…

    Cookie与Session:Web开发核心差异详解

    理解 Cookie 和 Session 的区别对于 Web 开发至关重要,它们虽然经常一起使用,但扮演着不同的角色。核心区别在于: Cookie:存储在客户端(用户的浏览器)的数据片段。 Session:存储在服务器端的数据结构,用于跟踪特定用户的状态。 下面是详细的对比: 特性CookieSession…

    【相干、相参】 雷电名词溯源

    〇、废话因缘 最近某些国产的微波制造公司总是提到一个概念【相干】【相参】【严格相参】等等概念层出不穷&#xff0c;让人苦恼。 一、这玩意还是英文溯源吧 这几个概念都聚焦在一个单词【Coherence】&#xff1b;所以就是说两个波形之间有某种联系&#xff0c;不一定就是完全…