Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backpropagate on the predicted shortest path rather than the entire input sequence #121

Open
luciaquirke opened this issue Mar 24, 2023 · 3 comments
Labels
bug Something isn't working enhancement New feature or request

Comments

@luciaquirke
Copy link
Contributor

luciaquirke commented Mar 24, 2023

Description

We want a model that has learned crisp and easily interpretable search algorithms. Such a model will solve mazes with high accuracy. However our ability to train such models is impacted by noise in our backpropagation calculation.

If we have a training sequence of tokens a b c d, then the model receives as separate samples the sequences:

a -> b
a b -> c
a b c -> d

Our current sequence includes both the adjacency list representing the maze and each step in the shortest path. This slows down training because in our first n sequence samples the model is making a prediction for and backpropagating on the adjacency list, which is partially randomly generated.

I think it's inherent to the transformer architecture that the model predicts the next token for each item in the sequence, but we can improve our backpropagation and loss accuracy by only using the shortest path predictions to calculate our loss and gradient updates.

A parameter in the HookedTransformer forward method can be set to return a tensor of per-token losses instead of the overall average loss. We can use this to determine optimal path prediction loss.

Definition of Done

  • We can identify the shortest path tokens in each input sequence (e.g., each dataset item includes the starting index of the solution tokens in the input sequence)
  • The loss/error gradients are calculated from the shortest path tokens onwards

Future Work

  • Ensure our solution algorithm guarantees the shortest path, rather than any valid solution. This will prevent us from training our model not to learn the shortest path
  • Adapt our loss function to penalise path segments from less-than-optimal but still valid solutions less
    than totally invalid path segments
@luciaquirke luciaquirke added bug Something isn't working enhancement New feature or request labels Mar 24, 2023
@luciaquirke luciaquirke changed the title Only backpropagate on the predicted shortest path, not the predicted adjacency list Backpropagate on the predicted shortest path rather than the entire input sequence Mar 24, 2023
@valedan
Copy link
Contributor

valedan commented Mar 25, 2023

Should we maintain support for the old backprop strategy after we introduce this one? I seem to recall some suggestion that there might be some benefit to the model learning the general maze structure by training on the adj_list too.

@valedan
Copy link
Contributor

valedan commented Apr 12, 2023

I think @afspies said he tried this in one of his experiments and it didn't make much difference. I think it would still be good for us to add the option of doing loss in this way to the training script

@afspies
Copy link
Member

afspies commented Apr 14, 2023

Given @canrager 's experiments with smaller-maze-generalization, I am now inclined to believe that a good implementation of this would be valuable (I.e. one that uses padding and masked gradients, such as to not incur the 10% slowdown of my "schnell and dirty"™️ implementation).

The reason for this is that all adjacency lists for a given maze size are the same length, and so (at least some) of the trained models learned to recognise the beginning of the path based on the position in the sequence, rather than the delimiting tokens. What this meant was that any model trained with adjacency lists, when given a smaller maze (and thus shorter adjacency list) would first hallucinate a bunch of adjacency list tokens, before starting a path.

Not surprising in hindsight - if you want a model to generalize w.r.t something, you should probably show it that thing varying during training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants