507
edits
No edit summary |
|||
Line 100: | Line 100: | ||
You can refer to [https://www.ridgerun.com/store/Deep-Learning-Models-and-Binaries-c33344794 R2Inference Model Zoo] for pre-trained models suitable for evaluating GstInference. | You can refer to [https://www.ridgerun.com/store/Deep-Learning-Models-and-Binaries-c33344794 R2Inference Model Zoo] for pre-trained models suitable for evaluating GstInference. | ||
== Create a model using saved weights from a .ckpt file == | |||
In some resources, pre-trained file can be provided in a .ckpt format, this corresponds to and old version output of the of a '''saver''' object and is the equivalent of your .ckpt-data obtained by the current '''saver'''. In this example, a .ckpt file for the '''InceptionV1''' graph is provided and it needs to be converted to a model suitable for GstInference. First you need to obtain the .ckpt file: | |||
<syntaxhighlight lang="bash"> | |||
mkdir gstinference_tf_model | |||
cd gstinference_tf_model | |||
wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz | |||
tar -xzf inception_v1_2016_08_28.tar.gz | |||
</syntaxhighlight> | |||
Then you need to associate the .ckpt file with the graph object. In case of the '''InceptionV1''' graph, it can be obtained from within the Tensorflow resources: | |||
<syntaxhighlight lang="python"> | |||
#! /usr/bin/env python3 | |||
import numpy as np | |||
import tensorflow as tf | |||
from tensorflow.contrib.slim.nets import inception | |||
slim = tf.contrib.slim | |||
def unpack(name, image_size, num_classes): | |||
with tf.Graph().as_default(): | |||
image = tf.placeholder("float", [1, image_size, image_size, 3], name="input") | |||
with slim.arg_scope(inception.inception_v1_arg_scope()): | |||
logits, _ = inception.inception_v1(image, num_classes, is_training=False, spatial_squeeze=False) | |||
probabilities = tf.nn.softmax(logits) | |||
init_fn = slim.assign_from_checkpoint_fn('inception_v1.ckpt', slim.get_model_variables('InceptionV1')) | |||
with tf.Session() as sess: | |||
init_fn(sess) | |||
saver = tf.train.Saver(tf.global_variables()) | |||
saver.save(sess, "output/"+name) | |||
unpack('inception-v1', 224, 1001) | |||
</syntaxhighlight> | |||
Finally you can freeze the graph by either the [https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py freeze_graph.py ] or the following python3 script: | |||
<syntaxhighlight lang="python"> | |||
import tensorflow as tf | |||
#Step 1 | |||
#import the model metagraph | |||
saver = tf.train.import_meta_graph('./output/inception-v1.meta', clear_devices=True) | |||
#make that as the default graph | |||
graph = tf.get_default_graph() | |||
input_graph_def = graph.as_graph_def() | |||
sess = tf.Session() | |||
#now restore the variables | |||
saver.restore(sess, "./inception_v1.ckpt") | |||
#Step 2 | |||
# Find the output name | |||
graph = tf.get_default_graph() | |||
for op in graph.get_operations(): | |||
print (op.name) | |||
#Step 3 | |||
from tensorflow.python.platform import gfile | |||
from tensorflow.python.framework import graph_util | |||
output_node_names="InceptionV1/Logits/Predictions/Reshape_1" | |||
output_graph_def = graph_util.convert_variables_to_constants( | |||
sess, # The session | |||
input_graph_def, # input_graph_def is useful for retrieving the nodes | |||
output_node_names.split(",") ) | |||
#Step 4 | |||
#output folder | |||
output_fld ='./' | |||
#output pb file name | |||
output_model_file = 'inceptionv1_tensorflow.pb' | |||
from tensorflow.python.framework import graph_io | |||
#write the graph | |||
graph_io.write_graph(output_graph_def, output_fld, output_model_file, as_text=False) | |||
</syntaxhighlight> | |||
=Tools= | =Tools= |
edits