- Module
- [macro] defmodule
- Modules (built_in)
- [function] !sum
- [function] !mean
- [function] !max
- [function] !min
- [function] !matmul
- [function] !sinh
- [function] !cosh
- [function] !tanh
- [function] !cos
- [function] !tan
- [function] !log2
- [function] !exp2
- [function] !truncate
- [function] !ceiling
- [function] !floor
- [function] !triu
- [function] !tril
- [function] !argmax
- [function] !argmin
Module
Module
[macro] defmodule
(defmodule (name ((&rest constructor-args) &rest attrs) &key (where nil) (direct-superclasses nil))
(&rest slots)
&key (documentation "") (impl nil) (forward nil) (backward nil))
Define a module named name
.
In Caten, Module
is a CLOS class that represents a set of Funcs and is defined as a subclass of Func
itself. It is used to represent computational nodes that can be expressed through compositions of Funcs. Consequently, as a subclass of Func, Module utilizes the following three methods for manipulation:
[method] impl
(impl (op Module) &rest tensors)
- tensors[list] a list of the input tensor.
In the impl
method, please describe the process for constructing the computational graph of the Module
using a combination of Func
and Module
.
The computational graph must begin with the inputs.
If there are multiple outputs, bind them with cl:values
.
If you need to record Tensors for the backward process, now is the time to do so.
[method] forward
(forward (op Module) &rest tensors)
- tensors[List] a list of the input tensor.
In the forward method, describe the operation to create a Tensor after computation.
Be mindful of its lazy evaluation nature; do not perform the actual computation at this stage.
The st
macro in ShapeTracker is quite useful for creating the Tensor after the operation. If necessary, you may also include checks for additional attributes or slots here.
If you specify ShapeTracker in :where
, the defmodule macro will automatically generate the forward.
Therefore, you must describe either :where
or :forward
.
[method] backward (optional)
(backward (op Module) prev-grad) -> (values input_1.grad input_2.grad ...)
- prev-grad[Tensor]
In the backward
method, describe the gradient computation for the Module using a combination of Func
and Module
.
The arguments are fixed as (op prev-grad)
, where op = module instance, prev-grad is a tensor.
If you need the value of the Tensor at the input stage for the gradient calculation, temporarily store it using the module-sv4bws
accessor while calling the impl
method.
The compiler does not depend on module-sv4bws
, so you are free to choose how to store the Tensor.
In Caten, since save-for-backward is automatically determined, there is no need to be concerned about in-place operations.
Note that backward
is optional. If it is not provided, AD will be applied based on the computational graph from impl
.
[method] lower
The lower method is automatically written by the defmodule
, so there is no need to consider it when describing the module.
However, it is necessary to understand how it is lowered for when defining simplifiers for the Module
.
lower
produces the following node:
(make-node :Graph (intern (symbol-name (symb 'graph/ name)) "KEYWORD") outputs inputs &rest attrs)
Nodes whose class is :Graph
are completely eliminated during lower by impl
.
Syntax
forward
, backward
, impl
are described in one of the following format.
forward := ((op &rest args) &body body)
forward := (lambda (&rest args) &body body)
forward := fname
Effects
- it defines a class named
name
. - it defines a function named
name
. it works as a constructor.
Arguments
- name[symbol] the name of module
- constructor-args[list] arguments for the constructor
- attrs[list] define attrs for the lowered graph based on the constructor-args variables using the following format:
(:key1 value1 :key1 value2 ...)
. - slots[list] slots for the defined class.
- where[nil or string] ShapeTracker
- documentation[string] documentation
Notes
- The methods are called in the order of
forward->impl->backward
during compilation impl
is performed recursively, so modules must not be co-dependent within theimpl
method. (e.g.: do not define a moduleA
that depends onB
that depends onA
...)
Modules (built_in)
[function] !sum
Compute the sum of the tensor.
Result
[function] !mean
Compute the mean of the tensor.
Result
[function] !max
Compute the maximum of the tensor.
Result
[function] !min
Compute the minimum of the tensor.
Result
[function] !matmul
Performs matrix multiplication between two tensors a
and b
.
Result
{Tensor[float32] :shape (32 128) :id STC368087
((14.639064 14.296303 13.489961 12.972249 15.508182 ~ 15.552954 13.380613 15.017203 14.39499 14.849478)
(18.753656 17.762669 15.07676 15.9911 18.203718 ~ 19.217888 15.209106 18.538982 15.9205675 18.074055)
(17.688019 16.558022 16.508041 14.953012 18.29895 ~ 17.889637 16.387316 18.802826 16.794077 17.202274)
(18.226372 16.557941 14.142925 13.2309475 15.661889 ~ 17.346859 14.384592 16.355894 16.286999 16.721907)
(12.376888 13.364449 13.05191 12.340791 14.570738 ~ 15.263534 12.598415 14.250049 14.215093 14.063347)
...
(17.307728 17.511219 15.907667 14.243661 17.97275 ~ 18.632648 16.29067 17.926577 17.499083 17.87277)
(15.822942 17.901558 16.937357 15.8575325 18.324823 ~ 19.711588 16.515047 18.40187 16.012062 17.714437)
(14.2812 16.14327 15.835524 13.185125 17.57536 ~ 16.1084 14.622827 16.493088 15.505656 15.500607)
(15.164535 15.677073 14.974585 14.0352545 15.975354 ~ 17.36294 15.175886 16.725435 15.114097 15.336441)
(16.680836 18.034092 16.628492 13.566727 18.184402 ~ 18.404257 15.847324 18.812273 16.612568 19.07695))
:op #<PROCEEDNODE {10074D69F3}>
:requires-grad NIL
:variables (STC368086)}
[function] !sinh
Result
[function] !cosh
Result
[function] !tanh
Result
[function] !cos
Result
[function] !tan
Result
[function] !log2
Result
[function] !exp2
Result
[function] !truncate
Result
[function] !ceiling
Result
[function] !floor
Result
[function] !triu
Returns the upper triangular part of the tensor (>= 2D) or batch of matrices input.
Result
[function] !tril
Returns the lower triangular part of the tensor (>= 2D) or batch of matrices input.
Result
[function] !argmax
Returns the indices of the maximum values along an axis.
Result
[function] !argmin
Returns the indices of the minimum values along an axis.