Coverage for Users/jsd/Library/CloudStorage/OneDrive-SimonFraserUniversity(1sfu)/projects/thztools/tests/consistency_test.py: 100%
60 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-12 16:26 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-12 16:26 -0700
1import pathlib
3import h5py # type: ignore
4import numpy as np
5import pytest
7# import matplotlib.pyplot as plt
8from thztools import (
9 costfunlsq,
10 epswater,
11 fftfreq,
12 noiseamp,
13 noisevar,
14 shiftmtx,
15 tdnll,
16 tdnoisefit,
17 tdtf,
18 thzgen,
19)
22def tdnll_alt(*args):
23 kwargs = {"fix_logv": False, "fix_mu": False, "fix_a": False,
24 "fix_eta": False}
25 return tdnll(*args, **kwargs)
28def tdnoisefit_alt(*args):
29 kwargs = {"fix_v": False, "fix_mu": False, "fix_a": False,
30 "fix_eta": False, "ignore_a": True, "ignore_eta": False}
31 out, _, _ = tdnoisefit(*args, **kwargs)
32 out_alt = [out["var"], out["mu"], out["a"], out["eta"]]
33 return out_alt
36# Establish dictionary mapping from function names to functions
37FUNC_DICT = {
38 "fftfreq": fftfreq,
39 "epswater": epswater,
40 "thzgen": thzgen,
41 "noisevar": noisevar,
42 "noiseamp": noiseamp,
43 "shiftmtx": shiftmtx,
44 "tdtf": tdtf,
45 "costfunlsq": costfunlsq,
46 "tdnll": tdnll_alt,
47 "tdnoisefit": tdnoisefit_alt,
48}
50# Set MAT-file path
51cur_path = pathlib.Path(__file__).parents[0].resolve()
52f_path = cur_path / "matlab" / "thztools_test_data.mat"
55def tfun_test(_theta, _w):
56 return _theta[0] * np.exp(-1j * _w * _theta[1])
59# Read test array from MAT-file
60def get_matlab_tests():
61 with h5py.File(f_path, "r") as f_obj:
62 # The MAT-file stores a structure named "Set" with a field for each
63 # function in the test set. Get the field names (which are also the
64 # function names) and loop over them.
65 func_names = list(f_obj["Set"].keys())
66 test_list = []
67 for func_name in func_names:
68 # The MATLAB inputs and outputs for each test configuration are
69 # stored in the HDF5 dataset arrays "args" and "out", respectively.
70 # The "[()]" index converts the HDF5 dataset arrays to NumPy
71 # arrays for easier manipulation, such as flattening.
72 arg_refs = f_obj["Set"][func_name]["args"][()].flatten()
73 out_refs = f_obj["Set"][func_name]["out"][()].flatten()
74 # Get the elements of the "args" and "out" arrays and eliminate
75 # extraneous array dimensions.
76 for arg_ref, out_ref in zip(arg_refs, out_refs):
77 args_val_list = []
78 out_val_list = []
79 arg_val_refs = f_obj[arg_ref][()].flatten()
80 for arg_val_ref in arg_val_refs:
81 # MATLAB apparently writes 2D arrays to HDF5 files as
82 # transposed C-order arrays, so we need to transpose them
83 # back after reading them in.
84 arg_val = np.squeeze(f_obj[arg_val_ref][()]).T
85 # Convert scalar arrays to scalars
86 if arg_val.shape == ():
87 arg_val = arg_val[()]
88 args_val_list.append(arg_val)
89 out_val_refs = f_obj[out_ref][()].flatten()
90 for out_val_ref in out_val_refs:
91 # MATLAB apparently writes 2D arrays to HDF5 files as
92 # transposed C-order arrays, so we need to transpose them
93 # back after reading them in.
94 out_val = np.squeeze(f_obj[out_val_ref][()]).T
95 # Convert scalar arrays to scalars
96 if out_val.shape == ():
97 out_val = out_val[()]
98 # The MAT-file stores complex numbers as tuples with a
99 # composite dtype. Convert these to NumPy complex dtypes.
100 if (
101 out_val.dtype.names is not None
102 and "real" in out_val.dtype.names
103 and "imag" in out_val.dtype.names
104 ):
105 out_val = out_val["real"] + 1j * out_val["imag"]
106 out_val_list.append(out_val)
107 test_list.append([func_name, args_val_list, out_val_list])
108 return test_list
111@pytest.fixture(params=get_matlab_tests())
112def get_test(request):
113 return request.param
116def test_matlab_result(get_test):
117 func_name = get_test[0]
118 func = FUNC_DICT[func_name]
119 args = get_test[1]
120 matlab_out = get_test[2]
121 if func_name in ["tdtf", "costfunlsq"]:
122 python_out = func(tfun_test, *args)
123 else:
124 python_out = func(*args)
125 # Ignore second output from Python version of thzgen
126 if func_name in ["thzgen", "tdnll", "tdnoisefit"]:
127 python_out = python_out[0]
128 if func_name != "tdnoisefit":
129 # Set absolute tolerance equal to 2 * epsilon for the array dtype
130 np.testing.assert_allclose(
131 matlab_out[0], python_out, atol=2 * np.finfo(python_out.dtype).eps
132 )
133 else:
134 np.testing.assert_allclose(matlab_out[0], python_out, atol=1e-2)