1

This is a follow-up to my previous question. I am implementing a Parameterized Quantum Circuit as a Quantum Neural Network, where the optimization loop is jitted. Although there's no error, everything is working fine, I find a very unusual behavior in terms of execution times.

Check out the code below:

Setting - 1

import pennylane as qml
from pennylane import numpy as np
import jax
from jax import numpy as jnp
import optax
from itertools import combinations
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import log_loss
import matplotlib.pyplot as plt
import matplotlib.colors
import warnings
warnings.filterwarnings("ignore")
np.random.seed(42)
import time

# Load the digits dataset with features (X_digits) and labels (y_digits)
X_digits, y_digits = load_digits(return_X_y=True)

# Create a boolean mask to filter out only the samples where the label is 2 or 6
filter_mask = np.isin(y_digits, [2, 6])

# Apply the filter mask to the features and labels to keep only the selected digits
X_digits = X_digits[filter_mask]
y_digits = y_digits[filter_mask]

# Split the filtered dataset into training and testing sets with 10% of data reserved for testing
X_train, X_test, y_train, y_test = train_test_split(
    X_digits, y_digits, test_size=0.1, random_state=42
)

# Normalize the pixel values in the training and testing data
# Convert each image from a 1D array to an 8x8 2D array, normalize pixel values, and scale them
X_train = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_train])
X_test = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_test])

# Adjust the labels to be centered around 0 and scaled to be in the range -1 to 1
# The original labels (2 and 6) are mapped to -1 and 1 respectively
y_train = (y_train - 4) / 2
y_test = (y_test - 4) / 2


def feature_map(features):
    # Apply Hadamard gates to all qubits to create an equal superposition state
    for i in range(len(features[0])):
        qml.Hadamard(i)

    # Apply angle embeddings based on the feature values
    for i in range(len(features)):
        # For odd-indexed features, use Z-rotation in the angle embedding
        if i % 2:
            qml.AngleEmbedding(features=features[i], wires=range(8), rotation="Z")
        # For even-indexed features, use X-rotation in the angle embedding
        else:
            qml.AngleEmbedding(features=features[i], wires=range(8), rotation="X")

# Define the ansatz (quantum circuit ansatz) for parameterized quantum operations
def ansatz(params):
    # Apply RY rotations with the first set of parameters
    for i in range(8):
        qml.RY(params[i], wires=i)

    # Apply CNOT gates with adjacent qubits (cyclically connected) to create entanglement
    for i in range(8):
        qml.CNOT(wires=[(i - 1) % 8, (i) % 8])

    # Apply RY rotations with the second set of parameters
    for i in range(8):
        qml.RY(params[i + 8], wires=i)

    # Apply CNOT gates with qubits in reverse order (cyclically connected)
    # to create additional entanglement
    for i in range(8):
        qml.CNOT(wires=[(8 - 2 - i) % 8, (8 - i - 1) % 8])



dev = qml.device("default.qubit", wires=8)


@qml.qnode(dev)
def circuit(params, features):
    feature_map(features)
    ansatz(params)
    return qml.expval(qml.PauliZ(0))


def variational_classifier(weights, bias, x):
    return circuit(weights, x) + bias


def square_loss(labels, predictions):
    return np.mean((labels - qml.math.stack(predictions)) ** 2)


def accuracy(labels, predictions):
    acc = sum([np.sign(l) == np.sign(p) for l, p in zip(labels, predictions)])
    acc = acc / len(labels)
    return acc


def cost(params, X, Y):
    predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
    return square_loss(Y, predictions)


def acc(params, X, Y):
    predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
    return accuracy(Y, predictions)


np.random.seed(0)
weights = 0.01 * np.random.randn(16)
bias = jnp.array(0.0)
params = {"weights": weights, "bias": bias}
opt = optax.adam(0.05)
batch_size = 7
num_batch = X_train.shape[0] // batch_size
opt_state = opt.init(params)
X_batched = X_train.reshape([-1, batch_size, 8, 8])
y_batched = y_train.reshape([-1, batch_size])


@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no, print_training = args
    _data = data[batch_no % num_batch]
    _targets = targets[batch_no % num_batch]
    train_loss, grads = jax.value_and_grad(cost)(params, _data, _targets)
    updates, opt_state = opt.update(grads, opt_state)
    test_loss, grads = jax.value_and_grad(cost)(params, X_test, y_test)
    params = optax.apply_updates(params, updates)

    # Print training loss every step if print_training is True
    def print_fn():
        jax.debug.print("Step: {i}, Train Loss: {train_loss}", i=i, train_loss=train_loss)
        jax.debug.print("Step: {i}, Test Loss: {test_loss}", i=i, test_loss=test_loss)

    jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None)
    return (params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no + 1, print_training)


@jax.jit
def optimization_jit(params, data, targets, X_test, y_test, X_train, y_train, print_training = True):
    opt_state = opt.init(params)
    args = (params, opt_state, data, targets, X_test, y_test, X_train, y_train, 0, print_training)
    (params, _, _, _, _, _, _, _, _, _) = jax.lax.fori_loop(0, 1, update_step_jit, args)
    return params


start_time = time.time()
params = optimization_jit(params, X_batched, y_batched, X_test, y_test, X_train, y_train)
print("Training Done! \nTime taken:",time.time() - start_time)

start_time = time.time()
var_train_acc = acc(params, X_train, y_train)
print("Training accuracy: ", var_train_acc)
print("Time taken:",time.time() - start_time)

start_time = time.time()
var_test_acc = acc(params, X_test, y_test)
print("Testing accuracy: ", var_test_acc)
print("Time taken:",time.time() - start_time)

Notice that it is running the jax.lax.fori_loop just 1 time.

For reproducibility, I verified it by running 3 times, and the outputs are as follows,

Output of first run:

Training Done! 
Time taken: 66.26599097251892
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Training accuracy:  0.5031055900621118
Time taken: 14.183394193649292
Testing accuracy:  0.5277777777777778
Time taken: 1.552431344985962

Output of second run:

Training Done! 
Time taken: 62.8515682220459
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Training accuracy:  0.5031055900621118
Time taken: 13.549866199493408
Testing accuracy:  0.5277777777777778
Time taken: 1.5097148418426514

Output of third run:

Training Done! 
Time taken: 63.35235905647278
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Training accuracy:  0.5031055900621118
Time taken: 13.52238941192627
Testing accuracy:  0.5277777777777778
Time taken: 1.5074975490570068

Setting - 2

So, then I ran it changing the jax.lax.fori_loop to run 10 times as

    (params, _, _, _, _, _, _, _, _, _) = jax.lax.fori_loop(0, 10, update_step_jit, args)

Surprisingly, the execution time reduces quite significantly, and the outputs are:

First Run:

Training Done! 
Time taken: 49.8694589138031
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Step: 1, Train Loss: 0.934578537940979
Step: 1, Test Loss: 0.9935969114303589
Step: 2, Train Loss: 0.982826828956604
Step: 2, Test Loss: 1.004722237586975
Step: 3, Train Loss: 0.982965350151062
Step: 3, Test Loss: 1.0281261205673218
Step: 4, Train Loss: 1.1700845956802368
Step: 4, Test Loss: 1.0455362796783447
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0411475896835327
Step: 6, Train Loss: 1.2408322095870972
Step: 6, Test Loss: 1.0204349756240845
Step: 7, Train Loss: 0.7292405366897583
Step: 7, Test Loss: 0.9959328770637512
Step: 8, Train Loss: 1.1697252988815308
Step: 8, Test Loss: 0.9822244644165039
Step: 9, Train Loss: 1.015731692314148
Step: 9, Test Loss: 0.9667297005653381
Training accuracy:  0.5217391304347826
Time taken: 13.903431177139282
Testing accuracy:  0.5555555555555556
Time taken: 1.537736177444458

Second run:

Training Done! 
Time taken: 56.34339928627014
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Step: 1, Train Loss: 0.934578537940979
Step: 1, Test Loss: 0.9935969114303589
Step: 2, Train Loss: 0.982826828956604
Step: 2, Test Loss: 1.004722237586975
Step: 3, Train Loss: 0.982965350151062
Step: 3, Test Loss: 1.0281261205673218
Step: 4, Train Loss: 1.1700845956802368
Step: 4, Test Loss: 1.0455362796783447
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0411475896835327
Step: 6, Train Loss: 1.2408322095870972
Step: 6, Test Loss: 1.0204349756240845
Step: 7, Train Loss: 0.7292405366897583
Step: 7, Test Loss: 0.9959328770637512
Step: 8, Train Loss: 1.1697252988815308
Step: 8, Test Loss: 0.9822244644165039
Step: 9, Train Loss: 1.015731692314148
Step: 9, Test Loss: 0.9667297005653381
Training accuracy:  0.5217391304347826
Time taken: 13.298640727996826
Testing accuracy:  0.5555555555555556
Time taken: 1.4631397724151611

Third run:

Training Done! 
Time taken: 53.01019215583801
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Step: 1, Train Loss: 0.934578537940979
Step: 1, Test Loss: 0.9935969114303589
Step: 2, Train Loss: 0.982826828956604
Step: 2, Test Loss: 1.004722237586975
Step: 3, Train Loss: 0.982965350151062
Step: 3, Test Loss: 1.0281261205673218
Step: 4, Train Loss: 1.1700845956802368
Step: 4, Test Loss: 1.0455362796783447
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0411475896835327
Step: 6, Train Loss: 1.2408322095870972
Step: 6, Test Loss: 1.0204349756240845
Step: 7, Train Loss: 0.7292405366897583
Step: 7, Test Loss: 0.9959328770637512
Step: 8, Train Loss: 1.1697252988815308
Step: 8, Test Loss: 0.9822244644165039
Step: 9, Train Loss: 1.015731692314148
Step: 9, Test Loss: 0.9667297005653381
Training accuracy:  0.5217391304347826
Time taken: 13.152780055999756
Testing accuracy:  0.5555555555555556
Time taken: 1.4448845386505127

Setting - 3

Furthermore, I thought of reducing the logging, and wanted to calculate and log the test_loss every 5th step, by updating the code to:

@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no, print_training = args
    _data = data[batch_no % num_batch]
    _targets = targets[batch_no % num_batch]
    train_loss, grads = jax.value_and_grad(cost)(params, _data, _targets)
    updates, opt_state = opt.update(grads, opt_state)
    # train_accuracy, grads = jax.value_and_grad(acc)(params, X_train, y_train)
    # test_accuracy, grads = jax.value_and_grad(acc)(params, X_test, y_test)
    params = optax.apply_updates(params, updates)

    # Print training loss every 5 steps if print_training is True
    def print_fn():
        test_loss, grads = jax.value_and_grad(cost)(params, X_test, y_test)
        jax.debug.print("Step: {i}, Train Loss: {train_loss}", i=i, train_loss=train_loss)
        # jax.debug.print("Step: {i}, Train Accuracy: {train_accuracy}", i=i, train_accuracy=train_accuracy)
        jax.debug.print("Step: {i}, Test Loss: {test_loss}", i=i, test_loss=test_loss)
        # jax.debug.print("Step: {i}, Test Accuracy: {test_accuracy}", i=i, test_accuracy=test_accuracy)

    jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)
    return (params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no + 1, print_training)


@jax.jit
def optimization_jit(params, data, targets, X_test, y_test, X_train, y_train, print_training = True):
    opt_state = opt.init(params)
    args = (params, opt_state, data, targets, X_test, y_test, X_train, y_train, 0, print_training)
    (params, _, _, _, _, _, _, _, _, _) = jax.lax.fori_loop(0, 10, update_step_jit, args)
    return params

I though calling the print_fn fewer times would have resulted in even lesser runtime but no, the outputs were:

First Run:

Training Done! 
Time taken: 75.2902774810791
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 0.9935969114303589
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0204349756240845
Training accuracy:  0.5217391304347826
Time taken: 13.591582536697388
Testing accuracy:  0.5555555555555556
Time taken: 1.6048238277435303

Second run:

Training Done! 
Time taken: 86.21267819404602
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 0.9935969114303589
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0204349756240845
Training accuracy:  0.5217391304347826
Time taken: 13.666489601135254
Testing accuracy:  0.5555555555555556
Time taken: 1.5537452697753906

Third run:

Training Done! 
Time taken: 90.7916328907013
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 0.9935969114303589
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0204349756240845
Training accuracy:  0.5217391304347826
Time taken: 13.21641230583191
Testing accuracy:  0.5555555555555556
Time taken: 1.5349321365356445

The runtimes for the different settings can be plotted as:

Runtimes for different runs for different settings

My questions are:

  • Why is Setting - 1, where the optimization loop is run only once, consistently taking more time than Setting - 2, where the optimization loop is run 10 times?
  • Why is Setting - 3, where the print_fn function in the optimization loop is called every 5th optimization step, consistently taking more time than Setting - 2, where the print_fn is being called on every iteration?

1 Answer 1

1

The compiler works in mysterious ways!

I suspect the difference here is that in the case of a length-1 fori_loop, the compiler optimizes-away the scan; for example:

$ print(jax.jit(lambda x: jax.lax.fori_loop(0, 1, lambda i, x: x * 2, x)).lower(1.0).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.25 (Arg_0.1: f32[]) -> f32[] {
  %Arg_0.1 = f32[] parameter(0), metadata={op_name="x"}
  %constant.3 = f32[] constant(2)
  ROOT %multiply.1 = f32[] multiply(f32[] %Arg_0.1, f32[] %constant.3), metadata={op_name="jit(<lambda>)/jit(main)/while/body/mul" source_file="<ipython-input-10-f67509edfb4c>" source_line=1}
}

But with a non-trivial for loop, the scan is not optimized away:

$ print(jax.jit(lambda x: jax.lax.fori_loop(0, 10, lambda i, x: x * 2, x)).lower(1.0).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

%region_0.8 (arg_tuple.9: (s32[], f32[])) -> (s32[], f32[]) {
  %constant.12 = s32[] constant(1)
  %arg_tuple.9 = (s32[], f32[]) parameter(0)
  %get-tuple-element.2 = s32[] get-tuple-element((s32[], f32[]) %arg_tuple.9), index=0
  %add.14 = s32[] add(s32[] %get-tuple-element.2, s32[] %constant.12), metadata={op_name="jit(<lambda>)/jit(main)/while/body/add" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
  %constant.0 = f32[] constant(2)
  %get-tuple-element.3 = f32[] get-tuple-element((s32[], f32[]) %arg_tuple.9), index=1
  %multiply.0 = f32[] multiply(f32[] %get-tuple-element.3, f32[] %constant.0), metadata={op_name="jit(<lambda>)/jit(main)/while/body/mul" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
  ROOT %tuple.2 = (s32[], f32[]) tuple(s32[] %add.14, f32[] %multiply.0)
}

%region_1.16 (arg_tuple.17: (s32[], f32[])) -> pred[] {
  %constant.20 = s32[] constant(10)
  %arg_tuple.17 = (s32[], f32[]) parameter(0)
  %get-tuple-element.18 = s32[] get-tuple-element((s32[], f32[]) %arg_tuple.17), index=0
  ROOT %compare.21 = pred[] compare(s32[] %get-tuple-element.18, s32[] %constant.20), direction=LT, metadata={op_name="jit(<lambda>)/jit(main)/while/cond/lt" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
}

ENTRY %main.25 (Arg_0.1: f32[]) -> f32[] {
  %Arg_0.1 = f32[] parameter(0), metadata={op_name="x"}
  %copy.6 = f32[] copy(f32[] %Arg_0.1)
  %constant.2 = s32[] constant(0)
  %copy.7 = s32[] copy(s32[] %constant.2)
  %tuple = (s32[], f32[]) tuple(s32[] %copy.7, f32[] %copy.6)
  %while.22 = (s32[], f32[]) while((s32[], f32[]) %tuple), condition=%region_1.16, body=%region_0.8, metadata={op_name="jit(<lambda>)/jit(main)/while" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}, backend_config={"known_trip_count":{"n":"10"}}
  ROOT %get-tuple-element.24 = f32[] get-tuple-element((s32[], f32[]) %while.22), index=1, metadata={op_name="jit(<lambda>)/jit(main)/while" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
}

The result of this, even for the simple function here, is that the compiler produces some fusions in the second case that it does not in the first; in your case it's probably the case that those fusions lead to faster execution.

A perfect compiler would never make a decision like this that leads to slower execution, but no compiler is perfect. If you wish, you could report this at https://github.com/openxla/xla, but you'd probably want to try for a far more minimized reproduction before doing so.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.