summaryrefslogtreecommitdiff
path: root/experiment/analysis/tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiment/analysis/tools.py')
-rw-r--r--experiment/analysis/tools.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/experiment/analysis/tools.py b/experiment/analysis/tools.py
index d32ccd3..1dffc9a 100644
--- a/experiment/analysis/tools.py
+++ b/experiment/analysis/tools.py
@@ -2,6 +2,9 @@
import pickle
from copy import deepcopy
+import matplotlib.pyplot as plt
+import numpy as np
+
def unpickle(pkl):
with open(pkl, "rb") as f:
@@ -9,6 +12,73 @@ def unpickle(pkl):
return data
+def fix_vp(data, procedures):
+ procs = deepcopy(procedures)
+ if data["train_0"]["procedure_order"] == data["test_0"]["procedure_order"]:
+ keys = list(data["train_0"].keys())
+ keys.remove("procedure_order")
+ keys.remove("water_sample")
+ for key in keys:
+ procs.remove(key)
+ proc_from = keys[2]
+ proc_to = procs[0]
+ for train in [x for x in data.keys() if x.startswith("train")]:
+ vp = deepcopy(data[train])
+ vp[proc_to] = vp.pop(proc_from)
+ data[train] = vp
+ return data
+
+def block_vps(data, condition):
+ blocked_vps = {}
+ for vp in data[condition].keys():
+ blocked_vps[vp] = blocked_time(data[condition][vp])
+ return blocked_vps
+
+
+def blocked_time(vp):
+ key_stem = list(vp.keys())[0].split("_")[0]
+
+ trial_count = len(vp.keys())
+ block_size = 5
+ block_count = trial_count / block_size
+
+ result = {}
+ sum_time = 0
+ block_i = 0
+ for trial in range(trial_count):
+ if trial % 5 == 0:
+ sum_time = 0
+ block_i += 1
+ sum_time += sum_time_over_trial(vp[f"{key_stem}_{trial}"])
+ result[block_i] = sum_time
+ return result
+
+
+def sum_time_over_trial(trial):
+ total_time = 0
+ for proc in trial.keys():
+ if proc != "procedure_order" and proc != "water_sample":
+ total_time += trial[proc]["time"]
+ return total_time
+
+
+def plot_vp(ax, data_dict):
+ x = data_dict.keys()
+ y = data_dict.values()
+
+ ax.scatter(x, y)
+
+
+def plot_average_vps(ax, label, blocked_vps):
+ xlist = [list(blocked_vps[x].keys()) for x in blocked_vps]
+ ylist = [list(blocked_vps[x].values()) for x in blocked_vps]
+ x = xlist[0]
+ yarray = np.array(ylist)
+ y = np.average(yarray, axis=0)
+
+ ax.scatter(x, y, label=label)
+
+
def count_correct(vp, trials, procedures):
trials_correct = {}
for proc in procedures:
@@ -52,9 +122,11 @@ def train_test_split(data):
if string in trial:
new_dict[cond][vp][trial] = data[cond][vp][trial]
return new_dict
+
data_train = delete_trials(data, "train")
data_test = delete_trials(data, "test")
return data_train, data_test
+
print("imported tools")