No, I don't believe this is an optimization that the compiler does. I'm basing this on the fact that XLA's computational model requires all array shapes to be known at compile-time, and the values in idx are not known until runtime.
If you're not convinced and want to see for yourself what the compiler is doing with this code, you can use JAX's Ahead of time compilation APIs to peek at the compiled HLO produced by XLA for this code (note that the compiler will perform different optimizations on different hardware).
For example:
print(my_function.lower(5).compile().as_text())
HloModule jit_my_function, is_scheduled=true, entry_computation_layout={()->f32[]}, allow_spmd_sharding_propagation_to_output={true}
%region_0.1 (reduce_sum.3: f32[], reduce_sum.4: f32[]) -> f32[] {
%reduce_sum.3 = f32[] parameter(0), metadata={op_name="reduce_sum"}
%reduce_sum.4 = f32[] parameter(1), metadata={op_name="reduce_sum"}
ROOT %reduce_sum.5 = f32[] add(%reduce_sum.3, %reduce_sum.4), metadata={op_name="jit(my_function)/reduce_sum" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=11 source_end_column=43}
}
%region_0.1.clone (reduce_sum.0: f32[], reduce_sum.1: f32[]) -> f32[] {
%reduce_sum.0 = f32[] parameter(0), metadata={op_name="reduce_sum"}
%reduce_sum.1 = f32[] parameter(1), metadata={op_name="reduce_sum"}
ROOT %reduce_sum.2 = f32[] add(%reduce_sum.0, %reduce_sum.1), metadata={op_name="jit(my_function)/reduce_sum" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=11 source_end_column=43}
}
%fused_computation () -> f32[1,25,100,100] {
%constant.2 = f32[] constant(1)
%broadcast_in_dim.0 = f32[5,100,100]{2,1,0} broadcast(%constant.2), dimensions={}, metadata={op_name="jit(my_function)/broadcast_in_dim" source_file="/tmp/ipython-input-3455249338.py" source_line=10 source_end_line=10 source_column=8 source_end_column=42}
%iota.5 = s32[1,1,5,5]{3,2,1,0} iota(), iota_dimension=3, metadata={op_name="jit(my_function)/broadcast_in_dim" source_file="/tmp/ipython-input-3455249338.py" source_line=8 source_end_line=8 source_column=10 source_end_column=50}
%bitcast.5 = s32[1,25]{1,0} bitcast(%iota.5), metadata={op_name="jit(my_function)/broadcast_in_dim" source_file="/tmp/ipython-input-3455249338.py" source_line=8 source_end_line=8 source_column=10 source_end_column=50}
%constant.1 = s32[] constant(0)
%broadcast.4 = s32[1,25]{1,0} broadcast(%constant.1), dimensions={}
%lt.0 = pred[1,25]{1,0} compare(%bitcast.5, %broadcast.4), direction=LT, metadata={op_name="jit(my_function)/lt" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
%constant.0 = s32[] constant(5)
%broadcast.1 = s32[1,25]{1,0} broadcast(%constant.0), dimensions={}
%add.0 = s32[1,25]{1,0} add(%bitcast.5, %broadcast.1), metadata={op_name="jit(my_function)/add" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
%select_n.0 = s32[1,25]{1,0} select(%lt.0, %add.0, %bitcast.5), metadata={op_name="jit(my_function)/select_n" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
%bitcast.4 = s32[25,1]{1,0} bitcast(%select_n.0), metadata={op_name="jit(my_function)/select_n" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
%gather.2 = f32[25,1,100,100]{3,2,1,0} gather(%broadcast_in_dim.0, %bitcast.4), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,100,100}, metadata={op_name="jit(my_function)/gather" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
ROOT %bitcast.3 = f32[1,25,100,100]{3,2,1,0} bitcast(%gather.2), metadata={op_name="jit(my_function)/gather" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}
}
ENTRY %main.2 () -> f32[] {
%constant.7 = f32[] constant(0)
%gather_bitcast_fusion = f32[1,25,100,100]{3,2,1,0} fusion(), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(my_function)/gather" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=36 source_end_column=42}, backend_config={"outer_dimension_partitions":["1","2"]}
%reduce-window = f32[1,1,4,4]{3,2,1,0} reduce-window(%gather_bitcast_fusion, %constant.7), window={size=1x25x32x32 stride=1x25x32x32 pad=0_0x0_0x14_14x14_14}, to_apply=%region_0.1, backend_config={"outer_dimension_partitions":["1","1","2"]}
ROOT %reduce_sum.7 = f32[] reduce(%reduce-window, %constant.7), dimensions={0,1,2,3}, to_apply=%region_0.1.clone, metadata={op_name="jit(my_function)/reduce_sum" source_file="/tmp/ipython-input-3455249338.py" source_line=12 source_end_line=12 source_column=11 source_end_column=43}
}
Reading this output takes some practice, but the relevant piece to answer your question is the line that starts with %gather.2 = f32[25,1,100,100]{3,2,1,0}: the gather primitive is XLA's version of indexing, and you see that it's explicitly constructing the full 25x100x100 array, and not removing the duplicated indices.