Coverage for jstark / sample / transactions.py: 100%
53 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-23 22:34 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-23 22:34 +0000
1import random
2from functools import cached_property
3import uuid
4from datetime import date
5from typing import Any, Iterable
6from decimal import Decimal
8from pyspark.sql import SparkSession, DataFrame
9from pyspark.sql.types import (
10 DecimalType,
11 IntegerType,
12 StringType,
13 StructField,
14 StructType,
15 TimestampType,
16)
17from faker import Faker
18from faker.providers import DynamicProvider
21class FakeGroceryTransactions:
22 def __init__(self, seed: int | None = None, number_of_baskets: int = 1000):
23 self.seed = seed
24 self.number_of_baskets = number_of_baskets
26 @property
27 def transactions_schema(self) -> StructType:
28 return StructType(
29 [
30 StructField("Timestamp", TimestampType(), True),
31 StructField("Customer", StringType(), True),
32 StructField("Store", StringType(), True),
33 StructField("Channel", StringType(), True),
34 StructField("Product", StringType(), True),
35 StructField("Quantity", IntegerType(), True),
36 StructField("Basket", StringType(), True),
37 StructField("GrossSpend", DecimalType(10, 2), True),
38 StructField("NetSpend", DecimalType(10, 2), True),
39 StructField("Discount", DecimalType(10, 2), True),
40 ]
41 )
43 @staticmethod
44 def flatten_transactions(transactions: list[Any]) -> Iterable[dict[str, Any]]:
45 return [
46 {
47 "Customer": d["Customer"],
48 "Store": d["Store"],
49 "Basket": d["Basket"],
50 "Channel": d["Channel"],
51 "Timestamp": d["Timestamp"],
52 **d2,
53 }
54 for d in transactions
55 for d2 in d["items"]
56 ]
58 @cached_property
59 def df(self) -> DataFrame:
60 stores_provider = DynamicProvider(
61 provider_name="store",
62 elements=["Hammersmith", "Ealing", "Richmond", "Twickenham", "Staines"],
63 )
64 products_provider = DynamicProvider(
65 provider_name="product",
66 elements=[
67 ("Custard Creams", "Ambient", 1.00),
68 ("Carrots", "Fresh", 0.69),
69 ("Cheddar", "Dairy", 3.43),
70 ("Ice Cream", "Frozen", 5.32),
71 ("Milk 1l", "Dairy", 1.70),
72 ("Apples", "Fresh", 1.50),
73 ("Beer", "BWS", 8.50),
74 ("Wine", "BWS", 7.50),
75 ("Yoghurt", "Dairy", 0.99),
76 ("Bananas", "Fresh", 0.79),
77 ("Nappies", "Baby", 15.00),
78 ("Baby formula", "Baby", 5.99),
79 ("Whisky", "BWS", 23.00),
80 ],
81 )
82 channels_provider = DynamicProvider(
83 provider_name="channel",
84 elements=["Instore", "Online", "Click and Collect"],
85 )
87 fake = Faker()
88 if self.seed:
89 Faker.seed(self.seed)
90 fake.add_provider(stores_provider)
91 fake.add_provider(channels_provider)
93 products_fake = Faker()
94 products_fake.add_provider(products_provider)
96 transactions = []
98 possible_quantities = [1, 2, 3, 4, 5]
99 if self.seed:
100 random.seed(self.seed)
101 quantities = random.choices(
102 possible_quantities,
103 weights=[100, 40, 20, 10, 8],
104 k=self.number_of_baskets * len(products_provider.elements),
105 )
106 for basket in range(self.number_of_baskets):
107 items = []
108 if self.seed:
109 random.seed(self.seed)
110 for item in range(random.randint(1, len(products_provider.elements))):
111 p = products_fake.unique.product()
112 quantity = quantities[(basket * len(possible_quantities)) + item]
113 gross_spend: float = round(p[2] * quantity, 2)
114 net_spend = round((gross_spend * random.random()), 2)
115 discount = round((gross_spend - net_spend) * random.random(), 2)
116 items.append(
117 {
118 "Product": p[0],
119 "ProductCategory": p[1],
120 "Quantity": quantity,
121 "GrossSpend": Decimal(gross_spend),
122 "NetSpend": Decimal(net_spend),
123 "Discount": Decimal(discount),
124 }
125 )
126 transactions.append(
127 {
128 "Customer": fake.name(),
129 "Homecity": fake.city(),
130 "Store": fake.store(),
131 "Timestamp": fake.date_time_between(
132 start_date=date(2021, 1, 1), end_date=date(2021, 12, 31)
133 ),
134 "Basket": str(uuid.uuid4()),
135 "Channel": fake.channel(),
136 "items": items,
137 }
138 )
139 products_fake.unique.clear()
141 flattened_transactions = self.flatten_transactions(transactions)
142 spark = SparkSession.builder.getOrCreate()
143 return spark.createDataFrame(
144 flattened_transactions,
145 schema=self.transactions_schema, # type: ignore
146 )