summaryrefslogtreecommitdiff
path: root/experiment/analysis/tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiment/analysis/tools.py')
-rw-r--r--experiment/analysis/tools.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/experiment/analysis/tools.py b/experiment/analysis/tools.py
new file mode 100644
index 0000000..d32ccd3
--- /dev/null
+++ b/experiment/analysis/tools.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python3
+
+import pickle
+from copy import deepcopy
+
+def unpickle(pkl):
+ with open(pkl, "rb") as f:
+ data = pickle.load(f)
+ return data
+
+
+def count_correct(vp, trials, procedures):
+ trials_correct = {}
+ for proc in procedures:
+ trials_correct[proc] = 0
+ for sample in trials:
+ for proc in vp[sample]["procedure_order"]:
+ vp_ans = vp[sample][proc]["answer"]
+ for c in vp_ans:
+ if not c.isdigit():
+ vp_ans = vp_ans.replace(c, "")
+ vp_ans = int(vp_ans)
+ if vp_ans == vp[sample]["water_sample"][proc][0]:
+ trials_correct[proc] += 1
+ return trials_correct
+
+
+def total_accuracy(vp, procedures):
+ train = [x for x in vp.keys() if "train" in x]
+ test = [x for x in vp.keys() if "test" in x]
+
+ train_total = len(train) * len(vp[train[0]]["procedure_order"])
+ test_total = len(test) * len(vp[test[0]]["procedure_order"])
+
+ acc_train = count_correct(vp, train, procedures)
+ acc_test = count_correct(vp, test, procedures)
+
+ acc_train = sum([acc_train[x] for x in acc_train.keys()]) / train_total
+ acc_test = sum([acc_test[x] for x in acc_test.keys()]) / test_total
+
+ return acc_train, acc_test
+
+
+def train_test_split(data):
+ def delete_trials(data, string):
+ new_dict = {}
+ for cond in data.keys():
+ new_dict[cond] = {}
+ for vp in data[cond].keys():
+ new_dict[cond][vp] = {}
+ for trial in data[cond][vp].keys():
+ if string in trial:
+ new_dict[cond][vp][trial] = data[cond][vp][trial]
+ return new_dict
+ data_train = delete_trials(data, "train")
+ data_test = delete_trials(data, "test")
+
+ return data_train, data_test
+
+print("imported tools")