Coverage for jstark/feature_generator.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-25 20:09 +0000

1"""Base class for all feature generators 

2""" 

3from abc import ABCMeta 

4import re 

5from datetime import date 

6from typing import List, Union, Dict 

7 

8from pyspark.sql import Column, SparkSession 

9 

10from jstark.feature_period import FeaturePeriod, PeriodUnitOfMeasure 

11from jstark.exceptions import FeaturePeriodMnemonicIsInvalid 

12 

13 

14class FeatureGenerator(metaclass=ABCMeta): 

15 """Base class for all feature generators""" 

16 

17 def __init__( 

18 self, 

19 as_at: date, 

20 feature_periods: Union[List[FeaturePeriod], List[str]] = [ 

21 FeaturePeriod(PeriodUnitOfMeasure.WEEK, 52, 0), 

22 ], 

23 ) -> None: 

24 # sourcery skip: use-named-expression 

25 # walrus operator not supported until python3.8, we are still supporting 3.7 

26 self.as_at = as_at 

27 period_unit_of_measure_values = "".join([e.value for e in PeriodUnitOfMeasure]) 

28 regex = ( 

29 # https://regex101.com/r/Xvf3ey/1 

30 r"^(\d*)([" 

31 + period_unit_of_measure_values 

32 + r"])(\d*)$" 

33 ) 

34 _feature_periods = [] 

35 for fp in feature_periods: 

36 if isinstance(fp, FeaturePeriod): 

37 _feature_periods.append(fp) 

38 else: 

39 matches = re.match(regex, fp) 

40 if not matches: 

41 raise FeaturePeriodMnemonicIsInvalid 

42 _feature_periods.append( 

43 FeaturePeriod( 

44 PeriodUnitOfMeasure(matches[2]), 

45 int(matches[1]), 

46 int(matches[3]), 

47 ) 

48 ) 

49 self.feature_periods = _feature_periods 

50 

51 # would prefer list[Type[Feature]] as type hint but 

52 # this only works on py3.10 and above 

53 FEATURE_CLASSES: list 

54 

55 @property 

56 def as_at(self) -> date: 

57 return self.__as_at 

58 

59 @as_at.setter 

60 def as_at(self, value: date) -> None: 

61 self.__as_at = value 

62 

63 @property 

64 def feature_periods(self) -> List[FeaturePeriod]: 

65 return self.__feature_periods 

66 

67 @feature_periods.setter 

68 def feature_periods(self, value: List[FeaturePeriod]) -> None: 

69 self.__feature_periods = value 

70 

71 @property 

72 def features(self) -> List[Column]: 

73 return [ 

74 feature.column 

75 for feature in [ 

76 f[0](as_at=self.as_at, feature_period=f[1]) 

77 for f in ( 

78 (cls, fp) 

79 for cls in self.FEATURE_CLASSES 

80 for fp in self.feature_periods 

81 ) 

82 ] 

83 ] 

84 

85 @property 

86 def references(self) -> Dict[str, List[str]]: 

87 # this function requires a SparkSession in order to do its thing. 

88 # In normal operation a SparkSession will probably already exist 

89 # but in unit tests that might not be the case, so getOrCreate one 

90 SparkSession.builder.getOrCreate() 

91 return { 

92 expr.name(): self.parse_references(expr.references().toList().toString()) 

93 # pylint: disable=protected-access 

94 for expr in [c._jc.expr() for c in self.features] # type: ignore 

95 } 

96 

97 @staticmethod 

98 def parse_references(references: str) -> List[str]: 

99 return sorted( 

100 "".join( 

101 references.replace("'", "") 

102 .replace("List(", "") 

103 .replace(")", "") 

104 .replace(")", "") 

105 .split() 

106 ).split(",") 

107 )