0%

tf2-from_tensor-slices说明

tf.data.Dataset.from_tensor_slices

入参

一个元组、列表和张量

出参

得到数据集,类型为TensorSliceDataset

作用

是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。如果有多个特征进行组合,那么一次切片是把每个组合的最外维度的数据切开,分成一组一组的。

假设我们现在有两组数据,分别是特征和标签,我们假设每两个特征对应一个标签。之后把特征和标签组合成一个tuple,那么我们的想法是让每个标签都恰好对应2个特征,而且像直接切片,比如:[f11, f12] [t1]。f11表示第一个数据的第一个特征,f12表示第1个数据的第二个特征,t1表示第一个数据标签。那么tf.data.Dataset.from_tensor_slices就是做了这件事情:

line_number: true
1
2
3
4
5
6
7
8
9
import tensorflow as tf
import numpy as np

features, labels = (np.random.sample((6, 3)), # 模拟6组数据,每组数据3个特征
np.random.sample((6, 1))) # 模拟6组数据,每组数据对应一个标签,注意两者的维数必须匹配

print((features, labels)) # 输出下组合的数据
data = tf.data.Dataset.from_tensor_slices((features, labels))
print(data)