小編給大家分享一下TensorFlow如何將ckpt文件固化成pb文件,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!
讓客戶滿意是我們工作的目標(biāo),不斷超越客戶的期望值來(lái)自于我們對(duì)這個(gè)行業(yè)的熱愛(ài)。我們立志把好的技術(shù)通過(guò)有效、簡(jiǎn)單的方式提供給客戶,將通過(guò)不懈努力成為客戶在信息化領(lǐng)域值得信任、有價(jià)值的長(zhǎng)期合作伙伴,公司提供的服務(wù)項(xiàng)目有:申請(qǐng)域名、網(wǎng)頁(yè)空間、營(yíng)銷軟件、網(wǎng)站建設(shè)、融安網(wǎng)站維護(hù)、網(wǎng)站推廣。將yolo3目標(biāo)檢測(cè)框架訓(xùn)練出來(lái)的ckpt文件固化成pb文件,主要利用了GitHub上的該項(xiàng)目。
為什么要最終生成pb文件呢?簡(jiǎn)單來(lái)說(shuō)就是直接通過(guò)tf.saver保存行程的ckpt文件其變量數(shù)據(jù)和圖是分開(kāi)的。我們知道TensorFlow是先畫(huà)圖,然后通過(guò)placeholde往圖里面喂數(shù)據(jù)。這種解耦形式存在的方法對(duì)以后的遷移學(xué)習(xí)以及對(duì)程序進(jìn)行微小的改動(dòng)提供了極大的便利性。但是對(duì)于訓(xùn)練好,以后不再改變的話這種存在就不再需要。一方面,ckpt文件儲(chǔ)存的數(shù)據(jù)都是變量,既然我們不再改動(dòng),就應(yīng)當(dāng)讓其變成常量,直接‘燒'到圖里面。另一方面,對(duì)于線上的模型,我們一般是通過(guò)C++或者C語(yǔ)言編寫(xiě)的程序進(jìn)行調(diào)用。所以一般模型最終形式都是應(yīng)該寫(xiě)成pb文件的形式。
由于這次的程序直接從GitHub上下載后改動(dòng)較小就能夠運(yùn)行,也就是自己寫(xiě)了很少一部分程序。因此進(jìn)行調(diào)試的時(shí)候還出現(xiàn)了以前根本沒(méi)有注意的一些小問(wèn)題,同時(shí)發(fā)現(xiàn)自己對(duì)TensorFlow還需要更加詳細(xì)的去研讀。
首先對(duì)程序進(jìn)行保存的時(shí)候,利用 saver = tf.train.Saver(), saver.save(sess,checkpoint_path,global_step=global_step)對(duì)訓(xùn)練的數(shù)據(jù)進(jìn)行保存,保存格式為ckpt。但是在恢復(fù)的時(shí)候一直提示有問(wèn)題,(其恢復(fù)語(yǔ)句為:saver = tf.train.Saver(), saver.restore(sess,ckpt_path),其中,ckpt_path是保存ckpt的文件夾路徑)。出現(xiàn)問(wèn)題的原因我估計(jì)是因?yàn)槲沂前凑彰?0個(gè)epoch進(jìn)行保存,而不是讓其進(jìn)行固定次數(shù)的batch進(jìn)行保存,這種固定batch次數(shù)的保存系統(tǒng)會(huì)自動(dòng)保存最近5次的ckpt文件(該方法的ckpt_path=tf.train,latest_checkpoint('ckpt/')進(jìn)行回復(fù))。那么如何將利用epoch的次數(shù)進(jìn)行保存呢(這種保存不是近5次的保存,而是每進(jìn)行一次保存就會(huì)留下當(dāng)時(shí)保存的ckpt,而那種按照batch的會(huì)在第n次保存,會(huì)將n-5次的刪除,n>5)。
我們可以利用:ckpt = tf.train.get_checkpoint_state(ckpt_path),獲取最新的ckptpoint文件,然后利用saver.restore(sess,ckpt.checkpoint_path)進(jìn)行恢復(fù)。當(dāng)然為了安全起見(jiàn),應(yīng)該對(duì)ckpt和ckpt.checkpoint_path進(jìn)行判斷是否存在后,再進(jìn)行恢復(fù)語(yǔ)句的調(diào)用,建議打開(kāi)ckptpoint看一下,里面記錄的最近五次的model的路徑,一目了然。即:
saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(model_path) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path)
對(duì)于固化網(wǎng)絡(luò),網(wǎng)上有很多的介紹。之所以再介紹,還是由于是用了別人的網(wǎng)絡(luò)而不是自己的網(wǎng)絡(luò)遇到的坑。在固化時(shí)候我們需要知道輸出tensor的名字,而再恢復(fù)的時(shí)候我們需要知道placeholder的名字。但是,如果網(wǎng)絡(luò)復(fù)雜或者別人的網(wǎng)絡(luò)命名比較復(fù)雜,或者name=,根本就沒(méi)有自己命名而用的系統(tǒng)自定義的,這樣捋起來(lái)還是比較費(fèi)勁的。當(dāng)時(shí)在網(wǎng)上查找的一些方法,像打印整個(gè)網(wǎng)絡(luò)變量的方法(先不管輸出的網(wǎng)路名稱,甚至隨便起一個(gè)名字,先固化好pb文件,然后對(duì)pb文件進(jìn)行讀取,最后打印操作的名字:
graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, ['cls_score/cls_score', 'cls_prob'] # We split on comma for convenience ) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print ('開(kāi)始打印節(jié)點(diǎn)名字') for op in graph.get_operations(): print(op.name) print("%d ops in the final graph." % len(output_graph_def.node))
代碼一
這樣盡然也能打印出來(lái)(盡管輸出名字是隨便命名的)。但是打印出來(lái)的是所有的節(jié)點(diǎn)的名字,簡(jiǎn)直不要太多。這樣找的話,一方面可能找不對(duì),另一方面也太費(fèi)事。
那么怎么辦?答案簡(jiǎn)單的讓我也很無(wú)語(yǔ)。其實(shí),對(duì)ckpt進(jìn)行數(shù)據(jù)恢復(fù)的時(shí)候,直接打印輸出的tensor名字就可以。比如說(shuō)在saver以及placeholder定義的時(shí)候:output = model.yolo_inference(images, config.num_anchors / 3, config.num_classes, is_training),我們?cè)诤竺娓痪洌簆rint output,從打印出來(lái)的信息即可查看。placeholder的查看方法同樣如此。
對(duì)網(wǎng)絡(luò)進(jìn)行固化:
代碼:
input_image_shape = tf.placeholder(dtype = tf.int32, shape = (2,)) input_image = tf.placeholder(shape = [None, 416, 416, 3], dtype = tf.float32) predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path) boxes, scores, classes = predictor.predict(input_image, input_image_shape) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(model_path) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # 采用meta 結(jié)構(gòu)加載,不需要知道網(wǎng)絡(luò)結(jié)構(gòu) # saver = tf.train.import_meta_graph(model_path, clear_devices=True) # 這里的model_path是model.ckpt.meta文件的全路徑 # ckpt_model_path 是保存模型的文件夾路徑 # saver.restore(sess, tf.train.latest_checkpoint(ckpt_model_path)) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, ['concat_11','concat_12','concat_13'] # We split on comma for convenience ) # # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString())
由于固化的時(shí)候是需要先恢復(fù)ckpt網(wǎng)絡(luò)的,所以還是在restore前寫(xiě)了placeholder和輸出tensor的定義(需要注點(diǎn)意的是,我們保存的ckpt文件是訓(xùn)練階段的graph和變量等,其inference輸出和最終predict的輸出的Tensor不一樣,因此predict與inference的輸出相比,還包括了一些后處理,比如說(shuō)nms等等,只有這些后處理也是TensorFlow框架內(nèi)的方法寫(xiě)的,才能使最終形成的pb文件能夠做到輸入一張圖片,直接輸出最終結(jié)果。因此,對(duì)于目標(biāo)檢測(cè)任務(wù),把后處理任務(wù)也交由TensorFlow內(nèi)的api來(lái)實(shí)現(xiàn),可免去夸平臺(tái)讀取pb文件后仍然需要重新進(jìn)行后處理等相關(guān)程序的編寫(xiě)帶來(lái)的不必要麻煩)。然后結(jié)合保存變量的那個(gè)文件(ckpt),將變量恢復(fù)到inference過(guò)程所需的變量數(shù)據(jù)(predict包括inference和eval兩個(gè)過(guò)程,訓(xùn)練過(guò)程只有inference和loss過(guò)程參與,而預(yù)測(cè)過(guò)程多了一個(gè)后處理eval過(guò)程,eval過(guò)程無(wú)變量。這樣在生成pb文件的時(shí)候也把后處理eval固化進(jìn)去。喂給網(wǎng)絡(luò)數(shù)據(jù),即可得到輸出tensor。
由于有讀者在此問(wèn)到了還是沒(méi)有弄明白'concat_11','concat_12','concat_13'是如何得來(lái)的,我在這里就在詳細(xì)說(shuō)一下:
是這樣的,在我們恢復(fù)網(wǎng)絡(luò)的時(shí)候肯定需要知道saver這個(gè)對(duì)象的,在這里介紹兩種方法生成這個(gè)對(duì)象的方法。
一:
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
其中meta_graph_location就是保存模型時(shí)的.meta文件的路徑。保存后有四個(gè)文件(checkpoint、.index、.data-00000-of-00001和.meta文件)。.meta文件就是整個(gè)TensorFlow的結(jié)構(gòu)圖。
二:
saver = tf.train.Saver()
本文采用的是第二種方法(上面已經(jīng)有詳細(xì)的代碼),由于這種方法得到的saver對(duì)象,他不知道具體圖是什么樣的,因此在恢復(fù)前我有用如下代碼
predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path) boxes, scores, classes = predictor.predict(input_image, input_image_shape)
把整個(gè)結(jié)構(gòu)又加載了一遍。如果采用第一種方法,是不需要在重寫(xiě)這兩行代碼的。
我們要的就是 boxes, scores, classes這三個(gè)tensor的結(jié)果,并且想知道他們?nèi)齻€(gè)tensor的名字。你直接利用print(boxes, scores, classes)打印出來(lái)這三個(gè)tensor就會(huì)出來(lái)這三個(gè)tensor具體信息(包括名字,和shape,dtype等)。這個(gè)只是利用第二種方法得到saver對(duì)象,然后恢復(fù)ckpt文件,不涉及到固化pb文件問(wèn)題。固化pb文件是需要知道這三個(gè)tensor的名字,所以需要打印看一下。
如果說(shuō),我只拿到了保存后的四個(gè)文件(checkpoint、.index、.data-00000-of-00001和.meta文件),其相應(yīng)用代碼寫(xiě)成的結(jié)構(gòu)圖不清楚,比如說(shuō)利用這兩行代碼:
predictor = yolo_predictor(config.obj_threshold, config.nms_threshold, config.classes_path, config.anchors_path) boxes, scores, classes = predictor.predict(input_image, input_image_shape)
畫(huà)出的結(jié)構(gòu)圖是什么樣的,我不知道。那么,想要知道具體的placehold和輸出tensor的名字,那只能通過(guò)代碼一中,打印出所有的OP操作節(jié)點(diǎn),然后進(jìn)行人工遍歷了。
讀取pb文件:
代碼:
def pb_detect(image_path, pb_model_path): os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_index image = Image.open(image_path) resize_image = letterbox_image(image, (416, 416)) image_data = np.array(resize_image, dtype = np.float32) image_data /= 255. image_data = np.expand_dims(image_data, axis = 0) with tf.Graph().as_default(): output_graph_def = tf.GraphDef() with open(pb_model_path, "rb") as f: output_graph_def.ParseFromString(f.read()) tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) input_image_tensor = sess.graph.get_tensor_by_name("Placeholder_1:0") input_image_tensor_shape = sess.graph.get_tensor_by_name("Placeholder:0") # 定義輸出的張量名稱 #output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0") boxes = sess.graph.get_tensor_by_name("concat_11:0") scores = sess.graph.get_tensor_by_name("concat_12:0") classes = sess.graph.get_tensor_by_name("concat_13:0") # 讀取測(cè)試圖片 # 測(cè)試讀出來(lái)的模型是否正確,注意這里傳入的是輸出和輸入節(jié)點(diǎn)的tensor的名字(需要在名字后面加:0),不是操作節(jié)點(diǎn)的名字 out_boxes, out_scores, out_classes= sess.run([boxes,scores,classes], feed_dict={ input_image_tensor: image_data, input_image_tensor_shape: [image.size[1], image.size[0]] })
可以看到讀取pb文件只需要比恢復(fù)ckpt文件容易的多,直接將placeholder的名字獲取到,將數(shù)據(jù)輸入恢復(fù)的網(wǎng)絡(luò),以及讀取輸出即可。
小記:
有可能是TensorFlow版本更新或者其他原因,在后來(lái)工作中加載pb文件是報(bào)錯(cuò)了:
ValueError: Fetch argument <tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024) dtype=float32> cannot be interpreted as a Tensor. (tf.Tensor 'shuffle_batch:0' shape=(1, 300, 1024), dtype=float32) is not an element of this graph.)
將上面讀取pb文件的代碼with tf.Graph().as_default():改成
global graph graph = tf.get_default_graph() with graph.as_default():
以上是“TensorFlow如何將ckpt文件固化成pb文件”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司行業(yè)資訊頻道!
另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無(wú)理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國(guó)服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡(jiǎn)單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場(chǎng)景需求。
分享題目:TensorFlow如何將ckpt文件固化成pb文件-創(chuàng)新互聯(lián)
當(dāng)前路徑:http://jinyejixie.com/article8/ddeiop.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供App設(shè)計(jì)、Google、搜索引擎優(yōu)化、企業(yè)建站、品牌網(wǎng)站建設(shè)、微信小程序
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容
移動(dòng)網(wǎng)站建設(shè)知識(shí)