From e381a1e64ddb9e3bacba25389214cfd4756bba06 Mon Sep 17 00:00:00 2001 From: "Dobbertin, Niclas" Date: Tue, 28 Nov 2023 12:00:55 +0100 Subject: add jupyter notebook, RTsum plot --- experiment/analysis/tools.py | 72 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) (limited to 'experiment/analysis/tools.py') diff --git a/experiment/analysis/tools.py b/experiment/analysis/tools.py index d32ccd3..1dffc9a 100644 --- a/experiment/analysis/tools.py +++ b/experiment/analysis/tools.py @@ -2,6 +2,9 @@ import pickle from copy import deepcopy +import matplotlib.pyplot as plt +import numpy as np + def unpickle(pkl): with open(pkl, "rb") as f: @@ -9,6 +12,73 @@ def unpickle(pkl): return data +def fix_vp(data, procedures): + procs = deepcopy(procedures) + if data["train_0"]["procedure_order"] == data["test_0"]["procedure_order"]: + keys = list(data["train_0"].keys()) + keys.remove("procedure_order") + keys.remove("water_sample") + for key in keys: + procs.remove(key) + proc_from = keys[2] + proc_to = procs[0] + for train in [x for x in data.keys() if x.startswith("train")]: + vp = deepcopy(data[train]) + vp[proc_to] = vp.pop(proc_from) + data[train] = vp + return data + +def block_vps(data, condition): + blocked_vps = {} + for vp in data[condition].keys(): + blocked_vps[vp] = blocked_time(data[condition][vp]) + return blocked_vps + + +def blocked_time(vp): + key_stem = list(vp.keys())[0].split("_")[0] + + trial_count = len(vp.keys()) + block_size = 5 + block_count = trial_count / block_size + + result = {} + sum_time = 0 + block_i = 0 + for trial in range(trial_count): + if trial % 5 == 0: + sum_time = 0 + block_i += 1 + sum_time += sum_time_over_trial(vp[f"{key_stem}_{trial}"]) + result[block_i] = sum_time + return result + + +def sum_time_over_trial(trial): + total_time = 0 + for proc in trial.keys(): + if proc != "procedure_order" and proc != "water_sample": + total_time += trial[proc]["time"] + return total_time + + +def plot_vp(ax, data_dict): + x = data_dict.keys() + y = data_dict.values() + + ax.scatter(x, y) + + +def plot_average_vps(ax, label, blocked_vps): + xlist = [list(blocked_vps[x].keys()) for x in blocked_vps] + ylist = [list(blocked_vps[x].values()) for x in blocked_vps] + x = xlist[0] + yarray = np.array(ylist) + y = np.average(yarray, axis=0) + + ax.scatter(x, y, label=label) + + def count_correct(vp, trials, procedures): trials_correct = {} for proc in procedures: @@ -52,9 +122,11 @@ def train_test_split(data): if string in trial: new_dict[cond][vp][trial] = data[cond][vp][trial] return new_dict + data_train = delete_trials(data, "train") data_test = delete_trials(data, "test") return data_train, data_test + print("imported tools") -- cgit v1.2.3