I have the following code that defines an abstract class and its final subclasse. The two classes are both subclasses of the equinox.Module class, which registers class attributes as the leaves of a PyTree container.
# === IMPORTS ===
from abc import ABC, abstractmethod
import jax
from jax.typing import ArrayLike
import jax.numpy as jnp
import equinox as eqx
from quadax import quadgk
jax.config.update("jax_enable_x64", True)
class MyClass(eqx.Module): # Works if I toggle to MyClass(ABC)
rtol = 1e-12
atol = 1e-12
param: ArrayLike
def __init__(self):
self.param = self._integral_moment(3) # Fails, but works if I toggle to something like "self.param = self.func(1.)"
@abstractmethod
def func(self, tau):
pass
def func_abs(self, tau):
return jnp.abs(self.func(tau))
def _integral_moment(self, order):
return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0]
def _integrand_moment(self, tau, order):
return self.func_abs(tau) * jnp.abs(tau)**order
class MySubClass(MyClass):
gamma: ArrayLike
kappa: ArrayLike
w0: ArrayLike
def __init__(self, gamma, kappa, w0):
self.gamma = jnp.asarray(gamma)
self.kappa = jnp.asarray(kappa)
self.w0 = jnp.asarray(w0)
super().__init__()
def func(self, tau):
return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2)
# Test
test = MySubClass(gamma=1., kappa=1., w0=1.)
test.param
This code produces the AttributeError message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[21], line 52
48 return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2)
51 # Test
---> 52 test = MySubClass(gamma=1., kappa=1., w0=1.)
53 test.param
[... skipping hidden 2 frame]
Cell In[21], line 45
43 self.kappa = jnp.asarray(kappa)
44 self.w0 = jnp.asarray(w0)
---> 45 super().__init__()
Cell In[21], line 19
18 def __init__(self):
---> 19 self.param = self._integral_moment(3)
[... skipping hidden 1 frame]
Cell In[21], line 29
28 def _integral_moment(self, order):
---> 29 return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0]
...
659 and isinstance(out, types.MethodType)
660 and out.__self__ is self
661 ):
AttributeError: 'MySubClass' object has no attribute 'param'
This error clearly comes from a restriction of the equinox.Module, since if I change the parent class to ABC, the code runs fine.
First, I thought that maybe equinox did not allow me to use methods to initialize attributes. But if I use the func() method instead of the _integral_moment() method to initialize param, the code works fine.
So I just don't understand what is going on here. I thought it would be better to ask here before asking the developers at equinox.
This uses equinox version 0.13.1 with jax version 0.7.2.