TorchScript
TorchScript
- A PyTorch model’s journey from Python to C++ is enabled by Torch Script, a representation of a PyTorch model that can be understood, compiled and serialized by the Torch Script compiler.
- Any TorchScript program can be saved from a Python process and loaded in a process where there is NO Python dependency. In other words, a TorchScript program can be run independently from Python, such as in a standalone C++ program.
- This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons.
- 👍 Advantage
- TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.
- This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python
- TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution
- TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.
Steps for Loading a PyTorch Model in C++
- Converte PyTorch Model to TorchScript
- Serialize script module to a file
- Load script module in C++
- Execute script module in C++
Convert PyTorch Model to Torch Script
There are wo ways to convert a PyTorch model to Torch Script
Tracing
- A mechanism in which
- the structure of the model is captured by evaluating it once using example inputs and
- recording the flow of those inputs through the model.
- Suitable for models that make limited use of control flow
- Function:
torch.jit.trace
Example
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
What happens under the hood when we call torch.jit.trace
, passing in the Module
and an example input?
- It has invoked the
Module
- Recorded the operations that occured when the
Module
was run - Created an instance of
torch.jit.ScriptModule
TorchScript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a graph (we can examine the graph with the .graph
property).
A better way is to use the .code
property to give a Python-syntax interpretation of the code:
print(traced_cell.code)
Out:
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = torch.add((self.linear).forward(input, ), h, alpha=1)
_1 = torch.tanh(_0)
return (_1, _1)
Scripting
If our code use control flows (if-else, loop…), then tracing is unsuitable. In this case, we will use a script compiler, which does code analysis of our Python source code to transform it into TorchScript. The function for compiling the module is torch.jit.script
.
Example
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
sm
is an instance of ScriptModule
that is ready for serialization.
Mixing Scripting and Tracing
In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. Tracing and scripting can be composed to suit the particular requirements of a part of a model.
Scripted functions can call traced functions.
Useful when we need to use control-flow around a simple feed-forward model
Example
import torch def foo(x, y): return 2 * x + y traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) @torch.jit.script def bar(x): return traced_foo(x, x)
Traced functions can call script functions.
Useful when a small part of a model requires some control-flow even though most of the model is just a feed-forward network.
Control-flow inside of a script function called by a traced function is preserved correctly.
Example
import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r def bar(x, y, z): return foo(x, y) + z traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
Saving aand Loading Script Module
- Save:
save()
- Load:
torch.jit.load()
Example:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
Save:
traced_script_module.save("traced_resnet_model.pt")
Load:
traced_resnet = torch.jit.load("traced_resnet_model.pt")