diff --git a/results/mnist/result.png b/results/mnist/result.png
index c6c2081b3c0dbca771c81f6d560fc2251926ce30..5daa9a4e93ba06f6900e491b45b285d56a8bebb1 100644
Binary files a/results/mnist/result.png and b/results/mnist/result.png differ
diff --git a/split_mnist.py b/split_mnist.py
index bd108f10239a1544c921f1df9115f08cd5c5fa47..0e7d1510ac949129a9d7416ad184c82194f904db 100644
--- a/split_mnist.py
+++ b/split_mnist.py
@@ -10,8 +10,7 @@ from utils.predict_utils import get_task_likelihood, get_test_acc
 
 PATH = './results/%s/' % conf.dataset_name
 
-epochs = 50
-latent_dim = 250
+epochs = 10
 output_dim = 10
 verbose = 0
 
@@ -33,8 +32,8 @@ model = Model(inputs=[inputs, task_input], outputs=[clf, task_output])
 model.compile(loss=['categorical_crossentropy', 'mse'], optimizer='adam', metrics=['accuracy', 'mse'],
               loss_weights=[1, 4])
 
-tlh = [] # Task Likelihood
-tlh_std = [] # Standard Deviation of Task Likelihood
+tlh = []  # Task Likelihood
+tlh_std = []  # Standard Deviation of Task Likelihood
 test_acc = []
 for task_idx in range(conf.num_tasks):
     # Learn a new task
@@ -45,14 +44,15 @@ for task_idx in range(conf.num_tasks):
     tlh_std.append(std)
     # Get the likelihood of the next task
     if task_idx < conf.num_tasks - 1:
-        mean, std = get_task_likelihood(model, learned_task=task_idx, test_task=task_idx+1, data_loader=data_loader)
+        mean, std = get_task_likelihood(model, learned_task=task_idx, test_task=task_idx + 1, data_loader=data_loader)
         tlh.append(mean)
         tlh_std.append(std)
     # Run 200 times to get the test accuracy (for drawing the figure)
     for _ in range(conf.num_runs):
-        test_acc.append(get_test_acc(model,data_loader,test_on_whole_set=False))
+        test_acc.append(get_test_acc(model, learned_task=task_idx, data_loader=data_loader, test_on_whole_set=False))
     # Print the average test accuracy across all the tasks
-    print('Learned %dth Task, Average test accuracy on all the task : %.3f'%(task_idx,get_test_acc(model, data_loader, test_on_whole_set=True)))
+    print('Learned %dth Task, Average test accuracy on all the task : %.3f' % (
+    task_idx, get_test_acc(model, learned_task=task_idx, data_loader=data_loader, test_on_whole_set=True)))
 
 
 def paa(sample, w=None):
diff --git a/utils/__pycache__/predict_utils.cpython-36.pyc b/utils/__pycache__/predict_utils.cpython-36.pyc
index f37f7415c433357c3ef575205da3c3e5053e97e8..ede71aa1a0f772c2a971eb43ca801e7a7a7c968b 100644
Binary files a/utils/__pycache__/predict_utils.cpython-36.pyc and b/utils/__pycache__/predict_utils.cpython-36.pyc differ
diff --git a/utils/predict_utils.py b/utils/predict_utils.py
index 0c254a9af23521671822371bef2e6321284fa8f5..1b2e82c8ab3426fc72a1803bbc06fd0ea1bdd65c 100644
--- a/utils/predict_utils.py
+++ b/utils/predict_utils.py
@@ -45,7 +45,7 @@ def get_task_likelihood(model, learned_task, test_task, data_loader):
            task_likelihood_var[np.argmax(task_likelihood, axis=0), np.arange(task_likelihood_var.shape[1])]
 
 
-def get_test_acc(model, data_loader, test_on_whole_set=True):
+def get_test_acc(model, learned_task, data_loader, test_on_whole_set=True, verbose=0):
     test_acc = []
     for task_idx in range(conf.num_tasks):
         if test_on_whole_set:
@@ -54,7 +54,7 @@ def get_test_acc(model, data_loader, test_on_whole_set=True):
             x, y = data_loader.sample(task_idx, batch_size=conf.test_batch_size, dataset='test')
         res = []
         pred = []
-        for test_idx in range(conf.num_tasks):
+        for test_idx in range(learned_task+1):
             task_input = np.zeros([y.shape[0], conf.num_tasks])
             task_input[:, test_idx] = 1
             prediction = model.predict([x, task_input])
@@ -66,4 +66,7 @@ def get_test_acc(model, data_loader, test_on_whole_set=True):
         acc = np.sum(pred[np.argmax(res, axis=0), np.arange(pred.shape[1])] == np.argmax(y, axis=1)) / y.shape[0]
         test_acc.append(acc)
 
+    if verbose:
+        print('Test accuracy on all the tasks : ', test_acc)
+
     return np.mean(test_acc)