From fb6adb84f02ab29e48223509245f08863a38dbbd Mon Sep 17 00:00:00 2001 From: "Li, Honglin (PG/R - Elec Electronic Eng)" <h.li@surrey.ac.uk> Date: Thu, 14 May 2020 19:43:42 +0100 Subject: [PATCH] Update csv_to_npy.py, utils.py files --- csv_to_npy.py | 34 +++++++++++++++++++++++++--------- utils.py | 19 +++++++++++++++++++ 2 files changed, 44 insertions(+), 9 deletions(-) create mode 100644 utils.py diff --git a/csv_to_npy.py b/csv_to_npy.py index 45a5c07..17ef442 100644 --- a/csv_to_npy.py +++ b/csv_to_npy.py @@ -1,5 +1,3 @@ -import os -import csv import numpy as np import pandas as pd from abc import abstractmethod @@ -7,6 +5,8 @@ import datetime import sys from configuration import Conf import argparse +import os +from utils import save_mkdir, save_obj, load_obj def get_args(argv): @@ -19,6 +19,7 @@ def get_args(argv): parser.add_argument('--test_date', type=str, default=None, nargs="+", help='infection date') parser.add_argument('--verbose', type=bool, default=False, help='insert the patient id and date into the data') parser.add_argument('--save_per_patient', type=bool, default=False, help='save the data per patient') + parser.add_argument('--extract_incident', type=bool, default=False, help='extract incident only') args = parser.parse_args(argv) return args @@ -34,6 +35,7 @@ class Data_loader(object): self.data = {} self.verbose = args.verbose self.save_per_patient = args.save_per_patient + self.extract_incident = args.extract_incident if self.patient_id is not None and self.test_date is None: raise ValueError('test date must be provided') self.env_feat_list = { @@ -125,6 +127,9 @@ class Data_loader(object): bt_data = self.load_body_temp(f, date_his) if self.verbose: data = [data, date_his, int(f.split('_')[0])] + elif self.extract_incident: + label, incident_info = self.load_label(f, date_his) + data = [data, incident_info] if self.patient_id is not None: test_id = int(f.split('_')[0]) @@ -142,7 +147,7 @@ class Data_loader(object): bodytemp.append(bt_data) if self.save_per_patient or self.patient_id is not None: - self.save_data() + self.save_data('patient_data') elif self.patient_id is None: self.data['env_data'] = np.concatenate(result) self.data['_label'] = np.concatenate(label) @@ -152,10 +157,17 @@ class Data_loader(object): pass self.split_label_unlabel() - def save_data(self): + def save_data(self, sub_folder=None): for key, value in self.data.items(): if key not in ['env_data', 'bodytemp', '_label']: - np.save(self.conf.npy_data + '/' + str(key) + '.npy', value) + # np.save(self.conf.npy_data + '/' + str(key) + '.npy', value) + path = self.conf.npy_data + if sub_folder is not None: + path = path + '/' + sub_folder + save_mkdir(path) + else: + pass + save_obj(value, path + '/' + str(key)) def _iter_directory(self, directory): file_list = [] @@ -180,13 +192,13 @@ class Data_loader(object): def split_label_unlabel(self): pass + class Env_loader(Data_loader): """docstring for DRI_dataloader """ def __init__(self, args): super(Env_loader, self).__init__(args) - def load_label(self, file, date_his): """ 0 - False @@ -195,12 +207,12 @@ class Env_loader(Data_loader): 3 - Test samples """ filename = list(file) - filename = ''.join(filename) filename = filename.split('_')[0] + '_flags.csv' sub_key = 'datetimeObserved' label_df = pd.read_csv(self.conf.data_path['flag'] + filename) label = np.zeros(len(date_his)) + 2 + incident_info = [] indices = label_df['element'].isin(self.incident) if len(indices) > 0: sub_df = label_df[indices] @@ -218,12 +230,14 @@ class Env_loader(Data_loader): try: if valid[d] == 'False' or valid[d] is False: label[idx] = 0 + incident_info.append([dates[d], sub_df['element'][d], False, int(file.split('_')[0])]) if self.incident == ['UTI symptoms']: for new_day in self.find_previous_day(dates[d], 2): new_idx = date_his.index(new_day) label[new_idx] = 0 elif valid[d] == 'True' or valid[d] is True: label[idx] = 1 + incident_info.append([dates[d], sub_df['element'][d], False, int(file.split('_')[0])]) if self.incident == ['UTI symptoms']: for new_day in self.find_previous_day(dates[d], 1): new_idx = date_his.index(new_day) @@ -232,6 +246,8 @@ class Env_loader(Data_loader): label[idx] = 2 except KeyError: pass + if self.extract_incident: + return label, incident_info return label def split_label_unlabel(self): @@ -251,9 +267,9 @@ class Env_loader(Data_loader): self.data['label'] = self.data['_label'][indices] - if __name__ == '__main__': args = get_args(sys.argv[1:]) dataloader = Env_loader(args) dataloader.load_env() - dataloader.save_data() + if args.save_per_patient is False and args.patient_id is None: + dataloader.save_data() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..bfd518d --- /dev/null +++ b/utils.py @@ -0,0 +1,19 @@ +import os +import pickle + + +def save_obj(obj, name): + with open(name + '.pkl', 'wb') as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + +def load_obj(name): + with open(name, 'rb') as f: + return pickle.load(f) + + +def save_mkdir(path): + try: + os.stat(path) + except: + os.mkdir(path) + -- GitLab