TensorFlow Saver
只加载权值
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "./checkpoint/finetune")
加载图结构和权值
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./checkpoint/finetune.meta')
saver.restore(sess, "./checkpoint/finetune")
加载部分权值
switch the optimizer from rmsprop
to adam
, 加载权值失败
with tf.Session() as sess:
tf.global_variables_initializer().run()
all_variables = tf.global_variables()
variables_to_restore = []
for var in all_variables:
# if 'Momentum' in var.name:
if 'Adam' in var.name or '_power' in var.name:
print("Ignore ", var.name)
continue
variables_to_restore.append(var)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, "./checkpoint/pre_model/finetune")
冻结权值,保存成PB
from tensorflow.python.framework import graph_util
# 输出结点
pred_classes = tf.argmax(pred, axis=1, name="output_cls")
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output_cls'])
跑PB文件
from tensorflow.python.platform import gfile
sess = tf.Session()
with gfile.FastGFile('./checkpoint/model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='') # 导入计算图
# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())
# 输入
input_ = sess.graph.get_tensor_by_name('input:0')
prob_ = sess.graph.get_tensor_by_name('kepp_prob:0')
rslt = sess.graph.get_tensor_by_name('output_cls:0')
ret = sess.run(rslt, feed_dict={input_: mnist.test.images[:2], prob_: 1.0})
获得Weight的值
import tensorflow as tf
# for checkpoint
for tv in tf.trainable_variables():
print (tv.name)
b = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/bias:0")
w = tf.get_default_graph().get_tensor_by_name("generate/resnet_stack/weight:0")
# for pb
# https://stackoverflow.com/questions/35336648/list-of-tensor-names-in-graph-in-tensorflow
sess = tf.Session()
op = sess.graph.get_operations()
[m.values() for m in op][1]
out:
(<tf.Tensor 'conv1/weights:0' shape=(4, 4, 3, 32) dtype=float32_ref>,)
with tf.Session() as sess:
with gfile.FastGFile(pb_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
op = sess.graph.get_operations()
for m in op:
if 'weights' in m.name:
print(m.name)
val = m.values()[0]
val_np = sess.run(val)
pass
# [m.values() for m in op][1]
# for tflite
# https://stackoverflow.com/questions/52111699/how-can-i-view-weights-in-a-tflite-file
from tflite import Model
buf = open('/path/to/mode.tflite', 'rb').read()
model = Model.Model.GetRootAsModel(buf, 0)
subgraph = model.Subgraphs(0)
# Check tensor.Name() to find the tensor_idx you want
tensor = subgraph.Tensors(tensor_idx)
buffer_idx = tensor.Buffer()
buffer = model.Buffers(buffer_idx)
# After that you'll be able to read the data by calling buffer.Data()
修改Weight的值
all_vars = tf.global_variables()
for var in all_vars:
sess.run(var.assign(val)) # val is numpy.ndarray
模型转换 (checkpoints,pb, tflite)
import sys
import numpy as np
import tensorflow as tf
if not sys.version_info[1] == 5:
raise Exception('It must be Python 3.5')
# 关闭tf c++ 输出
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# remove the disgusting "deprecated" warning messages
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
def save_pb(model_path, pb_path, quantize=False):
"""
@pargm: model_path='./checkpoint/quantize/model'
@pargm: pb_path='./checkpoint/quantize/model.pb'
"""
from tensorflow.python.framework import graph_util
_inptus = tf.placeholder(tf.float32, [1, 112, 112, 3], name='inputs')
# _is_training = True if is_quantize else tf.placeholder(tf.bool, name="is_training")
rslt = build_ResNet(_inptus, False)
if quantize:
tf.contrib.quantize.create_eval_graph()
device_count = {"GPU": 0}
with tf.Session(config=tf.ConfigProto(device_count=device_count)) as sess:
saver = tf.train.Saver()
saver.restore(sess, model_path)
print("save model")
if quantize:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['resnet/fc/output/act_quant/FakeQuantWithMinMaxVars'])
else:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['resnet/fc/output/BiasAdd'])
with tf.gfile.FastGFile(model_path + ".pb", mode='wb') as f:
f.write(constant_graph.SerializeToString())
def val_pb(pb_path):
from tensorflow.python.platform import gfile
tf.reset_default_graph()
# data
input_data = np.random.rand(...)
print(input_data.shape)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
with gfile.FastGFile(pb_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
_inputs = sess.graph.get_tensor_by_name('inputs:0')
rslt = sess.graph.get_tensor_by_name('resnet/fc/output/BiasAdd:0')
rslt_ = sess.run(rslt, feed_dict={_inputs: input_data})
# print(rslt_)
pass
def save_lite(pb_path, is_quantize=False):
"""
r1.14
https://www.tensorflow.org/api_docs/python/tf/lite/TFLiteConverter
"""
print("INFO: save tflite")
converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
# converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
pb_path, ['inputs'], ['resnet/fc/output/BiasAdd'])
if is_quantize:
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 255.)}
tflite_model = converter.convert()
lite_path = pb_path.replace(".pb", ".tflite")
open(lite_path, "wb").write(tflite_model)
pass
def val_lite(lite_path):
print("INFO: verify tflite")
interpreter = tf.contrib.lite.Interpreter(model_path=lite_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details:")
print(input_details[0])
# print(input_details[1])
print("output_details:")
print(output_details[0])
# idx = input_details[0]['index']
# print("input_details index: %s" % idx)
# idx = input_details[1]['index']
# print("input_details index: %s" % idx)
# # Test model on random input data.
# input_shape = input_details[0]['shape']
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
# interpreter.set_tensor(input_details[0]['index'], input_data)
# input_shape = input_details[1]['shape']
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
# interpreter.set_tensor(input_details[1]['index'], input_data)
# print(f"input_c :\n {input_c.shape} {input_c.dtype}")
# print(f"input_d :\n {input_d.shape} {input_d.dtype}")
inputs = np.ones((1, 112, 112, 3)).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], inputs)
# interpreter.set_tensor(input_details[1]['index'], input_d)
interpreter.invoke()
rslt_ = interpreter.get_tensor(output_details[0]['index'])
print("rslt_ shape: {}, \n {}".format(rslt_.shape, rslt_))
pass