Holoscan Sensor Bridge Sensor Support

From RidgeRun Developer Wiki



Previous: Holoscan SDK/Installation and Testing Index Next: Contact_Us






Introduction

To accelerate an operation within the Holoscan SDK, it is required to create a new Holoscan Operator. Holoscan SDK provides a base class for Holoscan Operators, which is composed mainly of:

  • Holoscan::setup: configures the inputs/outputs and params of the operator
  • Holoscan::start: is called once when the operator starts, and is used for initialising heavy tasks such as allocating memory resources and using parameters.
  • Holoscan::stop: is called once when the operator is stopped, and is used for deinitializing heavy tasks such as deallocating resources that were previously assigned in start().
  • Holoscan::compute: performs the computation.

The lifecycle of the operator is as follows:

The structure of an operator is:

#include "holoscan/holoscan.hpp"

using holoscan::Operator;
using holoscan::OperatorSpec;
using holoscan::InputContext;
using holoscan::OutputContext;
using holoscan::ExecutionContext;
using holoscan::Arg;
using holoscan::ArgList;

class MyOp : public Operator {
 public:
  HOLOSCAN_OPERATOR_FORWARD_ARGS(MyOp)

  MyOp() = default;

  void setup(OperatorSpec& spec) override {
  }

  void start() override {
    HOLOSCAN_LOG_TRACE("MyOp::start()");
  }

  void compute(InputContext&, OutputContext& op_output, ExecutionContext&) override {
    HOLOSCAN_LOG_TRACE("MyOp::compute()");
  };

  void stop() override {
    HOLOSCAN_LOG_TRACE("MyOp::stop()");
  }
};

Creating a custom filter

In this case study, we will cover how to create a gamma correction operator for the Hololink that works on RGBA images. This algorithm receives an image and outputs another image with gamma correction, with a gamma value as a parameter:

  • Inputs: image tensor to transform (RGBA) and a gamma value (float)
  • Outputs: transformed image tensor (RGBA)

The algorithm for gamma correction is:

Moreover, we are going to use CUDA to accelerate the computation.

CUDA Kernel

First, we need to have the algorithm implemented in a function/kernel. This receives basic C types, such as pointers, integers and floats. For this example, the following code will be used as a CUDA Kernel:

#include <hololink/native/cuda_helper.hpp>

namespace {

const char* source = R"(
extern "C" {

/**
 * Apply gamma correction.
 *
 * @param in [in] pointer to image
 * @param components [in] components per pixel
 * @param width [in] width of the image
 * @param height [in] height of the image
 */
__global__ void applyGammaCorrection(unsigned short *image,
                                     int components,
                                     int width,
                                     int height)
{
    int idx_x = blockIdx.x * blockDim.x + threadIdx.x;
    int idx_y = blockIdx.y * blockDim.y + threadIdx.y;

    if ((idx_x >= width) || (idx_y >= height))
        return;

    const int index = (idx_y * width + idx_x) * components;
    const float range = (1 << (sizeof(unsigned short) * 8)) - 1;

    // apply gamma correction to each component except alpha
    for (int component = 0; component < min(components, 3); ++component) {
        float value = (float)(image[index + component]);
        value = powf(value / range, 1.f / GAMMA) * range;
        image[index + component] = (unsigned short)(value + 0.5f);
    }
}

})";

} // anonymous namespace

It is worth noting that the CUDA kernel is expressed in this manner, as Hololink offers a means to perform JIT CUDA compilation and launch the kernels.

This will be saved together with the Holoscan operator.

Holoscan Operator Declaration

In a header file, declare the Holoscan Operator. The proposal is the following:

#include <memory>

#include <holoscan/core/operator.hpp>
#include <holoscan/core/parameter.hpp>
#include <holoscan/utils/cuda_stream_handler.hpp>

#include <cuda.h>

// Notice: structure declaration without polluting with dependencies.
namespace hololink::native {
class CudaFunctionLauncher;
}

namespace hololink::operators {

class GammaCorrectionOp : public holoscan::Operator {
public:
    // Notice: Always forward the arguments (this implements the constructor and destructor)
    HOLOSCAN_OPERATOR_FORWARD_ARGS(GammaCorrectionOp);

    // Notice: Override to implement the operator
    void start() override;
    void stop() override;
    void setup(holoscan::OperatorSpec& spec) override;
    void compute(holoscan::InputContext&, holoscan::OutputContext& op_output, holoscan::ExecutionContext&) override;

private:
    // Notice: this is a parameter declaration: the gamma as a float
    holoscan::Parameter<float> gamma_;

    // Handle the CUDA kernel and stream
    holoscan::CudaStreamHandler cuda_stream_handler_;
    std::shared_ptr<hololink::native::CudaFunctionLauncher> cuda_function_launcher_;

    CUcontext cuda_context_ = nullptr;
    CUdevice cuda_device_ = 0;
};

}

This is going to be saved in gamma_correction.hpp

Holoscan Operator Definition

This section covers the definition of each method of the Holoscan Gamma Operator. In the first part, the CUDA Kernel (aforementioned) is placed. All the code is going to be saved in gamma_correction.cpp.

Setup

The setup override configures the Holoscan Operator on its parameters and the characteristics of its inputs and outputs.

void GammaCorrectionOp::setup(holoscan::OperatorSpec& spec)
{
    // Setup the input and output tensors
    spec.input<holoscan::gxf::Entity>("input");
    spec.output<holoscan::gxf::Entity>("output");

    // Setup the params
    spec.param(gamma_, "gamma", "Gamma", "Gamma correction value", 2.2f);

    // Setup CUDA
    cuda_stream_handler_.define_params(spec);
}

Start / Stop

The start and stop are responsible for heavy initialisation, such as memory allocation/deallocation, operator creation, and JIT compilation. This prepares the Holoscan Operator before and after the compute (the actual data computation).

Start:

void GammaCorrectionOp::start()
{
    // Initialise the devices (in this case, CUDA)
    CudaCheck(cuInit(0));
    CUdevice device;
    CudaCheck(cuDeviceGet(&cuda_device_, cuda_device_ordinal_.get()));
    CudaCheck(cuDevicePrimaryCtxRetain(&cuda_context_, cuda_device_));

    // Push the context
    hololink::native::CudaContextScopedPush cur_cuda_context(cuda_context_);

    // Create the CUDA Kernel with the JIT compilation
    cuda_function_launcher_.reset(new hololink::native::CudaFunctionLauncher(
        source, { "applyGammaCorrection" }, { fmt::format("-D GAMMA={}", gamma_.get()) }));
}

Stop: (notice the SISO initialisation/deinitialisation)

void GammaCorrectionOp::stop()
{
    // Retire the context
    hololink::native::CudaContextScopedPush cur_cuda_context(cuda_context_);

    // Destroy the CUDA Kernel
    cuda_function_launcher_.reset();

    // Release the context
    CudaCheck(cuDevicePrimaryCtxRelease(cuda_device_));
    cuda_context_ = nullptr;
}

Compute

The compute is where the image/data transformation happens. It is a best practice to have everything initialised at this point. The arguments of the method are always the same.

#define CHECK_RUNTIME(bool, message) \\
  if (!(bool)) throw std::runtime_error((message));

void GammaCorrectionOp::compute(holoscan::InputContext& input, holoscan::OutputContext& output, holoscan::ExecutionContext& context)
{
    // Step 1: Get the input tensor
    auto maybe_entity = input.receive<holoscan::gxf::Entity>("input");
    CHECK_RUNTIME(maybe_entity, "Failed to receive input");

    auto& entity = static_cast<nvidia::gxf::Entity&>(maybe_entity.value());

    const auto maybe_tensor = entity.get<nvidia::gxf::Tensor>();
    CHECK_RUNTIME(maybe_tensor, "Tensor not found in message");

    const auto input_tensor = maybe_tensor.value();

    // Step 2: Check the input tensor to meet all requirements
    CHECK_RUNTIME(input_tensor->storage_type() == nvidia::gxf::MemoryStorageType::kDevice,
        "Tensor is not in device memory");
    CHECK_RUNTIME(input_tensor->rank() == 3, "Tensor must be an image");
    CHECK_RUNTIME(input_tensor->element_type() == nvidia::gxf::PrimitiveType::kUnsigned16,
        "Tensor must be RGBA64");


    // Step 3: Get all the image information
    const uint32_t height = input_tensor->shape().dimension(0);
    const uint32_t width = input_tensor->shape().dimension(1);
    const uint32_t components = input_tensor->shape().dimension(2);

    // Step 4: Get the CUDA stream and launch it
    gxf_result_t stream_handler_result = cuda_stream_handler_.from_message(context.context(), entity);
    CHECK_RUNTIME(stream_handler_result == GXF_SUCCESS, 
        "Failed to get the CUDA stream from incoming messages");

    hololink::native::CudaContextScopedPush cur_cuda_context(cuda_context_);
    const cudaStream_t cuda_stream = cuda_stream_handler_.get_cuda_stream(context.context());

    // Step 5: Launch the CUDA Kernel
    if (gamma_ != 1.f) {
        cuda_function_launcher_->launch(
            "applyGammaCorrection",
            { width, height, 1 },
            cuda_stream,
            input_tensor->pointer(), components, width, height);
    }

    // Step 6: Transmit the result (this is an in-place processing)
    auto out_message = nvidia::gxf::Expected<nvidia::gxf::Entity>(entity);
    stream_handler_result = cuda_stream_handler_.to_message(out_message);
    CHECK_RUNTIME(stream_handler_result == GXF_SUCCESS,
      "Failed to add the CUDA stream to the outgoing messages");

    // Emit the tensor
    output.emit(entity);
}


CMake

This is a CMakeLists that lives inside of the Hololink repository:

add_library(gamma_correction STATIC gamma_correction.cpp)

set_property(TARGET gamma_correction PROPERTY POSITION_INDEPENDENT_CODE ON)

add_library(hololink::operators::gamma_correction ALIAS gamma_correction)

target_include_directories(gamma_correction
  INTERFACE
    $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../..>
    $<INSTALL_INTERFACE:src>
  )

target_link_libraries(gamma_correction
  PRIVATE
    hololink::native
    holoscan::core
  )


The code has been inspired by this Hololink module from NVIDIA.

Using the custom filter in Python

One of the strengths of Holoscan is its Python compatibility.

You can find the implementation of the PyBind bridge in this link.

This is an oversimplified version of the aforementioned implementation:

#include <hololink/operators/gamma_correction/gamma_correction.hpp>

namespace py = pybind11;

namespace hololink::operators {

class PyGammaCorrectionOp : public GammaCorrectionOp {
public:
    /* Inherit the constructors */
    using GammaCorrectionOp::GammaCorrectionOp;

    // Define a constructor that fully initializes the object.
    PyGammaCorrectionOp(holoscan::Fragment* fragment, float gamma, int cuda_device_ordinal, const std::string& name = "gamma_correction")
        : GammaCorrectionOp(holoscan::ArgList { holoscan::Arg { "gamma", gamma }, holoscan::Arg { "cuda_device_ordinal", cuda_device_ordinal } })
    {
        name_ = name;
        fragment_ = fragment;
        spec_ = std::make_shared<holoscan::OperatorSpec>(fragment);
        setup(*spec_.get());
    }
};

PYBIND11_MODULE(_gamma_correction, m)
{
    m.attr("__version__") = "dev";

    py::class_<GammaCorrectionOp, PyGammaCorrectionOp, holoscan::Operator, std::shared_ptr<GammaCorrectionOp>>(m, "GammaCorrectionOp")
        .def(py::init<holoscan::Fragment*, float, int, const std::string&>(),
            "fragment"_a,
            "gamma"_a = 2.2f,
            "cuda_device_ordinal"_a = 0,
            "name"_a = "gamma_correction"s)
        .def("setup", &GammaCorrectionOp::setup, "spec"_a);

} // PYBIND11_MODULE

Afterwards, it requires to be registered in Python using the __init__.py from here.

And it can be later used in Python:

class HoloscanApplication(holoscan.core.Application):
    def compose(self):
        # ...
        gamma_correction = hololink_module.operators.GammaCorrectionOp(
            self,
            name="gamma_correction",
            cuda_device_ordinal=self._cuda_device_ordinal,
        )
        
        # ...
        self.add_flow(demosaic, gamma_correction, {("transmitter", "input")})
        self.add_flow(gamma_correction, visualizer, {("output", "receivers")})



Previous: Holoscan SDK/Installation and Testing Index Next: Contact_Us