![]() |
NeuZephyr
Simple DL Framework
|
Implements im2col transformation for efficient convolution operations in neural networks. More...
Public Member Functions | |
Img2ColNode (Node *input, Tensor::size_type kernelHeight, Tensor::size_type kernelWidth, Tensor::size_type stride, Tensor::size_type padding) | |
Constructor for the Img2ColNode class. | |
void | forward () override |
Performs the forward propagation for the Img2ColNode. | |
void | backward () override |
Performs the backward propagation for the Img2ColNode. | |
![]() | |
virtual void | print (std::ostream &os) const |
Prints the type, data, and gradient of the node. | |
void | dataInject (Tensor::value_type *data, bool grad=false) const |
Injects data into a relevant tensor object, optionally setting its gradient requirement. | |
template<typename Iterator > | |
void | dataInject (Iterator begin, Iterator end, const bool grad=false) const |
Injects data from an iterator range into the output tensor of the InputNode, optionally setting its gradient requirement. | |
void | dataInject (const std::initializer_list< Tensor::value_type > &data, bool grad=false) const |
Injects data from a std::initializer_list into the output tensor of the Node, optionally setting its gradient requirement. | |
Implements im2col transformation for efficient convolution operations in neural networks.
This node converts (N, C, H, W) input tensors into expanded column matrices (N, 1, Hout*Wout, C*K_h*K_w) following the im2col algorithm, enabling efficient convolution computation through matrix multiplication.
Core functionality and characteristics:
Key implementation aspects:
Typical use cases:
Critical considerations:
nz::nodes::calc::Img2ColNode::Img2ColNode | ( | Node * | input, |
Tensor::size_type | kernelHeight, | ||
Tensor::size_type | kernelWidth, | ||
Tensor::size_type | stride, | ||
Tensor::size_type | padding ) |
Constructor for the Img2ColNode class.
input | A pointer to the input Node. Memory location: host. This is a pointer to an existing Node object, and the constructor only stores the pointer, not making a copy of the object. |
kernelHeight | The height of the kernel. Memory location: host. It is a value passed by value, and the constructor stores its copy. |
kernelWidth | The width of the kernel. Memory location: host. It is a value passed by value, and the constructor stores its copy. |
stride | The stride value for the convolution operation. Memory location: host. It is a value passed by value, and the constructor stores its copy. |
padding | The padding value for the convolution operation. Memory location: host. It is a value passed by value, and the constructor stores its copy. |
This constructor initializes an Img2ColNode object. It stores the input node pointer, sets the kernel height, width, stride, and padding values. It also calculates the output height and width based on the input tensor's shape, kernel size, stride, and padding. Then, it creates a new Tensor object for the output with the appropriate shape and gradient requirement. Finally, it sets the node type to "Img2Col".
Memory management strategy: The constructor creates a new Tensor object using std::make_shared
, which manages the memory automatically. The input node pointer is just stored, and no new memory is allocated for it. Exception handling mechanism: There is no explicit exception handling in this constructor. However, if the std::make_shared
call fails to allocate memory for the output Tensor, a std::bad_alloc
exception will be thrown.
std::bad_alloc | If memory allocation for the output Tensor fails. |
|
overridevirtual |
Performs the backward propagation for the Img2ColNode.
None |
This function conducts the backward propagation of the Img2ColNode. It first checks if the output of the input node requires gradient computation. If so, it calls the iImg2colBackward
function, passing in relevant parameters including the gradient of the input node's output, the gradient of the current node's output, output height, output width, the number of input channels, kernel height, kernel width, stride, padding, input height, input width, and batch size. The iImg2colBackward
function is responsible for calculating the gradients with respect to the input.
Memory management strategy: The function does not allocate or free any memory directly. It relies on the pre - allocated memory for the gradients of the input and output tensors. The iImg2colBackward
function is assumed to write the calculated gradients into the pre - allocated gradient tensors. Exception handling mechanism: There is no explicit exception handling in this function. However, if the iImg2colBackward
function encounters problems such as invalid pointers or incorrect input parameters, it may throw an exception.
iImg2colBackward
function. Generally, the time complexity of the iImg2colBackward
operation is O(n), where n is the number of elements in the input gradient tensor.Implements nz::nodes::Node.
|
overridevirtual |
Performs the forward propagation for the Img2ColNode.
None |
This function executes the forward propagation of the Img2ColNode. It calls the iImg2col
function, passing in the necessary parameters such as the output data pointer, input data pointer, output height, output width, number of input channels, kernel height, kernel width, stride, padding, input height, input width, and batch size. The iImg2col
function is responsible for converting the image data into a column - major format, which is useful for performing convolution operations more efficiently.
Memory management strategy: The function does not allocate or free any memory directly. It relies on the memory already allocated for the input and output tensors. The iImg2col
function is assumed to write the results directly into the pre - allocated output tensor. Exception handling mechanism: There is no explicit exception handling in this function. However, if the iImg2col
function encounters issues such as invalid pointers or incorrect input parameters, it may throw an exception.
iImg2col
function. In general, the time complexity of the iImg2col
operation is O(n), where n is the number of elements in the output tensor.Implements nz::nodes::Node.