Overview of the compilation pipeline#
Once one has an PyTensor graph, they can use pytensor.function()
to compile a
function that will perform the computations modeled by the graph in Python, C,
Numba, or JAX.
More specifically, pytensor.function()
takes a list of input and output
Variables that define the precise sub-graphs that
correspond to the desired computations.
Here is an overview of the various steps that are taken during the
compilation performed by pytensor.function()
.
Step 1 - Create a FunctionGraph
#
The subgraph specified by the end-user is wrapped in a structure called
FunctionGraph
. This structure defines several callback hooks for when specific
changes are made to a FunctionGraph
–like adding and
removing nodes, as well as modifying links between nodes
(e.g. modifying an input of an Apply node). See fg – Graph Container [doc TODO].
FunctionGraph
provides a method to change the input of an Apply
node from one
Variable
to another, and a more high-level method to replace a Variable
with another. These are the primary means of performing graph rewrites.
Some relevant Features are typically added to the
FunctionGraph
at this stage. Namely, Feature
s that prevent
rewrites from operating in-place on inputs declared as immutable.
Step 2 - Perform graph rewrites#
Once the FunctionGraph
is constructed, a rewriter is produced by
the mode passed to function()
. That rewrite is
applied to the FunctionGraph
using its GraphRewriter.rewrite()
method.
The rewriter is typically obtained through a query on optdb
.
Step 3 - Execute linker to obtain a thunk#
Once the computation graph is rewritten, the linker is
extracted from the Mode
. It is then called with the FunctionGraph
as
argument to produce a thunk
, which is a function with no arguments that
returns nothing. Along with the thunk, one list of input containers (a
pytensor.link.basic.Container
is a sort of object that wraps another and does
type casting) and one list of output containers are produced,
corresponding to the input and output Variable
s as well as the updates
defined for the inputs when applicable. To perform the computations,
the inputs must be placed in the input containers, the thunk must be
called, and the outputs must be retrieved from the output containers
where the thunk put them.
Typically, the linker calls the FunctionGraph.toposort()
method in order to obtain
a linear sequence of operations to perform. How they are linked
together depends on the Linker
class used. For example, the CLinker
produces a single
block of C code for the whole computation, whereas the OpWiseCLinker
produces one thunk for each individual operation and calls them in
sequence.
The linker is where some options take effect: the strict
flag of
an input makes the associated input container do type checking. The
borrow
flag of an output, if False
, adds the output to a
no_recycling
list, meaning that when the thunk is called the
output containers will be cleared (if they stay there, as would be the
case if borrow
was True, the thunk would be allowed to reuse–or
“recycle”–the storage).
Note
Compiled libraries are stored within a specific compilation directory,
which by default is set to $HOME/.pytensor/compiledir_xxx
, where
xxx
identifies the platform (under Windows the default location
is instead $LOCALAPPDATA\PyTensor\compiledir_xxx
). It may be manually set
to a different location either by setting config.compiledir
or
config.base_compiledir
, either within your Python script or by
using one of the configuration mechanisms described in config
.
The compile cache is based upon the C++ code of the graph to be compiled.
So, if you change compilation configuration variables, such as
config.blas__ldflags
, you will need to manually remove your compile cache,
using PyTensor/bin/pytensor-cache clear
PyTensor also implements a lock mechanism that prevents multiple compilations within the same compilation directory (to avoid crashes with parallel execution of some scripts).
Step 4 - Wrap the thunk in a pretty package#
The thunk returned by the linker along with input and output
containers is unwieldy. pytensor.function()
hides that complexity away so
that it can be used like a normal function with arguments and return
values.