2

Consider the following file:

import jax.numpy as jnp

def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    return a + b

Running mypy mypytest.py returns the following error:

mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")

For some reason it believes adding two jax.numpy.ndarrays returns a NumPy array of bools. Am I doing something wrong? Or is this a bug in MyPy, or Jax's type annotations?

0

3 Answers 3

3

At least statically, jnp.ndarray is a subclass of np.ndarray with very minimal modifications

class ndarray(np.ndarray, metaclass=_ArrayMeta):
  dtype: np.dtype
  shape: Tuple[int, ...]
  size: int

  def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
               order=None):
    raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                    " Use jax.numpy.array, or jax.numpy.zeros instead.")

As such, it inherits np.ndarray's method type signatures.

I guess the runtime behaviour is achieved via the jnp.array function. Unless I've missed some stub files or type trickery, the result of jnp.array matches jnp.ndarray simply because jnp.array is untyped. You can test this out with

def foo(_: str) -> None:
   pass

foo(jnp.array(0))

which passes mypy.

So to answer your questions, I don't think you're doing anything wrong. It's a bug in the sense that it's probably not what they mean, but it's not actually incorrect because you do get an np.ndarray when you add jnp.ndarrays because a jnp.ndarray is an np.ndarray.

As for why bools, that's likely because your jnp.arrays are missing generic parameters and the first valid overload for __add__ on np.ndarray is

    @overload
    def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ...  # type: ignore[misc]

so it's just defaulted to bool.

Sign up to request clarification or add additional context in comments.

Comments

2

In general, JAX has very poor compatibility with mypy, because it's very difficult to satisfy mypy's constraints with JAX's transformation model, which often calls functions with transform-specific tracer values that act as stand-ins for arrays (See How To Think in JAX: JIT Mechanics for a brief discussion of this mechanism).

This use of tracer types as standins for arrays means that mypy will raise errors when strictly-typed JAX functions are transformed, and for this reason throughout the JAX codebase we tend to alias Array to Any, and use this as the return type annotation for JAX functions that return arrays.

It would be good to improve on this, because an Any return type is not very useful for effective type checking, but it's just the first of many challenges for making mypy play well with JAX. If you want to read some of the last few years worth of discussions surrounding this issue, I would start here: https://github.com/google/jax/issues/943

And in the meantime, my suggestion would be to use Any as a type annotation for JAX arrays.

Comments

0

As of late 2023, it appears that jax has greatly improved its typing annotations. mypy is fine with the new syntax:

from jax import Array
from jax.typing import ArrayLike

import jax.numpy as jnp

def test(a: ArrayLike, b: ArrayLike) -> Array:
    return a + b

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.