1、tensorflow载入数据的三种方式tensorflow载入数据的三种方式 Tensorflow数据读取有三种方式:Preloaded data: 预加载数据Feeding: Python产生数据,再把数据喂给后端。Reading from file: 从文件中直接读取这三种有读取方式有什么区别呢? 我们首先要知道TensorFlow(TF)是怎么样工作的。TF的核心是用C+写的,这样的好处是运行快,缺点是调用不灵活。而Python恰好相反,所以结合两种语言的优势。涉及计算的核心算子和运行框架是用C+写的,并提供API给Python。Python调用这些API,设计训练模型(Graph),再
2、将设计好的Graph给后端去执行。简而言之,Python的角色是Design,C+是Run。一、预加载数据:python view plain copy import tensorflow as tf # 设计Graph x1 = tf.constant(2, 3, 4) x2 = tf.constant(4, 0, 1) y = tf.add(x1, x2) # 打开一个session -> 计算y with tf.Session() as sess: print sess.run(y) 二、python产生数据,再将数据喂给后端python view plain copy impor
3、t tensorflow as tf # 设计Graph x1 = tf.placeholder(tf.int16) x2 = tf.placeholder(tf.int16) y = tf.add(x1, x2) # 用Python产生数据 li1 = 2, 3, 4 li2 = 4, 0, 1 # 打开一个session -> 喂数据 -> 计算y with tf.Session() as sess: print sess.run(y, feed_dict=x1: li1, x2: li2) 说明:在这里x1, x2只是占位符,没有具体的值,那么运行的时候去哪取值呢?这时候就要
4、用到sess.run()中的feed_dict参数,将Python产生的数据喂给后端,并计算y。这两种方案的缺点:1、预加载:将数据直接内嵌到Graph中,再把Graph传入Session中运行。当数据量比较大时,Graph的传输会遇到效率问题。2、用占位符替代数据,待运行的时候填充数据。前两种方法很方便,但是遇到大型数据的时候就会很吃力,即使是Feeding,中间环节的增加也是不小的开销,比如数据类型转换等等。最优的方案就是在Graph定义好文件读取的方法,让TF自己去从文件中读取数据,并解码成可使用的样本集。三、从文件中读取,简单来说就是将数据读取模块的图搭好1、准备数据,构造三个文件,A
5、.csv,B.csv,C.csvpython view plain copy $ echo -e Alpha1,A1nAlpha2,A2nAlpha3,A3 > A.csv $ echo -e Bee1,B1nBee2,B2nBee3,B3 > B.csv $ echo -e Sea1,C1nSea2,C2nSea3,C3 > C.csv 2、单个Reader,单个样本python view plain copy #-*- coding:utf-8 -*- import tensorflow as tf # 生成一个先入先出队列和一个QueueRunner,生成文件名队列 f
6、ilenames = A.csv, B.csv, C.csv filename_queue = tf.train.string_input_producer(filenames, shuffle=False) # 定义Reader reader = tf.TextLineReader() key, value = reader.read(filename_queue) # 定义Decoder example, label = tf.decode_csv(value, record_defaults=null, null) #example_batch, label_batch = tf.tra
7、in.shuffle_batch(example,label, batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2) # 运行Graph with tf.Session() as sess: coord = tf.train.Coordinator() #创建一个协调器,管理线程 threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。 for i in range(10): print example.eval(
8、),label.eval() coord.request_stop() coord.join(threads) 说明:这里没有使用tf.train.shuffle_batch,会导致生成的样本和label之间对应不上,乱序了。生成结果如下:Alpha1 A2Alpha3 B1Bee2 B3Sea1 C2Sea3 A1Alpha2 A3Bee1 B2Bee3 C1Sea2 C3Alpha1 A2解决方案:用tf.train.shuffle_batch,那么生成的结果就能够对应上。python view plain copy #-*- coding:utf-8 -*- import tensorf
9、low as tf # 生成一个先入先出队列和一个QueueRunner,生成文件名队列 filenames = A.csv, B.csv, C.csv filename_queue = tf.train.string_input_producer(filenames, shuffle=False) # 定义Reader reader = tf.TextLineReader() key, value = reader.read(filename_queue) # 定义Decoder example, label = tf.decode_csv(value, record_defaults=nu
10、ll, null) example_batch, label_batch = tf.train.shuffle_batch(example,label, batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2) # 运行Graph with tf.Session() as sess: coord = tf.train.Coordinator() #创建一个协调器,管理线程 threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列
11、已经进队。 for i in range(10): e_val,l_val = sess.run(example_batch, label_batch) print e_val,l_val coord.request_stop() coord.join(threads) 3、单个Reader,多个样本,主要也是通过tf.train.shuffle_batch来实现python view plain copy #-*- coding:utf-8 -*- import tensorflow as tf filenames = A.csv, B.csv, C.csv filename_queue =
12、 tf.train.string_input_producer(filenames, shuffle=False) reader = tf.TextLineReader() key, value = reader.read(filename_queue) example, label = tf.decode_csv(value, record_defaults=null, null) # 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。 #Decoder解后数据会进入这个队列,再批量出队。 # 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数
13、会提高读取速度,但并不是线程越多越好。 example_batch, label_batch = tf.train.batch( example, label, batch_size=5) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): e_val,l_val = sess.run(example_batch,label_batch) print e_val,l_val coord.r
14、equest_stop() coord.join(threads) 说明:下面这种写法,提取出来的batch_size个样本,特征和label之间也是不同步的python view plain copy #-*- coding:utf-8 -*- import tensorflow as tf filenames = A.csv, B.csv, C.csv filename_queue = tf.train.string_input_producer(filenames, shuffle=False) reader = tf.TextLineReader() key, value = read
15、er.read(filename_queue) example, label = tf.decode_csv(value, record_defaults=null, null) # 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。 #Decoder解后数据会进入这个队列,再批量出队。 # 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。 example_batch, label_batch = tf.train.batch( example, label, batch_size=5) with tf.S
16、ession() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): print example_batch.eval(), label_batch.eval() coord.request_stop() coord.join(threads) 说明:输出结果如下:可以看出feature和label之间是不对应的Alpha1Alpha2 Alpha3 Bee1 Bee2 B3 C1 C2 C3 A1Alpha2 Alpha3
17、 Bee1 Bee2 Bee3 C1 C2 C3 A1 A2Alpha3 Bee1 Bee2 Bee3 Sea1 C2 C3 A1 A2 A34、多个reader,多个样本python view plain copy #-*- coding:utf-8 -*- import tensorflow as tf filenames = A.csv, B.csv, C.csv filename_queue = tf.train.string_input_producer(filenames, shuffle=False) reader = tf.TextLineReader() key, value
18、 = reader.read(filename_queue) record_defaults = null, null #定义了多种解码器,每个解码器跟一个reader相连 example_list = tf.decode_csv(value, record_defaults=record_defaults) for _ in range(2) # Reader设置为2 # 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。 example_batch, label_batch = tf.train.batch_join( e
19、xample_list, batch_size=5) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): e_val,l_val = sess.run(example_batch,label_batch) print e_val,l_val coord.request_stop() coord.join(threads) tf.train.batch与tf.train.shuffle_ba
20、tch函数是单个Reader读取,但是可以多线程。tf.train.batch_join与tf.train.shuffle_batch_join可设置多Reader读取,每个Reader使用一个线程。至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,甚至更多的线程反而会使效率下降。5、迭代控制,设置epoch参数,指定我们的样本在训练的时候只能被用多少轮python view plain copy #-*- coding:utf-8 -*- import tensorflow as tf filenames
21、 = A.csv, B.csv, C.csv #num_epoch: 设置迭代数 filename_queue = tf.train.string_input_producer(filenames, shuffle=False,num_epochs=3) reader = tf.TextLineReader() key, value = reader.read(filename_queue) record_defaults = null, null #定义了多种解码器,每个解码器跟一个reader相连 example_list = tf.decode_csv(value, record_def
22、aults=record_defaults) for _ in range(2) # Reader设置为2 # 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。 example_batch, label_batch = tf.train.batch_join( example_list, batch_size=1) #初始化本地变量 init_local_op = tf.initialize_local_variables() with tf.Session() as sess: sess.run(init_local_op
23、) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): e_val,l_val = sess.run(example_batch,label_batch) print e_val,l_val except tf.errors.OutOfRangeError: print(Epochs Complete!) finally: coord.request_stop() coord.join(threads) coo
24、rd.request_stop() coord.join(threads) 在迭代控制中,记得添加tf.initialize_local_variables(),官网教程没有说明,但是如果不初始化,运行就会报错。=对于传统的机器学习而言,比方说分类问题,x1x2 x3是feature。对于二分类问题,label经过one-hot编码之后就会是0,1或者1,0。一般情况下,我们会考虑将数据组织在csv文件中,一行代表一个sample。然后使用队列的方式去读取数据说明:对于该数据,前三列代表的是feature,因为是分类问题,后两列就是经过one-hot编码之后得到的label使用队列读取该csv
25、文件的代码如下:python view plain copy #-*- coding:utf-8 -*- import tensorflow as tf # 生成一个先入先出队列和一个QueueRunner,生成文件名队列 filenames = A.csv filename_queue = tf.train.string_input_producer(filenames, shuffle=False) # 定义Reader reader = tf.TextLineReader() key, value = reader.read(filename_queue) # 定义Decoder rec
26、ord_defaults = 1, 1, 1, 1, 1 col1, col2, col3, col4, col5 = tf.decode_csv(value,record_defaults=record_defaults) features = tf.pack(col1, col2, col3) label = tf.pack(col4,col5) example_batch, label_batch = tf.train.shuffle_batch(features,label, batch_size=2, capacity=200, min_after_dequeue=100, num_
27、threads=2) # 运行Graph with tf.Session() as sess: coord = tf.train.Coordinator() #创建一个协调器,管理线程 threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。 for i in range(10): e_val,l_val = sess.run(example_batch, label_batch) print e_val,l_val coord.request_stop() coord.join(threads) 输出结果如下:说明:record_defaults = 1, 1, 1, 1, 1代表解析的模板,每个样本有5列,在数据中是默认用,隔开的,然后解析的标准是1,也即每一列的数值都解析为整型。1.0就是解析为浮点,null解析为string类型
copyright@ 2008-2022 冰豆网网站版权所有
经营许可证编号:鄂ICP备2022015515号-1