diff --git a/README.md b/README.md
index bddd6d221bfe8a9571c46b49b59ce935c89359d0..83ad8a62185f13252135de78006fa801a0c98d0c 100644
--- a/README.md
+++ b/README.md
@@ -25,9 +25,13 @@ roscore
 roslaunch open_manipulator_gazebo open_manipulator_gazebo.launch
 roslaunch open_manipulator_controller open_manipulator_controller.launch use_platform:=false
 ```
-It is important to run the two roslaunch commands as the former runs Gazebo simulation with the robotic arm, whereas the latter allows for control of the arm.
+It is important to run the two roslaunch commands as the former runs Gazebo simulation with the robotic arm, and the latter allows for control of the arm.
 **IMPORTANT**: Before running the following scripts, it is important to press play at the bottom of Gazebo, otherwise the arm will not move.
 
+## Roslaunch
+There is two `roslaunch` files to run currently. The first one `keypoint_collection.launch` runs the collection of data points from pose estimation. This data is being used to train the data. It outputs a pickle file. The second one is `training_data.launch` which uses the output of the first launch file, and converts the keypoints into a training dataset. The labels are obtained by calculating the inverse kinematics of the end-effector. This is also used to visualize and test the data in RVIZ and Gazebo.
+
+## Files
 At the moment, there are currently three files that can be run - in the future the functionality will be combined. The first one (1) runs the video capture of the keypoints in real time and moving the robotic arm in Gazebo [*It is currently being debugged due to new functionality being added*]. The second one (2) is a keypoint capture program which saves the keypoints into a pickle file. Finally, (3) loads the saved keypoints (hence 2 and 3 have to be run in order for 3 to work) and visualizes them in RVIZ and calculates the manipulator keypoints. Currently, 1 and 3 use different approaches. The accuracy is currently being tested.
 
 1. Arm Keypoint Capture + Pose Estimation Control
diff --git a/pose_estimation/data/collected_data.pickle b/pose_estimation/data/collected_data.pickle
new file mode 100644
index 0000000000000000000000000000000000000000..d0224b1b532ce84a296691231b5a1a718e3c725a
Binary files /dev/null and b/pose_estimation/data/collected_data.pickle differ
diff --git a/pose_estimation/data/gesture_recognizer.task b/pose_estimation/data/gesture_recognizer.task
new file mode 100644
index 0000000000000000000000000000000000000000..1c6adc8b497d3e6d6603dfc7c864136d75f66e88
Binary files /dev/null and b/pose_estimation/data/gesture_recognizer.task differ
diff --git a/pose_estimation/data/keypoint_dataset.pickle b/pose_estimation/data/keypoint_dataset.pickle
new file mode 100644
index 0000000000000000000000000000000000000000..7f5559578e526810010f2f533a970ed0b411b88c
Binary files /dev/null and b/pose_estimation/data/keypoint_dataset.pickle differ
diff --git a/pose_estimation/data/models/joint_predict.pt b/pose_estimation/data/models/joint_predict.pt
new file mode 100644
index 0000000000000000000000000000000000000000..13efa4648168af601e5153659e45ce22cafc5e7f
Binary files /dev/null and b/pose_estimation/data/models/joint_predict.pt differ
diff --git a/pose_estimation/include/pose_estimation/kinematics_solver.h b/pose_estimation/include/pose_estimation/kinematics_solver.h
new file mode 100644
index 0000000000000000000000000000000000000000..c62bdc490462ebd12ae4e9e2f9372a614cd6a27d
--- /dev/null
+++ b/pose_estimation/include/pose_estimation/kinematics_solver.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#include "open_manipulator_libs/kinematics.h"
+#include "open_manipulator_libs/open_manipulator.h"
+#include "open_manipulator_msgs/GetKinematicsPose.h"
+#include "open_manipulator_msgs/SetKinematicsPose.h"
+#include "pose_estimation/JointPositions.h"
+#include "pose_estimation/Joints.h"
+#include <robotis_manipulator/robotis_manipulator.h>
+#include "ros/ros.h"
+   
+class KinematicsSolver
+{
+public:
+  KinematicsSolver(bool using_platform, std::string usb_port, std::string baud_rate, double control_period);
+   
+  void keypointsInverseCallback(const open_manipulator_msgs::KinematicsPose &msg);
+  void keypointsForwardCallback(const pose_estimation::Joints &msg);
+   
+  ros::NodeHandle getNodeHandle() const { return n_; } 
+
+private:
+  OpenManipulator open_manipulator_;
+  std::vector<JointValue>* goal_joint_value_;
+   
+  ros::NodeHandle n_;
+  ros::Publisher ik_pub_;
+  ros::Publisher fk_pub_;
+  robotis_manipulator::Kinematics *kinematics_;
+
+  void solveIK(Pose target_pose, const open_manipulator_msgs::KinematicsPose& manipulator_pose);
+  void solveFK(const pose_estimation::Joints &msg);
+};
\ No newline at end of file
diff --git a/pose_estimation/launch/keypoint_collection.launch b/pose_estimation/launch/keypoint_collection.launch
new file mode 100644
index 0000000000000000000000000000000000000000..4bd9a1c312f3086e82b03834879196a9cdb485ba
--- /dev/null
+++ b/pose_estimation/launch/keypoint_collection.launch
@@ -0,0 +1,36 @@
+<launch>
+    <arg name="gesture_model" doc="Name of the gesture recognition model. Located in `data/`" default="gesture_recognizer.task" />
+    <arg name="output_file" doc="Name of the file where the data will be saved" default="collected_data.pickle" />
+    <arg name="video_save" doc="Save video in the output file? ('true' or 'false')" default="true" />
+    
+    <arg name="manipulator_control" doc="Control or Collection?" default="false" />
+    <arg name="ik_model" doc="Name trained model. Located in `data/models`" default="joint_predict.pt" />
+
+    <group if="$(eval arg('manipulator_control') == true)">
+        <!-- Running Gazebo Simulation with Open Manipulator-X if control is on -->
+        <include file="$(find open_manipulator_gazebo)/launch/open_manipulator_gazebo.launch">
+            <arg name="paused" value="false"/>
+        </include>
+
+        <arg name="usb_port"               default="/dev/ttyUSB0"/>
+        <arg name="baud_rate"              default="1000000"/>
+        <arg name="control_period"         default="0.010"/>
+        <arg name="use_platform"           default="false"/>
+        
+        <!-- Running the Manipulator Controller if control is on -->
+        <node name="open_manipulator_controller" pkg="open_manipulator_controller" type="open_manipulator_controller" 
+            output="log" args="$(arg usb_port) $(arg baud_rate)">
+            <param name="control_period"       value="$(arg control_period)"/>
+            <param name="using_platform"       value="$(arg use_platform)"/>
+        </node>
+    </group>
+
+    <node name="keypoint_collection" pkg="pose_estimation" type="data_collection.py" output="screen">
+        <param name="gesture_model"         value="$(arg gesture_model)"/>
+        <param name="filename"              value="$(arg output_file)"/>
+        <param name="video_save"            value="$(arg video_save)"/>
+        <param name="ik_model"              value="$(arg ik_model)"/>
+        <param name="manipulator_control"   value="$(arg manipulator_control)"/>
+    </node>
+    
+</launch>
diff --git a/pose_estimation/launch/train_evaluate.launch b/pose_estimation/launch/train_evaluate.launch
new file mode 100644
index 0000000000000000000000000000000000000000..00ebce37d729809bc53b138f6a05e76e87eb3463
--- /dev/null
+++ b/pose_estimation/launch/train_evaluate.launch
@@ -0,0 +1,35 @@
+<launch>
+    <arg name = "evaluate" doc = "Run evaluation {boolean}" default = "true" />
+    <arg name = "plot" doc = "Plot training loss {boolean}" default = "false" />
+    
+    <arg name = "model" doc = "The name of the model to be saved in data/models/ {string}" default = "joint_predict.pt" />
+    <arg name = "dataset" doc = "The name of the dataset saved in data/ {string}" default = "keypoint_dataset.pickle" />
+    <arg name = "loss" doc = "[huber (default), mse] {string}" default = "huber" />
+    <arg name = "scheduler" doc = "[None (default)] {string}" default = "None" />
+    <arg name = "initial_lr" doc = "Initial learning rate for training the model {float}" default = "1e-1" />
+    <arg name = "epochs" doc = "Number of epochs the model will be trained for {integer}" default = "100" />
+    <arg name = "validate" doc = "Run validation with positional loss {boolean}" default = "false" />
+    
+    <arg name = "NAS" doc = "Run Neural Architecture Search {boolean}" default = "false" />
+    <arg name = "NAS_iter" doc = "Number of variations to run NAS [10] {int}" default = "10"/>
+
+    <!-- Running Kinematics Solver Node -->
+    <node name="kinematics_solver" pkg="pose_estimation" type="kinematics_solver" output="screen" />
+
+    <!-- Running the training loop -->
+    <node name="training" pkg="pose_estimation" type="training.py" output="screen">
+        <param name = "model"         value = "$(arg model)"/>
+        <param name = "dataset"       value = "$(arg dataset)"/>
+        <param name = "loss"          value = "$(arg loss)"/>
+        <param name = "scheduler"     value = "$(arg scheduler)"/>
+        <param name = "plot"          value = "$(arg plot)"/>
+        <param name = "initial_lr"    value = "$(arg initial_lr)"/>
+        <param name = "plot"          value = "$(arg plot)"/>
+        <param name = "evaluate"      value = "$(arg evaluate)"/>
+        <param name = "epochs"        value = "$(arg epochs)"/>
+        <param name = "validate"      value = "$(arg validate)"/>
+        <param name = "NAS"           value = "$(arg NAS)"/>
+        <param name = "NAS_iter"      value = "$(arg NAS_iter)"/>
+    </node>
+
+</launch>
diff --git a/pose_estimation/launch/training_data.launch b/pose_estimation/launch/training_data.launch
new file mode 100644
index 0000000000000000000000000000000000000000..28dac4b2ded3a9790890220d5636e2436eb25298
--- /dev/null
+++ b/pose_estimation/launch/training_data.launch
@@ -0,0 +1,37 @@
+<launch>
+    <arg name="input_keypoint_file" doc="Name of the input file with keypoints (from data_collection)" default="collected_data.pickle" />
+    <arg name="training_file" doc="Name of the file where the training data will be saved" default="keypoint_dataset.pickle" />
+    <arg name="simulate" doc="Simulate the motion in Gazebo?" default="false"/>
+    <arg name="show_video" doc="Show the input video" default="false" />
+
+    <!-- Running Gazebo Simulation with Open Manipulator-X -->
+    <!-- <include file="$(find open_manipulator_gazebo)/launch/open_manipulator_gazebo.launch" if="$(eval arg('simulate') == 'true')">
+        <arg name="paused" value="false"/>
+    </include> -->
+
+    <!-- <arg name="usb_port"               default="/dev/ttyUSB0"/>
+    <arg name="baud_rate"              default="1000000"/>
+    <arg name="control_period"         default="0.010"/>
+    <arg name="use_platform"           default="false"/>
+    
+    <node name="open_manipulator_controller" pkg="open_manipulator_controller" type="open_manipulator_controller" 
+        output="log" args="$(arg usb_port) $(arg baud_rate)" if="$(eval arg('simulate') == 'true')">
+        <param name="control_period"       value="$(arg control_period)"/>
+        <param name="using_platform"       value="$(arg use_platform)"/>
+    </node> -->
+
+
+    <!-- Running rviz with the TF graph setup -->
+    <node type="rviz" name="rviz" pkg="rviz" args="-d $(find pose_estimation)/rviz/keypoint_tf_frame.rviz" />
+
+    <!-- Running Inverse Kinematics Solver Node -->
+    <node name="kinematics_solver" pkg="pose_estimation" type="kinematics_solver" output="screen" />
+
+    <node name="training_data" pkg="pose_estimation" type="training_data.py" output="screen">
+        <param name="input_keypoint_file" value="$(arg input_keypoint_file)"/>
+        <param name="training_file"       value="$(arg training_file)"/>
+        <param name="simulate"            value="$(arg simulate)"/>
+        <param name="show_video"          value="$(arg show_video)"/>
+    </node>
+
+</launch>
diff --git a/pose_estimation/msg/JointAngle.msg b/pose_estimation/msg/JointAngle.msg
new file mode 100644
index 0000000000000000000000000000000000000000..27a75a0b34f1a8526468db736b25156a1f9cdb6e
--- /dev/null
+++ b/pose_estimation/msg/JointAngle.msg
@@ -0,0 +1,2 @@
+string name
+float64 angle
\ No newline at end of file
diff --git a/pose_estimation/msg/JointPositions.msg b/pose_estimation/msg/JointPositions.msg
new file mode 100644
index 0000000000000000000000000000000000000000..b76eaa7071c1ba2a35cc05625c4dd92b9a535e76
--- /dev/null
+++ b/pose_estimation/msg/JointPositions.msg
@@ -0,0 +1,4 @@
+JointAngle[] jointPositions
+open_manipulator_msgs/KinematicsPose manipulatorPose
+
+bool success
\ No newline at end of file
diff --git a/pose_estimation/msg/Joints.msg b/pose_estimation/msg/Joints.msg
new file mode 100644
index 0000000000000000000000000000000000000000..17883e659ea1f87f1de8bf5891e80097d6b50a38
--- /dev/null
+++ b/pose_estimation/msg/Joints.msg
@@ -0,0 +1 @@
+float64[] angles
\ No newline at end of file
diff --git a/pose_estimation/rviz/keypoint_tf_frame.rviz b/pose_estimation/rviz/keypoint_tf_frame.rviz
new file mode 100644
index 0000000000000000000000000000000000000000..3e2b384e83c24909dcd53d9be5ae160a1c14c8a9
--- /dev/null
+++ b/pose_estimation/rviz/keypoint_tf_frame.rviz
@@ -0,0 +1,135 @@
+Panels:
+  - Class: rviz/Displays
+    Help Height: 78
+    Name: Displays
+    Property Tree Widget:
+      Expanded:
+        - /Global Options1
+        - /Status1
+        - /TF1
+      Splitter Ratio: 0.5
+    Tree Height: 549
+  - Class: rviz/Selection
+    Name: Selection
+  - Class: rviz/Tool Properties
+    Expanded:
+      - /2D Pose Estimate1
+      - /2D Nav Goal1
+      - /Publish Point1
+    Name: Tool Properties
+    Splitter Ratio: 0.5886790156364441
+  - Class: rviz/Views
+    Expanded:
+      - /Current View1
+    Name: Views
+    Splitter Ratio: 0.5
+  - Class: rviz/Time
+    Name: Time
+    SyncMode: 0
+    SyncSource: ""
+Preferences:
+  PromptSaveOnExit: true
+Toolbars:
+  toolButtonStyle: 2
+Visualization Manager:
+  Class: ""
+  Displays:
+    - Alpha: 0.5
+      Cell Size: 1
+      Class: rviz/Grid
+      Color: 160; 160; 164
+      Enabled: true
+      Line Style:
+        Line Width: 0.029999999329447746
+        Value: Lines
+      Name: Grid
+      Normal Cell Count: 0
+      Offset:
+        X: 0
+        Y: 0
+        Z: 0
+      Plane: XY
+      Plane Cell Count: 10
+      Reference Frame: <Fixed Frame>
+      Value: true
+    - Class: rviz/TF
+      Enabled: true
+      Frame Timeout: 15
+      Frames:
+        All Enabled: true
+      Marker Alpha: 1
+      Marker Scale: 1
+      Name: TF
+      Show Arrows: true
+      Show Axes: true
+      Show Names: true
+      Tree:
+        {}
+      Update Interval: 0
+      Value: true
+  Enabled: true
+  Global Options:
+    Background Color: 48; 48; 48
+    Default Light: true
+    Fixed Frame: keypoint_frame
+    Frame Rate: 30
+  Name: root
+  Tools:
+    - Class: rviz/Interact
+      Hide Inactive Objects: true
+    - Class: rviz/MoveCamera
+    - Class: rviz/Select
+    - Class: rviz/FocusCamera
+    - Class: rviz/Measure
+    - Class: rviz/SetInitialPose
+      Theta std deviation: 0.2617993950843811
+      Topic: /initialpose
+      X std deviation: 0.5
+      Y std deviation: 0.5
+    - Class: rviz/SetGoal
+      Topic: /move_base_simple/goal
+    - Class: rviz/PublishPoint
+      Single click: true
+      Topic: /clicked_point
+  Value: true
+  Views:
+    Current:
+      Class: rviz/Orbit
+      Distance: 10
+      Enable Stereo Rendering:
+        Stereo Eye Separation: 0.05999999865889549
+        Stereo Focal Distance: 1
+        Swap Stereo Eyes: false
+        Value: false
+      Field of View: 0.7853981852531433
+      Focal Point:
+        X: 0
+        Y: 0
+        Z: 0
+      Focal Shape Fixed Size: true
+      Focal Shape Size: 0.05000000074505806
+      Invert Z Axis: false
+      Name: Current View
+      Near Clip Distance: 0.009999999776482582
+      Pitch: 0.785398006439209
+      Target Frame: <Fixed Frame>
+      Yaw: 0.785398006439209
+    Saved: ~
+Window Geometry:
+  Displays:
+    collapsed: false
+  Height: 846
+  Hide Left Dock: false
+  Hide Right Dock: false
+  QMainWindow State: 000000ff00000000fd000000040000000000000156000002b0fc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d000002b0000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f000002b0fc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073010000003d000002b0000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000004b00000003efc0100000002fb0000000800540069006d00650100000000000004b0000003bc00fffffffb0000000800540069006d006501000000000000045000000000000000000000023f000002b000000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
+  Selection:
+    collapsed: false
+  Time:
+    collapsed: false
+  Tool Properties:
+    collapsed: false
+  Views:
+    collapsed: false
+  Width: 1200
+  X: 677
+  Y: 105
diff --git a/pose_estimation/scripts/arm_keypoint_capture.py b/pose_estimation/scripts/arm_keypoint_capture.py
deleted file mode 100755
index 76d00c69c229bcbc98078c1ccf75a5e4f9803080..0000000000000000000000000000000000000000
--- a/pose_estimation/scripts/arm_keypoint_capture.py
+++ /dev/null
@@ -1,263 +0,0 @@
-
-import mediapipe as mp
-import numpy as np
-import rospy
-import math
-import cv2
-import vg
-import os
-
-from mediapipe.framework.formats import landmark_pb2
-
-### Importing messages
-from pose_estimation.msg import SphericalCoordinates
-
-### Replaying a pre-recorded mp4 video or using the webcam
-VIDEO_REPLAY = True
-test_video = "test_movement_2.mp4"
-VIDEO = f"{os.getcwd()}/../catkin_ws/src/pose_estimation/test/{test_video}"
-
-### Mediapipe drawing
-mp_drawing = mp.solutions.drawing_utils 
-mp_drawing_styles = mp.solutions.drawing_styles
-mp_ds = mp_drawing_styles.get_default_pose_landmarks_style()
-
-### Mediapipe Pose Estimation
-mp_pose = mp.solutions.pose
-
-# TODO: Change to identity
-### Unit axis vectors
-x_axis = np.array([1, 0, 0])
-y_axis = np.array([0, 1, 0])
-z_axis = np.array([0, 0, 1])
-
-# TODO: Define in the other file (prove mathematically)
-### Ranges 
-mp_distance_range = (0.0,0.53)
-robot_distance_range = (0.14, 0.35)
-
-### Threshold
-VISIBILITY_THRESHOLD = 0.1
-
-### Arm Landmarks
-wrist_keypoint = mp_pose.PoseLandmark.RIGHT_WRIST.value
-elbow_keypoint = mp_pose.PoseLandmark.RIGHT_ELBOW.value
-shoulder_keypoint = mp_pose.PoseLandmark.RIGHT_SHOULDER.value
-
-arm_landmarks = [
-    shoulder_keypoint, 
-    elbow_keypoint,
-    wrist_keypoint,
-]
-
-### Checks if the keypoint is relatively visible
-def verify_visibility(keypoint, visibility_threshold):
-    if keypoint.visibility <= visibility_threshold:
-        return False
-
-    return True
-
-### Adjusts values to a new range
-def adjust_range(old_min, old_max, new_min, new_max, old_value):
-    old_range = old_max - old_min
-    new_range = new_max - new_min
-
-    return (((old_value - old_min) * new_range) / old_range) + new_min
-
-
-def euclid_dist(vec1, vec2):
-  if vec1.size != vec2.size:
-    print("Size of the two points doesn't match")
-    return None
-  
-  sum_ = 0
-  for i in range(vec1.size):
-    squared_difference = (vec1[i] - vec2[i])**2
-    sum_ += squared_difference
-  
-  return math.sqrt(sum_)
-
-
-### Convert keypoint locations to a numpy array
-def get_keypoint_as_np(keypoint):
-    return np.array([keypoint.x, keypoint.y, keypoint.z])
-
-### Method used for publishing the spherical coordinates 
-def publish_keypoint(pub, results):
-
-    def zero_out(arr, n):
-        new_arr = np.copy(arr)
-        new_arr[n] = 0
-        return new_arr
-
-    shoulder = np.array([])
-    wrist = np.array([])
-    spherical_coordinates = None
-
-    for idx, keypoint in enumerate(results.pose_world_landmarks.landmark):
-        if idx == shoulder_keypoint:
-            if not verify_visibility(keypoint=keypoint, visibility_threshold = VISIBILITY_THRESHOLD):
-                rospy.logwarn("Shoulder keypoint not visible")
-                break
-
-            shoulder = get_keypoint_as_np(keypoint)
-
-        if idx == elbow_keypoint:
-            if not verify_visibility(keypoint=keypoint, visibility_threshold = VISIBILITY_THRESHOLD):
-                rospy.logwarn("Wrist keypoint not visible")
-                break
-
-            elbow = get_keypoint_as_np(keypoint)
-
-
-        if idx == wrist_keypoint and shoulder.size != 0:
-            if not verify_visibility(keypoint=keypoint, visibility_threshold = VISIBILITY_THRESHOLD):
-                rospy.logwarn("Wrist keypoint not visible")
-                break
-
-            wrist = get_keypoint_as_np(keypoint)
-
-            shifted_keypoint = wrist - shoulder
-            # Checking if the keypoint is to the left side of the shoulder, i.e. if x is positive (because the image is inverted, otherwise it should be positive)
-            if shifted_keypoint[0] > 0:
-                rospy.logerr("Wrist has to be to the right of the shoulder")    
-                break
-                        
-            # TODO: THETA
-            # Calculating the angle between the keypoint and the z-axis (depth)
-            theta_keypoints = zero_out(shifted_keypoint, 1)
-            theta = vg.angle(theta_keypoints, z_axis)
-
-        
-            # Calculating the angle between the keypoint and the y-axis (height)
-            phi_keypoints = zero_out(shifted_keypoint, 2)
-            phi = vg.angle(phi_keypoints, y_axis)
-
-            if (phi < 20 or phi > 150):
-                rospy.logerr("Out of range")
-                return
-
-            distance = euclid_dist(np.array([0,0,0]), shifted_keypoint)
-
-            # TODO: Check maximum available distance
-            # Scaling the range of the distance from mediapipe to manipulator
-            dist = adjust_range(mp_distance_range[0], mp_distance_range[1], 
-                                robot_distance_range[0], robot_distance_range[1], 
-                                distance)
-
-            # TODO:
-            wrist_vec = wrist - elbow
-            wrist_vec = zero_out(wrist_vec, 2)
-
-            wrist_angle = vg.angle(wrist_vec, -x_axis, units="rad")
-            if wrist[1] < elbow[1]:
-                wrist_angle = -wrist_angle
-        
-            spherical_coordinates = SphericalCoordinates()
-            spherical_coordinates.theta = theta
-            spherical_coordinates.phi = phi
-            spherical_coordinates.distance = dist
-            spherical_coordinates.wrist_angle = wrist_angle
-
-        elif idx == wrist_keypoint and shoulder.size == 0:
-            rospy.logwarn("Shoulder keypoint missing")
-
-    # If spherical coordinates are defined, they are published
-    if spherical_coordinates:
-        pub.publish(spherical_coordinates)
-
-
-def skeleton_estimation_pose(pub, img, pose, data_collection = True):
-    # Running mediapipe to extract the keypoints
-    results = pose.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
-
-    # Getting the pose landmarks from the resulting pose
-    pose_landmarks = results.pose_landmarks
-    try:
-        # Getting the right arm keypoints
-        arm = [results.pose_landmarks.landmark[arm_keypoint] for arm_keypoint in arm_landmarks]
-        pose_landmarks = landmark_pb2.NormalizedLandmarkList(landmark = arm)
-    except Exception as e:
-        pass
-
-    # Drawing the landmarks on the frame of the video (i.e. keypoints)
-    mp_drawing.draw_landmarks(
-        img,
-        pose_landmarks,
-        landmark_drawing_spec=mp_ds
-    )
-
-    # Displaying the current frame of the video
-    cv2.imshow('Pose', cv2.flip(img,1))
-
-    # Publishing the keypoints only if data collection has been set to true and there are results available
-    if data_collection and results.pose_world_landmarks:
-        publish_keypoint(pub=pub, results=results)
-
-        
-def video_capture():
-    # Defining the current node
-    node = 'arm_keypoint_capture'
-    rospy.init_node(node)
-
-    # Creating a publisher of spherical coordinates
-    pub = rospy.Publisher(node, SphericalCoordinates, queue_size=50)
-
-    # Capturing the video with a webcam or with an mp4 video
-    if (VIDEO_REPLAY):
-        video = cv2.VideoCapture(VIDEO)
-    else:
-        video = cv2.VideoCapture(0)
-
-    # Initially the data is not published
-    data_collection = False
-
-    # Check if camera opened successfully
-    if (not video.isOpened()): 
-        print("Error opening video stream or file")
-        return
-
-    # If video is None the program will exit
-    if (not video):
-        print("Error - video is None")
-        return
-
-    # Mediapipe pose estimation
-    with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, model_complexity=2, enable_segmentation=True) as pose:
-        
-        print("Press 'E' to capture the keypoints")
-
-        while not rospy.is_shutdown():
-            # Read the captured video 
-            _, capture = video.read()
-            
-            try:   
-                # Running pose estimation on the video and publishing the coordinates
-                skeleton_estimation_pose(pub, capture, pose, data_collection)
-            except Exception as e:
-                # If there was an exception, the program should exit
-                rospy.logwarn(f"Finished with: {e}")
-                break
-
-            # Initialize data collection with an 'E' key
-            key = cv2.waitKey(1) 
-            if key == ord('e'):
-                data_collection = not data_collection
-                if data_collection:
-                    rospy.loginfo("Publishing the data on 'arm_keypoint_capture'")
-                else:
-                    rospy.loginfo("Publishing of data stopped")
-
-            # Exit the program with a 'Q' key press
-            if key == ord('q'):
-                rospy.logwarn("Closing the program")
-                break
-        # Release the video from memory and clear all the openCV windows
-        video.release()
-        cv2.destroyAllWindows()
-
-if __name__ == "__main__":
-    try:
-        video_capture()
-    except rospy.ROSInterruptException as e:
-        rospy.logfatal(f"Following exception caused the program to stop: {e}")
diff --git a/pose_estimation/scripts/capture_keypoints.py b/pose_estimation/scripts/capture_keypoints.py
deleted file mode 100644
index dcc5babff16b5a431f3797b1098e8b061c3b79e9..0000000000000000000000000000000000000000
--- a/pose_estimation/scripts/capture_keypoints.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import mediapipe as mp
-import cv2
-import os
-import pickle
-
-from mediapipe.framework.formats import landmark_pb2
-
-### Replaying a pre-recorded mp4 video or using the webcam
-VIDEO_REPLAY = True
-test_video = "test_movement_2.mp4"
-VIDEO = f"{os.getcwd()}/../test/{test_video}"
-
-FILENAME = 'keypoints.pickle'
-
-### Mediapipe drawing
-mp_drawing = mp.solutions.drawing_utils 
-mp_drawing_styles = mp.solutions.drawing_styles
-mp_ds = mp_drawing_styles.get_default_pose_landmarks_style()
-
-### Mediapipe Pose Estimation
-mp_pose = mp.solutions.pose
-
-### Arm Landmarks
-arm_landmarks = [
-    mp_pose.PoseLandmark.RIGHT_SHOULDER.value, # 12
-    mp_pose.PoseLandmark.RIGHT_ELBOW.value, # 14
-    mp_pose.PoseLandmark.RIGHT_WRIST.value, # 16
-]
-
-def generate_keypoints(img, pose, keypoint_dict, dict_idx):
-    # Running mediapipe to extract the keypoints
-    results = pose.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
-
-    # Getting the pose landmarks from the resulting pose
-    pose_landmarks = results.pose_landmarks
-    try:
-        # Getting the right arm keypoints
-        arm = [results.pose_landmarks.landmark[arm_keypoint] for arm_keypoint in arm_landmarks]
-        pose_landmarks = landmark_pb2.NormalizedLandmarkList(landmark = arm)
-    except Exception as e:
-        pass
-
-    # Drawing the landmarks on the frame of the video (i.e. keypoints)
-    mp_drawing.draw_landmarks(
-        img,
-        pose_landmarks,
-        landmark_drawing_spec=mp_ds
-    )
-
-    keypoint_dict[dict_idx] = {}
-
-    if results.pose_world_landmarks:
-        for idx, keypoint in enumerate(results.pose_world_landmarks.landmark):
-            if idx in arm_landmarks:
-                coordinates = (keypoint.x, keypoint.z, keypoint.y)
-                keypoint_dict[dict_idx][idx] = coordinates
-
-    return keypoint_dict
-
-def capture():
-    # Capturing the video with a webcam or with an mp4 video
-    if (VIDEO_REPLAY):
-        video = cv2.VideoCapture(VIDEO)
-    else:
-        video = cv2.VideoCapture(0)
-
-    keypoints_dict = {}
-
-    # Mediapipe pose estimation
-    with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, model_complexity=2, enable_segmentation=True) as pose:
-        
-        idx = 0
-        while(video.isOpened()):
-            # Read the captured video 
-            _, capture = video.read()
-            
-            try:
-                # Running pose estimation on the video and publishing the coordinates
-                keypoints_dict = generate_keypoints(capture, pose, keypoints_dict, idx)
-            except Exception as e:
-                break
-            
-            idx += 1
-            
-    # Release the video from memory and clear all the openCV windows
-    video.release()
-    cv2.destroyAllWindows()
-    
-    # Saving the keypoints with time stamps into a pickle file
-    with open(FILENAME, 'wb') as handle:
-        pickle.dump(keypoints_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)     
-
-if __name__ == "__main__":
-    capture()
-
diff --git a/pose_estimation/scripts/data_collection.py b/pose_estimation/scripts/data_collection.py
new file mode 100755
index 0000000000000000000000000000000000000000..92cc47ad5c025663e63c9cb69a71eee3d72ee622
--- /dev/null
+++ b/pose_estimation/scripts/data_collection.py
@@ -0,0 +1,466 @@
+import numpy as np
+import mediapipe as mp
+
+from PIL import Image
+import rospy
+import cv2
+import os
+import pickle
+import math
+import vg
+import torch
+import tf
+
+from mediapipe.framework.formats import landmark_pb2
+
+from helpers.training_help import Net
+
+CONTROL = rospy.get_param('/keypoint_collection/manipulator_control')
+
+if not CONTROL:
+    FILENAME = rospy.get_param('/keypoint_collection/filename')
+    VIDEO_SAVE = rospy.get_param('/keypoint_collection/video_save') 
+    save_file = os.path.expanduser(f"~/catkin_ws/src/pose_estimation/data/{FILENAME}")
+
+else: 
+    VIDEO_SAVE = False
+    ik_model = rospy.get_param('/keypoint_collection/ik_model')
+    CONTROL_MODEL = os.path.expanduser(f"~/catkin_ws/src/pose_estimation/data/models/{ik_model}")
+
+MODEL = rospy.get_param('/keypoint_collection/gesture_model')
+model_file = os.path.expanduser(f"~/catkin_ws/src/pose_estimation/data/{MODEL}")
+
+
+# Dictionary of different modes/stages of data collection
+COLLECTION_MODES = {
+    "Setup": 0,
+    "Calibration": 1,
+    "Collection": 2,
+    "Terminate": 3,
+    "Finished": 4,
+    "Control": 5,
+}
+
+MANIPULATOR_MAX_DIST = 0.38
+CONFIDENCE_SCORE = 0.6
+
+### Mediapipe Gesture Recognition
+BaseOptions = mp.tasks.BaseOptions
+GestureRecognizer = mp.tasks.vision.GestureRecognizer
+GestureRecognizerOptions = mp.tasks.vision.GestureRecognizerOptions
+VisionRunningMode = mp.tasks.vision.RunningMode
+
+### Mediapipe pose estimation
+mp_pose = mp.solutions.pose
+
+### Mediapipe drawing
+mp_drawing = mp.solutions.drawing_utils 
+mp_drawing_styles = mp.solutions.drawing_styles
+mp_ds = mp_drawing_styles.get_default_pose_landmarks_style()
+
+### Arm Landmarks
+wrist_keypoint = mp_pose.PoseLandmark.RIGHT_WRIST.value
+elbow_keypoint = mp_pose.PoseLandmark.RIGHT_ELBOW.value
+shoulder_keypoint = mp_pose.PoseLandmark.RIGHT_SHOULDER.value
+
+arm_landmarks = [shoulder_keypoint, elbow_keypoint, wrist_keypoint]
+
+def create_file():
+    rospy.loginfo(f"Creating file: {save_file}")
+    with open(save_file, 'wb') as handle:
+        pickle.dump({}, handle, protocol=pickle.HIGHEST_PROTOCOL)
+
+def save_dataset(data):
+    rospy.loginfo("Saving the dataset...")
+    if not os.path.exists(save_file):
+        create_file()
+
+    # Read the file and append the input to it
+    captured_keypoints = None
+    with open(save_file, 'rb') as handle:
+        captured_keypoints = pickle.load(handle)
+        keys = list(captured_keypoints.keys())
+        idx = keys[-1] + 1 if len(keys) != 0 else 0
+        
+        captured_keypoints[idx] = data
+    
+    if captured_keypoints:
+        with open(save_file, 'wb') as handle:
+            rospy.loginfo(f"Saving a video with index: {idx}")
+            pickle.dump(captured_keypoints, handle, protocol=pickle.HIGHEST_PROTOCOL)
+            rospy.loginfo(f"The dataset has been saved correctly at '{save_file}'")
+
+def verify_gesture(gesture, expected_gesture, score):
+    # Checking the the gesture is the expected gesture and if the confidence is high
+    return gesture == expected_gesture and score >= CONFIDENCE_SCORE
+
+def gesture_recognition(recognizer, capture, frame_timestamp_ms, mode):
+    # Converting the image to Mediapipe image type
+    pil_img = Image.fromarray(capture) 
+    mp_image = mp.Image(image_format = mp.ImageFormat.SRGB, data = np.asarray(pil_img))
+    
+    # Using the model to get the gesture recognition from video
+    gesture_recognition_result = recognizer.recognize_for_video(mp_image, frame_timestamp_ms)
+
+    if len(gesture_recognition_result.gestures) != 0:
+        # Getting the highest ranked gesture and its confidence score
+        gesture = gesture_recognition_result.gestures[0][0].category_name
+        score = gesture_recognition_result.gestures[0][0].score
+        
+        # If the mode is set to setup, searching for a Thumbs up to start calibration
+        if mode == COLLECTION_MODES["Setup"]:
+            if verify_gesture(gesture, "Thumb_Up", score):
+                rospy.loginfo("\n*** CALIBRATION MODE ***")
+                return COLLECTION_MODES["Calibration"]
+        # If the mode is set to collection, seraching for a Thumbs up or down to stop the collection
+        elif mode == COLLECTION_MODES["Collection"]:
+            if verify_gesture(gesture, "Thumb_Up", score):
+                rospy.loginfo("\n*** Collection Over ***")
+                return COLLECTION_MODES["Finished"]
+            if verify_gesture(gesture, "Thumb_Down", score):
+                rospy.logwarn("\n***Program Terminated***")
+                return COLLECTION_MODES["Terminate"]
+        elif mode == COLLECTION_MODES["Control"]:
+            if verify_gesture(gesture, "Thumb_Down", score):
+                rospy.logwarn("\n***Program Terminated***")
+                return COLLECTION_MODES["Terminate"]
+
+    # Returing the mode
+    return mode
+
+
+def show_image(img, pose_landmarks = None):
+    # Drawing the landmarks on the frame of the video (i.e. keypoints)
+    if pose_landmarks:
+        mp_drawing.draw_landmarks(
+            img,
+            pose_landmarks,
+            landmark_drawing_spec=mp_ds
+        )
+    
+    # Displaying the current frame of the video
+    cv2.imshow("Pose Video", cv2.flip(img, 1))
+
+# TODO: Possibly add Visibility
+def get_landmarks(pose_result):
+
+    if not pose_result.pose_world_landmarks or not pose_result.pose_world_landmarks.landmark: return None
+
+    landmarks = {}
+    angle = None
+    for idx, keypoint in enumerate(pose_result.pose_world_landmarks.landmark):
+        if idx in arm_landmarks:
+            coordinates = (keypoint.x, keypoint.z, keypoint.y)
+            landmarks[idx] = coordinates
+    
+    elbow, wrist = np.array(landmarks[elbow_keypoint]), np.array(landmarks[wrist_keypoint])
+    wrist_vec = zero_out(wrist - elbow, 2)
+
+    angle = vg.angle(wrist_vec, np.array([-1, 0, 0]), units="rad")
+    if wrist[1] < elbow[1]:
+        angle = -angle
+    else:
+        angle = 0.0
+    
+    landmarks['angle'] = angle
+    return landmarks
+
+def test_path(model_path):
+    return os.path.exists(model_path)
+
+############# HELPER FUNCTIONS FOR KEYPOINT MANIPULATION #############
+ 
+### Check the visibility of a keypoint
+def verify_visibility(keypoint, visibility_threshold):
+    if keypoint.visibility <= visibility_threshold:
+        return False
+
+    return True
+
+### Convert keypoint locations to a numpy array
+def get_keypoint_as_np(keypoint):
+    return np.array([keypoint.x, keypoint.y, keypoint.z])
+
+### Takes in an array and zeros out a specified index
+def zero_out(arr, n):
+    new_arr = np.copy(arr)
+    new_arr[n] = 0
+    return new_arr
+
+### Euclidean distance between two vectors
+def euclid_dist(vec1, vec2):
+  if vec1.size != vec2.size:
+    rospy.logerr("Size of the two points doesn't match")
+    return None
+  
+  sum_ = 0
+  for i in range(vec1.size):
+    squared_difference = (vec1[i] - vec2[i])**2
+    sum_ += squared_difference
+  
+  return math.sqrt(sum_)
+
+########################################################
+
+def skeleton_estimation_pose(img, pose):
+    # Running mediapipe to extract the keypoints
+    results = pose.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+
+    # Getting the pose landmarks from the resulting pose
+    pose_landmarks = results.pose_landmarks
+    
+    if not pose_landmarks: 
+        show_image(img)
+        return None
+
+    try:
+        # Getting the right arm keypoints
+        arm = [pose_landmarks.landmark[arm_keypoint] for arm_keypoint in arm_landmarks]
+        pose_landmarks = landmark_pb2.NormalizedLandmarkList(landmark = arm)
+    except Exception as e:
+        rospy.logerr(f"Exception: {e}")
+
+    # Displaying the video feed and the landmarks
+    # show_image(img, pose_landmarks)
+    
+    return results
+
+
+def calculate_calibration(pose_result):
+    # Distance between the shoulder and the wrist
+    landmarks = get_landmarks(pose_result)
+
+    shoulder, wrist = landmarks[shoulder_keypoint], landmarks[wrist_keypoint]
+    shoulder, wrist = np.array(shoulder), np.array(wrist)
+
+    shifted_wrist = wrist - shoulder
+
+    # Checking if the keypoint is to the left side of the shoulder, i.e. if x is positive (because the image is inverted, otherwise it should be positive)
+    if shifted_wrist[0] > 0:
+        rospy.logerr("Wrist has to be to the right of the shoulder")    
+        return None
+
+    distance = euclid_dist(np.array([0,0,0]), shifted_wrist)
+
+    return MANIPULATOR_MAX_DIST / distance
+  
+def get_model(path):
+    model = Net(7, 4)
+    model.load_state_dict(torch.load(path)['model_state_dict'])
+    model.eval()
+    return model
+
+def video_capture(control = False):
+    video = cv2.VideoCapture(0)
+    video_fps = video.get(cv2.CAP_PROP_FPS)
+
+    counter, countdown = 0, -1
+    threshold = (video_fps * 100) + 100
+    calibration_results = []
+
+    MODE = COLLECTION_MODES["Setup"]
+
+    if VIDEO_SAVE is not None and type(VIDEO_SAVE) != bool:
+        rospy.logfatal("video_save must be a boolean")
+        return None
+
+    if VIDEO_SAVE:
+        data = {"points": [], "video":[], "calib": None}  
+    else:
+        data = {"points": [], "calib": None}
+
+    if not control:
+        rospy.logwarn("Important!")
+        setup_info = """"\n\n*** Setup mode ***
+To calibrate the arm length with the manipulator, face the camera, 
+extend your **right** arm fully to the side (parallel to the ground) and gesture **Thumbs Up**.
+Once the gesture is recognized, a countdown will initiate and at 0 it will calibrate the distance.
+Important: Try moving the arm up and down whilst fully extended to capture most accurate max distance.\n\n"""
+        rospy.loginfo(setup_info)
+    else:
+        model = get_model(CONTROL_MODEL)
+
+    # Check if camera opened successfully
+    if (not video.isOpened() or not video): 
+        rospy.logerr("Error opening video stream or file")
+        return
+    
+    options = GestureRecognizerOptions(
+        base_options = BaseOptions(model_asset_path = model_file), 
+        running_mode = VisionRunningMode.VIDEO
+    )
+    
+    with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, model_complexity=2, enable_segmentation=True) as pose:    
+        # Create a gesture recognizer instance with the video mode:
+
+        with GestureRecognizer.create_from_options(options) as recognizer:
+            frame_idx = 0
+
+            # Running while the user does not signify that the data collection is finished
+            while MODE != COLLECTION_MODES['Finished']:
+                
+                frame_idx += 1
+
+                # If the program is shutdown or the user sets it to terminate, the data will not be saved 
+                if MODE == COLLECTION_MODES['Terminate'] or cv2.waitKey(1) & 0xFF == ord('q') or rospy.is_shutdown(): 
+                    # Release the video from memory and clear all the openCV windows
+                    video.release()
+                    cv2.destroyAllWindows()
+                  
+                    rospy.logwarn("Terminated the program, no data is saved")
+                  
+                    return None
+
+                _, capture = video.read()
+                
+                try:   
+                    frame_ms = int((1000 * frame_idx) / video_fps)
+                    MODE = gesture_recognition(recognizer, capture, frame_ms, MODE)
+                    
+                    if MODE == COLLECTION_MODES["Setup"]:
+                        show_image(capture)
+
+                    ### TODO: Max calibration rather than last
+                    elif MODE == COLLECTION_MODES['Calibration']:
+                        counter += (1000 / video_fps)
+                        if counter <= threshold:
+                            results = skeleton_estimation_pose(capture, pose)
+                            if results is not None:
+                                calibration = calculate_calibration(results)
+                                if calibration is not None: calibration_results.append(calibration)
+
+                            prev_count = countdown
+                            countdown = int((counter)/1000)
+                            
+                            if prev_count != countdown:
+                                print(3-countdown)
+                    
+                        else:
+                            if len(calibration_results) == 0:
+                                rospy.logerr("No results found for calibration, Terminating...") 
+                                MODE = COLLECTION_MODES["Terminate"]
+                                continue
+
+                            data['calib'] = max(calibration_results)
+                            MODE = COLLECTION_MODES["Collection"] if not control else COLLECTION_MODES["Control"]
+                            
+                            if MODE == COLLECTION_MODES["Collection"]:
+                                rospy.loginfo("\n***Dataset Collection Started***\n\n-> Terminate: Thumbs Down (or CTRL+C)\n-> Finish and Save: Thumbs Up")
+                            else:
+                                rospy.loginfo("\n***Manipulation Started***\n\n-> Terminate: Thumbs Down (or CTRL+C)\n")
+
+                    elif MODE == COLLECTION_MODES["Collection"]:
+                        if VIDEO_SAVE:
+                            data['video'].append(cv2.flip(capture, 1))
+                        
+                        pose_results = skeleton_estimation_pose(capture, pose)
+                        
+                        if not pose_results: continue
+
+                        landmarks = get_landmarks(pose_results)
+                        # Checking if landmarks exist and if all the three keypoints have been found
+                        if landmarks:
+                            data['points'].append(landmarks)
+
+                            LANDMARKS += 1 
+                            print(LANDMARKS)
+
+                    elif MODE == COLLECTION_MODES["Control"]:
+                        pose_results = skeleton_estimation_pose(capture, pose)
+                        if not pose_results: continue
+                        landmarks = get_landmarks(pose_results)
+                    
+                        control_manipulator(landmarks, model, data['calib'])
+                        
+
+                except Exception as e:
+                        # If there was an exception, the program should exit
+                        rospy.logwarn(f"Exception thrown: {e}")
+                        MODE = COLLECTION_MODES['Terminate']
+
+    # Release the video from memory and clear all the openCV windows
+    video.release()
+    cv2.destroyAllWindows()
+
+    return data
+
+####################### TODO: CODE FROM training_data.py ########################
+from open_manipulator_msgs.srv import SetJointPosition, SetJointPositionRequest
+
+PATH_TIME = 0.5
+SHOULDER = mp_pose.PoseLandmark.RIGHT_SHOULDER.value
+WRIST = mp_pose.PoseLandmark.RIGHT_WRIST.value
+
+def set_wrist_angle(joint_angles):
+    rospy.wait_for_service('goal_joint_space_path', 2)
+    set_joints = rospy.ServiceProxy('goal_joint_space_path', SetJointPosition)
+
+    goal_request = SetJointPositionRequest()
+    goal_request.joint_position.joint_name = ["joint1", "joint2", "joint3", "joint4"]
+    goal_request.joint_position.position = joint_angles
+    goal_request.joint_position.max_accelerations_scaling_factor = 0.0
+    goal_request.joint_position.max_velocity_scaling_factor = 0.0
+    goal_request.path_time = PATH_TIME
+    resp = set_joints(goal_request)
+
+    rospy.sleep(PATH_TIME)
+
+    if not resp.is_planned:
+        rospy.logerr("Failed to solve IK equation for wrist")
+
+    return resp.is_planned
+
+
+def shift_keypoints(original, offset):
+    return (np.array(original) - np.array(offset))
+#############################################################################################################
+
+### TODO: Testing
+def control_manipulator(landmarks, model, calib):
+    print("Controlling manipulator...")
+    ## Shifting the wrist to be relative to the shoulder at origin
+    wrist_point = shift_keypoints(landmarks[WRIST], landmarks[SHOULDER])
+    wrist_point[2] = -wrist_point[2]
+
+    ## Calculating the calibrated point
+    wrist_point = wrist_point * calib    
+
+    ## Converting the euler angle to a quaternion
+    angle = tf.transformations.quaternion_from_euler(0, landmarks['angle'], 0)
+
+    ## Concatenating the wrist point and angle to match the model input
+    model_input = np.concatenate((wrist_point, angle), axis = None)
+    
+    ## Predicting the joints based on the wrist point and angle
+    control_outputs = model(torch.Tensor(model_input))
+
+    ## Setting the manipulator's wrist angle
+    success = set_wrist_angle(control_outputs.tolist())
+
+def verify_params():
+    verified = test_path(model_file)
+    if CONTROL:
+        verified = verified and test_path(CONTROL_MODEL)
+
+        if verified:
+            ### Sleeping to let Gazebo launch
+            rospy.sleep(3)
+    return verified
+
+if __name__ == '__main__':
+    rospy.init_node("keypoint_collection")
+
+    print(CONTROL)
+    print(type(CONTROL))
+
+    if verify_params():
+        rospy.loginfo("All the parameters were verified")
+
+        data = video_capture(control = CONTROL)
+        
+        if data: 
+            save_dataset(data = data)
+    else:
+        rospy.logfatal(f"Check if all the provided model names are correct or if they are in the correct directories")
+    
+    rospy.loginfo("Finished")
\ No newline at end of file
diff --git a/pose_estimation/scripts/evaluate.py b/pose_estimation/scripts/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3552206e20833703487661666903a16fb3d58c6
--- /dev/null
+++ b/pose_estimation/scripts/evaluate.py
@@ -0,0 +1,198 @@
+import torch
+import os
+import pickle
+import threading
+import rospy 
+
+import numpy as np
+import matplotlib.pyplot as plt
+import torch.nn as nn
+import torch.nn.functional as F
+
+from open_manipulator_msgs.msg import KinematicsPose
+from pose_estimation.msg import Joints
+# from training import Net, process_dataset, list_to_tensor
+
+
+#### TODO: Repeated code ####
+
+class Net(nn.Module):
+
+  def __init__(self, input_num, output_num):
+    super(Net, self).__init__()
+    self.l1 = nn.Linear(input_num, 32)
+    # self.l2 = nn.Linear(32, 32)
+    self.output = nn.Linear(32, output_num)
+    # self.dropout = nn.Dropout(0.2)
+
+  def forward(self, x):
+    x = F.tanh(self.l1(x))
+    # x = F.tanh(self.dropout(self.l2(x)))
+    return self.output(x)
+
+def list_to_tensor(arr):
+    return torch.tensor(arr)
+
+def process_dataset(dataset):
+    inputs = []
+    labels = []
+
+    for data in dataset:
+        actual_values = []
+        for joint in data['jointPositions']:
+            actual_values.append(joint.angle)
+        
+        labels.append(actual_values)
+
+        man_pos = data['manipulatorPositions']
+        man_angle = data['angle']
+        x = [man_pos.x, man_pos.y, man_pos.z, man_angle.x, man_angle.y, man_angle.z, man_angle.w]
+
+        inputs.append(x)
+
+    inputs, labels = remove_duplicates(inputs, labels)
+    
+    inputs = list_to_tensor(np.array(inputs).astype(np.float32))
+    labels = list_to_tensor(labels)
+
+    return (inputs, labels)
+
+def remove_duplicates(inputs, labels):
+    processed_inputs, processed_labels = [], []
+    
+    for i in range(len(inputs)):
+        
+        duplicates = []
+        for j in range(i+1, len(inputs)):
+            duplicates.append(inputs[i] == inputs[j])    
+        
+        if np.array(duplicates).any() == False: 
+            processed_inputs.append(inputs[i])
+            processed_labels.append(labels[i])
+
+    print(f"{len(inputs) - len(processed_inputs)} duplicates have been removed")
+    return processed_inputs, processed_labels
+
+
+############################################
+
+
+
+
+model_path = os.path.expanduser("~/catkin_ws/src/pose_estimation/data/models/joint_predict_all.pt")
+dataset_path = os.path.expanduser("~/catkin_ws/src/pose_estimation/data/keypoint_dataset.pickle")
+
+def plot_history(history):
+    if history:
+        plt.plot(history)
+        plt.show()
+
+def load_model(model_path, shape):
+    model_dict = torch.load(model_path)
+
+    model = Net(shape[0], shape[1])
+    model.load_state_dict(model_dict['model_state_dict'])
+    model.eval()
+    
+    loss = model_dict['loss']
+
+    print("Loading the model...")
+    print(f"Loaded model error: {loss}")
+
+    history = None
+    if 'history' in model_dict:
+        history = model_dict['history']
+
+    return model, history
+
+def predict(model, inputs):
+    predictions = []
+
+    with torch.no_grad():
+        for input in inputs:
+            pred = model(input)
+            predictions.append(pred.numpy())
+
+    return list_to_tensor(np.array(predictions).astype(np.float32))
+
+def per_joint_error(labels, predictions, loss_fn = nn.MSELoss()):
+    errors = {}
+    joint_labels, joint_predictions = labels.transpose(1,0), predictions.transpose(1,0)
+
+    for joint in range(joint_labels.shape[0]):
+        error = loss_fn(joint_labels[joint], joint_predictions[joint])
+        errors[f"joint{joint+1}"] = error.item()
+    
+    print(f"\n-------------------------------\nPer Joint Error: \n{errors}")
+    print("-------------------------------\n")
+
+
+from tqdm import tqdm
+
+def calcualte_fk(pub, predictions):
+
+    for prediction in tqdm(predictions):
+        msg = Joints()
+        joints = [val.item() for val in prediction]
+        msg.angles = joints
+        
+        # lock.acquire()
+        pub.publish(msg)
+        rospy.sleep(0.2)
+    
+    rospy.sleep(5)
+
+def positional_error(actual, predicted):
+    print(actual, predicted)
+
+
+
+    ## L2 loss between the two positions
+    # dist = 0.0
+    # count = 0.0
+    # for l_a, l_p in zip(actual, predicted):
+    #     dist += np.linalg.norm(l_a[:3] - l_p)
+    #     count += 1.0
+    
+    # print(f"Average positional error (L2 loss between predicted and actual gripper position): {dist / count}")
+
+
+def process_manipulator_positions(manipulator_pose):
+    position = manipulator_pose.pose.position
+    PREDICTED_POSITIONS.append([position.x, position.y, position.z])
+    
+
+if __name__ == '__main__':   
+    PREDICTED_POSITIONS = []
+    # lock = threading.Lock()
+    rospy.init_node("evaluation")
+    rospy.Subscriber("forward_kinematics_keypoints", KinematicsPose, process_manipulator_positions)
+
+    rospy.sleep(1)
+
+    kinematics_pub = rospy.Publisher("evaluation", Joints, queue_size = 10)
+
+    dataset = {}
+    with open(dataset_path, "rb") as input_file:
+        dataset = pickle.load(input_file)
+    
+    inputs, labels = process_dataset(dataset)
+    model, history = load_model(model_path, shape = (len(inputs[0]), len(labels[0])))
+    predictions = predict(model, inputs)
+
+    ## Plot Overall Loss Graph
+    plot_history(history)
+
+    ## Calculate error for each joint
+    loss_fn = nn.HuberLoss()
+    per_joint_error(labels, predictions, loss_fn)
+
+    ## Calculate error between the predicted and final position with the use of forward kinematics  
+    calcualte_fk(kinematics_pub, predictions)
+    positional_error(inputs.numpy(), np.array(PREDICTED_POSITIONS))
+
+
+
+    
+
+    
\ No newline at end of file
diff --git a/pose_estimation/scripts/helpers/training_help.py b/pose_estimation/scripts/helpers/training_help.py
new file mode 100755
index 0000000000000000000000000000000000000000..b54b1e41b16f0860d5182bb5fc5b810432ef5a26
--- /dev/null
+++ b/pose_estimation/scripts/helpers/training_help.py
@@ -0,0 +1,15 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Net(nn.Module):
+  def __init__(self, input_num, output_num):
+    super(Net, self).__init__()
+    self.l1 = nn.Linear(input_num, 32)
+    # self.l2 = nn.Linear(32, 32)
+    self.output = nn.Linear(32, output_num)
+    # self.dropout = nn.Dropout(0.2)
+
+  def forward(self, x):
+    x = F.tanh(self.l1(x))
+    # x = F.tanh(self.dropout(self.l2(x)))
+    return self.output(x)
\ No newline at end of file
diff --git a/pose_estimation/scripts/keypoints.pickle b/pose_estimation/scripts/keypoints.pickle
deleted file mode 100644
index f49cda91bc0c17a30b4baaf497b78f13b646b887..0000000000000000000000000000000000000000
Binary files a/pose_estimation/scripts/keypoints.pickle and /dev/null differ
diff --git a/pose_estimation/scripts/pose_estimation_control.py b/pose_estimation/scripts/pose_estimation_control.py
deleted file mode 100755
index 5218bfbd7a065af855c6f57dd5cbd960e4ba5f9d..0000000000000000000000000000000000000000
--- a/pose_estimation/scripts/pose_estimation_control.py
+++ /dev/null
@@ -1,227 +0,0 @@
-#############################################
-#                     TODO                  #
-# Debug: Show Manipulator                   # 
-# Test Motion                               #
-# Vector from elbow                         #                  
-#############################################
-# Remove DEBUG - Incorporate with visualize #
-
-
-import rospy
-import math
-import itertools
-import numpy as np
-
-from math import sin, cos
-
-from geometry_msgs.msg import Pose
-from geometry_msgs.msg import PointStamped
-
-from open_manipulator_msgs.srv import SetKinematicsPose, SetJointPosition, GetJointPosition, SetJointPositionRequest
-from open_manipulator_msgs.msg import KinematicsPose, JointPosition
-from sensor_msgs.msg import JointState
-
-### Importing the messages for Spherical Coordinates
-from pose_estimation.msg import SphericalCoordinates
-
-from visualization_msgs.msg import Marker
-
-
-### Defining available joints by name
-JOINT_NAMES = ["joint1", "joint2", "joint3", "joint4"]
-
-### Variables for moving the gripper
-END_EFFECTOR_NAME = "gripper"
-PATH_TIME = 1
-
-### Starting joint positions
-START_LOCATION = {
-    'x': 0.286,
-    'y': 0.0,
-    'z': 0.204,
-}
-
-### If set to true will publish the coordinates as PointStamped msg which can be displayed in RVIZ
-DEBUG = False
-
-# Calculating the distance between two points
-def dis(point1, point2):
-    return math.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2)
-
-# USE Cloud Point
-def create_vis_marker(coordinates):
-    # Creating a marker which can be visualized in RVIZ
-    msg = Marker()
-
-    # Creating a custom frame and setting the timestamp
-    msg.header.frame_id = "my_frame"
-    msg.header.stamp = rospy.Time.now()
-
-    # Setting the position of the marker
-    msg.pose.position.x = coordinates['x'] * 10
-    msg.pose.position.y = coordinates['y'] * 10
-    msg.pose.position.z = coordinates['z'] * 10   
-
-    msg.pose.orientation.x = 0
-    msg.pose.orientation.y = 0
-    msg.pose.orientation.z = 0
-    msg.pose.orientation.w = 0
-
-    # Setting the scale
-    msg.scale.x = 0.2 * 10
-    msg.scale.y = 0.01 * 10
-    msg.scale.z = 0.05 * 10
-
-    # Making the marker visible and pink
-    msg.color.r = 199
-    msg.color.g = 21
-    msg.color.b = 133
-    msg.color.a = 1.0 
-
-    return msg
-
-
-### TODO: Refactor
-def adjust_range(old_min, old_max, new_min, new_max, old_value):
-    old_range = old_max - old_min
-    new_range = new_max - new_min
-
-    return (((old_value - old_min) * new_range) / old_range) + new_min
-
-def spherical_to_cartesian(dist, theta, phi):
-    global debug_pub
-
-    def invert_sign(val):
-        return -val
-
-    print(f"Theta: {theta}")
-
-    phi = math.radians(adjust_range(30, 150, 90, 180, phi))
-    theta = math.radians(theta)
-
-    x = dist * sin(phi)
-    y = dist * cos(theta)
-    z = dist * invert_sign(cos(phi))
-
-    coordinates = {
-        "x" : x, 
-        "y" : y,
-        "z" : z
-    }
-    
-    if (DEBUG):
-        print(f"Distance {dist}")
-        print(f"Coordinates: {coordinates}")
-        print()
-
-        msg = create_vis_marker(coordinates)
-        debug_pub.publish(msg)
-
-    return coordinates
-
-
-### TODO: Joint Angle
-def set_wrist_angle(pos):
-    pass
-    # TODO: Exception and stop program if connections fails
-    
-    # joint_position = JointPosition()
-    # joint_position.joint_name = JOINT_NAMES
-    # joint_position.position = [1.0, 1.0, -1.0, 0.50]
-    
-    # resp = set_joints('', joint_position, PATH_TIME)
-
-    rospy.wait_for_service('goal_joint_space_path', 2)
-    set_joints = rospy.ServiceProxy('goal_joint_space_path', SetJointPosition)
-
-    goal_request= SetJointPositionRequest()
-    goal_request.joint_position.joint_name = JOINT_NAMES
-    goal_request.joint_position.position = [1.0, 1.0, -1.0, 0.50]
-    goal_request.joint_position.max_accelerations_scaling_factor = 0.0
-    goal_request.joint_position.max_velocity_scaling_factor = 0.0
-    goal_request.path_time = 2.0
-
-    resp = set_joints(goal_request)
-
-    if not resp.is_planned:
-        rospy.logerr("Failed to solve IK equation for wrist")
-
-def set_position(target_location):
-    # TODO: Exception and stop program if connections fails
-    rospy.wait_for_service('goal_task_space_path_position_only', 2)
-
-    set_kinematics_pose = rospy.ServiceProxy('goal_task_space_path_position_only', SetKinematicsPose)
-
-    pose = Pose()
-    pose.position.x = target_location['x']
-    pose.position.y = target_location['y']
-    pose.position.z = target_location['z']
-
-    kinematics_pose = KinematicsPose()
-    kinematics_pose.pose = pose
-
-    resp = set_kinematics_pose(end_effector_name = END_EFFECTOR_NAME, kinematics_pose=kinematics_pose, path_time = PATH_TIME)
-
-    if not resp.is_planned:
-        rospy.logerr("Failed to solve IK equation")
-        print(f"X: {target_location['x']}; Y: {target_location['y']}; Z: {target_location['z']}")
-
-
-def video_capture(coordinate):
-    # Converting the spherical coordiantes received from pose estimation to cartesian coordinates
-    target_location = spherical_to_cartesian(coordinate.distance, coordinate.theta, coordinate.phi)
-    
-    
-    if(not DEBUG):
-        try:
-            # TODO: SET Wrist
-            set_wrist_angle(coordinate.wrist_angle)
-
-            # Setting the location of the robotic arm
-            # set_position(target_location)
-
-        except rospy.ServiceException as e:
-            rospy.logerr(f"Service call failed {e}")
-
-        # Sleep equal to the time needed for the arm to move to the desired location
-        rospy.sleep(PATH_TIME)
-
-JOINT_POSITION_TMP = -100
-
-def joint_state(msg):
-    global JOINT_POSITION_TMP
-
-    angle = round(msg.position[-1], 0)
-    if JOINT_POSITION_TMP != angle:
-        print(f"New angle position: {angle}")
-        JOINT_POSITION_TMP = angle
-
-def pose_estimation():
-    # Moving the arm to the starting position
-    rospy.loginfo("Moving the arm to the starting location")
-
-    print("Subscribing to `/joint_states`")
-    rospy.Subscriber("/joint_states", JointState, joint_state, queue_size = 1)
-    
-    if (not DEBUG):
-        set_wrist_angle(2)
-        set_position(START_LOCATION)
-
-    return
-
-    print("Subscribing to `arm_keypoint_capture`")
-    rospy.Subscriber("arm_keypoint_capture", SphericalCoordinates, video_capture, queue_size = 1)
-
-    rospy.spin()
-
-
-if __name__ == '__main__':
-    node = 'pose_estimation_control'
-    rospy.init_node(node)
-
-    debug_pub = rospy.Publisher(node, Marker, queue_size=50)
-    
-    try:
-        pose_estimation()
-    except rospy.ROSInterruptException as e:
-        print("Exception: " + e)
diff --git a/pose_estimation/scripts/training.py b/pose_estimation/scripts/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fe00bfd28f028b7dcce7b7ebbdfe0f9a1fddc91
--- /dev/null
+++ b/pose_estimation/scripts/training.py
@@ -0,0 +1,449 @@
+import pickle
+import rospy
+import os
+import pypareto
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import numpy as np
+import matplotlib.pyplot as plt
+
+from tqdm import tqdm
+from random import choice, random, uniform
+from open_manipulator_msgs.msg import KinematicsPose
+from pose_estimation.msg import Joints
+
+from helpers.training_help import Net
+
+
+dataset_name = rospy.get_param('/training/dataset')
+DATASET_PATH = os.path.expanduser(f"~/catkin_ws/src/pose_estimation/data/{dataset_name}")
+
+model_name = rospy.get_param('/training/model')
+MODEL_SAVE_PATH = os.path.expanduser(f"~/catkin_ws/src/pose_estimation/data/models/{model_name}")
+
+criterion = rospy.get_param('/training/loss')
+scheduler = rospy.get_param('/training/scheduler')
+learning_rate = rospy.get_param('/training/initial_lr')
+EPOCHS = rospy.get_param('/training/epochs')
+
+PLOT = rospy.get_param('/training/plot')
+EVALUATE = rospy.get_param('/training/evaluate')
+VALIDATE = rospy.get_param('/training/validate')
+
+RUN_NAS = rospy.get_param('/training/NAS')
+NAS_iter = rospy.get_param('/training/NAS_iter')
+
+LEARNING_RATE = 1e-1
+CRITERION = None
+SCHEDULER = None
+CHECKPOINT_LOG = 1
+
+def validate_params():  
+    global CRITERION, SCHEDULER, LEARNING_RATE, CHECKPOINT_LOG
+
+    if type(EPOCHS) != int:
+        print(f"Epochs have to be of integer type")
+        return False
+
+    if type(EVALUATE) != bool or type(PLOT) != bool or type(VALIDATE) != bool:
+        print(f"`evaluate`, `validate` and `plot` flags need to be booleans")
+        return False
+    try:
+        LEARNING_RATE = float(learning_rate)
+    except Exception as conversion_exp:
+        print(f"Learning rate incorrect. Exception when converting: {conversion_exp}")
+        return False
+
+    if criterion == "mse":
+        CRITERION = nn.MSELoss()
+    elif criterion == "huber":
+        CRITERION = nn.HuberLoss()
+    else:
+        print(f"Given loss [{criterion}] not valid")
+        return False
+    
+    if VALIDATE:
+        CHECKPOINT_LOG = 1
+    else:
+        CHECKPOINT_LOG = 5
+
+    if scheduler == "None":
+        SCHEDULER = None
+        # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=9e-1)
+    else:
+        print(f"Given scheduler [{scheduler}] not valid")
+        return False
+    
+    return True
+
+
+def list_to_tensor(arr):
+    return torch.tensor(arr)
+
+def process_dataset(dataset):
+    inputs = []
+    labels = []
+
+    for data in dataset:
+        actual_values = []
+        for joint in data['jointPositions']:
+            actual_values.append(joint.angle)
+        
+        labels.append(actual_values)
+
+        man_pos = data['manipulatorPositions']
+        man_angle = data['angle']
+        x = [man_pos.x, man_pos.y, man_pos.z, man_angle.x, man_angle.y, man_angle.z, man_angle.w]
+
+        inputs.append(x)
+
+    inputs, labels = remove_duplicates(inputs, labels)
+    
+    inputs = list_to_tensor(np.array(inputs).astype(np.float32))
+    labels = list_to_tensor(labels)
+
+    return (inputs, labels)
+
+def remove_duplicates(inputs, labels):
+    processed_inputs, processed_labels = [], []
+    
+    for i in range(len(inputs)):
+        
+        duplicates = []
+        for j in range(i+1, len(inputs)):
+            duplicates.append(inputs[i] == inputs[j])    
+        
+        if np.array(duplicates).any() == False: 
+            processed_inputs.append(inputs[i])
+            processed_labels.append(labels[i])
+
+    print(f"{len(inputs) - len(processed_inputs)} duplicates have been removed")
+    return processed_inputs, processed_labels
+
+
+def train(inputs, labels, model, optimizer):
+    running_loss = 0.0
+    running_pos_loss = 0.0
+    count = 0
+
+    model.train()
+
+    for i, data in enumerate(zip(inputs, labels)):
+        # get the inputs; data is a bath of inputs and labels
+        input, label = data
+
+        # zero the parameter gradients
+        optimizer.zero_grad()
+
+        # forward + backward + optimize
+        outputs = model(input)
+        loss = CRITERION(outputs, label)
+        loss.backward()
+        optimizer.step()
+        
+        running_loss += loss.item()
+        
+        if VALIDATE:
+            running_pos_loss += positional_loss(outputs, input).item()
+
+        count += 1
+
+    train_loss = running_loss / count
+    valid_loss = running_pos_loss / count
+
+    return train_loss, valid_loss
+
+def save_model(model_dict):
+    print(f"Best model has been saved: {MODEL_SAVE_PATH} \n")
+    print(f"    -> Best Loss: {model_dict['train_loss']}")
+    torch.save(model_dict, MODEL_SAVE_PATH)
+
+def per_joint_error(labels, predictions, loss_fn = nn.MSELoss()):
+    errors = {}
+    joint_labels, joint_predictions = labels.transpose(1,0), predictions.transpose(1,0)
+
+    for joint in range(joint_labels.shape[0]):
+        error = loss_fn(joint_labels[joint], joint_predictions[joint])
+        errors[f"joint{joint+1}"] = error.item()
+    
+    print(f"\n-------------------------------\nPer Joint Error: \n{errors}")
+    print("-------------------------------\n")
+
+def plot_history(history):
+    train_loss = history['train_loss']
+    plt.plot(train_loss)
+    plt.show()
+    
+    if VALIDATE:
+        valid_loss = history['valid_loss']
+        plt.plot(valid_loss)
+        plt.show()
+
+def predict(model, inputs):
+    predictions = []
+
+    with torch.no_grad():
+        for input in inputs:
+            pred = model(input)
+            predictions.append(pred.numpy())
+
+    return list_to_tensor(np.array(predictions).astype(np.float32))
+
+def dataset_pos_error(predictions, inputs):
+    error = 0.0
+    count = 0.0
+    for prediction, input in zip(predictions, inputs):
+        error += positional_loss(prediction, input).item()
+        count += 1
+
+    print(f"\n-------------------------------\nPositional Mean Error: {error / count}")
+    print("-------------------------------\n")
+    return(error / count)
+
+def evaluate_model(model_dict, dataset):
+    inputs, labels = process_dataset(dataset)
+
+    print("Loading the best model...") 
+    # model = Net(len(inputs[0]), len(labels[0]))
+    # model.load_state_dict(model_dict['model_state_dict'])
+    
+    model = model_dict['model']
+    model.eval()
+
+    with torch.no_grad():
+        predictions = predict(model, inputs)
+
+        valid_loss = ""
+        if VALIDATE:
+            valid_loss = f"-> {model_dict['valid_loss']} positional error"
+        
+        print(f"""Best model was trained with: 
+        -> {model_dict['train_loss']} per joint error
+        {valid_loss}""")
+        
+        history = model_dict['history']
+
+        if PLOT: plot_history(history)
+
+        per_joint_error(labels, predictions, CRITERION)
+
+        if not VALIDATE:
+            dataset_pos_error(inputs, predictions)
+
+def solve_fk(prediction):
+    global kinematics_pub
+
+    msg = Joints()
+    joints = [val.item() for val in prediction]
+    msg.angles = joints
+
+    kinematics_pub.publish(msg)
+    return_msg = rospy.wait_for_message('/forward_kinematics_keypoints', KinematicsPose)
+    position = return_msg.pose.position
+
+    return torch.tensor([position.x, position.y, position.z])
+
+
+def positional_loss(prediction, inputs):
+    predicted_position = solve_fk(prediction)
+    actual_position = inputs[:3]
+
+    cost_fn = nn.HuberLoss()
+    cost = cost_fn(actual_position, predicted_position)
+
+    return cost
+
+
+def create_model(input_size, output_size):
+    m_choices = [8, 16, 32, 64, 128]
+    n_choices = [8, 16, 32, 64, 128, 256]
+    l_choices = [8, 16, 32, 64, 128]
+    layer_prob = [0.9, 0.6, 0.6]
+
+    activation_choices = [nn.Sigmoid(), nn.Tanh(), nn.ReLU(), nn.ELU(), None]
+    dropout_range = (0.0, 0.5)
+  
+    layers = [choice(m_choices), choice(n_choices), choice(l_choices)]
+    act_f = choice(activation_choices)
+
+    modules = []
+    prev = input_size
+
+    for layer_num, layer in enumerate(layers):
+
+        if random() <= layer_prob[layer_num]:
+            modules.append(nn.Linear(prev, layer))
+
+            dropout = round(uniform(dropout_range[0], dropout_range[1]), 2)
+            modules.append(nn.Dropout(dropout))
+            if act_f is not None:
+                modules.append(act_f)
+            prev = layer 
+
+    modules.append(nn.Linear(prev, output_size))
+    sequential = nn.Sequential(*modules)
+
+    return sequential
+
+def optimal_pareto(history):
+    values = []
+
+    for idx, his in enumerate(history):
+        value = his['min_loss'], his['pos_err']
+        values.append(value)
+
+    chain = pypareto.Comparison(pypareto.by_value, pypareto.MaxMinList(pypareto.MaxMin.MIN, pypareto.MaxMin.MIN)).as_chain()
+    pareto_front = chain.split_by_pareto(values)[0]
+
+    pareto_front_idx = [idx for idx, value in enumerate(values) if value in pareto_front]
+    mid = pareto_front_idx[int(round(len(pareto_front_idx) / 2, 0))]
+
+    for idx in pareto_front_idx:
+        print(history[idx])
+
+    best_model = history[mid]
+
+    if PLOT:
+        plt.scatter([value[0] for value in values], 
+                    [value[1] for value in values], 
+                    c = ["red" if value in pareto_front else "blue" for value in values])
+
+        plt.xlabel("Joint Loss")
+        plt.ylabel("Positional Loss")
+
+        plt.show()
+
+    return best_model
+
+    
+
+def run_nas(inputs, labels):
+    global LEARNING_RATE
+
+    HISTORY = []
+
+    lr_range = [1e-4, 5e-3, 1e-1]
+
+    for iter in tqdm(range(NAS_iter)):
+
+        best_model_iter = None
+        for initial_lr in lr_range:
+            model = create_model(inputs.shape[1], labels.shape[1])
+            
+            local_history = []
+            optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)
+
+            for epoch in range(EPOCHS):
+                loss, _ = train(inputs, labels, model, optimizer)
+                local_history.append(loss)
+
+
+            min_idx = 0
+            for i in range(len(local_history)):
+                if local_history[i] <= local_history[min_idx]:
+                    min_idx = i
+
+            if best_model_iter is None or best_model_iter['min_loss'] >= local_history[min_idx]:
+                best_model_iter = {
+                    "lr": initial_lr,
+                    "min_loss": local_history[min_idx],
+                    "best_epoch": min_idx,
+                    "model": model
+                }
+
+        ### Evaluate Positional Loss
+        predictions = predict(model, inputs)
+        pos_err = dataset_pos_error(inputs, predictions)
+
+        best_model_iter['pos_err'] = pos_err
+        HISTORY.append(best_model_iter)
+
+    best_model = optimal_pareto(HISTORY)
+    LEARNING_RATE = best_model['lr']
+    return best_model['model']
+
+def training_loop(dataset):
+    inputs, labels = process_dataset(dataset)
+
+    if RUN_NAS:
+        model = run_nas(inputs, labels)
+    else:
+        model = Net(len(inputs[0]), len(labels[0]))
+
+    print("\nTraining Hyperparameters:")
+    print(f"-> Model create with {len(inputs[0])} inputs and {len(labels[0])} outputs")
+    print(f"-> Training Epochs: {EPOCHS}")
+    print(f"-> Adam Optimizer (initial learning rate of {LEARNING_RATE})")
+    print(f"-> Cost Function based on {criterion} loss")
+    print(f"-> Scheduler: {SCHEDULER}\n")
+
+    optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
+    
+    best_model = None
+    history = {"train_loss": []}
+    
+    if VALIDATE:
+        history['valid_loss'] = []
+
+    for epoch in range(EPOCHS):
+
+        train_loss, val_loss = train(inputs, labels, model, optimizer)
+
+        if epoch % CHECKPOINT_LOG == 0:
+            print(f'Epoch {epoch} loss: {train_loss}')
+            print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
+            if VALIDATE:
+                print(f"Positional Validation Loss: {val_loss}")
+            print("-----------------------------------------\n")
+
+        if SCHEDULER:
+            SCHEDULER.step()
+
+        loss_tag = "valid_loss" if VALIDATE else "train_loss"
+        loss = train_loss if not VALIDATE else val_loss
+        if best_model is None or loss <= best_model[loss_tag]:
+            best_model = {
+                'epoch': epoch,
+                'train_loss': train_loss,
+                'model': model,
+                # 'model_state_dict': model.state_dict(),
+                'optimizer_state_dict': optimizer.state_dict()
+            }
+            if VALIDATE: 
+                best_model['valid_loss'] = val_loss
+         
+        history['train_loss'].append(train_loss)
+        
+        if VALIDATE: 
+            history['valid_loss'].append(val_loss)
+
+    best_model['history'] = history
+
+    print('Finished Training')
+    save_model(best_model)
+
+    return best_model
+
+
+if __name__ == '__main__':
+    rospy.init_node("training")
+    kinematics_pub = rospy.Publisher("evaluation", Joints, queue_size = 10)
+    rospy.sleep(1)
+
+    try:
+        if validate_params():
+            dataset = {}
+            with open(DATASET_PATH, "rb") as input_file:
+                dataset = pickle.load(input_file)
+                print(f"There are {len(dataset)} points in the dataset")
+            
+            print("Training starting...")
+            model_dict = training_loop(dataset)
+
+            if EVALUATE:
+                print("Evaluation starting...")
+                evaluate_model(model_dict, dataset)
+
+    except Exception as e:
+        rospy.logwarn(f"Exception thrown: {e}")
diff --git a/pose_estimation/scripts/training_data.py b/pose_estimation/scripts/training_data.py
new file mode 100755
index 0000000000000000000000000000000000000000..40f85697165cd71b7201eccd076bb913cdf87532
--- /dev/null
+++ b/pose_estimation/scripts/training_data.py
@@ -0,0 +1,416 @@
+import pickle
+import rospy
+import tf
+import cv2
+import math
+import os
+import threading
+
+import numpy as np
+import mediapipe as mp
+from scipy.spatial.transform import Rotation as R
+from tqdm import tqdm
+
+from tf.transformations import quaternion_matrix, translation_matrix
+from geometry_msgs.msg import Point32
+from geometry_msgs.msg import Pose
+from pose_estimation.msg import JointPositions
+from sensor_msgs.msg import PointCloud
+from open_manipulator_msgs.msg import KinematicsPose
+from open_manipulator_msgs.srv import SetKinematicsPose, SetJointPosition, SetJointPositionRequest
+
+SIMULATE = rospy.get_param('/training_data/simulate')
+VIDEO_SHOW = rospy.get_param('/training_data/show_video')
+
+FILENAME = rospy.get_param('/training_data/input_keypoint_file')
+TRAINING_FILE = rospy.get_param('/training_data/training_file')
+
+OUT_OF_RANGE_THRESHOLD = 0.1
+THRESHOLD = 2
+NOISE = 0.03
+
+
+input_file = os.path.expanduser(f"~/catkin_ws/src/pose_estimation/data/{FILENAME}")
+training_file = os.path.expanduser(f"~/catkin_ws/src/pose_estimation/data/{TRAINING_FILE}")
+
+mp_pose = mp.solutions.pose
+
+# Defining ROS Node to publish the messages
+frame = 'keypoint_frame'
+node = 'keypoint_capture'
+
+SHOULDER = mp_pose.PoseLandmark.RIGHT_SHOULDER.value
+pub_shoulder = rospy.Publisher(f"{node}/shoulder", PointCloud, queue_size=50)
+
+WRIST = mp_pose.PoseLandmark.RIGHT_WRIST.value
+pub_wrist = rospy.Publisher(f"{node}/wrist", PointCloud, queue_size=50)
+
+ELBOW = mp_pose.PoseLandmark.RIGHT_ELBOW.value
+pub_elbow = rospy.Publisher(f"{node}/elbow", PointCloud, queue_size=50)
+
+### Variables for moving the gripper
+END_EFFECTOR_NAME = "gripper"
+PATH_TIME = 1
+
+### Starting joint positions and angles
+START_LOCATION = (0.286, 0.0, 0.204)
+START_ANGLE = 0.0
+
+def shift_keypoints(original, offset):
+    return (np.array(original) - np.array(offset))
+
+def publish_point(keypoint, offset, timestamp, pub):
+    point = Point32()
+    shifted_keypoint = shift_keypoints(keypoint, offset)
+    point.x, point.y, point.z = shifted_keypoint
+    point.z = -point.z
+
+    msg = PointCloud()
+    msg.header.frame_id = frame
+    msg.header.stamp = timestamp
+    msg.points = [point]
+
+    pub.publish(msg)
+
+    return point
+
+def set_manipulator(manipulator_pos):
+    rospy.wait_for_service('goal_task_space_path_position_only', 2)
+    
+    set_kinematics_pose = rospy.ServiceProxy('goal_task_space_path_position_only', SetKinematicsPose)
+
+    pose = Pose()
+    pose.position.x = manipulator_pos[0]
+    pose.position.y = manipulator_pos[1]
+    pose.position.z = manipulator_pos[2]
+
+    kinematics_pose = KinematicsPose()
+    kinematics_pose.pose = pose
+
+    resp = set_kinematics_pose(planning_group = 'arm', end_effector_name = END_EFFECTOR_NAME, kinematics_pose = kinematics_pose, path_time = PATH_TIME)
+    
+    if resp.is_planned:
+        rospy.sleep(PATH_TIME)
+    
+    return resp.is_planned
+
+def set_wrist_angle(angle, starting = False):
+    if starting:
+        rospy.wait_for_service('goal_joint_space_path', 2)
+        set_joints = rospy.ServiceProxy('goal_joint_space_path', SetJointPosition)
+    else: 
+        rospy.wait_for_service('goal_joint_space_path_from_present', 2)
+        set_joints = rospy.ServiceProxy('goal_joint_space_path_from_present', SetJointPosition)
+
+    goal_request = SetJointPositionRequest()
+    goal_request.joint_position.joint_name = ["joint1", "joint2", "joint3", "joint4"]
+    goal_request.joint_position.position = [0.0, 0.0, 0.0, angle]
+    goal_request.joint_position.max_accelerations_scaling_factor = 0.0
+    goal_request.joint_position.max_velocity_scaling_factor = 0.0
+    goal_request.path_time = PATH_TIME
+    resp = set_joints(goal_request)
+
+    rospy.sleep(PATH_TIME)
+
+    if not resp.is_planned:
+        rospy.logerr("Failed to solve IK equation for wrist")
+
+    return resp.is_planned
+
+def point_to_np(point):
+    return np.array([point.x, point.y, point.z])
+
+def np_to_point(arr):
+    point = Point32()
+    point.x, point.y, point.z = arr[0], arr[1], arr[2]
+    return point
+
+def homogenous_transformation(old_position, transform_matrix):
+    # Converting the position to numpy array and appending one to it to make it a 4x1 vector
+    old_position_vect = point_to_np(old_position)
+    old_position_vect = np.append(old_position_vect, [1])
+    old_position_vect = old_position_vect.reshape((4,1))
+
+    # Multiplying the old point with the transform matrix to get the point in the other coordinate frame
+    transformed_point = np.matmul(transform_matrix, old_position_vect)
+
+    return np_to_point(transformed_point)
+
+def create_matrix(rot, trans):
+    # Converting the quaternion and position into matrices
+    rot_matrix = quaternion_matrix(rot)
+    trans_matrix = translation_matrix(trans)
+
+    # Combining the two matrices to create a 4x4 homogenous transformation matrix
+    rot_matrix = rot_matrix[:,:3]
+    trans_matrix = np.array([trans_matrix[:,3]])
+    transform_matrix = np.concatenate((rot_matrix, trans_matrix.T), axis=1)
+
+    return transform_matrix
+
+def broadcast(point, node_from, node_to, time, orientation = (0,0,0)):
+    br = tf.TransformBroadcaster()
+
+    br.sendTransform(
+        (point.x, point.y, point.z),
+        tf.transformations.quaternion_from_euler(orientation[0], orientation[1], orientation[2]),
+        time,
+        node_from,
+        node_to
+    )
+
+def tf_listener(old_frame, new_frame, time):
+    trans, rot = None, None
+    listener = tf.TransformListener()
+
+    try: 
+        listener.waitForTransform(new_frame, old_frame, time, rospy.Duration(1.5))
+        (trans, rot) = listener.lookupTransform(new_frame, old_frame,  time)
+    except Exception as e:
+        rospy.logwarn(f'Exception: {e}')
+
+    return trans, rot
+
+def create_transform_matrix(trans, rot):
+    if not trans or not rot:
+        return None
+    
+    return create_matrix(rot, trans)
+
+def show_video(video):
+    # Displaying the captured video frames 
+    cv2.imshow("Test video", video)
+
+def iterative_movement(current_loc, desired_loc):
+    # Checks how many times the IK failed to be solved
+    failure = 0
+    backtracked = False
+    # Iterating while we reach the set limit of failures
+    while failure < (THRESHOLD+1): 
+        # If the previous movement was successful, 
+        # we attempt to move the arm to the desired location
+        if failure == 0 or backtracked:
+            # Trying to move to the desired location
+            success = set_manipulator(desired_loc)
+            
+            # If desired location reached, finish
+            if success: break
+
+        # Calculating the midpoint between the desired location
+        # and current location, and then trying to move there
+        midpoint = (desired_loc + current_loc) / 2.
+        success = set_manipulator(midpoint)
+
+        # If failed to move to the midpoint, increase failure count
+        if not success:
+            failure += 1
+            
+            # Trying to move out, in case it is stuck
+            if failure == THRESHOLD:
+                
+                if backtracked: break
+
+                midpoint = (current_loc * 0.9)
+                success = set_manipulator(midpoint)
+
+                if not success:
+                    failure += 1
+                else:
+                    backtracked = True
+                    failure -= 1
+
+        # Else change the current location to the midpoint and reset failures
+        if success:
+            current_loc = midpoint
+            failure = 0
+    
+    # If moved to the desired location return True and the current location
+    if failure == 0:
+        rospy.logdebug(f"-> *** Reached Desired Location")
+        return True, desired_loc
+    
+    # Else return False and the closest possible location
+    rospy.logdebug(f"-> Failed to Reach the Destination - Final Location: {current_loc}\n")
+    return False, current_loc
+
+def euclid_dist(vec1, vec2):
+  if vec1.size != vec2.size:
+    print("Size of the two points doesn't match")
+    return None
+  
+  sum_ = 0
+  for i in range(vec1.size):
+    squared_difference = (vec1[i] - vec2[i])**2
+    sum_ += squared_difference
+  
+  return math.sqrt(sum_)
+
+def range_convert(old_range, new_range, old_value):
+    old_diff = (old_range[1] - old_range[0])
+    new_diff = (new_range[1] - new_range[0])
+
+    new_value = (((old_value - old_range[0]) * new_diff) / old_diff) + new_range[0]
+    return new_value
+
+### Getting the relative wrist
+def get_relative_point(frame1, frame2, ros_time, point, rotation = True):
+
+    trans, rot = tf_listener(frame1, frame2, ros_time)
+
+    if rotation == False:
+        rot = [0.0, 0.0, 0.0, 1.0]
+
+    transform = create_transform_matrix(trans, rot)
+
+    if transform is None:
+        return None
+    
+    return homogenous_transformation(point, transform)
+
+def save_point(dataset, input_point, output_point):
+    dataset['x'].append(input_point)
+    dataset['y'].append(output_point)
+    return dataset
+
+def verify_inputs():
+    verify = True
+
+    if type(SIMULATE) != bool or type(VIDEO_SHOW) != bool:
+        rospy.logfatal("'simulate' and 'video_show' arguments must be booleans")
+        verify = False
+
+    if not os.path.exists(input_file):
+        rospy.logfatal(f"The following file '{input_file}' does not exist")
+        verify = False
+
+    return verify
+
+def data_gather():
+    # Reading the pickle file and saving its content as a dictionary
+    ### TODO: Function
+    rospy.loginfo("Reading the file...")
+    with open(input_file, 'rb') as handle:
+        captured_keypoints = pickle.load(handle)
+    rospy.loginfo("File read successfully")
+    
+    kinematics_pub = rospy.Publisher('captured_keypoints', KinematicsPose, queue_size = 10)
+
+    for idx in captured_keypoints.keys():
+        rospy.loginfo(f"Processing video with {idx}")
+        points = captured_keypoints[idx]['points']
+
+
+        # TODO: Correct calibration ? Noise?
+        calib = 1.0 - captured_keypoints[idx]['calib'] + 0.025
+
+        for point_idx, point in enumerate(points):
+            if VIDEO_SHOW:
+                if cv2.waitKey(1) & 0xFF == ord('q'):
+                    rospy.logwarn("Closing the video")
+                    return
+                
+                video = captured_keypoints[idx]['video'][point_idx]
+                show_video(video)
+
+
+            if rospy.is_shutdown(): return
+    
+            angle = point['angle']
+
+            broadcast_time_1 = rospy.Time.now()
+            shoulder_point = publish_point(point[SHOULDER], point[SHOULDER], broadcast_time_1, pub_shoulder)
+            wrist_point = publish_point(point[WRIST], point[SHOULDER], broadcast_time_1, pub_wrist)
+            broadcast(shoulder_point, f"{node}/shoulder", frame, broadcast_time_1, orientation=(np.pi, 0.0, 0.0))
+            relative_wrist = get_relative_point(frame, f"{node}/shoulder", broadcast_time_1, wrist_point)
+            
+
+            if relative_wrist is None:
+                continue
+        
+            broadcast_time_2 = rospy.Time.now()
+            broadcast(relative_wrist, f"{node}/wrist", f"{node}/shoulder", broadcast_time_2)
+            manipulator_position = np_to_point(point_to_np(relative_wrist) * calib)
+            relative_manipulator = get_relative_point(f"{node}/shoulder", f"{node}/wrist", broadcast_time_2, manipulator_position)
+            elbow_point = publish_point(point[ELBOW], point[WRIST], broadcast_time_2, pub_elbow)
+
+            if relative_manipulator is None:
+                continue
+            
+            broadcast_time_3 = rospy.Time.now()
+            broadcast(relative_manipulator, "/gripper/kinematics_pose", f"{node}/wrist", broadcast_time_3)
+            broadcast(np_to_point(np.zeros((3,))), "/gripper/wrist", "/gripper/kinematics_pose", broadcast_time_3, orientation=(np.pi, angle, np.pi))
+
+            # Getting the desired location
+            desired_loc = point_to_np(relative_manipulator).flatten()
+
+            pose = Pose() 
+            pose.position.x = desired_loc[0]
+            pose.position.y = desired_loc[1] 
+            pose.position.z = desired_loc[2]
+            
+
+            orientation = tf.transformations.quaternion_from_euler(0, angle, 0)
+            pose.orientation.x = orientation[0]
+            pose.orientation.y = orientation[1]
+            pose.orientation.z = orientation[2]
+            pose.orientation.w = orientation[3]
+
+
+            kinematics_pose = KinematicsPose()
+            kinematics_pose.pose = pose
+
+            # Won't unlock until received a message back
+            lock.acquire()
+            kinematics_pub.publish(kinematics_pose)
+            ###########################################
+
+def vector_angle(v1, v2):
+  """ Angle between two 2D vectors """
+  unit_v1 = v1 / np.linalg.norm(v1)
+  unit_v2 = v2 / np.linalg.norm(v2)
+  dot_product = np.dot(unit_v1, unit_v2)
+  angle = np.arccos(dot_product)
+
+  if v2[1] < 0:
+    angle = -angle
+
+  return angle
+
+def process_joint_positions(joint_position):
+    save = True if joint_position.success or len(joint_position.jointPositions) != 0 else False
+    
+    if save:
+        position = joint_position.manipulatorPose.pose.position
+        position.y = vector_angle(np.array([1,0]), np.array([position.x, position.y]))
+
+        TRAINING_DATA.append({
+            "jointPositions": joint_position.jointPositions,
+            "manipulatorPositions": position,
+            "angle": joint_position.manipulatorPose.pose.orientation
+        })
+    else:
+        rospy.logerr(f"Point out of range: {joint_position.manipulatorPose.pose.position}")  
+    lock.release()
+
+
+if __name__ == '__main__':
+    lock = threading.Lock()
+    TRAINING_DATA = []
+
+    rospy.init_node(node)
+    rospy.Subscriber("inverse_kinematics_keypoints", JointPositions, process_joint_positions)
+
+    try:  
+        # rospy.loginfo("Sleeping for 1 second")
+        # rospy.sleep(1)
+        if verify_inputs(): data_gather()
+        
+    except rospy.ROSInterruptException as e:
+        rospy.logwarn(f"Exception: {e}")
+    
+    lock.acquire()
+    rospy.loginfo("Saving training dataset")
+    with open(training_file, 'wb') as handle:
+        pickle.dump(TRAINING_DATA, handle, protocol=pickle.HIGHEST_PROTOCOL)
+    lock.release()
+
diff --git a/pose_estimation/scripts/visualize_keypoints.py b/pose_estimation/scripts/visualize_keypoints.py
deleted file mode 100644
index 73ee6dcecd6b526d65999c16e4756dc3bce9302a..0000000000000000000000000000000000000000
--- a/pose_estimation/scripts/visualize_keypoints.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# Distance of the Manipulator -> 0.36
-# Max distance of the Arm -> 0.53
-# Scale = 0.36/0.53
-
-import pickle
-import time
-import rospy
-import tf
-
-import numpy as np
-import mediapipe as mp
-
-from tf.transformations import quaternion_matrix, translation_matrix
-
-from geometry_msgs.msg import Point32
-from geometry_msgs.msg import Pose
-from sensor_msgs.msg import PointCloud
-from open_manipulator_msgs.msg import KinematicsPose
-from open_manipulator_msgs.srv import SetKinematicsPose
-
-mp_pose = mp.solutions.pose
-SHOULDER = mp_pose.PoseLandmark.RIGHT_SHOULDER.value
-WRIST = mp_pose.PoseLandmark.RIGHT_WRIST.value
-
-# Pickle file with the keypoints
-FILENAME = 'keypoints.pickle'
-
-# Defining ROS Node to publish the messages
-frame = 'keypoint_frame'
-node = 'keypoint_capture'
-
-pub_shoulder = rospy.Publisher(f"{node}/shoulder", PointCloud, queue_size=50)
-pub_wrist = rospy.Publisher(f"{node}/wrist", PointCloud, queue_size=50)
-
-
-# TODO: Remove if not using
-manipulator_state = None
-
-def joint_manipulation(joint_state):
-    global manipulator_state
-    manipulator_state = joint_state
-
-
-def shift_keypoints(original, offset):
-    return (np.array(original) - np.array(offset))
-
-def publish_point(keypoint, offset, timestamp, pub):
-    point = Point32()
-    shifted_keypoint = shift_keypoints(keypoint, offset)
-    point.x, point.y, point.z = shifted_keypoint
-    point.z = -point.z
-
-    msg = PointCloud()
-    msg.header.frame_id = frame
-    msg.header.stamp = timestamp
-    msg.points = [point]
-
-    pub.publish(msg)
-
-    return point
-
-
-def broadcast(point, node_from, node_to, time, orientation = (0,0,0)):
-    br = tf.TransformBroadcaster()
-
-    br.sendTransform((point.x, point.y, point.z),
-                      tf.transformations.quaternion_from_euler(orientation[0], orientation[1], orientation[2]),
-                      time,
-                      node_from,
-                      node_to)
-
-def point_to_np(point):
-    return np.array([point.x, point.y, point.z])
-
-def np_to_point(arr):
-    point = Point32()
-    point.x, point.y, point.z = arr[0], arr[1], arr[2]
-    return point
-
-def homogenous_transformation(old_position, transform_matrix):
-    # Converting the position to numpy array and appending one to it to make it a 4x1 vector
-    old_position_vect = point_to_np(old_position)
-    old_position_vect = np.append(old_position_vect, [1])
-    old_position_vect = old_position_vect.reshape((4,1))
-
-    # Multiplying the old point with the transform matrix to get the point in the other coordinate frame
-    transformed_point = np.matmul(transform_matrix, old_position_vect)
-
-    return np_to_point(transformed_point)
-
-def create_matrix(rot, trans):
-    # Converting the quaternion and position into matrices
-    rot_matrix = quaternion_matrix(rot)
-    trans_matrix = translation_matrix(trans)
-
-    # Combining the two matrices to create a 4x4 homogenous transformation matrix
-    rot_matrix = rot_matrix[:,:3]
-    trans_matrix = np.array([trans_matrix[:,3]])
-    transform_matrix = np.concatenate((rot_matrix, trans_matrix.T), axis=1)
-
-    return transform_matrix
-
-
-def create_transform_matrix(old_frame, new_frame, time):
-    transform = None
-    listener = tf.TransformListener()
-    
-    # TODO: Why does it fail at times? Time?
-    try: 
-        listener.waitForTransform(new_frame, old_frame, time, rospy.Duration(1.0))
-        (trans, rot) = listener.lookupTransform(new_frame, old_frame,  time)
-
-        transform = create_matrix(rot, trans)
-    except Exception as e:
-        rospy.logwarn(e)
-        
-    
-    return transform
-
-def visualization():
-    gripper_topic = "/gripper/kinematics_pose"
-    print(f"Subscribing: {gripper_topic}")
-
-    rospy.Subscriber(gripper_topic, KinematicsPose, joint_manipulation, queue_size = 1)
-    
-    while not manipulator_state:
-        continue
-    
-    print("Initial manipulator location has been received")
-    
-    # Reading the pickle file and saving its content as a dictionary
-    with open(FILENAME, 'rb') as handle:
-        captured_keypoints = pickle.load(handle)
-
-    # Iterating through all of the frames with keypoints
-    for timestamp, keypoints in captured_keypoints.items():
-        if rospy.is_shutdown():
-            break
-        
-        if not keypoints:
-            continue
-
-        if SHOULDER not in keypoints or WRIST not in keypoints:
-            continue
-
-        ros_time = rospy.Time.now()
-
-        timestamp = rospy.Time.now()
-        offset = keypoints[SHOULDER]
-
-        shoulder = keypoints[SHOULDER]
-        shoulder_point = publish_point(shoulder, offset, timestamp, pub_shoulder)
-
-        wrist = keypoints[WRIST]
-        wrist_point = publish_point(wrist, offset, timestamp, pub_wrist)
-
-        broadcast(shoulder_point, f"{node}/shoulder", frame, ros_time, orientation=(0, 0, np.pi))
-        broadcast(wrist_point, f"{node}/wrist", f"{node}/shoulder", ros_time)
-
-        # TODO: Rotation; Scale (Calibration)
-        manipulator_position = np_to_point((point_to_np(wrist_point) * (0.36/0.53)))
-
-        # TODO:
-        # manipulator_position = manipulator_state.pose.position
-        transform = create_transform_matrix(f"{node}/shoulder", f"{node}/wrist", ros_time)
-        
-        if transform is None:
-            continue
-
-        relative_manipulator = homogenous_transformation(manipulator_position, transform, )
-        broadcast(relative_manipulator, gripper_topic, f"{node}/wrist", ros_time)
-
-
-        # Sleep
-        time.sleep(0.5)
-
-
-if __name__ == '__main__':
-    rospy.init_node(node)
-
-    try:
-        visualization()
-    except rospy.ROSInterruptException as e:
-        print("Exception: " + e)
diff --git a/pose_estimation/setup.py b/pose_estimation/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d628ffed4cacf60d810542fead83197319bf550d
--- /dev/null
+++ b/pose_estimation/setup.py
@@ -0,0 +1,10 @@
+## ! DO NOT MANUALLY INVOKE THIS setup.py, USE CATKIN INSTEAD
+from distutils.core import setup
+from catkin_pkg.python_setup import generate_distutils_setup
+
+# fetch values from package.xml
+setup_args = generate_distutils_setup(
+    packages=['helpers'],
+    package_dir={'': 'scripts'},
+)
+setup(**setup_args)
diff --git a/pose_estimation/src/kinematics_solver.cpp b/pose_estimation/src/kinematics_solver.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..b7e1a6db372c0218798a4cb633010bc6ca6e4107
--- /dev/null
+++ b/pose_estimation/src/kinematics_solver.cpp
@@ -0,0 +1,102 @@
+
+#include "pose_estimation/kinematics_solver.h"
+
+KinematicsSolver::KinematicsSolver(bool using_platform, std::string usb_port, std::string baud_rate, double control_period)
+{
+    log::info("Setting up the IK Solver for Open Manipulator");
+
+    open_manipulator_.initOpenManipulator(using_platform, usb_port, baud_rate, control_period);
+    goal_joint_value_ = new std::vector<JointValue>();
+
+    kinematics_ = new kinematics::SolverUsingCRAndSRPositionOnlyJacobian();
+    open_manipulator_.addKinematics(kinematics_);
+    log::info("Kinematics Solver Set 'SolverUsingCRandSRPoisionOnlyJacobian'");
+
+    ik_pub_ = n_.advertise<pose_estimation::JointPositions>("inverse_kinematics_keypoints", 1000);
+    fk_pub_ = n_.advertise<open_manipulator_msgs::KinematicsPose>("forward_kinematics_keypoints", 1000);
+
+    log::info("Completed setting up the Kinematics Solver");
+}
+
+void KinematicsSolver::solveIK(Pose target_pose, const open_manipulator_msgs::KinematicsPose &manipulator_pose)
+{
+    pose_estimation::JointPositions msg;
+    bool solved = open_manipulator_.solveInverseKinematics("gripper", target_pose, goal_joint_value_);
+
+    int idx = 0;
+    auto names = open_manipulator_.getManipulator()->getAllActiveJointComponentName();
+
+    std::cout << "Computed Joints:" << std::endl;
+
+    for (auto &point : *goal_joint_value_)
+    {
+        pose_estimation::JointAngle joint;
+        joint.angle = point.position;
+        joint.name = names.at(idx);
+        msg.jointPositions.push_back(joint);
+        
+        std::cout << joint.angle << std::endl;
+
+        idx++;
+    }
+
+    std::cout << "--------------------------" << std::endl;
+
+    msg.success = solved;
+
+    msg.manipulatorPose = manipulator_pose;
+    ik_pub_.publish(msg);
+    log::info("Published the point");
+}
+
+void KinematicsSolver::keypointsInverseCallback(const open_manipulator_msgs::KinematicsPose &msg)
+{
+    Eigen::Vector3d position;
+    position[0] = msg.pose.position.x;
+    position[1] = 0.0; // TODO: msg.pose.position.y; ??
+    position[2] = msg.pose.position.z;
+
+    std::cout << "Received a position:" << std::endl;
+    std::cout << msg.pose.position.x << std::endl;
+    std::cout << msg.pose.position.y << std::endl;
+    std::cout << msg.pose.position.z << std::endl;
+
+    Pose target_pose = {position};
+    solveIK(target_pose, msg);
+}
+
+void KinematicsSolver::solveFK(const pose_estimation::Joints &msg)
+{
+    open_manipulator_msgs::KinematicsPose pub_msg;
+
+    open_manipulator_.getManipulator()->setAllActiveJointPosition(msg.angles);
+    open_manipulator_.solveForwardKinematics();
+
+    auto gripper_position = open_manipulator_.getPose("gripper");
+    pub_msg.pose.position.x = gripper_position.kinematic.position[0];
+    pub_msg.pose.position.y = gripper_position.kinematic.position[1];
+    pub_msg.pose.position.z = gripper_position.kinematic.position[2];
+
+    fk_pub_.publish(pub_msg);
+}
+
+void KinematicsSolver::keypointsForwardCallback(const pose_estimation::Joints &msg)
+{
+    ros::Duration(0.1).sleep();
+    solveFK(msg);
+}
+
+int main(int argc, char **argv)
+{
+    ros::init(argc, argv, "kinematics_solver");
+
+    KinematicsSolver kinematicsSolver(false, "/dev/ttyUSB0", "1000000", 0.010);
+    auto n = kinematicsSolver.getNodeHandle();
+    
+    ros::Duration(1.0).sleep();
+    ros::Subscriber sub_fk = n.subscribe("/evaluation", 100, &KinematicsSolver::keypointsForwardCallback, &kinematicsSolver);
+    ros::Subscriber sub_ik = n.subscribe("/captured_keypoints", 100, &KinematicsSolver::keypointsInverseCallback, &kinematicsSolver);
+
+    ros::spin();
+    return 0;
+}
\ No newline at end of file