Coverage for Adifpy/differentiate/evaluator.py: 72%

29 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-22 19:44 -0500

1"""Automatic Differentiation object""" 

2 

3from typing import Callable 

4 

5import numpy as np 

6 

7from Adifpy.differentiate.forward_mode import forward_mode 

8from Adifpy.differentiate.reverse_mode import reverse_mode 

9 

10 

11class Evaluator: 

12 """AD evaluation object 

13 

14 >>> my_evaluator = Evaluator(lambda x: x*x) 

15 >>> my_evaluator.eval(1) 

16 (1, 2) 

17 >>> my_evaluator.eval(3) 

18 (9, 9) 

19 """ 

20 

21 def __init__(self, fn: Callable): 

22 self.fn = fn 

23 

24 def eval(self, pt, **kwargs): 

25 """Perform AD on this Evaluator's function, at this point 

26 

27 Args: 

28 pt (float | iterable): the point or vector at which to evaluate the function 

29 seed_vector (iterable, optional): the seed vector, if the function has vector input 

30 force_mode (str, optional): either 'forward' or 'reverse' for forcing AD mode 

31  

32 Returns: 

33 ADEvaluated: the evaluated AD object 

34 """ 

35 shape = np.shape(pt) 

36 self.input_dim = 1 if shape == () else shape[0] 

37 

38 # Ensure that a seed vector is provided for vector functions 

39 if self.input_dim != 1 and 'seed_vector' not in kwargs: 

40 raise AttributeError('For vector functions, `seed_vector` argument is required') 

41 elif 'seed_vector' not in kwargs: 

42 kwargs['seed_vector'] = [1] 

43 

44 # Set the output dimension (and ensure the function is valid) 

45 try: 

46 fn_output = self.fn(pt) 

47 

48 # TODO: Check for invalid functions (null returns, etc) 

49 

50 self.output_dim = 1 if type(fn_output) in [int, float] else len(fn_output) 

51 except Exception as error: 

52 raise RuntimeError('Evaluator function failed') from error 

53 

54 # Decide which AD mode to use, either depending on forced user input or optimized for performance 

55 if 'force_mode' in kwargs: 

56 match kwargs['force_mode']: 

57 case 'forward': 

58 differentiator = forward_mode 

59 case 'reverse': 

60 differentiator = reverse_mode 

61 case _: 

62 raise ValueError('`force_mode` argument must be either `forward` or `reverse`') 

63 else: 

64 differentiator = forward_mode if self.input_dim < self.output_dim else reverse_mode 

65 

66 return differentiator(func=self.fn, pt=pt, seed_vector=kwargs['seed_vector'])