Use the config file instead of global variables.
This commit is contained in:
parent
ff0f646d04
commit
b13f8ab039
180
src/querying.py
180
src/querying.py
@ -1,26 +1,36 @@
|
||||
import sqlite3
|
||||
import numpy as np
|
||||
from tprint import tprint
|
||||
|
||||
from joblib import Memory # for persistent memoïzation
|
||||
|
||||
from query_generator import *
|
||||
import orderankings as odrk
|
||||
import kemeny_young as km
|
||||
from joblib import Memory
|
||||
|
||||
import yaml # to load config file
|
||||
from os import environ # access environment variables
|
||||
|
||||
|
||||
# persistent memoïzation
|
||||
memory = Memory("cache")
|
||||
memory = Memory("src/cache")
|
||||
|
||||
DATABASE_NAME = "flight_delay"
|
||||
DATABASE_NAME = "SSB"
|
||||
VENV_PATH = environ.get('VIRTUAL_ENV')
|
||||
|
||||
with open(VENV_PATH + "/src/config.yaml") as config_file:
|
||||
cfg = yaml.load(config_file, Loader=yaml.Loader)
|
||||
|
||||
VERBOSE = cfg["verbose"]["querying"]
|
||||
|
||||
DATABASE_NAME = cfg["database_name"]
|
||||
if VERBOSE: print("using database", DATABASE_NAME)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Connexion to sqlite database
|
||||
|
||||
odrk.VERBOSE = False
|
||||
VERBOSE = True
|
||||
|
||||
# initialize database connection
|
||||
DATABASE_FILE = f"../{DATABASE_NAME}_dataset/{DATABASE_NAME}.db"
|
||||
DATABASE_FILE = f"{DATABASE_NAME}_dataset/{DATABASE_NAME}.db"
|
||||
if VERBOSE: print(f"connecting to {DATABASE_FILE}")
|
||||
CON = sqlite3.connect(DATABASE_FILE)
|
||||
CUR = CON.cursor()
|
||||
@ -39,10 +49,10 @@ def query(q: str) -> list[tuple]:
|
||||
|
||||
if DATABASE_NAME == "flight_delay":
|
||||
QUERY_PARAM_GB_FACTORY = QueryFlightWithParameterGroupedByCriteria
|
||||
QUERY_PARAM_FACTORY = QueryFlightWithParameter
|
||||
elif DATABASE_NAME == "SSB":
|
||||
QUERY_PARAM_GB_FACTORY = QuerySSBWithParameterGroupedByCriteria
|
||||
QUERY_PARAM_FACTORY = QuerySSBWithParameter
|
||||
else:
|
||||
raise ValueError(f"Unknown database : {DATABASE_NAME}")
|
||||
|
||||
################################################################################
|
||||
# orderings extraction functions
|
||||
@ -50,7 +60,7 @@ elif DATABASE_NAME == "SSB":
|
||||
@memory.cache # persistent memoïzation
|
||||
def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str, ...],
|
||||
length: int,
|
||||
authorized_parameter_values: list[str] =None
|
||||
authorized_parameter_values: tuple[str, ...] | None =None
|
||||
) -> list[list[str]]:
|
||||
"""Gather the list of every ordering returned by queries using given values
|
||||
of parameter, summed_attribute, and all given values of criterion.
|
||||
@ -65,15 +75,10 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
|
||||
"""
|
||||
# instanciate the query generator
|
||||
qg = QUERY_PARAM_GB_FACTORY(parameter=parameter,
|
||||
authorized_parameter_values=authorized_parameter_values,
|
||||
summed_attribute=summed_attribute,
|
||||
criteria=None)
|
||||
|
||||
if authorized_parameter_values is None:
|
||||
# reduce the number of compared parameter values
|
||||
qg.authorized_parameter_values = qg.authorized_parameter_values#[:length]
|
||||
else:
|
||||
qg.authorized_parameter_values = authorized_parameter_values#[:length]
|
||||
|
||||
# ensemble de tous les ordres trouvés
|
||||
# la clef est la valeur dans la colonne criteria
|
||||
orderings = list()
|
||||
@ -104,145 +109,4 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
|
||||
return correct_length_orderings
|
||||
|
||||
|
||||
@memory.cache # persistent memoïzation
|
||||
def find_true_ordering_ranking(parameter: str,
|
||||
summed_attribute: str,
|
||||
length: int,
|
||||
authorized_parameter_values: tuple[str,...]|None =None
|
||||
) -> tuple[list[list[str]], list[list[int]]]:
|
||||
"""Return the true (ordering, ranking), considering the data as a whole (no
|
||||
grouping by), and getting the true order (no rankings aggregation)."""
|
||||
if authorized_parameter_values is None:
|
||||
qg = QUERY_PARAM_FACTORY(parameter=parameter,
|
||||
summed_attribute=summed_attribute)
|
||||
else:
|
||||
qg = QUERY_PARAM_FACTORY(parameter=parameter,
|
||||
summed_attribute=summed_attribute,
|
||||
authorized_parameter_values=authorized_parameter_values)
|
||||
# qg.authorized_parameter_values = qg.authorized_parameter_values[:length]
|
||||
res = query(str(qg))
|
||||
if VERBOSE: print(res)
|
||||
ordering = odrk.get_orderings_from_table(res)
|
||||
ranking = odrk.rankings_from_orderings([ordering])[0]
|
||||
return ordering, ranking
|
||||
|
||||
################################################################################
|
||||
def flight_delay_main():
|
||||
PARAMETER = "departure_airport"
|
||||
SUMMED_ATTRIBUTE = "nb_flights"
|
||||
LENGTH = 5
|
||||
|
||||
ordering, ranking = find_true_ordering_ranking(parameter=PARAMETER,
|
||||
summed_attribute=SUMMED_ATTRIBUTE,
|
||||
length=LENGTH)
|
||||
print(ordering, ranking)
|
||||
|
||||
CRITERION = [
|
||||
# "airline",
|
||||
# "departure_hour",
|
||||
"day",
|
||||
# "month",
|
||||
]
|
||||
rng = np.random.default_rng()
|
||||
rng.shuffle(CRITERION)
|
||||
|
||||
grouped_orderings = find_orderings(parameter=PARAMETER,
|
||||
summed_attribute=SUMMED_ATTRIBUTE,
|
||||
criterion=CRITERION,
|
||||
length=LENGTH)
|
||||
# grouped_orderings = grouped_orderings[:5]
|
||||
# tprint(grouped_orderings, limit=20)
|
||||
print(grouped_orderings)
|
||||
# inferred_ordering = odrk.get_orderings_from_table(inferred_orderings_table)
|
||||
grouped_rankings = odrk.rankings_from_orderings(grouped_orderings)
|
||||
_, inferred_ranking = km.rank_aggregation(grouped_rankings)
|
||||
inferred_ranking = np.array(inferred_ranking)
|
||||
inferred_order = odrk.ordering_from_ranking(inferred_ranking,
|
||||
grouped_orderings[0])
|
||||
print("inferred :")
|
||||
print(inferred_order, inferred_ranking)
|
||||
|
||||
# print("distance =", km.kendall_tau_dist(ranking, inferred_ranking))
|
||||
|
||||
################################################################################
|
||||
def SSB_main():
|
||||
PARAMETER = "p_color"
|
||||
SUMMED_ATTRIBUTE = "lo_quantity"
|
||||
# SUMMED_ATTRIBUTE = "lo_revenue"
|
||||
# SUMMED_ATTRIBUTE = "lo_extendedprice"
|
||||
LENGTH = 2
|
||||
|
||||
CRITERION = (
|
||||
##### customer table
|
||||
"c_region",
|
||||
"c_city",
|
||||
"c_nation",
|
||||
|
||||
##### part table
|
||||
"p_category",
|
||||
"p_brand",
|
||||
"p_mfgr",
|
||||
"p_color",
|
||||
"p_type",
|
||||
"p_container",
|
||||
|
||||
##### supplier table
|
||||
"s_city",
|
||||
"s_nation",
|
||||
"s_region",
|
||||
|
||||
##### order date
|
||||
# "D_DATE",
|
||||
# "D_DATEKEY",
|
||||
# "D_DATE",
|
||||
# "D_DAYOFWEEK",
|
||||
# "D_MONTH",
|
||||
# "D_YEAR",
|
||||
# "D_YEARMONTHNUM",
|
||||
# "D_YEARMONTH",
|
||||
# "D_DAYNUMINWEEK"
|
||||
# "D_DAYNUMINMONTH",
|
||||
# "D_DAYNUMINYEAR",
|
||||
# "D_MONTHNUMINYEAR",
|
||||
"D_WEEKNUMINYEAR",
|
||||
# "D_SELLINGSEASON",
|
||||
# "D_LASTDAYINWEEKFL",
|
||||
# "D_LASTDAYINMONTHFL",
|
||||
# "D_HOLIDAYFL",
|
||||
# "D_WEEKDAYFL",
|
||||
)
|
||||
|
||||
HYPOTHESIS_ORDERING = ("aquamarine", "dark")
|
||||
|
||||
ordering, ranking = find_true_ordering_ranking(parameter=PARAMETER,
|
||||
summed_attribute=SUMMED_ATTRIBUTE,
|
||||
length=LENGTH,
|
||||
authorized_parameter_values=HYPOTHESIS_ORDERING)
|
||||
print(ordering, ranking)
|
||||
|
||||
grouped_orderings = find_orderings(parameter=PARAMETER,
|
||||
summed_attribute=SUMMED_ATTRIBUTE,
|
||||
criterion=CRITERION,
|
||||
length=LENGTH
|
||||
)
|
||||
|
||||
# grouped_orderings = grouped_orderings[:5]
|
||||
tprint(grouped_orderings, limit=20)
|
||||
# print(grouped_orderings)
|
||||
# inferred_ordering = odrk.get_orderings_from_table(inferred_orderings_table)
|
||||
grouped_rankings = odrk.rankings_from_orderings(grouped_orderings)
|
||||
_, inferred_ranking = km.rank_aggregation(grouped_rankings)
|
||||
inferred_ranking = np.array(inferred_ranking)
|
||||
inferred_order = odrk.ordering_from_ranking(inferred_ranking,
|
||||
grouped_orderings[0])
|
||||
print("inferred :")
|
||||
print(inferred_order, inferred_ranking)
|
||||
|
||||
# print("distance =", km.kendall_tau_dist(ranking, inferred_ranking))
|
||||
|
||||
if __name__ == '__main__':
|
||||
if DATABASE_NAME == "SSB":
|
||||
SSB_main()
|
||||
elif DATABASE_NAME == "flight_delay":
|
||||
flight_delay_main()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user