Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

import collections 

import copy 

import functools 

import os 

 

from .cassette import Cassette 

from .serializers import yamlserializer, jsonserializer 

from . import matchers 

from . import filters 

 

 

class VCR(object): 

 

    def __init__(self, serializer='yaml', cassette_library_dir=None, 

                 record_mode="once", filter_headers=(), ignore_localhost=False, 

                 custom_patches=(), filter_query_parameters=(), 

                 filter_post_data_parameters=(), before_record_request=None, 

                 before_record_response=None, ignore_hosts=(), 

                 match_on=('method', 'scheme', 'host', 'port', 'path', 'query'), 

                 before_record=None, inject_cassette=False): 

        self.serializer = serializer 

        self.match_on = match_on 

        self.cassette_library_dir = cassette_library_dir 

        self.serializers = { 

            'yaml': yamlserializer, 

            'json': jsonserializer, 

        } 

        self.matchers = { 

            'method': matchers.method, 

            'uri': matchers.uri, 

            'url': matchers.uri,  # matcher for backwards compatibility 

            'scheme': matchers.scheme, 

            'host': matchers.host, 

            'port': matchers.port, 

            'path': matchers.path, 

            'query': matchers.query, 

            'headers': matchers.headers, 

            'body': matchers.body, 

        } 

        self.record_mode = record_mode 

        self.filter_headers = filter_headers 

        self.filter_query_parameters = filter_query_parameters 

        self.filter_post_data_parameters = filter_post_data_parameters 

        self.before_record_request = before_record_request or before_record 

        self.before_record_response = before_record_response 

        self.ignore_hosts = ignore_hosts 

        self.ignore_localhost = ignore_localhost 

        self.inject_cassette = inject_cassette 

        self._custom_patches = tuple(custom_patches) 

 

    def _get_serializer(self, serializer_name): 

        try: 

            serializer = self.serializers[serializer_name] 

        except KeyError: 

            print("Serializer {0} doesn't exist or isn't registered".format( 

                serializer_name 

            )) 

            raise KeyError 

        return serializer 

 

    def _get_matchers(self, matcher_names): 

        matchers = [] 

        try: 

            for m in matcher_names: 

                matchers.append(self.matchers[m]) 

        except KeyError: 

            raise KeyError( 

                "Matcher {0} doesn't exist or isn't registered".format(m) 

            ) 

        return matchers 

 

    def use_cassette(self, path, with_current_defaults=False, **kwargs): 

        if with_current_defaults: 

            path, config = self.get_path_and_merged_config(path, **kwargs) 

            return Cassette.use(path, **config) 

        # This is made a function that evaluates every time a cassette 

        # is made so that changes that are made to this VCR instance 

        # that occur AFTER the `use_cassette` decorator is applied 

        # still affect subsequent calls to the decorated function. 

        args_getter = functools.partial(self.get_path_and_merged_config, 

                                        path, **kwargs) 

        return Cassette.use_arg_getter(args_getter) 

 

    def get_path_and_merged_config(self, path, **kwargs): 

        serializer_name = kwargs.get('serializer', self.serializer) 

        matcher_names = kwargs.get('match_on', self.match_on) 

        cassette_library_dir = kwargs.get( 

            'cassette_library_dir', 

            self.cassette_library_dir 

        ) 

        if cassette_library_dir: 

            path = os.path.join(cassette_library_dir, path) 

 

        merged_config = { 

            'serializer': self._get_serializer(serializer_name), 

            'match_on': self._get_matchers(matcher_names), 

            'record_mode': kwargs.get('record_mode', self.record_mode), 

            'before_record_request': self._build_before_record_request(kwargs), 

            'before_record_response': self._build_before_record_response( 

                kwargs 

            ), 

            'custom_patches': self._custom_patches + kwargs.get( 

                'custom_patches', () 

            ), 

            'inject': kwargs.get('inject_cassette', self.inject_cassette) 

        } 

        return path, merged_config 

 

    def _build_before_record_response(self, options): 

        before_record_response = options.get( 

            'before_record_response', self.before_record_response 

        ) 

        filter_functions = [] 

        if before_record_response and not isinstance(before_record_response, 

                                                     collections.Iterable): 

            before_record_response = (before_record_response,) 

            for function in before_record_response: 

                filter_functions.append(function) 

        def before_record_response(response): 

            for function in filter_functions: 

                if response is None: 

                    break 

                response = function(response) 

            return response 

        return before_record_response 

 

    def _build_before_record_request(self, options): 

        filter_functions = [] 

        filter_headers = options.get( 

            'filter_headers', self.filter_headers 

        ) 

        filter_query_parameters = options.get( 

            'filter_query_parameters', self.filter_query_parameters 

        ) 

        filter_post_data_parameters = options.get( 

            'filter_post_data_parameters', self.filter_post_data_parameters 

        ) 

        before_record_request = options.get( 

            "before_record_request", options.get("before_record", self.before_record_request) 

        ) 

        ignore_hosts = options.get( 

            'ignore_hosts', self.ignore_hosts 

        ) 

        ignore_localhost = options.get( 

            'ignore_localhost', self.ignore_localhost 

        ) 

        if filter_headers: 

            filter_functions.append(functools.partial(filters.remove_headers, 

                                                      headers_to_remove=filter_headers)) 

        if filter_query_parameters: 

            filter_functions.append(functools.partial(filters.remove_query_parameters, 

                                                      query_parameters_to_remove=filter_query_parameters)) 

        if filter_post_data_parameters: 

            filter_functions.append(functools.partial(filters.remove_post_data_parameters, 

                                                      post_data_parameters_to_remove=filter_post_data_parameters)) 

 

        hosts_to_ignore = list(ignore_hosts) 

        if ignore_localhost: 

            hosts_to_ignore.extend(('localhost', '0.0.0.0', '127.0.0.1')) 

 

        if hosts_to_ignore: 

            hosts_to_ignore = set(hosts_to_ignore) 

            filter_functions.append(self._build_ignore_hosts(hosts_to_ignore)) 

 

        if before_record_request: 

            if not isinstance(before_record_request, collections.Iterable): 

                before_record_request = (before_record_request,) 

            for function in before_record_request: 

                filter_functions.append(function) 

        def before_record_request(request): 

            request = copy.copy(request) 

            for function in filter_functions: 

                if request is None: 

                    break 

                request = function(request) 

            return request 

 

        return before_record_request 

 

    @staticmethod 

    def _build_ignore_hosts(hosts_to_ignore): 

        def filter_ignored_hosts(request): 

            if hasattr(request, 'host') and request.host in hosts_to_ignore: 

                return 

            return request 

        return filter_ignored_hosts 

 

    def register_serializer(self, name, serializer): 

        self.serializers[name] = serializer 

 

    def register_matcher(self, name, matcher): 

        self.matchers[name] = matcher