StateDict
State-Dict
[struct] State-Dict
A structure representing the state dictionary of a neural network model.
- Entry[Hash-Table]: A hash table containing the parameters of the model, where:
- Key[string]: A string representing the parameter name, following a naming convention that reflects the model's architecture.
- Value[Tensor]: A tensor containing the parameter values associated with the key.
This hash table stores all the parameters of the model, enabling easy saving, loading, and manipulation of model weights.
[generic] ->state-dict
Generates a list of cons cells containing the keys and values for creating a State-Dict
from any given Module/Class
.
Return: A list of cons cells in the form (coms (list module_names ...) . tensor)
, representing parameters and their corresponding tensors. module_names are the list of symbols representing the path to the parameter slot from the root module.
TL;DR. this function traverses all slots of the Module/Class and recognizes the following objects as parameters:
- Tensor
- Module/Class
- List of Tensors
- List of Modules/Classes
By default, the keys are created according to the following rules:
- Tensor
- If a slot contains a Tensor, create a cons cell with the key as
(append parent (list slot_name))
and the value as the Tensor. - Module
- If a slot contains a Module, recursively apply
(->state-dict slot_value parent)
to that Module with setting parent =(append parent (list slot_name))
as the new Parent. - List of Tensors/Modules
- If a slot contains a list of Tensors or Modules, apply the above rules to each element.
- Keys are created by appending the slot name and the index (e.g.,
slot_name 0
,slot_name 1
, ...) to the Parent.
To recognize other objects as parameters, please extend this method by adding methods for the desired classes.
[function] get-state-dict
Constructs a State-Dict
by recursively exploring all paramters of the given Module/Class.
- Module[Module/Class] The module from which to extract parameters.
- Key-Mapper[Function] A function that takes a list of names (the first element of the cons cell) and returns a string to be used as the key in the StateDict. Defaults to
pytorch-style-dict-key
. which must be#'(lambda (x) ...)
Returns: State-Dict
[function] load-state-dict
Loads the parameters from the given state-dict into the module, returning the given module.
- silent[boolean] If set to t, suppresses warnings about unused keys, dtype mismatches, shape mismatches, and uninitialized tensors.
- key-mapper[function] A function used to map the keys in the state-dict to the keys in the module. Defaults to pytorch-style-dict-key
. (see: get-state-dict)
Example: Transformer
CATEN-USER> (progn (ql:quickload :caten/llm) (get-state-dict (caten/llm:Transformer 32 2 2 1e-5 32)))
Result
#<STATE-DICT
{
wte.weight -> (32 32)
wpe.weight -> (1024 32)
h.0.attn.c_attn.weight -> (96 32)
h.0.attn.c_attn.bias -> (96)
h.0.attn.c_proj.weight -> (32 32)
h.0.attn.c_proj.bias -> (32)
h.0.mlp.c_fc.weight -> (128 32)
h.0.mlp.c_fc.bias -> (128)
h.0.mlp.c_proj.weight -> (32 128)
h.0.mlp.c_proj.bias -> (32)
h.0.ln_1.affine -> (32)
h.0.ln_1.bias -> (32)
h.0.ln_2.affine -> (32)
h.0.ln_2.bias -> (32)
h.1.attn.c_attn.weight -> (96 32)
h.1.attn.c_attn.bias -> (96)
h.1.attn.c_proj.weight -> (32 32)
h.1.attn.c_proj.bias -> (32)
h.1.mlp.c_fc.weight -> (128 32)
h.1.mlp.c_fc.bias -> (128)
h.1.mlp.c_proj.weight -> (32 128)
h.1.mlp.c_proj.bias -> (32)
h.1.ln_1.affine -> (32)
h.1.ln_1.bias -> (32)
h.1.ln_2.affine -> (32)
h.1.ln_2.bias -> (32)
ln_f.affine -> (32)
ln_f.bias -> (32)
lm_head.weight -> (32 32)
} {1002287923}>