Coverage for /usr/local/lib/python3.10/dist-packages/Adifpy/differentiate/node.py: 100%
162 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-07 00:47 -0500
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-07 00:47 -0500
1from __future__ import annotations
3import numpy as np
6class Node:
7 """Object for storing weights and adjoint values during reverse mode AD
9 This class is our implementation of nodes that will be passed through a
10 user's input function in reverse mode AD.
12 >>> foo = Node(3)
13 >>> bar = Node(2)
14 >>> foo * bar
15 6.0, 0
16 """
17 def __init__(self, value: int | float):
18 self.value = value
19 self.children = []
20 self.node_adjoint = None
22 def __str__(self):
23 return f"Node object of value {float(self.value)}, with {len(self.children)} child nodes."
25 def __repr__(self):
26 return f"{float(self.value)}, {len(self.children)}"
28 def reverse_pass(self):
29 # Check if node has been visited already
30 if not self.node_adjoint:
31 self.node_adjoint = 0
33 # Iterate over child nodes, summing together the weights and child node adjoint values
34 for weight, node in self.children:
35 self.node_adjoint += weight * node.reverse_pass()
37 return self.node_adjoint
39 def __add__(self, other: Node | float | int) -> Node:
40 """Add a scalar or a node to a node"""
41 if isinstance(other, Node):
42 z = Node(self.value + other.value)
43 self.children.append([1.0, z])
44 other.children.append([1.0, z])
45 return z
46 elif isinstance(other, float) or isinstance(other, int):
47 z = Node(self.value + other)
48 self.children.append([1.0, z])
49 return z
50 raise TypeError("Operand must be of type int, float, or node.")
52 def __radd__(self, other: float | int) -> Node:
53 """Add a node to a scalar"""
54 return self.__add__(other)
56 def __sub__(self, other: Node | float | int) -> Node:
57 """Subtract a scalar or node from a node"""
58 if isinstance(other, Node):
59 z = Node(self.value - other.value)
60 self.children.append([1.0, z])
61 other.children.append([-1.0, z])
62 return z
63 elif isinstance(other, float) or isinstance(other, int):
64 z = Node(self.value - other)
65 self.children.append([1.0, z])
66 return z
67 raise TypeError("Operand must be of type int, float, or node.")
69 def __rsub__(self, other: float | int) -> Node:
70 """Subtract a node from a scalar"""
71 if isinstance(other, int) or isinstance(other, float):
72 z = Node(other - self.value)
73 self.children.append([-1.0, z])
74 return z
75 raise TypeError("Operand must be of type int, float, or DualNumber.")
77 def __mul__(self, other: Node | float | int) -> Node:
78 """Multiply a node by another node or a scalar"""
79 if isinstance(other, Node):
80 z = Node(self.value * other.value)
81 self.children.append([other.value, z])
82 other.children.append([self.value, z])
83 return z
84 elif isinstance(other, float) or isinstance(other, int):
85 z = Node(self.value * other)
86 self.children.append([other, z])
87 return z
88 raise TypeError("Operand must be of type int, float, or node.")
90 def __rmul__(self, other: float | int) -> Node:
91 """Multiply a scalar by a node"""
92 return self.__mul__(other)
94 def __truediv__(self, other: Node | float | int) -> Node:
95 """Divide a node by another node or a scalar"""
96 if isinstance(other, Node):
97 z = Node(self.value / other.value)
98 self.children.append([1 / other.value, z])
99 other.children.append([-self.value / pow(other.value, 2), z])
100 return z
101 elif isinstance(other, float) or isinstance(other, int):
102 z = Node(self.value / other)
103 self.children.append([1 / other, z])
104 return z
105 raise TypeError("Operand must be of type int, float, or node.")
107 def __rtruediv__(self, other: float | int) -> Node:
108 """Divide a scalar by a node"""
109 if isinstance(other, float) or isinstance(other, int):
110 z = Node(other / self.value)
111 self.children.append([-other / pow(self.value, 2), z])
112 return z
113 raise TypeError("Operand must be of type int, float, or node.")
115 def __pow__(self, other: Node | float | int) -> Node:
116 """Raise a node to the power of another node or a scalar"""
117 if isinstance(other, Node):
118 z = Node(pow(self.value, other.value))
119 self.children.append([other.value * pow(self.value, other.value - 1), z])
120 other.children.append([pow(self.value, other.value) * np.log(self.value), z])
121 return z
122 elif isinstance(other, int) or isinstance(other, float):
123 z = Node(pow(self.value, other))
124 self.children.append([other * pow(self.value, other - 1), z])
125 return z
126 raise TypeError("Operand must be of type int, float, or node.")
128 def __rpow__(self, other: float | int) -> Node:
129 """Raise a scalar to the power of a node"""
130 if isinstance(other, int) or isinstance(other, float):
131 z = Node(pow(other, self.value))
132 self.children.append([pow(other, self.value) * np.log(other), z])
133 return z
134 raise TypeError("Operand must be of type int, float, or node.")
136 def __neg__(self: Node) -> Node:
137 """Negate a node"""
138 return -1 * self
140# Other functions
141 def exp(self):
142 """Exponential e^x"""
143 z = Node(np.exp(self.value))
144 self.children.append([np.exp(self.value), z])
145 return z
147 def sqrt(self):
148 "Square root"
149 return self ** 0.5
151 # Builtin log functions of multiple bases
152 def log(self):
153 "Natural log"
154 z = Node(np.log(self.value))
155 self.children.append([1.0 / self.value, z])
156 return z
158 def log10(self):
159 """Log base 10"""
160 z = Node(np.log10(self.value))
161 self.children.append([1.0 / self.value / np.log(10), z])
162 return z
164 def log2(self):
165 """Log base 2"""
166 z = Node(np.log2(self.value))
167 self.children.append([1.0 / self.value / np.log(2), z])
168 return z
170 # Trigonometric functions
171 def sin(self):
172 z = Node(np.sin(self.value))
173 self.children.append([np.cos(self.value), z])
174 return z
176 def cos(self):
177 z = Node(np.cos(self.value))
178 self.children.append([-np.sin(self.value), z])
179 return z
181 def tan(self):
182 z = Node(np.tan(self.value))
183 self.children.append([1 / pow(np.cos(self.value), 2), z])
184 return z
186 # Inverse trigonometric functions
187 def arcsin(self):
188 z = Node(np.arcsin(self.value))
189 self.children.append([pow(1 - pow(self.value, 2), -1 / 2), z])
190 return z
192 def arccos(self):
193 z = Node(np.arccos(self.value))
194 self.children.append([-pow(1 - pow(self.value, 2), -1 / 2), z])
195 return z
197 def arctan(self):
198 z = Node(np.arctan(self.value))
199 self.children.append([pow(1 + pow(self.value, 2), -1), z])
200 return z
202 # Hyperbolic functions
203 def sinh(self):
204 z = Node(np.sinh(self.value))
205 self.children.append([np.cosh(self.value), z])
206 return z
208 def cosh(self):
209 z = Node(np.cosh(self.value))
210 self.children.append([np.sinh(self.value), z])
211 return z
213 def tanh(self):
214 z = Node(np.tanh(self.value))
215 self.children.append([pow(np.cosh(self.value), -2), z])
216 return z
218 # Inverse hyperbolic functions
219 def arcsinh(self):
220 z = Node(np.arcsinh(self.value))
221 self.children.append([pow(pow(self.value, 2) + 1, -1/2), z])
222 return z
224 def arccosh(self):
225 z = Node(np.arccosh(self.value))
226 self.children.append([pow(pow(self.value, 2) - 1, -1/2), z])
227 return z
229 def arctanh(self):
230 z = Node(np.arctanh(self.value))
231 self.children.append([pow(1 - pow(self.value, 2), -1), z])
232 return z