Coverage for jstark / feature_generator.py: 95%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-23 22:34 +0000

1"""Base class for all feature generators""" 

2 

3from abc import ABCMeta 

4import re 

5from datetime import date 

6from typing import Self 

7 

8from pyspark.sql import Column, SparkSession 

9 

10from jstark.feature_period import FeaturePeriod, PeriodUnitOfMeasure 

11from jstark.exceptions import FeaturePeriodMnemonicIsInvalid 

12from jstark.features.feature import Feature 

13 

14FeaturePeriodsType = ( 

15 list[FeaturePeriod] 

16 | list[str] 

17 | list[FeaturePeriod | str] 

18 | set[FeaturePeriod] 

19 | set[str] 

20 | None 

21) 

22 

23 

24class FeatureGenerator(metaclass=ABCMeta): 

25 """Base class for all feature generators""" 

26 

27 def __init__( 

28 self, 

29 as_at: date | None = None, 

30 feature_periods: FeaturePeriodsType = None, 

31 feature_stems: set[str] | list[str] | None = None, 

32 first_day_of_week: str | None = None, 

33 use_absolute_periods: bool = False, 

34 ) -> None: 

35 if as_at is None: 

36 as_at = date.today() 

37 if feature_stems is None: 

38 feature_stems = set[str]() 

39 if isinstance(feature_stems, list): 

40 feature_stems = set[str](feature_stems) 

41 self.as_at = as_at 

42 self.feature_periods = feature_periods 

43 self.feature_stems = feature_stems 

44 self.first_day_of_week = first_day_of_week 

45 self.use_absolute_periods = use_absolute_periods 

46 

47 FEATURE_CLASSES: set[type["Feature"]] = set[type["Feature"]]() 

48 

49 @property 

50 def as_at(self) -> date: 

51 return self._as_at 

52 

53 @as_at.setter 

54 def as_at(self, value: date) -> None: 

55 self._as_at = value 

56 

57 @property 

58 def feature_periods(self) -> list[FeaturePeriod]: 

59 return self.__feature_periods 

60 

61 @feature_periods.setter 

62 def feature_periods(self, value: FeaturePeriodsType) -> None: 

63 if value is None: 

64 value = {FeaturePeriod(PeriodUnitOfMeasure.WEEK, 52, 0)} 

65 period_unit_of_measure_values = "".join([e.value for e in PeriodUnitOfMeasure]) 

66 regex = ( 

67 # https://regex101.com/r/Xvf3ey/1 

68 r"^(\d*)([" + period_unit_of_measure_values + r"])(\d*)$" 

69 ) 

70 _feature_periods = [] 

71 for fp in value: 

72 if isinstance(fp, FeaturePeriod): 

73 _feature_periods.append(fp) 

74 else: 

75 matches = re.match(regex, fp) 

76 if not matches: 

77 raise FeaturePeriodMnemonicIsInvalid 

78 _feature_periods.append( 

79 FeaturePeriod( 

80 PeriodUnitOfMeasure(matches[2]), 

81 int(matches[1]), 

82 int(matches[3]), 

83 ) 

84 ) 

85 self.__feature_periods = _feature_periods 

86 

87 @property 

88 def features(self) -> list[Column]: 

89 # Find feature stems that do not correspond to any class in FEATURE_CLASSES. 

90 # If any are not found, raise an Exception. 

91 missing_stems = self.feature_stems - { 

92 cls.__name__ for cls in self.FEATURE_CLASSES 

93 } 

94 if missing_stems: 

95 # Only raise on the first (sorted for determinism) 

96 raise Exception(f"Feature(s) {sorted(missing_stems)} not found") 

97 desired_features = ( 

98 [fc for fc in self.FEATURE_CLASSES if fc.__name__ in self.feature_stems] 

99 if self.feature_stems 

100 else self.FEATURE_CLASSES 

101 ) 

102 return [ 

103 feature.column 

104 for feature in [ 

105 f[0]( 

106 as_at=self.as_at, 

107 feature_period=f[1], 

108 first_day_of_week=self.first_day_of_week, 

109 use_absolute_periods=self.use_absolute_periods, 

110 ) 

111 for f in ( 

112 (cls, fp) for cls in desired_features for fp in self.feature_periods 

113 ) 

114 ] 

115 ] 

116 

117 @property 

118 def references(self) -> dict[str, list[str]]: 

119 # this function requires a SparkSession in order to do its thing. 

120 # In normal operation a SparkSession will probably already exist 

121 # but in unit tests that might not be the case, so getOrCreate one 

122 SparkSession.builder.getOrCreate() 

123 return { 

124 # pylint: disable=protected-access 

125 node.name().head(): self.parse_references(node.toString()) 

126 for node in [c._jc.node() for c in self.features] # type: ignore 

127 } 

128 

129 @property 

130 def flattened_references(self) -> set[str]: 

131 return {item for sublist in self.references.values() for item in sublist} 

132 

133 @staticmethod 

134 def parse_references(references: str) -> list[str]: 

135 return sorted( 

136 set(re.findall(r"UnresolvedAttribute\(List\(([^)]+)\)", references)) 

137 ) 

138 

139 def with_feature_periods(self, feature_periods: FeaturePeriodsType) -> Self: 

140 self.feature_periods = feature_periods 

141 return self 

142 

143 def with_feature_period(self, feature_period: FeaturePeriod | str) -> Self: 

144 self.feature_periods = [*self.feature_periods, feature_period] 

145 return self 

146 

147 def with_feature_stems(self, feature_stems: list[str]) -> Self: 

148 self.feature_stems = set[str](feature_stems) 

149 return self 

150 

151 def with_feature_stem(self, feature_stem: str) -> Self: 

152 self.feature_stems.add(feature_stem) 

153 return self 

154 

155 def with_first_day_of_week(self, first_day_of_week: str) -> Self: 

156 self.first_day_of_week = first_day_of_week 

157 return self 

158 

159 def with_as_at(self, as_at: date) -> Self: 

160 self.as_at = as_at 

161 return self 

162 

163 def with_use_absolute_periods(self, use_absolute_periods: bool) -> Self: 

164 self.use_absolute_periods = use_absolute_periods 

165 return self 

166 

167 def without_feature_period(self, feature_period: FeaturePeriod | str) -> Self: 

168 if isinstance(feature_period, str): 

169 # Reuse setter to parse the mnemonic, then extract the result 

170 original = self.feature_periods 

171 self.feature_periods = [feature_period] 

172 feature_period = self.feature_periods[0] 

173 self.feature_periods = original 

174 self.feature_periods = [ 

175 fp for fp in self.feature_periods if fp != feature_period 

176 ] 

177 return self 

178 

179 def without_feature_stem(self, feature_stem: str) -> Self: 

180 self.feature_stems.discard(feature_stem) 

181 return self 

182 

183 def __repr__(self) -> str: 

184 periods = sorted([fp.mnemonic for fp in self.feature_periods]) 

185 stems = sorted(self.feature_stems) if self.feature_stems is not None else None 

186 return ( 

187 f"{self.__class__.__name__}" 

188 f"(as_at={self.as_at}" 

189 f", feature_periods={periods}" 

190 f", feature_stems={stems}" 

191 f", first_day_of_week={self.first_day_of_week})" 

192 )