I translated this code from matlab.It is part of a larger body of code But I would like to get some advice on how to vectorize this section in order to make it faster. my major concern is with the for loops and if statements. If possible I would like to write it without using an if else statement. (Jax is not able to jit an if conditional). Thanks
import numpy as np
num_rows = 5
num_cols = 20
smf = np.array([np.inf, 0.1, 0.1, 0.1, 0.1])
par_init = np.array([1,2,3,4,5])
lb = np.array([0.1, 0.1, 0.1, 0.1, 0.1])
ub = np.array([10, 10, 10, 10, 10])
par = np.broadcast_to(par_init[:,None],(num_rows,num_cols))
print(par.shape)
par0_col = np.zeros(num_rows*num_cols - (num_cols-1) * np.sum(np.isinf(smf)))
lb_col = np.zeros(num_rows*num_cols - (num_cols-1) * np.sum(np.isinf(smf)))
ub_col = np.zeros(num_rows*num_cols- (num_cols-1) * np.sum(np.isinf(smf)))
# First looping
k = 0
for i in range(num_rows):
smf_1 = smf.copy()
if smf_1[i] == np.inf:
par0_col[k] = par[i, 0]
lb_col[k] = lb[i]
ub_col[k] = ub[i]
k = k+1
else:
par0_col[k:k+num_cols] = par[i, :num_cols]
lb_col[k:k+num_cols] = lb[i]
ub_col[k:k+num_cols] = ub[i]
k = k+num_cols
arr_1 = np.zeros(shape = (num_rows, num_cols))
arr_2 = np.zeros(shape = (num_rows, num_cols))
par_log = np.log10((par0_col - lb_col) / (1 - par0_col / ub_col))
# second looping
k = 0
for i in range(num_rows):
smf_1 = smf.copy()
if np.isinf(smf_1[i]):
smf_1[i] = 0
arr_1[i, :] = (par_log[k])
arr_2[i, :] = 10**par_log[k]
k = k+1
else:
arr_1[i, :] = par_log[k:k+num_cols]
arr_2[i, :] = 10**par_log[k:k+num_cols]
k = k+num_cols
# print(arr_1)
# [[0. 0. 0. 0. 0. 0.
# 0. 0. 0. 0. 0. 0.
# 0. 0. 0. 0. 0. 0.
# 0. 0. ]
# [0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361 0.37566361
# 0.37566361 0.37566361]
# [0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996 0.61729996
# 0.61729996 0.61729996]
# [0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336 0.81291336
# 0.81291336 0.81291336]
# [0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608 0.99122608
# 0.99122608 0.99122608]]
iorkat once. You/we have to spend time understanding the probkem, and then visualize working with 'whole' arrays.