當前位置:網站首頁>Tensor Flow PB文件量化到TFLITE

Tensor Flow PB文件量化到TFLITE

2022-01-28 01:27:09 17歲boy想當攻城獅

代碼非常簡單,相關代碼都有注釋,關於representative_dataset的作用可以參考這篇文章:Tensor Flow量化裏representative_dataset參數是什麼意思?_17歲boy的博客-CSDN博客

import tensorflow as tf
import io
import PIL
import numpy as np

def rep():
    #需要是驗證集的數據源
    record_iterator = tf.python_io.tf_record_iterator(path='/home/zhihao/models/research/slim/ci_data/cifar10_train.tfrecord')
    count = 0
    #將圖像從protobu取出來量化成數組
    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)
        #這裏是你存放圖像數據的消息協議名
        image_stream = io.BytesIO(example.features.feature['image/encoded'].bytes_list.value[0])
        image = PIL.Image.open(image_stream)
        #這裏將它固定量化成96x96的數組大小,這樣方便優化
        image = image.resize((96,96))
        #量化,L=灰度圖,1個bit錶示三個像素點
        image = image.convert('L')
        array = np.array(image)
        array = np.expand_dims(array,axis=2)
        array = np.expand_dims(array,axis=0)
        array = ((array / 127.5) - 1.0).astype(np.float32)
        yield([array])
        count += 1
        #最大量化三百張
        if count > 300:
            break

#你的PB文件,這個文件要是包含神經網絡權重的PB文件
converter = tf.lite.TFLiteConverter.from_frozen_graph('/home/zhihao/work/freezed_cifarnet.pb',['input'],['MobilenetV1/Predictions/Reshape_1'])
converter.inference_input_type = tf.lite.constants.INT8
converter.inference_output_type = tf.lite.constants.INT8
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = rep

#量化並保存
tflite_quant_model = converter.convert()
open("test.tflite","wb").write(tflite_quant_model)

版權聲明
本文為[17歲boy想當攻城獅]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/01/202201280127094455.html

隨機推薦