Script: uff to trt.py

From RidgeRun Developer Wiki
#!/usr/bin/env python3

"""
This modules takes a uff model file and converts it to a TensorRT engine
"""

import argparse
import tensorrt as trt

DEFAULT_MAX_BATCH_SIZE = 32
DEFAULT_ENGINE_NAME = "engine.trt"


def parse_args():
    """Parse the command line input arguments"""

    parser = argparse.ArgumentParser(
        description='Convert an UFF model to a TensorRT engine')
    parser.add_argument(
        '--max_batch_size',
        metavar='MAX_BATCH',
        type=int,
        help='Maximum batch size that the engine will process',
        default=DEFAULT_MAX_BATCH_SIZE)
    parser.add_argument('--width', metavar='WIDTH', type=int,
                        help='Network input width', required=True)
    parser.add_argument('--height', metavar='HEIGHT', type=int,
                        help='Network input height', required=True)
    parser.add_argument('--channels', metavar='CHANNELS', type=int,
                        help='Network input channels', required=True)
    parser.add_argument(
        '--max_workspace_size',
        metavar='MAX_WORKSPACE',
        type=int,
        help='Max workspace size')
    parser.add_argument(
        '--order',
        metavar='ORDER',
        type=str,
        help='Order, available options are NHWC and NCHW',
        required=True,
        choices={
            "NHWC",
            "NCHW"})
    parser.add_argument('--input_name', metavar='INPUT_NAME', type=str,
                        help='Input layer name', required=True)
    parser.add_argument('--output_name', metavar='OUTPUT_NAME', type=str,
                        help='Output layer name', required=True)
    parser.add_argument('--graph_name', metavar='GRAPH_NAME', type=str,
                        help='Graph name', required=True)
    parser.add_argument('--engine_name', metavar='ENGINE_NAME', type=str,
                        help='Engine Name', default=DEFAULT_ENGINE_NAME)

    args = parser.parse_args()

    return args


def convert_uff_to_trt(args):
    """
    Converts an an uff model file to a TensorRT engine_name

    Parameters
    ----------
    args: a structure containing all needed parameters for parsing.
        It must contain:
        * width
        * height
        * channels
        * order
        * input_name
        * output_name
        * graph_name
        Also optional are:
        * max_batch_size
        * max_workspace_size
        * engine_name
    """
    min_severity = trt.Logger.Severity.VERBOSE
    logger = trt.Logger(min_severity)

    builder = trt.Builder(logger)
    builder.max_batch_size = args.max_batch_size

    if args.max_workspace_size:
        builder.max_workspace_size = args.max_workspace_size

    builder_config = builder.create_builder_config()

    flags = 0
    network = builder.create_network(flags)

    parser = trt.UffParser()

    dims = trt.Dims([args.width, args.height, args.channels])

    if args.order == "NHWC":
        trtorder = trt.UffInputOrder.NHWC
    elif args.order == "NCHW":
        trtorder = trt.UffInputOrder.NCHW
    else:
        raise RuntimeError("Unknown order: " + str(args.order))

    parser.register_input(args.input_name, dims, trtorder)
    parser.register_output(args.output_name)

    err = parser.parse(args.graph_name, network)
    if not err:
        raise RuntimeError("Unable to parse: " + args.graph_name)

    engine = builder.build_engine(network, builder_config)
    if not engine:
        raise RuntimeError("Unable to create engine")

    host_memory = engine.serialize()

    with open(args.engine_name, 'wb') as f:
        f.write(host_memory)


if __name__ == "__main__":
    args = parse_args()
    convert_uff_to_trt(args)