分布式TensorFlow 踩坑记

  • Post author:
  • Post category:其他

单机版的TF没毛病,但是当大家在Tensorflow Github里面找到可用的模型,想分布式跑到时候,就会跑出来各种奇怪的问题。我尝试了几种不同构造TF的方式,算是成功渡过了踩坑期,特别记录一下。如果能帮助到各位TF boy最好。

方法一:自己手动写分布式协议

比如logistic regression

在master上运行的伪代码如下

with tf.Session('grpc://vm1:2222') as sess:
    sess.run(initialization)
    while not stop:
        run(train_op)       

master负责的是初始化session,以及将parameter发送给其他的worker。如果有saver也定义在这里。

下面的逻辑就是每个worker收到parameter,计算gradient,然后到master上进行aggregate。最常用的aggregate的方式就是将每一个worker的gradient求和。

Tip: 下面都默认task_index是0的节点为master。在TF中,这个master节点也叫做chief node。

伪代码如下:

with tf.device('/job:worker/task:%d" % FLAGS.task_index):
    read_data
    compute gradient

with tf.device('/job:worker/task:0'):
    aggregate weight

但是既然parameter server在sparse的数据集上非常好用,那么我们不妨尝试利用这个特性。
1. 首先每一个worker得到sparse index,传给master。
2. master根据对应的sparse index,得到对应的sparse data。这个也叫做working set,不知道是不是system方向的叫法。
3. 每一个worker从master得到working set。更新gradient。
4. master将每一个worker的gradient进行aggregate。

其实相比于前一个方法,就是多了一次信息传递(1、2步),来获取sparse信息。

伪代码如下:

with tf.device('/job:worker/task:%d' % FLAGS.task_index):
    read data
    get sparse index

with tf.device('/job:worker/task:0'):
    get sparse index from each worker
    generate working set for each worker

with tf.device('/job:worker/task:%d' % FLAGS.task_index):
    get working set from master
    compute gradient

with tf.device('/job:worker/task:0'):
    aggregate gradient

完整代码在这里.

Tip: 实现的时候,我直接用list存储每一个worker的gradient,sparse index,和working set。在下面提到的TensorFlow中已经实现的类中,使用的是内置的queue。(Python里queue和list差别不大)

方法二:使用MonitoredTrainingSession

TF有内置的类,supervisor和MonotoredTrainingSession是最常用的两个。

MonitoredTrainingSession是MonitoredSession的子类,多增加的功能是为master/chief 节点增加断点功能,以及创建session,分配给其他worker。

如果要保证同步更新,主要“下手”在optimizer上,类似我们在上面的第一个算法,所以TF有一个叫做SyncReplicasOptimizer的类。

这里一个让我踩了许久的坑:如果用了MonotoredTrainingSessionSyncReplicasOptimizer,如果ps相对于worker的分布不是均匀的,那么有的worker会跑的特别快。比如这样设置:

parser.add_argument(
    "--ps_hosts",
    type=str,
    default="vm1:2233",
    help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
    "--worker_hosts",
    type=str,
    default="vm1:2222,vm2:2222,vm3:2222",
    help="Comma-separated list of hostname:port pairs"
)

虽说理论上快的worker在每次iteration结束之后应该等queue挤满,但是很神奇的是!这个快的worker会自己再多跑几份session来把queue填充满,从而进入下一次iteration。

所以千万不要把ps和worker分布不均匀。可以每台worker都是ps,也可以固定几个ps与worker不重合。

parser.add_argument(
    "--ps_hosts",
    type=str,
    default="vm1:2233",
    help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
    "--worker_hosts",
    type=str,
    default="vm2:2222,vm3:2222",
    help="Comma-separated list of hostname:port pairs"
)

版权声明:本文为miss_snow_m原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。