Coverage for Users/jsd/Library/CloudStorage/OneDrive-SimonFraserUniversity(1sfu)/projects/thztools/tests/consistency_test.py: 100%

60 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-12 16:26 -0700

1import pathlib 

2 

3import h5py # type: ignore 

4import numpy as np 

5import pytest 

6 

7# import matplotlib.pyplot as plt 

8from thztools import ( 

9 costfunlsq, 

10 epswater, 

11 fftfreq, 

12 noiseamp, 

13 noisevar, 

14 shiftmtx, 

15 tdnll, 

16 tdnoisefit, 

17 tdtf, 

18 thzgen, 

19) 

20 

21 

22def tdnll_alt(*args): 

23 kwargs = {"fix_logv": False, "fix_mu": False, "fix_a": False, 

24 "fix_eta": False} 

25 return tdnll(*args, **kwargs) 

26 

27 

28def tdnoisefit_alt(*args): 

29 kwargs = {"fix_v": False, "fix_mu": False, "fix_a": False, 

30 "fix_eta": False, "ignore_a": True, "ignore_eta": False} 

31 out, _, _ = tdnoisefit(*args, **kwargs) 

32 out_alt = [out["var"], out["mu"], out["a"], out["eta"]] 

33 return out_alt 

34 

35 

36# Establish dictionary mapping from function names to functions 

37FUNC_DICT = { 

38 "fftfreq": fftfreq, 

39 "epswater": epswater, 

40 "thzgen": thzgen, 

41 "noisevar": noisevar, 

42 "noiseamp": noiseamp, 

43 "shiftmtx": shiftmtx, 

44 "tdtf": tdtf, 

45 "costfunlsq": costfunlsq, 

46 "tdnll": tdnll_alt, 

47 "tdnoisefit": tdnoisefit_alt, 

48} 

49 

50# Set MAT-file path 

51cur_path = pathlib.Path(__file__).parents[0].resolve() 

52f_path = cur_path / "matlab" / "thztools_test_data.mat" 

53 

54 

55def tfun_test(_theta, _w): 

56 return _theta[0] * np.exp(-1j * _w * _theta[1]) 

57 

58 

59# Read test array from MAT-file 

60def get_matlab_tests(): 

61 with h5py.File(f_path, "r") as f_obj: 

62 # The MAT-file stores a structure named "Set" with a field for each 

63 # function in the test set. Get the field names (which are also the 

64 # function names) and loop over them. 

65 func_names = list(f_obj["Set"].keys()) 

66 test_list = [] 

67 for func_name in func_names: 

68 # The MATLAB inputs and outputs for each test configuration are 

69 # stored in the HDF5 dataset arrays "args" and "out", respectively. 

70 # The "[()]" index converts the HDF5 dataset arrays to NumPy 

71 # arrays for easier manipulation, such as flattening. 

72 arg_refs = f_obj["Set"][func_name]["args"][()].flatten() 

73 out_refs = f_obj["Set"][func_name]["out"][()].flatten() 

74 # Get the elements of the "args" and "out" arrays and eliminate 

75 # extraneous array dimensions. 

76 for arg_ref, out_ref in zip(arg_refs, out_refs): 

77 args_val_list = [] 

78 out_val_list = [] 

79 arg_val_refs = f_obj[arg_ref][()].flatten() 

80 for arg_val_ref in arg_val_refs: 

81 # MATLAB apparently writes 2D arrays to HDF5 files as 

82 # transposed C-order arrays, so we need to transpose them 

83 # back after reading them in. 

84 arg_val = np.squeeze(f_obj[arg_val_ref][()]).T 

85 # Convert scalar arrays to scalars 

86 if arg_val.shape == (): 

87 arg_val = arg_val[()] 

88 args_val_list.append(arg_val) 

89 out_val_refs = f_obj[out_ref][()].flatten() 

90 for out_val_ref in out_val_refs: 

91 # MATLAB apparently writes 2D arrays to HDF5 files as 

92 # transposed C-order arrays, so we need to transpose them 

93 # back after reading them in. 

94 out_val = np.squeeze(f_obj[out_val_ref][()]).T 

95 # Convert scalar arrays to scalars 

96 if out_val.shape == (): 

97 out_val = out_val[()] 

98 # The MAT-file stores complex numbers as tuples with a 

99 # composite dtype. Convert these to NumPy complex dtypes. 

100 if ( 

101 out_val.dtype.names is not None 

102 and "real" in out_val.dtype.names 

103 and "imag" in out_val.dtype.names 

104 ): 

105 out_val = out_val["real"] + 1j * out_val["imag"] 

106 out_val_list.append(out_val) 

107 test_list.append([func_name, args_val_list, out_val_list]) 

108 return test_list 

109 

110 

111@pytest.fixture(params=get_matlab_tests()) 

112def get_test(request): 

113 return request.param 

114 

115 

116def test_matlab_result(get_test): 

117 func_name = get_test[0] 

118 func = FUNC_DICT[func_name] 

119 args = get_test[1] 

120 matlab_out = get_test[2] 

121 if func_name in ["tdtf", "costfunlsq"]: 

122 python_out = func(tfun_test, *args) 

123 else: 

124 python_out = func(*args) 

125 # Ignore second output from Python version of thzgen 

126 if func_name in ["thzgen", "tdnll", "tdnoisefit"]: 

127 python_out = python_out[0] 

128 if func_name != "tdnoisefit": 

129 # Set absolute tolerance equal to 2 * epsilon for the array dtype 

130 np.testing.assert_allclose( 

131 matlab_out[0], python_out, atol=2 * np.finfo(python_out.dtype).eps 

132 ) 

133 else: 

134 np.testing.assert_allclose(matlab_out[0], python_out, atol=1e-2)