文章目錄
1、匯入模型
2、定義載入函式
3、定義批次載入函式
4、載入資料
5、定義資料預處理及訓練模型的一些超引數
6、定義資料增強模型
7、構建模型
7。1 構建多層感知器(MLP)
7。2 建立一個類似卷積層的patch層
7。3 檢視由patch層隨機生成的影象塊
7。4構建patch 編碼層( encoding layer)
7。5構建ViT模型
8、編譯、訓練模型
9、檢視執行結果
使用Transformer來提升模型的效能
最近幾年,Transformer體系結構已成為自然語言處理任務的實際標準,
但其在計算機視覺中的應用還受到限制。在視覺上,注意力要麼與卷積網路結合使用,
要麼用於替換卷積網路的某些元件,同時將其整體結構保持在適當的位置。2020年10月22日,谷歌人工智慧研究院發表一篇題為“An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”的文章。文章將影象切割成一個個影象塊,組成序列化的資料輸入Transformer執行影象分類任務。當對大量資料進行預訓練並將其傳輸到多箇中型或小型影象識別資料集(如ImageNet、CIFAR-100、VTAB等)時,與目前的卷積網路相比,Vision Transformer(ViT)獲得了出色的結果,同時所需的計算資源也大大減少。
這裡我們以ViT我模型,實現對資料CiFar10的分類工作,模型效能得到進一步的提升。
1、匯入模型
這裡使用了TensorFlow_addons模組,它實現了核心 TensorFlow 中未提供的新功能。
tensorflow_addons的安裝要注意與tf的版本對應關係,請參考:
https://github。com/tensorflow/addons。
安裝addons時要注意其版本與tensorflow版本的對應,具體關係以上這個連結有。
2、定義載入函式
3、定義批次載入函式
4、載入資料
把資料轉換為dataset格式
5、定義資料預處理及訓練模型的一些超引數
6、定義資料增強模型
預處理層是在模型訓練開始之前計算其狀態的層。他們在訓練期間不會得到更新。大多數預處理層為狀態計算實現了adapt()方法。
adapt(data, batch_size=None, steps=None, reset_state=True)該函式引數說明如下:
7、構建模型
7。1 構建多層感知器(MLP)
7。2 建立一個類似卷積層的patch層
7。3 檢視由patch層隨機生成的影象塊
執行結果
Image size: 72 X 72
Patch size: 6 X 6
Patches per image: 144
Elements per patch: 108
7。4構建patch 編碼層( encoding layer)
7。5構建ViT模型
該模型的處理流程如下圖所示
8、編譯、訓練模型
例項化類,執行模型
執行結果
Epoch 1/10
176/176 [==============================] - 68s 333ms/step - loss: 2。6394 - accuracy: 0。2501 - top-5-accuracy: 0。7377 - val_loss: 1。5331 - val_accuracy: 0。4580 - val_top-5-accuracy: 0。9092
Epoch 2/10
176/176 [==============================] - 58s 327ms/step - loss: 1。6359 - accuracy: 0。4150 - top-5-accuracy: 0。8821 - val_loss: 1。2714 - val_accuracy: 0。5348 - val_top-5-accuracy: 0。9464
Epoch 3/10
176/176 [==============================] - 58s 328ms/step - loss: 1。4332 - accuracy: 0。4839 - top-5-accuracy: 0。9210 - val_loss: 1。1633 - val_accuracy: 0。5806 - val_top-5-accuracy: 0。9616
Epoch 4/10
176/176 [==============================] - 58s 329ms/step - loss: 1。3253 - accuracy: 0。5280 - top-5-accuracy: 0。9349 - val_loss: 1。1010 - val_accuracy: 0。6112 - val_top-5-accuracy: 0。9572
Epoch 5/10
176/176 [==============================] - 58s 330ms/step - loss: 1。2380 - accuracy: 0。5626 - top-5-accuracy: 0。9411 - val_loss: 1。0212 - val_accuracy: 0。6400 - val_top-5-accuracy: 0。9690
Epoch 6/10
176/176 [==============================] - 58s 330ms/step - loss: 1。1486 - accuracy: 0。5945 - top-5-accuracy: 0。9520 - val_loss: 0。9698 - val_accuracy: 0。6602 - val_top-5-accuracy: 0。9718
Epoch 7/10
176/176 [==============================] - 58s 330ms/step - loss: 1。1208 - accuracy: 0。6060 - top-5-accuracy: 0。9558 - val_loss: 0。9215 - val_accuracy: 0。6724 - val_top-5-accuracy: 0。9790
Epoch 8/10
176/176 [==============================] - 58s 330ms/step - loss: 1。0643 - accuracy: 0。6248 - top-5-accuracy: 0。9621 - val_loss: 0。8709 - val_accuracy: 0。6944 - val_top-5-accuracy: 0。9768
Epoch 9/10
176/176 [==============================] - 58s 330ms/step - loss: 1。0119 - accuracy: 0。6446 - top-5-accuracy: 0。9640 - val_loss: 0。8290 - val_accuracy: 0。7142 - val_top-5-accuracy: 0。9784
Epoch 10/10
176/176 [==============================] - 58s 330ms/step - loss: 0。9740 - accuracy: 0。6615 - top-5-accuracy: 0。9666 - val_loss: 0。8175 - val_accuracy: 0。7096 - val_top-5-accuracy: 0。9806
313/313 [==============================] - 9s 27ms/step - loss: 0。8514 - accuracy: 0。7032 - top-5-accuracy: 0。9773
Test accuracy: 70。32%
Test top 5 accuracy: 97。73%
In [15]:
從結果看可以來看,測試精度已達70%,這是一個較大提升!
9、檢視執行結果
執行結果
作者 :吳茂貴,資深大資料和人工智慧技術專家,在BI、資料探勘與分析、資料倉庫、機器學習等領域工作超過20年!在基於Spark、TensorFlow、Pytorch、Keras等機器學習和深度學習方面有大量的工程實踐經驗。代表作有《深入淺出Embedding:原理解析與應用實踐》、《Python深度學習基於Pytorch》和《Python深度學習基於TensorFlow》。
——The End——
點選購買