如何把torch_geometric.data模块下创建的 Data类图数据对象 批量化处理成一张图?

  • Post author:
  • Post category:其他




问题描述:

如何用PyG 表示多张图(torch_geometric.data.Batch)?把 torch_geometric.data.Data 多张Data类图对象拼接成一个batch,其目的是批量化处理多张图,如图所示。

在这里插入图片描述



代码实例:

import torch
from torch_geometric.data import Data
from torch_geometric.data.batch import Batch


edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

edge_index_3 = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x_3 = torch.randn(4, 16)

data1= Data(x=x_s,edge_index=edge_index_s)
data2= Data(x=x_t,edge_index=edge_index_t)
data3= Data(x=x_3,edge_index=edge_index_3)
#上面是构建3张Data图对象
# * `Batch(Data)` in case `Data` objects are batched together
#* `Batch(HeteroData)` in case `HeteroData` objects are batched together

data_list = [data1, data2,data3]


loader = Batch.from_data_list(data_list)#调用该函数data_list里的data1、data2、data3 三张图形成一张大图,也就是batch
print('data_list:\n',data_list)
#data_list: [Data(edge_index=[2, 4], x=[5, 16]), Data(edge_index=[2, 3], x=[4, 16]), Data(edge_index=[2, 4], x=[4, 16])]
print('batch:',loader.batch)
#batch: tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
print('loader:',loader)
#loader: Batch(batch=[13], edge_index=[2, 11], x=[13, 16])
print('loader.edge_index:\n',loader.edge_index) #batch的边的元组
#loader.edge_index:
#tensor([[ 0,  0,  0,  0,  5,  5,  5,  9, 10, 10, 11],
#        [ 1,  2,  3,  4,  6,  7,  8, 10,  9, 11, 10]])

print('loader.num_graphs:',loader.num_graphs)#该batch的图的个数,这里是3个
#loader.num_graphs: 3

Batch=Batch.to_data_list(loader)#大图Batch变回成3张小图
print(Batch)
#[Data(edge_index=[2, 4], x=[5, 16]), Data(edge_index=[2, 3], x=[4, 16]), Data(edge_index=[2, 4], x=[4, 16])]



版权声明:本文为qq_41800917原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。