I have a function compiled under the 'njit' framework in Numba. It looks like this:
import numpy as np
from numba import njit, types, prange
from numba.typed import List
@njit(cache=CACHE_FLAG)
def find_combinations(target, *arrays):
"""
Find all combinations of element indices from 7 arrays containing segment lengths that sum to the target update time.
Returns both the lengths that sum up to the target and the corresponding indices.
Args:
target (int): The target sum.
arrays (tuple of lists): Each list contains segment lengths.
Returns:
Tuple of two lists:
1. Lengths list - contains tuples of segment lengths that sum up to the target.
2. Indices list - contains tuples of indices in the original lists corresponding to the lengths.
"""
lengths_list = List()
indices_list = List()
for i in prange(len(arrays[0])):
sum_i = arrays[0][i]
if sum_i > target:
continue
for j in range(len(arrays[1])):
sum_j = sum_i + arrays[1][j]
if sum_j > target:
continue
for k in range(len(arrays[2])):
sum_k = sum_j + arrays[2][k]
if sum_k > target:
continue
for l in range(len(arrays[3])):
sum_l = sum_k + arrays[3][l]
if sum_l > target:
continue
for m in range(len(arrays[4])):
sum_m = sum_l + arrays[4][m]
if sum_m > target:
continue
for n in range(len(arrays[5])):
sum_n = sum_m + arrays[5][n]
if sum_n > target:
continue
for o in range(len(arrays[6])):
total = sum_n + arrays[6][o]
if total == target:
lengths_list.append(
(
arrays[0][i],
arrays[1][j],
arrays[2][k],
arrays[3][l],
arrays[4][m],
arrays[5][n],
arrays[6][o],
)
)
indices_list.append((i, j, k, l, m, n, o))
return lengths_list, indices_list
The function works as expected but is designed to work for an 'arrays' input of length 7. The outputs lengths list and indices list contain some number of tuples of length 7 (= length of input). I want to re-write this function to be able to handle inputs of arbitrary lengths.
I tried to rewrite this function using recursion like this:
import numpy as np
from numba import njit, types, prange
from numba.typed import List
@njit
def find_combinations_recursive(arrays, target, current_sum, current_indices, current_lengths, lengths_list, indices_list, depth):
# Base case: when we've processed all arrays
if depth == len(arrays):
if current_sum == target:
# lengths_list.append(tuple(current_lengths)) # tuple does not work in the Pythonic way-- cannot convert iterable to tuple in Numba?
# indices_list.append(tuple(current_indices))
lengths_list.append(current_lengths)
indices_list.append(current_indices)
return
# Recursive case: iterate over the current array
for i in range(len(arrays[depth])):
new_sum = current_sum + arrays[depth][i]
if new_sum > target:
continue
current_indices.append(i)
# current_lengths.extend(List([arrays[depth][i]]))
current_lengths.append(arrays[depth][i])
find_combinations_recursive(
arrays, target, new_sum, current_indices, current_lengths, lengths_list, indices_list, depth + 1
)
@njit(cache=True)
def find_combinations(target, arrays):
# lengths_list = List()
# indices_list = List()
# lengths_list = List.empty_list(tuple([np.int64] * len(arrays)))
# indices_list = List.empty_list(tuple([np.int64] * len(arrays)))
lengths_list = List([List([1]*len(arrays))])
# print(lengths_list)
# indices_list = List(tuple([1] * len(arrays)))
indices_list = List([List([1]*len(arrays))])
# print(indices_list)
# import pdb; pdb.set_trace()
k = lengths_list.pop()
x = indices_list.pop()
current_indices = List.empty_list(types.int64)
current_lengths = List.empty_list(types.int64)
find_combinations_recursive(arrays, target, 0, current_indices, current_lengths, lengths_list, indices_list, 0)
return lengths_list, indices_list
However, while the function runs it does not produce the expected output-- for some reason, lengths_list does not contain the expected elements and seems to contain a single list of ints rather than multiple lists of ints of length 7. Also, I had a lot of trouble with Numba's List() data structure and eventually resorted to a sort of hacky way of initializing where I give it an initial value to set the type and then pop this value out to obtain an empty list of desired type. What do I need to fix to get this all to work coherently?
itertools.combinationsalready solves the problem of computing combinations of an arbitrary number of sets. I don't know how or if Numba can optimize it, though.