diff --git a/pgl/utils/data/dataloader.py b/pgl/utils/data/dataloader.py index 13544128..bab57412 100644 --- a/pgl/utils/data/dataloader.py +++ b/pgl/utils/data/dataloader.py @@ -89,10 +89,14 @@ def __init__(self, "You might want to set [stream_shuffle_size] with StreamDataset." warnings.warn(warn_msg) - if self.stream_shuffle_size > 0 and self.batch_size >= stream_shuffle_size: - raise ValueError("stream_shuffle_size must be larger than batch_size," \ - "but got [stream_shuffle_size=%s] smaller than [batch_size=%s]" \ - % (self.stream_shuffle_size, self.batch_size)) + if self.stream_shuffle_size > 0 and self.batch_size > stream_shuffle_size: + warn_msg = "stream_shuffle_size should be larger than batch_size, " + warn_msg += "but got [stream_shuffle_size=%s] " % ( + self.stream_shuffle_size) + warn_msg += "smaller than [batch_size=%s]. " % (self.batch_size) + warn_msg += "stream_shuffle_size will be set to %s." % ( + self.batch_size) + warnings.warn(warn_msg) if self.stream_shuffle_size > 0 and isinstance(self.dataset, Dataset): warn_msg = "[stream_shuffle_size] should not be set with Dataset. " \ @@ -128,16 +132,15 @@ def __iter__(self): # so set seed explicitly every time np.random.seed() if self.num_workers == 1: - r = paddle.reader.buffered(_DataLoaderIter(self, 0), self.buf_size) + workers = _DataLoaderIter(self, 0) else: worker_pool = [ _DataLoaderIter(self, wid) for wid in range(self.num_workers) ] workers = mp_reader.multiprocess_reader( - worker_pool, use_pipe=True, queue_size=1000) - r = paddle.reader.buffered(workers, self.buf_size) + worker_pool, use_pipe=True, queue_size=self.buf_size) - for batch in r(): + for batch in workers(): yield batch def __call__(self): @@ -196,45 +199,58 @@ def _streamdata_generator(self): yield batch_data def _stream_shuffle_data_generator(self): - def _stream_shuffle_index_generator(): - shuffle_size = [i for i in range(self.stream_shuffle_size)] - while True: - yield shuffle_size - - def _data_generator(): + def _batch_stream_data_generator(): dataset = iter(self.dataset) - for shuffle_size in _stream_shuffle_index_generator(): - shuffle_size_data = [] - for idx in shuffle_size: - try: - shuffle_size_data.append(next(dataset)) - except StopIteration: - break - - if len(shuffle_size_data) == 0: + batch_data = [] + while True: + try: + batch_data.append(next(dataset)) + except StopIteration: break - yield shuffle_size_data + if len(batch_data) == self.batch_size: + yield batch_data + batch_data = [] - def _batch_data_generator(): - batch_data = [] - for shuffle_size_data in _data_generator(): - np.random.shuffle(shuffle_size_data) + if not self.drop_last and len(batch_data) > 0: + yield batch_data + batch_data = [] - for d in shuffle_size_data: - batch_data.append(d) + def _batch_stream_shuffle_generator(): + buffer_list = [] + batch_data = [] + for examples in _batch_stream_data_generator(): + if len(buffer_list) < self.stream_shuffle_size: + buffer_list.extend(examples) + else: + rand_idx = np.random.randint(0, + len(buffer_list), + len(examples)) + for idx, e in zip(rand_idx, examples): + batch_data.append(buffer_list[idx]) + buffer_list[idx] = e + + yield batch_data + batch_data = [] + + if len(buffer_list) > 0: + np.random.shuffle(buffer_list) + batch_data = [] + for e in buffer_list: + batch_data.append(e) if len(batch_data) == self.batch_size: yield batch_data batch_data = [] - if not self.drop_last and len(batch_data) > 0: - yield batch_data + if not self.drop_last and len(batch_data) > 0: + yield batch_data + batch_data = [] self._worker_info = WorkerInfo( num_workers=self.num_workers, fid=self.fid) self.dataset._set_worker_info(self._worker_info) - for batch_data in _batch_data_generator(): + for batch_data in _batch_stream_shuffle_generator(): if self.collate_fn is not None: yield self.collate_fn(batch_data) else: diff --git a/pgl/utils/data/dataset.py b/pgl/utils/data/dataset.py index f850a45d..5b30be7d 100644 --- a/pgl/utils/data/dataset.py +++ b/pgl/utils/data/dataset.py @@ -93,16 +93,14 @@ class StreamDataset(object): class MyStreamDataset(StreamDataset): def __init__(self): self.data = list(range(0, 40)) - self.count = 0 def __iter__(self): - for data in self.dataset: - self.count += 1 - if self.count % self._worker_info.num_workers != self._worker_info.fid: + for count, data in enumerate(self.dataset): + if count % self._worker_info.num_workers != self._worker_info.fid: continue # do something (like parse data) of your data - time.sleep(0.1) yield data + """ def __iter__(self): diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 9b5916d8..22293bfb 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -36,7 +36,7 @@ def __len__(self): return len(self.dataset) def _transform(self, example): - time.sleep(0.05 + random.random() * 0.1) + # time.sleep(0.05 + random.random() * 0.1) return example @@ -48,7 +48,7 @@ def __iter__(self): for count, data in enumerate(self.dataset): if count % self._worker_info.num_workers != self._worker_info.fid: continue - time.sleep(0.1) + # time.sleep(0.1) yield data @@ -105,40 +105,61 @@ def test_ListDataset(self): self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res)) def test_IterDataset(self): - config = { - 'batch_size': 3, - 'drop_last': True, - 'num_workers': 2, - } - collate_fn = Collate_fn(config) - ds = IterDataset() - loader = Dataloader( - ds, - batch_size=config['batch_size'], - drop_last=config['drop_last'], - num_workers=config['num_workers'], - collate_fn=collate_fn) epochs = 1 - for e in range(epochs): - res = [] - for batch_data in loader: - res.extend(batch_data['data']) - self.assertEqual(len(batch_data['data']), config['batch_size']) + bs_list = [1, 3, 100] + workers_list = [1, 4, 40] + shuf_list = [0, 1, 10, 100] - # test shuffle - loader = Dataloader( - ds, - batch_size=3, - drop_last=False, - num_workers=1, - collate_fn=collate_fn) + collate_fn = Collate_fn(None) + ds = IterDataset() - for e in range(epochs): - res = [] - for batch_data in loader: - res.extend(batch_data['data']) - self.assertEqual(set([i for i in range(DATA_SIZE)]), set(res)) + for shuf_size in shuf_list: + for batch_size in bs_list: + for workers in workers_list: + msg = "batch_size: %s | " % batch_size + msg += "num_workers: %s | " % workers + msg += "shuf_size: %s | " % shuf_size + print(msg) + + loader = Dataloader( + ds, + batch_size=batch_size, + drop_last=False, + num_workers=workers, + stream_shuffle_size=shuf_size, + collate_fn=collate_fn) + + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual( + set([i for i in range(DATA_SIZE)]), set(res)) + + # test drop_last + for shuf_size in shuf_list: + for batch_size in bs_list: + for workers in workers_list: + msg = "batch_size: %s | " % batch_size + msg += "num_workers: %s | " % workers + msg += "shuf_size: %s | " % shuf_size + print(msg) + + loader = Dataloader( + ds, + batch_size=batch_size, + drop_last=True, + num_workers=workers, + stream_shuffle_size=0, + collate_fn=collate_fn) + + for e in range(epochs): + res = [] + for batch_data in loader: + res.extend(batch_data['data']) + self.assertEqual( + len(batch_data['data']), batch_size) def test_ListDataset_Order(self): config = {