The hardware and bandwidth for this mirror is donated by METANET, the Webhosting and Full Service-Cloud Provider.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]metanet.ch.
This vignette describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.
We begin by generating encoder representations of four different words.
# encoder representations of four different words
word_1 = matrix(c(1,0,0), nrow=1)
word_2 = matrix(c(0,1,0), nrow=1)
word_3 = matrix(c(1,1,0), nrow=1)
word_4 = matrix(c(0,0,1), nrow=1)
Next, we stack the word embeddings into a single array (in this case
a matrix) which we call words
.
Let’s see what this looks like.
Next, we generate random integers on the domain
[0,3]
.
# initializing the weight matrices (with random values)
set.seed(0)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
Next, we generate the Queries (Q
), Keys
(K
), and Values (V
). The %*%
operator performs the matrix multiplication. You can view the R help
page using help('%*%')
(or the online An
Introduction to R).
Following this, we score the Queries (Q
) against the Key
(K
) vectors (which are transposed for the multiplation
using t()
, see help('t')
for more info).
# scoring the query vectors against all key vectors
scores = Q %*% t(K)
print(scores)
#> [,1] [,2] [,3] [,4]
#> [1,] 6 4 10 5
#> [2,] 4 6 10 6
#> [3,] 10 10 20 11
#> [4,] 3 1 4 2
We now generate the weights
matrix.
Let’s have a look at the weights
matrix.
print(weights)
#> [,1] [,2] [,3] [,4]
#> [1,] -0.2986355 -2.6877197 4.479533 -1.4931776
#> [2,] -3.1208558 -0.6241712 4.369198 -0.6241712
#> [3,] -1.7790165 -1.7790165 4.690134 -1.1321014
#> [4,] 1.2167336 -3.6502008 3.650201 -1.2167336
Finally, we compute the attention
as a weighted sum of
the value vectors (which are combined in the matrix V
).
Now we can view the results using:
print(attention)
#> [,1] [,2] [,3]
#> [1,] 7.167252 6.868617 -1.4931776
#> [2,] 4.993369 1.872514 -0.6241712
#> [3,] 6.469151 4.690134 -1.1321014
#> [4,] 7.300402 8.517135 -1.2167336
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.