- Module
- [macro] defmodule
- Modules (built_in)
- [function] !sum
- [function] !mean
- [function] !max
- [function] !min
- [function] !matmul
- [function] !sinh
- [function] !cosh
- [function] !tanh
- [function] !cos
- [function] !tan
- [function] !log2
- [function] !exp2
- [function] !expt
- [function] !truncate
- [function] !ceiling
- [function] !floor
- [function] !triu
- [function] !tril
- [function] !argmax
- [function] !argmin
- [function] !clip
- [function] !erf
Module
Module
[macro] defmodule
(defmodule (name ((&rest constructor-args) &rest attrs) &key (where nil) (direct-superclasses nil))
(&rest slots)
&key (documentation "") (impl nil) (forward nil) (backward nil))
Define a module named name
.
In Caten, Module
is a CLOS class that represents a set of Funcs and is defined as a subclass of Func
itself. It is used to represent computational nodes that can be expressed through compositions of Funcs. Consequently, as a subclass of Func, Module utilizes the following three methods for manipulation:
[method] impl
(impl (op Module) &rest tensors)
- tensors[list] a list of the input tensor.
In the impl
method, please describe the process for constructing the computational graph of the Module
using a combination of Func
and Module
.
The computational graph must begin with the inputs.
If there are multiple outputs, bind them with cl:values
.
If you need to record Tensors for the backward process, now is the time to do so.
[method] forward
(forward (op Module) &rest tensors)
- tensors[List] a list of the input tensor.
In the forward method, describe the operation to create a Tensor after computation.
Be aware of its lazy evaluation nature; do not perform the actual computation at this stage.
The st
macro in ShapeTracker is quite useful for creating the Tensor after the operation. If necessary, you may also include checks for additional attributes or slots here.
If you specify ShapeTracker in :where
, the defmodule macro will automatically generate the forward.
Therefore, you must describe either :where
or :forward
.
[method] backward (optional)
(backward (op Module) prev-grad) -> (values input_1.grad input_2.grad ...)
- prev-grad[Tensor]
In the backward
method, describe the gradient computation for the Module using a combination of Func
and Module
.
The arguments are fixed as (op prev-grad)
, where op = module instance, prev-grad is a tensor.
If you need the value of the Tensor at the input stage for the gradient calculation, temporarily store it using the module-sv4bws
accessor while calling the impl
method.
The compiler does not depend on module-sv4bws
, so you are free to choose how to store the Tensor.
In Caten, since save-for-backward is automatically determined, there is no need to be concerned about in-place operations.
Note that backward
is optional. If it is not provided, AD will be applied based on the computational graph from impl
.
[method] lower
The lower method is automatically written by the defmodule
, so there is no need to consider it when describing the module.
However, it is necessary to understand how it is lowered for when defining simplifiers for the Module
.
lower
produces the following node:
(make-node :Graph (intern (symbol-name (symb 'graph/ name)) "KEYWORD") outputs inputs &rest attrs)
Nodes whose class is :Graph
are completely eliminated during lower by impl
.
Syntax
forward
, backward
, impl
are described in one of the following format.
forward := ((op &rest args) &body body)
forward := (lambda (&rest args) &body body)
forward := fname
Effects
- it defines a class named
name
. - it defines a function named
name
. it works as a constructor.
Arguments
- name[symbol] the name of module
- constructor-args[list] arguments for the constructor
- attrs[list] define attrs for the lowered graph based on the constructor-args variables using the following format:
(:key1 value1 :key1 value2 ...)
. - slots[list] slots for the defined class.
- where[nil or string] ShapeTracker
- documentation[string] documentation
Notes
- The methods are called in the order of
forward->impl->backward
during compilation impl
is performed recursively, so modules must not be co-dependent within theimpl
method. (e.g.: do not define a moduleA
that depends onB
that depends onA
...)
Modules (built_in)
[function] !sum
Compute the sum of the tensor.
Result
[function] !mean
Compute the mean of the tensor.
Result
[function] !max
Compute the maximum of the tensor.
Result
[function] !min
Compute the minimum of the tensor.
Result
[function] !matmul
Performs matrix multiplication between the two tensors a
and b
.
Result
{Tensor{LISPBUFFER}[float32] :shape (32 128) :id STC166090
((16.115143 16.458525 16.092386 17.81228 13.661716 ~ 17.686089 18.038162 16.683573 16.274176 18.700905)
(16.326677 15.621148 16.356201 17.916739 14.871678 ~ 18.09569 17.361448 15.73287 16.648724 18.704985)
(16.722721 14.823622 14.077707 16.202015 13.752501 ~ 17.039087 16.04018 16.320084 16.388088 16.415977)
(16.117361 16.602467 14.900114 16.754038 12.972164 ~ 15.914988 16.225424 14.572695 16.187672 17.73983)
(15.025454 14.202459 13.762134 14.921074 12.034962 ~ 14.2497 14.7951 15.23291 14.765818 17.027578)
...
(14.671678 15.548787 14.961614 16.846735 14.662898 ~ 15.099126 15.20463 15.049035 16.246407 17.84826)
(15.084174 14.32615 15.439503 17.32295 13.282019 ~ 16.467468 15.561042 15.09145 16.568888 16.251673)
(15.426824 15.761291 14.47131 15.671989 13.309795 ~ 15.386046 15.088989 14.298187 14.939453 17.597435)
(16.099691 15.497888 13.230323 15.339464 12.640831 ~ 16.377731 16.77367 15.217857 16.147154 15.99234)
(14.86157 15.624828 13.565197 16.511955 13.524295 ~ 16.017046 14.623064 14.877545 16.962063 16.800804))
:op #<PROCEEDNODE {1002ED0633}>
:requires-grad NIL
:variables (STC161638)
:tracker #<TRACKER :order={row(0 1)} :shape=(32 128) :contiguous-p=T>}
[function] !sinh
Result
{Tensor{LISPBUFFER}[float32] :shape (3 3) :id STC166854
((-0.6667148 -1.7986057 -0.16147432)
(0.99311143 1.5690467 -0.45024085)
(0.2741389 -0.84445757 3.6562643))
:op #<PROCEEDNODE {1003399B73}>
:requires-grad NIL
:variables (STC166097)
:tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}
[function] !cosh
Result
{Tensor{LISPBUFFER}[float32] :shape (3 3) :id STC167612
((1.1091444 12.141601 1.8187159)
(1.0450624 1.0673907 2.024306)
(1.0057459 11.280086 1.0366853))
:op #<PROCEEDNODE {100372F103}>
:requires-grad NIL
:variables (STC166861)
:tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}
[function] !tanh
Result
{Tensor{LISPBUFFER}[float32] :shape (3 3) :id STC170370
((0.48222518 0.83497787 0.15090048)
(-0.6021365 -0.8142663 0.0836575)
(0.90113175 0.5248045 0.26375103))
:op #<PROCEEDNODE {1003FA9C23}>
:requires-grad NIL
:variables (STC167619)
:tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}
[function] !cos
Result
{Tensor{LISPBUFFER}[float32] :shape (3 3) :id STC171055
((0.8741474 0.6821368 0.96187115)
(0.9756673 -0.15136945 0.5390583)
(0.7348094 0.56210893 0.9987133))
:op #<PROCEEDNODE {1005563B83}>
:requires-grad NIL
:variables (STC170377)
:tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}
[function] !tan
Result
{Tensor{LISPBUFFER}[float32] :shape (3 3) :id STC171779
((2.2656112 1.2551787 -52.0603)
(0.11679068 0.27259648 0.0093172565)
(0.26652285 -0.48955667 -2.6124444))
:op #<PROCEEDNODE {1005BD7AB3}>
:requires-grad NIL
:variables (STC171062)
:tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}
[function] !log2
Result
{Tensor{LISPBUFFER}[float32] :shape (3 3) :id STC172535
((-3.321928 0.13750356 1.0703893)
(1.6322682 2.0356238 2.3504972)
(2.6088092 2.827819 3.017922))
:op #<PROCEEDNODE {100719F343}>
:requires-grad NIL
:variables (STC171787)
:tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}
[function] !exp2
Result
[function] !expt
Computes the power of base with power.Result
[function] !truncate
Result
[function] !ceiling
Result
[function] !floor
Result
[function] !triu
Returns the upper triangular part of the tensor (>= 2D) or batch of matrices input.
Result
[function] !tril
Returns the lower triangular part of the tensor (>= 2D) or batch of matrices input.
Result
[function] !argmax
Returns the indices of the maximum values along an axis.
Result
[function] !argmin
Returns the indices of the minimum values along an axis.
Result
[function] !clip
!clip limits the given input within an interval. The interval is specified by the inputs 'min' and 'max'. min/max is either a number, a symbol, or a tensor.
The implementation follows the ONNX specification. https://github.com/onnx/onnx/blob/main/docs/Changelog.md#clip-13
Result
[function] !erf
Computes the error function of x.Result
{Tensor{LISPBUFFER}[float32] :shape (3 3) :id STC207681
((-0.5891361 -0.9240581 -0.08728349)
(0.9355107 0.5068274 -0.44668216)
(-0.76412225 -0.96422094 0.37599617))
:op #<PROCEEDNODE {10031D7013}>
:requires-grad NIL
:variables (STC196342)
:tracker #<TRACKER :order={row(0 1)} :shape=(3 3) :contiguous-p=T>}