![]() |
NeuZephyr
Simple DL Framework
|
Contains classes and functions for managing and executing computation graphs in deep learning workflows. More...
Classes | |
class | ComputeGraph |
Represents a computational graph, which manages nodes and the computation flow. More... | |
Functions | |
std::ostream & | operator<< (std::ostream &os, ComputeGraph &graph) |
Overloads the stream insertion operator to print the details of the computational graph. | |
void | CreateNode (ComputeGraph *graph, const std::string &type, const std::string &name, std::vector< int > pre, const Tensor::shape_type &shape, float *data, const bool requires_grad, float *grad) |
Creates and adds a node to the computational graph based on the specified type. | |
Contains classes and functions for managing and executing computation graphs in deep learning workflows.
The nz::graph
namespace provides essential tools for creating, managing, and executing computation graphs in deep learning models. It facilitates the construction of neural networks, supports forward and backward propagation, and allows for gradient computation and optimization steps. This namespace is integral to the workflow of deep learning frameworks, ensuring efficient execution on GPU devices.
Key components within this namespace include:
The nz::graph
namespace is designed with performance in mind, utilizing CUDA to accelerate computation on GPUs.
void nz::graph::CreateNode | ( | ComputeGraph * | graph, |
const std::string & | type, | ||
const std::string & | name, | ||
std::vector< int > | pre, | ||
const Tensor::shape_type & | shape, | ||
float * | data, | ||
const bool | requires_grad, | ||
float * | grad ) |
Creates and adds a node to the computational graph based on the specified type.
This function is used to create various types of nodes in a computational graph based on the provided node type, and then adds the created node to the ComputeGraph
object. The node is initialized with the specified shape, data, and gradient information if needed. It also ensures that the nodes are connected to their previous nodes as specified by the pre
vector.
graph | The ComputeGraph object to which the new node will be added. |
type | A string representing the type of node to be created. Supported types include "Input", "Output", "Add", "MatMul", "Sub", "ReLU", "Sigmoid", "Tanh", "LeakyReLU", "Swish", "ELU", "HardSigmoid", "HardSwish", "Softmax", "MeanSquaredError", "BinaryCrossEntropy". |
name | The name of the node to be added to the graph. |
pre | A vector of integers specifying the indices of the previous nodes (input nodes) that this node depends on. The number of elements in pre and the type of node may vary. |
shape | A vector representing the shape of the node's output tensor. |
data | A pointer to the data to initialize the node's output tensor. |
requires_grad | A boolean flag indicating whether the node requires gradients for backpropagation. |
grad | A pointer to the gradient data for the node's output tensor if requires_grad is true. |
std::runtime_error | If an unsupported node type is provided or if there is a mismatch in node dependencies. |
CreateNode
function automatically handles the creation of nodes, their connection to previous nodes, and the addition of the new node to the graph.pre
vector is used to specify which nodes are required as inputs for the current node, and it may differ in size based on the node type.Definition at line 108 of file ComputeGraph.cu.
std::ostream & nz::graph::operator<< | ( | std::ostream & | os, |
ComputeGraph & | graph ) |
Overloads the stream insertion operator to print the details of the computational graph.
This function overloads the <<
operator to provide an easy and intuitive way to print the details of a ComputeGraph
object. It calls the print
method of ComputeGraph
to output the graph's nodes, their connections, data, gradients, and loss to the provided output stream.
os | The output stream to which the graph details will be printed (e.g., std::cout ). |
graph | The ComputeGraph object whose details will be printed. |
Definition at line 56 of file ComputeGraph.cu.