2

In an effort to optimize an existing function used in an optimization algorithm by applying @jit, I encountered some issues. When running the following function:

import jax
import jax.numpy as jnp
from jax import grad, jacobian
from scipy.optimize import minimize
from scipy.interpolate import BSpline

jax.config.update("jax_enable_x64", True)

@jax.jit
def bspline(t_values, knots, coefficients, degree):
    """
    Generate a B-spline curve from given knots and coefficients.

    Parameters:
    - t_values: Array of parameter values where the spline is evaluated.
    - knots: Knot vector as a 1D numpy array.
    - coefficients: Control points as a 1D numpy array of shape (n,).
    - degree: Degree of the B-spline (e.g., 3 for cubic B-splines).

    Returns:
    - A numpy array of shape (num_points,) representing the B-spline curve.
    """
    def basis_function(i, k, t, knots):
        """Compute the basis function recursively."""
        if k == 0:
            return jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0)
        else:
            denom1 = knots[i + k] - knots[i]
            denom2 = knots[i + k + 1] - knots[i + 1]

            term1 = (t - knots[i]) / denom1 * basis_function(i, k - 1, t, knots) if denom1 != 0 else 0
            term2 = (knots[i + k + 1] - t) / denom2 * basis_function(i + 1, k - 1, t, knots) if denom2 != 0 else 0

            return term1 + term2

    # Compute the B-spline curve points
    curve_points = jnp.zeros(len(t_values))
    for i in range(len(coefficients)):
        v = basis_function(i, degree, t_values, knots)
        curve_points = curve_points + v * coefficients[i]

    return curve_points

I get the following error:

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].

What I've Tried:

After consulting JAX's official documentation about this error (available here), I modified the basis_function to avoid direct boolean checks:

def basis_function(i, k, t, knots):
    """Compute the basis function recursively."""
    return jnp.where(
        k == 0,
        jnp.where((knots[i] <= t) & (t < knots[i + 1]), 1.0, 0.0),
        jnp.where(
            (knots[i + k] - knots[i]) != 0,
            (t - knots[i]) / (knots[i + k] - knots[i]) * basis_function(i, k - 1, t, knots),
            0
        ) +
        jnp.where(
            (knots[i + k + 1] - knots[i + 1]) != 0,
            (knots[i + k + 1] - t) / (knots[i + k + 1] - knots[i + 1]) * basis_function(i + 1, k - 1, t, knots),
            0
        )
    )

However, now I encounter a RecursionError:

RecursionError: maximum recursion depth exceeded in comparison

This recursion issue seems to stem from applying @jit, as it was not present before.

1 Answer 1

0

Unfortunately, you cannot use recursive approaches in JAX where the recursion is based on a traced condition. You'll either have to write your recursion using Python control flow with static conditions, or you'll have to rewrite it using a non-recursive approach.

In your case, the first option seems doable so long as degree is static at the call-site. In that case, you could fix your issue by redefining your first function this way:

from functools import partial

@partial(jax.jit, static_argnames=['degree'])
def bspline(t_values, knots, coefficients, degree):
  ...

Keep in mind though that JAX tracing will unroll all such recursion, so this may end up generating a long program that will lead to long compile times.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.