2

Consider the example code:

from functools import partial
from jax import jit
import jax.numpy as jnp

@partial(jit, static_argnums=(0,))
def my_function(n):

    idx = jnp.tile(jnp.arange(n, dtype=int),(1,n)) # Contains duplicate indices
    A = jnp.ones((n**2,), dtype=float)
    B = jnp.ones((n,100,100), dtype=float)

    return jnp.sum(A[...,None,None]*B[idx]) # Will data in B be duplicated in memory here?

my_function(5)

When compiling through B[idx], will jax compilation recognize that there are duplicate indices and thereby avoid unnecessarily duplicating the data in B?

I suspect probably not because it's value dependent in general, but just want to understand better.

1 Answer 1

1

No, I don't believe this is an optimization that the compiler does. I'm basing this on the fact that XLA's computational model requires all array shapes to be known at compile-time, and the values in idx are not known until runtime.

If you're not convinced and want to see for yourself what the compiler is doing with this code, you can use JAX's Ahead of time compilation APIs to peek at the compiled HLO produced by XLA for this code (note that the compiler will perform different optimizations on different hardware).

For example:

print(my_function.lower(5).compile().as_text())
HloModule jit_my_function, is_scheduled=true, entry_computation_layout={()->f32[]}, allow_spmd_sharding_propagation_to_output={true}

%region_0.1 (reduce_sum.3: f32[], reduce_sum.4: f32[]) -> f32[] {
  %reduce_sum.3 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.4 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.5 = f32[] add(%reduce_sum.3, %reduce_sum.4), metadata={op_name="jit(my_function)/reduce_sum" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=11 source_end_column=43}
}

%region_0.1.clone (reduce_sum.0: f32[], reduce_sum.1: f32[]) -> f32[] {
  %reduce_sum.0 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.1 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.2 = f32[] add(%reduce_sum.0, %reduce_sum.1), metadata={op_name="jit(my_function)/reduce_sum" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=11 source_end_column=43}
}

%fused_computation () -> f32[1,25,100,100] {
  %constant.2 = f32[] constant(1)
  %broadcast_in_dim.0 = f32[5,100,100]{2,1,0} broadcast(%constant.2), dimensions={}, metadata={op_name="jit(my_function)/broadcast_in_dim" source_file="/tmp/ipython-input-3455249338.py" source_line=10 source_end_line=10 source_column=8 source_end_column=42}
  %iota.5 = s32[1,1,5,5]{3,2,1,0} iota(), iota_dimension=3, metadata={op_name="jit(my_function)/broadcast_in_dim" source_file="/tmp/ipython-input-3455249338.py" source_line=8 source_end_line=8 source_column=10 source_end_column=50}
  %bitcast.5 = s32[1,25]{1,0} bitcast(%iota.5), metadata={op_name="jit(my_function)/broadcast_in_dim" source_file="/tmp/ipython-input-3455249338.py" source_line=8 source_end_line=8 source_column=10 source_end_column=50}
  %constant.1 = s32[] constant(0)
  %broadcast.4 = s32[1,25]{1,0} broadcast(%constant.1), dimensions={}
  %lt.0 = pred[1,25]{1,0} compare(%bitcast.5, %broadcast.4), direction=LT, metadata={op_name="jit(my_function)/lt" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
  %constant.0 = s32[] constant(5)
  %broadcast.1 = s32[1,25]{1,0} broadcast(%constant.0), dimensions={}
  %add.0 = s32[1,25]{1,0} add(%bitcast.5, %broadcast.1), metadata={op_name="jit(my_function)/add" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
  %select_n.0 = s32[1,25]{1,0} select(%lt.0, %add.0, %bitcast.5), metadata={op_name="jit(my_function)/select_n" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
  %bitcast.4 = s32[25,1]{1,0} bitcast(%select_n.0), metadata={op_name="jit(my_function)/select_n" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
  %gather.2 = f32[25,1,100,100]{3,2,1,0} gather(%broadcast_in_dim.0, %bitcast.4), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,100,100}, metadata={op_name="jit(my_function)/gather" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
  ROOT %bitcast.3 = f32[1,25,100,100]{3,2,1,0} bitcast(%gather.2), metadata={op_name="jit(my_function)/gather" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
}

ENTRY %main.2 () -> f32[] {
  %constant.7 = f32[] constant(0)
  %gather_bitcast_fusion = f32[1,25,100,100]{3,2,1,0} fusion(), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(my_function)/gather" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}, backend_config={"outer_dimension_partitions":["1","2"]}
  %reduce-window = f32[1,1,4,4]{3,2,1,0} reduce-window(%gather_bitcast_fusion, %constant.7), window={size=1x25x32x32 stride=1x25x32x32 pad=0_0x0_0x14_14x14_14}, to_apply=%region_0.1, backend_config={"outer_dimension_partitions":["1","1","2"]}
  ROOT %reduce_sum.7 = f32[] reduce(%reduce-window, %constant.7), dimensions={0,1,2,3}, to_apply=%region_0.1.clone, metadata={op_name="jit(my_function)/reduce_sum" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=11 source_end_column=43}
}

Reading this output takes some practice, but the relevant piece to answer your question is the line that starts with %gather.2 = f32[25,1,100,100]{3,2,1,0}: the gather primitive is XLA's version of indexing, and you see that it's explicitly constructing the full 25x100x100 array, and not removing the duplicated indices.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.