{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d372d82b-0842-4c24-86d1-b75c6637d2a3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "imported tools\n" ] } ], "source": [ "import pandas as pd\n", "from pathlib import Path\n", "from pprint import pprint\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import tools\n", "\n", "plt.rcParams[\"axes.prop_cycle\"] = plt.cycler(\"color\", plt.cm.tab10.colors)" ] }, { "cell_type": "code", "execution_count": 2, "id": "3f7c451c-6afb-439d-8bfe-4e545c4f7992", "metadata": {}, "outputs": [], "source": [ "data_path = Path(\"/home/niclas/repos/uni/thesis/experiment/data\")\n", "\n", "procedures = [\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"overall\"]\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "cffed6cb-ed44-4f8d-92dc-760b752d4302", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['random', 'blocked', 'fixed']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conditions = [x.stem for x in data_path.iterdir() if x.is_dir()]\n", "conditions\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "511c7800-8ca9-457b-90b0-d3c3302f6ef0", "metadata": {}, "outputs": [], "source": [ "data = {}\n", "for condition in conditions:\n", " data[condition] = {}\n", " for vp in (data_path / condition).iterdir():\n", " data[condition][vp.stem] = tools.fix_vp(tools.unpickle(vp / \"vp.pkl\"), procedures)\n", "\n", "data_train, data_test = tools.train_test_split(data)\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "39503de8-fa01-4ce7-a0c9-90f337548945", "metadata": {}, "outputs": [], "source": [ "condition = \"blocked\"\n", "#print(data_train[\"fixed\"])" ] }, { "cell_type": "code", "execution_count": 6, "id": "8f79bdbb-ac99-4577-866a-d566eee527ed", "metadata": {}, "outputs": [], "source": [ "train_blocked_fixed = tools.block_vps(data_train, \"fixed\")\n", "train_blocked_random = tools.block_vps(data_train, \"random\")\n", "train_blocked_blocked = tools.block_vps(data_train, \"blocked\")\n", "\n", "test_blocked_fixed = tools.block_vps(data_test, \"fixed\")\n", "test_blocked_random = tools.block_vps(data_test, \"random\")\n", "test_blocked_blocked = tools.block_vps(data_test, \"blocked\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "3ab5e63b-9c5a-4c53-9d80-983a27f833e1", "metadata": {}, "outputs": [], "source": [ "\n", "#fig = tools.plot_vp(blocked_vps[list(blocked_vps.keys())[0]])\n", "#plt.show()\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "e371b733-1b94-4ec9-8223-06b82dbac3df", "metadata": {}, "outputs": [], "source": [ "#tools.plot_average_vps(test_blocked_fixed)\n", "#plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "id": "eb3f2e96-2246-4b08-a7d1-999161ab3fd3", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.rcParams[\"figure.figsize\"] = [15,5]\n", "fig, axes = plt.subplots(ncols=2, nrows=1)\n", "tools.plot_average_vps(axes[0], \"fixed\", train_blocked_fixed)\n", "tools.plot_average_vps(axes[0], \"random\", train_blocked_random)\n", "tools.plot_average_vps(axes[0], \"blocked\", train_blocked_blocked)\n", "\n", "tools.plot_average_vps(axes[1], \"fixed\", test_blocked_fixed)\n", "tools.plot_average_vps(axes[1], \"random\", test_blocked_random)\n", "tools.plot_average_vps(axes[1], \"blocked\", test_blocked_blocked)\n", "\n", "axes[0].set_title(\"Train\")\n", "axes[0].set_xlabel(\"Block\")\n", "axes[0].set_ylabel(\"RTsum\")\n", "axes[1].set_title(\"Transfer\")\n", "#plt.xlabel(\"Block\")\n", "axes[1].set_xlabel(\"Block\")\n", "plt.ylabel(\"RTsum\")\n", "plt.legend()\n", "fig.tight_layout()\n", "plt.savefig(\"RT.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "id": "497bd4dc-943a-41f3-a694-3f4b8f049dee", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
traintest
vp140.9822220.986667
vp180.9622220.970000
vp150.9733330.980000
vp200.9066670.980000
vp100.9244440.943333
vp130.8577780.946667
vp170.9111110.960000
vp120.8222220.820000
vp190.9666670.800000
vp160.9577780.926667
\n", "
" ], "text/plain": [ " train test\n", "vp14 0.982222 0.986667\n", "vp18 0.962222 0.970000\n", "vp15 0.973333 0.980000\n", "vp20 0.906667 0.980000\n", "vp10 0.924444 0.943333\n", "vp13 0.857778 0.946667\n", "vp17 0.911111 0.960000\n", "vp12 0.822222 0.820000\n", "vp19 0.966667 0.800000\n", "vp16 0.957778 0.926667" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "condition = \"random\"\n", "df = pd.DataFrame([tools.total_accuracy(data[condition][vp], procedures) for vp in data[condition].keys()], index=data[condition].keys(), columns=[\"train\", \"test\"])\n", "df\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "143f7497-2c6c-492c-85ab-da3d2cf2a828", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
123456overall
vp140.9920.9760.9920.9760.4000.6000.968
vp180.9760.9760.9600.3920.6000.9840.904
vp150.9920.9920.9600.3920.5921.0000.928
vp200.9920.3760.9520.9760.9760.5600.784
vp100.9680.3600.5920.9840.9840.9920.712
vp130.3840.9600.9280.5600.9920.9680.568
vp170.3920.9680.5841.0001.0000.9920.648
vp120.9920.5920.3920.9760.9601.0000.016
vp191.0000.9920.0000.5760.9920.9920.848
vp160.9760.6000.3760.9760.9921.0000.752
\n", "
" ], "text/plain": [ " 1 2 3 4 5 6 overall\n", "vp14 0.992 0.976 0.992 0.976 0.400 0.600 0.968\n", "vp18 0.976 0.976 0.960 0.392 0.600 0.984 0.904\n", "vp15 0.992 0.992 0.960 0.392 0.592 1.000 0.928\n", "vp20 0.992 0.376 0.952 0.976 0.976 0.560 0.784\n", "vp10 0.968 0.360 0.592 0.984 0.984 0.992 0.712\n", "vp13 0.384 0.960 0.928 0.560 0.992 0.968 0.568\n", "vp17 0.392 0.968 0.584 1.000 1.000 0.992 0.648\n", "vp12 0.992 0.592 0.392 0.976 0.960 1.000 0.016\n", "vp19 1.000 0.992 0.000 0.576 0.992 0.992 0.848\n", "vp16 0.976 0.600 0.376 0.976 0.992 1.000 0.752" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "condition = \"random\"\n", "proc_accs = [\n", " tools.count_correct(data[condition][vp], data[condition][vp].keys(), procedures)\n", " for vp in data[condition].keys()\n", "]\n", "for vp in proc_accs:\n", " for proc in vp.keys():\n", " vp[proc] /= len(next(iter(data[condition].values())).keys())\n", "df = pd.DataFrame(proc_accs, index=data[condition].keys())\n", "df\n" ] }, { "cell_type": "code", "execution_count": null, "id": "52d6e2e6-999d-47a2-a829-cee5042d5c68", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }