NeuZephyr
Simple DL Framework
nz::nodes Namespace Reference

Contains classes and functionality for nodes in a neural network or computational graph. More...

Namespaces

namespace  calc
 Contains classes and functionality for computation nodes in a neural network or computational graph.
 
namespace  io
 This namespace contains standard nodes used in computational graphs for neural networks.
 
namespace  loss
 Contains loss function nodes for computing various loss metrics in a machine learning model.
 

Classes

class  Node
 Base class for nodes in a neural network or computational graph. More...
 

Functions

template<typename T >
std::enable_if_t< std::is_base_of_v< Node, T >, std::ostream & > operator<< (std::ostream &os, const T &node)
 Overloads the << operator to print information about a node.
 

Detailed Description

Contains classes and functionality for nodes in a neural network or computational graph.

The nz::nodes namespace provides a collection of classes that represent various layers and operations in a neural network. Each node is an essential component of a computational graph, responsible for performing specific computations during the forward and backward passes.

This namespace includes:

  • Node Class: The abstract base class Node, which defines the interface for all types of nodes. It provides the basic structure and functionality, including methods for forward and backward passes.
  • Derived Node Classes: A set of derived classes representing common operations and layers in neural networks, including activation functions, mathematical operations, and loss functions. Examples include:
    • Activation Functions: LeakyReLUNode, SwishNode, ELUNode, HardSigmoidNode, HardSwishNode, SoftmaxNode.
    • Mathematical Operations: AddNode, MatMulNode, ScalarMulNode, ScalarDivNode, etc.
    • Loss Functions: MeanSquaredErrorNode, BinaryCrossEntropyNode.

The nodes in this namespace work with Tensor objects to propagate data and gradients through the network, supporting the training and inference processes of deep learning models.

Note
  • The nodes in this namespace are designed to be used as part of a computational graph, and each node can be connected to other nodes to define the structure of a neural network.
  • Ensure proper memory management when working with tensors, particularly when dealing with GPU memory.
Author
Mgepahmge (https://github.com/Mgepahmge)
Date
2024/11/29

Function Documentation

◆ operator<<()

template<typename T >
std::enable_if_t< std::is_base_of_v< Node, T >, std::ostream & > nz::nodes::operator<< ( std::ostream & os,
const T & node )

Overloads the << operator to print information about a node.

The operator<< is overloaded to provide a convenient way to print detailed information about a node, including its type, data, gradient, and loss (if applicable). This operator calls the print() method of the node, which handles the actual formatting and output of the node's information.

The operator outputs the following details:

  • Type: The type of the node (e.g., the operation it represents, such as "MatrixMul", "ReLU", etc.).
  • Data: The tensor data stored in the node's output tensor.
  • Gradient: If the node has a computed gradient, it is displayed, providing insights into the gradient values being backpropagated during training.
  • Loss: The loss value associated with the node (if applicable), which can be used to monitor the error or discrepancy in the node during the forward-backward pass.

This operator is primarily used for debugging, logging, and inspecting the state of the node, including its tensor data, gradients, and any associated loss. By using the << operator, you can easily print the node's information directly to standard output or any other output stream.

Note
  • The print() method must be implemented by the node's class (or any class derived from it). This method should handle printing the type, data, gradient, and loss for that specific class.
  • This operator is designed to be used with any class that has a print() method, making it a flexible and reusable solution for logging and debugging.
Parameters
osThe output stream (e.g., std::cout) to which the node's information will be printed.
nodeThe node object to be printed. It is passed as a const reference to ensure it is not modified.
Returns
The output stream (os), allowing the operator to be used in chain expressions like std::cout << node1 << node2;.
Author
Mgepahmge (https://github.com/Mgepahmge)
Date
2024/11/29

Definition at line 114 of file Nodes.cuh.