tensor.rewriting.basic – Tensor Rewrites#
Tensor optimizations addressing the ops in basic.py.
Notes
There are two ways of broadcasting arrays: second(x, y) == alloc(y, broadcast_shapes(x.shape, y.shape))
The second can be more efficient because x doesn’t usually need to be computed when we only want its shape. It may also allow other rewrites that don’t try to modify x when it has multiple clients (for fear of duplicating computation).
However, the first one is easier to reason about. Knowing we have such a graph allows to do certain rewrites such as “sinking” broadcasting operations below Elemwise. The same rewrites with alloc would be more complicated as we would need to symbolically combine the shapes of each one.
As an example contrast rewriting the following two equivalent graphs
alloc(x, broadcast_shapes(x.shape, y.shape)) + alloc(y, broadcast_shapes(x.shape, y.shape)) -> x + y second(y, x) + second(x, y) -> x + y
Theano developers (mostly) preferred to use the first form during canonicalization and introduce the second form later,
via rewrites like local_second_to_alloc, and using the broadcast_like_elemwise helper inside rewrites.
Many stabilize and stabilization rewrites refuse to be applied when a variable has multiple clients, so this is important.
- pytensor.tensor.rewriting.basic.broadcast_like_elemwise(value, node, *, fgraph, ref_input_idx=None, stack_trace=False)[source]#
Broadcast value(s) to match the output shape and dtype(s) of an elemwise node.
Each value is cast to match its corresponding
node.outputs[i]dtype, then broadcast viaAllocto the full output shape if needed.The broadcast shape is derived from
node.inputsrather thannode.outputs, so the result does not depend on the node’s outputs. This eagerness may mask shape errors present in the original graph (“shape_unsafe”).- Parameters:
value – A single variable (or constant), or a list of variables. When a list is given each entry corresponds to an output of
node.node – The elemwise node whose output shape the values should match.
fgraph – The function graph containing the node (used to simplify shapes).
ref_input_idx – If given,
node.inputs[ref_input_idx]is moved to the front when computing the broadcast shape, so its shape entries dominate on axes where it is not broadcastable (seeget_simplified_broadcast_shape). Use this when the rewrite “keeps” one specific input and that input’s shape should dominate (e.g.second(a, b) -> bwithref_input_idx=1drops the dependency ona— shape_unsafe).stack_trace – If
True, copy the stack trace fromnode.outputs[0]onto every returned variable.
- pytensor.tensor.rewriting.basic.broadcasted_by(x, y)[source]#
Check whether x would be broadcasted by y in an Elemwise operation
- Parameters:
x (TensorVariable) – The variable that may be broadcasted by y
y (TensorVariable) – The variable that may broadcast x
- Returns:
broadcasted_by
- Return type:
bool
- pytensor.tensor.rewriting.basic.equivalent_up_to_constant_casting(a, b)[source]#
Return True if a and b are equivalent up to constant casting.
- pytensor.tensor.rewriting.basic.get_simplified_broadcast_shape(first, *others, fgraph)[source]#
Per-axis fold of
firstwithothers, prioritizing non-broadcastable lengths.The shape entry from
firstwins on every axis wherefirstis not broadcastable; only axes wherefirstis broadcastable look toothersfor a non-broadcastable length to substitute. Callers therefore pass the “preferred shape source” first (Elemwise’s standard broadcast joining is commutative, but the resulting symbolic shape is not — we want the simplest/most-static expression).Assumes all inputs have the same ndim (the Elemwise contract).