复现KGAT: Knowledge Graph Attention Network for Recommendation(五)
4.19 打开这堆代码看了半天,真TMD的难啊,这是我动手复现的第一篇推荐论文。上一个看过代码的是NCF,那个难度简直没法和这个相提并论,准确的说,有点机器学习、深度学习基础的那个代码看起来分分钟的事。
这篇。。。对我这个新手很不友好啊。。。
看了半天,从读懂utility开始吧,要不然读主干读着读着还是得回来读它,读完了他主干又忘了。
看了一晚上了,想哭。。。
4.20 昨天和今天心情都很不好,绝对是因为这个贼难的KGAT。让我看啥啥不懂,运行还运行不出来。。。今天到现在(15:17)唯一的收获是解决了那个除以0的问题,实际上人家是个warning,是我智障了。。。还有,不读utility了,大概看了utility中的load_data和loader_bprmf、laoder_kgat。有问题还是应该直接面对主要问题,在主要问题旁边绕来绕去都是浪费时间,emmm也不能这么说,总是有点收获的。今天直攻kgat吧。
先发个昨天读的load_data和loader_bprmf,解释写在代码里了。kgat的内容从下一篇(六)开始写啦。
一、Utility
1.1 load_data.py
'''
Created on Dec 18, 2018
Tensorflow Implementation of Knowledge Graph Attention Network (KGAT) model in:
Wang Xiang et al. KGAT: Knowledge Graph Attention Network for Recommendation. In KDD 2019.
@author: Xiang Wang (xiangwang@u.nus.edu)
'''
import collections
import numpy as np
import random as rd
class Data(object):
#初始化对象
#Data.path
#Data.args
#Data.batch_size
#train_file:训练数据。每一行是用户userID和他的正样本物品(itemID)的列表
#test_file:测试数据,每一行是userID和他的正样本(itemID)的列表,所有未观测到的都定为负样本
#kg_file:应该是每一行是一个(head,relation,tail)
#train_data,train_user_dict 训练数据中的交互列表,以用户为键的字典
#test_data,test_user_dict 测试数据的交互列表,以用户为键的字典
#n_users:用户id最大值
#n_items:物品id最大值
#n_train:训练样本数
#n_test: 测试样本数
#关系的最大值:n_relations
#实体的最大值:n_entities
#三元组的个数:n_triples
# kg_dict:字典,以head为键,(tail,relation)为值
# relation_dict:字典,以relation为键,(head,tail)为值
#batch_size是输入的参数
#batch_size_kg是kg_final中不重复的三元组个数
def __init__(self, args, path):
self.path = path
self.args = args
self.batch_size = args.batch_size
train_file = path + '/train.txt'
test_file = path + '/test.txt'
kg_file = path + '/kg_final.txt'
# ----------get number of users and items & then load rating data from train_file & test_file------------.
self.n_train, self.n_test = 0, 0
self.n_users, self.n_items = 0, 0
#train_data、test_data是一个np数组,里面包括了这个数据文件中的所有交互信息,
#每一个交互以[userID,itemID]的形式存储,最后整体为np.array
#train_uesr_dict和test_user_dict 是字典形式
#键为用户的userID,值对应的是和他有交互的列表(无重复)
self.train_data, self.train_user_dict = self._load_ratings(train_file)
self.test_data, self.test_user_dict = self._load_ratings(test_file)
#exist_users是训练数据的字典形式的键的列表,也就是全部训练数据中的用户id(userID)
self.exist_users = self.train_user_dict.keys()
#_statistic_ratings()求出了用户总数n_users和物品总数n_items
#和训练样本数n_train、测试样本数n_test
self._statistic_ratings()
# ----------get number of entities and relations & then load kg data from kg_file ------------.
self.n_relations, self.n_entities, self.n_triples = 0, 0, 0
self.kg_data, self.kg_dict, self.relation_dict = self._load_kg(kg_file)
# ----------print the basic info about the dataset-------------.
self.batch_size_kg = self.n_triples // (self.n_train // self.batch_size)
self._print_data_info()
# reading train & test interaction data.
def _load_ratings(self, file_name)
版权声明:本文为weixin_45665465原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。