diff --git a/split_mnit.py b/split_mnit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3da899707c894fb003263743d68eaea2b0f986c3
--- /dev/null
+++ b/split_mnit.py
@@ -0,0 +1,101 @@
+import tensorflow as tf
+from keras.models import Model
+from keras.layers import Dense, Input, Lambda, Dropout
+from configuration import conf
+from utils.dataloader import Sequential_loader
+import numpy as np
+from utils.layers import Probability_CLF_Mul_by_task
+
+
+def mask_layer_by_task(task_input, input_tensor, name=None, return_mask=False):
+    mask = tf.expand_dims(task_input, axis=-1)
+    mask = tf.tile(mask, multiples=[1, 1, input_tensor.shape[1] // conf.num_tasks])
+    mask = tf.keras.layers.Flatten()(mask)
+    if name is None:
+        out = Lambda(lambda x: x * mask)(input_tensor)
+    else:
+        out = Lambda(lambda x: x * mask, name=name)(input_tensor)
+    if return_mask:
+        return out, mask
+    else:
+        return out
+
+
+def get_model_keras_mask(output_dim, label=None):
+    inputs = Input(shape=(784,))
+    task_input = Input(shape=(5,))
+    archi = Dense(1000, activation='relu')(inputs)
+    archi = mask_layer_by_task(task_input, archi)
+    archi = Dense(1000, activation='relu')(archi)
+    archi = mask_layer_by_task(task_input, archi)
+
+    task_output = Probability_CLF_Mul_by_task(conf.num_tasks, num_centers=output_dim // conf.num_tasks)(
+        [task_input, archi])
+    task_output = mask_layer_by_task(task_input, task_output, 'task_out')
+    clf = Dense(output_dim, activation='softmax')(archi)
+    clf = mask_layer_by_task(task_input, clf, 'clf_out')
+    model = Model(inputs=[inputs, task_input], outputs=[clf, task_output])
+    model_latent = Model(inputs=inputs, outputs=archi)
+    model.compile(loss=['categorical_crossentropy', 'mse'], optimizer='adam', metrics=['accuracy', 'mse'],
+                  loss_weights=[1, 4])
+
+    return model, model_latent
+
+
+data_loader = Sequential_loader()
+
+model, model_latent = get_model_keras_mask(10)
+for task_idx in range(conf.num_tasks):
+    x, y = data_loader.sample(task_idx=task_idx, whole_set=True)
+    task_input = np.zeros([y.shape[0], conf.num_tasks])
+    task_input[:, task_idx] = 1
+    model.fit([x, task_input], [y, task_input], epochs=10, batch_size=conf.batch_size, verbose=0)
+    if task_idx == 0:
+        model.layers[1].trainable = False
+        model.compile(loss=['categorical_crossentropy', 'mse'], optimizer='adam', metrics=['accuracy', 'mse'],
+                  loss_weights=[1, 4])
+
+for task_idx in range(conf.num_tasks):
+    x, y = data_loader.sample(task_idx, whole_set=True, dataset='test')
+    for test_idx in range(conf.num_tasks):
+        task_input = np.zeros([y.shape[0], conf.num_tasks])
+        task_input[:, test_idx] = 1
+        res = np.max(model.predict([x, task_input])[1], axis=1)
+
+
+block_size = conf.test_batch_size
+
+
+def block_likelihood(res):
+    block_likelihood = []
+    for r in res:
+        extra_index = r.shape[0] % block_size
+        extra_values = r[-extra_index:]
+        resize_values = r[:-extra_index]
+        r = resize_values.reshape(-1, block_size)
+        r = np.mean(r, axis=1, keepdims=True)
+        r = np.repeat(r, block_size, axis=1).reshape(-1, )
+        extra = np.repeat(np.mean(extra_values), len(extra_values))
+        final = np.append(r, extra)
+        block_likelihood.append(final)
+    return block_likelihood
+
+
+test_acc = []
+for task_idx in range(conf.num_tasks):
+    x, y = data_loader.sample(task_idx, whole_set=True, dataset='test')
+    res = []
+    pred = []
+    for test_idx in range(conf.num_tasks):
+        task_input = np.zeros([y.shape[0], conf.num_tasks])
+        task_input[:, test_idx] = 1
+        prediction = model.predict([x, task_input])
+        res.append(np.max(prediction[1], axis=1))
+        pred.append(np.argmax(prediction[0], axis=1))
+
+    res = block_likelihood(res)
+    pred = np.array(pred)
+    acc = np.sum(pred[np.argmax(res, axis=0), np.arange(pred.shape[1])] == np.argmax(y, axis=1)) / y.shape[0]
+    print('Task %d, Accuracy %.3f' % (task_idx, acc))
+    test_acc.append(acc)
+print('Average of Test Accuracy : %.3f' % np.mean(test_acc))