From RidgeRun Developer Wiki
#!/usr/bin/env python3
"""
This modules takes a caffe model file and converts it to a TensorRT engine
"""
import argparse
import tensorrt as trt
DEFAULT_MAX_BATCH_SIZE = 32
DEFAULT_ENGINE_NAME = "engine.trt"
def parse_args():
"""Parse the command line input arguments"""
parser = argparse.ArgumentParser(
description='Convert an Caffe model to a TensorRT engine')
parser.add_argument(
'--max_batch_size',
metavar='MAX_BATCH',
type=int,
help='Maximum batch size that the engine will process',
default=DEFAULT_MAX_BATCH_SIZE)
parser.add_argument(
'--dtype',
metavar='DTYPE',
type=str,
help='Engine data type',
default='FP32',
choices={
'FP32',
'FP16',
'INT8'})
parser.add_argument(
'--max_workspace_size',
metavar='MAX_WORKSPACE',
type=int,
help='Max workspace size')
parser.add_argument('--output_name', metavar='OUTPUT_NAME', type=str,
help='Output layer name', required=True)
parser.add_argument(
'--deploy_file',
metavar='DEPLOY_NAME',
type=str,
help='Plain text prototxt file used to define the model',
required=True)
parser.add_argument(
'--model_file',
metavar='MODEL_NAME',
type=str,
help='Binary prototxt Caffe model that contains the weights',
required=True)
parser.add_argument('--engine_name', metavar='ENGINE_NAME', type=str,
help='Engine Name', default=DEFAULT_ENGINE_NAME)
args = parser.parse_args()
return args
def convert_caffe_to_trt(args):
"""
Converts an an Caffe model file to a TensorRT engine_name
Parameters
----------
args: a structure containing all needed parameters for parsing.
It must contain:
* output_name
* deploy_file
* model_file
Also optional are:
* dtype
* max_batch_size
* max_workspace_size
* engine_name
"""
min_severity = trt.Logger.Severity.VERBOSE
logger = trt.Logger(min_severity)
builder = trt.Builder(logger)
builder.max_batch_size = args.max_batch_size
if args.max_workspace_size:
builder.max_workspace_size = args.max_workspace_size
else:
builder.max_workspace_size = 1 << 20
if args.dtype == "FP32":
datatype = trt.float32
elif args.dtype == "FP16":
datatype = trt.float16
elif args.dtype == "INT8":
datatype = trt.int8
flags = 0
network = builder.create_network(flags)
parser = trt.CaffeParser()
model_tensors = parser.parse(
args.deploy_file,
args.model_file,
network,
datatype)
network.mark_output(model_tensors.find(args.output_name))
engine = builder.build_cuda_engine(network)
if not engine:
raise RuntimeError("Unable to create engine")
host_memory = engine.serialize()
with open(args.engine_name, 'wb') as f:
f.write(host_memory)
if __name__ == "__main__":
args = parse_args()
convert_caffe_to_trt(args)