1

I am trying to write a general utility to update indices in a jax array that may have a different number of dimensions depending on the instance.

I know that I have to use the .at[].set() methods, and this is what I have so far:

b = np.arange(16).reshape([4,4])
print(b)
update_indices = np.array([[1,1], [3,2], [0,3]])
update_indices = np.moveaxis(update_indices, -1, 0)
b = b.at[update_indices[0], update_indices[1]].set([333, 444, 555])
print(b)

This transforms:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

into

[[  0   1   2 555]
 [  4 333   6   7]
 [  8   9  10  11]
 [ 12  13 444  15]]

My problem is that I have had to hard code the argument to at as update_indices[0], update_indices[1]. However, in general b could have an arbitrary number of dimensions so this will not work. (e.g. for a 3D array I would have to replace it with update_indices[0], update_indices[1], update_indices[2]).

It would be nice if I could write something like b.at[*update_indices] but this does not work.

1 Answer 1

1

This should work:

b.at[tuple(update_indices)]
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.