TorchScript is a statically typed subset of Python that can be interpreted by LibTorch without any Python dependency. The torch R package provides interfaces to create, serialize, load and execute TorchScript programs.
Advantages of using TorchScript are:
TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.
This format allows us to save the whole model to disk and load it into another environment, such as on server written in a language other than R.
TorchScript gives us a representation in which we can do compiler optimizations on the code to make execution more efficient.
TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.
TorchScript programs can be created from R using tracing. When using tracing, code is automatically converted into this subset of Python by recording only the actual operators on tensors and simply executing and discarding the other surrounding R code.
Currently tracing is the only supported way to create TorchScript programs from R code.
For example, let’s use the jit_trace
function to create
a TorchScript program. We pass a regular R function and example
inputs.
The jit_trace
function has executed the R function with
the example input and recorded all torch operations that occurred during
execution to create a graph. graph is how we call the
intermediate representation of TorchScript programs, and it can be
inspected with:
traced_fn$graph
#> graph(%0 : Float(3, strides=[1], requires_grad=0, device=cpu)):
#> %1 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu(%0)
#> return (%1)
The traced function can now be invoked as a regular R function:
It’s also possible to trace nn_modules()
defined in R,
for example:
module <- nn_module(
initialize = function() {
self$linear1 <- nn_linear(10, 10)
self$linear2 <- nn_linear(10, 1)
},
forward = function(x) {
x %>%
self$linear1() %>%
nnf_relu() %>%
self$linear2()
}
)
traced_module <- jit_trace(module(), torch_randn(10, 10))
When using jit_trace
with a nn_module
only
the forward
method is traced. However, by default, one pass
will be conducted in ‘train’ mode, and another one in ‘eval’ mode, which
is different from the PyTorch behavior. One can opt out of this by
specifying respect_mode = FALSE
which will only trace the
forward pass in the mode the network is currently in. You can use the
jit_trace_module
function to pass example inputs to other
methods. Traced modules look like normal nn_modules()
, and
can be called the same way:
traced_module(torch_randn(3, 10))
#> torch_tensor
#> 0.2964
#> 0.3116
#> 0.6045
#> [ CPUFloatType{3,1} ][ grad_fn = <AddmmBackward0> ]
# fn does does an operation for each dimension of a tensor
fn <- function(x) {
x %>%
torch_unbind(dim = 1) %>%
lapply(function(x) x$sum()) %>%
torch_stack(dim = 1)
}
# we trace using as an example a tensor with size (10, 5, 5)
traced_fn <- jit_trace(fn, torch_randn(10, 5, 5))
# applying it with a tensor with different size returns an error.
traced_fn(torch_randn(11, 5, 5))
#> Error in cpp_call_traced_fn(ptr, inputs): The following operation failed in the TorchScript interpreter.
#> Traceback of TorchScript (most recent call last):
#> RuntimeError: Expected 10 elements in a list but found 11
ScriptModule
, operations that have
different behaviors in training and eval modes will always behave as if
it were in the mode it was in during tracing, no matter which mode the
ScriptModule
is in. For example:traced_dropout <- jit_trace(nn_dropout(), torch_ones(5,5))
traced_dropout(torch_ones(3,3))
#> torch_tensor
#> 2 0 0
#> 0 0 0
#> 2 2 0
#> [ CPUFloatType{3,3} ]
traced_dropout$eval()
#> [1] FALSE
# even after setting to eval mode, dropout is applied
traced_dropout(torch_ones(3,3))
#> torch_tensor
#> 1 1 1
#> 1 1 1
#> 1 1 1
#> [ CPUFloatType{3,3} ]
fn <- function(x, y) {
x + y
}
jit_trace(fn, torch_tensor(1), 1)
#> Error in cpp_trace_function(tr_fn, list(...), .compilation_unit, strict, : Only tensors or (possibly nested) dict or tuples of tensors can be inputs to traced functions. Got float
#> Exception raised from addInput at /Users/runner/work/libtorch-mac-m1/libtorch-mac-m1/pytorch/torch/csrc/jit/frontend/tracer.cpp:422 (most recent call first):
#> frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>) + 52 (0x10775c11c in libc10.dylib)
#> frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) + 140 (0x107758d6c in libc10.dylib)
#> frame #2: torch::jit::tracer::addInput(std::__1::shared_ptr<torch::jit::tracer::TracingState> const&, c10::IValue const&, c10::Type::SingletonOrSharedTypePtr<c10::Type> const&, torch::jit::Value*) + 6060 (0x3041dc984 in libtorch_cpu.dylib)
#> frame #3: torch::jit::tracer::addInput(std::__1::shared_ptr<torch::jit::tracer::TracingState> const&, c10::IValue const&, c10::Type::SingletonOrSharedTypePtr<c10::Type> const&, torch::jit::Value*) + 4656 (0x3041dc408 in libtorch_cpu.dylib)
#> frame #4: torch::jit::tracer::trace(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>, std::__1::function<std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>> (std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>)> const&, std::__1::function<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> (at::Tensor const&)>, bool, bool, torch::jit::Module*, std::__1::vector<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, std::__1::allocator<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>> const&) + 680 (0x3041d9cb0 in libtorch_cpu.dylib)
#> frame #5: _lantern_trace_fn + 292 (0x165030444 in liblantern.dylib)
#> frame #6: cpp_trace_function(Rcpp::Function_Impl<Rcpp::PreserveStorage>, XPtrTorchStack, XPtrTorchCompilationUnit, XPtrTorchstring, bool, XPtrTorchScriptModule, bool, bool) + 536 (0x121c87f98 in torchpkg.so)
#> frame #7: _torch_cpp_trace_function + 692 (0x121a72374 in torchpkg.so)
#> frame #8: R_doDotCall + 3356 (0x102cb461c in libR.dylib)
#> frame #9: bcEval_loop + 128100 (0x102d102a4 in libR.dylib)
#> frame #10: bcEval + 684 (0x102ce346c in libR.dylib)
#> frame #11: Rf_eval + 556 (0x102ce2b6c in libR.dylib)
#> frame #12: R_execClosure + 812 (0x102ce572c in libR.dylib)
#> frame #13: applyClosure_core + 164 (0x102ce4824 in libR.dylib)
#> frame #14: Rf_eval + 1224 (0x102ce2e08 in libR.dylib)
#> frame #15: do_eval + 1352 (0x102cea048 in libR.dylib)
#> frame #16: bcEval_loop + 40204 (0x102cfab4c in libR.dylib)
#> frame #17: bcEval + 684 (0x102ce346c in libR.dylib)
#> frame #18: Rf_eval + 556 (0x102ce2b6c in libR.dylib)
#> frame #19: forcePromise + 232 (0x102ce36a8 in libR.dylib)
#> frame #20: Rf_eval + 660 (0x102ce2bd4 in libR.dylib)
#> frame #21: do_withVisible + 64 (0x102cea380 in libR.dylib)
#> frame #22: do_internal + 400 (0x102d53190 in libR.dylib)
#> frame #23: bcEval_loop + 40764 (0x102cfad7c in libR.dylib)
#> frame #24: bcEval + 684 (0x102ce346c in libR.dylib)
#> frame #25: Rf_eval + 556 (0x102ce2b6c in libR.dylib)
#> frame #26: forcePromise + 232 (0x102ce36a8 in libR.dylib)
#> frame #27: Rf_eval + 660 (0x102ce2bd4 in libR.dylib)
#> frame #28: forcePromise + 232 (0x102ce36a8 in libR.dylib)
#> frame #29: bcEval_loop + 19728 (0x102cf5b50 in libR.dylib)
#> frame #30: bcEval + 684 (0x102ce346c in libR.dylib)
#> frame #31: Rf_eval + 556 (0x102ce2b6c in libR.dylib)
#> frame #32: R_execClosure + 812 (0x102ce572c in libR.dylib)
#> frame #33: applyClosure_core + 164 (0x102ce4824 in libR.dylib)
#> frame #34: Rf_eval + 1224 (0x102ce2e08 in libR.dylib)
#> frame #35: do_eval + 1352 (0x102cea048 in libR.dylib)
#> frame #36: bcEval_loop + 40204 (0x102cfab4c in libR.dylib)
#> frame #37: bcEval + 684 (0x102ce346c in libR.dylib)
#> frame #38: Rf_eval + 556 (0x102ce2b6c in libR.dylib)
#> frame #39: R_execClosure + 812 (0x102ce572c in libR.dylib)
#> frame #40: applyClosure_core + 164 (0x102ce4824 in libR.dylib)
#> frame #41: Rf_eval + 1224 (0x102ce2e08 in libR.dylib)
#> frame #42: R_execClosure + 812 (0x102ce572c in libR.dylib)
#> frame #43: applyClosure_core + 164 (0x102ce4824 in libR.dylib)
#> frame #44: bcEval_loop + 37320 (0x102cfa008 in libR.dylib)
#> frame #45: bcEval + 684 (0x102ce346c in libR.dylib)
#> frame #46: Rf_eval + 556 (0x102ce2b6c in libR.dylib)
#> frame #47: R_execClosure + 812 (0x102ce572c in libR.dylib)
#> frame #48: applyClosure_core + 164 (0x102ce4824 in libR.dylib)
#> frame #49: Rf_eval + 1224 (0x102ce2e08 in libR.dylib)
#> frame #50: Rf_ReplIteration + 756 (0x102d3f034 in libR.dylib)
#> frame #51: R_ReplConsole + 168 (0x102d40668 in libR.dylib)
#> frame #52: run_Rmainloop + 100 (0x102d405a4 in libR.dylib)
#> frame #53: Rf_mainloop + 16 (0x102d40710 in libR.dylib)
#> frame #54: main + 32 (0x102723ea0 in R)
#> frame #55: start + 2840 (0x19ddd4274 in dyld)
#> :
It’s also possible to create TorchScript programs by compiling TorchScript code. TorchScript code looks a lot like standard python code. For example:
TorchScript programs can be serialized using the
jit_save
function and loaded back from disk with
jit_load
.
For example:
fn <- function(x) {
torch_relu(x)
}
tr_fn <- jit_trace(fn, torch_tensor(1))
jit_save(tr_fn, "path.pt")
loaded <- jit_load("path.pt")
Loaded programs can be executed as usual:
Note You can load TorchScript programs that were
created in libraries different than torch
for R. Eg, a
TorchScript program can be created in PyTorch with
torch.jit.trace
or torch.jit.script
, and run
from R.
R objects are automatically converted to their TorchScript
counterpart following the Types table in this document. However,
sometimes it’s necessary to make type annotations with
jit_tuple()
and jit_scalar()
to disambiguate
the conversion.
The following table lists all TorchScript types and how to convert the to and back to R.
TorchScript Type | R Description |
---|---|
Tensor |
A torch_tensor with any shape, dtype or backend. |
Tuple[T0, T1, ..., TN] |
A list() containing subtypes T0 ,
T1 , etc. wrapped with jit_tuple() . |
bool |
A scalar logical value create using jit_scalar . |
int |
A scalar integer value created using jit_scalar . |
float |
A scalar floating value created using jit_scalar . |
str |
A string (ie. character vector of length 1) wrapped in
jit_scalar . |
List[T] |
An R list of which all types are type T . Or numeric
vectors, logical vectors, etc. |
Optional[T] |
Not yet supported. |
Dict[str, V] |
A named list with values of type V . Only
str key values are currently supported. |
T |
Not yet supported. |
E |
Not yet supported. |
NamedTuple[T0, T1, ...] |
A named list containing subtypes T0 , T1 ,
etc. wrapped in jit_tuple() . |