2

I'm trying to create a Dash application that displays a grid of subplots to visualize the pairwise comparison of the columns of a dataframe. To the top and left of each grid row and column will be the corresponding variables. The variable names can be quite long though, so it's easy to misalign them. I've tried, staggering the variable names, but eventually settled on line-wrapping them. See the picture below. I've also included my code at the end of this post

df = pd.DataFrame({
    "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA": ["1", "2", "3", "4"],
    "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"],
    "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC": ["cat", "dog", "cat", "mouse"],
    "DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD": ["10.5", "20.3", "30.1", "40.2"],
    'EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE': ['apple', 'apple', 'apple', 'banana']
})

For this dataframe, I'd like to get something like

enter image description here

As you can see, I'm having trouble aligning the row and column labels of the grid. Here is my code

import dash
from dash import dcc, html
import pandas as pd
import plotly.express as px
import plotly.subplots as sp
import numpy as np
import plotly.graph_objects as go


# Sample DataFrame
df = pd.DataFrame({
    "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA": ["1", "2", "3", "4"],
    "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04"],
    "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC": ["cat", "dog", "cat", "mouse"],
    "DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD": ["10.5", "20.3", "30.1", "40.2"],
    'EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE': ['apple', 'apple', 'apple', 'banana']
})

# Convert data types
def convert_dtypes(df):
    for col in df.columns:
        try:
            df[col] = pd.to_numeric(df[col])  # Convert to int/float
        except ValueError:
            try:
                df[col] = pd.to_datetime(df[col])  # Convert to datetime
            except ValueError:
                df[col] = df[col].astype("string")  # Keep as string
    return df

df = convert_dtypes(df)
columns = df.columns
num_cols = len(columns)

# Dash App
app = dash.Dash(__name__)

app.layout = html.Div([
    html.H1("Pairwise Column Plots"),
    dcc.Graph(id='grid-plots')
])

@app.callback(
    dash.Output('grid-plots', 'figure'),
    dash.Input('grid-plots', 'id')  # Dummy input to trigger callback
)
def create_plot_grid(_):
    fig = sp.make_subplots(rows = num_cols, cols = num_cols, 
                           #subplot_titles = [f"{x} vs {y}" for x in columns for y in columns],
                           shared_xaxes = False, shared_yaxes = False)

    annotations = []  # Store subplot titles dynamically
    # Add column labels (Top Labels)
    for j, col_label in enumerate(columns):
        annotations.append(
            dict(
                #text=f"<b>{col_label}</b>",  # Bold for emphasis
                text=f"<b>{'<br>'.join(col_label[x:x+10] for x in range(0, len(col_label), 10))}</b>",
                xref = "paper", yref = "paper",
                x = (j) / num_cols,  # Center over the column
                y = 1.02,  # Slightly above the top row
                showarrow = False,
                font = dict(size = 14, color = "black")
            )
        )
    # Add row labels (Side Labels)
    for i, row_label in enumerate(columns):
        annotations.append(
            dict(
                #text = f"<b>{row_label}</b>",  # Bold for emphasis
                text=f"<b>{'<br>'.join(row_label[x:x+10] for x in range(0, len(row_label), 10))}</b>",
                xref = "paper", yref = "paper",
                x = -0.02,  # Slightly to the left of the row
                y = (1 - (i + 0.5) / num_cols),  # Center next to the row
                showarrow = False,
                font = dict(size = 14, color = "black"),
                textangle = -90  # Rotate text for vertical orientation
            )
        )

    print(annotations)

    for i, x_col in enumerate(columns):
        for j, y_col in enumerate(columns):
            dtype_x, dtype_y = df[x_col].dtype, df[y_col].dtype
            row, col = i + 1, j + 1  # Adjust for 1-based indexing

            # I only want to print the upper triangle of the grid
            if j <= i:
                trace = None

            # Numeric vs Numeric: Scatter Plot
            elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
                trace = px.scatter(df, x = x_col, y = y_col).data[0]

            # Numeric vs Categorical: Box Plot
            elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y):
                trace = px.box(df, x = y_col, y = x_col).data[0]
            elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
                trace = px.box(df, x = x_col, y = y_col).data[0]

            # Categorical vs Categorical: Count Heatmap
            elif pd.api.types.is_string_dtype(dtype_x) and pd.api.types.is_string_dtype(dtype_y):
                #trace = px.histogram(df, x = x_col, color = y_col, barmode = "group").data[0]
                counts_df = (
                    df
                    .groupby([x_col, y_col])
                    .size()
                    .reset_index(name = 'count')
                    .pivot_table(index = x_col, columns = y_col, values = "count", aggfunc="sum")  
                ) 
                trace = go.Heatmap(z = counts_df.values, x = counts_df.columns, y = counts_df.index, showscale = False)

            # Datetime vs Numeric: Line Plot
            elif pd.api.types.is_datetime64_any_dtype(dtype_x) and pd.api.types.is_numeric_dtype(dtype_y):
                trace = px.line(df, x = x_col, y = y_col).data[0]
            elif pd.api.types.is_numeric_dtype(dtype_x) and pd.api.types.is_datetime64_any_dtype(dtype_y):
                trace = px.line(df, x = y_col, y = x_col).data[0]

            else:
                trace = None  # Unsupported combination

            if trace:
                fig.add_trace(trace, row = row, col = col)

    fig.update_layout(height = 300 * num_cols, 
                      width = 300 * num_cols, 
                      showlegend = False,
                      annotations = annotations)
    print(fig['layout'])
    return fig

if __name__ == '__main__':
    app.run_server(debug = True)
2
  • Are you open to font size changes or adjusting columns names as work arounds? Commented Mar 21 at 0:20
  • Sorry for the late response. I'm fine with font-size changes, but would rather keep the column names as-is, if possible Commented Mar 21 at 12:48

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.