Coverage for jstark / sample / transactions.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-23 22:34 +0000

1import random 

2from functools import cached_property 

3import uuid 

4from datetime import date 

5from typing import Any, Iterable 

6from decimal import Decimal 

7 

8from pyspark.sql import SparkSession, DataFrame 

9from pyspark.sql.types import ( 

10 DecimalType, 

11 IntegerType, 

12 StringType, 

13 StructField, 

14 StructType, 

15 TimestampType, 

16) 

17from faker import Faker 

18from faker.providers import DynamicProvider 

19 

20 

21class FakeGroceryTransactions: 

22 def __init__(self, seed: int | None = None, number_of_baskets: int = 1000): 

23 self.seed = seed 

24 self.number_of_baskets = number_of_baskets 

25 

26 @property 

27 def transactions_schema(self) -> StructType: 

28 return StructType( 

29 [ 

30 StructField("Timestamp", TimestampType(), True), 

31 StructField("Customer", StringType(), True), 

32 StructField("Store", StringType(), True), 

33 StructField("Channel", StringType(), True), 

34 StructField("Product", StringType(), True), 

35 StructField("Quantity", IntegerType(), True), 

36 StructField("Basket", StringType(), True), 

37 StructField("GrossSpend", DecimalType(10, 2), True), 

38 StructField("NetSpend", DecimalType(10, 2), True), 

39 StructField("Discount", DecimalType(10, 2), True), 

40 ] 

41 ) 

42 

43 @staticmethod 

44 def flatten_transactions(transactions: list[Any]) -> Iterable[dict[str, Any]]: 

45 return [ 

46 { 

47 "Customer": d["Customer"], 

48 "Store": d["Store"], 

49 "Basket": d["Basket"], 

50 "Channel": d["Channel"], 

51 "Timestamp": d["Timestamp"], 

52 **d2, 

53 } 

54 for d in transactions 

55 for d2 in d["items"] 

56 ] 

57 

58 @cached_property 

59 def df(self) -> DataFrame: 

60 stores_provider = DynamicProvider( 

61 provider_name="store", 

62 elements=["Hammersmith", "Ealing", "Richmond", "Twickenham", "Staines"], 

63 ) 

64 products_provider = DynamicProvider( 

65 provider_name="product", 

66 elements=[ 

67 ("Custard Creams", "Ambient", 1.00), 

68 ("Carrots", "Fresh", 0.69), 

69 ("Cheddar", "Dairy", 3.43), 

70 ("Ice Cream", "Frozen", 5.32), 

71 ("Milk 1l", "Dairy", 1.70), 

72 ("Apples", "Fresh", 1.50), 

73 ("Beer", "BWS", 8.50), 

74 ("Wine", "BWS", 7.50), 

75 ("Yoghurt", "Dairy", 0.99), 

76 ("Bananas", "Fresh", 0.79), 

77 ("Nappies", "Baby", 15.00), 

78 ("Baby formula", "Baby", 5.99), 

79 ("Whisky", "BWS", 23.00), 

80 ], 

81 ) 

82 channels_provider = DynamicProvider( 

83 provider_name="channel", 

84 elements=["Instore", "Online", "Click and Collect"], 

85 ) 

86 

87 fake = Faker() 

88 if self.seed: 

89 Faker.seed(self.seed) 

90 fake.add_provider(stores_provider) 

91 fake.add_provider(channels_provider) 

92 

93 products_fake = Faker() 

94 products_fake.add_provider(products_provider) 

95 

96 transactions = [] 

97 

98 possible_quantities = [1, 2, 3, 4, 5] 

99 if self.seed: 

100 random.seed(self.seed) 

101 quantities = random.choices( 

102 possible_quantities, 

103 weights=[100, 40, 20, 10, 8], 

104 k=self.number_of_baskets * len(products_provider.elements), 

105 ) 

106 for basket in range(self.number_of_baskets): 

107 items = [] 

108 if self.seed: 

109 random.seed(self.seed) 

110 for item in range(random.randint(1, len(products_provider.elements))): 

111 p = products_fake.unique.product() 

112 quantity = quantities[(basket * len(possible_quantities)) + item] 

113 gross_spend: float = round(p[2] * quantity, 2) 

114 net_spend = round((gross_spend * random.random()), 2) 

115 discount = round((gross_spend - net_spend) * random.random(), 2) 

116 items.append( 

117 { 

118 "Product": p[0], 

119 "ProductCategory": p[1], 

120 "Quantity": quantity, 

121 "GrossSpend": Decimal(gross_spend), 

122 "NetSpend": Decimal(net_spend), 

123 "Discount": Decimal(discount), 

124 } 

125 ) 

126 transactions.append( 

127 { 

128 "Customer": fake.name(), 

129 "Homecity": fake.city(), 

130 "Store": fake.store(), 

131 "Timestamp": fake.date_time_between( 

132 start_date=date(2021, 1, 1), end_date=date(2021, 12, 31) 

133 ), 

134 "Basket": str(uuid.uuid4()), 

135 "Channel": fake.channel(), 

136 "items": items, 

137 } 

138 ) 

139 products_fake.unique.clear() 

140 

141 flattened_transactions = self.flatten_transactions(transactions) 

142 spark = SparkSession.builder.getOrCreate() 

143 return spark.createDataFrame( 

144 flattened_transactions, 

145 schema=self.transactions_schema, # type: ignore 

146 )