From fb6adb84f02ab29e48223509245f08863a38dbbd Mon Sep 17 00:00:00 2001
From: "Li, Honglin (PG/R - Elec Electronic Eng)" <h.li@surrey.ac.uk>
Date: Thu, 14 May 2020 19:43:42 +0100
Subject: [PATCH] Update csv_to_npy.py, utils.py files

---
 csv_to_npy.py | 34 +++++++++++++++++++++++++---------
 utils.py      | 19 +++++++++++++++++++
 2 files changed, 44 insertions(+), 9 deletions(-)
 create mode 100644 utils.py

diff --git a/csv_to_npy.py b/csv_to_npy.py
index 45a5c07..17ef442 100644
--- a/csv_to_npy.py
+++ b/csv_to_npy.py
@@ -1,5 +1,3 @@
-import os
-import csv
 import numpy as np
 import pandas as pd
 from abc import abstractmethod
@@ -7,6 +5,8 @@ import datetime
 import sys
 from configuration import Conf
 import argparse
+import os
+from utils import save_mkdir, save_obj, load_obj
 
 
 def get_args(argv):
@@ -19,6 +19,7 @@ def get_args(argv):
     parser.add_argument('--test_date', type=str, default=None, nargs="+", help='infection date')
     parser.add_argument('--verbose', type=bool, default=False, help='insert the patient id and date into the data')
     parser.add_argument('--save_per_patient', type=bool, default=False, help='save the data per patient')
+    parser.add_argument('--extract_incident', type=bool, default=False, help='extract incident only')
     args = parser.parse_args(argv)
     return args
 
@@ -34,6 +35,7 @@ class Data_loader(object):
         self.data = {}
         self.verbose = args.verbose
         self.save_per_patient = args.save_per_patient
+        self.extract_incident = args.extract_incident
         if self.patient_id is not None and self.test_date is None:
             raise ValueError('test date must be provided')
         self.env_feat_list = {
@@ -125,6 +127,9 @@ class Data_loader(object):
             bt_data = self.load_body_temp(f, date_his)
             if self.verbose:
                 data = [data, date_his, int(f.split('_')[0])]
+            elif self.extract_incident:
+                label, incident_info = self.load_label(f, date_his)
+                data = [data, incident_info]
 
             if self.patient_id is not None:
                 test_id = int(f.split('_')[0])
@@ -142,7 +147,7 @@ class Data_loader(object):
                 bodytemp.append(bt_data)
 
         if self.save_per_patient or self.patient_id is not None:
-            self.save_data()
+            self.save_data('patient_data')
         elif self.patient_id is None:
             self.data['env_data'] = np.concatenate(result)
             self.data['_label'] = np.concatenate(label)
@@ -152,10 +157,17 @@ class Data_loader(object):
                 pass
             self.split_label_unlabel()
 
-    def save_data(self):
+    def save_data(self, sub_folder=None):
         for key, value in self.data.items():
             if key not in ['env_data', 'bodytemp', '_label']:
-                np.save(self.conf.npy_data + '/' + str(key) + '.npy', value)
+                # np.save(self.conf.npy_data + '/' + str(key) + '.npy', value)
+                path = self.conf.npy_data
+                if sub_folder is not None:
+                    path = path + '/' + sub_folder
+                    save_mkdir(path)
+                else:
+                    pass
+                save_obj(value, path + '/' + str(key))
 
     def _iter_directory(self, directory):
         file_list = []
@@ -180,13 +192,13 @@ class Data_loader(object):
     def split_label_unlabel(self):
         pass
 
+
 class Env_loader(Data_loader):
     """docstring for DRI_dataloader	"""
 
     def __init__(self, args):
         super(Env_loader, self).__init__(args)
 
-
     def load_label(self, file, date_his):
         """
         0 - False
@@ -195,12 +207,12 @@ class Env_loader(Data_loader):
         3 - Test samples
         """
         filename = list(file)
-
         filename = ''.join(filename)
         filename = filename.split('_')[0] + '_flags.csv'
         sub_key = 'datetimeObserved'
         label_df = pd.read_csv(self.conf.data_path['flag'] + filename)
         label = np.zeros(len(date_his)) + 2
+        incident_info = []
         indices = label_df['element'].isin(self.incident)
         if len(indices) > 0:
             sub_df = label_df[indices]
@@ -218,12 +230,14 @@ class Env_loader(Data_loader):
                 try:
                     if valid[d] == 'False' or valid[d] is False:
                         label[idx] = 0
+                        incident_info.append([dates[d], sub_df['element'][d], False, int(file.split('_')[0])])
                         if self.incident == ['UTI symptoms']:
                             for new_day in self.find_previous_day(dates[d], 2):
                                 new_idx = date_his.index(new_day)
                                 label[new_idx] = 0
                     elif valid[d] == 'True' or valid[d] is True:
                         label[idx] = 1
+                        incident_info.append([dates[d], sub_df['element'][d], False, int(file.split('_')[0])])
                         if self.incident == ['UTI symptoms']:
                             for new_day in self.find_previous_day(dates[d], 1):
                                 new_idx = date_his.index(new_day)
@@ -232,6 +246,8 @@ class Env_loader(Data_loader):
                         label[idx] = 2
                 except KeyError:
                     pass
+        if self.extract_incident:
+            return label, incident_info
         return label
 
     def split_label_unlabel(self):
@@ -251,9 +267,9 @@ class Env_loader(Data_loader):
         self.data['label'] = self.data['_label'][indices]
 
 
-
 if __name__ == '__main__':
     args = get_args(sys.argv[1:])
     dataloader = Env_loader(args)
     dataloader.load_env()
-    dataloader.save_data()
+    if args.save_per_patient is False and args.patient_id is None:
+        dataloader.save_data()
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..bfd518d
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,19 @@
+import os
+import pickle
+
+
+def save_obj(obj, name):
+    with open(name + '.pkl', 'wb') as f:
+        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
+
+def load_obj(name):
+    with open(name, 'rb') as f:
+        return pickle.load(f)
+
+
+def save_mkdir(path):
+    try:
+        os.stat(path)
+    except:
+        os.mkdir(path)
+
-- 
GitLab