#!/usr/bin/env python3 import pickle from copy import deepcopy import matplotlib.pyplot as plt import numpy as np def unpickle(pkl): with open(pkl, "rb") as f: data = pickle.load(f) return data def fix_vp(data, procedures): procs = deepcopy(procedures) if data["train_0"]["procedure_order"] == data["test_0"]["procedure_order"]: keys = list(data["train_0"].keys()) keys.remove("procedure_order") keys.remove("water_sample") for key in keys: procs.remove(key) proc_from = keys[2] proc_to = procs[0] for train in [x for x in data.keys() if x.startswith("train")]: vp = deepcopy(data[train]) vp[proc_to] = vp.pop(proc_from) data[train] = vp return data def block_vps(data, condition): blocked_vps = {} for vp in data[condition].keys(): blocked_vps[vp] = blocked_time(data[condition][vp]) return blocked_vps def blocked_time(vp): key_stem = list(vp.keys())[0].split("_")[0] trial_count = len(vp.keys()) block_size = 5 block_count = trial_count / block_size result = {} sum_time = 0 block_i = 1 for trial in range(1, trial_count): if trial % 5 == 0: sum_time = 0 block_i += 1 sum_time += sum_time_over_trial(vp[f"{key_stem}_{trial}"]) result[block_i] = sum_time return result def sum_time_over_trial(trial): total_time = 0 for proc in trial.keys(): if proc != "procedure_order" and proc != "water_sample": total_time += trial[proc]["time"] return total_time def plot_vp(ax, data_dict): x = data_dict.keys() y = data_dict.values() ax.scatter(x, y) def plot_average_vps(ax, label, blocked_vps): xlist = [list(blocked_vps[x].keys()) for x in blocked_vps] ylist = [list(blocked_vps[x].values()) for x in blocked_vps] x = xlist[0] yarray = np.array(ylist) y = np.average(yarray, axis=0) ax.scatter(x, y, label=label) def count_correct(vp, trials, procedures): trials_correct = {} for proc in procedures: trials_correct[proc] = 0 for sample in trials: for proc in vp[sample]["procedure_order"]: vp_ans = vp[sample][proc]["answer"] for c in vp_ans: if not c.isdigit(): vp_ans = vp_ans.replace(c, "") vp_ans = int(vp_ans) if vp_ans == vp[sample]["water_sample"][proc][0]: trials_correct[proc] += 1 return trials_correct def total_accuracy(vp, procedures): train = [x for x in vp.keys() if "train" in x] test = [x for x in vp.keys() if "test" in x] train_total = len(train) * len(vp[train[0]]["procedure_order"]) test_total = len(test) * len(vp[test[0]]["procedure_order"]) acc_train = count_correct(vp, train, procedures) acc_test = count_correct(vp, test, procedures) acc_train = sum([acc_train[x] for x in acc_train.keys()]) / train_total acc_test = sum([acc_test[x] for x in acc_test.keys()]) / test_total return acc_train, acc_test def train_test_split(data): def delete_trials(data, string): new_dict = {} for cond in data.keys(): new_dict[cond] = {} for vp in data[cond].keys(): new_dict[cond][vp] = {} for trial in data[cond][vp].keys(): if string in trial and trial != "train_0": new_dict[cond][vp][trial] = data[cond][vp][trial] return new_dict data_train = delete_trials(data, "train") data_test = delete_trials(data, "test") return data_train, data_test print("imported tools")