NeuZephyr
Simple DL Framework
nz::nodes::calc::ReshapeNode Class Reference

Implements tensor shape transformation within a neural network computational graph. More...

Inheritance diagram for nz::nodes::calc::ReshapeNode:
Collaboration diagram for nz::nodes::calc::ReshapeNode:

Public Member Functions

 ReshapeNode (Node *input, const Tensor::shape_type &newShape)
 Constructs a ReshapeNode object to reshape the input tensor.
 
void forward () override
 Performs the forward pass operation of the ReshapeNode.
 
void backward () override
 Performs the backward propagation for the ReshapeNode.
 
- Public Member Functions inherited from nz::nodes::Node
virtual void print (std::ostream &os) const
 Prints the type, data, and gradient of the node.
 
void dataInject (Tensor::value_type *data, bool grad=false) const
 Injects data into a relevant tensor object, optionally setting its gradient requirement.
 
template<typename Iterator >
void dataInject (Iterator begin, Iterator end, const bool grad=false) const
 Injects data from an iterator range into the output tensor of the InputNode, optionally setting its gradient requirement.
 
void dataInject (const std::initializer_list< Tensor::value_type > &data, bool grad=false) const
 Injects data from a std::initializer_list into the output tensor of the Node, optionally setting its gradient requirement.
 

Detailed Description

Implements tensor shape transformation within a neural network computational graph.

The ReshapeNode class modifies the dimensional structure of input tensors while preserving their underlying data. This node enables flexible tensor shape adaptation between different network layers without altering the actual data values.

Core functionality and behavior:

  • Shape Transformation: Reorganizes tensor dimensions according to specified new_shape.
  • Data Preservation: Maintains original data values and memory sharing where possible.
  • Zero-Copy Optimization: Avoids data duplication when compatible memory layouts allow.
  • Gradient Propagation: Correctly routes gradients during backward pass by preserving original tensor dimensions in gradient computations.
  • Runtime Validation: Verifies shape compatibility (total element count) during forward pass.

Implementation specifics:

  • Forward Pass: Adjusts tensor dimensions immediately without data movement.
  • Backward Pass: Restores gradient tensor dimensions to match original input shape.
  • Memory Management: Shares underlying data buffer between input and output tensors.
  • CUDA Support: Maintains device context (CPU/GPU) during reshape operations.

Typical applications:

  • Bridging fully connected layers with convolutional layers requiring specific tensor shapes.
  • Adapting variable-length sequence inputs to fixed-dimension layer requirements.
  • Implementing dynamic computational graphs with shape-modifying operations.

Critical considerations:

  • Element Count Consistency: Input and new_shape must contain identical total elements.
  • Memory Layout Impact: Reshape success depends on tensor memory continuity in some frameworks.
  • Gradient Integrity: Requires maintaining original input shape reference for correct backpropagation.
  • Device Consistency: Input tensor and new_shape must reside on same computation device.
Note
  • Shape modification doesn't alter tensor data ordering - elements follow memory layout order.
  • For discontinuous tensors, reshape may trigger implicit data copy (implementation-dependent).
  • Gradient computation requires preserving original input tensor dimensions throughout node lifetime.
See also
Node Base class for computational graph nodes
Tensor::reshape() Underlying tensor shape modification method

Usage Demonstration:

// Create input node with 4D tensor
InputNode input({2, 3, 4, 5}, true); // Batch 2, 3x4x5 features
// Reshape to 2D tensor (batch size, flattened features)
ReshapeNode reshape(&input, {2, 1, 1, 3*4*5});
reshape.forward();
// Verify new shape
std::cout << "Reshaped tensor dimensions: "
<< reshape.output->shape() << std::endl;
// Backward pass demonstration
reshape.backward();
Implements tensor shape transformation within a neural network computational graph.
Definition Nodes.cuh:3344
void forward() override
Performs the forward pass operation of the ReshapeNode.
Definition Nodes.cu:559
Author
Mgepahmge (https://github.com/Mgepahmge)
Date
2023/10/15

Definition at line 3344 of file Nodes.cuh.

Constructor & Destructor Documentation

◆ ReshapeNode()

nz::nodes::calc::ReshapeNode::ReshapeNode ( Node * input,
const Tensor::shape_type & newShape )

Constructs a ReshapeNode object to reshape the input tensor.

This constructor initializes a ReshapeNode with an input node and a new shape. It checks if the number of dimensions of the input tensor's shape matches the number of dimensions of the new shape. If they match, it adds the input node to the list of inputs, creates a new output tensor with the specified new shape and the same requiresGrad property as the input tensor, and sets the node type to "Reshape".

Parameters
inputA pointer to the input node. Memory location: host - to - device (used to access input tensor information).
newShapeA reference to the new shape of type Tensor::shape_type. Memory location: host - to - device (used to define the new shape).
Returns
None

Memory Management Strategy:

  • The inputs vector stores a pointer to the input node. The memory management of the input node is assumed to be handled by the caller.
  • The output member variable is a std::shared_ptr to a new Tensor object. The memory of the Tensor will be automatically managed by the std::shared_ptr.

Exception Handling Mechanism:

  • Throws std::invalid_argument if the number of dimensions of the input tensor's shape does not match the number of dimensions of the new shape.

Relationship with Other Components:

  • Depends on the input node to access its output tensor.
  • Creates a new Tensor object for the output.
Exceptions
std::invalid_argumentWhen the number of dimensions of the input tensor's shape and the new shape do not match.
Note
  • Ensure that the input node is valid and points to a non - null Node object.
  • Ensure that the new shape has the same number of dimensions as the input tensor's shape to avoid exceptions.
```cpp
Node* inputNode = new Node();
Tensor::shape_type newShape = {1, 2, 3, 4};
try {
ReshapeNode reshapeNode(inputNode, newShape);
} catch (const std::exception& e) {
std::cerr << e.what() << std::endl;
}
```
Represents a multi - dimensional shape, typically used in deep learning for tensor dimensions.
Definition Dimension.cuh:57
Base class for nodes in a neural network or computational graph.
Definition Nodes.cuh:114

Definition at line 550 of file Nodes.cu.

Here is the call graph for this function:

Member Function Documentation

◆ backward()

void nz::nodes::calc::ReshapeNode::backward ( )
overridevirtual

Performs the backward propagation for the ReshapeNode.

Parameters
None
Returns
None

This function is responsible for performing the backward propagation in the ReshapeNode. If the output of the first input tensor requires gradient computation, it copies the gradient of the output tensor to the gradient of the first input tensor's output. The memory copy operation is performed using the CUDA stream manager with a device-to-device memory transfer.

Memory management strategy: The function does not allocate or free any memory. It only copies existing memory using CUDA's memcpy. Exception handling mechanism: There is no explicit exception handling in this function. However, CUDA's memcpy operation may throw errors if there are issues with the memory pointers or the CUDA device.

Note
  • Ensure that the CUDA device is properly initialized before calling this function.
  • The sizes of the output tensor and the first input tensor's output must match for the memory copy to be valid.
```cpp
// Assume inputs and output tensors are properly initialized
node.backward();
```
void backward() override
Performs the backward propagation for the ReshapeNode.
Definition Nodes.cu:565
Author
Mgepahmge(https://github.com/Mgepahmge)
Date
2024/07/15

Implements nz::nodes::Node.

Definition at line 565 of file Nodes.cu.

Here is the call graph for this function:

◆ forward()

void nz::nodes::calc::ReshapeNode::forward ( )
overridevirtual

Performs the forward pass operation of the ReshapeNode.

This function copies the data from the output tensor of the input node to the output tensor of the ReshapeNode using CUDA memory copy. It uses the singleton instance of cuStrm::StreamManager<float> to manage the CUDA stream for the memory copy operation.

Parameters
None
Returns
None

Memory Management Strategy:

  • The function relies on the CUDA memory copy operation (cudaMemcpyDeviceToDevice). The source and destination memory are managed by the Tensor objects (inputs[0]->output and output).
  • The CUDA memory copy operation is assumed to handle the actual data transfer and memory allocation/deallocation related to the transfer properly.

Exception Handling Mechanism:

  • If the CUDA memory copy operation fails, it may return a CUDA error code. However, this function does not handle CUDA errors explicitly. Callers should check the CUDA error state after calling this function if necessary.

Relationship with Other Components:

  • Depends on the cuStrm::StreamManager<float> singleton to manage the CUDA stream for the memory copy.
  • Relies on the inputs[0]->output tensor for the source data and the output tensor for the destination data.
Exceptions
Noneexplicitly, but CUDA errors may occur during the memory copy operation.
Note
  • Ensure that the CUDA runtime environment is properly initialized before calling this function.
  • The CUDA memory copy operation assumes that the source and destination memory regions are valid and have sufficient space.
  • The time complexity of this function is mainly determined by the CUDA memory copy operation, which is typically proportional to the size of the data being copied (O(n), where n is the number of elements in the output tensor).
```cpp
// Assume ReshapeNode, Node, and Tensor are defined
Node* inputNode = new Node();
Tensor::shape_type newShape = {1, 2, 3, 4};
ReshapeNode reshapeNode(inputNode, newShape);
reshapeNode.forward();
```

Implements nz::nodes::Node.

Definition at line 559 of file Nodes.cu.

Here is the call graph for this function:

The documentation for this class was generated from the following files: