Coverage for jstark/feature_generator.py: 100%
44 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-25 20:09 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-25 20:09 +0000
1"""Base class for all feature generators
2"""
3from abc import ABCMeta
4import re
5from datetime import date
6from typing import List, Union, Dict
8from pyspark.sql import Column, SparkSession
10from jstark.feature_period import FeaturePeriod, PeriodUnitOfMeasure
11from jstark.exceptions import FeaturePeriodMnemonicIsInvalid
14class FeatureGenerator(metaclass=ABCMeta):
15 """Base class for all feature generators"""
17 def __init__(
18 self,
19 as_at: date,
20 feature_periods: Union[List[FeaturePeriod], List[str]] = [
21 FeaturePeriod(PeriodUnitOfMeasure.WEEK, 52, 0),
22 ],
23 ) -> None:
24 # sourcery skip: use-named-expression
25 # walrus operator not supported until python3.8, we are still supporting 3.7
26 self.as_at = as_at
27 period_unit_of_measure_values = "".join([e.value for e in PeriodUnitOfMeasure])
28 regex = (
29 # https://regex101.com/r/Xvf3ey/1
30 r"^(\d*)(["
31 + period_unit_of_measure_values
32 + r"])(\d*)$"
33 )
34 _feature_periods = []
35 for fp in feature_periods:
36 if isinstance(fp, FeaturePeriod):
37 _feature_periods.append(fp)
38 else:
39 matches = re.match(regex, fp)
40 if not matches:
41 raise FeaturePeriodMnemonicIsInvalid
42 _feature_periods.append(
43 FeaturePeriod(
44 PeriodUnitOfMeasure(matches[2]),
45 int(matches[1]),
46 int(matches[3]),
47 )
48 )
49 self.feature_periods = _feature_periods
51 # would prefer list[Type[Feature]] as type hint but
52 # this only works on py3.10 and above
53 FEATURE_CLASSES: list
55 @property
56 def as_at(self) -> date:
57 return self.__as_at
59 @as_at.setter
60 def as_at(self, value: date) -> None:
61 self.__as_at = value
63 @property
64 def feature_periods(self) -> List[FeaturePeriod]:
65 return self.__feature_periods
67 @feature_periods.setter
68 def feature_periods(self, value: List[FeaturePeriod]) -> None:
69 self.__feature_periods = value
71 @property
72 def features(self) -> List[Column]:
73 return [
74 feature.column
75 for feature in [
76 f[0](as_at=self.as_at, feature_period=f[1])
77 for f in (
78 (cls, fp)
79 for cls in self.FEATURE_CLASSES
80 for fp in self.feature_periods
81 )
82 ]
83 ]
85 @property
86 def references(self) -> Dict[str, List[str]]:
87 # this function requires a SparkSession in order to do its thing.
88 # In normal operation a SparkSession will probably already exist
89 # but in unit tests that might not be the case, so getOrCreate one
90 SparkSession.builder.getOrCreate()
91 return {
92 expr.name(): self.parse_references(expr.references().toList().toString())
93 # pylint: disable=protected-access
94 for expr in [c._jc.expr() for c in self.features] # type: ignore
95 }
97 @staticmethod
98 def parse_references(references: str) -> List[str]:
99 return sorted(
100 "".join(
101 references.replace("'", "")
102 .replace("List(", "")
103 .replace(")", "")
104 .replace(")", "")
105 .split()
106 ).split(",")
107 )