This commit is contained in:
Oscar Plaisant
2024-07-02 02:32:24 +02:00
parent f84aec4456
commit e3df57ccde
68 changed files with 315 additions and 158 deletions

View File

@@ -1,66 +1,101 @@
import sqlite3
import numpy as np
from random import choice
from tprint import tprint
from joblib import Memory # for persistent memoïzation
from query_generator import *
import orderankings as odrk
import kemeny_young as km
import yaml # to load config file
from os import environ # access environment variables
from config import CONFIG, DATABASE_CFG, VENV_HOME, DATABASE_FILE
# persistent memoïzation
memory = Memory("src/cache")
memory = Memory(f"{VENV_HOME}/src/cache")
VENV_PATH = environ.get('VIRTUAL_ENV')
VERBOSE = CONFIG["verbose"]["querying"]
with open(VENV_PATH + "/src/config.yaml") as config_file:
cfg = yaml.load(config_file, Loader=yaml.Loader)
VERBOSE = cfg["verbose"]["querying"]
DATABASE_NAME = cfg["database_name"]
if VERBOSE: print("using database", DATABASE_NAME)
################################################################################
# Connexion to sqlite database
######################### Connexion to sqlite database #########################
# initialize database connection
DATABASE_FILE = f"{DATABASE_NAME}_dataset/{DATABASE_NAME}.db"
if VERBOSE: print(f"connecting to {DATABASE_FILE}")
if VERBOSE:
print(f"connecting to {DATABASE_FILE}")
CON = sqlite3.connect(DATABASE_FILE)
CUR = CON.cursor()
@memory.cache # persistent memoïzation
def query(q: str) -> list[tuple]:
"""Execute a given query and reture the result in a python list[tuple]."""
if VERBOSE: print(f'sending query : {q}')
if VERBOSE:
print(f'sending query : {q}')
res = CUR.execute(str(q))
if VERBOSE: print("got response", res)
if VERBOSE:
print("got response", res)
return res.fetchall()
################################################################################
# Choice of the right query generator
if DATABASE_NAME == "flight_delay":
QUERY_PARAM_GB_FACTORY = QueryFlightWithParameterGroupedByCriteria
elif DATABASE_NAME == "SSB":
QUERY_PARAM_GB_FACTORY = QuerySSBWithParameterGroupedByCriteria
else:
raise ValueError(f"Unknown database : {DATABASE_NAME}")
##################### Choice of the right query generator ######################
################################################################################
# orderings extraction functions
QUERY_PARAM_GB_CONSTRUCTOR = DATABASE_CFG["query_generator"]
######################## orderings extraction functions ########################
def random_query() -> list[tuple]:
random_criteria = choice(DATABASE_CFG["criterion"])
qg_constructor = DATABASE_CFG["query_generator"]
sql_query = qg_constructor(
parameter=DATABASE_CFG["parameter"],
authorized_parameter_values=DATABASE_CFG["authorized_parameter_values"],
criteria=random_criteria,
summed_attribute=DATABASE_CFG["summed_attribute"])
# print the query
if VERBOSE: print("query :", str(sql_query), sep="\n")
result = query(str(sql_query)) # get result from database
if VERBOSE: # print the result
print("query result :")
tprint(result)
return result
def filter_correct_length_orderings(orderings: list[tuple], length: int) -> list[tuple]:
"""Keep only orders that are of the specified length that means removing
too short ones, and slicing too long ones."""
correct_length_orderings = np.array(
[ordrng[:length] for ordrng in orderings if len(ordrng) >= length]
)
if VERBOSE:
print(f"found {len(correct_length_orderings)} orderings :")
# print(correct_length_orderings)
tprint(correct_length_orderings)
return correct_length_orderings
def rankings_from_table(query_result: list[tuple]):
orderings_dict = odrk.get_all_orderings_from_table(query_result)
orderings = orderings_dict.values()
orderings = filter_correct_length_orderings(
orderings,
DATABASE_CFG["orders_length"])
if VERBOSE:
print(orderings)
rankings = odrk.rankings_from_orderings(orderings)
return rankings
@memory.cache # persistent memoïzation
def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str, ...],
length: int,
authorized_parameter_values: tuple[str, ...] | None =None
authorized_parameter_values: tuple[str, ...] | None = None
) -> list[list[str]]:
"""Gather the list of every ordering returned by queries using given values
of parameter, summed_attribute, and all given values of criterion.
@@ -73,11 +108,13 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
Returns:
list[list]: The list of all found orderings.
"""
# instanciate the query generator
qg = QUERY_PARAM_GB_FACTORY(parameter=parameter,
authorized_parameter_values=authorized_parameter_values,
summed_attribute=summed_attribute,
criteria=None)
qg = DATABASE_CFG["query_generator"](
parameter=parameter,
authorized_parameter_values=authorized_parameter_values,
summed_attribute=summed_attribute,
criteria=None)
# ensemble de tous les ordres trouvés
# la clef est la valeur dans la colonne criteria
@@ -95,18 +132,6 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
# update the global list of all found orders
orderings.extend(table_orders.values())
# keep only orders that are of the specified length
# that means removing too short ones, and slicing too long ones
correct_length_orderings = np.array(
[ordrng[:length] for ordrng in orderings if len(ordrng) >= length]
)
if VERBOSE:
print(f"found {len(correct_length_orderings)} orderings :")
print(correct_length_orderings)
# tprint(correct_length_orderings)
correct_length_orderings = filter_correct_length_orderings(orderings, length)
return correct_length_orderings