如何在TensorFlow中使用并行数据加载,解决

 我来答
约定20125
2018-05-22 · TA获得超过1.5万个赞
知道大有可为答主
回答量:1.1万
采纳率:96%
帮助的人:3125万
展开全部

整个过程可以在python中简单实现,我们需要定义一个FIFOqueue类,用于保存数据,如:

import queueclass FIFOQueue(object):
__max_len = None
__queue = None
def __init__(self, max_len=5):
if self.__max_len is None:
self.__max_len = max_len    else:      if self.__max_len is not max_len:        raise ValueError('The FIFOQueue has been declared yet and max_len is not same!')    if self.__queue is None:
self.__queue = queue.Queue(maxsize=max_len)  def enqueue(self, item):
'''
put a batch into queue. If the queue is full, then it will be blocked and wait until the queue is not full.
:param item: a batch with the format of (data_batch, data_label)
:return: None
'''
self.__queue.put(item)  def dequeue(self):
'''
pop a batch from queue. If the queue is empty then it will be blocked till the queue is not empty.
:return: the batch with the format of (data_batch, data_label)
'''
item = self.__queue.get()    return item  def max_len(self):
return self.__max_len  def get_len(self):
return self.__queue.qsize()12345678910111213141516171819202122232425262728293031

可以发现只是对queue的简单封装。

在最主要的Train类中,如:

import FIFOqueue as queueimport threadingclass Train(object):
_train_global_queue = None
_val_global_queue = None
_test_global_queue = None
_threads = []  def __init__(self,
main_task,
batch_size=32,
train_yield=None,
val_yield=None,
test_yield=None,
max_nthread=10,
max_len=10):
self._train_global_queue = queue.FIFOQueue(max_len=max_len) if train_yield is not None else None
self._val_global_queue = queue.FIFOQueue(max_len=max_len) if val_yield is not None else None
self._test_global_queue = queue.FIFOQueue(max_len=max_len) if test_yield is not None else None
# init the global queue and maintain them
train_threads = [threading.Thread(target=self._data_enqueue,
args=(train_yield, batch_size, task_id, 'train_data_load', self._train_global_queue))    for task_id in range(max_nthread)]    def wrapper_main_task(fn):
while True:
fn(self._train_global_queue.dequeue())

self._threads += train_threads
self._threads += [threading.Thread(target=wrapper_main_task, args=([main_task]))]  def _data_enqueue(self, fn, batch_size, task_id, task_type, queue_h):
print('here begin the data loading with task_id %d with type %s' % (task_id, task_type))    while True:
item = fn()
item['task_id'] = task_id
item['task_type'] = task_type
queue_h.enqueue(item=item)  def start(self):
for each_t in self._threads:
each_t.start() 

我们实现了刚才说是的并行加载的过程,其中需要注意几点:

  • _data_enqueue是用于将数据入队列的,其中fn为数据生成器,需要用户自行重写传入。

  • wrapper_main_task是用于封装主任务的,并且在使得可以在主任务中出队,利用数据。

已赞过 已踩过<
你对这个回答的评价是?
评论 收起
收起 1条折叠回答
推荐律师服务: 若未解决您的问题,请您详细描述您的问题,通过百度律临进行免费专业咨询

为你推荐:

下载百度知道APP,抢鲜体验
使用百度知道APP,立即抢鲜体验。你的手机镜头里或许有别人想知道的答案。
扫描二维码下载
×

类别

我们会通过消息、邮箱等方式尽快将举报结果通知您。

说明

0/200

提交
取消

辅 助

模 式