![]() |
NeuZephyr
Simple DL Framework
|
Represents a subtraction operation node in a computational graph. More...
Public Member Functions | |
SubNode (Node *input_left, Node *input_right) | |
Constructor to initialize a SubNode for tensor subtraction. | |
void | forward () override |
Forward pass for the SubNode to perform tensor subtraction. | |
void | backward () override |
Backward pass for the SubNode to propagate gradients. | |
![]() | |
virtual void | print (std::ostream &os) const |
Prints the type, data, and gradient of the node. | |
void | dataInject (Tensor::value_type *data, bool grad=false) const |
Injects data into a relevant tensor object, optionally setting its gradient requirement. | |
template<typename Iterator > | |
void | dataInject (Iterator begin, Iterator end, const bool grad=false) const |
Injects data from an iterator range into the output tensor of the InputNode, optionally setting its gradient requirement. | |
void | dataInject (const std::initializer_list< Tensor::value_type > &data, bool grad=false) const |
Injects data from a std::initializer_list into the output tensor of the Node, optionally setting its gradient requirement. | |
Represents a subtraction operation node in a computational graph.
The SubNode
class performs element-wise subtraction between two input tensors. Unlike scalar-based operations, this node handles tensor-to-tensor subtraction, ensuring compatibility of input shapes and propagating gradients for both input tensors during backpropagation.
Key features:
output
tensor.output
tensor's gradient. For the right input tensor, the gradient is negated before being propagated.This class is part of the nz::nodes
namespace and facilitates tensor-to-tensor subtraction operations in computational graphs.
std::invalid_argument | If the shapes of the two input tensors are not identical. |
Constructor to initialize a SubNode
for tensor subtraction.
The constructor initializes a SubNode
, which performs element-wise subtraction between two input tensors. It validates the shapes of the input tensors to ensure they are compatible for subtraction, establishes connections to the input nodes, and prepares the output tensor for storing the results.
input_left | A pointer to the first input node. Its output tensor is treated as the left operand in the subtraction. |
input_right | A pointer to the second input node. Its output tensor is treated as the right operand in the subtraction. |
inputs
vector, establishing their connection in the computational graph.output
tensor is initialized with the same shape as the input tensors, and its gradient tracking is determined based on the requirements of the input tensors."Sub"
to reflect its operation.std::invalid_argument | If the shapes of the two input tensors are not identical. |
|
overridevirtual |
Backward pass for the SubNode
to propagate gradients.
The backward()
method computes and propagates the gradients of the loss with respect to both input tensors. For the left input tensor, the gradient is directly copied from the output
tensor's gradient. For the right input tensor, the gradient is negated before propagation.
output
tensor is directly copied to the gradient tensor of the left input.output
tensor is negated using a CUDA kernel (Negation
) before being copied to the gradient tensor of the right input.n_grad
) is managed within the method and is released after use.n_grad
) after use.Implements nz::nodes::Node.
Definition at line 302 of file Nodes.cu.
|
overridevirtual |
Forward pass for the SubNode
to perform tensor subtraction.
The forward()
method computes the element-wise subtraction between the two input tensors. It uses a CUDA kernel to perform the operation efficiently on the GPU, storing the result in the output
tensor.
MatrixSub
) is launched to subtract the elements of the second input tensor (inputs[1]
) from the corresponding elements of the first input tensor (inputs[0]
).output
tensor to ensure optimal GPU parallelism.output
tensor.output[i] = input_left[i] - input_right[i]
for each element.Implements nz::nodes::Node.
Definition at line 298 of file Nodes.cu.