Coverage for jstark / feature_generator.py: 95%
100 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-23 22:34 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-23 22:34 +0000
1"""Base class for all feature generators"""
3from abc import ABCMeta
4import re
5from datetime import date
6from typing import Self
8from pyspark.sql import Column, SparkSession
10from jstark.feature_period import FeaturePeriod, PeriodUnitOfMeasure
11from jstark.exceptions import FeaturePeriodMnemonicIsInvalid
12from jstark.features.feature import Feature
14FeaturePeriodsType = (
15 list[FeaturePeriod]
16 | list[str]
17 | list[FeaturePeriod | str]
18 | set[FeaturePeriod]
19 | set[str]
20 | None
21)
24class FeatureGenerator(metaclass=ABCMeta):
25 """Base class for all feature generators"""
27 def __init__(
28 self,
29 as_at: date | None = None,
30 feature_periods: FeaturePeriodsType = None,
31 feature_stems: set[str] | list[str] | None = None,
32 first_day_of_week: str | None = None,
33 use_absolute_periods: bool = False,
34 ) -> None:
35 if as_at is None:
36 as_at = date.today()
37 if feature_stems is None:
38 feature_stems = set[str]()
39 if isinstance(feature_stems, list):
40 feature_stems = set[str](feature_stems)
41 self.as_at = as_at
42 self.feature_periods = feature_periods
43 self.feature_stems = feature_stems
44 self.first_day_of_week = first_day_of_week
45 self.use_absolute_periods = use_absolute_periods
47 FEATURE_CLASSES: set[type["Feature"]] = set[type["Feature"]]()
49 @property
50 def as_at(self) -> date:
51 return self._as_at
53 @as_at.setter
54 def as_at(self, value: date) -> None:
55 self._as_at = value
57 @property
58 def feature_periods(self) -> list[FeaturePeriod]:
59 return self.__feature_periods
61 @feature_periods.setter
62 def feature_periods(self, value: FeaturePeriodsType) -> None:
63 if value is None:
64 value = {FeaturePeriod(PeriodUnitOfMeasure.WEEK, 52, 0)}
65 period_unit_of_measure_values = "".join([e.value for e in PeriodUnitOfMeasure])
66 regex = (
67 # https://regex101.com/r/Xvf3ey/1
68 r"^(\d*)([" + period_unit_of_measure_values + r"])(\d*)$"
69 )
70 _feature_periods = []
71 for fp in value:
72 if isinstance(fp, FeaturePeriod):
73 _feature_periods.append(fp)
74 else:
75 matches = re.match(regex, fp)
76 if not matches:
77 raise FeaturePeriodMnemonicIsInvalid
78 _feature_periods.append(
79 FeaturePeriod(
80 PeriodUnitOfMeasure(matches[2]),
81 int(matches[1]),
82 int(matches[3]),
83 )
84 )
85 self.__feature_periods = _feature_periods
87 @property
88 def features(self) -> list[Column]:
89 # Find feature stems that do not correspond to any class in FEATURE_CLASSES.
90 # If any are not found, raise an Exception.
91 missing_stems = self.feature_stems - {
92 cls.__name__ for cls in self.FEATURE_CLASSES
93 }
94 if missing_stems:
95 # Only raise on the first (sorted for determinism)
96 raise Exception(f"Feature(s) {sorted(missing_stems)} not found")
97 desired_features = (
98 [fc for fc in self.FEATURE_CLASSES if fc.__name__ in self.feature_stems]
99 if self.feature_stems
100 else self.FEATURE_CLASSES
101 )
102 return [
103 feature.column
104 for feature in [
105 f[0](
106 as_at=self.as_at,
107 feature_period=f[1],
108 first_day_of_week=self.first_day_of_week,
109 use_absolute_periods=self.use_absolute_periods,
110 )
111 for f in (
112 (cls, fp) for cls in desired_features for fp in self.feature_periods
113 )
114 ]
115 ]
117 @property
118 def references(self) -> dict[str, list[str]]:
119 # this function requires a SparkSession in order to do its thing.
120 # In normal operation a SparkSession will probably already exist
121 # but in unit tests that might not be the case, so getOrCreate one
122 SparkSession.builder.getOrCreate()
123 return {
124 # pylint: disable=protected-access
125 node.name().head(): self.parse_references(node.toString())
126 for node in [c._jc.node() for c in self.features] # type: ignore
127 }
129 @property
130 def flattened_references(self) -> set[str]:
131 return {item for sublist in self.references.values() for item in sublist}
133 @staticmethod
134 def parse_references(references: str) -> list[str]:
135 return sorted(
136 set(re.findall(r"UnresolvedAttribute\(List\(([^)]+)\)", references))
137 )
139 def with_feature_periods(self, feature_periods: FeaturePeriodsType) -> Self:
140 self.feature_periods = feature_periods
141 return self
143 def with_feature_period(self, feature_period: FeaturePeriod | str) -> Self:
144 self.feature_periods = [*self.feature_periods, feature_period]
145 return self
147 def with_feature_stems(self, feature_stems: list[str]) -> Self:
148 self.feature_stems = set[str](feature_stems)
149 return self
151 def with_feature_stem(self, feature_stem: str) -> Self:
152 self.feature_stems.add(feature_stem)
153 return self
155 def with_first_day_of_week(self, first_day_of_week: str) -> Self:
156 self.first_day_of_week = first_day_of_week
157 return self
159 def with_as_at(self, as_at: date) -> Self:
160 self.as_at = as_at
161 return self
163 def with_use_absolute_periods(self, use_absolute_periods: bool) -> Self:
164 self.use_absolute_periods = use_absolute_periods
165 return self
167 def without_feature_period(self, feature_period: FeaturePeriod | str) -> Self:
168 if isinstance(feature_period, str):
169 # Reuse setter to parse the mnemonic, then extract the result
170 original = self.feature_periods
171 self.feature_periods = [feature_period]
172 feature_period = self.feature_periods[0]
173 self.feature_periods = original
174 self.feature_periods = [
175 fp for fp in self.feature_periods if fp != feature_period
176 ]
177 return self
179 def without_feature_stem(self, feature_stem: str) -> Self:
180 self.feature_stems.discard(feature_stem)
181 return self
183 def __repr__(self) -> str:
184 periods = sorted([fp.mnemonic for fp in self.feature_periods])
185 stems = sorted(self.feature_stems) if self.feature_stems is not None else None
186 return (
187 f"{self.__class__.__name__}"
188 f"(as_at={self.as_at}"
189 f", feature_periods={periods}"
190 f", feature_stems={stems}"
191 f", first_day_of_week={self.first_day_of_week})"
192 )