Compositionality for Transfer Learning

Transfer learning is the idea that, after a machine learning system (or a non-machine learning system, for that matter, like a human) has learned to solve some problem, it should be able to transfer this knowledge to solving similar problems. Humans are pretty good at this, at least compared to current ML systems, which tend to suck.

Why do we expect transfer learning to work? It seems that, in general, we expect that the solution to a task can be decomposed into several pieces, some of which will still be useful for the related task. As an example, suppose we teach a self-driving card to drive to a given place, using a map of the local area and camera input. If we could open up the resulting algorithm, we may expect to find “subroutines” corresponding to

  1. Breaking down the visual input into objects.
  2. Maintaining/updating an internal memory with this data.
  3. Using this data to locate itself on the map
  4. Plotting a route on the map between two points
  5. Executing a route while driving correctly (i.e not breaking the law, not causing crashes).

If we now want this system, instead, to locate and follow a car with a specific license plate (or something), we would expect that most of these routines, except perhaps 3 and 4, would still be useful. It would not have to learn all over again how to recognize other cars and involve crashes.

In other words, we expect transfer learning to happen because of compositionality. To introduce some symbols, we are trying to learn a function \(f: X \to Y\) from observations to decisions. The target behavior is really a function of some high-level description of the system, as understood by humans, i.e it factors as \(X \overset{p}{\to} \bar{X} \overset{f'}{\to} Y\). If the system manages to learn \(p\) as well as \(f'\), then if we swap out the task with another one, which is also specified on the same abstract level, \(g': \bar{X} \to Y\), the system will have less to do - a lot of the parameter space is already in the right configuration.

Of course, if you already knew how to compute the high-level representation \(\bar{X}\), you mostly wouldn’t need machine learning. However, when we view the problem from this angle, it seems one way to get more transfer learning is to look for learning algorithms that output “decomposed” models, so that we can try to separate the abstraction \(X \to \bar{X}\), from the “task-specific logic” \(f': \bar{X} \to Y\).

One way to do this is to find a class of related tasks, \(\{f_i:X \to Y_i | i = 1 \dots n\}\), which we feel share a common abstraction. Then we can pick some set \(\bar{X}\) and try to train \(n+1\) models - one going from \(X \to \bar{X}\), the others going from \(\bar{X} \to Y_i\). We give each of these algorithms the average loss across all the tasks as the loss. Then the “high-level models” \(\bar{X} \to Y_i\) try to make use of the low-level data as best they can, while the “abstraction” model \(X \to \bar{X}\) tries to create an abstraction which is useful to all the high-level models.