summaryrefslogtreecommitdiff
path: root/experiment/analysis/tools.py
blob: cde322f575cb340068271feee25701f74663844c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3

import pickle
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np


def unpickle(pkl):
    with open(pkl, "rb") as f:
        data = pickle.load(f)
    return data


def fix_vp(data, procedures):
    procs = deepcopy(procedures)
    if data["train_0"]["procedure_order"] == data["test_0"]["procedure_order"]:
        keys = list(data["train_0"].keys())
        keys.remove("procedure_order")
        keys.remove("water_sample")
        for key in keys:
            procs.remove(key)
        proc_from = keys[2]
        proc_to = procs[0]
        for train in [x for x in data.keys() if x.startswith("train")]:
            vp = deepcopy(data[train])
            vp[proc_to] = vp.pop(proc_from)
            data[train] = vp
    return data

def block_vps(data, condition):
    blocked_vps = {}
    for vp in data[condition].keys():
        blocked_vps[vp] = blocked_time(data[condition][vp])
    return blocked_vps


def blocked_time(vp):
    key_stem = list(vp.keys())[0].split("_")[0]

    trial_count = len(vp.keys())
    block_size = 5
    block_count = trial_count / block_size

    result = {}
    sum_time = 0
    block_i = 1
    for trial in range(1, trial_count):
        if trial % 5 == 0:
            sum_time = 0
            block_i += 1
        sum_time += sum_time_over_trial(vp[f"{key_stem}_{trial}"])
        result[block_i] = sum_time
    return result


def sum_time_over_trial(trial):
    total_time = 0
    for proc in trial.keys():
        if proc != "procedure_order" and proc != "water_sample":
            total_time += trial[proc]["time"]
    return total_time


def plot_vp(ax, data_dict):
    x = data_dict.keys()
    y = data_dict.values()

    ax.scatter(x, y)


def plot_average_vps(ax, label, blocked_vps):
    xlist = [list(blocked_vps[x].keys()) for x in blocked_vps]
    ylist = [list(blocked_vps[x].values()) for x in blocked_vps]
    x = xlist[0]
    yarray = np.array(ylist)
    y = np.average(yarray, axis=0)

    ax.scatter(x, y, label=label)


def count_correct(vp, trials, procedures):
    trials_correct = {}
    for proc in procedures:
        trials_correct[proc] = 0
    for sample in trials:
        for proc in vp[sample]["procedure_order"]:
            vp_ans = vp[sample][proc]["answer"]
            for c in vp_ans:
                if not c.isdigit():
                    vp_ans = vp_ans.replace(c, "")
            vp_ans = int(vp_ans)
            if vp_ans == vp[sample]["water_sample"][proc][0]:
                trials_correct[proc] += 1
    return trials_correct


def total_accuracy(vp, procedures):
    train = [x for x in vp.keys() if "train" in x]
    test = [x for x in vp.keys() if "test" in x]

    train_total = len(train) * len(vp[train[0]]["procedure_order"])
    test_total = len(test) * len(vp[test[0]]["procedure_order"])

    acc_train = count_correct(vp, train, procedures)
    acc_test = count_correct(vp, test, procedures)

    acc_train = sum([acc_train[x] for x in acc_train.keys()]) / train_total
    acc_test = sum([acc_test[x] for x in acc_test.keys()]) / test_total

    return acc_train, acc_test


def train_test_split(data):
    def delete_trials(data, string):
        new_dict = {}
        for cond in data.keys():
            new_dict[cond] = {}
            for vp in data[cond].keys():
                new_dict[cond][vp] = {}
                for trial in data[cond][vp].keys():
                    if string in trial and trial != "train_0":
                        new_dict[cond][vp][trial] = data[cond][vp][trial]
        return new_dict

    data_train = delete_trials(data, "train")
    data_test = delete_trials(data, "test")

    return data_train, data_test


print("imported tools")