Coverage for jstark/features/feature.py: 100%
85 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-25 20:09 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-25 20:09 +0000
1"""Feature abstract base class
3All feature classes are derived from here
4"""
5from abc import ABCMeta, abstractmethod
6from datetime import date, timedelta, datetime
7from typing import Callable, Dict
8from dateutil.relativedelta import relativedelta
11from pyspark.sql import Column
12import pyspark.sql.functions as f
14from jstark.feature_period import FeaturePeriod, PeriodUnitOfMeasure
15from jstark.features.first_and_last_date_of_period import FirstAndLastDateOfPeriod
16from jstark.exceptions import AsAtIsNotADate
19class Feature(metaclass=ABCMeta):
20 def __init__(self, as_at: date, feature_period: FeaturePeriod) -> None:
21 self.feature_period = feature_period
22 self.as_at = as_at
24 @property
25 def feature_period(self) -> FeaturePeriod:
26 return self.__feature_period
28 @feature_period.setter
29 def feature_period(self, value) -> None:
30 self.__feature_period = value
32 @property
33 def as_at(self) -> date:
34 return self.__as_at
36 @as_at.setter
37 def as_at(self, value) -> None:
38 if not isinstance(value, date):
39 raise AsAtIsNotADate
40 self.__as_at = value
42 @property
43 def feature_name(self) -> str:
44 return f"{type(self).__name__}_{self.feature_period.mnemonic}"
46 @property
47 @abstractmethod
48 def column(self) -> Column:
49 """Complete definition of the column returned by this feature,
50 replete with feature period filtering, metadata, default value
51 and alias"""
53 @property
54 @abstractmethod
55 def description_subject(self) -> str:
56 """Desciption of the feature that will be concatenated
57 with an explanation of the feature period.
58 """
60 @property
61 def commentary(self) -> str:
62 return "No commentary supplied"
64 @abstractmethod
65 def default_value(self) -> Column:
66 """Default value of the feature, typically used when zero rows match
67 the feature's feature_period
68 """
70 @abstractmethod
71 def column_expression(self) -> Column:
72 """The expression that defines the feature"""
74 @property
75 def start_date(self) -> date:
76 n_days_ago = self.as_at - timedelta(days=self.feature_period.start)
77 n_weeks_ago = self.as_at - timedelta(weeks=self.feature_period.start)
78 n_months_ago = self.as_at - relativedelta(months=self.feature_period.start)
79 n_quarters_ago = self.as_at - relativedelta(
80 months=self.feature_period.start * 3
81 )
82 n_years_ago = self.as_at - relativedelta(years=self.feature_period.start)
83 return (
84 n_days_ago
85 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.DAY
86 else FirstAndLastDateOfPeriod(n_weeks_ago).first_date_in_week
87 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.WEEK
88 else FirstAndLastDateOfPeriod(n_months_ago).first_date_in_month
89 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.MONTH
90 else FirstAndLastDateOfPeriod(n_quarters_ago).first_date_in_quarter
91 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.QUARTER
92 else FirstAndLastDateOfPeriod(n_years_ago).first_date_in_year
93 )
95 @property
96 def end_date(self) -> date:
97 n_days_ago = self.as_at - timedelta(days=self.feature_period.end)
98 n_weeks_ago = self.as_at - timedelta(weeks=self.feature_period.end)
99 n_months_ago = self.as_at - relativedelta(months=self.feature_period.end)
100 n_quarters_ago = self.as_at - relativedelta(months=self.feature_period.end * 3)
101 n_years_ago = self.as_at - relativedelta(years=self.feature_period.end)
102 last_day_of_period = (
103 n_days_ago
104 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.DAY
105 else FirstAndLastDateOfPeriod(n_weeks_ago).last_date_in_week
106 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.WEEK
107 else FirstAndLastDateOfPeriod(n_months_ago).last_date_in_month
108 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.MONTH
109 else FirstAndLastDateOfPeriod(n_quarters_ago).last_date_in_quarter
110 if self.feature_period.period_unit_of_measure == PeriodUnitOfMeasure.QUARTER
111 else FirstAndLastDateOfPeriod(n_years_ago).last_date_in_year
112 )
113 # min() is used to ensure we don't return a date later than self.as_at
114 return min(last_day_of_period, self.as_at)
116 @property
117 def column_metadata(self) -> Dict[str, str]:
118 return {
119 "created-with-love-by": "https://github.com/jamiekt/jstark",
120 "start-date": self.start_date.strftime("%Y-%m-%d"),
121 "end-date": self.end_date.strftime("%Y-%m-%d"),
122 "description": (
123 f"{self.description_subject} between "
124 + f'{self.start_date.strftime("%Y-%m-%d")} and '
125 + f'{self.end_date.strftime("%Y-%m-%d")}'
126 ),
127 "generated-at": datetime.now().strftime("%Y-%m-%d"),
128 "commentary": self.commentary,
129 "name-stem": str(type(self).__name__),
130 }
133class DerivedFeature(Feature, metaclass=ABCMeta):
134 """A DerivedFeature is a feature that is calculated by combining
135 data that has already been aggregated. For example, a derived
136 feature called 'Average Gross Spend Per Basket' would be calculated
137 by dividing the total GrossSpend by number of baskets (BasketCount)
138 """
140 @property
141 def column(self) -> Column:
142 return f.coalesce(self.column_expression(), self.default_value()).alias(
143 self.feature_name, metadata=self.column_metadata
144 )
147class BaseFeature(Feature, metaclass=ABCMeta):
148 """A BaseFeature is a feature that is calculated by aggregating
149 raw source data. That data may have been cleaned and transformed in
150 some way, but typically the grain of that data is real occurrences
151 of some activity. Examples of such data are lists of grocery
152 transactions, phone calls or journeys.
153 """
155 def sum_aggregator(self, column: Column) -> Column:
156 return f.sum(column)
158 def count_aggregator(self, column: Column) -> Column:
159 return f.count(column)
161 def count_distinct_aggregator(self, column: Column) -> Column:
162 return f.countDistinct(column)
164 def approx_count_distinct_aggregator(self, column: Column) -> Column:
165 return f.approx_count_distinct(column)
167 def max_aggregator(self, column: Column) -> Column:
168 return f.max(column)
170 def min_aggregator(self, column: Column) -> Column:
171 return f.min(column)
173 @abstractmethod
174 def aggregator(self) -> Callable[[Column], Column]:
175 """Aggregator function"""
177 @property
178 def column(self) -> Column:
179 return f.coalesce(
180 self.aggregator()(
181 f.when(
182 (f.to_date(f.col("Timestamp")) >= f.lit(self.start_date))
183 & (f.to_date(f.col("Timestamp")) <= f.lit(self.end_date)),
184 self.column_expression(),
185 )
186 ),
187 self.default_value(),
188 ).alias(self.feature_name, metadata=self.column_metadata)