Script: caffe to trt.py

From RidgeRun Developer Wiki
#!/usr/bin/env python3

"""
This modules takes a caffe model file and converts it to a TensorRT engine
"""

import argparse
import tensorrt as trt

DEFAULT_MAX_BATCH_SIZE = 32
DEFAULT_ENGINE_NAME = "engine.trt"


def parse_args():
    """Parse the command line input arguments"""

    parser = argparse.ArgumentParser(
        description='Convert an Caffe model to a TensorRT engine')
    parser.add_argument(
        '--max_batch_size',
        metavar='MAX_BATCH',
        type=int,
        help='Maximum batch size that the engine will process',
        default=DEFAULT_MAX_BATCH_SIZE)
    parser.add_argument(
        '--dtype',
        metavar='DTYPE',
        type=str,
        help='Engine data type',
        default='FP32',
        choices={
            'FP32',
            'FP16',
            'INT8'})
    parser.add_argument(
        '--max_workspace_size',
        metavar='MAX_WORKSPACE',
        type=int,
        help='Max workspace size')
    parser.add_argument('--output_name', metavar='OUTPUT_NAME', type=str,
                        help='Output layer name', required=True)
    parser.add_argument(
        '--deploy_file',
        metavar='DEPLOY_NAME',
        type=str,
        help='Plain text prototxt file used to define the model',
        required=True)
    parser.add_argument(
        '--model_file',
        metavar='MODEL_NAME',
        type=str,
        help='Binary prototxt Caffe model that contains the weights',
        required=True)
    parser.add_argument('--engine_name', metavar='ENGINE_NAME', type=str,
                        help='Engine Name', default=DEFAULT_ENGINE_NAME)

    args = parser.parse_args()

    return args


def convert_caffe_to_trt(args):
    """
    Converts an an Caffe model file to a TensorRT engine_name

    Parameters
    ----------
    args: a structure containing all needed parameters for parsing.
        It must contain:
        * output_name
        * deploy_file
        * model_file
        Also optional are:
        * dtype
        * max_batch_size
        * max_workspace_size
        * engine_name
    """
    min_severity = trt.Logger.Severity.VERBOSE
    logger = trt.Logger(min_severity)

    builder = trt.Builder(logger)
    builder.max_batch_size = args.max_batch_size

    if args.max_workspace_size:
        builder.max_workspace_size = args.max_workspace_size
    else:
        builder.max_workspace_size = 1 << 20

    if args.dtype == "FP32":
        datatype = trt.float32
    elif args.dtype == "FP16":
        datatype = trt.float16
    elif args.dtype == "INT8":
        datatype = trt.int8

    flags = 0
    network = builder.create_network(flags)

    parser = trt.CaffeParser()

    model_tensors = parser.parse(
        args.deploy_file,
        args.model_file,
        network,
        datatype)

    network.mark_output(model_tensors.find(args.output_name))

    engine = builder.build_cuda_engine(network)
    if not engine:
        raise RuntimeError("Unable to create engine")

    host_memory = engine.serialize()

    with open(args.engine_name, 'wb') as f:
        f.write(host_memory)


if __name__ == "__main__":
    args = parse_args()
    convert_caffe_to_trt(args)