From 4c71eec3cd5f5f36c1cdc6d2284f6dd93facc193 Mon Sep 17 00:00:00 2001 From: Niclas Dobbertin Date: Wed, 1 Nov 2023 11:20:45 +0100 Subject: put analysis functions into own py file --- experiment/analysis/tools.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 experiment/analysis/tools.py (limited to 'experiment/analysis/tools.py') diff --git a/experiment/analysis/tools.py b/experiment/analysis/tools.py new file mode 100644 index 0000000..d32ccd3 --- /dev/null +++ b/experiment/analysis/tools.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +import pickle +from copy import deepcopy + +def unpickle(pkl): + with open(pkl, "rb") as f: + data = pickle.load(f) + return data + + +def count_correct(vp, trials, procedures): + trials_correct = {} + for proc in procedures: + trials_correct[proc] = 0 + for sample in trials: + for proc in vp[sample]["procedure_order"]: + vp_ans = vp[sample][proc]["answer"] + for c in vp_ans: + if not c.isdigit(): + vp_ans = vp_ans.replace(c, "") + vp_ans = int(vp_ans) + if vp_ans == vp[sample]["water_sample"][proc][0]: + trials_correct[proc] += 1 + return trials_correct + + +def total_accuracy(vp, procedures): + train = [x for x in vp.keys() if "train" in x] + test = [x for x in vp.keys() if "test" in x] + + train_total = len(train) * len(vp[train[0]]["procedure_order"]) + test_total = len(test) * len(vp[test[0]]["procedure_order"]) + + acc_train = count_correct(vp, train, procedures) + acc_test = count_correct(vp, test, procedures) + + acc_train = sum([acc_train[x] for x in acc_train.keys()]) / train_total + acc_test = sum([acc_test[x] for x in acc_test.keys()]) / test_total + + return acc_train, acc_test + + +def train_test_split(data): + def delete_trials(data, string): + new_dict = {} + for cond in data.keys(): + new_dict[cond] = {} + for vp in data[cond].keys(): + new_dict[cond][vp] = {} + for trial in data[cond][vp].keys(): + if string in trial: + new_dict[cond][vp][trial] = data[cond][vp][trial] + return new_dict + data_train = delete_trials(data, "train") + data_test = delete_trials(data, "test") + + return data_train, data_test + +print("imported tools") -- cgit v1.2.3