![]() |
NeuZephyr
Simple DL Framework
|
Contains classes and functionality for nodes in a neural network or computational graph. More...
Namespaces | |
namespace | calc |
Contains classes and functionality for computation nodes in a neural network or computational graph. | |
namespace | io |
This namespace contains standard nodes used in computational graphs for neural networks. | |
namespace | loss |
Contains loss function nodes for computing various loss metrics in a machine learning model. | |
Classes | |
class | Node |
Base class for nodes in a neural network or computational graph. More... | |
Functions | |
template<typename T > | |
std::enable_if_t< std::is_base_of_v< Node, T >, std::ostream & > | operator<< (std::ostream &os, const T &node) |
Overloads the << operator to print information about a node. | |
Contains classes and functionality for nodes in a neural network or computational graph.
The nz::nodes
namespace provides a collection of classes that represent various layers and operations in a neural network. Each node is an essential component of a computational graph, responsible for performing specific computations during the forward and backward passes.
This namespace includes:
Node
, which defines the interface for all types of nodes. It provides the basic structure and functionality, including methods for forward and backward passes.LeakyReLUNode
, SwishNode
, ELUNode
, HardSigmoidNode
, HardSwishNode
, SoftmaxNode
.AddNode
, MatMulNode
, ScalarMulNode
, ScalarDivNode
, etc.MeanSquaredErrorNode
, BinaryCrossEntropyNode
.The nodes in this namespace work with Tensor
objects to propagate data and gradients through the network, supporting the training and inference processes of deep learning models.
std::enable_if_t< std::is_base_of_v< Node, T >, std::ostream & > nz::nodes::operator<< | ( | std::ostream & | os, |
const T & | node ) |
Overloads the <<
operator to print information about a node.
The operator<<
is overloaded to provide a convenient way to print detailed information about a node, including its type, data, gradient, and loss (if applicable). This operator calls the print()
method of the node, which handles the actual formatting and output of the node's information.
The operator outputs the following details:
output
tensor.This operator is primarily used for debugging, logging, and inspecting the state of the node, including its tensor data, gradients, and any associated loss. By using the <<
operator, you can easily print the node's information directly to standard output or any other output stream.
print()
method must be implemented by the node's class (or any class derived from it). This method should handle printing the type, data, gradient, and loss for that specific class.print()
method, making it a flexible and reusable solution for logging and debugging.os | The output stream (e.g., std::cout ) to which the node's information will be printed. |
node | The node object to be printed. It is passed as a const reference to ensure it is not modified. |
os
), allowing the operator to be used in chain expressions like std::cout << node1 << node2;
.