Coverage for sleapyfaces/experiment.py: 40%

92 statements  

« prev     ^ index     » next       coverage.py v7.0.2, created at 2023-01-03 12:07 -0800

1from dataclasses import dataclass 

2from sleapyfaces.io import SLEAPanalysis, BehMetadata, VideoMetadata, DAQData 

3from sleapyfaces.structs import FileConstructor, CustomColumn 

4 

5from sleapyfaces.utils import into_trial_format, reduce_daq, flatten_list 

6from sleapyfaces.normalize import mean_center, z_score 

7 

8import pandas as pd 

9import numpy as np 

10 

11@dataclass(slots=True) 

12class Experiment: 

13 """Class constructor for the Experiment object. 

14 

15 Args: 

16 name (str): The name of the experiment. 

17 files (FileConstructor): The FileConstructor object containing the paths to the experiment files. 

18 

19 Attributes: 

20 name (str): The name of the experiment. 

21 files (FileConstructor): The FileConstructor object containing the paths to the experiment files. 

22 sleap (SLEAPanalysis): The SLEAPanalysis object containing the SLEAP data. 

23 beh (BehMetadata): The BehMetadata object containing the behavior metadata. 

24 video (VideoMetadata): The VideoMetadata object containing the video metadata. 

25 daq (DAQData): The DAQData object containing the DAQ data. 

26 numeric_columns (list[str]): A list of the titles of the numeric columns in the SLEAP data. 

27 """ 

28 

29 name: str 

30 files: FileConstructor 

31 sleap: SLEAPanalysis 

32 beh: BehMetadata 

33 video: VideoMetadata 

34 daq: DAQData 

35 numeric_columns: list[str] 

36 rawData: pd.DataFrame 

37 data: pd.DataFrame 

38 trialsList: list[pd.DataFrame] 

39 trialData: pd.DataFrame 

40 

41 

42 def __init__(self, name: str, files: FileConstructor): 

43 self.name = name 

44 self.files = files 

45 self.sleap = SLEAPanalysis(self.files.sleap.file) 

46 self.beh = BehMetadata(self.files.beh.file) 

47 self.video = VideoMetadata(self.files.video.file) 

48 self.daq = DAQData(self.files.daq.file) 

49 self.numeric_columns = self.sleap.track_names 

50 self.rawData = self.data 

51 

52 @property 

53 def data(self): 

54 return self.sleap.tracks 

55 

56 @data.setter 

57 def data(self, value): 

58 self.sleap.append(value) 

59 

60 def buildData(self, CustomColumns: list[CustomColumn]): 

61 """Builds the data for the experiment. 

62 

63 Args: 

64 CustomColumns (list[CustomColumn]): A list of the CustomColumn objects to be added to the experiment. 

65 

66 Raises: 

67 ValueError: If the columns cannot be appended to the SLEAP data. 

68 

69 Returns: 

70 None 

71 

72 Initializes attributes: 

73 sleap.tracks (pd.DataFrame): The SLEAP data. 

74 custom_columns (pd.DataFrame): The non-numeric columns. 

75 """ 

76 self.custom_columns = [0] * (len(self.sleap.tracks.index) + len(CustomColumns)) 

77 col_names = [0] * (len(CustomColumns) + 2) 

78 for i, col in enumerate(CustomColumns): 

79 col_names[i] = col.ColumnTitle 

80 col.buildColumn(len(self.sleap.tracks.index)) 

81 self.custom_columns[i] = col.Column 

82 self.custom_columns[i].reset_index(inplace=True) 

83 ms_per_frame = (self.video.fps**-1) * 1000 

84 for i in range(len(self.sleap.tracks.index)): 

85 self.custom_columns[(i + len(CustomColumns))] = pd.DataFrame( 

86 {"Timestamps": [i * ms_per_frame], "Frames": [i]}, 

87 columns=["Timestamps", "Frames"], 

88 ) 

89 col_names[len(CustomColumns)] = "Timestamps" 

90 col_names[len(CustomColumns) + 1] = "Frames" 

91 self.custom_columns[len(CustomColumns)] = pd.concat( 

92 self.custom_columns[len(CustomColumns) :], axis=0 

93 ) 

94 self.custom_columns[len(CustomColumns)].reset_index(inplace=True) 

95 self.custom_columns = pd.concat( 

96 self.custom_columns[: (len(CustomColumns) + 1)], axis=1 

97 ) 

98 self.data = self.custom_columns.loc[:, col_names] 

99 

100 def buildTrials( 

101 self, 

102 TrackedData: list[str], 

103 Reduced: list[bool], 

104 start_buffer: int = 10000, 

105 end_buffer: int = 13000, 

106 ): 

107 """Converts the data into trial by trial format. 

108 

109 Args: 

110 TrackedData (list[str]): the list of columns from the DAQ data that signify the START of each trial. 

111 Reduced (list[bool]): a boolean list with the same length as the TrackedData list that signifies the columns from the tracked data with quick TTL pulses that occur during the trial. 

112 (e.g. the LED TTL pulse may signify the beginning of a trial, but during the trial the LED turns on and off, so the LED TTL column should be marked as True) 

113 start_buffer (int, optional): The time in miliseconds you want to capture before the trial starts. Defaults to 10000 (i.e. 10 seconds). 

114 end_buffer (int, optional): The time in miliseconds you want to capture after the trial starts. Defaults to 13000 (i.e. 13 seconds). 

115 

116 Raises: 

117 ValueError: if the length of the TrackedData and Reduced lists are not equal. 

118 

119 Initializes attributes: 

120 trials (pd.DataFrame): the dataframe with the data in trial by trial format, with a metaindex of trial number and frame number 

121 trialData (list[pd.DataFrame]): a list of the dataframes with the individual trial data. 

122 """ 

123 

124 if len(Reduced) != len(TrackedData): 

125 raise ValueError( 

126 "The number of Reduced arguments must be equal to the number of TrackedData arguments. NOTE: If you do not want to reduce the data, pass in a list of False values." 

127 ) 

128 

129 start_indecies = [0] * len(TrackedData) 

130 end_indecies = [0] * len(TrackedData) 

131 timestamps = self.custom_columns.loc[:, "Timestamps"].to_numpy(dtype=np.float64) 

132 

133 for data, reduce, i in zip(TrackedData, Reduced, range(len(TrackedData))): 

134 

135 if reduce: 

136 times = pd.Series(self.daq.cache.loc[:, data]) 

137 times = times[times != 0] 

138 times = reduce_daq(times.to_list()) 

139 times = np.array(times, dtype=np.float64) 

140 

141 else: 

142 times = pd.Series(self.daq.cache.loc[:, data]) 

143 times = times[times != 0] 

144 times = times.to_numpy(dtype=np.float64, na_value=0) 

145 

146 times = times[times != 0] 

147 

148 start_indecies[i] = [0] * len(times) 

149 end_indecies[i] = [0] * len(times) 

150 

151 for j, time in enumerate(times): 

152 start_indecies[i][j] = int( 

153 np.absolute(timestamps - (time - start_buffer)).argmin() 

154 ) 

155 end_indecies[i][j] = int( 

156 (np.absolute(timestamps - (time + end_buffer)).argmin() + 1) 

157 ) 

158 

159 if type(start_indecies) is not list and type(start_indecies[0]) is not list: 

160 raise TypeError( 

161 "The start indecies are not in the correct format in the DAQ data." 

162 ) 

163 

164 start_indecies = flatten_list(start_indecies) 

165 end_indecies = flatten_list(end_indecies) 

166 

167 if len(start_indecies) != len(end_indecies): 

168 raise ValueError( 

169 "The number of start indecies does not match the number of end indecies." 

170 ) 

171 

172 start_indecies = np.unique(np.array(start_indecies, dtype=np.int64)) 

173 end_indecies = np.unique(np.array(end_indecies, dtype=np.int64)) 

174 

175 self.trialsList = into_trial_format( 

176 self.sleap.tracks, 

177 self.beh.cache.loc[:, "trialArray"], 

178 start_indecies, 

179 end_indecies, 

180 ) 

181 self.trialsList = [i for i in self.trialsList if isinstance(i, pd.DataFrame)] 

182 self.trialData = pd.concat( 

183 self.trialsList, axis=0, keys=[i for i in range(len(self.trialsList))] 

184 ) 

185 

186 def normalizeTrials(self): 

187 """Normalizes the data to the first frame of the first trial. 

188 

189 Args: 

190 data (pd.DataFrame): the data you want to normalize. 

191 

192 Returns: 

193 pd.DataFrame: the normalized data. 

194 """ 

195 if isinstance(self.trialsList, list) and isinstance(self.trialsList[0], pd.DataFrame): 

196 trials = [0] * len(self.trialsList) 

197 for i, trial in enumerate(self.trialsList): 

198 trials[i] = mean_center(trial, self.numeric_columns) 

199 self.trialData = pd.concat( 

200 trials, axis=0, keys=[i for i in range(len(trials))] 

201 ) 

202 self.trialData = z_score(self.trialData, self.numeric_columns) 

203 else: 

204 raise TypeError("The data is not in the correct format. Make sure you first run the buildTrials method.")