The following numpy code is perfectly fine:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
It also works after migrating to jax:
import jax.numpy as jnp
arr = jnp.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
Now let's try a mix of numpy and jax:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
This produces the following error:
IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed
If indexing into a numpy array with a jax array is not supported, that's fine by me. But the error message seems wrong. And things get even more confusing. If you change the shapes a bit, the code works fine. In the following sample I've only edited the shape of indices from (30,) to (40,). No more error message:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)
arr[indices]
I'm running jax version '0.2.12', on the cpu. What is happening here?
indicesas a tuple - if smaller than 32, the maximum number of dimensions. Legacy code did this with some lists, though newer versions are working toward deprecating the behavior.