diff --git a/split_mnist.py b/split_mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd108f10239a1544c921f1df9115f08cd5c5fa47
--- /dev/null
+++ b/split_mnist.py
@@ -0,0 +1,107 @@
+from keras.models import Model
+from keras.layers import Dense, Input, Conv2D, Flatten, MaxPooling2D
+from configuration import conf
+from utils.dataloader import Sequential_loader
+import numpy as np
+from utils.model_utils import mask_layer_by_task
+from utils.layers import Probability_CLF_Mul_by_task
+from utils.train_utils import train_with_task
+from utils.predict_utils import get_task_likelihood, get_test_acc
+
+PATH = './results/%s/' % conf.dataset_name
+
+epochs = 50
+latent_dim = 250
+output_dim = 10
+verbose = 0
+
+data_loader = Sequential_loader()
+
+inputs = Input(shape=(784,))
+task_input = Input(shape=(5,))
+archi = Dense(1000, activation='relu')(inputs)
+archi = mask_layer_by_task(task_input, archi)
+archi = Dense(1000, activation='relu')(archi)
+archi = mask_layer_by_task(task_input, archi)
+
+task_output = Probability_CLF_Mul_by_task(conf.num_tasks, num_centers=output_dim // conf.num_tasks)(
+    [task_input, archi])
+task_output = mask_layer_by_task(task_input, task_output, 'task_out')
+clf = Dense(output_dim, activation='softmax')(archi)
+clf = mask_layer_by_task(task_input, clf, 'clf_out')
+model = Model(inputs=[inputs, task_input], outputs=[clf, task_output])
+model.compile(loss=['categorical_crossentropy', 'mse'], optimizer='adam', metrics=['accuracy', 'mse'],
+              loss_weights=[1, 4])
+
+tlh = [] # Task Likelihood
+tlh_std = [] # Standard Deviation of Task Likelihood
+test_acc = []
+for task_idx in range(conf.num_tasks):
+    # Learn a new task
+    train_with_task(model, task_idx=task_idx, data_loader=data_loader)
+    # Get the likelihood of the current task
+    mean, std = get_task_likelihood(model, learned_task=task_idx, test_task=task_idx, data_loader=data_loader)
+    tlh.append(mean)
+    tlh_std.append(std)
+    # Get the likelihood of the next task
+    if task_idx < conf.num_tasks - 1:
+        mean, std = get_task_likelihood(model, learned_task=task_idx, test_task=task_idx+1, data_loader=data_loader)
+        tlh.append(mean)
+        tlh_std.append(std)
+    # Run 200 times to get the test accuracy (for drawing the figure)
+    for _ in range(conf.num_runs):
+        test_acc.append(get_test_acc(model,data_loader,test_on_whole_set=False))
+    # Print the average test accuracy across all the tasks
+    print('Learned %dth Task, Average test accuracy on all the task : %.3f'%(task_idx,get_test_acc(model, data_loader, test_on_whole_set=True)))
+
+
+def paa(sample, w=None):
+    w = sample.shape[0] // 20 if w is None else w
+    l = len(sample)
+    stepfloat = l / w
+    step = int(np.ceil(stepfloat))
+    start = 0
+    j = 1
+    paa = []
+    while start <= (l - step):
+        section = sample[start:start + step]
+        paa.append(np.mean(section))
+        start = int(j * stepfloat)
+        j += 1
+    return paa
+
+
+tlh_s = []
+for i in tlh:
+    tlh_s += i.tolist()
+tlh_s = np.array(tlh_s)
+
+tlh_std_s = []
+for i in tlh_std:
+    tlh_std_s += i.tolist()
+tlh_std_s = np.array(tlh_std_s)
+
+test_acc_s = np.array(test_acc).reshape(-1)
+
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+sns.set()
+
+tlh = np.array(paa(tlh_s))
+tlh_std = np.array(paa(tlh_std_s))
+test_acc = np.array(paa(test_acc_s, tlh.shape[0]))
+
+fig = sns.lineplot(np.arange(len(tlh)), tlh, label='Task Likelihood')
+fig.fill_between(np.arange(len(tlh)), tlh - tlh_std, tlh + tlh_std, alpha=0.3)
+fig = sns.lineplot(np.arange(len(tlh)), test_acc, label='Test Accuracy')
+a = [10, 30, 50, 70]
+for i in a:
+    fig.fill_between(np.arange(i, i + 10 + 1), 0, 0.1, alpha=0.1, color='red')
+    fig.fill_between(np.arange(i - 10, i + 1), 0, 0.1, alpha=0.1, color='green')
+fig.fill_between(np.arange(90 - 10, 90), 0, 0.1, alpha=0.1, color='green')
+# a = fig.get_xticklabels()
+fig.set_xticklabels(['', 'Task 1', 'Task 2', 'Task 3', 'Task 4', 'Task 5'])
+plt.legend(loc='center right')
+plt.savefig(PATH + 'result')
+plt.show()