Coverage for jstark/sample/transactions.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-25 20:09 +0000

1import random 

2import uuid 

3from datetime import date 

4from typing import Dict, Any, Iterable, Union, List 

5from decimal import Decimal 

6 

7from pyspark.sql import SparkSession, DataFrame 

8from pyspark.sql.types import ( 

9 DecimalType, 

10 IntegerType, 

11 StringType, 

12 StructField, 

13 StructType, 

14 TimestampType, 

15) 

16from faker import Faker 

17from faker.providers import DynamicProvider 

18 

19 

20class FakeTransactions: 

21 @property 

22 def transactions_schema(self) -> StructType: 

23 return StructType( 

24 [ 

25 StructField("Timestamp", TimestampType(), True), 

26 StructField("Customer", StringType(), True), 

27 StructField("Store", StringType(), True), 

28 StructField("Channel", StringType(), True), 

29 StructField("Product", StringType(), True), 

30 StructField("Quantity", IntegerType(), True), 

31 StructField("Basket", StringType(), True), 

32 StructField("GrossSpend", DecimalType(10, 2), True), 

33 StructField("NetSpend", DecimalType(10, 2), True), 

34 StructField("Discount", DecimalType(10, 2), True), 

35 ] 

36 ) 

37 

38 @staticmethod 

39 def flatten_transactions(transactions: List[Any]) -> Iterable[Dict[str, Any]]: 

40 return [ 

41 { 

42 "Customer": d["Customer"], 

43 "Store": d["Store"], 

44 "Basket": d["Basket"], 

45 "Channel": d["Channel"], 

46 "Timestamp": d["Timestamp"], 

47 **d2, 

48 } 

49 for d in transactions 

50 for d2 in d["items"] 

51 ] 

52 

53 def get_df( 

54 self, seed: Union[int, None] = None, number_of_baskets: int = 1000 

55 ) -> DataFrame: 

56 

57 stores_provider = DynamicProvider( 

58 provider_name="store", 

59 elements=["Hammersmith", "Ealing", "Richmond", "Twickenham", "Staines"], 

60 ) 

61 products_provider = DynamicProvider( 

62 provider_name="product", 

63 elements=[ 

64 ("Custard Creams", "Ambient", 1.00), 

65 ("Carrots", "Fresh", 0.69), 

66 ("Cheddar", "Dairy", 3.43), 

67 ("Ice Cream", "Frozen", 5.32), 

68 ("Milk 1l", "Dairy", 1.70), 

69 ("Apples", "Fresh", 1.50), 

70 ("Beer", "BWS", 8.50), 

71 ("Wine", "BWS", 7.50), 

72 ("Yoghurt", "Dairy", 0.99), 

73 ("Bananas", "Fresh", 0.79), 

74 ("Nappies", "Baby", 15.00), 

75 ("Baby formula", "Baby", 5.99), 

76 ("Whisky", "BWS", 23.00), 

77 ], 

78 ) 

79 channels_provider = DynamicProvider( 

80 provider_name="channel", 

81 elements=["Instore", "Online", "Click and Collect"], 

82 ) 

83 

84 fake = Faker() 

85 if seed: 

86 Faker.seed(seed) 

87 fake.add_provider(stores_provider) 

88 fake.add_provider(channels_provider) 

89 

90 products_fake = Faker() 

91 products_fake.add_provider(products_provider) 

92 

93 transactions = [] 

94 

95 possible_quantities = [1, 2, 3, 4, 5] 

96 if seed: 

97 random.seed(seed) 

98 quantities = random.choices( 

99 possible_quantities, 

100 weights=[100, 40, 20, 10, 8], 

101 k=number_of_baskets * len(products_provider.elements), 

102 ) 

103 for basket in range(number_of_baskets): 

104 items = [] 

105 if seed: 

106 random.seed(seed) 

107 for item in range(random.randint(1, len(products_provider.elements))): 

108 p = products_fake.unique.product() 

109 quantity = quantities[(basket * len(possible_quantities)) + item] 

110 gross_spend: float = round(p[2] * quantity, 2) 

111 net_spend = round((gross_spend * random.random()), 2) 

112 discount = round((gross_spend - net_spend) * random.random(), 2) 

113 items.append( 

114 { 

115 "Product": p[0], 

116 "ProductCategory": p[1], 

117 "Quantity": quantity, 

118 "GrossSpend": Decimal(gross_spend), 

119 "NetSpend": Decimal(net_spend), 

120 "Discount": Decimal(discount), 

121 } 

122 ) 

123 transactions.append( 

124 { 

125 "Customer": fake.name(), 

126 "Homecity": fake.city(), 

127 "Store": fake.store(), 

128 "Timestamp": fake.date_time_between( 

129 start_date=date(2021, 1, 1), end_date=date(2021, 12, 31) 

130 ), 

131 "Basket": str(uuid.uuid4()), 

132 "Channel": fake.channel(), 

133 "items": items, 

134 } 

135 ) 

136 products_fake.unique.clear() 

137 

138 flattened_transactions = self.flatten_transactions(transactions) 

139 spark = SparkSession.builder.getOrCreate() 

140 return spark.createDataFrame( 

141 flattened_transactions, schema=self.transactions_schema # type: ignore 

142 )