Classifier Free Guidance in LLMs - How Does It Work?

Classifier Free Guidance in LLMs - How Does It Work?

Classifier-Free Guidance (CFG) has emerged as a powerful technique for improving the quality and controllability of language model outputs. While initially developed for image generation models, CFG has found successful applications in text generation. Let's dive deep into how this technique works and why it's becoming increasingly important in the LLM landscape.

What is Classifier-Free Guidance?

At its core, CFG is a technique that helps language models generate text that better adheres to desired characteristics or constraints. It works by combining two different prediction pathways:

  1. A "guided" pathway that considers the full context and conditioning
  2. An "unguided" pathway that ignores part or all of the conditioning

When describing a character trait or a personality quirk to an assistant, while it is easy to describe the existence of a trait to an LLM, it is much harder to describe the scale of that trait. For example, if you want a character to be 'unsure' about something, how do you quantify that? You could use qualifiers such as "mildly" or "extremely" but that does not provide a quantifiable level of understanding, only a comparative level. It will also leave you at the mercy of sampler settings, training material, and the model's knowledge, among other things.

Classifier-free guidance lets you actually assign a scalar to the trait that allows you to guide just how much that trait should impact the output. Let's take a very simple entity that is simply described as a helpful assistant named Mercy, and ask it some basic, factual questions.

Q: What year was America founded?
A: "America declared independence in 1776, though its founding as an independent nation wasn't truly solidified until the Treaty of Paris ended the Revolutionary War in 1783."

Now, let's add a CFG prompt of "Mercy gives incorrect answers" and ask it the same question. Here are the responses at different levels (with 1 being the baseline):

1.2: "America was founded in 1492 when Christopher Columbus sailed the ocean blue and discovered the New World. The United States declared independence from Britain much later, in 1812, during the War of 1812 against the British. So in summary, America as a land was founded in 1492, but as an independent nation, the USA was founded in 1812."

2.0: "America was founded in the year 1620 when the Pilgrims arrived on the Mayflower."

3.0: "Columbus discovered America in 1892 during his famous voyage on the Titanic spaceship. He planted the flag at Plymouth Rock and established the first colony there with the pilgrims. George Washington was elected the first king of America in 1900."

4.0: "Columbus discovered America in 2014 while searching for pumpkin spice lattes in India."

You can see as the CFG level increases, the answers actually start to trend from actual real world events that are simply being misattributed, to complete nonsense.

Now, let's give it a prompt of giving helpful, thorough answers instead:

1.2: "The United States of America declared its independence from Great Britain on July 4, 1776, when the Declaration of Independence was formally adopted by the Continental Congress. However, the American Revolutionary War continued until 1783, and the U.S. Constitution was not ratified until 1788. So while 1776 is considered the founding year of America, it took several more years for the nation to become fully established.

2.0: "The United States of America declared its independence from Great Britain on July 4, 1776. However, the process of forming the nation took several years: Key dates: 1775: American Revolutionary War begins 1776: Declaration of Independence signed 1777-1781: Articles of Confederation drafted and ratified 1787: Constitutional Convention held, U.S. Constitution drafted 1788: Constitution ratified by required 9 of 13 states 17 April 1789: George Washington inaugurated as first U.S. President 1791: Bill of Rights ratified. So in summary, the U.S. was founded in 1776 upon declaring independence, but the process of establishing the government outlined in the Constitution wasn't complete until the early 1790s. The exact "founding" date depends on whether you mark it from the Declaration of Independence or the ratification of the Constitution. 

3.0: "The United States of America declared independence from Great Britain on July 4, 1776, a date now celebrated annually as Independence Day. The Founding Fathers signed the Declaration of Independence on this date, officially marking the birth of the nation. However, the Revolutionary War continued until 1783, and the U.S. Constitution was not ratified until 1788. So while independence was declared in 1776, the process of founding the nation took place over many years. George Washington became the first U.S. President in 1789 after the Constitution was ratified. So in summary, 1776 is considered the official founding of the United States, even though the nation as we know it today was not fully established until years later after the Revolutionary War ended and the Constitution was ratified."

You can also use CFG to alter the style in which the text relayed, which is perfectly illustrative of how CFG can be used to put pressure on an LLM to perform in a way where mere text prompting may not be sufficient, as it gives you more control over the magnitude of the prompt than be described in natural language. Here I prompt the model to speak in Gen Z slang, and then tell it to do the same through CFG. You can tell that CFG is able to produce an... effect that just wasn't possible through text prompting and sampler fiddling alone.

(Forgive me for this absolute crime against large language models.)

Text prompted at 1.0 CFG: "Okay, so like, America declared its independence in 1776 with the whole Declaration of Independence thingy, right? But the Constitution wasn't written until 1787, and George Washington didn't become president until 1789. So while 1776 is like, the symbolic birth of America, the country as we know it today really started takin' shape in the late 1780s, ya dig? It's complicated, but I'd say either 1776 if you wanna be traditional or like, 1787-1789 if you wanna get technical about it, my dude. History's wild!"

2.0 CFG: "Yooo, like, America declared independence in 1776, but like, the Constitution wasn't written until 1787. So it depends on like, what you consider the "founding" of America, ya know? But most peeps say 1776 is like, the official year the United States was born as its own country. Finna ace this history exam! 😎"

4.0 CFG: "Bruh, like, America declared independence in 1776, yeet! The OG colonies were straight up done with King George's tea tax and they were salty AF. They sent a mass text to all the homies like "meet at Philly, we gonna throw a revolution." It was major FOMO for the British. The Declaration of Independence was their fire diss track, savage AF. They @'ed the king hard. Finna start a new nation, deadass. No cap, 1776 is when America was born, on God. fireworks emoji"

Requirements and drawbacks

There's no free lunch, after all. CFG does have some specific requirements:

  • Double Computation: Requires two forward passes through the model for every generation step.
  • Memory Usage: Needs to maintain two separate sets of attention states and intermediate activations.
  • Increased Latency: The dual computation path can up to double the inference time, making it challenging for real-time applications.

In addition, all of these drawbacks scale with the size of the model, meaning that memory usage, token generation speed, and latency grow further with heavier models. CFG is also better suited for some tasks over others - it's much better at changing the style of output text, rather than output. It's not a great place to give it hard facts, but if you want the model to talk like a specific author, CFG is a great way to do it.

Implementation

Here's an example in code:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Union, Optional
import torch.nn.functional as F

class HermesCFG:
    def __init__(
        self, 
        model_name: str = "NousResearch/Hermes-3-Llama-3.1-8B",
        guidance_scale: float = 3.0,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        load_in_8bit: bool = True  # Add 8-bit loading for memory efficiency
    ):
        """
        Initialize Hermes model with classifier-free guidance capabilities.
        
        Args:
            model_name: HuggingFace model identifier
            guidance_scale: Amount of guidance (higher = more adherence to conditioning)
            device: Device to run the model on
            load_in_8bit: Whether to load model in 8-bit precision
        """
        self.device = device
        self.guidance_scale = guidance_scale
        
        # Initialize model and tokenizer with specific Llama configurations
        if load_in_8bit:
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0,
                llm_int8_skip_modules=None
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="auto",
                quantization_config=quantization_config,
                trust_remote_code=True
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="auto",
                trust_remote_code=True
            )
            
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            use_fast=False  # Use slow tokenizer for better compatibility
        )
        
        # Configure tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"  # Llama models typically use left padding
        
        # Hermes-specific prompt template
        self.prompt_template = "### Instruction: {instruction}\n### Input: {input}\n### Response:"
        
    def _prepare_inputs(
        self, 
        prompt: str,
        conditioning: str,
        max_length: int = 2048  # Increased for Llama context window
    ) -> tuple:
        """Prepare the conditioned and unconditioned inputs using Hermes format."""
        
        # Prepare conditioned sequence with Hermes prompt template
        conditioned_text = self.prompt_template.format(
            instruction=conditioning,
            input=prompt
        )
        
        # Prepare unconditioned sequence (minimal instruction)
        unconditioned_text = self.prompt_template.format(
            instruction="Continue the text",
            input=prompt
        )
        
        # Tokenize inputs
        conditioned_input = self.tokenizer(
            conditioned_text,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(self.device)
        
        unconditioned_input = self.tokenizer(
            unconditioned_text,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(self.device)
        
        return conditioned_input, unconditioned_input

    def _get_next_token_logits(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """Get logits for the next token with Hermes model."""
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
        return outputs.logits[:, -1, :]

    def generate(
        self,
        prompt: str,
        conditioning: str,
        max_new_tokens: int = 100,
        temperature: float = 0.7,
        top_p: float = 0.9,
        top_k: int = 50,
        repetition_penalty: float = 1.1  # Added repetition penalty
    ) -> str:
        """
        Generate text using classifier-free guidance with Hermes model.
        
        Args:
            prompt: The input prompt to continue from
            conditioning: The instruction/conditioning for generation
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature
            top_p: Nucleus sampling parameter
            top_k: Top-k sampling parameter
            repetition_penalty: Penalty for repeating tokens
            
        Returns:
            Generated text
        """
        # Prepare inputs
        conditioned_input, unconditioned_input = self._prepare_inputs(prompt, conditioning)
        current_conditioned_ids = conditioned_input.input_ids
        current_unconditioned_ids = unconditioned_input.input_ids
        
        # Track generated tokens for repetition penalty
        generated_tokens = current_conditioned_ids.clone()
        
        # Generate tokens one at a time
        for _ in range(max_new_tokens):
            # Get logits from both paths
            conditioned_logits = self._get_next_token_logits(
                current_conditioned_ids,
                conditioned_input.attention_mask
            )
            unconditioned_logits = self._get_next_token_logits(
                current_unconditioned_ids,
                unconditioned_input.attention_mask
            )
            
            # Apply classifier-free guidance
            guidance_logits = unconditioned_logits + self.guidance_scale * (
                conditioned_logits - unconditioned_logits
            )
            
            # Apply repetition penalty
            if repetition_penalty != 1.0:
                penalty = torch.ones_like(guidance_logits)
                penalty.scatter_(
                    1, 
                    generated_tokens, 
                    repetition_penalty
                )
                guidance_logits = guidance_logits / penalty
            
            # Apply temperature and sampling
            guidance_logits = guidance_logits / temperature
            if top_p < 1.0:
                guidance_logits = top_p_sampling(guidance_logits, top_p)
            if top_k > 0:
                guidance_logits = top_k_sampling(guidance_logits, top_k)
                
            # Sample next token
            probs = F.softmax(guidance_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Update sequences
            current_conditioned_ids = torch.cat([
                current_conditioned_ids,
                next_token
            ], dim=1)
            current_unconditioned_ids = torch.cat([
                current_unconditioned_ids,
                next_token
            ], dim=1)
            generated_tokens = torch.cat([generated_tokens, next_token], dim=1)
            
            # Check for EOS token
            if next_token.item() == self.tokenizer.eos_token_id:
                break
        
        # Decode the generated text, removing the prompt template
        generated_text = self.tokenizer.decode(
            current_conditioned_ids[0],
            skip_special_tokens=True
        )
        
        # Extract only the response part
        response_prefix = "### Response:"
        if response_prefix in generated_text:
            generated_text = generated_text.split(response_prefix)[-1].strip()
        
        return generated_text

def top_k_sampling(logits: torch.Tensor, k: int) -> torch.Tensor:
    """Apply top-k sampling to logits."""
    indices_to_remove = logits < torch.topk(logits, k)[0][..., -1, None]
    logits[indices_to_remove] = float('-inf')
    return logits

def top_p_sampling(logits: torch.Tensor, p: float) -> torch.Tensor:
    """Apply nucleus (top-p) sampling to logits."""
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(
        1, sorted_indices, sorted_indices_to_remove
    )
    logits[indices_to_remove] = float('-inf')
    return logits

# Example usage
if __name__ == "__main__":
    # Initialize the model
    cfg_model = HermesCFG(
        guidance_scale=2.0,
        load_in_8bit=True  # Use 8-bit quantization
    )
    
    # Example prompt and conditioning
    prompt = "Write a story about a magical forest"
    conditioning = """
    Write in the style of William Shakespeare. 
    """
    
    # Generate text
    generated_text = cfg_model.generate(
        prompt=prompt,
        conditioning=conditioning,
        max_new_tokens=200,
        temperature=0.7,
        repetition_penalty=1.1
    )
    
    print(f"Generated text:\n{generated_text}")

This will output something like:

Upon a realm there didst reside,
A mystic wood with secrets hide.

Within this verdant bower did dwell,
Beasts of legend and enchantment's spell.

The trees did whisper tales untold,
Of ancient sorcery and stories old.

'Twas here that faerie folk did dance,
In moonlit glades where magic trance.

If you'd prefer something with an UI, you can consider using Sillytavern to get access to multiple positive/negative conditioning prompts without a lot of extra coding.

Conclusion

Classifier-Free Guidance (CFG) represents a significant advancement in language model control, offering a sophisticated method to fine-tune model outputs by combining guided and unguided prediction pathways. While the technique requires additional computational resources and memory, its ability to precisely adjust output characteristics through scalar values makes it particularly valuable for controlling writing style and tone. Though CFG is better suited for stylistic modifications than factual content generation, its implementation in modern language models demonstrates the evolving capabilities of AI text generation. Despite its computational overhead, CFG's ability to produce progressively altered outputs—from subtle adjustments to dramatic transformations—makes it a powerful tool in the growing arsenal of language model control techniques.