Coverage for jstark/sample/transactions.py: 100%
48 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-25 20:09 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-25 20:09 +0000
1import random
2import uuid
3from datetime import date
4from typing import Dict, Any, Iterable, Union, List
5from decimal import Decimal
7from pyspark.sql import SparkSession, DataFrame
8from pyspark.sql.types import (
9 DecimalType,
10 IntegerType,
11 StringType,
12 StructField,
13 StructType,
14 TimestampType,
15)
16from faker import Faker
17from faker.providers import DynamicProvider
20class FakeTransactions:
21 @property
22 def transactions_schema(self) -> StructType:
23 return StructType(
24 [
25 StructField("Timestamp", TimestampType(), True),
26 StructField("Customer", StringType(), True),
27 StructField("Store", StringType(), True),
28 StructField("Channel", StringType(), True),
29 StructField("Product", StringType(), True),
30 StructField("Quantity", IntegerType(), True),
31 StructField("Basket", StringType(), True),
32 StructField("GrossSpend", DecimalType(10, 2), True),
33 StructField("NetSpend", DecimalType(10, 2), True),
34 StructField("Discount", DecimalType(10, 2), True),
35 ]
36 )
38 @staticmethod
39 def flatten_transactions(transactions: List[Any]) -> Iterable[Dict[str, Any]]:
40 return [
41 {
42 "Customer": d["Customer"],
43 "Store": d["Store"],
44 "Basket": d["Basket"],
45 "Channel": d["Channel"],
46 "Timestamp": d["Timestamp"],
47 **d2,
48 }
49 for d in transactions
50 for d2 in d["items"]
51 ]
53 def get_df(
54 self, seed: Union[int, None] = None, number_of_baskets: int = 1000
55 ) -> DataFrame:
57 stores_provider = DynamicProvider(
58 provider_name="store",
59 elements=["Hammersmith", "Ealing", "Richmond", "Twickenham", "Staines"],
60 )
61 products_provider = DynamicProvider(
62 provider_name="product",
63 elements=[
64 ("Custard Creams", "Ambient", 1.00),
65 ("Carrots", "Fresh", 0.69),
66 ("Cheddar", "Dairy", 3.43),
67 ("Ice Cream", "Frozen", 5.32),
68 ("Milk 1l", "Dairy", 1.70),
69 ("Apples", "Fresh", 1.50),
70 ("Beer", "BWS", 8.50),
71 ("Wine", "BWS", 7.50),
72 ("Yoghurt", "Dairy", 0.99),
73 ("Bananas", "Fresh", 0.79),
74 ("Nappies", "Baby", 15.00),
75 ("Baby formula", "Baby", 5.99),
76 ("Whisky", "BWS", 23.00),
77 ],
78 )
79 channels_provider = DynamicProvider(
80 provider_name="channel",
81 elements=["Instore", "Online", "Click and Collect"],
82 )
84 fake = Faker()
85 if seed:
86 Faker.seed(seed)
87 fake.add_provider(stores_provider)
88 fake.add_provider(channels_provider)
90 products_fake = Faker()
91 products_fake.add_provider(products_provider)
93 transactions = []
95 possible_quantities = [1, 2, 3, 4, 5]
96 if seed:
97 random.seed(seed)
98 quantities = random.choices(
99 possible_quantities,
100 weights=[100, 40, 20, 10, 8],
101 k=number_of_baskets * len(products_provider.elements),
102 )
103 for basket in range(number_of_baskets):
104 items = []
105 if seed:
106 random.seed(seed)
107 for item in range(random.randint(1, len(products_provider.elements))):
108 p = products_fake.unique.product()
109 quantity = quantities[(basket * len(possible_quantities)) + item]
110 gross_spend: float = round(p[2] * quantity, 2)
111 net_spend = round((gross_spend * random.random()), 2)
112 discount = round((gross_spend - net_spend) * random.random(), 2)
113 items.append(
114 {
115 "Product": p[0],
116 "ProductCategory": p[1],
117 "Quantity": quantity,
118 "GrossSpend": Decimal(gross_spend),
119 "NetSpend": Decimal(net_spend),
120 "Discount": Decimal(discount),
121 }
122 )
123 transactions.append(
124 {
125 "Customer": fake.name(),
126 "Homecity": fake.city(),
127 "Store": fake.store(),
128 "Timestamp": fake.date_time_between(
129 start_date=date(2021, 1, 1), end_date=date(2021, 12, 31)
130 ),
131 "Basket": str(uuid.uuid4()),
132 "Channel": fake.channel(),
133 "items": items,
134 }
135 )
136 products_fake.unique.clear()
138 flattened_transactions = self.flatten_transactions(transactions)
139 spark = SparkSession.builder.getOrCreate()
140 return spark.createDataFrame(
141 flattened_transactions, schema=self.transactions_schema # type: ignore
142 )