I am trying to pass a list of jitclass objects to a jitted function, but I've been unable to find a suitable type signature for njit. I don't understand well how Numba thinks, so I would really appreciate any help that you might be able to provide.
I will demonstrate with a simple example.
First step: processing a single jitclass object (this part actually works fine).
@nb.experimental.jitclass([('x', nb.float64)])
class test_jitclass(object):
def __init__(self, x):
self.x = x
test_jitclass_type = test_jitclass.class_type.instance_type
@nb.njit(nb.int64(test_jitclass_type))
def process_test_jitclass(tjc):
print('processing: ', tjc.x)
return 1
process_test_jitclass(test_jitclass(5.))
Second step: extending this to a list of jitclass objects (this is where I start to get errors).
First attempt at the list:
@nb.experimental.jitclass([('x', nb.float64)])
class test_jitclass(object):
def __init__(self, x):
self.x = x
test_jitclass_type = test_jitclass.class_type.instance_type
test_jitclass_list_type = nb.types.ListType(test_jitclass_type)
list_test_jitclass = []
for i in range(3):
list_test_jitclass.append(test_jitclass(i))
@nb.njit(nb.int64(test_jitclass_list_type))
def process_list_test_jitclass(list_test_jitclass):
for tj in list_test_jitclass:
print('processing: ', tj.x)
return 1
process_list_test_jitclass(list_test_jitclass)
Here is the error that results:
# Traceback (most recent call last):
# File "<string>", line 1, in <module>
# File "/home/____/.pyenv/versions/3.9.16/lib/python3.9/site-packages/numba/core/dispatcher.py", line 703, in _explain_matching_error
# raise TypeError(msg)
# TypeError: No matching definition for argument type(s) reflected list(instance.jitclass.test_jitclass#7f83a409d970<x:float64>)<iv=None>
I know that the "reflected list" is being deprecated and I believe the new way is to use nb.typed.List. However, when I execute this line, test_jitclass_list_type = nb.typed.List(test_jitclass_type), I get a new error:
# Traceback (most recent call last):
# File "<string>", line 7, in <module>
# File "/home/____/.pyenv/versions/3.9.16/lib/python3.9/site-packages/numba/typed/typedlist.py", line 268, in __init__
# for i in args[0]:
# File "/home/____/.pyenv/versions/3.9.16/lib/python3.9/site-packages/numba/core/types/abstract.py", line 185, in __getitem__
# ndim, layout = self._determine_array_spec(args)
# File "/home/____/.pyenv/versions/3.9.16/lib/python3.9/site-packages/numba/core/types/abstract.py", line 210, in _determine_array_spec
# raise KeyError(f"Can only index numba types with slices with no start or stop, got {args}.")
# KeyError: 'Can only index numba types with slices with no start or stop, got 0.'
(Note: as you see above, I am using Python version 3.9.16, which I am stuck to due to the large project I am working within.)
I'd really appreciate any guidance on how I can fix this to be able to correct this signature to allow passing a list of jitclass objects.
Thank you in advance.
--
SOLUTION THAT WORKED FOR ME:
Thank you very much to @Jérôme Richard for clarifying the usage within numba. I find numba very confusing, but his comments below got me to a solution that worked for me!
@nb.experimental.jitclass([('x', nb.float64)])
class test_jitclass(object):
def __init__(self, x):
self.x = x
test_jitclass_type = test_jitclass.class_type.instance_type
list_test_jitclass = nb.typed.List([test_jitclass(i) for i in range(3)])
@nb.njit(nb.int64(nb.types.ListType(test_jitclass_type)))
def process_list_test_jitclass(list_test_jitclass):
for tjc in list_test_jitclass:
print('processing: ', tjc.x)
return 1
process_list_test_jitclass(list_test_jitclass)
The key changes were:
- applying nb.typed.List to the variable, not to the type
- in the njit signature, used
nb.types.ListType(test_jitclass_type)
I hope this resolution is helpful for others as well.