2

I have the following code that defines an abstract class and its final subclasse. The two classes are both subclasses of the equinox.Module class, which registers class attributes as the leaves of a PyTree container.

# === IMPORTS ===
from abc import ABC, abstractmethod
import jax
from jax.typing import ArrayLike
import jax.numpy as jnp
import equinox as eqx
from quadax import quadgk

jax.config.update("jax_enable_x64", True)


class MyClass(eqx.Module): # Works if I toggle to MyClass(ABC)

    rtol = 1e-12
    atol = 1e-12

    param: ArrayLike

    def __init__(self):
        self.param = self._integral_moment(3) # Fails, but works if I toggle to something like "self.param = self.func(1.)"

    @abstractmethod 
    def func(self, tau):
        pass

    def func_abs(self, tau):
        return jnp.abs(self.func(tau))
    
    def _integral_moment(self, order): 
        return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0]

    def _integrand_moment(self, tau, order):
        return self.func_abs(tau) * jnp.abs(tau)**order
 

class MySubClass(MyClass):

    gamma: ArrayLike
    kappa: ArrayLike
    w0: ArrayLike

    def __init__(self, gamma, kappa, w0):
        self.gamma = jnp.asarray(gamma)
        self.kappa = jnp.asarray(kappa) 
        self.w0 = jnp.asarray(w0)
        super().__init__()

    def func(self, tau):
        return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2)
    

# Test    
test = MySubClass(gamma=1., kappa=1., w0=1.)
test.param

This code produces the AttributeError message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[21], line 52
     48         return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2)
     51 # Test    
---> 52 test = MySubClass(gamma=1., kappa=1., w0=1.)
     53 test.param

    [... skipping hidden 2 frame]

Cell In[21], line 45
     43 self.kappa = jnp.asarray(kappa) 
     44 self.w0 = jnp.asarray(w0)
---> 45 super().__init__()

Cell In[21], line 19
     18 def __init__(self):
---> 19     self.param = self._integral_moment(3)

    [... skipping hidden 1 frame]

Cell In[21], line 29
     28 def _integral_moment(self, order): 
---> 29     return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0]
...
    659         and isinstance(out, types.MethodType)
    660         and out.__self__ is self
    661     ):

AttributeError: 'MySubClass' object has no attribute 'param'

This error clearly comes from a restriction of the equinox.Module, since if I change the parent class to ABC, the code runs fine.

First, I thought that maybe equinox did not allow me to use methods to initialize attributes. But if I use the func() method instead of the _integral_moment() method to initialize param, the code works fine.

So I just don't understand what is going on here. I thought it would be better to ask here before asking the developers at equinox.

This uses equinox version 0.13.1 with jax version 0.7.2.

2 Answers 2

1

The issue here is that when traced, eqx.Module attempts to access all the declared attributes of the Module, so the module cannot be traced before those attributes are created. Here's a simpler repro of the same problem:

import jax
import equinox as eqx

class MyClass(eqx.Module):
    param: ArrayLike

    def __init__(self):
      self.param = jax.jit(self.func)()

    def func(self):
      return 4

MyClass()  # AttributeError: 'MyClass' object has no attribute 'param'

The quadgk function traces its input, and since you call it before setting param, you get this error. With this issue in mind, you can fix your problem by setting the missing param to a placeholder value before you call a function that traces the object's methods:

class MyClass(eqx.Module):

    ...

    def __init__(self):
        self.param = 0  # set to a placeholder to allow tracing
        self.param = self._integral_moment(3)

    ...
Sign up to request clarification or add additional context in comments.

Comments

0

To follow up on @jakevdp's answer, a completely equivalent but perhaps slightly more elegant way of systematically pre-empting this issue in equinox is to assign a value directly in the attribute definition:

class MyClass(eqx.Module):

    ...

    param: float = 0 # set to a placeholder to allow tracing
    
    def __init__(self):
        self.param = self._integral_moment(3)

    ...

EDIT: Importantly, not that it is NOT allowed to initialize attributes as mutables or jax arrays at the class level in dataclasses like equinox modules, which raises a ValueError: Use default_factory. For the above code, all instances of the class will initially share the same instance object for the field, which is not desired behavior in a dataclass if the attribute can later be modified in some way. This is probably why the previous answer made that choice of initializing in init, which will always work.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.