Caten Documentation

  • Home
  • Quickstart
  • Development
  • API Reference
    • caten/air
    • caten/aasm
    • caten/codegen
    • caten/api
      • Overview
      • Tensor
      • Func
      • Module
      • Model
      • Initializers
      • ShapeTracker
      • Facet API
      • StateDict
    • caten/nn
      • Activation
      • Convolution
      • Criterion
      • Embedding
      • Linear
      • Normalization
      • Padding
      • Pooling
      • Encoding
      • Optimizers
  • Ready to use packages
    • Overview
    • caten/apps.gpt2
  • External Packages
    • caten/gguf
    • caten/oonx
    • caten/llm
In this article
  • State-Dict
    • [struct] State-Dict
    • [generic] ->state-dict
    • [function] get-state-dict
    • [function] load-state-dict
    • Example: Transformer

StateDict

  1. Caten Documentation
  2. API Reference
  3. caten/api
  4. StateDict
|
  • Share via

  •  Edit this article

State-Dict

[struct] State-Dict

A structure representing the state dictionary of a neural network model.

  • Entry[Hash-Table]: A hash table containing the parameters of the model, where:
  • Key[string]: A string representing the parameter name, following a naming convention that reflects the model's architecture.
  • Value[Tensor]: A tensor containing the parameter values associated with the key.

This hash table stores all the parameters of the model, enabling easy saving, loading, and manipulation of model weights.

[generic] ->state-dict

(->state-dict module parents)

Generates a list of cons cells containing the keys and values for creating a State-Dict from any given Module/Class.

Return: A list of cons cells in the form (coms (list module_names ...) . tensor), representing parameters and their corresponding tensors. module_names are the list of symbols representing the path to the parameter slot from the root module.

TL;DR. this function traverses all slots of the Module/Class and recognizes the following objects as parameters:

  • Tensor
  • Module/Class
  • List of Tensors
  • List of Modules/Classes

By default, the keys are created according to the following rules:

  • Tensor
  • If a slot contains a Tensor, create a cons cell with the key as (append parent (list slot_name)) and the value as the Tensor.
  • Module
  • If a slot contains a Module, recursively apply (->state-dict slot_value parent) to that Module with setting parent = (append parent (list slot_name)) as the new Parent.
  • List of Tensors/Modules
  • If a slot contains a list of Tensors or Modules, apply the above rules to each element.
  • Keys are created by appending the slot name and the index (e.g., slot_name 0, slot_name 1, ...) to the Parent.

To recognize other objects as parameters, please extend this method by adding methods for the desired classes.

[function] get-state-dict

(get-state-dict module &key (key-mapper #'pytorch-style-dict-key))

Constructs a State-Dict by recursively exploring all paramters of the given Module/Class.

  • Module[Module/Class] The module from which to extract parameters.
  • Key-Mapper[Function] A function that takes a list of names (the first element of the cons cell) and returns a string to be used as the key in the StateDict. Defaults to pytorch-style-dict-key. which must be #'(lambda (x) ...)

Returns: State-Dict

[function] load-state-dict

(load-state-dict module state-dict &key (silent nil) (key-mapper #'pytorch-style-dict-key))

Loads the parameters from the given state-dict into the module, returning the given module. - silent[boolean] If set to t, suppresses warnings about unused keys, dtype mismatches, shape mismatches, and uninitialized tensors. - key-mapper[function] A function used to map the keys in the state-dict to the keys in the module. Defaults to pytorch-style-dict-key. (see: get-state-dict)

Example: Transformer

lisp
CATEN-USER> (progn (ql:quickload :caten/llm) (get-state-dict (caten/llm:Transformer 32 2 2 1e-5 32)))
Result
Result
#<STATE-DICT 
  {
    wte.weight             -> (32 32)
    wpe.weight             -> (1024 32)
    h.0.attn.c_attn.weight -> (96 32)
    h.0.attn.c_attn.bias   -> (96)
    h.0.attn.c_proj.weight -> (32 32)
    h.0.attn.c_proj.bias   -> (32)
    h.0.mlp.c_fc.weight    -> (128 32)
    h.0.mlp.c_fc.bias      -> (128)
    h.0.mlp.c_proj.weight  -> (32 128)
    h.0.mlp.c_proj.bias    -> (32)
    h.0.ln_1.affine        -> (32)
    h.0.ln_1.bias          -> (32)
    h.0.ln_2.affine        -> (32)
    h.0.ln_2.bias          -> (32)
    h.1.attn.c_attn.weight -> (96 32)
    h.1.attn.c_attn.bias   -> (96)
    h.1.attn.c_proj.weight -> (32 32)
    h.1.attn.c_proj.bias   -> (32)
    h.1.mlp.c_fc.weight    -> (128 32)
    h.1.mlp.c_fc.bias      -> (128)
    h.1.mlp.c_proj.weight  -> (32 128)
    h.1.mlp.c_proj.bias    -> (32)
    h.1.ln_1.affine        -> (32)
    h.1.ln_1.bias          -> (32)
    h.1.ln_2.affine        -> (32)
    h.1.ln_2.bias          -> (32)
    ln_f.affine            -> (32)
    ln_f.bias              -> (32)
    lm_head.weight         -> (32 32)
  } {100B25A0B3}>
Search
Enter a keyword to search.