NeuZephyr
Simple DL Framework
nz::nodes::calc::SubNode Class Reference

Represents a subtraction operation node in a computational graph. More...

Inheritance diagram for nz::nodes::calc::SubNode:
Collaboration diagram for nz::nodes::calc::SubNode:

Public Member Functions

 SubNode (Node *input_left, Node *input_right)
 Constructor to initialize a SubNode for tensor subtraction.
 
void forward () override
 Forward pass for the SubNode to perform tensor subtraction.
 
void backward () override
 Backward pass for the SubNode to propagate gradients.
 
- Public Member Functions inherited from nz::nodes::Node
virtual void print (std::ostream &os) const
 Prints the type, data, and gradient of the node.
 
void dataInject (Tensor::value_type *data, bool grad=false) const
 Injects data into a relevant tensor object, optionally setting its gradient requirement.
 
template<typename Iterator >
void dataInject (Iterator begin, Iterator end, const bool grad=false) const
 Injects data from an iterator range into the output tensor of the InputNode, optionally setting its gradient requirement.
 
void dataInject (const std::initializer_list< Tensor::value_type > &data, bool grad=false) const
 Injects data from a std::initializer_list into the output tensor of the Node, optionally setting its gradient requirement.
 

Detailed Description

Represents a subtraction operation node in a computational graph.

The SubNode class performs element-wise subtraction between two input tensors. Unlike scalar-based operations, this node handles tensor-to-tensor subtraction, ensuring compatibility of input shapes and propagating gradients for both input tensors during backpropagation.

Key features:

  • Forward Pass: Computes the element-wise subtraction of two input tensors and stores the result in the output tensor.
  • Backward Pass: Propagates gradients for both input tensors. For the left input tensor, the gradient is directly copied from the output tensor's gradient. For the right input tensor, the gradient is negated before being propagated.
  • Shape Validation: Ensures the shapes of the two input tensors are identical during construction. Mismatched shapes result in an exception.
  • Gradient Management: Tracks whether gradients are required for either of the input tensors, and propagates gradients accordingly.

This class is part of the nz::nodes namespace and facilitates tensor-to-tensor subtraction operations in computational graphs.

Note
  • The left and right input tensors must have the same shape; otherwise, an exception will be thrown.
  • Gradients are propagated efficiently, with negation applied to the right input tensor's gradient during backpropagation.

Usage Example:

// Example: Using SubNode for tensor subtraction
InputNode input1({3, 3}, true); // Create the first input node with shape {3, 3}
InputNode input2({3, 3}, true); // Create the second input node with shape {3, 3}
input1.output->fill(5.0f); // Fill the first tensor with value 5.0
input2.output->fill(3.0f); // Fill the second tensor with value 3.0
SubNode sub_node(&input1, &input2); // Subtract input2 from input1
sub_node.forward(); // Perform the forward pass
sub_node.backward(); // Propagate gradients in the backward pass
std::cout << "Output: " << *sub_node.output << std::endl; // Print the result
SubNode(Node *input_left, Node *input_right)
Constructor to initialize a SubNode for tensor subtraction.
Definition Nodes.cu:283
See also
forward() for the tensor subtraction computation in the forward pass.
backward() for gradient propagation in the backward pass.
Exceptions
std::invalid_argumentIf the shapes of the two input tensors are not identical.
Author
Mgepahmge (https://github.com/Mgepahmge)
Date
2024/12/05

Definition at line 1787 of file Nodes.cuh.

Constructor & Destructor Documentation

◆ SubNode()

nz::nodes::calc::SubNode::SubNode ( Node * input_left,
Node * input_right )

Constructor to initialize a SubNode for tensor subtraction.

The constructor initializes a SubNode, which performs element-wise subtraction between two input tensors. It validates the shapes of the input tensors to ensure they are compatible for subtraction, establishes connections to the input nodes, and prepares the output tensor for storing the results.

Parameters
input_leftA pointer to the first input node. Its output tensor is treated as the left operand in the subtraction.
input_rightA pointer to the second input node. Its output tensor is treated as the right operand in the subtraction.
  • The constructor validates that the shapes of the two input tensors are identical. If the shapes do not match, an exception is thrown to prevent invalid operations.
  • The input nodes are added to the inputs vector, establishing their connection in the computational graph.
  • The output tensor is initialized with the same shape as the input tensors, and its gradient tracking is determined based on the requirements of the input tensors.
  • The node's type is set to "Sub" to reflect its operation.
Note
  • The input tensors must have the same shape; mismatched shapes result in an exception.
  • This node supports automatic gradient tracking if either input tensor requires gradients.
Exceptions
std::invalid_argumentIf the shapes of the two input tensors are not identical.
See also
forward() for the forward pass implementation.
backward() for gradient propagation in the backward pass.
Author
Mgepahmge (https://github.com/Mgepahmge)
Date
2024/12/05

Definition at line 283 of file Nodes.cu.

Member Function Documentation

◆ backward()

void nz::nodes::calc::SubNode::backward ( )
overridevirtual

Backward pass for the SubNode to propagate gradients.

The backward() method computes and propagates the gradients of the loss with respect to both input tensors. For the left input tensor, the gradient is directly copied from the output tensor's gradient. For the right input tensor, the gradient is negated before propagation.

  • If the left input tensor requires gradients, the gradient from the output tensor is directly copied to the gradient tensor of the left input.
  • If the right input tensor requires gradients, a temporary buffer is allocated, and the gradient of the output tensor is negated using a CUDA kernel (Negation) before being copied to the gradient tensor of the right input.
  • Temporary GPU memory (n_grad) is managed within the method and is released after use.
Note
  • Gradient computation is efficient and ensures proper handling for both input tensors.
  • The negation step for the right input tensor is necessary due to the derivative of subtraction with respect to the right operand being -1.
Warning
  • Proper memory management is ensured by freeing the temporary buffer (n_grad) after use.
See also
forward() for the tensor subtraction computation in the forward pass.
Author
Mgepahmge (https://github.com/Mgepahmge)
Date
2024/12/05

Implements nz::nodes::Node.

Definition at line 302 of file Nodes.cu.

Here is the call graph for this function:

◆ forward()

void nz::nodes::calc::SubNode::forward ( )
overridevirtual

Forward pass for the SubNode to perform tensor subtraction.

The forward() method computes the element-wise subtraction between the two input tensors. It uses a CUDA kernel to perform the operation efficiently on the GPU, storing the result in the output tensor.

  • A CUDA kernel (MatrixSub) is launched to subtract the elements of the second input tensor (inputs[1]) from the corresponding elements of the first input tensor (inputs[0]).
  • The grid and block dimensions are calculated dynamically based on the size of the output tensor to ensure optimal GPU parallelism.
  • The result of the subtraction is stored in the output tensor.
Note
  • The subtraction operation is performed as output[i] = input_left[i] - input_right[i] for each element.
  • Ensure the shapes of the input tensors are identical; this is enforced during node construction.
See also
backward() for gradient propagation in the backward pass.
Author
Mgepahmge (https://github.com/Mgepahmge)
Date
2024/12/05

Implements nz::nodes::Node.

Definition at line 298 of file Nodes.cu.

Here is the call graph for this function:

The documentation for this class was generated from the following files: