Adding JAX and Numba support for Op
s#
PyTensor is able to convert its graphs into JAX and Numba compiled functions. In order to do
this, each Op
in an PyTensor graph must have an equivalent JAX/Numba implementation function.
This tutorial will explain how JAX and Numba implementations are created for an Op
. It will
focus specifically on the JAX case, but the same mechanisms are used for Numba as well.
Step 1: Identify the PyTensor Op
you’d like to implement in JAX#
Find the source for the PyTensor Op
you’d like to be supported in JAX, and
identify the function signature and return values. These can be determined by
looking at the Op.make_node()
implementation. In general, one needs to be familiar
with PyTensor Op
s in order to provide a conversion implementation, so first read
Creating a new Op: Python implementation if you are not familiar.
For example, the Eye
Op
current has an Op.make_node()
as follows:
def make_node(self, n, m, k):
n = as_tensor_variable(n)
m = as_tensor_variable(m)
k = as_tensor_variable(k)
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0
return Apply(
self,
[n, m, k],
[TensorType(dtype=self.dtype, shape=(None, None))()],
)
The Apply
instance that’s returned specifies the exact types of inputs that
our JAX implementation will receive and the exact types of outputs it’s expected to
return–both in terms of their data types and number of dimensions.
The actual inputs our implementation will receive are necessarily numeric values
or NumPy ndarray
s; all that Op.make_node()
tells us is the
general signature of the underlying computation.
More specifically, the Apply
implies that the inputs come from values that are
automatically converted to PyTensor variables via as_tensor_variable()
, and
the assert
s that follow imply that they must be scalars. According to this
logic, the inputs could have any data type (e.g. floats, ints), so our JAX
implementation must be able to handle all the possible data types.
It also tells us that there’s only one return value, that it has a data type
determined by Eye.dtype
, and that it has two non-broadcastable
dimensions. The latter implies that the result is necessarily a matrix. The
former implies that our JAX implementation will need to access the dtype
attribute of the PyTensor Eye
Op
it’s converting.
Next, we can look at the Op.perform()
implementation to see exactly
how the inputs and outputs are used to compute the outputs for an Op
in Python. This method is effectively what needs to be implemented in JAX.
Step 2: Find the relevant JAX method (or something close)#
With a precise idea of what the PyTensor Op
does we need to figure out how
to implement it in JAX. In the best case scenario, JAX has a similarly named
function that performs exactly the same computations as the Op
. For
example, the Eye
operator has a JAX equivalent: jax.numpy.eye()
(see the documentation).
If we wanted to implement an Op
like IfElse
, we might need to
recreate the functionality with some custom logic. In many cases, at least some
custom logic is needed to reformat the inputs and outputs so that they exactly
match the Op
’s.
Here’s an example for IfElse
:
def ifelse(cond, *args, n_outs=n_outs):
res = jax.lax.cond(
cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
)
return res if n_outs > 1 else res[0]
Step 3: Register the function with the jax_funcify
dispatcher#
With the PyTensor Op
replicated in JAX, we’ll need to register the
function with the PyTensor JAX Linker
. This is done through the use of
singledispatch
. If you don’t know how singledispatch
works, see the
Python documentation.
The relevant dispatch functions created by singledispatch
are pytensor.link.numba.dispatch.numba_funcify()
and
pytensor.link.jax.dispatch.jax_funcify()
.
Here’s an example for the Eye
Op
:
import jax.numpy as jnp
from pytensor.tensor.basic import Eye
from pytensor.link.jax.dispatch import jax_funcify
@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
# Obtain necessary "static" attributes from the Op being converted
dtype = op.dtype
# Create a JAX jit-able function that implements the Op
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye
Step 4: Write tests#
Test that your registered Op
is working correctly by adding tests to the
appropriate test suites in PyTensor (e.g. in tests.link.test_jax
and one of
the modules in tests.link.numba.dispatch
). The tests should ensure that your implementation can
handle the appropriate types of inputs and produce outputs equivalent to Op.perform
.
Check the existing tests for the general outline of these kinds of tests. In
most cases, a helper function can be used to easily verify the correspondence
between a JAX/Numba implementation and its Op
.
For example, the compare_jax_and_py()
function streamlines the steps
involved in making comparisons with Op.perform
.
Here’s a small example of a test for Eye
:
import pytensor.tensor as at
def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
# Create a symbolic input for `Eye`
x_at = at.scalar()
# Create a variable that is the output of an `Eye` `Op`
eye_var = at.eye(x_at)
# Create an PyTensor `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])
# Pass the graph and any inputs to the testing function
compare_jax_and_py(out_fg, [3])