Raspberry Pi Notes

Companion Computer Setup for Flight Controller 2024.pdf

Raspberry-Pi-5-Pinout--1210x642.jpg

Wiring RPI 5 x Pixhawk

image.png

image.png

def tracking():
    while True:
        objects = huskyLens.requestAll()
        for item in objects:
            x = item['x']
            y = item['y']
            
            # Initialize velocity components
            vx = 0
            vy = 0
            vz = 0

            # Adjust horizontal velocity based on x position
            if x > 165:
                vy = 1  # Move RIGHT
            elif x < 155:
                vy = -1  # Move LEFT

            # Adjust forward/backward velocity based on y position
            if y < 115:
                vx = 1  # Move FORWARD
            elif y > 125:
                vx = -1  # Move BACKWARD

            # If the object is within the desired range, stop the vehicle
            if 155 <= x <= 165 and 115 <= y <= 125:
                vx = 0
                vy = 0
                print("Target reached")

            # Send the velocity command
            send_ned_velocity(vx, vy, vz)

update dronekit init.py .abc.mutablemapping

Drone Python File

mavproxy.py --master=/dev/ttyAMA0 --baudrate 921600 --out 127.0.0.1:14550

Connect

SITL Simulator

SITL with Gazebo

Records

2024-09-12 15-49-10.mp4

2024-09-12 16-28-37.mp4

2024-09-12 16-00-33.mp4

TRACKING_VICTIM

Version 1

# Description: This code is used to detect the victim using 
# the TFLite model and stream the video with the bounding box 
# around the victim. The drone will follow the victim by
# moving towards the victim's center.

## IMPORT LIBRARY
# FLASK SERVER TO STREAM VIDEO
from flask import Flask, Response, render_template_string
# LIBRARY TO RECORD VIDEO
from picamera2 import Picamera2, Preview
from libcamera import controls
from PIL import Image
# LIBRARY TO GIVE COORDINATES
import numpy as np
# LIBRARY TO DRAW BOUNDING BOXES
import cv2
# LIBRARY TO LOAD TFLITE MODEL
import tflite_runtime.interpreter as tflite
# DRONEKIT LIBRARY TO CONNECT TO DRONE
from dronekit import connect, VehicleMode, LocationGlobalRelative
from pymavlink import mavutil
# OTHER NECESSARY LIBRARIES
import logging
import time
import signal 
import sys

# CONNECT TO DRONE VEHICLE
vehicle = connect('127.0.0.1:14550', baud=921600, wait_ready=True)
vehicle.parameters.set('PLND_ENABLED', 1)
vehicle.parameters.set('PLND_TYPE', 1)
vehicle.parameters.set('PLND_EST_TYPE', 0)

# ARM AND TAKEOFF FUNCTION
def arm_and_takeoff(altitude):

   while not vehicle.is_armable:
      print("waiting to be armable")
      time.sleep(1)

   print("Arming motors")
   vehicle.mode = VehicleMode("GUIDED")
   vehicle.armed = True

   while not vehicle.armed: time.sleep(1)

   print("Taking Off")
   vehicle.simple_takeoff(altitude)

   while True:
      v_alt = vehicle.location.global_relative_frame.alt
      print(">> Altitude = %.1f m"%v_alt)
      if v_alt >= altitude - 1.0:
          print("Target altitude reached")
          break
      time.sleep(1)

# FUNCTION TO MOVE THE DRONE AT VELOCITY
def send_ned_velocity(vehicle, vx, vy, vz, duration, MV_Status):
    vehicle.mode = VehicleMode("GUIDED")
    msg = vehicle.message_factory.set_position_target_local_ned_encode(
        0,
        0, 0,
        9,
        1479,
        0, 0, 0, # POSITION m
        vx, vy, vz, # VELOCITY m/s
        0, 0, 0, # ACCELERATIONS m/s^2
        0, 0
    )
    for x in range(0, duration):
        vehicle.send_mavlink(msg)
        print(MV_Status)
        time.sleep(1)

def condition_yaw(vehicle, radian, duration, MV_Status):
    vehicle.mode = VehicleMode("GUIDED")
    msg = vehicle.message_factory.set_position_target_local_ned_encode(
        0,
        0, 0,
        9,
        2503,
        0, 0, 0, # POSITION m
        0, 0, 0, # VELOCITY m/s
        0, 0, 0, # ACCELERATIONS m/s^2
        radian, 0
    )
    for x in range(0, duration):
        vehicle.send_mavlink(msg)
        print(MV_Status)
        time.sleep(1)

def square_movement():
    """Move the drone in a square pattern."""
    print("Executing square movement pattern.")
    for _ in range(4):  # Repeat 4 times for a square
        # Move forward
        send_ned_velocity(vehicle, 1, 0, 0, duration=3)  # Move forward for 3 seconds
        # Turn right (yaw)
        radian = 1.5708
        condition_yaw(vehicle, radian, 1, MV_Status="YAW TO THE RIGHT")  # Rotate by 90 degrees TO RIGHT
        time.sleep(5)  # Wait a second for yaw to complete

def close_drone_connection(signum, frame):
    """Handles closing the drone connection gracefully."""
    vehicle.Mode = VehicleMode("RTL")
    print("Closing drone connection...")
    if vehicle.armed:
        print("Disarming vehicle...")
        vehicle.armed = False
        while vehicle.armed:
            time.sleep(1)
    vehicle.close()
    print("Drone connection closed.")
    sys.exit(0)

app = Flask(__name__)

# Load TensorFlow Lite model with specified number of threads
print("AI MODEL IS INITIATING")
model_path = "/home/sedna/Downloads/PROJECT_IRIS_V1/IRIS.tflite" # DIRECTORY PATH OF TFLITE MODELS
label_path = "/home/sedna/Downloads/PROJECT_IRIS_V1/labelmap.txt" # DIRECTORY PATH OF LABELMAP FILE
num_threads = 4  # Modify this value to control the number of threads
print("Loading model:", model_path)
print("Using", num_threads, "threads")
print("Loading labels:", label_path)
print("Using", num_threads, "threads")

if model_path is None:
    print("Model not found!")
    exit()
elif label_path is None:
    print("Label map not found!")
    exit()
else :
    print("AI MODEL IS SUCCESSFULLY INITIATED")

# Use 'Interpreter' with 'num_threads' option
interpreter = tflite.Interpreter(model_path=model_path, num_threads=num_threads)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Load labels
with open(label_path, 'r') as f:
    labels = [line.strip() for line in f.readlines()]

# Setup PiCamera2 with autofocus enabled
print("CAMERA IS INITIATING")
picam2 = Picamera2()
picam2.set_controls({"AfMode":controls.AfModeEnum.Continuous,"AfSpeed":controls.AfSpeedEnum.Fast,"AfRange":controls.AfRangeEnum.Full,"AeExposureMode":controls.AeExposureModeEnum.Short})
camera_config = picam2.create_preview_configuration(main={"format": "RGB888", "size": (1280, 720)})
picam2.configure(camera_config)
picam2.start()
print("CAMERA IS SUCCESSFULLY INITIATED")

# ARM AND TAKEOFF THE DRONE
print("DRONE IS ARMING AND TAKING OFF")
arm_and_takeoff(10)
vehicle.flush()
'''
### THIS CODE IS READY FOR MOVING TO GPS POINT
print("Going towards first point for 30 seconds ...")
point1 = LocationGlobalRelative(5.147264589053249, 100.49392067963265, 20)
vehicle.simple_goto(point1)
'''
# Register the signal handler for keyboard interrupt
signal.signal(signal.SIGINT, close_drone_connection)

# Set verbosity
verbose = True  # Change this to False to reduce logging output
if verbose:
    logging.basicConfig(level=logging.INFO)
else:
    logging.basicConfig(level=logging.WARNING)

def preprocess_image(image):
    """Resize and preprocess image for the TFLite model."""
    # Ensure the input shape matches the model's input dimensions
    input_shape = input_details[0]['shape'][1:3]  # [300, 300]

    # Convert the image to RGB format (Pillow's default is RGB)
    image = Image.fromarray(image).convert('RGB')

    # Resize the image to the expected input size
    image = image.resize(input_shape, Image.LANCZOS)  # Pillow 10 uses LANCZOS for high-quality downsampling

    # Convert image to numpy array and expand dimensions to match [1, 300, 300, 3]
    input_tensor = np.expand_dims(np.array(image), axis=0)

    # Ensure the image is in uint8 format as required by the model
    #input_tensor = input_tensor.astype(np.uint8)
    input_tensor = input_tensor.astype(np.float32)

    return input_tensor

# Configure logging
logging.basicConfig(level=logging.DEBUG)

def detect_objects(image, overlap_thresh=0.5):
    """Run object detection on the input image with thresholding and NMS."""
    input_tensor = preprocess_image(image)
    interpreter.set_tensor(input_details[0]['index'], input_tensor)
    interpreter.invoke()
    
    # Check output layer name to determine if this model was created with TF2 or TF1,
    # because outputs are ordered differently for TF2 and TF1 models
    outname = output_details[0]['name']

    if ('StatefulPartitionedCall' in outname):  # This is a TF2 model
        boxes_idx, classes_idx, scores_idx = 1, 3, 0
    else:  # This is a TF1 model
        boxes_idx, classes_idx, scores_idx = 0, 1, 2

    # Retrieve and copy the output tensors immediately to avoid the reference issue
    boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0].copy()  # Bounding boxes
    classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0].copy()  # Class IDs
    scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0].copy()  # Confidence scores
    
    # Set a confidence threshold
    confidence_threshold = 0.3
    selected_indices = np.where(scores > confidence_threshold)

    # If no objects exceed the confidence threshold, skip detection
    if len(selected_indices[0]) == 0:
        return [], [], []

    # Filter boxes, scores, and classes
    boxes = boxes[selected_indices]
    scores = scores[selected_indices]
    classes = classes[selected_indices]

    # Convert bounding boxes to pixel coordinates
    bounding_boxes = []
    for box in boxes:
        ymin, xmin, ymax, xmax = box
        x = int(xmin * image.shape[1])
        y = int(ymin * image.shape[0])
        width = int((xmax - xmin) * image.shape[1])
        height = int((ymax - ymin) * image.shape[0])
        bounding_boxes.append([x, y, width, height])

    # Apply NMS
    indices = cv2.dnn.NMSBoxes(
        bounding_boxes, scores.tolist(),
        confidence_threshold, overlap_thresh
    )
    if len(indices) > 0:
        indices = indices.flatten()
        bounding_boxes = np.array([bounding_boxes[i] for i in indices])
        scores = np.array([scores[i] for i in indices])
        classes = np.array([classes[i] for i in indices])
    
    return bounding_boxes, scores, classes

def annotate_image(image, boxes, scores, classes):
    """Annotate the image with detection results."""
    centers = []
    highest_score = -1
    highest_score_index = -1
    for i in range(len(boxes)):
        x, y, w, h = boxes[i]
        start_point = (x, y)
        end_point = (x + w, y + h)
        
        # Draw bounding boxes for all detections
        cv2.rectangle(image, start_point, end_point, (0, 255, 0), 2)
        object_name = labels[int(classes[i])]  # Look up object name from "labels" array using class index
        label = f"{object_name}: {scores[i]:.2f}"
        cv2.putText(image, label, (start_point[0], start_point[1] - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        # Calculate the center of the bounding box
        center_x = x + w // 2
        center_y = y + h // 2
        centers.append((center_x, center_y))

        # Track the highest scoring detection
        if scores[i] > highest_score:
            highest_score = scores[i]
            highest_score_index = i
    
    # Return the modified image, all centers, and the center of the highest-scoring detection
    return image, centers, centers[highest_score_index] if highest_score_index != -1 else None

# Define the target framerate (e.g., 10 frames per second)
target_fps = 10
frame_time = 1.0 / target_fps

# Global variable to store the latest coordinates
latest_centers = []

def follow_target(centers, scores, last_detection_time):
    """Function to follow the target with the highest score."""
    current_time = time.time()

    if len(centers) > 0:
        # RESET THE DETECTION TIME SINCE WE HAVE A DETECTION
        # Get the center coordinates of the highest-scoring detection
        # Find the index of the highest-scoring detection
        max_score_idx = np.argmax(scores)  # Get the index of the highest score
        center_x, center_y = centers[max_score_idx]  # Use the center of the highest-scoring target

        # Set default velocities
        vx = 0
        vy = 0
        vz = 0

        # DEFINE THE IMAGE WIDTH AND HEIGHT
        image_width = 1280
        image_height = 720

        # DEFINE THE CENTER TOLERANCE RANGES FOR X AND Y
        x_min = image_width // 2 - 10  
        x_max = image_width // 2 + 10  
        y_min = image_height // 2 - 10  
        y_max = image_height // 2 + 10  

        '''center_x is the x-coordinate of the center of the bounding box,
        center_y is the y-coordinate of the center of the bounding box'''
        # Adjust horizontal velocity based on the x position (left/right)
        # MOVE RIGHT
        if center_x > x_max:
            vy = 1
            vz = 0
        # MOVE LEFT
        elif center_x < x_min:
            vy = -1
            vz = 0
        # Adjust forward/backward velocity based on the y position
        # MOVE FORWARD
        if center_y < y_min:
            vx = 1
            vz = 0
        # MOVE BACKWARD
        elif center_y > y_max:
            vx = -1
            vz = 0
        if (center_x > x_min and center_x < x_max) and (center_y > y_min and center_y < y_max):
            vx = 0
            vy = 0
            vz = 0
            vehicle.mode = VehicleMode("LOITER")

        # Send the velocity command to follow the target
        send_ned_velocity(vehicle, vx, vy, vz, duration=1, MV_Status="Following Target")
    '''
    else:
        # IF NO TARGET IS DETECTED FOR 15 SECONDS, SWITCH TO SQUARE MOVEMENT
        if current_time - last_detection_time > 15:
            square_movement()
        else:
            # If no target is detected, switch the drone to Loiter mode
            print("No victim detected, switching to Loiter mode.")
            vehicle.mode = VehicleMode("LOITER")
    '''
    return last_detection_time

def generate_frames():
    """Generate frames from the camera and run object detection."""
    global latest_centers  # Track the latest centers for coordinate updates
    last_detection_time = time.time()  # Initialize detection time to current time

    while True:
        frame = picam2.capture_array()

        # Detect objects in the frame
        boxes, scores, classes = detect_objects(frame)

        # Annotate the frame with detection results
        annotated_frame, centers = annotate_image(frame, boxes, scores, classes)

        # Update the global centers list for coordinate display
        latest_centers = centers

        # Follow the highest-scoring target or execute square movement
        last_detection_time = follow_target(centers, scores, last_detection_time)

        # Encode the frame in JPEG format
        ret, buffer = cv2.imencode('.jpg', annotated_frame)
        if not ret:
            continue  # Skip this frame if encoding fails

        frame = buffer.tobytes()

        # Concatenate frame to a streaming-compatible format
        yield (b'--frame\\r\\n'
               b'Content-Type: image/jpeg\\r\\n\\r\\n' + frame + b'\\r\\n')

HTML_TEMPLATE = "/home/sedna/Drone_code/HTML_TEMPLATE.txt"

@app.route('/video_feed')
def video_feed():
    return render_template_string(HTML_TEMPLATE)

@app.route('/video_feed_stream')
def video_feed_stream():
    return Response(generate_frames(), mimetype='multipart/x-mixed-replace; boundary=frame')

@app.route('/')
def index():
    return "Video Streaming with Object Detection. Go to /video_feed to view the stream."

@app.route('/get_coordinates')
def get_coordinates():
    global latest_centers
    centers_str = ', '.join([f"({x},{y})" for x, y in latest_centers])
    return centers_str

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=False)

1. High-level Functionality:

2. Strengths of the Code:

3. Potential Issues:

A. Movement Logic (Bounding Box Centering)

B. Time-based Square Movement

C. Bounding Box Center Handling

D. Model Confidence Threshold and Non-Max Suppression (NMS)

E. Signal Handling (Graceful Shutdown)

F. Resource Usage

4. DroneKit Interaction Concerns:

5. Suggested Improvements:

6. Conclusion:

The code is well-structured and should function as expected, given that the correct model and labels are provided. However, improvements in movement control, mode handling, error handling, and multi-threading would enhance reliability and performance.

Version 2

## IMPORT LIBRARY
from flask import Flask, Response, render_template_string
from picamera2 import Picamera2, Preview
from libcamera import controls
from PIL import Image
import numpy as np
import cv2
import tflite_runtime.interpreter as tflite
from dronekit import connect, VehicleMode, LocationGlobalRelative
from pymavlink import mavutil
import logging
import time
import signal
import sys
import threading
import queue
import keyboard  # For detecting emergency stop

# Set up logging with timestamps for better debugging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s] %(message)s')

# GLOBAL VARIABLES
vehicle = None
latest_centers = []
emergency_stop = False  # Global flag for emergency stop

# THREAD-SAFE QUEUE FOR FRAME PROCESSING
frame_queue = queue.Queue(maxsize=5)

# FLASK APP HTML TEMPLATE
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>SKYSCANNER</title>
    <style>
        @import url('<https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700&display=swap>');

        body {
            background: url('<https://wallpapers.com/images/high/4k-anime-space-jrdma2osexe3xo8v.webp>') no-repeat center center fixed;
            background-size: cover;
            color: #00ff00;
            font-family: 'Orbitron', sans-serif;
            display: flex;
            flex-direction: column;
            align-items: center;
            justify-content: center;
            height: 100vh;
            margin: 0;
            overflow: hidden;
        }
        h1 {
            font-size: 3em;
            margin-bottom: 20px;
            z-index: 1;
            text-shadow: 2px 2px 4px #000000;
        }
        .container {
            display: flex;
            flex-direction: row;
            z-index: 1;
        }
        .video-container {
            border: 5px solid #00ff00;
            padding: 10px;
            background-color: rgba(0, 0, 0, 0.5);
        }
        .video-container img {
            width: 100%;
            height: auto;
        }
        .coordinates {
            margin-left: 20px;
            color: #00ff00;
            font-size: 1.5em;
            background-color: rgba(0, 0, 0, 0.7);
            padding: 10px;
            border-radius: 10px;
            overflow-y: auto;
            max-height: 500px; /* Adjust as needed */
        }
        .footer {
            margin-top: 20px;
            font-size: 1.2em;
            text-shadow: 2px 2px 4px #000000;
        }
    </style>
</head>
<body>
    <h1>SKYSCANNER LIVE STREAM</h1>
    <div class="container">
        <div class="video-container">
            <img src="{{ url_for('video_feed_stream') }}" alt="Video Stream">
        </div>
        <div class="coordinates" id="coordinates"></div>
    </div>
    <div class="footer">
        <p>Video Streaming with Victim Detection</p>
    </div>

    <script>
        function fetchCoordinates() {
            fetch('{{ url_for('get_coordinates') }}')
                .then(response => response.text())
                .then(data => {
                    document.getElementById('coordinates').innerText = data;
                });
        }

        setInterval(fetchCoordinates, 1000); // Fetch coordinates every 1 second
    </script>
</body>
</html>
"""

# CONNECT TO DRONE VEHICLE (WITH ERROR HANDLING)
def connect_to_vehicle():
    global vehicle
    try:
        vehicle = connect('127.0.0.1:14550', baud=921600, wait_ready=True)
        vehicle.parameters.set('PLND_ENABLED', 1)
        vehicle.parameters.set('PLND_TYPE', 1)
        vehicle.parameters.set('PLND_EST_TYPE', 0)
        logging.info("Drone connected successfully.")
    except Exception as e:
        logging.error(f"Failed to connect to vehicle: {e}")
        sys.exit(1)

# ARM AND TAKEOFF FUNCTION (WITH ERROR HANDLING)
def arm_and_takeoff(altitude):
    try:
        while not vehicle.is_armable:
            logging.info("Waiting for vehicle to become armable...")
            time.sleep(1)

        logging.info("Arming motors")
        vehicle.mode = VehicleMode("GUIDED")
        vehicle.armed = True

        while not vehicle.armed:
            time.sleep(1)

        logging.info("Taking off")
        vehicle.simple_takeoff(altitude)

        while True:
            v_alt = vehicle.location.global_relative_frame.alt
            logging.info(f">> Altitude = {v_alt:.1f} m")
            if v_alt >= altitude - 1.0:
                logging.info("Target altitude reached")
                break
            time.sleep(1)
    except Exception as e:
        logging.error(f"Error during takeoff: {e}")

# FUNCTION TO HANDLE RTL MODE IN CASE OF EMERGENCY
def emergency_rtl():
    global emergency_stop
    logging.info("Emergency RTL activated")
    emergency_stop = True
    vehicle.mode = VehicleMode("RTL")
    while vehicle.armed:
        vehicle.armed = False
        time.sleep(1)
    logging.info("Drone safely landed and disarmed.")

# DETECT EMERGENCY KEY PRESS ('r' FOR RETURN-TO-LAUNCH)
def check_emergency_key():
    global emergency_stop
    while True:
        if keyboard.is_pressed('r'):
            logging.info("Emergency stop triggered (Key 'r' detected). Switching to RTL.")
            emergency_rtl()
            break
        time.sleep(0.1)

# FUNCTION TO MOVE THE DRONE (WITH ERROR HANDLING AND THREADING SUPPORT)
def send_ned_velocity(vehicle, vx, vy, vz, duration, MV_Status):
    if emergency_stop:
        return  # Stop all movement if emergency is triggered

    try:
        vehicle.mode = VehicleMode("GUIDED")
        msg = vehicle.message_factory.set_position_target_local_ned_encode(
            0,
            0, 0, 
            9, 
            1479,
            0, 0, 0, 
            vx, vy, vz,
            0, 0, 0, 
            0, 0
        )
        for _ in range(duration):
            if emergency_stop:
                break  # Stop movement if emergency occurs
            vehicle.send_mavlink(msg)
            print(MV_Status)
            logging.info(MV_Status)
            time.sleep(1)
    except Exception as e:
        logging.error(f"Error sending NED velocity: {e}")

# CONTROL YAW FUNCTION (WITH ERROR HANDLING)
def condition_yaw(vehicle, radian, duration, MV_Status):
    vehicle.mode = VehicleMode("GUIDED")
    msg = vehicle.message_factory.set_position_target_local_ned_encode(
        0,
        0, 0,
        9,
        2503,
        0, 0, 0, # POSITION m
        0, 0, 0, # VELOCITY m/s
        0, 0, 0, # ACCELERATIONS m/s^2
        radian, 0
    )
    for x in range(0, duration):
        vehicle.send_mavlink(msg)
        print(MV_Status)
        time.sleep(1)

# SQUARE MOVEMENT FUNCTION (FOR NO TARGET DETECTION)
def square_movement():
    """Move the drone in a square pattern."""
    print("Executing square movement pattern.")
    for _ in range(4):  # Repeat 4 times for a square
        # Move forward
        send_ned_velocity(vehicle, 1, 0, 0, duration=3)  # Move forward for 3 seconds
        # Turn right (yaw)
        radian = 1.5708
        condition_yaw(vehicle, radian, 1, MV_Status="YAW TO THE RIGHT")  # Rotate by 90 degrees TO RIGHT
        time.sleep(5)  # Wait a second for yaw to complete

# CAMERA INITIALIZATION FUNCTION (WITH ERROR HANDLING)
def start_camera():
    try:
        picam2 = Picamera2()
        picam2.set_controls({"AfMode": controls.AfModeEnum.Continuous, "AfSpeed": controls.AfSpeedEnum.Fast})
        camera_config = picam2.create_preview_configuration(main={"format": "RGB888", "size": (1280, 720)})
        picam2.configure(camera_config)
        picam2.start()
        return picam2
    except Exception as e:
        logging.error(f"Failed to initialize camera: {e}")
        sys.exit(1)

# FRAME PROCESSING THREAD FUNCTION (PUT FRAMES INTO QUEUE)
def frame_producer(picam2):
    while True:
        if emergency_stop:
            break
        frame = picam2.capture_array()
        if not frame_queue.full():
            frame_queue.put(frame)

def preprocess_image(image):
    """Resize and preprocess image for the TFLite model."""
    # Ensure the input shape matches the model's input dimensions
    input_shape = input_details[0]['shape'][1:3]  # [300, 300]

    # Convert the image to RGB format (Pillow's default is RGB)
    image = Image.fromarray(image).convert('RGB')

    # Resize the image to the expected input size
    image = image.resize(input_shape, Image.LANCZOS)  # Pillow 10 uses LANCZOS for high-quality downsampling

    # Convert image to numpy array and expand dimensions to match [1, 300, 300, 3]
    input_tensor = np.expand_dims(np.array(image), axis=0)

    # Ensure the image is in uint8 format as required by the model
    #input_tensor = input_tensor.astype(np.uint8)
    input_tensor = input_tensor.astype(np.float32)

    return input_tensor

# DETECT OBJECTS FUNCTION
def detect_objects(image, overlap_thresh=0.5):
    try:
        input_tensor = preprocess_image(image)
        interpreter.set_tensor(input_details[0]['index'], input_tensor)
        interpreter.invoke()

        # Check output layer name to determine if this model was created with TF2 or TF1,
        # because outputs are ordered differently for TF2 and TF1 models
        outname = output_details[0]['name']

        if ('StatefulPartitionedCall' in outname):  # This is a TF2 model
            boxes_idx, classes_idx, scores_idx = 1, 3, 0
        else:  # This is a TF1 model
            boxes_idx, classes_idx, scores_idx = 0, 1, 2

        boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0]
        classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0]
        scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0]
        
        confidence_threshold = 0.5
        selected_indices = np.where(scores > confidence_threshold)

        # Convert bounding boxes to pixel coordinates
        bounding_boxes = []
        for box in boxes:
            ymin, xmin, ymax, xmax = box
            x = int(xmin * image.shape[1])
            y = int(ymin * image.shape[0])
            width = int((xmax - xmin) * image.shape[1])
            height = int((ymax - ymin) * image.shape[0])
            bounding_boxes.append([x, y, width, height])

        # Apply NMS
        indices = cv2.dnn.NMSBoxes(
            bounding_boxes, scores.tolist(),
            confidence_threshold, overlap_thresh
        )
        if len(indices) > 0:
            indices = indices.flatten()
            bounding_boxes = np.array([bounding_boxes[i] for i in indices])
            scores = np.array([scores[i] for i in indices])
            classes = np.array([classes[i] for i in indices])
        
        return boxes[selected_indices], scores[selected_indices], classes[selected_indices]
    except Exception as e:
        logging.error(f"Error during object detection: {e}")
        return [], [], []

def annotate_image(image, boxes, scores, classes):
    """Annotate the image with detection results."""
    centers = []
    highest_score = -1
    highest_score_index = -1
    for i in range(len(boxes)):
        x, y, w, h = boxes[i]
        start_point = (x, y)
        end_point = (x + w, y + h)
        
        # Draw bounding boxes for all detections
        cv2.rectangle(image, start_point, end_point, (0, 255, 0), 2)
        object_name = labels[int(classes[i])]  # Look up object name from "labels" array using class index
        label = f"{object_name}: {scores[i]:.2f}"
        cv2.putText(image, label, (start_point[0], start_point[1] - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        # Calculate the center of the bounding box
        center_x = x + w // 2
        center_y = y + h // 2
        centers.append((center_x, center_y))

        # Track the highest scoring detection
        if scores[i] > highest_score:
            highest_score = scores[i]
            highest_score_index = i
    
    # Return the modified image, all centers, and the center of the highest-scoring detection
    return image, centers, centers[highest_score_index] if highest_score_index != -1 else None

# FRAME CONSUMER THREAD (DETECT OBJECTS AND FOLLOW TARGET)
def frame_consumer():
    global latest_centers
    last_detection_time = time.time()

    while True:
        if emergency_stop:
            break

        if not frame_queue.empty():
            frame = frame_queue.get()

            boxes, scores, classes = detect_objects(frame)
            annotated_frame, centers, _ = annotate_image(frame, boxes, scores, classes)
            latest_centers = centers  # Update global center coordinates

            follow_target(centers, scores, last_detection_time)
            time.sleep(1.0 / 10)  # Limit FPS

# MAIN FLASK APP FOR STREAMING VIDEO
app = Flask(__name__)

@app.route('/video_feed')
def video_feed():
    return render_template_string(HTML_TEMPLATE)

@app.route('/video_feed_stream')
def video_feed_stream():
    return Response(generate_frames(), mimetype='multipart/x-mixed-replace; boundary=frame')

@app.route('/')
def index():
    return "Video Streaming with Object Detection. Go to /video_feed to view the stream."

@app.route('/get_coordinates')
def get_coordinates():
    global latest_centers
    centers_str = ', '.join([f"({x},{y})" for x, y in latest_centers])
    return centers_str

def follow_target(centers, scores, last_detection_time):
    """Function to follow the target with the highest score."""
    current_time = time.time()
    follow_duration = 10  # Duration to follow the target
    if len(centers) > 0:
        # RESET THE DETECTION TIME SINCE WE HAVE A DETECTION
        # Get the center coordinates of the highest-scoring detection
        # Find the index of the highest-scoring detection
        max_score_idx = np.argmax(scores)  # Get the index of the highest score
        center_x, center_y = centers[max_score_idx]  # Use the center of the highest-scoring target

        # Set default velocities
        vx = 0
        vy = 0
        vz = 0

        # DEFINE THE IMAGE WIDTH AND HEIGHT
        image_width = 1280
        image_height = 720

        # DEFINE THE CENTER TOLERANCE RANGES FOR X AND Y
        x_min = image_width // 2 - 10  
        x_max = image_width // 2 + 10  
        y_min = image_height // 2 - 10  
        y_max = image_height // 2 + 10  

        '''center_x is the x-coordinate of the center of the bounding box,
        center_y is the y-coordinate of the center of the bounding box'''
        # Adjust horizontal velocity based on the x position (left/right)
        # MOVE RIGHT
        if center_x > x_max:
            vy = 1
            vz = 0
        # MOVE LEFT
        elif center_x < x_min:
            vy = -1
            vz = 0
        # Adjust forward/backward velocity based on the y position
        # MOVE FORWARD
        if center_y > y_min:
            vx = 1
            vz = 0
        # MOVE BACKWARD
        elif center_y < y_max:
            vx = -1
            vz = 0
        if (center_x > x_min and center_x < x_max) and (center_y > y_min and center_y < y_max):
            vx = 0
            vy = 0
            vz = 0
            vehicle.mode = VehicleMode("LOITER")

        # Send the velocity command to follow the target
        send_ned_velocity(vehicle, vx, vy, vz, duration=1, MV_Status="Following Target")
    
        #CHECK IF THE FOLLOW DURATION HAS BEEN EXCEEDED
        if current_time - last_detection_time >= follow_duration:
            vehicle.mode = VehicleMode("LOITER")
            print("Target lost, switching to Loiter mode.")
            vehicle.close()

            return current_time 
'''
    else:
        # IF NO TARGET IS DETECTED FOR 15 SECONDS, SWITCH TO SQUARE MOVEMENT
        if current_time - last_detection_time > 15:
            square_movement()
        else:
            # If no target is detected, switch the drone to Loiter mode
            print("No victim detected, switching to Loiter mode.")
            vehicle.mode = VehicleMode("LOITER")
    
    return last_detection_time

    '''

# STREAM VIDEO FRAMES TO FLASK (UPDATED)
def generate_frames():
    """Generate frames from the camera and run object detection."""
    global latest_centers  # Track the latest centers for coordinate updates
    last_detection_time = time.time()  # Initialize detection time to current time

    while True:
        frame = picam2.capture_array()

        # Detect objects in the frame
        boxes, scores, classes = detect_objects(frame)

        # Annotate the frame with detection results
        annotated_frame, centers, _ = annotate_image(frame, boxes, scores, classes)

        # Update the global centers list for coordinate display
        latest_centers = centers

        # Follow the highest-scoring target or execute square movement
        last_detection_time = follow_target(centers, scores, last_detection_time)

        # Encode the frame in JPEG format
        ret, buffer = cv2.imencode('.jpg', annotated_frame)
        if not ret:
            continue  # Skip this frame if encoding fails

        frame = buffer.tobytes()

        # Concatenate frame to a streaming-compatible format
        yield (b'--frame\\r\\n'
               b'Content-Type: image/jpeg\\r\\n\\r\\n' + frame + b'\\r\\n')

def run_flask():
    app.run(host='0.0.0.0', port=5000, debug=False)

# MAIN EXECUTION FUNCTION
if __name__ == '__main__':
    # Load the TFLite model
    model_path = "/home/sedna/Downloads/PROJECT_IRIS_V1/IRIS.tflite"
    label_path = "/home/sedna/Downloads/PROJECT_IRIS_V1/labelmap.txt" # DIRECTORY PATH OF LABELMAP FILE
    interpreter = tflite.Interpreter(model_path=model_path, num_threads=4)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Load labels
    with open(label_path, 'r') as f:
        labels = [line.strip() for line in f.readlines()]

    # Start the camera
    picam2 = start_camera()

    # Start frame processing threads
    threading.Thread(target=frame_producer, args=(picam2,), daemon=True).start()
    threading.Thread(target=frame_consumer, daemon=True).start()
    threading.Thread(target=check_emergency_key, daemon=True).start()  # Start listening for emergency key press
    print("PRESS ('r' FOR RETURN-TO-LAUNCH)")

    # Start Flask app
    run_flask()

    time.sleep(15)
    connect_to_vehicle()
    arm_and_takeoff(10)