How to migrate Pytorch to TensorRT

From RidgeRun Developer Wiki


Introduction

This guide will try to help people that have a pyTorch model and want to migrate it to Tensor RT in order to use the full potential of NVIDIA hardware for inferences and training.

Installing requirements

  • pip3 install torch
  • pip3 install onnx
  • pip3 install onnxruntime
  • pip3 install pycuda

Process overview

First, the torch model needs to be migrated to Onnx, an open standard for machine learning models. After that, the Onnx model can be parsed with NVIDIA's OnnxParser, and can be used as is or written to a file in order to save it.

Torch to Onnx

First, the process needs a valid array that has the same shape and properties of the input that normally feeds the torch model. In order to do that:

1. Create empty array:

x = numpy.empty((x, y, z, w),dtype=numpy.uint8)
tensor = torch.tensor(x).type(torch.uint8)

With x,y,z,w the sizes of the different axis of the array, if it's not known, can be obtained when running an inference with the model, using numpy's shape property if it's a numpy array:

input = InputGetter()
--print(input.shape)--
output = model(input)

If it fails with 'list object has no attribute shape', it means that is probably a list of numpy/tensor arrays, to check it use prints if possible, if not, check it with nested for's:

input = InputGetter()
--for element in input:--
--  print(element.shape)--
output = model(input)

Also check the type, to make sure the input is the same as if it were a real inference. For this example the input is a list with two arrays:

[torch.Size([1, 3, 8, 256, 256]),torch.Size([1, 3, 32, 256, 256])]

2. Load the torch model

model = CustomModelClass.get_model(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=torch.device('DEVICE')))

Where CustomModelClass is the class that is being used to handle the model. PATH is the torch model path, and DEVICE is the target device to load the data, 'cpu' if cuda is not available or 'cuda' if it is. For more reference check: Saving and loading models Torch and Tensor Attributes

3. Do model migrate

torch.onnx.export(model,input, OUT_PATH,verbose=True, input_names=['input',...,'inputN'],output_names=['output',....,'outputN'], export_params=True)
model = onnx.load(OUT_PATH)
try:
    onnx.checker.check_model(model)
except onnx.checker.ValidationError as e:
    print('Error with model: %s' % e)
    exit(-1)
else:
    print("Model migrated")

Where OUT_PATH is the path and name where the onnx model will be written. Also, check the models input and output, if yours has more than one input or output, it's a good idea to label them. In this case, since the example is a list with two tensors/numpy arrays that means there are two inputs, so it's a good idea to add the labels to input as input_names=['input1','input2']. This input label is because onnx uses a dictionary to load the inputs, and uses the labels as keys.

4. Test model if possible

import onnxruntime as ort
import numpy
import torch

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

t1,t2 = InputGetter()
ort_session = ort.InferenceSession(PATH)
ort_inputs = {'input1':to_numpy(t1),'input2':to_numpy(t2)}
pred = ort_session.run(None, ort_inputs)[0]

PATH is the path where the onnx model was written. The inputs need to be numpy arrays, for that to_numpy() function is used. You can print the result or use something like this to get a more friendly result:

new_pred = torch.tensor(pred)
post_act = torch.nn.Softmax(dim=1)
output = post_act(new_pred).topk(k=1)
print(output)

5. Visualize model

You can use a tool like netron to check the generated model, download it as an image or print it if needed. Also, it's useful to check the inputs and output, and their shapes to check if all went ok.

  • pip install netron

And just type netron in console after install, for more info check: Netron

Onnx to TensorRT

1. With the onnx model the following code can be used to load the model with ONNX_PATH the path to the onnx model

import tensorrt as trt

TRT_LOGGER = trt.Logger()
builder = trt.Builder(TRT_LOGGER)
explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(explicit_batch)
parser = trt.OnnxParser(network, TRT_LOGGER)
if not(parser.parse_from_file(ONNX_PATH)):
    print("Could not parse onnx model")
    exit(-1)
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
plan = builder.build_serialized_network(network, config)
with trt.Runtime(TRT_LOGGER) as runtime:
    model = runtime.deserialize_cuda_engine(plan)

There are changes that can be done to the engine config to affect the generated model and limit resources used while migrating, for more info check, Builder

2. To write it, use this code with ENGINE_PATH as the path to write the TensorRT model

with open(ENGINE_PATH,'wb') as f:
    f.write(model.serialize())

3. To test it use:

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

if not exists(ENGINE_PATH):
    print("ERROR, model not found")
    exit(1)
TRT_LOGGER = trt.Logger()
with open(ENGINE_PATH, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
    model = runtime.deserialize_cuda_engine(f.read())

##data load stage
with model.create_execution_context() as context:
    context.set_binding_shape(model.get_binding_index("input"), (1, 3, 8, 256, 256))##Set INPUT SHAPE to input using key
    context.set_binding_shape(model.get_binding_index("input2"), (1, 3, 32, 256, 256))##INPUT SHAPE

    t1,t2 = InputGetter()
    t1 = ModelTester.to_numpy(t1)
    t2 = ModelTester.to_numpy(t2)

    bindings = []

    input_buffer_t1 = numpy.ascontiguousarray(t1)
    input_memory_t1 = cuda.mem_alloc(t1.nbytes)
    bindings.append(int(input_memory_t1))
                
    input_buffer_t2 = numpy.ascontiguousarray(t2)
    input_memory_t2 = cuda.mem_alloc(t2.nbytes)
    bindings.append(int(input_memory_t2))

    size_out = trt.volume(context.get_binding_shape(2))##in this case since there are 2 inputs and 1 output, the index of the ouput is 2, check your model's input/output config
    type_out = trt.nptype(model.get_binding_dtype(model[2]))
    output_buffer = cuda.pagelocked_empty(size_out, type_out)
    output_memory = cuda.mem_alloc(output_buffer.nbytes)
    output_shape = model.get_binding_shape(model[2])
    output_data = numpy.empty(output_shape,dtype=type_out)
    bindings.append(int(output_memory))

    stream = cuda.Stream()
    cuda.memcpy_htod_async(input_memory_t1, input_buffer_t1, stream)##copy input data to gpu memory
    cuda.memcpy_htod_async(input_memory_t2, input_buffer_t2, stream)
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)##do inference
    stream.synchronize()
    cuda.memcpy_dtoh_async(output_buffer, output_memory, stream)##get output
    output_data[0] = output_buffer##in order to use the same method used before to format the inference output
    new_pred = torch.tensor(output_data)
    post_act = torch.nn.Softmax(dim=1)
    output = post_act(new_pred).topk(k=1)
    print(output)


For direct inquiries, please refer to the contact information available on our Contact page. Alternatively, you may complete and submit the form provided at the same link. We will respond to your request at our earliest opportunity.


Links to RidgeRun Resources and RidgeRun Artificial Intelligence Solutions can be found in the footer below.