"""Functions for accessing HDF5 files."""
from __future__ import division
from __future__ import print_function
import re
import h5py as h5
import numpy as np
import six
from six.moves import range
from ..utils import filter_regex, to_list
def _ls(item, recursive=False, groups=False, level=0):
keys = []
if isinstance(item, h5.Group):
if groups and level > 0:
keys.append(item.name)
if level == 0 or recursive:
for key in list(item.keys()):
keys.extend(_ls(item[key], recursive, groups, level + 1))
elif not groups:
keys.append(item.name)
return keys
[docs]def ls(filename, group='/', recursive=False, groups=False,
regex=None, nb_key=None, must_exist=True):
"""List name of records HDF5 file.
Parameters
----------
filename:
Path of HDF5 file.
group:
HDF5 group to be explored.
recursive: bool
If `True`, list records recursively.
groups: bool
If `True`, only list group names but not name of datasets.
regex: str
Regex to filter listed records.
nb_key: int
Maximum number of records to be listed.
must_exist: bool
If `False`, return `None` if file or group does not exist.
Returns
-------
list
`list` with name of records in `filename`.
"""
if not group.startswith('/'):
group = '/%s' % group
h5_file = h5.File(filename, 'r')
if not must_exist and group not in h5_file:
return None
keys = _ls(h5_file[group], recursive, groups)
for i, key in enumerate(keys):
keys[i] = re.sub('^%s/' % group, '', key)
h5_file.close()
if regex:
keys = filter_regex(keys, regex)
if nb_key is not None:
keys = keys[:nb_key]
return keys
[docs]def write_data(data, filename):
"""Write data in dict `data` to HDF5 file."""
is_root = isinstance(filename, str)
group = h5.File(filename, 'w') if is_root else filename
for key, value in six.iteritems(data):
if isinstance(value, dict):
key_group = group.create_group(key)
write_data(value, key_group)
else:
group[key] = value
if is_root:
group.close()
[docs]def hnames_to_names(hnames):
"""Flattens `dict` `hnames` of hierarchical names.
Converts hierarchical `dict`, e.g. hnames={'a': ['a1', 'a2'], 'b'}, to flat
list of keys for accessing HDF5 file, e.g. ['a/a1', 'a/a2', 'b']
"""
names = []
for key, value in six.iteritems(hnames):
if isinstance(value, dict):
for name in hnames_to_names(value):
names.append('%s/%s' % (key, name))
elif isinstance(value, list):
for name in value:
names.append('%s/%s' % (key, name))
elif isinstance(value, str):
names.append('%s/%s' % (key, value))
else:
names.append(key)
return names
def reader(data_files, names, batch_size=128, nb_sample=None, shuffle=False,
loop=False):
if isinstance(names, dict):
names = hnames_to_names(names)
else:
names = to_list(names)
# Copy, since list will be changed if shuffle=True
data_files = list(to_list(data_files))
# Check if names exist
h5_file = h5.File(data_files[0], 'r')
for name in names:
if name not in h5_file:
raise ValueError('%s does not exist!' % name)
h5_file.close()
if nb_sample:
# Select the first k files s.t. the total sample size is at least
# nb_sample. Only these files will be shuffled.
_data_files = []
nb_seen = 0
for data_file in data_files:
h5_file = h5.File(data_file, 'r')
nb_seen += len(h5_file[names[0]])
h5_file.close()
_data_files.append(data_file)
if nb_seen >= nb_sample:
break
data_files = _data_files
else:
nb_sample = np.inf
file_idx = 0
nb_seen = 0
while True:
if shuffle and file_idx == 0:
np.random.shuffle(data_files)
h5_file = h5.File(data_files[file_idx], 'r')
data_file = dict()
for name in names:
data_file[name] = h5_file[name]
nb_sample_file = len(list(data_file.values())[0])
if shuffle:
# Shuffle data within the entire file, which requires reading
# the entire file into memory
idx = np.arange(nb_sample_file)
np.random.shuffle(idx)
for name, value in six.iteritems(data_file):
data_file[name] = value[:len(idx)][idx]
nb_batch = int(np.ceil(nb_sample_file / batch_size))
for batch in range(nb_batch):
batch_start = batch * batch_size
nb_read = min(nb_sample - nb_seen, batch_size)
batch_end = min(nb_sample_file, batch_start + nb_read)
_batch_size = batch_end - batch_start
if _batch_size == 0:
break
data_batch = dict()
for name in names:
data_batch[name] = data_file[name][batch_start:batch_end]
yield data_batch
nb_seen += _batch_size
if nb_seen >= nb_sample:
break
h5_file.close()
file_idx += 1
assert nb_seen <= nb_sample
if nb_sample == nb_seen or file_idx == len(data_files):
if loop:
file_idx = 0
nb_seen = 0
else:
break
def _to_dict(data):
if isinstance(data, np.ndarray):
data = [data]
return dict(zip(range(len(data)), data))
def read_from(reader, nb_sample=None):
from .utils import stack_dict
data = dict()
nb_seen = 0
is_dict = True
for data_batch in reader:
if not isinstance(data_batch, dict):
data_batch = _to_dict(data_batch)
is_dict = False
for key, value in six.iteritems(data_batch):
values = data.setdefault(key, [])
values.append(value)
nb_seen += len(list(data_batch.values())[0])
if nb_sample and nb_seen >= nb_sample:
break
data = stack_dict(data)
if nb_sample:
for key, value in six.iteritems(data):
data[key] = value[:nb_sample]
if not is_dict:
data = [data[i] for i in range(len(data))]
return data
def read(data_files, names, nb_sample=None, batch_size=1024, *args, **kwargs):
data_reader = reader(data_files, names, batch_size=batch_size,
nb_sample=nb_sample, loop=False, *args, **kwargs)
return read_from(data_reader, nb_sample)