update
This commit is contained in:
123
src/querying.py
123
src/querying.py
@@ -1,66 +1,101 @@
|
||||
import sqlite3
|
||||
import numpy as np
|
||||
from random import choice
|
||||
from tprint import tprint
|
||||
|
||||
from joblib import Memory # for persistent memoïzation
|
||||
|
||||
from query_generator import *
|
||||
import orderankings as odrk
|
||||
import kemeny_young as km
|
||||
|
||||
import yaml # to load config file
|
||||
from os import environ # access environment variables
|
||||
|
||||
from config import CONFIG, DATABASE_CFG, VENV_HOME, DATABASE_FILE
|
||||
|
||||
# persistent memoïzation
|
||||
memory = Memory("src/cache")
|
||||
memory = Memory(f"{VENV_HOME}/src/cache")
|
||||
|
||||
VENV_PATH = environ.get('VIRTUAL_ENV')
|
||||
VERBOSE = CONFIG["verbose"]["querying"]
|
||||
|
||||
with open(VENV_PATH + "/src/config.yaml") as config_file:
|
||||
cfg = yaml.load(config_file, Loader=yaml.Loader)
|
||||
|
||||
VERBOSE = cfg["verbose"]["querying"]
|
||||
|
||||
DATABASE_NAME = cfg["database_name"]
|
||||
if VERBOSE: print("using database", DATABASE_NAME)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Connexion to sqlite database
|
||||
######################### Connexion to sqlite database #########################
|
||||
|
||||
# initialize database connection
|
||||
DATABASE_FILE = f"{DATABASE_NAME}_dataset/{DATABASE_NAME}.db"
|
||||
if VERBOSE: print(f"connecting to {DATABASE_FILE}")
|
||||
if VERBOSE:
|
||||
print(f"connecting to {DATABASE_FILE}")
|
||||
|
||||
CON = sqlite3.connect(DATABASE_FILE)
|
||||
CUR = CON.cursor()
|
||||
|
||||
|
||||
|
||||
@memory.cache # persistent memoïzation
|
||||
def query(q: str) -> list[tuple]:
|
||||
"""Execute a given query and reture the result in a python list[tuple]."""
|
||||
if VERBOSE: print(f'sending query : {q}')
|
||||
if VERBOSE:
|
||||
print(f'sending query : {q}')
|
||||
res = CUR.execute(str(q))
|
||||
if VERBOSE: print("got response", res)
|
||||
if VERBOSE:
|
||||
print("got response", res)
|
||||
return res.fetchall()
|
||||
|
||||
################################################################################
|
||||
# Choice of the right query generator
|
||||
|
||||
if DATABASE_NAME == "flight_delay":
|
||||
QUERY_PARAM_GB_FACTORY = QueryFlightWithParameterGroupedByCriteria
|
||||
elif DATABASE_NAME == "SSB":
|
||||
QUERY_PARAM_GB_FACTORY = QuerySSBWithParameterGroupedByCriteria
|
||||
else:
|
||||
raise ValueError(f"Unknown database : {DATABASE_NAME}")
|
||||
##################### Choice of the right query generator ######################
|
||||
|
||||
################################################################################
|
||||
# orderings extraction functions
|
||||
|
||||
QUERY_PARAM_GB_CONSTRUCTOR = DATABASE_CFG["query_generator"]
|
||||
|
||||
|
||||
######################## orderings extraction functions ########################
|
||||
|
||||
def random_query() -> list[tuple]:
|
||||
random_criteria = choice(DATABASE_CFG["criterion"])
|
||||
|
||||
qg_constructor = DATABASE_CFG["query_generator"]
|
||||
sql_query = qg_constructor(
|
||||
parameter=DATABASE_CFG["parameter"],
|
||||
authorized_parameter_values=DATABASE_CFG["authorized_parameter_values"],
|
||||
criteria=random_criteria,
|
||||
summed_attribute=DATABASE_CFG["summed_attribute"])
|
||||
|
||||
# print the query
|
||||
if VERBOSE: print("query :", str(sql_query), sep="\n")
|
||||
|
||||
result = query(str(sql_query)) # get result from database
|
||||
|
||||
if VERBOSE: # print the result
|
||||
print("query result :")
|
||||
tprint(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_correct_length_orderings(orderings: list[tuple], length: int) -> list[tuple]:
|
||||
"""Keep only orders that are of the specified length that means removing
|
||||
too short ones, and slicing too long ones."""
|
||||
correct_length_orderings = np.array(
|
||||
[ordrng[:length] for ordrng in orderings if len(ordrng) >= length]
|
||||
)
|
||||
|
||||
if VERBOSE:
|
||||
print(f"found {len(correct_length_orderings)} orderings :")
|
||||
# print(correct_length_orderings)
|
||||
tprint(correct_length_orderings)
|
||||
return correct_length_orderings
|
||||
|
||||
|
||||
def rankings_from_table(query_result: list[tuple]):
|
||||
orderings_dict = odrk.get_all_orderings_from_table(query_result)
|
||||
orderings = orderings_dict.values()
|
||||
orderings = filter_correct_length_orderings(
|
||||
orderings,
|
||||
DATABASE_CFG["orders_length"])
|
||||
if VERBOSE:
|
||||
print(orderings)
|
||||
rankings = odrk.rankings_from_orderings(orderings)
|
||||
return rankings
|
||||
|
||||
@memory.cache # persistent memoïzation
|
||||
def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str, ...],
|
||||
length: int,
|
||||
authorized_parameter_values: tuple[str, ...] | None =None
|
||||
authorized_parameter_values: tuple[str, ...] | None = None
|
||||
) -> list[list[str]]:
|
||||
"""Gather the list of every ordering returned by queries using given values
|
||||
of parameter, summed_attribute, and all given values of criterion.
|
||||
@@ -73,11 +108,13 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
|
||||
Returns:
|
||||
list[list]: The list of all found orderings.
|
||||
"""
|
||||
|
||||
# instanciate the query generator
|
||||
qg = QUERY_PARAM_GB_FACTORY(parameter=parameter,
|
||||
authorized_parameter_values=authorized_parameter_values,
|
||||
summed_attribute=summed_attribute,
|
||||
criteria=None)
|
||||
qg = DATABASE_CFG["query_generator"](
|
||||
parameter=parameter,
|
||||
authorized_parameter_values=authorized_parameter_values,
|
||||
summed_attribute=summed_attribute,
|
||||
criteria=None)
|
||||
|
||||
# ensemble de tous les ordres trouvés
|
||||
# la clef est la valeur dans la colonne criteria
|
||||
@@ -95,18 +132,6 @@ def find_orderings(parameter: str, summed_attribute: str, criterion: tuple[str,
|
||||
# update the global list of all found orders
|
||||
orderings.extend(table_orders.values())
|
||||
|
||||
# keep only orders that are of the specified length
|
||||
# that means removing too short ones, and slicing too long ones
|
||||
correct_length_orderings = np.array(
|
||||
[ordrng[:length] for ordrng in orderings if len(ordrng) >= length]
|
||||
)
|
||||
|
||||
if VERBOSE:
|
||||
print(f"found {len(correct_length_orderings)} orderings :")
|
||||
print(correct_length_orderings)
|
||||
# tprint(correct_length_orderings)
|
||||
correct_length_orderings = filter_correct_length_orderings(orderings, length)
|
||||
|
||||
return correct_length_orderings
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user