From RidgeRun Developer Wiki
#!/usr/bin/env python3
"""
This modules takes a uff model file and converts it to a TensorRT engine
"""
import argparse
import tensorrt as trt
DEFAULT_MAX_BATCH_SIZE = 32
DEFAULT_ENGINE_NAME = "engine.trt"
def parse_args():
"""Parse the command line input arguments"""
parser = argparse.ArgumentParser(
description='Convert an UFF model to a TensorRT engine')
parser.add_argument(
'--max_batch_size',
metavar='MAX_BATCH',
type=int,
help='Maximum batch size that the engine will process',
default=DEFAULT_MAX_BATCH_SIZE)
parser.add_argument('--width', metavar='WIDTH', type=int,
help='Network input width', required=True)
parser.add_argument('--height', metavar='HEIGHT', type=int,
help='Network input height', required=True)
parser.add_argument('--channels', metavar='CHANNELS', type=int,
help='Network input channels', required=True)
parser.add_argument(
'--max_workspace_size',
metavar='MAX_WORKSPACE',
type=int,
help='Max workspace size')
parser.add_argument(
'--order',
metavar='ORDER',
type=str,
help='Order, available options are NHWC and NCHW',
required=True,
choices={
"NHWC",
"NCHW"})
parser.add_argument('--input_name', metavar='INPUT_NAME', type=str,
help='Input layer name', required=True)
parser.add_argument('--output_name', metavar='OUTPUT_NAME', type=str,
help='Output layer name', required=True)
parser.add_argument('--graph_name', metavar='GRAPH_NAME', type=str,
help='Graph name', required=True)
parser.add_argument('--engine_name', metavar='ENGINE_NAME', type=str,
help='Engine Name', default=DEFAULT_ENGINE_NAME)
args = parser.parse_args()
return args
def convert_uff_to_trt(args):
"""
Converts an an uff model file to a TensorRT engine_name
Parameters
----------
args: a structure containing all needed parameters for parsing.
It must contain:
* width
* height
* channels
* order
* input_name
* output_name
* graph_name
Also optional are:
* max_batch_size
* max_workspace_size
* engine_name
"""
min_severity = trt.Logger.Severity.VERBOSE
logger = trt.Logger(min_severity)
builder = trt.Builder(logger)
builder.max_batch_size = args.max_batch_size
if args.max_workspace_size:
builder.max_workspace_size = args.max_workspace_size
builder_config = builder.create_builder_config()
flags = 0
network = builder.create_network(flags)
parser = trt.UffParser()
dims = trt.Dims([args.width, args.height, args.channels])
if args.order == "NHWC":
trtorder = trt.UffInputOrder.NHWC
elif args.order == "NCHW":
trtorder = trt.UffInputOrder.NCHW
else:
raise RuntimeError("Unknown order: " + str(args.order))
parser.register_input(args.input_name, dims, trtorder)
parser.register_output(args.output_name)
err = parser.parse(args.graph_name, network)
if not err:
raise RuntimeError("Unable to parse: " + args.graph_name)
engine = builder.build_engine(network, builder_config)
if not engine:
raise RuntimeError("Unable to create engine")
host_memory = engine.serialize()
with open(args.engine_name, 'wb') as f:
f.write(host_memory)
if __name__ == "__main__":
args = parse_args()
convert_uff_to_trt(args)