Skip to content

Commit

Permalink
Merge pull request #25 from RyanAugust/dev
Browse files Browse the repository at this point in the history
Visualization & Validation
  • Loading branch information
RyanAugust authored Feb 29, 2024
2 parents 1c48883 + 810bd76 commit 1b7bb37
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 4 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ description = "Python package for generating MMM (Marketing Mix Model) input dat
dependencies = [
"numpy",
"pandas",
"pyyaml"
"pyyaml",
"matplotlib"
]
requires-python = ">=3.8"
authors = [
Expand Down
38 changes: 38 additions & 0 deletions src/pysimmmulator/load_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,41 @@ def define_basic_params(
)

return my_basic_params

def validate_config(config_path: str, return_individual_results: bool = False):
cfg = load_config(config_path=config_path)
results = {}
overall = True
try:
define_basic_params(**cfg["basic_params"])
results.update({"basic_params":True})
except Exception as e:
results.update({"basic_params":e})
overall = False
try:
my_basic_params = define_basic_params(**cfg["basic_params"])
baseline_parameters(basic_params=my_basic_params, **cfg["baseline_params"])
results.update({"baseline_params":True})
except Exception as e:
results.update({"baseline_params":e})
overall = False

matched_validation = {
ad_spend_parameters:"ad_spend_params",
media_parameters:"media_params",
cvr_parameters:"cvr_params",
adstock_parameters:"adstock_params",
output_parameters:"output_params"
}
for handler, conf_name in matched_validation.items():
try:
handler(**cfg[conf_name])
results.update({conf_name: True})
except Exception as e:
results.update({conf_name: e})
overall = False

if return_individual_results:
return results
else:
return overall
7 changes: 4 additions & 3 deletions src/pysimmmulator/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@
output_parameters,
)

from .visualize import visualize

import numpy as np
import pandas as pd

import logging
import logging.config

logger = logging.getLogger(__name__)


class simmm:
class simmm(visualize):
"""Takes input of basic params and provies either piece meal or single shot
creation of MMM data using a config file,"""

def __init__(self, basic_params: basic_parameters = None, random_seed = None):
self.basic_params = basic_params
self.rng = self._create_random_factory(seed=random_seed)
super().__init__()

def _create_random_factory(self, seed: int) -> np.random.Generator:
rng = np.random.default_rng(seed=seed)
Expand Down
112 changes: 112 additions & 0 deletions src/pysimmmulator/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pandas as pd
import matplotlib.pyplot as plt


class visualize:
def __init__(self):
self._viz_available = True
self._valid_agg_levels = ['daily', 'weekly', 'monthly', 'yearly']

def plot_spend(self, agg: str = None):
"""Plot simulated spend data based on a passed date-wise aggregation
Args:
agg (str): pick from [{', '.join(self._valid_agg_levels)}] to aggregate simulated data by"""
assert agg in self._valid_agg_levels, f"Please select [{', '.join(self._valid_agg_levels)}] for your aggregation level. {agg} is an invalid selection."

# prepare frame and filter columns for plotting
self._plot_frame_overhead(agg_level=agg)
plot_cols = self._filter_columns(columns = self.plot_frame.columns.tolist(), filter_string = '_spend')

return self._plot_majors(columns = plot_cols)

def plot_impressions(self, agg: str = None):
"""Plot simulated impressions data based on a passed date-wise aggregation
Args:
agg (str): pick from [{', '.join(self._valid_agg_levels)}] to aggregate simulated data by"""
assert agg in self._valid_agg_levels, f"Please select [{', '.join(self._valid_agg_levels)}] for your aggregation level. {agg} is an invalid selection."

# prepare frame and filter columns for plotting
self._plot_frame_overhead(agg_level=agg)
plot_cols = self._filter_columns(columns = self.plot_frame.columns.tolist(), filter_string = '_impressions')

return self._plot_majors(columns = plot_cols)

def plot_clicks(self, agg: str = None):
"""Plot simulated clicks data based on a passed date-wise aggregation
Args:
agg (str): pick from [{', '.join(self._valid_agg_levels)}] to aggregate simulated data by"""
assert agg in self._valid_agg_levels, f"Please select [{', '.join(self._valid_agg_levels)}] for your aggregation level. {agg} is an invalid selection."

# prepare frame and filter columns for plotting
self._plot_frame_overhead(agg_level=agg)
plot_cols = self._filter_columns(columns = self.plot_frame.columns.tolist(), filter_string = '_clicks')

return self._plot_majors(columns = plot_cols)

def plot_revenue(self, agg: str = None):
"""Plot simulated revenue data based on a passed date-wise aggregation
Args:
agg (str): pick from [{', '.join(self._valid_agg_levels)}] to aggregate simulated data by"""
assert agg in self._valid_agg_levels, f"Please select [{', '.join(self._valid_agg_levels)}] for your aggregation level. {agg} is an invalid selection."

# prepare frame and filter columns for plotting
self._plot_frame_overhead(agg_level=agg)
plot_cols = self._filter_columns(columns = self.plot_frame.columns.tolist(), filter_string = 'total_revenue')

return self._plot_majors(columns = plot_cols)

def _filter_columns(self, columns: list, filter_string: str) -> list:
filtered_cols = []
[filtered_cols.append(col) for col in columns if filter_string in col]
return filtered_cols

def _plot_frame_overhead(self, agg_level: int = None) -> pd.DataFrame:
if agg_level is not None:
self.plot_frame = self.final_df.copy()
self.plot_frame.reset_index(inplace=True)
self._aggregator(agg_level)
else:
self.plot_frame = self.final_df.copy()

def _aggregator(self, agg_level: str):
if agg_level == 'daily':
self.plot_frame = self.plot_frame.groupby("date").sum()

elif agg_level == 'weekly':
self.plot_frame["week_start"] = self.plot_frame["date"] - pd.to_timedelta(
self.plot_frame["date"].apply(lambda x: x.weekday()), unit="d"
)
del self.plot_frame["date"]
self.plot_frame = self.plot_frame.groupby("week_start").sum()

elif agg_level == 'monthly':
self.plot_frame["month_start"] = self.plot_frame["date"] - pd.to_timedelta(
self.plot_frame["date"].apply(lambda x: x.day), unit="d"
)
del self.plot_frame["date"]
self.plot_frame = self.plot_frame.groupby("month_start").sum()

elif agg_level == 'yearly':
self.plot_frame["year_start"] = self.plot_frame["date"] - pd.to_timedelta(
self.plot_frame["date"].apply(lambda x: x.timetuple()[7]), unit="d"
)
del self.plot_frame["date"]
self.plot_frame = self.plot_frame.groupby("year_start").sum()


def _plot_majors(self, columns):
plot_subject = columns[-1].split('_')[1]
plot_subject = plot_subject[0].upper() + plot_subject[1:]

fig, ax = plt.subplots(1,1, figsize=(9,6), dpi=200)
for col in columns:
ax.plot(self.plot_frame.index, self.plot_frame[col], label=col.split('_')[0])
ax.set_xlabel("Date")
ax.set_ylabel(f"{plot_subject}")
ax.set_title(f"{plot_subject} by Channel")
fig.legend(loc="upper right")
plt.savefig(f'{plot_subject}_by_channel.png')
4 changes: 4 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ def test_adstock_cfg_check():
adstock_params = load_parameters.adstock_parameters(**cfg["adstock_params"])
adstock_params.check(basic_params=my_basic_params)
assert True == True

def test_validate_config():
overall = load_parameters.validate_config(config_path="./example_config.yaml")
assert overall == True
38 changes: 38 additions & 0 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pysimmmulator as pysimmm

def test_viz_clicks_daily():
cfg = pysimmm.load_parameters.load_config(config_path="example_config.yaml")
sim = pysimmm.simmm()
sim.run_with_config(config=cfg)
sim.plot_clicks(agg='daily')

def test_viz_clicks_weekly():
cfg = pysimmm.load_parameters.load_config(config_path="example_config.yaml")
sim = pysimmm.simmm()
sim.run_with_config(config=cfg)
sim.plot_clicks(agg='weekly')

def test_viz_clicks_monthly():
cfg = pysimmm.load_parameters.load_config(config_path="example_config.yaml")
sim = pysimmm.simmm()
sim.run_with_config(config=cfg)
sim.plot_clicks(agg='monthly')

def test_viz_clicks_yearly():
cfg = pysimmm.load_parameters.load_config(config_path="example_config.yaml")
sim = pysimmm.simmm()
sim.run_with_config(config=cfg)
sim.plot_clicks(agg='yearly')

def test_viz_impressions_daily():
cfg = pysimmm.load_parameters.load_config(config_path="example_config.yaml")
sim = pysimmm.simmm()
sim.run_with_config(config=cfg)
sim.plot_impressions(agg='daily')

def test_viz_spend_daily():
cfg = pysimmm.load_parameters.load_config(config_path="example_config.yaml")
sim = pysimmm.simmm()
sim.run_with_config(config=cfg)
sim.plot_spend(agg='daily')

0 comments on commit 1b7bb37

Please sign in to comment.