
如何在TensorFlow中使用并行数据加载,解决
整个过程可以在python中简单实现,我们需要定义一个FIFOqueue类,用于保存数据,如:
import queueclass FIFOQueue(object):
__max_len = None
__queue = None
def __init__(self, max_len=5):
if self.__max_len is None:
self.__max_len = max_len else: if self.__max_len is not max_len: raise ValueError('The FIFOQueue has been declared yet and max_len is not same!') if self.__queue is None:
self.__queue = queue.Queue(maxsize=max_len) def enqueue(self, item):
'''
put a batch into queue. If the queue is full, then it will be blocked and wait until the queue is not full.
:param item: a batch with the format of (data_batch, data_label)
:return: None
'''
self.__queue.put(item) def dequeue(self):
'''
pop a batch from queue. If the queue is empty then it will be blocked till the queue is not empty.
:return: the batch with the format of (data_batch, data_label)
'''
item = self.__queue.get() return item def max_len(self):
return self.__max_len def get_len(self):
return self.__queue.qsize()12345678910111213141516171819202122232425262728293031
可以发现只是对queue的简单封装。
在最主要的Train类中,如:
import FIFOqueue as queueimport threadingclass Train(object):
_train_global_queue = None
_val_global_queue = None
_test_global_queue = None
_threads = [] def __init__(self,
main_task,
batch_size=32,
train_yield=None,
val_yield=None,
test_yield=None,
max_nthread=10,
max_len=10):
self._train_global_queue = queue.FIFOQueue(max_len=max_len) if train_yield is not None else None
self._val_global_queue = queue.FIFOQueue(max_len=max_len) if val_yield is not None else None
self._test_global_queue = queue.FIFOQueue(max_len=max_len) if test_yield is not None else None
# init the global queue and maintain them
train_threads = [threading.Thread(target=self._data_enqueue,
args=(train_yield, batch_size, task_id, 'train_data_load', self._train_global_queue)) for task_id in range(max_nthread)] def wrapper_main_task(fn):
while True:
fn(self._train_global_queue.dequeue())
self._threads += train_threads
self._threads += [threading.Thread(target=wrapper_main_task, args=([main_task]))] def _data_enqueue(self, fn, batch_size, task_id, task_type, queue_h):
print('here begin the data loading with task_id %d with type %s' % (task_id, task_type)) while True:
item = fn()
item['task_id'] = task_id
item['task_type'] = task_type
queue_h.enqueue(item=item) def start(self):
for each_t in self._threads:
each_t.start()
我们实现了刚才说是的并行加载的过程,其中需要注意几点:
_data_enqueue是用于将数据入队列的,其中fn为数据生成器,需要用户自行重写传入。
wrapper_main_task是用于封装主任务的,并且在使得可以在主任务中出队,利用数据。