Matrix factorization

Matrix factorization is a simple embedding model. Given the feedback matrix A \(\in R^{m \times n}\), where \(m\) is the number of users (or queries) and \(n\) is the number of items, the model learns:

  • A user embedding matrix \(U \in \mathbb R^{m \times d}\), where row i is the embedding for user i.
  • An item embedding matrix \(V \in \mathbb R^{n \times d}\), where row j is the embedding for item j.

Illustration of matrix factorization using the recurring movie example.

The embeddings are learned such that the product \(U V^T\) is a good approximation of the feedback matrix A. Observe that the \((i, j)\) entry of \(U . V^T\) is simply the dot product \(\langle U_i, V_j\rangle\) of the embeddings of user \(i\) and item \(j\), which you want to be close to \(A_{i, j}\).

Choosing the objective function

One intuitive objective function is the squared distance. To do this, minimize the sum of squared errors over all pairs of observed entries:

\[\min_{U \in \mathbb R^{m \times d},\ V \in \mathbb R^{n \times d}} \sum_{(i, j) \in \text{obs}} (A_{ij} - \langle U_{i}, V_{j} \rangle)^2.\]

In this objective function, you only sum over observed pairs (i, j), that is, over non-zero values in the feedback matrix. However, only summing over values of one is not a good idea—a matrix of all ones will have a minimal loss and produce a model that can't make effective recommendations and that generalizes poorly.

Illustration of three matrices: Observed only matrix factorization, weighted factorization, and Singular Value Decomposition.

Perhaps you could treat the unobserved values as zero, and sum over all entries in the matrix. This corresponds to minimizing the squared Frobenius distance between \(A\) and its approximation \(U V^T\):

\[\min_{U \in \mathbb R^{m \times d},\ V \in \mathbb R^{n \times d}} \|A - U V^T\|_F^2.\]

You can solve this quadratic problem through Singular Value Decomposition (SVD) of the matrix. However, SVD is not a great solution either, because in real applications, the matrix \(A\) may be very sparse. For example, think of all the videos on YouTube compared to all the videos a particular user has viewed. The solution \(UV^T\) (which corresponds to the model's approximation of the input matrix) will likely be close to zero, leading to poor generalization performance.

In contrast, Weighted Matrix Factorization decomposes the objective into the following two sums:

  • A sum over observed entries.
  • A sum over unobserved entries (treated as zeroes).

\[\min_{U \in \mathbb R^{m \times d},\ V \in \mathbb R^{n \times d}} \sum_{(i, j) \in \text{obs}} (A_{ij} - \langle U_{i}, V_{j} \rangle)^2 + w_0 \sum_{(i, j) \not \in \text{obs}} (\langle U_i, V_j\rangle)^2.\]

Here, \(w_0\) is a hyperparameter that weights the two terms so that the objective is not dominated by one or the other. Tuning this hyperparameter is very important.

\[\sum_{(i, j) \in \text{obs}} w_{i, j} (A_{i, j} - \langle U_i, V_j \rangle)^2 + w_0 \sum_{i, j \not \in \text{obs}} \langle U_i, V_j \rangle^2\]

where \(w_{i, j}\) is a function of the frequency of query i and item j.

Minimizing the objective function

Common algorithms to minimize the objective function include:

  • Stochastic gradient descent (SGD) is a generic method to minimize loss functions.

  • Weighted Alternating Least Squares (WALS) is specialized to this particular objective.

The objective is quadratic in each of the two matrices U and V. (Note, however, that the problem is not jointly convex.) WALS works by initializing the embeddings randomly, then alternating between:

  • Fixing \(U\) and solving for \(V\).
  • Fixing \(V\) and solving for \(U\).

Each stage can be solved exactly (via solution of a linear system) and can be distributed. This technique is guaranteed to converge because each step is guaranteed to decrease the loss.

SGD versus WALS

SGD and WALS have advantages and disadvantages. Review the information below to see how they compare:

SGD

Very flexible—can use other loss functions.

Can be parallelized.

Slower—does not converge as quickly.

Harder to handle the unobserved entries (need to use negative sampling or gravity).

WALS

Reliant on Loss Squares only.

Can be parallelized.

Converges faster than SGD.

Easier to handle unobserved entries.