Coverage for jstark/features/feature.py: 100%

85 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-25 20:09 +0000

1"""Feature abstract base class 

2 

3All feature classes are derived from here 

4""" 

5from abc import ABCMeta, abstractmethod 

6from datetime import date, timedelta, datetime 

7from typing import Callable, Dict 

8from dateutil.relativedelta import relativedelta 

9 

10 

11from pyspark.sql import Column 

12import pyspark.sql.functions as f 

13 

14from jstark.feature_period import FeaturePeriod, PeriodUnitOfMeasure 

15from jstark.features.first_and_last_date_of_period import FirstAndLastDateOfPeriod 

16from jstark.exceptions import AsAtIsNotADate 

17 

18 

19class Feature(metaclass=ABCMeta): 

20 def __init__(self, as_at: date, feature_period: FeaturePeriod) -> None: 

21 self.feature_period = feature_period 

22 self.as_at = as_at 

23 

24 @property 

25 def feature_period(self) -> FeaturePeriod: 

26 return self.__feature_period 

27 

28 @feature_period.setter 

29 def feature_period(self, value) -> None: 

30 self.__feature_period = value 

31 

32 @property 

33 def as_at(self) -> date: 

34 return self.__as_at 

35 

36 @as_at.setter 

37 def as_at(self, value) -> None: 

38 if not isinstance(value, date): 

39 raise AsAtIsNotADate 

40 self.__as_at = value 

41 

42 @property 

43 def feature_name(self) -> str: 

44 return f"{type(self).__name__}_{self.feature_period.mnemonic}" 

45 

46 @property 

47 @abstractmethod 

48 def column(self) -> Column: 

49 """Complete definition of the column returned by this feature, 

50 replete with feature period filtering, metadata, default value 

51 and alias""" 

52 

53 @property 

54 @abstractmethod 

55 def description_subject(self) -> str: 

56 """Desciption of the feature that will be concatenated 

57 with an explanation of the feature period. 

58 """ 

59 

60 @property 

61 def commentary(self) -> str: 

62 return "No commentary supplied" 

63 

64 @abstractmethod 

65 def default_value(self) -> Column: 

66 """Default value of the feature, typically used when zero rows match 

67 the feature's feature_period 

68 """ 

69 

70 @abstractmethod 

71 def column_expression(self) -> Column: 

72 """The expression that defines the feature""" 

73 

74 @property 

75 def start_date(self) -> date: 

76 n_days_ago = self.as_at - timedelta(days=self.feature_period.start) 

77 n_weeks_ago = self.as_at - timedelta(weeks=self.feature_period.start) 

78 n_months_ago = self.as_at - relativedelta(months=self.feature_period.start) 

79 n_quarters_ago = self.as_at - relativedelta( 

80 months=self.feature_period.start * 3 

81 ) 

82 n_years_ago = self.as_at - relativedelta(years=self.feature_period.start) 

83 return ( 

84 n_days_ago 

85 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.DAY 

86 else FirstAndLastDateOfPeriod(n_weeks_ago).first_date_in_week 

87 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.WEEK 

88 else FirstAndLastDateOfPeriod(n_months_ago).first_date_in_month 

89 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.MONTH 

90 else FirstAndLastDateOfPeriod(n_quarters_ago).first_date_in_quarter 

91 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.QUARTER 

92 else FirstAndLastDateOfPeriod(n_years_ago).first_date_in_year 

93 ) 

94 

95 @property 

96 def end_date(self) -> date: 

97 n_days_ago = self.as_at - timedelta(days=self.feature_period.end) 

98 n_weeks_ago = self.as_at - timedelta(weeks=self.feature_period.end) 

99 n_months_ago = self.as_at - relativedelta(months=self.feature_period.end) 

100 n_quarters_ago = self.as_at - relativedelta(months=self.feature_period.end * 3) 

101 n_years_ago = self.as_at - relativedelta(years=self.feature_period.end) 

102 last_day_of_period = ( 

103 n_days_ago 

104 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.DAY 

105 else FirstAndLastDateOfPeriod(n_weeks_ago).last_date_in_week 

106 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.WEEK 

107 else FirstAndLastDateOfPeriod(n_months_ago).last_date_in_month 

108 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.MONTH 

109 else FirstAndLastDateOfPeriod(n_quarters_ago).last_date_in_quarter 

110 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.QUARTER 

111 else FirstAndLastDateOfPeriod(n_years_ago).last_date_in_year 

112 ) 

113 # min() is used to ensure we don't return a date later than self.as_at 

114 return min(last_day_of_period, self.as_at) 

115 

116 @property 

117 def column_metadata(self) -> Dict[str, str]: 

118 return { 

119 "created-with-love-by": "https://github.com/jamiekt/jstark", 

120 "start-date": self.start_date.strftime("%Y-%m-%d"), 

121 "end-date": self.end_date.strftime("%Y-%m-%d"), 

122 "description": ( 

123 f"{self.description_subject} between " 

124 + f'{self.start_date.strftime("%Y-%m-%d")} and ' 

125 + f'{self.end_date.strftime("%Y-%m-%d")}' 

126 ), 

127 "generated-at": datetime.now().strftime("%Y-%m-%d"), 

128 "commentary": self.commentary, 

129 "name-stem": str(type(self).__name__), 

130 } 

131 

132 

133class DerivedFeature(Feature, metaclass=ABCMeta): 

134 """A DerivedFeature is a feature that is calculated by combining 

135 data that has already been aggregated. For example, a derived 

136 feature called 'Average Gross Spend Per Basket' would be calculated 

137 by dividing the total GrossSpend by number of baskets (BasketCount) 

138 """ 

139 

140 @property 

141 def column(self) -> Column: 

142 return f.coalesce(self.column_expression(), self.default_value()).alias( 

143 self.feature_name, metadata=self.column_metadata 

144 ) 

145 

146 

147class BaseFeature(Feature, metaclass=ABCMeta): 

148 """A BaseFeature is a feature that is calculated by aggregating 

149 raw source data. That data may have been cleaned and transformed in 

150 some way, but typically the grain of that data is real occurrences 

151 of some activity. Examples of such data are lists of grocery 

152 transactions, phone calls or journeys. 

153 """ 

154 

155 def sum_aggregator(self, column: Column) -> Column: 

156 return f.sum(column) 

157 

158 def count_aggregator(self, column: Column) -> Column: 

159 return f.count(column) 

160 

161 def count_distinct_aggregator(self, column: Column) -> Column: 

162 return f.countDistinct(column) 

163 

164 def approx_count_distinct_aggregator(self, column: Column) -> Column: 

165 return f.approx_count_distinct(column) 

166 

167 def max_aggregator(self, column: Column) -> Column: 

168 return f.max(column) 

169 

170 def min_aggregator(self, column: Column) -> Column: 

171 return f.min(column) 

172 

173 @abstractmethod 

174 def aggregator(self) -> Callable[[Column], Column]: 

175 """Aggregator function""" 

176 

177 @property 

178 def column(self) -> Column: 

179 return f.coalesce( 

180 self.aggregator()( 

181 f.when( 

182 (f.to_date(f.col("Timestamp")) >= f.lit(self.start_date)) 

183 & (f.to_date(f.col("Timestamp")) <= f.lit(self.end_date)), 

184 self.column_expression(), 

185 ) 

186 ), 

187 self.default_value(), 

188 ).alias(self.feature_name, metadata=self.column_metadata)