Developer documentation for Scan
#
Context#
This document is meant to act as reference material for developers working
on PyTensor’s loop mechanism. This mechanism is called Scan
and its internals
are highly complex, hence the need for a centralized repository of knowledge
regarding its inner workings.
The pytensor.scan
function is the public-facing interface for looping in
PyTensor. Under the hood, this function will perform some processing on its
inputs and instantiate the Scan
Op
class which implements the looping
mechanism. It achieves this by compiling its own PyTensor function representing
the computation to be done at every iteration of the loop and calling it as
many times as necessary.
The correspondence between the parameters and behaviors of the function and the
Op
is not always simple since the former is meant for usability and the second
for performance. Since this document is intended to be used by developers
working inside Scan
itself, it will mostly discuss things from the point of view
of the Scan
Op
class. Nonetheless, it will attempt to link those elements to
their corresponding concepts in the Scan
function as often as is reasonably
practical.
Pre-requisites#
The following sections assumes the reader is familiar with the following :
PyTensor’s graph structure (
Apply
nodes,Variable
nodes andOp
s)The interface and usage of PyTensor’s scan function
Additionally, the Rewrites section below assumes knowledge of:
PyTensor’s graph rewriting
Relevant code files#
The implementation of Scan
is spread over several files in
pytensor/scan
. The different files, and sections of the code they
deal with, are :
basic.py
implements thescan
function. Thescan
function arranges the arguments ofscan
correctly, constructs theScan
Op
and afterwards calls the constructedScan
Op
on the arguments. This function takes care of figuring out missing inputs and shared variables.op.py
implements theScan
Op
class. TheScan
respects theOp
interface, and contains most of the logic of theScan
operator.utils.py
contains several helpful functions used throughout out the other files that are specific of theScan
operator.views.py
contains different views of theScan
Op
that have simpler and easier signatures to be used in specific cases.opt.py
contains the list of all PyTensor graph rewrites for theScan
operator.
Notation#
Scan
being a sizeable and complex module, it has its own naming convention for
functions and variables which this section will attempt to introduce.
A Scan
Op
contains an PyTensor function representing the computation
that is done in a single iteration of the loop represented by the Scan
Op
(in
other words, the computation given by the function provided as value to
pytensor.scan
’s fn
argument ). Whenever we discuss a Scan
Op
, the outer
function refers to the PyTensor function that contains the Scan
Op
whereas the
inner function refers to the PyTensor function that is contained inside the
Scan
Op
.
In the same spirit, the inputs and outputs of the Apply node wrapping the `Scan`
`Op` (or `Scan` node for short) are referred to as outer inputs and outer
outputs, respectively, because these inputs and outputs are variables in the
outer function graph. The inputs and outputs of Scan
’s inner function are
designated inner inputs and inner outputs, respectively.
Scan
variables#
The following are the different types of variables that Scan
has the
capacity to handle, along with their various characteristics.
Sequence : A sequence is an PyTensor variable which Scan
will iterate
over and give sub-elements to its inner function as input. A sequence
has no associated output. For a sequence variable X
, at timestep
t
, the inner function will receive as input the sequence element
X[t]
. These variables are used through the argument sequences
of the pytensor.scan
function.
Non-sequences : A non-sequence is an PyTensor variable which Scan
will provide as-is to its inner function. Like a sequence, a
non-sequence has no associated output. For a non-sequence variable
X
, at timestep t
, the inner function will receive as input
the variable X
. These variables are used through the argument
non_sequences
of the pytensor.scan
function.
NITSOT (no input tap, single output tap) : A NITSOT is an output
variable of the inner function that is not fed back as an input to the
next iteration of the inner function. NITSOTs are typically
encountered in situations where Scan
is used to perform a ‘map’
operation (every element in a tensor is independently altered using a
given operation to produce a new tensor) such as squaring every number
in a vector.
SITSOT (single input tap, single output tap) : A SITSOT is an output
variable of the inner function that is fed back as an input to the next
iteration of the inner function. A typical setting where a SITSOT might be
encountered is the case where Scan
is used to compute the cumulative sum over
the elements of a vector and a SITSOT output is employed to act as an
accumulator.
MITSOT (multiple input taps, single output tap) : A MITSOT is an
output variable of the inner function that is fed back as an input to
future iterations of the inner function (either multiple future
iterations or a single one that isn’t the immediate next one). For
example, a MITSOT might be used in the case where Scan
is used to
compute the Fibonacci sequence, one term of the sequence at every
timestep, since every computed term needs to be reused to compute the
two next terms of the sequence.
MITMOT (multiple input taps, multiple output taps) : These outputs exist
but they cannot be directly created by the user. They can appear in an PyTensor
graph as a result of taking the gradient of the output of a Scan
with respect
to its inputs: This will result in the creation of a new Scan
node used to
compute the gradients of the first Scan
node. If the original Scan
had SITSOTs
or MITSOTs variables, the new Scan
will use MITMOTs to compute the gradients
through time for these variables.
To synthesize :
Type of |
Corresponding outer input |
Corresponding inner input at timestep |
Corresponding inner output at timestep |
Corresponding outer output |
Corresponding argument of the |
---|---|---|---|---|---|
Sequence |
Sequence of elements |
Individual sequence element |
No corresponding inner output |
No corresponding outer output |
|
Non-Sequence |
Any variable |
Variable identical to |
No corresponding inner output |
No corresponding outer output |
|
Non-recurring output (NITSOT) |
No corresponding outer input |
No corresponding inner input |
Output value at timestep |
Concatenation of the values of the output at all timestep |
|
Singly-recurrent output (SITSOT) |
Initial value (value at timestep |
Output value at previous timestep ( |
Output value at timestep |
Concatenation of the values of the output at all timestep |
|
Multiply-recurrent output (MITSOT) |
Initial values for the required timesteps where |
Output value at previous required timesteps |
Output value at timestep |
Concatenation of the values of the output at all timestep |
|
Multiply-recurrent multiple outputs (MITMOT) |
Initial values for the required timesteps where |
Output value at previous required timesteps |
Output values for current and multiple future timesteps |
Concatenation of the values of the output at all timestep |
No corresponding argument |
Rewrites#
remove_constants_and_unused_inputs_scan
#
This rewrite serves two purposes, The first is to remove a Scan
Op
’s
unused inputs. The second is to take a Scan
Op
’s constant inputs and remove
them, instead injecting the constants directly into the graph or the Scan
Op
’s inner function. This will allow constant folding to happen inside the
inner function.
PushOutNonSeqScan
#
This rewrite pushes sub-graphs that depends only on non-sequence inputs out of
Scan
’s inner function and into the outer function. Such computation ends up
being done every iteration on the same values so moving it to the outer function
to be executed only once, before the Scan
Op
, reduces the amount of
computation that needs to be performed.
PushOutSeqScan
#
This rewrite resembles PushOutNonSeqScan
but it tries to push, out of
the inner function, the computation that only relies on sequence and
non-sequence inputs. The idea behind this rewrite is that, when it is
possible to do so, it is generally more computationally efficient to perform
a single operation on a large tensor rather then perform that same operation
many times on many smaller tensors. In many cases, this rewrite can
increase memory usage but, in some specific cases, it can also decrease it.
PushOutScanOutput
#
This rewrite attempts to push out some of the computation at the end
of the inner function to the outer function, to be executed after the Scan
node. Like PushOutSeqScan
, this rewrite aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
PushOutDot1
#
This is another rewrite that attempts to detect certain patterns of
computation in a Scan
Op
’s inner function and move this computation to the
outer graph.
ScanInplaceOptimizer
#
This rewrite attempts to make Scan
compute its recurrent outputs inplace
on the input tensors that contain their initial states. This rewrite can
improve runtime performance as well as reduce memory usage.
ScanSaveMem
#
This rewrite attempts to determine if a Scan
node, during its execution,
for any of its outputs, can get away with allocating a memory buffer that is
large enough to contain some of the computed timesteps of that output but not
all of them.
By default, during the execution of a Scan
node, memory buffers will be
allocated to store the values computed for every output at every iteration.
However, in some cases, there are outputs for which there is only really a
need to store the most recent N
values, not all of them.
For instance, if a Scan
node has a SITSOT output (last computed value is
fed back as an input at the next iteration) and only the last timestep of
that output is ever used in the outer function, the ScanSaveMem
rewrite
could determine that there is no need to store all computed timesteps for
that SITSOT output. Only the most recently computed timestep ever needs to
be kept in memory.
ScanMerge
#
This rewrite attempts to fuse distinct Scan
nodes into a single Scan
node
that performs all the computation. The main advantage of merging Scan
nodes
together comes from the possibility of both original Scan
Op
s having some
computation in common. In such a setting, this computation ends up being done
twice. The fused Scan
s, however, would only need to do it once and could
therefore be more computationally efficient. Also, since every Scan
node
involves a certain overhead, at runtime, reducing the number of Scan
nodes in
the graph can improve performance.
scan_merge_inouts
#
This rewrite attempts to merge a Scan
s identical outer inputs as well
as merge its identical outer outputs (outputs that perform the same
computation on the same inputs). This can reduce the amount of computation as
well as result in a simpler graph for both the inner function and the outer
function.
Helper classes and functions#
Because of the complexity involved in dealing with Scan
, a large number of
helper classes and functions have been developed over time to implement
operations commonly needed when dealing with the Scan
Op
. The Scan
Op
itself defines a large number of them and others can be found in the file
utils.py
. This sections aims to point out the most useful ones sorted
by usage.
Accessing/manipulating Scan
’s inputs and outputs by type#
Declared in utils.py
, the class ScanArgs
handles the
parsing of the inputs and outputs (both inner and outer) to a format
that is easier to analyze and manipulate. Without this class,
analyzing Scan
’s inputs and outputs can require convoluted logic
which make for code that is hard to read and to maintain. Because of
this, you should favor using ScanArgs
when it is practical and
appropriate to do so.
The Scan
Op
extends ScanPropertiesMixin
, which defines a few helper
methods for this purpose, such as inner_nitsot_outs
or mitmot_out_taps
, but
they are often poorly documented and easy to misuse. These should be used with
great care.