83

Is there a simple/clean way to iterate an array of axis returned by subplots like

nrow = ncol = 2
a = []
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)
for i, row in enumerate(axs):
    for j, ax in enumerate(row):
        a.append(ax)

for i, ax in enumerate(a):
    ax.set_ylabel(str(i))

which even works for nrow or ncol == 1.

I tried list comprehension like:

[element for tupl in tupleOfTuples for element in tupl]

but that fails if nrows or ncols == 1

1
  • The solution which also works with a single axis is this one. Commented Dec 15, 2022 at 15:53

6 Answers 6

109

The ax return value is a numpy array, which can be reshaped, I believe, without any copying of the data. If you use the following, you'll get a linear array that you can iterate over cleanly.

nrow = 1; ncol = 2;
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)

for ax in axs.reshape(-1): 
  ax.set_ylabel(str(i))

This doesn't hold when ncols and nrows are both 1, since the return value is not an array; you could turn the return value into an array with one element for consistency, though it feels a bit like a cludge:

nrow = 1; ncol = 1;
fig, axs = plt.subplots(nrows=nrow, ncols=nrow)
axs = np.array(axs)

for ax in axs.reshape(-1):
  ax.set_ylabel(str(i))

reshape docs. The argument -1 causes reshape to infer dimensions of the output.

Sign up to request clarification or add additional context in comments.

3 Comments

For nrow=ncol=1 you can use squeeze=0. plt.subplots(nrows=nrow, ncols=nrow, squeeze=0) always returns a 2 dimensional array for the axes, even if both are one.
This suggestion will not work because axs is not a numpy array in the case nrow=ncol=1. squeeze=0 works!
I had no luck with squeeze=, but note that in 2024 it's a boolean, so squeeze=False would be the "correct" way to use the argument.
76

The fig return value of plt.subplots has a list of all the axes. To iterate over all the subplots in a figure you can use:

nrow = 2
ncol = 2
fig, axs = plt.subplots(nrow, ncol)
for i, ax in enumerate(fig.axes):
    ax.set_ylabel(str(i))

This also works for nrow == ncol == 1.

4 Comments

so we don't need the axs then?
this is so simple and so useful!
This should be the selected answer.
So... what does fig.axes have/do that axs doesn't have/do?
24

I am not sure when it was added, but there is now a squeeze keyword argument. This makes sure the result is always a 2D numpy array. Turning that into a 1D array is easy:

fig, ax2d = subplots(2, 2, squeeze=False)
axli = ax2d.flatten()

Works for any number of subplots, no trick for single ax, so a little easier than the accepted answer (perhaps squeeze didn't exist yet back then).

2 Comments

did you mean ax2d.flatten() in your second line? Otherwise it's unclear what ax1d is referencing.
@notlink Ow yeah that makes more sense
22

You can use numpy's .flat attribute on the axes object.

Here is an example

fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
    ## do something with instance of 'ax'

2 Comments

.flat is an attribute of the numpy array and has nothing to do with matplotlib.
@ImportanceOfBeingErnest Thanks for correcting my answer. I got confused.
13

TLDR; axes.flat is the most pythonic way of iterating through axes

As others have pointed out, the return value of plt.subplots() is a numpy array of Axes objects, thus there are a ton of built-in numpy methods for flattening the array. Of those options axes.flat is the least verbose access method. Furthermore, axes.flatten() returns a copy of the array whereas axes.flat returns an iterator to the array. This means axes.flat will be more efficient in the long run.

Stealing @Sukjun-Kim's example:

fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
    ## do something with instance of 'ax'

sources: axes.flat docs Matplotlib tutorial

Comments

3

Here is a good practice:
For example, we need a set up four by four subplots so we can have them like below:

rows = 4; cols = 4;
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(20, 16), squeeze=0, sharex=True, sharey=True)
axes = np.array(axes)

for i, ax in enumerate(axes.reshape(-1)):
  ax.set_ylabel(f'Subplot: {i}')

The output is beautiful and clear.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.