基于树的召回框架(二):Joint Optimization of Tree-based Index and Deep Model for Recommender Systems

  • Post author:
  • Post category:其他


阿里基于树结构的召回体系一共发了三篇paper,这是第二篇,其他两篇论文的阅读笔记见下方链接:



背景

在论文《Learning Tree-based Deep Model for Recommender Systems》中提出的基于树结构的召回模型,其在优化树结构



T

\mathcal{T}







T






和优化用户兴趣模型



M

\mathcal{M}







M






的两个过程使用的评估指标是不一致的,而本文旨在提出一个训练框架,能够将这两个过程在一个统一的指标下进行优化,从而加快训练过程以及提高召回的效果。本文提出的方法称为 JTM。



方法

与第一篇论文的最大不同之处在于,本文使用的树的层级结构是固定的,文章以完全二叉树为例方便说明算法运行过程,并指出该算法可以拓展到 multi-way tree。假设物品集合为



C

C






C





,该完全二叉树的所有叶子节点的集合为



N

N






N





,则满足



N

=

C

|N|=|C|









N







=











C








。本文中对于树结构的优化,实际上是在优化一个映射函数



π

(

)

\pi(\cdot)






π


(





)





,其建立了物品到叶子节点的一对一映射。寻找函数



π

(

)

\pi(\cdot)






π


(





)





的过程,实际上等价于寻找带权二部图的最大匹配(相关概念自行搜索),其证明过程如下(来自该论文的补充材料,见本文开始部分):

在这里插入图片描述

本文并不直接将物品分配到最后一层的叶子节点中,而是从根节点到叶节点、自上而下、由粗到细地进行分配。算法描述如下:

在这里插入图片描述


博主注

:个人感觉该算法描述结合论文的文字解释也仍然不是很直观,因此查看了论文给出的(粗略)实现的代码,个人认为Algorithm 2 可以用如下示意图来理解:

我们以16个物品为例(



C

=

16

|C|=16









C







=








1


6





),即对应了有 5 层的满二叉树,并假设我们取

间隔




d

=

2

d=2






d




=








2





。树在初始化时,所有16个物品都被分配到了根节点(第0层):

在这里插入图片描述

在第一轮迭代后(对应Algorithm 2 中的第3行到第9行),所有物品都被分配到了第 2 层,如下图所示:

在这里插入图片描述

在第二轮迭代后(对应Algorithm 2 中的第3行到第9行),所有物品都被分配到了第 4 层,如下图所示:

在这里插入图片描述

需要注意的是,在每一轮迭代,实际上是对于上一轮已经分配好的那层的所有节点一一进行分配,即已经被分配到某个节点的物品,在后续迭代中只会继续被分配到该节点的后代节点中。例如第一轮迭代,实际上是进行了 1 次匹配,即将根节点中的16个物品和第 2 层中的 4 个节点进行匹配;在第二轮迭代中,实际上是进行了 4 次匹配,即对第 2 层的 4 个节点的后代分别进行了一次匹配:将分配到该节点中的所有物品,分配到该节点的下 2 代的后代(即儿子的儿子)中,注意,第 2 层节点的下 2 代的后代位于第 4 层。

最后要解释的问题就是,Algorithm 2 的第5行,即给定一个物品集合



C

C






C





以及节点集合



N

N






N





,匹配过程是如何进行的。例如,上面图示例子中,第一轮迭代时,我们将



c

1

c

4

c

6

c

15

c_1c_4c_6c_{15}







c










1



















c










4



















c










6



















c











1


5






















这4个物品分到第 2 层的第一个节点的依据是什么?正如论文中指出的,“ we use a greedy algorithm with rebalance strategy to solve the sub-problem”:

前面已经证明了这个过程相当于寻找带权二部图的最大匹配的问题,而为了降低求解复杂度,本文采用的是贪婪算法。首先,对于所有的物品-节点对,计算



c

k

c_k







c










k

























n

m

n_m







n










m





















之间的边权重



L

c

k

,

n

m

\mathcal{L}_{c_k,n_m}








L













c










k


















,



n










m






































,即一共要计算



C

×

N

|C|\times|N|









C







×











N








个权重(对应后文代码中的

get_weights()

,注意,此处的



N

C

|N|\neq |C|









N










































=












C








,因为此处的



N

N






N





并非所有叶子节点集合)。对于每个物品



c

k

c_k







c










k





















,将其与连接后权重最大的节点



n

m

n_m







n










m





















连接起来,然后进行rebalance操作(见

assign_parent()

代码的后半部分)。所谓 rebalance,指的是手动调节一些物品分配到的节点,因为可能有的节点分配不到任何物品,而有的节点分配得到了过多节点(第



d

d






d





层的每个节点最多只能分配到



2

l

m

a

x

2

d

=

2

l

m

a

x

d

\frac{2^{l_{max}}}{2^d}=2^{l_{max}-d}



















2










d
























2












l










m


a


x















































=









2












l











m


a


x






















d













个物品)。具体逻辑见后文代码的

assign_parent()



算法实现

节选自论文作者提供的代码实现,更多细节见文章开始的supplementary files链接。

class TreeLearner(object):
    def __init__(self, filename):
        self.filename = filename
        self.tree_meta = tree_proto.TreeMeta()
        self.id_code = dict()
        self.item_codes = set()
        self.nodes = dict()

    def get_ancestor(self, code, level):
        code_max = 2 ** (level + 1) - 1
        while code >= code_max:
            code = int((code - 1) / 2) 
        return code
    
    def get_nodes_given_level(self, level):
        code_min = 2 ** level - 1
        code_max = 2 ** (level + 1) - 1
        res = []
        for code in self.nodes.keys():
            if code >= code_min and code < code_max:
                res.append(code)
        return res

    def get_children_given_ancestor_and_level(self, ancestor, level):
        code_min = 2 ** level - 1
        code_max = 2 ** (level + 1) - 1
        parent = [ancestor]
        res = []
        while True:
            children = []
            for p in parent:
                children.extend([2 * p + 1, 2 * p + 2])
            if code_min <= children[0] < code_max:
                break
            parent = children

        output = []
        for i in res:
            if i in self.nodes:
                output.append(i)
        return output
    # .......


def get_itemset_given_ancestor(pi_new, node):
    res = []
    for ci, code in pi_new.items():
        if code == node:
            res.append(ci)
    return res


def tree_learning(d, tree):
    """ The overall tree learning algorithm (Algorithm 2 in the paper)
    
    Returns:
        the leant new projection from item to leaf node (\pi_{new})
        
    Args:
        d (int, required): the tree learning level gap
        tree (tree, required): the old tree (\pi_{old})
    """
    l_max = tree.tree_meta.max_level - 1
    l = d

    pi_new = dict()

    # \pi_{new} <- \pi_{old}
    for item_code in tree.item_codes:
        ci = tree.nodes[item_code].id
        pi_new[ci] = tree.get_ancestor(item_code, l - d)

    while d > 0:
        nodes = tree.get_nodes_given_level(l - d)
        for ni in nodes:
            C_ni = get_itemset_given_ancestor(pi_new, ni)
            pi_star = assign_parent(l_max, l, ni, C_ni, tree)

            # update pi_new according to the found optimal pi_star
            for item, node in pi_star.items():
                pi_new[item] = node

        d = min(d, l_max - l)
        l = l + d

    return pi_new


def assign_parent(l_max, l, ni, C_ni, tree):
    """implementation of line 5 of Algorithm 2
    
    Returns: 
        updated \pi_{new}
    
    Args:
        l_max (int, required): the max level of the tree
        l (int, required): current assign level
        ni (node, required): a non-leaf node in level l-d
        C_ni (item, required): item set whose ancestor is the non-leaf node ni
        tree (tree, required): the old tree (\pi_{old})
    """

    # get the children of ni in level l
    children_of_ni_in_level_l = tree.get_children_given_ancestor_and_level(ni, l)

    # get all the required weights 
    edge_weights = get_weights(C_ni, ni, children_of_ni_in_level_l, tree) 

    # assign each item to the level l node with the maximum weight
    assign_dict = dict()
    for ci, info in edge_weights.items():
        assign_candidate_nodes = info[0]
        assign_weights = np.array(info[1], dtype=np.float32)
        sorted_idx = np.argsort(-assign_weights)
        # assign item ci to the node with the largest weight
        max_weight_node = assign_candidate_nodes[sorted_idx[0]]
        if max_weight_node in assign_dict:
            assign_dict[max_weight_node].append((ci, sorted_idx, assign_candidate_nodes, assign_weights))
        else:
            assign_dict[max_weight_node] = [(ci, sorted_idx, assign_candidate_nodes, assign_weights)]

    edge_weights = None

    # get each item's original assignment of level l in tree, used in rebalance process
    origin_relation = dict()
    for ci in C_ni:
        origin_relation[ci] = tree.get_ancestor(ci, l)

    # rebalance
    max_assign_num = int(math.pow(2, l_max - l))
    processed_set = set()
    while True:
        max_assign_cnt = 0
        max_assign_node = None

        for node in children_of_ni_in_level_l:
            if node in processed_set:
                continue
            if node not in assign_dict:
                continue
            if len(assign_dict[node]) > max_assign_cnt:
                max_assign_cnt = len(assign_dict[node])
                max_assign_node = node

        if max_assign_node == None or max_assign_cnt <= max_assign_num:
            break
            
        # rebalance
        processed_set.add(max_assign_node)
        elements = assign_dict[max_assign_node]
        elements.sort(key=lambda x: (int(max_assign_node != origin_relation[x[0]]), -x[1]))
        for e in elements[max_assign_num:]:
            for idx in e[1]:
                other_parent_node = e[2][idx]
                if other_parent_node in processed_set:
                    continue
                if other_parent_node not in assign_dict:
                    assign_dict[other_parent_node] = [(e[0], e[1], e[2], e[3])]
                else:
                    assign_dict[other_parent_node].append((e[0], e[1], e[2], e[3]))
                break
        del elements[max_assign_num:]

    pi_new = dict()
    for parent_code, value in assign_dict.items():
        max_assign_num = int(math.pow(2, l_max - l))
        assert len(value) <= max_assign_num
        for e in value:
            assert e[0] not in pi_new
            pi_new[e[0]] = parent_code

    return pi_new



def get_weights(C_ni, ni, children_of_ni_in_level_l, tree):
    """use the user preference prediction model to calculate the required weights
    
    Returns: 
        all weights
    
    Args:
        C_ni (item, required): item set whose ancestor is the non-leaf node ni
        ni (node, required): a non-leaf node in level l-d
        children_of_ni_in_level_l (list, required): the level l-th children of ni
        tree (tree, required): the old tree (\pi_{old})
    """
    edge_weights = dict()
    
    for ck in C_ni:
        edge_weights[ck] = list()
        edge_weights[ck].append([]) # the first element is the list of nodes in level l
        edge_weights[ck].append([]) # the second element is the list of corresponding weights

        for node in children_of_ni_in_level_l:
            path_to_ni = tree.get_parent_path(node, ni)
            weight = 0.0
            for n in path_to_ni:
                sample_set = set() # the sample set that the target item is ck
                
                # use the user preference prediction model to calculate the required weights.
                # the detailed calculation process is omitted here
                weight += calculate_weight_use_prediction_model(sample_set, n)

            edge_weights[ck][0].append(node)
            edge_weights[ck][1].append(weight)

    return edge_weights



版权声明:本文为xpy870663266原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。