1

The following numpy code is perfectly fine:

arr = np.arange(50)
print(arr.shape) # (50,)

indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

It also works after migrating to jax:

import jax.numpy as jnp

arr = jnp.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

Now let's try a mix of numpy and jax:

arr = np.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

This produces the following error:

IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed

If indexing into a numpy array with a jax array is not supported, that's fine by me. But the error message seems wrong. And things get even more confusing. If you change the shapes a bit, the code works fine. In the following sample I've only edited the shape of indices from (30,) to (40,). No more error message:

arr = np.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)

arr[indices]

I'm running jax version '0.2.12', on the cpu. What is happening here?

1
  • looks like it's treating indices as a tuple - if smaller than 32, the maximum number of dimensions. Legacy code did this with some lists, though newer versions are working toward deprecating the behavior. Commented May 28, 2021 at 15:45

1 Answer 1

1

This is a long-standing known issue (see https://github.com/google/jax/issues/620); it's not a bug that can be easily fixed by JAX itself, but will require changes to how NumPy treats non-ndarray indices. The good news is that the fix is on the horizon: your problematic code above is accompanied by the following warning, which originates from NumPy:

FutureWarning: Using a non-tuple sequence for multidimensional indexing is
 deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this
 will be interpreted as an array index, `arr[np.array(seq)]`, which will result
 either in an error or a different result.

Once this deprecation cycle is complete, JAX arrays will work correctly in NumPy indexing.

Until then, you can work around it by explicitly calling np.asarray when using JAX arrays to index into NumPy arrays.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.