Closures in Numba

python
Published

October 24, 2022

!pip install numba
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (0.56.3)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.7/dist-packages (from numba) (0.39.1)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba) (57.4.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from numba) (4.13.0)
Requirement already satisfied: numpy<1.24,>=1.18 in /usr/local/lib/python3.7/dist-packages (from numba) (1.21.6)
Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->numba) (4.1.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->numba) (3.9.0)
from numba.core import types
from numba.typed import Dict
from numba import njit

Closures in Numba

Closures can be used to dinamically create different versions of a function based on some parameter, a similar functionality can be achieved with functools.partial.

My specific use-case is creating filters for RDataFrame, therefore I need to create numba optimizable functions with no other arguments except the input data.

Parameter is a simple type

In this case numba supports a closure without any issue, in this case have a function which defines a cut on an array, we can create different versions of this function dinamically.

def MB_cut_factory(limit):
    def cut(value):
        return value < limit
    return cut
MB_cut_factory(4)(3)
True
njit(MB_cut_factory(4))(3)
True

Parameter is a complex type

If the parameter is a complex type, unfortunately numba throws a NotImplementedError:

dict_ranges = Dict.empty(
    key_type=types.int64,
    value_type=types.Tuple((types.float64, types.float64))
    )

dict_ranges[3] = (1, 3)

def MB_cut_factory(dict_ranges):
    def cut(series, value):
        return dict_ranges[series][0] < value < dict_ranges[series][1]
    return cut

MB_cut_factory(dict_ranges)(3,2)
True
njit(MB_cut_factory(dict_ranges))(3,2)
NumbaNotImplementedError: ignored

The ugly workaround

Using exec we can brutally create the function definition injecting the dictionary as a string into the function definition itself.

It is ugly but works and gives back a function that can be tested in pure Python before passing it to numba for optimization.

Notice we need to use globals() in the call to exec to have the cut function available in the namespace.

def MB_cut_factory(dict_ranges):
    exec("def cut(series, value):\n    dict_ranges=" +\
         dict_ranges.__str__() +\
        "\n    return dict_ranges[series][0] < value < dict_ranges[series][1]", globals())
    return cut
MB_cut_factory(dict_ranges)(3,2)
True
njit(MB_cut_factory(dict_ranges))(3,2)
True

Questions on Stackoverflow

Trying to find solutions I posted 2 related questions to Stackoverflow, plase contribute there if you have better suggestions: