Coverage for /usr/local/lib/python3.10/dist-packages/Adifpy/differentiate/node.py: 100%

162 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-07 00:47 -0500

1from __future__ import annotations 

2 

3import numpy as np 

4 

5 

6class Node: 

7 """Object for storing weights and adjoint values during reverse mode AD 

8 

9 This class is our implementation of nodes that will be passed through a 

10 user's input function in reverse mode AD. 

11 

12 >>> foo = Node(3) 

13 >>> bar = Node(2) 

14 >>> foo * bar 

15 6.0, 0 

16 """ 

17 def __init__(self, value: int | float): 

18 self.value = value 

19 self.children = [] 

20 self.node_adjoint = None 

21 

22 def __str__(self): 

23 return f"Node object of value {float(self.value)}, with {len(self.children)} child nodes." 

24 

25 def __repr__(self): 

26 return f"{float(self.value)}, {len(self.children)}" 

27 

28 def reverse_pass(self): 

29 # Check if node has been visited already 

30 if not self.node_adjoint: 

31 self.node_adjoint = 0 

32 

33 # Iterate over child nodes, summing together the weights and child node adjoint values 

34 for weight, node in self.children: 

35 self.node_adjoint += weight * node.reverse_pass() 

36 

37 return self.node_adjoint 

38 

39 def __add__(self, other: Node | float | int) -> Node: 

40 """Add a scalar or a node to a node""" 

41 if isinstance(other, Node): 

42 z = Node(self.value + other.value) 

43 self.children.append([1.0, z]) 

44 other.children.append([1.0, z]) 

45 return z 

46 elif isinstance(other, float) or isinstance(other, int): 

47 z = Node(self.value + other) 

48 self.children.append([1.0, z]) 

49 return z 

50 raise TypeError("Operand must be of type int, float, or node.") 

51 

52 def __radd__(self, other: float | int) -> Node: 

53 """Add a node to a scalar""" 

54 return self.__add__(other) 

55 

56 def __sub__(self, other: Node | float | int) -> Node: 

57 """Subtract a scalar or node from a node""" 

58 if isinstance(other, Node): 

59 z = Node(self.value - other.value) 

60 self.children.append([1.0, z]) 

61 other.children.append([-1.0, z]) 

62 return z 

63 elif isinstance(other, float) or isinstance(other, int): 

64 z = Node(self.value - other) 

65 self.children.append([1.0, z]) 

66 return z 

67 raise TypeError("Operand must be of type int, float, or node.") 

68 

69 def __rsub__(self, other: float | int) -> Node: 

70 """Subtract a node from a scalar""" 

71 if isinstance(other, int) or isinstance(other, float): 

72 z = Node(other - self.value) 

73 self.children.append([-1.0, z]) 

74 return z 

75 raise TypeError("Operand must be of type int, float, or DualNumber.") 

76 

77 def __mul__(self, other: Node | float | int) -> Node: 

78 """Multiply a node by another node or a scalar""" 

79 if isinstance(other, Node): 

80 z = Node(self.value * other.value) 

81 self.children.append([other.value, z]) 

82 other.children.append([self.value, z]) 

83 return z 

84 elif isinstance(other, float) or isinstance(other, int): 

85 z = Node(self.value * other) 

86 self.children.append([other, z]) 

87 return z 

88 raise TypeError("Operand must be of type int, float, or node.") 

89 

90 def __rmul__(self, other: float | int) -> Node: 

91 """Multiply a scalar by a node""" 

92 return self.__mul__(other) 

93 

94 def __truediv__(self, other: Node | float | int) -> Node: 

95 """Divide a node by another node or a scalar""" 

96 if isinstance(other, Node): 

97 z = Node(self.value / other.value) 

98 self.children.append([1 / other.value, z]) 

99 other.children.append([-self.value / pow(other.value, 2), z]) 

100 return z 

101 elif isinstance(other, float) or isinstance(other, int): 

102 z = Node(self.value / other) 

103 self.children.append([1 / other, z]) 

104 return z 

105 raise TypeError("Operand must be of type int, float, or node.") 

106 

107 def __rtruediv__(self, other: float | int) -> Node: 

108 """Divide a scalar by a node""" 

109 if isinstance(other, float) or isinstance(other, int): 

110 z = Node(other / self.value) 

111 self.children.append([-other / pow(self.value, 2), z]) 

112 return z 

113 raise TypeError("Operand must be of type int, float, or node.") 

114 

115 def __pow__(self, other: Node | float | int) -> Node: 

116 """Raise a node to the power of another node or a scalar""" 

117 if isinstance(other, Node): 

118 z = Node(pow(self.value, other.value)) 

119 self.children.append([other.value * pow(self.value, other.value - 1), z]) 

120 other.children.append([pow(self.value, other.value) * np.log(self.value), z]) 

121 return z 

122 elif isinstance(other, int) or isinstance(other, float): 

123 z = Node(pow(self.value, other)) 

124 self.children.append([other * pow(self.value, other - 1), z]) 

125 return z 

126 raise TypeError("Operand must be of type int, float, or node.") 

127 

128 def __rpow__(self, other: float | int) -> Node: 

129 """Raise a scalar to the power of a node""" 

130 if isinstance(other, int) or isinstance(other, float): 

131 z = Node(pow(other, self.value)) 

132 self.children.append([pow(other, self.value) * np.log(other), z]) 

133 return z 

134 raise TypeError("Operand must be of type int, float, or node.") 

135 

136 def __neg__(self: Node) -> Node: 

137 """Negate a node""" 

138 return -1 * self 

139 

140# Other functions 

141 def exp(self): 

142 """Exponential e^x""" 

143 z = Node(np.exp(self.value)) 

144 self.children.append([np.exp(self.value), z]) 

145 return z 

146 

147 def sqrt(self): 

148 "Square root" 

149 return self ** 0.5 

150 

151 # Builtin log functions of multiple bases 

152 def log(self): 

153 "Natural log" 

154 z = Node(np.log(self.value)) 

155 self.children.append([1.0 / self.value, z]) 

156 return z 

157 

158 def log10(self): 

159 """Log base 10""" 

160 z = Node(np.log10(self.value)) 

161 self.children.append([1.0 / self.value / np.log(10), z]) 

162 return z 

163 

164 def log2(self): 

165 """Log base 2""" 

166 z = Node(np.log2(self.value)) 

167 self.children.append([1.0 / self.value / np.log(2), z]) 

168 return z 

169 

170 # Trigonometric functions 

171 def sin(self): 

172 z = Node(np.sin(self.value)) 

173 self.children.append([np.cos(self.value), z]) 

174 return z 

175 

176 def cos(self): 

177 z = Node(np.cos(self.value)) 

178 self.children.append([-np.sin(self.value), z]) 

179 return z 

180 

181 def tan(self): 

182 z = Node(np.tan(self.value)) 

183 self.children.append([1 / pow(np.cos(self.value), 2), z]) 

184 return z 

185 

186 # Inverse trigonometric functions 

187 def arcsin(self): 

188 z = Node(np.arcsin(self.value)) 

189 self.children.append([pow(1 - pow(self.value, 2), -1 / 2), z]) 

190 return z 

191 

192 def arccos(self): 

193 z = Node(np.arccos(self.value)) 

194 self.children.append([-pow(1 - pow(self.value, 2), -1 / 2), z]) 

195 return z 

196 

197 def arctan(self): 

198 z = Node(np.arctan(self.value)) 

199 self.children.append([pow(1 + pow(self.value, 2), -1), z]) 

200 return z 

201 

202 # Hyperbolic functions 

203 def sinh(self): 

204 z = Node(np.sinh(self.value)) 

205 self.children.append([np.cosh(self.value), z]) 

206 return z 

207 

208 def cosh(self): 

209 z = Node(np.cosh(self.value)) 

210 self.children.append([np.sinh(self.value), z]) 

211 return z 

212 

213 def tanh(self): 

214 z = Node(np.tanh(self.value)) 

215 self.children.append([pow(np.cosh(self.value), -2), z]) 

216 return z 

217 

218 # Inverse hyperbolic functions 

219 def arcsinh(self): 

220 z = Node(np.arcsinh(self.value)) 

221 self.children.append([pow(pow(self.value, 2) + 1, -1/2), z]) 

222 return z 

223 

224 def arccosh(self): 

225 z = Node(np.arccosh(self.value)) 

226 self.children.append([pow(pow(self.value, 2) - 1, -1/2), z]) 

227 return z 

228 

229 def arctanh(self): 

230 z = Node(np.arctanh(self.value)) 

231 self.children.append([pow(1 - pow(self.value, 2), -1), z]) 

232 return z