The Type class is used to provide “static” information about the types of Variables in an PyTensor graph. This information is used for graph rewrites and compilation to languages with typing that’s stricter than Python’s.

The types handled by PyTensor naturally overlap a lot with NumPy, but they also differ from it in some very important ways. In the following, we use TensorType to illustrate some important concepts and functionality regarding Type, because it’s the most common and feature rich subclass of Type. Just be aware that all the same high-level concepts apply to any other graph objects modeled by a Type subclass.

The TensorType#

PyTensor has a Type subclass for tensors/arrays called TensorType. It broadly represents a type for tensors, but, more specifically, all of its computations are performed using instances of the numpy.ndarray class, so it effectively represents the same objects as numpy.ndarray.

The expression TensorType(dtype, shape)() will construct a symbolic TensorVariable instance (a subclass of Variable), like numpy.ndarray(shape, dtype) will construct a numpy.ndarray instance.

Notice the extra parenthesis in the TensorType example. Those are necessary because TensorType(dtype, shape) only constructs an instance of a TensorType, then, with that instance, a Variable instance can be constructed using TensorType.__call__(). Just remember that Type objects are not Python types/classes; they’re instances of the Python class Type. The purpose is effectively the same, though: Types provide high-level typing information and construct instances of the high-level types they model. While Python types/classes do this for the Python VM, PyTensor Types do this for its effective “VM”.

In relation to NumPy, the important difference is that PyTensor works at the symbolic level, and, because of that, there are no concrete array instances with which it can call the dtype and shape methods and get information about the data type or shape of a symbolic variable. PyTensor needs static class/type-level information to serve that purpose, and it can’t use the numpy.ndarray class itself, because that doesn’t have fixed data types or shapes.

In analogy with NumPy, we could imagine that the expression TensorType is a numpy.ndarray class constructor like the following:

def NdarrayType(dtype, shape):
    class fixed_dtype_shape_ndarray(_numpy.ndarray):
        dtype = dtype
        shape = shape

        def __call__(self):
            return super().__call__(dtype, shape)

    return fixed_dtype_shape_ndarray

This hypothetical NdarrayType would construct numpy.ndarray subclasses that produces instances with fixed data types and shapes. Also, the subclasses created by this class constructor, would provide data type and shape information about the instances they produce without ever needing to construct an actual instance (e.g. one can simply inspect the class-level shape and dtype members for that information).

TensorTypes provide a way to carry around the same array information at the type level, but they also perform comparisons and conversions between and to different types.

For instance, TensorTypes allow for _partial_ shape information. In other words, the shape values for some–or all–dimensions may be unspecified. The only fixed requirement is that the _number_ of dimensions be fixed/given (i.e. the length of the shape tuple). To encode partial shape information, TensorType allows its shape arguments to include Nones.

To illustrate, TensorType("float64", (2, None)) could represent an array of shape (2, 0), (2, 1), etc. This dynamic opens up some questions regarding the comparison of TensorTypes.

For example, let’s say we have two Variables with the following TensorTypes:

>>> from pytensor.tensor.type import TensorType
>>> v1 = TensorType("float64", (2, None))()
>>> v1.type
TensorType(float64, (2, ?))
>>> v2 = TensorType("float64", (2, 1))()
>>> v2.type
TensorType(float64, (2, 1))

If we ever wanted to replace v1 in an PyTensor graph with v2, we would first need to check that they’re “compatible”. This could be done by noticing that their shapes match everywhere except on the second dimension, where v1 has the shape value None and v2 has a 1. Since None indicates “any” shape value, the two are “compatible” in some sense.

The “compatibility” we’re describing here is really that v1’s Type represents a larger set of arrays, and v2’s represents a much more specific subset, but both belong to the same set of array types.

Type provides a generic interface for these kinds of comparisons with its Type.in_same_class() and Type.is_super() methods. These type-comparison methods are in turn used by the Variable conversion methods to “narrow” the type information received at different stages of graph construction and rewriting.

For example:

>>> v1.type.in_same_class(v2.type)

This result is due to the definition of “type class” used by TensorType. Its definition is based on the broadcastable dimensions (i.e. 1s) in the available static shape information. See the docstring for TensorType.in_same_class() for more information.

>>> v1.type.is_super(v2.type)

This result is due to the fact that v1.type models a superset of the types that v2.type models, since v2.type is a type for arrays with the specific shape (2, 1) and v1.type is a type for _all_ arrays with shape (2, N) for any N–of which v2.type’s type is only a single instance.

This relation is used to “filter” Variables through specific Types in order to generate a new Variable that’s compatible with both. This “filtering” is an important step in the node replacement process during graph rewriting, for instance.

>>> v1.type.filter_variable(v2)
<TensorType(float64, (2, 1))>

“Filtering” returned a variable of the same Type as v2, because v2’s Type is more specific/informative than v1’s–and both are compatible.

>>> v3 = v2.type.filter_variable(v1)
>>> v3
>>> import pytensor
>>> pytensor.dprint(v3, print_type=True)
SpecifyShape [id A] <TensorType(float64, (2, 1))>
 |<TensorType(float64, (2, ?))> [id B] <TensorType(float64, (2, ?))>
 |TensorConstant{2} [id C] <TensorType(int8, ())>
 |TensorConstant{1} [id D] <TensorType(int8, ())>

Performing this in the opposite direction returned the output of a SpecifyShapeOp. This SpecifyShape uses v1 static shape as an input and serves to produce a new Variable that has a Type compatible with both v1 and v2.


The Type for v3 should really have a static shape of (2, 1) (i.e. v2’s shape), but the static shape information feature is still under development.

It’s important to keep these special type comparisons in mind when developing custom Ops and graph rewrites in PyTensor, because simple naive comparisons like v1.type == v2.type may unnecessarily restrict logic and prevent more refined type information from propagating throughout a graph. They may not cause errors, but they could prevent PyTensor from performing at its best.

Type’s contract#

In PyTensor’s framework, a Type is any object which defines the following methods. To obtain the default methods described below, the Type should be an instance of Type or should be an instance of a subclass of Type. If you will write all methods yourself, you need not use an instance of Type.

Methods with default arguments must be defined with the same signature, i.e. the same default argument names and values. If you wish to add extra arguments to any of these methods, these extra arguments must have default values.

class pytensor.graph.type.Type[source]

Interface specification for variable type instances.

A Type instance is mainly responsible for two things:

  • creating Variable instances (conventionally, __call__ does this), and

  • filtering a value assigned to a Variable so that the value conforms to restrictions imposed by the type (also known as casting, this is done by filter).

in_same_class(otype: Type) bool | None[source]

Determine if another Type represents a subset from the same “class” of types represented by self.

A “class” of types could be something like “float64 tensors with four dimensions”. One Type could represent a set containing only a type for “float64 tensors with shape (1, 2, 3, 4)” and another the set of “float64 tensors with shape (1, x, x, x)” for all suitable “x”.

It’s up to each subclass of Type to determine to which “classes” of types this method applies.

The default implementation assumes that all “classes” have only one unique element (i.e. it uses self.__eq__).

is_super(otype: Type) bool | None[source]

Determine if self is a supertype of otype.

This method effectively implements the type relation >.

In general, t1.is_super(t2) == True implies that t1 can be replaced with t2.

See Type.in_same_class.

Return type:

None if the type relation cannot be applied/determined.

filter_inplace(value, storage, strict=False, allow_downcast=None)[source]

If filter_inplace is defined, it will be called instead of filter() This is to allow reusing the old allocated memory. This was used only when new data was transferred to a shared variable on a GPU.

storage will be the old value (e.g. the old ndarray).


Returns True iff the value is compatible with the Type. If filter(value, strict = True) does not raise an exception, the value is compatible with the Type.

Default: True iff filter(value, strict=True) does not raise an exception.

values_eq(a, b)[source]

Returns True iff a and b are equal.

Default: a == b

values_eq_approx(a, b)[source]

Returns True iff a and b are approximately equal, for a definition of “approximately” which varies from Type to Type.

Default: values_eq(a, b)


Makes a Variable of this Type with the specified name, if name is not None. If name is None, then the Variable does not have a name. The Variable will have its type field set to the Type object.

Default: there is a generic definition of this in Type. The Variable’s type will be the object that defines this method (in other words, self).


Syntactic shortcut to make_variable.

Default: make_variable


Used to compare Type instances themselves

Default: object.__eq__


Types should not be mutable, so it should be OK to define a hash function. Typically this function should hash all of the terms involved in __eq__.

Default: id(self)

clone(*args, **kwargs) Type[source]

Clone a copy of this type with the given arguments/keyword values, if any.

class pytensor.tensor.type.TensorType(dtype: str | numpy.dtype, shape: Optional[Iterable[bool | int | None]] = None, name: Optional[str] = None, broadcastable: Optional[Iterable[bool]] = None)[source]

Symbolic Type representing numpy.ndarrays.

may_share_memory(a, b)[source]

Optional to run, but mandatory for DebugMode. Return True if the Python objects a and b could share memory. Return False otherwise. It is used to debug when Ops did not declare memory aliasing between variables. Can be a static method. It is highly recommended to use and is mandatory for Type in PyTensor as our buildbot runs in DebugMode.


Optional. Only needed to profile the memory of this Type of object.

Return the information needed to compute the memory size of obj.

The memory size is only the data, so this excludes the container. For an ndarray, this is the data, but not the ndarray object and other data structures such as shape and strides.

get_shape_info() and get_size() work in tandem for the memory profiler.

get_shape_info() is called during the execution of the function. So it is better that it is not too slow.

get_size() will be called on the output of this function when printing the memory profile.


obj – The object that this Type represents during execution


Python object that self.get_size() understands


Number of bytes taken by the object represented by shape_info.

Optional. Only needed to profile the memory of this Type of object.

param shape_info:

the output of the call to get_shape_info


the number of bytes taken by the object described by shape_info.

Additional definitions#

For certain mechanisms, you can register functions and other such things to plus your type into pytensor’s mechanisms. These are optional but will allow people to use you type with familiar interfaces.


To plug in additional options for the transfer target, define a function which takes an PyTensor variable and a target argument and returns eitehr a new transferred variable (which can be the same as the input if no transfer is necessary) or returns None if the transfer can’t be done.

Then register that function by calling register_transfer() with it as argument.

An example#

We are going to base Type DoubleType on Python’s float. We must define filter and values_eq_approx.


# note that we shadow python's function ``filter`` with this
# definition.
def filter(x, strict=false, allow_downcast=none):
    if strict:
        if isinstance(x, float):
            return x
            raise typeerror('expected a float!')
    elif allow_downcast:
        return float(x)
    else:   # covers both the false and none cases.
        x_float = float(x)
        if x_float == x:
            return x_float
              raise TypeError('The double type cannot accurately represent '
                              f'value {x} (of type {type(x)}): you must explicitly '
                              'allow downcasting if you want to do this.')

If strict is True we need to return x. If strict is True and x is not a float (for example, x could easily be an int) then it is incompatible with our Type and we must raise an exception.

If strict is False then we are allowed to cast x to a float, so if x is an int it we will return an equivalent float. However if this cast triggers a precision loss (x != float(x)) and allow_downcast is not True, then we also raise an exception. Note that here we decided that the default behavior of our type (when allow_downcast is set to None) would be the same as when allow_downcast is False, i.e. no precision loss is allowed.


def values_eq_approx(x, y, tolerance=1e-4):
    return abs(x - y) / (abs(x) + abs(y)) < tolerance

The second method we define is values_eq_approx. This method allows approximate comparison between two values respecting our Type’s constraints. It might happen that a rewrite changes the computation graph in such a way that it produces slightly different variables, for example because of numerical instability like rounding errors at the end of the mantissa. For instance, a + a + a + a + a + a might not actually produce the exact same output as 6 * a (try with a=0.1), but with values_eq_approx we do not necessarily mind.

We added an extra tolerance argument here. Since this argument is not part of the API, it must have a default value, which we chose to be 1e-4.


values_eq is never actually used by PyTensor, but it might be used internally in the future. Equality testing in DebugMode is done using values_eq_approx.

Putting them together

What we want is an object that respects the aforementioned contract. Recall that Type defines default implementations for all required methods of the interface, except filter.

from pytensor.graph.type import Type

class DoubleType(Type):

    def filter(self, x, strict=False, allow_downcast=None):
        # See code above.

    def values_eq_approx(self, x, y, tolerance=1e-4):
        # See code above.

double = DoubleType()

double is then an instance of TypeDoubleType, which in turn is a subclass of Type.

There is a small issue with our DoubleType: all instances of DoubleType are technically the same Type; however, different DoubleTypeType instances do not compare the same:

>>> double1 = DoubleType()
>>> double2 = DoubleType()
>>> double1 == double2

PyTensor compares Types using == to see if they are the same. This happens in DebugMode. Also, Ops can (and should) ensure that their inputs have the expected Type by checking something like x.type.is_super(lvector) or x.type.in_same_class(lvector).

There are several ways to make sure that equality testing works properly:

  1. Define DoubleType.__eq__() so that instances of type DoubleType are equal. For example:

    def __eq__(self, other):
        return type(self) == type(other)
  2. Override DoubleType.__new__() to always return the same instance.

  3. Hide the DoubleType class and only advertise a single instance of it.

Here we will prefer the final option, because it is the simplest. Ops in the PyTensor code often define the __eq__() method though.

Untangling some concepts#

Initially, confusion is common on what an instance of Type is versus a subclass of Type or an instance of Variable. Some of this confusion is syntactic. A Type is any object which has fields corresponding to the functions defined above. The Type class provides sensible defaults for all of them except Type.filter(), so when defining new Types it is natural to subclass Type. Therefore, we often end up with Type subclasses and it is can be confusing what these represent semantically. Here is an attempt to clear up the confusion:

  • An instance of :class:`Type` (or an instance of a subclass) is a set of constraints on real data. It is akin to a primitive type or class in C. It is a static annotation.

  • An instance of :class:`Variable` symbolizes data nodes in a data flow graph. If you were to parse the C expression int x;, int would be a Type instance and x would be a Variable instance of that Type instance. If you were to parse the C expression c = a + b;, a, b and c would all be Variable instances.

  • A subclass of :class:`Type` is a way of implementing a set of Type instances that share structural similarities. In the DoubleType example that we are doing, there is actually only one Type in that set, therefore the subclass does not represent anything that one of its instances does not. In this case it is a singleton, a set with one element. However, the TensorType class in PyTensor (which is a subclass of Type) represents a set of types of tensors parametrized by their data type or number of dimensions. We could say that subclassing Type builds a hierarchy of Types which is based upon structural similarity rather than compatibility.