From 74bbc0bd6809ea1d57988342172660a04bff1cd6 Mon Sep 17 00:00:00 2001
From: "Li, Honglin (PG/R - Elec Electronic Eng)" <h.li@surrey.ac.uk>
Date: Fri, 15 May 2020 12:39:31 +0100
Subject: [PATCH] Update csv_to_npy.py

---
 csv_to_npy.py | 35 +++++++++++++++++++++++------------
 1 file changed, 23 insertions(+), 12 deletions(-)

diff --git a/csv_to_npy.py b/csv_to_npy.py
index d8d521b..5d7136c 100644
--- a/csv_to_npy.py
+++ b/csv_to_npy.py
@@ -21,6 +21,7 @@ def get_args(argv):
     parser.add_argument('--save_per_patient', type=bool, default=False, help='save the data per patient')
     parser.add_argument('--extract_incident', type=bool, default=False, help='extract incident only')
     parser.add_argument('--save_dir', type=str, default=None, help='folder to save the data')
+    parser.add_argument('--label_previous_day', type=bool, default=False, help='label previous day as UTI infection or not')
     args = parser.parse_args(argv)
     return args
 
@@ -38,6 +39,7 @@ class Data_loader(object):
         self.save_per_patient = args.save_per_patient
         self.extract_incident = args.extract_incident
         self.save_dir = args.save_dir
+        self.label_previous_day = args.label_previous_day
         if self.patient_id is not None and self.test_date is None:
             raise ValueError('test date must be provided')
         self.env_feat_list = {
@@ -129,9 +131,6 @@ class Data_loader(object):
             bt_data = self.load_body_temp(f, date_his)
             if self.verbose:
                 data = [data, date_his, int(f.split('_')[0])]
-            elif self.extract_incident:
-                label, incident_info = self.load_label(f, date_his)
-                data = [data, incident_info]
 
             if self.patient_id is not None:
                 test_id = int(f.split('_')[0])
@@ -142,7 +141,14 @@ class Data_loader(object):
                     self.data[test_id].append((data[day], bt_data[day]))
             elif self.save_per_patient:
                 test_id = int(f.split('_')[0])
-                self.data[test_id] = [data, bt_data]
+                if self.extract_incident:
+                    label, incident_info = self.load_label(f, date_his)
+                    data = data[label < 2]
+                    if np.sum(label < 2) > 0:
+                        incident_info = incident_info[label < 2]
+                        self.data[test_id] = [data, bt_data, incident_info]
+                else:
+                    self.data[test_id] = [data, bt_data]
             else:
                 result.append(data)
                 label.append(self.load_label(f, date_his))
@@ -159,11 +165,13 @@ class Data_loader(object):
                 pass
             self.split_label_unlabel()
 
-    def save_data(self, sub_folder=None):
+    def save_data(self):
         for key, value in self.data.items():
             if key not in ['env_data', 'bodytemp', '_label']:
                 # np.save(self.conf.npy_data + '/' + str(key) + '.npy', value)
-                path = self.conf.npy_data if self.save_dir is None else self.save_dir
+                path = self.conf.npy_data
+                if self.save_dir is not None:
+                    path = path + '/' + self.save_dir
                 save_mkdir(path)
                 save_obj(value, path + '/' + str(key))
 
@@ -210,7 +218,7 @@ class Env_loader(Data_loader):
         sub_key = 'datetimeObserved'
         label_df = pd.read_csv(self.conf.data_path['flag'] + filename)
         label = np.zeros(len(date_his)) + 2
-        incident_info = []
+        incident_info = [[None, None, None]] * len(label)
         indices = label_df['element'].isin(self.incident)
         if len(indices) > 0:
             sub_df = label_df[indices]
@@ -228,24 +236,27 @@ class Env_loader(Data_loader):
                 try:
                     if valid[d] == 'False' or valid[d] is False:
                         label[idx] = 0
-                        incident_info.append([dates[d], sub_df['element'][d], False, int(file.split('_')[0])])
-                        if self.incident == ['UTI symptoms']:
+                        incident_info[idx] = [dates[d], sub_df['element'][d], False, int(file.split('_')[0])]
+                        if self.incident == ['UTI symptoms'] and self.label_previous_day:
                             for new_day in self.find_previous_day(dates[d], 2):
                                 new_idx = date_his.index(new_day)
                                 label[new_idx] = 0
+                                incident_info[new_idx] = [dates[d], sub_df['element'][d], False, int(file.split('_')[0])]
                     elif valid[d] == 'True' or valid[d] is True:
                         label[idx] = 1
-                        incident_info.append([dates[d], sub_df['element'][d], False, int(file.split('_')[0])])
-                        if self.incident == ['UTI symptoms']:
+                        incident_info[idx] = [dates[d], sub_df['element'][d], True, int(file.split('_')[0])]
+                        if self.incident == ['UTI symptoms'] and self.label_previous_day:
                             for new_day in self.find_previous_day(dates[d], 1):
                                 new_idx = date_his.index(new_day)
                                 label[new_idx] = 1
+                                incident_info[new_idx] = [dates[d], sub_df['element'][d], True, int(file.split('_')[0])]
                     else:
                         label[idx] = 2
+                        incident_info[idx] = None
                 except KeyError:
                     pass
         if self.extract_incident:
-            return label, incident_info
+            return label, np.array(incident_info)
         return label
 
     def split_label_unlabel(self):
-- 
GitLab