The hardware and bandwidth for this mirror is donated by METANET, the Webhosting and Full Service-Cloud Provider.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]metanet.ch.

Extending Autograd

library(torch)

Adding operations to autograd requires implementing a new autograd_function for each operation. Recall that autograd_functionss are what autograd uses to compute the results and gradients, and encode the operation history. Every new function requires you to implement 2 methods:

Note

It’s the user’s responsibility to use the special functions in the forward’s ctx properly in order to ensure that the new autograd_function works properly with the autograd engine.

Examples

Below you can find code for a linear function:

linear <- autograd_function(
  forward = function(ctx, input, weight, bias = NULL) {
    ctx$save_for_backward(input = input, weight = weight, bias = bias)
    output <- input$mm(weight$t())
    if (!is.null(bias))
      output <- output + bias$unsqueeze(0)$expand_as(output)
    
    output
  },
  backward = function(ctx, grad_output) {
    
    s <- ctx$saved_variables
    
    grads <- list(
      input = NULL,
      weight = NULL,
      bias = NULL
    )
    
    if (ctx$needs_input_grad$input)
      grads$input <- grad_output$mm(s$weight)
    
    if (ctx$needs_input_grad$weight)
      grads$weight <- grad_output$t()$mm(s$input)
    
    if (!is.null(s$bias) && ctx$needs_input_grad$bias)
      grads$bias <- grad_output$sum(dim = 0)
    
    grads
  }
)

Here, we give an additional example of a function that is parametrized by non-Tensor arguments:

mul_constant <- autograd_function(
  forward = function(ctx, tensor, constant) {
    ctx$save_for_backward(constant = constant)
    tensor * constant
  },
  backward = function(ctx, grad_output) {
    v <- ctx$saved_variables
    list(
      tensor = grad_output * v$constant
    )
  }
)
x <- torch_tensor(1, requires_grad = TRUE)
o <- mul_constant(x, 2)
o$backward()
x$grad

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.