tensorflow载入数据的三种方式.docx

上传人:b****7 文档编号:9634669 上传时间:2023-02-05 格式:DOCX 页数:6 大小:16.94KB
下载 相关 举报
tensorflow载入数据的三种方式.docx_第1页
第1页 / 共6页
tensorflow载入数据的三种方式.docx_第2页
第2页 / 共6页
tensorflow载入数据的三种方式.docx_第3页
第3页 / 共6页
tensorflow载入数据的三种方式.docx_第4页
第4页 / 共6页
tensorflow载入数据的三种方式.docx_第5页
第5页 / 共6页
点击查看更多>>
下载资源
资源描述

tensorflow载入数据的三种方式.docx

《tensorflow载入数据的三种方式.docx》由会员分享,可在线阅读,更多相关《tensorflow载入数据的三种方式.docx(6页珍藏版)》请在冰豆网上搜索。

tensorflow载入数据的三种方式.docx

tensorflow载入数据的三种方式

tensorflow载入数据的三种方式

Tensorflow数据读取有三种方式:

Preloadeddata:

预加载数据Feeding:

Python产生数据,再把数据喂给后端。

Readingfromfile:

从文件中直接读取

这三种有读取方式有什么区别呢?

我们首先要知道TensorFlow(TF)是怎么样工作的。

TF的核心是用C++写的,这样的好处是运行快,缺点是调用不灵活。

而Python恰好相反,所以结合两种语言的优势。

涉及计算的核心算子和运行框架是用C++写的,并提供API给Python。

Python调用这些API,设计训练模型(Graph),再将设计好的Graph给后端去执行。

简而言之,Python的角色是Design,C++是Run。

一、预加载数据:

 

[python]viewplaincopyimporttensorflowastf#设计Graphx1=tf.constant([2,3,4])x2=tf.constant([4,0,1])y=tf.add(x1,x2)#打开一个session-->计算ywithtf.Session()assess:

printsess.run(y)

二、python产生数据,再将数据喂给后端[python]viewplaincopyimporttensorflowastf#设计Graphx1=tf.placeholder(tf.int16)x2=tf.placeholder(tf.int16)y=tf.add(x1,x2)#用Python产生数据li1=[2,3,4]li2=[4,0,1]#打开一个session-->喂数据-->计算ywithtf.Session()assess:

printsess.run(y,feed_dict={x1:

li1,x2:

li2})说明:

在这里x1,x2只是占位符,没有具体的值,那么运行的时候去哪取值呢?

这时候就要用到sess.run()中的feed_dict参数,将Python产生的数据喂给后端,并计算y。

这两种方案的缺点:

 

1、预加载:

将数据直接内嵌到Graph中,再把Graph传入Session中运行。

当数据量比较大时,Graph的传输会遇到效率问题。

2、用占位符替代数据,待运行的时候填充数据。

前两种方法很方便,但是遇到大型数据的时候就会很吃力,即使是Feeding,中间环节的增加也是不小的开销,比如数据类型转换等等。

最优的方案就是在Graph定义好文件读取的方法,让TF自己去从文件中读取数据,并解码成可使用的样本集。

三、从文件中读取,简单来说就是将数据读取模块的图搭好1、准备数据,构造三个文件,A.csv,B.csv,C.csv

 

[python]viewplaincopy$echo-e"Alpha1,A1\nAlpha2,A2\nAlpha3,A3">A.csv$echo-e"Bee1,B1\nBee2,B2\nBee3,B3">B.csv$echo-e"Sea1,C1\nSea2,C2\nSea3,C3">C.csv

2、单个Reader,单个样本[python]viewplaincopy#-*-coding:

utf-8-*-importtensorflowastf#生成一个先入先出队列和一个QueueRunner,生成文件名队列filenames=['A.csv','B.csv','C.csv']filename_queue=tf.train.string_input_producer(filenames,shuffle=False)#定义Readerreader=tf.TextLineReader()key,value=reader.read(filename_queue)#定义Decoderexample,label=tf.decode_csv(value,record_defaults=[['null'],['null']])#example_batch,label_batch=tf.train.shuffle_batch([example,label],batch_size=1,capacity=200,min_after_dequeue=100,num_threads=2)#运行Graphwithtf.Session()assess:

coord=tf.train.Coordinator()#创建一个协调器,管理线程threads=tf.train.start_queue_runners(coord=coord)#启动QueueRunner,此时文件名队列已经进队。

foriinrange(10):

printexample.eval(),label.eval()coord.request_stop()coord.join(threads)说明:

这里没有使用tf.train.shuffle_batch,会导致生成的样本和label之间对应不上,乱序了。

生成结果如下:

 

Alpha1A2

Alpha3B1

Bee2B3

Sea1C2

Sea3A1

Alpha2A3

Bee1B2

Bee3C1

Sea2C3

Alpha1A2解决方案:

用tf.train.shuffle_batch,那么生成的结果就能够对应上。

 

[python]viewplaincopy#-*-coding:

utf-8-*-importtensorflowastf#生成一个先入先出队列和一个QueueRunner,生成文件名队列filenames=['A.csv','B.csv','C.csv']filename_queue=tf.train.string_input_producer(filenames,shuffle=False)#定义Readerreader=tf.TextLineReader()key,value=reader.read(filename_queue)#定义Decoderexample,label=tf.decode_csv(value,record_defaults=[['null'],['null']])example_batch,label_batch=tf.train.shuffle_batch([example,label],batch_size=1,capacity=200,min_after_dequeue=100,num_threads=2)#运行Graphwithtf.Session()assess:

coord=tf.train.Coordinator()#创建一个协调器,管理线程threads=tf.train.start_queue_runners(coord=coord)#启动QueueRunner,此时文件名队列已经进队。

foriinrange(10):

e_val,l_val=sess.run([example_batch,label_batch])printe_val,l_valcoord.request_stop()coord.join(threads)

3、单个Reader,多个样本,主要也是通过tf.train.shuffle_batch来实现[python]viewplaincopy#-*-coding:

utf-8-*-importtensorflowastffilenames=['A.csv','B.csv','C.csv']filename_queue=tf.train.string_input_producer(filenames,shuffle=False)reader=tf.TextLineReader()key,value=reader.read(filename_queue)example,label=tf.decode_csv(value,record_defaults=[['null'],['null']])#使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。

#Decoder解后数据会进入这个队列,再批量出队。

#虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。

example_batch,label_batch=tf.train.batch([example,label],batch_size=5)withtf.Session()assess:

coord=tf.train.Coordinator()threads=tf.train.start_queue_runners(coord=coord)foriinrange(10):

e_val,l_val=sess.run([example_batch,label_batch])printe_val,l_valcoord.request_stop()coord.join(threads)

说明:

下面这种写法,提取出来的batch_size个样本,特征和label之间也是不同步的[python]viewplaincopy#-*-coding:

utf-8-*-importtensorflowastffilenames=['A.csv','B.csv','C.csv']filename_queue=tf.train.string_input_producer(filenames,shuffle=False)reader=tf.TextLineReader()key,value=reader.read(filename_queue)example,label=tf.decode_csv(value,record_defaults=[['null'],['null']])#使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。

#Decoder解后数据会进入这个队列,再批量出队。

#虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。

example_batch,label_batch=tf.train.batch([example,label],batch_size=5)withtf.Session()assess:

coord=tf.train.Coordinator()threads=tf.train.start_queue_runners(coord=coord)foriinrange(10):

printexample_batch.eval(),label_batch.eval()coord.request_stop()coord.join(threads)说明:

输出结果如下:

可以看出feature和label之间是不对应的

 

['Alpha1'

'Alpha2''Alpha3''Bee1''Bee2']['B3''C1''C2''C3''A1']

['Alpha2''Alpha3''Bee1''Bee2''Bee3']['C1''C2''C3''A1''A2']

['Alpha3''Bee1''Bee2''Bee3''Sea1']['C2''C3''A1''A2''A3']4、多个reader,多个样本

 

[python]viewplaincopy#-*-coding:

utf-8-*-importtensorflowastffilenames=['A.csv','B.csv','C.csv']filename_queue=tf.train.string_input_producer(filenames,shuffle=False)reader=tf.TextLineReader()key,value=reader.read(filename_queue)record_defaults=[['null'],['null']]#定义了多种解码器,每个解码器跟一个reader相连example_list=[tf.decode_csv(value,record_defaults=record_defaults)for_inrange

(2)]#Reader设置为2#使用tf.train.batch_join(),可以使用多个reader,并行读取数据。

每个Reader使用一个线程。

example_batch,label_batch=tf.train.batch_join(example_list,batch_size=5)withtf.Session()assess:

coord=tf.train.Coordinator()threads=tf.train.start_queue_runners(coord=coord)foriinrange(10):

e_val,l_val=sess.run([example_batch,label_batch])printe_val,l_valcoord.request_stop()coord.join(threads)

tf.train.batch与tf.train.shuffle_batch函数是单个Reader读取,但是可以多线程。

tf.train.batch_join与tf.train.shuffle_batch_join可设置多Reader读取,每个Reader使用一个线程。

至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。

多Reader时,2个Reader就达到了极限。

所以并不是线程越多越快,甚至更多的线程反而会使效率下降。

5、迭代控制,设置epoch参数,指定我们的样本在训练的时候只能被用多少轮

 

[python]viewplaincopy#-*-coding:

utf-8-*-importtensorflowastffilenames=['A.csv','B.csv','C.csv']#num_epoch:

设置迭代数filename_queue=tf.train.string_input_producer(filenames,shuffle=False,num_epochs=3)reader=tf.TextLineReader()key,value=reader.read(filename_queue)record_defaults=[['null'],['null']]#定义了多种解码器,每个解码器跟一个reader相连example_list=[tf.decode_csv(value,record_defaults=record_defaults)for_inrange

(2)]#Reader设置为2#使用tf.train.batch_join(),可以使用多个reader,并行读取数据。

每个Reader使用一个线程。

example_batch,label_batch=tf.train.batch_join(example_list,batch_size=1)#初始化本地变量init_local_op=tf.initialize_local_variables()withtf.Session()assess:

sess.run(init_local_op)coord=tf.train.Coordinator()threads=tf.train.start_queue_runners(coord=coord)try:

whilenotcoord.should_stop():

e_val,l_val=sess.run([example_batch,label_batch])printe_val,l_valexcepttf.errors.OutOfRangeError:

print('EpochsComplete!

')finally:

coord.request_stop()coord.join(threads)coord.request_stop()coord.join(threads)

在迭代控制中,记得添加tf.initialize_local_variables(),官网教程没有说明,但是如果不初始化,运行就会报错。

=========================================================================================对于传统的机器学习而言,比方说分类问题,[x1

x2x3]是feature。

对于二分类问题,label经过one-hot编码之后就会是[0,1]或者[1,0]。

一般情况下,我们会考虑将数据组织在csv文件中,一行代表一个sample。

然后使用队列的方式去读取数据说明:

对于该数据,前三列代表的是feature,因为是分类问题,后两列就是经过one-hot编码之后得到的label

使用队列读取该csv文件的代码如下:

 

[python]viewplaincopy#-*-coding:

utf-8-*-importtensorflowastf#生成一个先入先出队列和一个QueueRunner,生成文件名队列filenames=['A.csv']filename_queue=tf.train.string_input_producer(filenames,shuffle=False)#定义Readerreader=tf.TextLineReader()key,value=reader.read(filename_queue)#定义Decoderrecord_defaults=[[1],[1],[1],[1],[1]]col1,col2,col3,col4,col5=tf.decode_csv(value,record_defaults=record_defaults)features=tf.pack([col1,col2,col3])label=tf.pack([col4,col5])example_batch,label_batch=tf.train.shuffle_batch([features,label],batch_size=2,capacity=200,min_after_dequeue=100,num_threads=2)#运行Graphwithtf.Session()assess:

coord=tf.train.Coordinator()#创建一个协调器,管理线程threads=tf.train.start_queue_runners(coord=coord)#启动QueueRunner,此时文件名队列已经进队。

foriinrange(10):

e_val,l_val=sess.run([example_batch,label_batch])printe_val,l_valcoord.request_stop()coord.join(threads)

输出结果如下:

说明:

 

record_defaults=[[1],[1],[1],[1],[1]]

代表解析的模板,每个样本有5列,在数据中是默认用‘,’隔开的,然后解析的标准是[1],也即每一列的数值都解析为整型。

[1.0]就是解析为浮点,['null']解析为string类型

展开阅读全文
相关资源
猜你喜欢
相关搜索

当前位置:首页 > 党团工作 > 入党转正申请

copyright@ 2008-2022 冰豆网网站版权所有

经营许可证编号:鄂ICP备2022015515号-1