|
9 | 9 | "source": [ |
10 | 10 | "from brian2 import *\n", |
11 | 11 | "from matplotlib import pyplot as plt\n", |
| 12 | + "from matplotlib.transforms import Affine2D\n", |
12 | 13 | "import pandas as pd\n", |
13 | 14 | "import numpy as np\n", |
14 | 15 | "from src.nb_helpers import *\n", |
|
803 | 804 | "fig, axs = plt.subplots(1, 3, figsize=(26, 6))\n", |
804 | 805 | "xticks = np.linspace(0, len(cc_cs_weights)-1, len(cc_cs_weights))\n", |
805 | 806 | "\n", |
| 807 | + "trans1 = Affine2D().translate(-0.2, 0.0)\n", |
| 808 | + "trans2 = Affine2D().translate(-0.1, 0.0)\n", |
| 809 | + "trans3 = Affine2D().translate(+0.0, 0.0)\n", |
| 810 | + "trans4 = Affine2D().translate(+0.1, 0.0)\n", |
| 811 | + "\n", |
806 | 812 | "os_cs_mean = [os_record[0] for os_record in os_mean_data]\n", |
807 | 813 | "os_cs_std = [os_record[0] for os_record in os_std_data]\n", |
808 | 814 | "os_cc_mean = [os_record[1] for os_record in os_mean_data]\n", |
|
813 | 819 | "os_pv_std = [os_record[3] for os_record in os_std_data]\n", |
814 | 820 | "\n", |
815 | 821 | "# Plot orientation selectivity\n", |
816 | | - "axs[0].errorbar(xticks, os_cs_mean, fmt='b', yerr=os_cs_std, label='CS')\n", |
817 | | - "axs[0].errorbar(xticks, os_cc_mean, fmt='r', yerr=os_cc_std, label='CC')\n", |
818 | | - "axs[0].errorbar(xticks, os_sst_mean, fmt='g', yerr=os_sst_std, label='SST')\n", |
819 | | - "axs[0].errorbar(xticks, os_pv_mean, fmt='y', yerr=os_pv_std, label='PV')\n", |
| 822 | + "axs[0].errorbar(xticks, os_cs_mean, fmt='b', yerr=os_cs_std, label='CS', transform=trans1+axs[0].transData, elinewidth=2)\n", |
| 823 | + "axs[0].errorbar(xticks, os_cc_mean, fmt='r', yerr=os_cc_std, label='CC', transform=trans2+axs[0].transData, elinewidth=2)\n", |
| 824 | + "axs[0].errorbar(xticks, os_sst_mean, fmt='g', yerr=os_sst_std, label='SST', transform=trans3+axs[0].transData, elinewidth=2)\n", |
| 825 | + "axs[0].errorbar(xticks, os_pv_mean, fmt='y', yerr=os_pv_std, label='PV', transform=trans4+axs[0].transData, elinewidth=2)\n", |
820 | 826 | "axs[0].set_xticks(xticks)\n", |
821 | 827 | "axs[0].set_xticklabels(cc_cs_weights)\n", |
822 | 828 | "axs[0].set_xlabel('Scalar of CC_CS connection')\n", |
|
834 | 840 | "os_paper_pv_std = [os_record[3] for os_record in os_paper_std_data]\n", |
835 | 841 | "\n", |
836 | 842 | "# Plot orientation selectivity (paper)\n", |
837 | | - "axs[1].errorbar(xticks, os_paper_cs_mean, fmt='b', yerr=os_paper_cs_std, label='CS')\n", |
838 | | - "axs[1].errorbar(xticks, os_paper_cc_mean, fmt='r', yerr=os_paper_cc_std, label='CC')\n", |
839 | | - "axs[1].errorbar(xticks, os_paper_sst_mean, fmt='g', yerr=os_paper_sst_std, label='SST')\n", |
840 | | - "axs[1].errorbar(xticks, os_paper_pv_mean, fmt='y', yerr=os_paper_pv_std, label='PV')\n", |
| 843 | + "axs[1].errorbar(xticks, os_paper_cs_mean, fmt='b', yerr=os_paper_cs_std, label='CS', transform=trans1+axs[1].transData, elinewidth=2)\n", |
| 844 | + "axs[1].errorbar(xticks, os_paper_cc_mean, fmt='r', yerr=os_paper_cc_std, label='CC', transform=trans2+axs[1].transData, elinewidth=2)\n", |
| 845 | + "axs[1].errorbar(xticks, os_paper_sst_mean, fmt='g', yerr=os_paper_sst_std, label='SST', transform=trans3+axs[1].transData, elinewidth=2)\n", |
| 846 | + "axs[1].errorbar(xticks, os_paper_pv_mean, fmt='y', yerr=os_paper_pv_std, label='PV', transform=trans4+axs[1].transData, elinewidth=2)\n", |
841 | 847 | "axs[1].set_xticks(xticks)\n", |
842 | 848 | "axs[1].set_xticklabels(cc_cs_weights)\n", |
843 | 849 | "axs[1].set_xlabel('Scalar of CC_CS connection')\n", |
|
855 | 861 | "ds_pv_std = [ds_record[3] for ds_record in ds_std_data]\n", |
856 | 862 | "\n", |
857 | 863 | "# Plot direction selectivity\n", |
858 | | - "axs[2].errorbar(xticks, ds_cs_mean, fmt='b', yerr=ds_cs_std, label='CS')\n", |
859 | | - "axs[2].errorbar(xticks, ds_cc_mean, fmt='r', yerr=ds_cc_std, label='CC')\n", |
860 | | - "axs[2].errorbar(xticks, ds_sst_mean, fmt='g', yerr=ds_sst_std, label='SST')\n", |
861 | | - "axs[2].errorbar(xticks, ds_pv_mean, fmt='y', yerr=ds_pv_std, label='PV')\n", |
| 864 | + "axs[2].errorbar(xticks, ds_cs_mean, fmt='b', yerr=ds_cs_std, label='CS', transform=trans1+axs[2].transData, elinewidth=2)\n", |
| 865 | + "axs[2].errorbar(xticks, ds_cc_mean, fmt='r', yerr=ds_cc_std, label='CC', transform=trans2+axs[2].transData, elinewidth=2)\n", |
| 866 | + "axs[2].errorbar(xticks, ds_sst_mean, fmt='g', yerr=ds_sst_std, label='SST', transform=trans3+axs[2].transData, elinewidth=2)\n", |
| 867 | + "axs[2].errorbar(xticks, ds_pv_mean, fmt='y', yerr=ds_pv_std, label='PV', transform=trans4+axs[2].transData, elinewidth=2)\n", |
862 | 868 | "axs[2].set_xticks(xticks)\n", |
863 | 869 | "axs[2].set_xticklabels(cc_cs_weights)\n", |
864 | 870 | "axs[2].set_xlabel('Scalar of CC_CS connection')\n", |
|
877 | 883 | "fig, axs = plt.subplots(1, 2, figsize=(18, 6))\n", |
878 | 884 | "xticks = np.linspace(0, len(cc_cs_weights)-1, len(cc_cs_weights))\n", |
879 | 885 | "\n", |
| 886 | + "trans1 = Affine2D().translate(-0.05, 0.0)\n", |
| 887 | + "trans2 = Affine2D().translate(+0.0, 0.0)\n", |
| 888 | + "\n", |
880 | 889 | "firing_rates_mean_over_input_degrees = np.mean(firing_rates_mean_over_simulations, axis=0)\n", |
881 | 890 | "firing_rates_std_over_input_degrees = np.mean(firing_rates_std_over_simulations, axis=0)\n", |
882 | 891 | "\n", |
|
890 | 899 | "fire_rate_pv_std = [fire_rate_std_by_weight[3] for fire_rate_std_by_weight in firing_rates_std_over_input_degrees]\n", |
891 | 900 | "\n", |
892 | 901 | "# Plot firing rate\n", |
893 | | - "axs[0].errorbar(xticks, fire_rate_cs_mean, fmt='b', yerr=fire_rate_cs_std, label='CS')\n", |
894 | | - "axs[0].errorbar(xticks, fire_rate_cc_mean, fmt='r', yerr=fire_rate_cc_std, label='CC')\n", |
895 | | - "axs[0].errorbar(xticks, fire_rate_sst_mean, fmt='g', yerr=fire_rate_sst_std, label='SST')\n", |
896 | | - "axs[0].errorbar(xticks, fire_rate_pv_mean, fmt='y', yerr=fire_rate_pv_std, label='PV')\n", |
| 902 | + "axs[0].errorbar(xticks, fire_rate_cs_mean, fmt='b', yerr=fire_rate_cs_std, label='CS', elinewidth=2)\n", |
| 903 | + "axs[0].errorbar(xticks, fire_rate_cc_mean, fmt='r', yerr=fire_rate_cc_std, label='CC', elinewidth=2)\n", |
| 904 | + "axs[0].errorbar(xticks, fire_rate_sst_mean, fmt='g', yerr=fire_rate_sst_std, label='SST', elinewidth=2)\n", |
| 905 | + "axs[0].errorbar(xticks, fire_rate_pv_mean, fmt='y', yerr=fire_rate_pv_std, label='PV', elinewidth=2)\n", |
897 | 906 | "axs[0].set_xticks(xticks)\n", |
898 | 907 | "axs[0].set_xticklabels(cc_cs_weights)\n", |
899 | 908 | "axs[0].set_xlabel('Scalar of CC_CS connection')\n", |
|
902 | 911 | "axs[0].legend(loc='best')\n", |
903 | 912 | "\n", |
904 | 913 | "# Plot os_rel, os_paper_rel and ds_rel\n", |
905 | | - "axs[1].errorbar(xticks, os_rel_mean_data, yerr=os_rel_std_data, label='OS rel')\n", |
| 914 | + "axs[1].errorbar(xticks, os_rel_mean_data, yerr=os_rel_std_data, label='OS rel', transform=trans1+axs[1].transData, elinewidth=2)\n", |
906 | 915 | "# axs[1].errorbar(xticks, os_paper_rel_mean_data, yerr=os_paper_rel_std_data, label='OS Paper rel')\n", |
907 | | - "axs[1].errorbar(xticks, ds_rel_mean_data, yerr=ds_rel_std_data, label='DS rel')\n", |
| 916 | + "axs[1].errorbar(xticks, ds_rel_mean_data, yerr=ds_rel_std_data, label='DS rel', transform=trans2+axs[1].transData, elinewidth=2)\n", |
908 | 917 | "axs[1].set_xticks(xticks)\n", |
909 | 918 | "axs[1].set_xticklabels(cc_cs_weights)\n", |
910 | 919 | "axs[1].set_xlabel('Scalar of CC_CS connection')\n", |
|
0 commit comments