tensor.rewriting.math – Tensor Rewrites for Math Operations#

Rewrites for the Ops in pytensor.tensor.math.

class pytensor.tensor.rewriting.math.AlgebraicCanonizer(main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True)[source]#

A Rewriter that rewrites algebraic expressions.

The variable is a node_rewriter. It is best used with a WalkingGraphRewriter in in-to-out order.

Usage: AlgebraicCanonizer(main, inverse, reciprocal, calculate)

  • main – A suitable Op class that is commutative, associative and takes one to an arbitrary number of inputs, e.g. add or mul

  • inverse – An Op class such that inverse(main(x, y), y) == x (e.g. sub or true_div).

  • reciprocal – A function such that main(x, reciprocal(y)) == inverse(x, y) (e.g. neg or reciprocal).

  • calculate – Function that takes a list of numpy.ndarray instances for the numerator, another list for the denumerator, and calculates inverse(main(\*num), main(\*denum)). It takes a keyword argument, aslist. If True, the value should be returned as a list of one element, unless the value is such that value = main(). In that case, the return value should be an empty list.


>>> import pytensor.tensor as pt
>>> from pytensor.tensor.rewriting.math import AlgebraicCanonizer
>>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \
...                                    lambda n, d: sum(n) - sum(d))
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, reciprocal, \
...                                    lambda n, d: prod(n) / prod(d))

Examples of rewrites mul_canonizer can perform:

x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
x * y * z -> Elemwise(mul){x,y,z} #only one pass over the memory.
!-> Elemwise(mul){x,Elemwise(mul){y,z}}

This extract two lists, num and denum, such that the input is: self.inverse(self.main(\*num), self.main(\*denum)). It returns the two lists in a (num, denum) pair.

For example, for main, inverse and reciprocal = \*, / and inv(),

input -> returned value (num, denum)
x*y -> ([x, y], [])
inv(x) -> ([], [x])
inv(x) * inv(y) -> ([], [x, y])
x*y/z -> ([x, y], [z])
log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
(((a / b) * c) / d) -> ([a, c], [b, d])
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
x * y * z -> ([x, y, z], [])
merge_num_denum(num, denum)[source]#

Utility function which takes two lists, num and denum, and returns something which is equivalent to inverse(main(*num), main(*denum)), but depends on the length of num and the length of denum (in order to minimize the number of operations).

Let n = len(num) and d = len(denum):

n=0, d=0: neutral element (given by self.calculate([], []))
(for example, this would be 0 if main is addition
and 1 if main is multiplication)
n=1, d=0: num[0]
n=0, d=1: reciprocal(denum[0])
n=1, d=1: inverse(num[0], denum[0])
n=0, d>1: reciprocal(main(*denum))
n>1, d=0: main(*num)
n=1, d>1: inverse(num[0], main(*denum))
n>1, d=1: inverse(main(*num), denum[0])
n>1, d>1: inverse(main(*num), main(*denum))

Given the values of n and d to which they are associated, all of the above are equivalent to: inverse(main(*num), main(*denum))

simplify(num, denum, out_type)[source]#

Shorthand for:

self.simplify_constants(*self.simplify_factors(num, denum))
simplify_constants(orig_num, orig_denum, out_type=None)[source]#

Find all constants and put them together into a single constant.

Finds all constants in orig_num and orig_denum and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator.


Let main be multiplication:

[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
simplify_factors(num, denum)[source]#

For any Variable r which is both in num and denum, removes it from both lists. Modifies the lists inplace. Returns the modified lists. For example:

[x], [x] -> [], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]

Return the list of Op classes to which this rewrite applies.

Returns None when the rewrite applies to all nodes.

transform(fgraph, node)[source]#

Rewrite the sub-graph given by node.

Subclasses should implement this function so that it returns one of the following:

  • False to indicate that this rewrite cannot be applied to node

  • A list of Variables to use in place of the node’s current outputs

  • A dict mapping old Variables to Variables, or the key

    "remove" mapping to a list of Variables to be removed.

  • fgraph – A FunctionGraph containing node.

  • node – An Apply node to be rewritten.

pytensor.tensor.rewriting.math.attempt_distribution(factor, num, denum, out_type)[source]#

Try to insert each num and each denum in the factor?


If there are changes, new_num and new_denum contain all the numerators and denominators that could not be distributed in the factor

Return type:

changes?, new_factor, new_num, new_denum

pytensor.tensor.rewriting.math.check_for_x_over_absX(numerators, denominators)[source]#

Convert x/abs(x) into sign(x).


Compute the Variable that is the output of a multiplication tree.

This is the inverse of the operation performed by parse_mul_tree, i.e. compute_mul(parse_mul_tree(tree)) == tree.


tree – A multiplication tree (as output by parse_mul_tree).


A Variable that computes the multiplication represented by the tree.

Return type:


pytensor.tensor.rewriting.math.is_1pexp(t, only_process_constants=True)[source]#

If ‘t’ is of the form (1+exp(x)), return (False, x). Else return None.

Return type:



Match a variable with either of the exp(x) or -exp(x) patterns.


var – The Variable to analyze.


A pair (b, x) with b a boolean set to True if var is of the form -exp(x) and False if var is of the form exp(x). If var cannot be cast into either form, then return None.

Return type:


pytensor.tensor.rewriting.math.is_inverse_pair(node_op, prev_op, inv_pair)[source]#

Given two consecutive operations, check if they are the provided pair of inverse functions.


Match a variable with x * y * z * ....


var – The Variable to analyze.


A list [x, y, z, …] if var is of the form x * y * z * ..., or None if var cannot be cast into this form.

Return type:



Match a variable with the -x pattern.


var – The Variable to analyze.


x if var is of the form -x, or None otherwise.

Return type:



Parse a tree of multiplications starting at the given root.


root – The variable at the root of the tree.


A tree where each non-leaf node corresponds to a multiplication in the computation of root, represented by the list of its inputs. Each input is a pair [n, x] with n a boolean value indicating whether sub-tree x should be negated.

Return type:



x * y               -> [False, [[False, x], [False, y]]]
-(x * y)            -> [True, [[False, x], [False, y]]]
-x * y              -> [False, [[True, x], [False, y]]]
-x                  -> [True, x]
(x * y) * -z        -> [False, [[False, [[False, x], [False, y]]],
                                [True, z]]]
pytensor.tensor.rewriting.math.perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, sigm_minus_x=None, parent=None, child_idx=None, full_tree=None)[source]#

Core processing of the local_sigm_times_exp rewrite.

This recursive function operates on a multiplication tree as output by parse_mul_tree. It walks through the tree and modifies it in-place by replacing matching pairs (exp, sigmoid) with the desired version.

  • tree – The sub-tree to operate on.

  • exp_x – List of arguments x so that exp(x) exists somewhere in the whole multiplication tree. Each argument is a pair (x, leaf) with x the argument of the exponential, and leaf the corresponding leaf in the multiplication tree (of the form [n, exp(x)] – see parse_mul_tree). If None, this argument is initialized to an empty list.

  • exp_minus_x – Similar to exp_x, but for exp(-x).

  • sigm_x – Similar to exp_x, but for sigmoid(x).

  • sigm_minus_x – Similar to exp_x, but for sigmoid(-x).

  • parent – Parent of tree (None if tree is the global root).

  • child_idx – Index of tree in its parent’s inputs (None if tree is the global root).

  • full_tree – The global multiplication tree (should not be set except by recursive calls to this function). Used for debugging only.


True if a modification was performed somewhere in the whole multiplication tree, or False otherwise.

Return type:


pytensor.tensor.rewriting.math.replace_leaf(arg, leaves, new_leaves, op, neg)[source]#

Attempt to replace a leaf of a multiplication tree.

We search for a leaf in leaves whose argument is arg, and if we find one, we remove it from leaves and add to new_leaves a leaf with argument arg and variable op(arg).

  • arg – The argument of the leaf we are looking for.

  • leaves – List of leaves to look into. Each leaf should be a pair (x, l) with x the argument of the Op found in the leaf, and l the actual leaf as found in a multiplication tree output by parse_mul_tree (i.e. a pair [boolean, variable]).

  • new_leaves – If a replacement occurred, then the leaf is removed from leaves and added to the list new_leaves (after being modified by op).

  • op – A function that, when applied to arg, returns the Variable we want to replace the original leaf variable with.

  • neg (bool) – If True, then the boolean value associated to the leaf should be swapped. If False, then this value should remain unchanged.


True if a replacement occurred, or False otherwise.

Return type:


pytensor.tensor.rewriting.math.scalarconsts_rest(inputs, elemwise=True, only_process_constants=False)[source]#

Partition a list of variables into two kinds: scalar constants, and the rest.


Simplify a multiplication tree.


tree – A multiplication tree (as output by parse_mul_tree).


A multiplication tree computing the same output as tree but without useless multiplications by 1 nor -1 (identified by leaves of the form [False, None] or [True, None] respectively). Useless multiplications (with less than two inputs) are also removed from the tree.

Return type:
