Say we want to do post-training quantization of an LLM.
For PyTorch models, we’ll usually have an implementation defaulting to bfloat16
and torch.nn
layers, such as torch.nn.Linear
and torch.nn.Embedding
.
We’ll also have pretrained weights. For a HuggingFace model they’ll come in a bunch of .safetensors
files, accompanied by model configs.
To get a quantized model, we can simply:
load the pretrained model into memory (cpu or gpu)
Do this with the default, non-quantized dtype, usually
bfloat16
.replace each layer with its quantized implementation.
For an LLM, that’s going to be
nn.Linear
andnn.Embeddings
as that’s where almost all the parameters of the model are. (The only other weights are in LayerNorms, 2 per each decoder block + 1 before the token prediction head).
See the problem with this approach? If we simply load the pretrained model weights in float32
or bfloat16
, we’ll need the full amount of memory for the model. If we’re doing quantization, it’s likely we’re already memory constrained, so this won’t cut it.
What we’ll do instead is load the model layer by layer, quantizing the weights on the fly. This way, we’ll avoid having all weight tensors in memory in full precision.
Now our plan is:
- Instantiate an empty model, without taking up the memory needed for full precision weights
- For each layer, load its weights, quantize them and add the quantized version to the model. Discard the full precision weights.
Let’s see how.
PyTorch meta
device
From PyTorch docs:
The “meta” device is an abstract device which denotes a tensor which records only metadata, but no actual data. Meta tensors have two primary use cases:
- Models can be loaded on the meta device, allowing you to load a representation of the model without actually loading the actual parameters into memory. This can be helpful if you need to make transformations on the model before you load the actual data
[…]
This is exactly what we need - load the representation of the model, without taking up any space.
Once we’ve done it, our step 2 is to iterate over submodules of the model and monkey-patch the ones we want to have quantized.
Doing it to an actual HF model
Let’s say we want to quantize a HuggingFace implementation of Llama3 model.
To load the model onto meta
device, we could simply do:
1import os
2import torch
3from transformers.models.llama import LlamaConfig, LlamaForCausalLM
4
5HF_MODEL_NAME = 'meta-llama/Llama-3.2-1B-Instruct'
6
7with open(os.path.expanduser('~/.hf_token')) as token_file:
8 token = token_file.read().strip()
9
10model_config = LlamaConfig.from_pretrained(HF_MODEL_NAME, token=token)
11
12with torch.no_grad(), torch.device('meta'):
13 model = LlamaForCausalLM(config=model_config)
Note that no weight loading is happening yet. The key here is line 12, where we use the torch.device('meta')
context manager. This causes all tensors under this context to be created on meta
, unless explicitly specified otherwise.
However, with HuggingFace implementation there’s one big gotcha here. It makes things a lot more complicated.
Here’s the thing. When we instantiate the model:
1with torch.no_grad(), torch.device('meta'):
2 model = LlamaForCausalLM(config=model_config)
all its parameters will be created on meta
. But so will all the buffers, too.
Buffers are considered a part of the model’s state, and by default, when a model is saved they’ll be saved too, alongside all the parameters. That is unless a buffer is registered with persistent=false
.
In the HuggingFace implementation of Llama (and Phi3, probably other LLMs too), the inv_freq
constants for RoPE embeddings are created as non-persistent buffers (source).
1...
2inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
3self.register_buffer("inv_freq", inv_freq, persistent=False) # <-- Here
4self.original_inv_freq = self.inv_freq
5...
This means that:
- The values of these buffers (there’s one per transformer block) are not persisted with weights
- When we instantiate the model using
torch.device('meta')
, they’re initialized but onmeta
, so the actual values are lost - When monkey-patching quantized layers, we’ll move their parameters to cpu or gpu and then copy the values. We have no way to do this for non-persistent registered buffers, as their values are meant to be set during model’s
__init__
, and not persisted to weights files
The HuggingFace solution
Digging deeper into the transformers codebase, we find that when loading quantized models, they’re using the init_empty_weights
context manager.
It’s defined in HuggingFace’s accelerate
project, and in turn refers to another context manager - init_on_device
Inside init_on_device
, there are two key parts:
register_empty_parameter
, a method used to patchtorch.nn.Module.register_parameter
register_empty_buffer
, but optionally. That is, we can have parameters pushed tometa
, but buffers will go to the default device
Thus, to load a model with empty weights HuggingFace will first monkey-patch the torch.nn.Module.register_parameter
method with the following:
1...
2@contextmanager
3def init_on_device(device: torch.device, include_buffers: bool = None):
4 ...
5
6 old_register_parameter = nn.Module.register_parameter
7
8 def register_empty_parameter(module, name, param):
9 old_register_parameter(module, name, param)
10 if param is not None:
11 param_cls = type(module._parameters[name])
12 kwargs = module._parameters[name].__dict__
13 kwargs["requires_grad"] = param.requires_grad
14 module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
15
16 ...
17
18 try:
19 nn.Module.register_parameter = register_empty_parameter
20 ...
21 yield
22 finally:
23 nn.Module.register_parameter = old_register_parameter
24 ...
What the new register_parameter
does is:
- (line 9) let the
nn.Module.register_parameter
do its thing - (line 14) then create a new parameter but send it to
device
, in this case it will bemeta
Knowing this, we can easily create our own standalone context manager to fix our initial attempt.
1@contextmanager
2def init_params_on_meta():
3 device = 'meta'
4 old_register_parameter = torch.nn.Module.register_parameter
5
6 def register_empty_parameter(module, name, param):
7 old_register_parameter(module, name, param)
8 if param is not None:
9 param_cls = type(module._parameters[name])
10 kwargs = module._parameters[name].__dict__
11 kwargs["requires_grad"] = param.requires_grad
12 module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
13 try:
14 torch.nn.Module.register_parameter = register_empty_parameter
15 yield
16 finally:
17 torch.nn.Module.register_parameter = old_register_parameter
With this, our code to create an empty model turns into:
1with torch.no_grad(), init_params_on_meta(): # <-- Here we use the new context manager
2 model = LlamaForCausalLM(config=model_config)
Now, during model initialization all the weights will go to meta
, but the buffers (ie. LlamaRotaryEmbedding.inv_freq
) will remain on the torch default device, so we don’t lose their initialized values.
The next step is to load each layer’s weights, quantize them and replace the standard torch.nn
modules with the quantized implementations. That we’ll do in another post.