Nov 9, 2023
Implicit Reasoning in LLMs
An explanation of implicit chain-of-thought reasoning through hidden states and knowledge distillation.
Currently, reasoning in LLMs is often achieved through a chain of thought, generating words that follow a thought process. The authors of the recent paper "Implicit Chain of Thought Reasoning via Knowledge Distillation" call this an explicit reasoning. In simple words explicit reasoning means expressing thoughts clearly, using natural language (read more here). Implicit reasoning means that something is understood although not directly expressed. Authors of the paper use this term to show the focus on the hidden states across different layers of the model, rather than generated tokens. Because these states contain rich, condensed information that the model uses to reason and arrive at a conclusion.
The Concept of Implicit Reasoning
Unlike explicit reasoning which generates readable, step-by-step solutions, implicit reasoning involves processing information within the model's hidden layers. This process is not visible or directly interpretable by humans. It uses the concepts from the knowledge distillation, where a complex, well-trained 'teacher' model imparts its understanding to a simpler 'student' model. The student model learns to mimic the teacher's output without reproducing the exact reasoning steps.
Training the Model for Implicit Reasoning
Authors proposed a following three-step strategy:
- Mind-Reading the Teacher: Train a student model to "read" the teacher's "thought process" - the continuous hidden states during intermediate reasoning step generation. The student model, rather than replicating these steps, uses some of the teacher's hidden states to produce the answer.
- Thought Emulation: We then employ knowledge distillation to train an emulator that predicts the teacher's hidden states from the input "vertically", across layers, eliminating the need for "horizontal" explicit reasoning steps.
- Optimization: Combine the emulator, which predicts the teacher's thought process, with the mind-reading student, which produces the final answer from the emulated teacher's thought process. This combined system is then optimized end-to-end, allowing the student model to develop its own reasoning methods that might differ from the teacher's approach.
It's also worth adding that the researchers observed that hidden states in higher layers tend to have larger norms, so they applied normalization to each hidden vector, which significantly impacted the model's reasoning capabilities.
Results
The study showed that a GPT-2 Medium model, under Implicit CoT, achieved 96% accuracy on complex 5 x 5 multiplication tasks - a significant leap from the 2% accuracy in No CoT settings. Different versions of GPT-2 (Small, Medium, Large) were assessed, with the Implicit CoT approach showing consistent improvements in accuracy and efficiency across all sizes.
This approach can significantly increase the speed and efficiency of problem-solving in AI models (like multi-digit multiplication), as it eliminates the time-consuming process of generating explicit reasoning steps.
While promising in terms of efficiency, implicit reasoning may reduce the interpretability of the model's decision-making process, posing challenges in situations where understanding the model's reasoning is crucial.
However...
The study focuses primarily on multi-digit multiplication and grade school math problems. It's important to question how well this methodology generalizes to other types of tasks, especially those involving more abstract or nuanced reasoning. Moreover, the research primarily uses GPT-2 models, which are now a bit outdated. Testing with a broader range of models, including the newest open-source models like LLama2 or Mistral 7b, would strengthen the validity of the findings.
Learning more
If you are interested in a hand-picked, brief list of recently presented papers, check out the Warsaw.AI Newsletter