Co-Author: Andrew James
2022 was an exciting year for the PyTorch ecosystem. The PyTorch project
joining the Linux Foundation was a major
milestone, and PyTorch 2.0 was announced with loads
of informative talks from the maintainers
explaining new features. Additionally, there was marked progress on areas
including sparse tensors,
JAX-like transformations in PyTorch were
released, and TorchVision announced a
new Transforms API.
In this post, we will review how Quansight has been working hand-in-hand with
the PyTorch team at Meta to enhance the PyTorch Project in regards to the above,
and beyond.
PyTorch 2.0
The biggest PyTorch announcement in 2022 was, without a doubt, the release of
PyTorch 2.0. The main feature this release adds is a more powerful and
user-friendly compiler for PyTorch programs. This compiler is executed via the
function/decorator torch.compile
and it wraps a PyTorch program or model and
turns it into a compiled version of it. What makes this different from other ML
frameworks, among other things, is that torch.compile
works with arbitrary
Python programs, even in the presence of data-dependent control flow. This way,
the user gets all the speed from a compiled framework, with the flexibility that
characterizes PyTorch. Even more, for this reason, this release has the
remarkable property of being fully backwards-compatible.
Of course, features of this caliber are composed of a number of subsystems.
Quansight helped in the development of a few of these.
Decompositions: PrimTorch
One of the first steps towards building any compiler that is worth its salt is
to desugar the language in question. In the case of PyTorch, this meant
implementing PyTorch operations in terms of simpler operations. This seems like
a simple task at first sight, until you realize the amount of subtle features
that PyTorch supports like type promotion, broadcasting, out
kwarg, inplace
and view operations, and more. For more about this, see this post on Tracing
with Primitives.
Quansight engineers helped design and implement many of these, as well as helped
build the testing infrastructure to make sure they are all correct.
Compiler: TorchInductor
Another big part of a compiler is… well, the compiler itself. The PyTorch 2.0
compiler is called TorchInductor
, and it takes as input an
FX graph and returns a compiled and optimized function that
can be executed from Python. For CPU tensors, it compiles against C++ with
OpenMP and SIMD instructions. On CUDA, it compiles against
Triton.
We were involved in the implementation of a number of optimization passes to
generate faster code.
Dynamic Shapes
Another feature that PyTorch 2.0 brings to the table is dynamic shapes. Consider
a model to which you want to feed batches of different sizes as, perhaps, your
data comes from a stream of data. You would not want to recompile the model
repeatedly, as a program that is optimized to run on inputs of shape
[2, 512, 512]
will probably do just fine on inputs of shape [3, 512, 512]
.
Dynamic shapes allow for tracing a program and also tracing the size of some
dimensions symbolically. For example, we know if inputs to matmul
have shapes
[m, k]
and [k, n]
, the output will have shape [m, n]
. By tracking these
shapes symbolically, we save on plenty of recompilations on language models and
we are able to export generic models that work for arbitrary lengths and
batches.
We were—and continue to be—involved in fixing some correctness issues
related to dynamic shapes across the PyTorch 2.0 stack.
NumPy With PyTorch as Its Backend
The PyTorch 2.0 tracer currently just knows how to trace Python programs that
have PyTorch calls in them. But what if you have NumPy calls as well? Even more,
what if you want to compile and run your good old NumPy model with autograd and
CUDA acceleration? In PyTorch 2.X you will be able to do that right out of the
box without having to change one line of code! Just stamp the entry point to
your model with a @torch.compile
and you’re ready to go.
We are in the process of implementing an interface that translates NumPy to
PyTorch. You could think of it as having PyTorch as a backend for your NumPy
computations without having to change one file in your project!
SciPy And scikit-learn With PyTorch as Their Backend: Python Array API
In 2021-22, Quansight’s team did a fair amount of work to implement the Python Array
API within PyTorch. For 2023, we are taking this one step
further.
In parallel, a number of PyData maintainers within Quansight have started
migrating parts of their codebase to use the Python Array API. These two
together will allow users to run their programs that use NumPy, SciPy, or
scikit-learn with PyTorch as their backend. In particular, you could create a
compiled CUDA version of your favorite scikit-learn algorithms and even
differentiate through them! What a time to be alive.
Additional Compiler Backends
It should be clear by now that the main feature of PyTorch 2.0 is its compiler.
The compiler itself is built to be flexible and support multiple backends.
Options are limited at launch, but we have been working to deliver alternatives.
There is ongoing work towards delivering a transpiler that translates a PyTorch
program written in Python into a PyTorch program written in C++. This means that
you can get all the performance improvements of writing your model using
PyTorch’s C++ API, without having to write one line of C++!
Further down the road, this will be integrated within TorchInductor, and we will
intermix Triton kernels with hand-written CUDA fallbacks for those operations
that are not supported by the compiler yet. This compiler backend will allow for
compiled functions to stay fully in C++ without ever having to come back to
sluggish Python-land.
Scientific PyTorch
PyTorch is designed and primarily targeted at deep learning applications, but
scientific communities are also part of the vast user base. The wealth of
numerical tools, autograd, and interfaces to powerful computing hardware
backends such as CUDA make it an attractive choice for those working in these
domains.
Our team is comprised of SciPy community veterans, and we have a wealth of
expertise in various scientific domains. Quansight has been instrumental in
expanding PyTorch to cover functionality from popular SciPy modules (in some
cases going beyond SciPy) and we have continued to do so in 2022.
Sparse Tensors
The torch.sparse
team has focused on enabling key workflows and developing a
path toward stability. We have expanded support for data types and layouts in
sparse-sparse and sparse-dense matrix multiplication, which has enabled PyTorch
Linear layers to carry sparse weights. Specialized kernels have been added using
Triton to support half-precision data types while
continuing to utilize the newest features from NVIDIA’s cuSPARSE as they become
available. While matrix multiplication represents the area where sparsity offers
the largest potential performance gains, performance improvements have also been
realized for point-wise multiplications, layout validations, and masking
operations. Autograd integration has been steadily improving with more and more
operators supporting sparse layouts on the backward path.
The next year is set to be an exciting one for sparsity in PyTorch. Experiments
with new backends are showing promising performance in the realm of sparsity
realistic for deep learning, with tensors that are 50-80% zero, where
performance gains have only been seen previously when tensors are 90% zero or
more. Integration with torch.compile
is also high on the list of priorities as
that technology stack offers a route to solving problems like determining the
best layouts for outputs in a sequence of operations. First, we need to get
sparse to the “stable feature” designation.
One of PyTorch’s big releases this year was that of torch.func
(prev.
functorch
) being merged into core. This module
provides higher order functions that allow you to vectorize your computation by
calling vmap
, to compute forward and backward vector products via jvp
and
vjp
, or to compute per-sample gradients, as popularized by JAX.
In order for this to become a thing, PyTorch maintainers needed to implement
these operations for each and every one of the operators that PyTorch has. That
is well over 2,600 operators. To make this task bearable, Quansight engineers
helped implement many of the PyTorch operators in a composite way. A composite
operator then inherits the implementation of vmap
, vjp
and all these from
the operators that compose it. This is a similar idea to the PrimTorch approach
we discussed above. Quansight engineers also helped implement many other
features now available in this shiny new PyTorch module.
Complex Half: torch.complex32
Once you have complex numbers, as we helped deliver in 2021, and accelerators, the next
thing you want is complex numbers in half precision. The issue, as always, is
that PyTorch has thousands of operations. Each of them implemented in CPU
and CUDA. And, for each device, implemented for each datatype. And each of those
contributes to the size of the executable. And each of those needs to be
compiled every time you compile PyTorch. And each of those takes a little bit of
time to load whenever you run import torch
. And at some point you start
wondering whether it is worthwhile to deliver half-precision complex numbers.
To solve this issue, Meta engineers put together a JIT compiler—yes, another
compiler, different from the two we discussed above—that allows compiling
torch.complex32
versions of operators just as they are used for the first time
in the program. This allows us to get fast compilation times, and smaller
binaries at the cost of a tiny perf hit the first time we execute a function.
Quansight engineers then helped expand the torch.complex32
coverage throughout
PyTorch.
torch.signal.windows
With the help of the OSS community, Quansight engineers have helped design and
implement SciPy’s signal.windows
module in PyTorch. This module is of
particular interest for digital signal processing, and it couples particularly
well with the previous work of Quansight engineers in bringing
complex numbers and
FFTs to PyTorch.
One of the areas of specialization of Quansight engineers is code optimization.
A faster PyTorch means less money spent on cloud services, less energy spent in
the world, and less time waiting for your model to finish training. Win, win,
win!
A few highlights from 2022 are:
Maintainability
It is great to talk about the flashy new features and performance improvements
we have worked on, but every project needs to be maintained at a level that most
users will never see. When the project is as large and complex as PyTorch, this
work is even more important, but often harder to see. Quansight engineers are
constantly at work refactoring sub-systems to reduce technical debt and making
improvements to PyTorch that allow the entire project to continue marching
forward.
Build Time Improvements
In a project the size of PyTorch, keeping build times reasonable is paramount.
Last year, Quansight engineers helped reduce the average
fresh build time from 20 minutes to five minutes. This year, we continued
working on recompilations. A huge number of compilation sources are generated at
build time from specification files, such as all of the C++ to Python bindings,
which are found in a single file. The generation system does not have a notion
of what has changed, and what has not, between compilations so all of the
generated sources are re-generated, and subsequently recompiled, any time that
single file is edited in any way—even if the edit is a comment.
Our work entailed making recompilation more granular. If one signature is
modified, you will just need to recompile the relevant files.
Typeless Storage
In 2021, the classes managing the memory where tensor data is stored, called
“storage,” were modified such that all data was stored as raw bytes without a
data type. This change made tensor storage more flexible and easier to maintain.
The change was not duplicated to the Python interface for these classes to
maintain backward compatibility. In 2022, we continued this work to begin the
deprecation cycle and re-unify the two interfaces under the type-free storage
system. This delicate task required diligent work to maintain the correctness of
lower-level functionality such as serialization, but our team was equal to the
task and typed storage is officially deprecated, and will be completely removed
in PyTorch 2.1!
Maintenance of torch.linalg
and torch.fft
In 2021, a number of Quansight and Meta engineers designed and implemented the
linalg
and fft
modules based on their NumPy and SciPy counterparts. One year
later, these modules are now completely stable, but we still provide guidance to
users that find the sharper corners of the
mathematically-involved functions in them.
Testing
Another front where Quansight spent a reasonable amount of time in 2022 was
helping to set up a scalable and reliable testing infrastructure in the PyTorch
project. At this time, the main structure and the tests are rather stable and
provide a reliable way to test new features that are added to the library. Now,
as any system that sets itself to tackle such a humongous task, it needs
maintaining to stop it from growing out of control.
In 2022, Quansight engineers helped refactor the new testing infrastructure to
keep parts of it more manageable and implemented a number of quality of life
improvements for it, which helped find and fix a number of bugs it had.
TorchVision
PyTorch core provides critical functionality that is generally applicable to
learning workflows, but it is not intended to cover all applications on its own.
Domain specific extension libraries like TorchVision fill the gaps with
specialized tools for a particular problem space. Not to be outshined by the
PyTorch 2.0 release, we have worked to deliver some exciting new features in
TorchVision as well.
One of the most tiring parts of doing real-world Machine Learning is data
pre-processing. TorchVision offers a set of functions to help you in this task
when you want to classify images. The issue is that image classification is
so 2016. Modern Deep Learning deals with segmentation tasks with masks, object
detection with bounding boxes, videos and more. This means that when you rotate
an image, you may also need to rotate the annotations, cropping and all other
transformations in general.
Quansight engineers designed and developed a more flexible API for these
transforms that allows you to have more heterogeneous data and annotations, to
go with modern vision tasks. This new API is not only backwards compatible, but
actually faster than the previous one and will be released in a beta state with
TorchVision 0.15. You can read more about it in this very nice blog post,
Extending TorchVision’s Transforms to Object Detection, Segmentation & Video tasks.
Video Codecs Support
As discussed above, image preprocessing is annoying. Now, this is difficult to
deal with when images are static. But when they move… that’s a whole different
level.
One particular pain point when dealing with videos is the different formats
they come in. We have been working on putting together a robust
video reader that deals with these different codecs and prepares the data to be
able to be processed by a neural network.
Yearly reviews are great! They help you see how much you moved forward in the
last twelve months in one go. We also can’t hope to cover everything our team of
20 engineers have done in over 18,000 hours of work in 2022. These highlights
show it was a truly impressive group effort together with Meta! Whether you are
looking to get started or are already using PyTorch, Quansight has the experts
to help you.
All this would not have been possible without the contributions from our
fantastic team of PyTorch and TorchVision contributors: Peter Bell, Mario
Lezcano, Andrew James, Nikita Vedeneev, Pearu Peterson, Kshiteej Kalambarkar,
Philip Meier, Victor Fomin, Yukio Siraichi, Kurt Mohler, Evgeni Burovski, Aaron
Meuer, Matthew Barber, Fabio Rocha, Aleksandar Samardžić, Bruno Korbar, Nikita
Karetnikov, Sean Ross-Ross, Ralf Gommers, and Matti Picus.
The other part of the story comes from the great team of PyTorch and TorchVision
engineers at Meta. Without their dedication and openness to collaboration this
work would not be possible. In particular, we would like to thank Alban
Desmaison, Natalia Gimelschein, Edward Yang, Anjali Chourdia, Christian Pursch,
Jane Xu, Joel Schlosser, Nikita Shulga, Richard Zou, Brian Hirsh, Horace He,
Jason Ansel, Joe Isaacson, and Soumith Chintala (PyTorch) and Vasilis Vryniotis,
Prabhat Roy, and Nicolas Hug (TorchVision). Thank you all for making this
collaboration so fruitful and enjoyable.