复现KGAT: Knowledge Graph Attention Network for Recommendation(五)

  • Post author:
  • Post category:其他




复现KGAT: Knowledge Graph Attention Network for Recommendation(五)

4.19 打开这堆代码看了半天,真TMD的难啊,这是我动手复现的第一篇推荐论文。上一个看过代码的是NCF,那个难度简直没法和这个相提并论,准确的说,有点机器学习、深度学习基础的那个代码看起来分分钟的事。

这篇。。。对我这个新手很不友好啊。。。

看了半天,从读懂utility开始吧,要不然读主干读着读着还是得回来读它,读完了他主干又忘了。

看了一晚上了,想哭。。。

4.20 昨天和今天心情都很不好,绝对是因为这个贼难的KGAT。让我看啥啥不懂,运行还运行不出来。。。今天到现在(15:17)唯一的收获是解决了那个除以0的问题,实际上人家是个warning,是我智障了。。。还有,不读utility了,大概看了utility中的load_data和loader_bprmf、laoder_kgat。有问题还是应该直接面对主要问题,在主要问题旁边绕来绕去都是浪费时间,emmm也不能这么说,总是有点收获的。今天直攻kgat吧。

先发个昨天读的load_data和loader_bprmf,解释写在代码里了。kgat的内容从下一篇(六)开始写啦。



一、Utility



1.1 load_data.py

'''
Created on Dec 18, 2018
Tensorflow Implementation of Knowledge Graph Attention Network (KGAT) model in:
Wang Xiang et al. KGAT: Knowledge Graph Attention Network for Recommendation. In KDD 2019.
@author: Xiang Wang (xiangwang@u.nus.edu)
'''
import collections
import numpy as np
import random as rd

class Data(object):
    #初始化对象
    #Data.path
    #Data.args
    #Data.batch_size

    #train_file:训练数据。每一行是用户userID和他的正样本物品(itemID)的列表
    #test_file:测试数据,每一行是userID和他的正样本(itemID)的列表,所有未观测到的都定为负样本
    #kg_file:应该是每一行是一个(head,relation,tail)

    #train_data,train_user_dict 训练数据中的交互列表,以用户为键的字典
    #test_data,test_user_dict  测试数据的交互列表,以用户为键的字典

    #n_users:用户id最大值
    #n_items:物品id最大值
    #n_train:训练样本数
    #n_test: 测试样本数

    #关系的最大值:n_relations
    #实体的最大值:n_entities
    #三元组的个数:n_triples

    # kg_dict:字典,以head为键,(tail,relation)为值
    # relation_dict:字典,以relation为键,(head,tail)为值

    #batch_size是输入的参数
    #batch_size_kg是kg_final中不重复的三元组个数

    def __init__(self, args, path):
        self.path = path
        self.args = args

        self.batch_size = args.batch_size

        train_file = path + '/train.txt'
        test_file = path + '/test.txt'

        kg_file = path + '/kg_final.txt'

        # ----------get number of users and items & then load rating data from train_file & test_file------------.
        self.n_train, self.n_test = 0, 0
        self.n_users, self.n_items = 0, 0

        #train_data、test_data是一个np数组,里面包括了这个数据文件中的所有交互信息,
        #每一个交互以[userID,itemID]的形式存储,最后整体为np.array
        #train_uesr_dict和test_user_dict 是字典形式
        #键为用户的userID,值对应的是和他有交互的列表(无重复)
        self.train_data, self.train_user_dict = self._load_ratings(train_file)
        self.test_data, self.test_user_dict = self._load_ratings(test_file)
        #exist_users是训练数据的字典形式的键的列表,也就是全部训练数据中的用户id(userID)
        self.exist_users = self.train_user_dict.keys()


        #_statistic_ratings()求出了用户总数n_users和物品总数n_items
        #和训练样本数n_train、测试样本数n_test
        self._statistic_ratings()

        # ----------get number of entities and relations & then load kg data from kg_file ------------.
        self.n_relations, self.n_entities, self.n_triples = 0, 0, 0
        self.kg_data, self.kg_dict, self.relation_dict = self._load_kg(kg_file)

        # ----------print the basic info about the dataset-------------.
        self.batch_size_kg = self.n_triples // (self.n_train // self.batch_size)
        self._print_data_info()

    # reading train & test interaction data.
    def _load_ratings(self, file_name)



版权声明:本文为weixin_45665465原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。