summaryrefslogtreecommitdiff
path: root/experiment/analysis/tools.py
blob: d32ccd3b9709160cfa3ddd94330dd43418135a97 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python3

import pickle
from copy import deepcopy

def unpickle(pkl):
    with open(pkl, "rb") as f:
        data = pickle.load(f)
    return data


def count_correct(vp, trials, procedures):
    trials_correct = {}
    for proc in procedures:
        trials_correct[proc] = 0
    for sample in trials:
        for proc in vp[sample]["procedure_order"]:
            vp_ans = vp[sample][proc]["answer"]
            for c in vp_ans:
                if not c.isdigit():
                    vp_ans = vp_ans.replace(c, "")
            vp_ans = int(vp_ans)
            if vp_ans == vp[sample]["water_sample"][proc][0]:
                trials_correct[proc] += 1
    return trials_correct


def total_accuracy(vp, procedures):
    train = [x for x in vp.keys() if "train" in x]
    test = [x for x in vp.keys() if "test" in x]

    train_total = len(train) * len(vp[train[0]]["procedure_order"])
    test_total = len(test) * len(vp[test[0]]["procedure_order"])

    acc_train = count_correct(vp, train, procedures)
    acc_test = count_correct(vp, test, procedures)

    acc_train = sum([acc_train[x] for x in acc_train.keys()]) / train_total
    acc_test = sum([acc_test[x] for x in acc_test.keys()]) / test_total

    return acc_train, acc_test


def train_test_split(data):
    def delete_trials(data, string):
        new_dict = {}
        for cond in data.keys():
            new_dict[cond] = {}
            for vp in data[cond].keys():
                new_dict[cond][vp] = {}
                for trial in data[cond][vp].keys():
                    if string in trial:
                        new_dict[cond][vp][trial] = data[cond][vp][trial]
        return new_dict
    data_train = delete_trials(data, "train")
    data_test = delete_trials(data, "test")

    return data_train, data_test

print("imported tools")