diff --git a/astro_lab.py b/astro_lab.py
index 3d2efcde6943dccda79d5371f2458eb98de58207..bd56b258e123fd35f26366ccc5d4cf6001ffbc95 100644
--- a/astro_lab.py
+++ b/astro_lab.py
@@ -82,11 +82,14 @@ def plot_data(zscale=True):
     plt.imshow(data, cmap='Greys_r', vmin=lims[0], vmax=lims[1], origin='lower')
     plt.show()
 
-def subtract_background():
+def subtract_background(plot=False):
     global data, error, bg_sub
     sigma_clip = SigmaClip(sigma=sclip)
     bkg_estimator = MedianBackground()
     bkg = Background2D(data, (bg_wsize, bg_wsize), sigma_clip=sigma_clip, bkg_estimator=bkg_estimator)
+    if plot:
+        plt.imshow(bkg.background, origin='lower', cmap='Greys_r')
+        plt.show()
     error = calc_total_error(data, bkg.background_rms, eff_gain)
     data = data - bkg.background
     bg_sub = True