Use the config file instead of global variables.

This commit is contained in:
Oscar Plaisant 2024-06-27 17:24:04 +02:00
parent ff0f646d04
commit b13f8ab039

View File

@ -1,26 +1,36 @@
import sqlite3
import numpy as np
from tprint import tprint
from joblib import Memory # for persistent memoïzation
from query_generator import *
import orderankings as odrk
import kemeny_young as km
from joblib import Memory
import yaml # to load config file
from os import environ # access environment variables
# persistent memoïzation
memory = Memory("cache")
memory = Memory("src/cache")
DATABASE_NAME = "flight_delay"
DATABASE_NAME = "SSB"
VENV_PATH = environ.get('VIRTUAL_ENV')
with open(VENV_PATH + "/src/config.yaml") as config_file:
cfg = yaml.load(config_file, Loader=yaml.Loader)
VERBOSE = cfg["verbose"]["querying"]
DATABASE_NAME = cfg["database_name"]
if VERBOSE: print("using database", DATABASE_NAME)
################################################################################
# Connexion to sqlite database
odrk.VERBOSE = False
VERBOSE = True
# initialize database connection
DATABASE_FILE = f"../{DATABASE_NAME}_dataset/{DATABASE_NAME}.db"
DATABASE_FILE = f"{DATABASE_NAME}_dataset/{DATABASE_NAME}.db"
if VERBOSE: print(f"connecting to {DATABASE_FILE}")
CON = sqlite3.connect(DATABASE_FILE)
CUR = CON.cursor()
@ -39,10 +49,10 @@ def query(q: str) -> list[tuple]:
if DATABASE_NAME == "flight_delay":
QUERY_PARAM_GB_FACTORY = QueryFlightWithParameterGroupedByCriteria
QUERY_PARAM_FACTORY = QueryFlightWithParameter
elif DATABASE_NAME == "SSB":
QUERY_PARAM_GB_FACTORY = QuerySSBWithParameterGroupedByCriteria
QUERY_PARAM_FACTORY = QuerySSBWithParameter
else:
raise ValueError(f"Unknown database : {DATABASE_NAME}")
################################################################################
# orderings extraction functions
@ -50,7 +60,7 @@ elif DATABASE_NAME == "SSB":
@memory.cache # persistent memoïzation
def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str, ...],
length: int,
authorized_parameter_values: list[str] =None
authorized_parameter_values: tuple[str, ...] | None =None
) -> list[list[str]]:
"""Gather the list of every ordering returned by queries using given values
of parameter, summed_attribute, and all given values of criterion.
@ -65,15 +75,10 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
"""
# instanciate the query generator
qg = QUERY_PARAM_GB_FACTORY(parameter=parameter,
authorized_parameter_values=authorized_parameter_values,
summed_attribute=summed_attribute,
criteria=None)
if authorized_parameter_values is None:
# reduce the number of compared parameter values
qg.authorized_parameter_values = qg.authorized_parameter_values#[:length]
else:
qg.authorized_parameter_values = authorized_parameter_values#[:length]
# ensemble de tous les ordres trouvés
# la clef est la valeur dans la colonne criteria
orderings = list()
@ -104,145 +109,4 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
return correct_length_orderings
@memory.cache # persistent memoïzation
def find_true_ordering_ranking(parameter: str,
summed_attribute: str,
length: int,
authorized_parameter_values: tuple[str,...]|None =None
) -> tuple[list[list[str]], list[list[int]]]:
"""Return the true (ordering, ranking), considering the data as a whole (no
grouping by), and getting the true order (no rankings aggregation)."""
if authorized_parameter_values is None:
qg = QUERY_PARAM_FACTORY(parameter=parameter,
summed_attribute=summed_attribute)
else:
qg = QUERY_PARAM_FACTORY(parameter=parameter,
summed_attribute=summed_attribute,
authorized_parameter_values=authorized_parameter_values)
# qg.authorized_parameter_values = qg.authorized_parameter_values[:length]
res = query(str(qg))
if VERBOSE: print(res)
ordering = odrk.get_orderings_from_table(res)
ranking = odrk.rankings_from_orderings([ordering])[0]
return ordering, ranking
################################################################################
def flight_delay_main():
PARAMETER = "departure_airport"
SUMMED_ATTRIBUTE = "nb_flights"
LENGTH = 5
ordering, ranking = find_true_ordering_ranking(parameter=PARAMETER,
summed_attribute=SUMMED_ATTRIBUTE,
length=LENGTH)
print(ordering, ranking)
CRITERION = [
# "airline",
# "departure_hour",
"day",
# "month",
]
rng = np.random.default_rng()
rng.shuffle(CRITERION)
grouped_orderings = find_orderings(parameter=PARAMETER,
summed_attribute=SUMMED_ATTRIBUTE,
criterion=CRITERION,
length=LENGTH)
# grouped_orderings = grouped_orderings[:5]
# tprint(grouped_orderings, limit=20)
print(grouped_orderings)
# inferred_ordering = odrk.get_orderings_from_table(inferred_orderings_table)
grouped_rankings = odrk.rankings_from_orderings(grouped_orderings)
_, inferred_ranking = km.rank_aggregation(grouped_rankings)
inferred_ranking = np.array(inferred_ranking)
inferred_order = odrk.ordering_from_ranking(inferred_ranking,
grouped_orderings[0])
print("inferred :")
print(inferred_order, inferred_ranking)
# print("distance =", km.kendall_tau_dist(ranking, inferred_ranking))
################################################################################
def SSB_main():
PARAMETER = "p_color"
SUMMED_ATTRIBUTE = "lo_quantity"
# SUMMED_ATTRIBUTE = "lo_revenue"
# SUMMED_ATTRIBUTE = "lo_extendedprice"
LENGTH = 2
CRITERION = (
##### customer table
"c_region",
"c_city",
"c_nation",
##### part table
"p_category",
"p_brand",
"p_mfgr",
"p_color",
"p_type",
"p_container",
##### supplier table
"s_city",
"s_nation",
"s_region",
##### order date
# "D_DATE",
# "D_DATEKEY",
# "D_DATE",
# "D_DAYOFWEEK",
# "D_MONTH",
# "D_YEAR",
# "D_YEARMONTHNUM",
# "D_YEARMONTH",
# "D_DAYNUMINWEEK"
# "D_DAYNUMINMONTH",
# "D_DAYNUMINYEAR",
# "D_MONTHNUMINYEAR",
"D_WEEKNUMINYEAR",
# "D_SELLINGSEASON",
# "D_LASTDAYINWEEKFL",
# "D_LASTDAYINMONTHFL",
# "D_HOLIDAYFL",
# "D_WEEKDAYFL",
)
HYPOTHESIS_ORDERING = ("aquamarine", "dark")
ordering, ranking = find_true_ordering_ranking(parameter=PARAMETER,
summed_attribute=SUMMED_ATTRIBUTE,
length=LENGTH,
authorized_parameter_values=HYPOTHESIS_ORDERING)
print(ordering, ranking)
grouped_orderings = find_orderings(parameter=PARAMETER,
summed_attribute=SUMMED_ATTRIBUTE,
criterion=CRITERION,
length=LENGTH
)
# grouped_orderings = grouped_orderings[:5]
tprint(grouped_orderings, limit=20)
# print(grouped_orderings)
# inferred_ordering = odrk.get_orderings_from_table(inferred_orderings_table)
grouped_rankings = odrk.rankings_from_orderings(grouped_orderings)
_, inferred_ranking = km.rank_aggregation(grouped_rankings)
inferred_ranking = np.array(inferred_ranking)
inferred_order = odrk.ordering_from_ranking(inferred_ranking,
grouped_orderings[0])
print("inferred :")
print(inferred_order, inferred_ranking)
# print("distance =", km.kendall_tau_dist(ranking, inferred_ranking))
if __name__ == '__main__':
if DATABASE_NAME == "SSB":
SSB_main()
elif DATABASE_NAME == "flight_delay":
flight_delay_main()