from tkinter import StringVar, TOP
from tkinter.tix import Balloon
from tkinterdnd2 import TkinterDnD, DND_ALL
import customtkinter as ctk
import tkinter as tk
from ultralytics import YOLO
from PIL import Image, ImageTk, ImageDraw
import numpy as np
import os
import re
import threading
import cv2
import tempfile
class Tk(ctk.CTk, TkinterDnD.DnDWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.TkdndVersion = TkinterDnD._require(self)
ctk.set_appearance_mode("dark")
ctk.set_default_color_theme("blue")
predict_settings = {'conf': 0.35, 'hide_labels': False, 'hide_conf': False, 'iou': 0.7, 'max_det': 300}
model = YOLO('trained_model.pt')
root = Tk()
root.geometry("600x500")
root.title("SpaceFinder AI")
root.resizable(False, False)
nameVar = StringVar()
def get_paths(event, file_entries_frame):
paths = re.findall(r'\{.*?\}|[^ ]+', event.data)
paths = [path.strip('{}') for path in paths]
filter_paths(paths, file_entries_frame)
def filter_paths(paths, file_entries_frame):
for path in paths.copy():
try:
ext = os.path.splitext(path)[1].lower()
if ext not in ['.jpg', '.png', '.jpeg', '.dng', '.bmp', '.mpo', '.tif', '.tiff', '.webp', '.pfm'] and \
ext not in ['.mp4', '.asf', '.avi', '.gif', '.m4v', '.mkv', '.mov', '.mpeg', '.mpg', '.ts', '.wmv', '.webm']:
paths.remove(path)
except Exception as e:
print(e)
if paths:
process_image(paths[0], file_entries_frame)
def process_image(path: str, file_entries_frame: ctk.CTkFrame):
def round_image(img):
mask = Image.new("L", img.size, 0)
draw = ImageDraw.Draw(mask)
radius = 30
draw.rounded_rectangle((0, 0) + img.size, fill=255, radius=radius)
img.putalpha(mask)
return img
def save_txt_file(result_obj):
temp_file = open('annotation.txt', 'w+', encoding='utf8')
result_obj.save_txt(temp_file.name)
file_path = tk.filedialog.asksaveasfilename(defaultextension=".txt", initialfile="annotation.txt",
filetypes=(("Text files", "*.txt"), ("All files", "*.*")))
if file_path:
with open(file_path, 'w') as file:
file.write(temp_file.read())
temp_file.close()
try:
os.remove(temp_file.name)
except:
pass
def save_image_file(pil_img):
file_path = tk.filedialog.asksaveasfilename(initialfile="prediction.jpg",
filetypes=(("Text files", "*.jpg"), ("All files", "*.*")))
if file_path:
pil_img.convert('RGB').save(file_path)
for widget in file_entries_frame.winfo_children():
widget.destroy()
source_img = Image.open(path)
img_sizes = (340, (source_img.height * 340) / source_img.width)
image_label = ctk.CTkLabel(file_entries_frame, text='',
image=ctk.CTkImage(round_image(source_img), size=img_sizes))
image_label.grid(row=0, column=0, padx=47, pady=10)
image_label.drop_target_register(DND_ALL)
image_label.dnd_bind("<<Drop>>", lambda event: get_paths(event, file_entries_frame))
file_result_entry = ctk.CTkEntry(file_entries_frame, state='disabled', width=340, height=200)
file_result_entry.grid(row=1, column=0, padx=47, pady=40)
def run_predict():
predict_result = model(path, **predict_settings)
result_np_array = np.uint8(cv2.cvtColor(predict_result[0].plot(), cv2.COLOR_BGR2RGB))
result_img = Image.fromarray(result_np_array)
file_entries_frame.grid_rowconfigure(2, minsize=30)
result_image_label = ctk.CTkLabel(file_entries_frame, text='',
image=ctk.CTkImage(round_image(result_img), size=img_sizes))
result_image_label.grid(row=3, column=0, padx=47, pady=0)
result_image_label.drop_target_register(DND_ALL)
result_image_label.dnd_bind("<<Drop>>", lambda event: get_paths(event, file_entries_frame))
save_img_but = ctk.CTkButton(file_entries_frame, text='Save image', font=ctk.CTkFont('Trebuchet MS', 14),
fg_color='#404040', hover_color='#363636', text_color='#cfcfcf', width=100,
command=lambda: save_image_file(result_img))
save_img_but.grid(row=4, column=0, padx=95, pady=15, sticky='E')
save_txt_but = ctk.CTkButton(file_entries_frame, text='Save annotation', font=ctk.CTkFont('Trebuchet MS', 14),
fg_color='#404040', hover_color='#363636', text_color='#cfcfcf', width=100,
command=lambda: save_txt_file(predict_result[0]))
save_txt_but.grid(row=4, column=0, padx=95, pady=15, sticky='W')
file_result_entry.destroy()
predict_thread = threading.Thread(target=run_predict)
predict_thread.start()
def reg_settings():
def update_label(label_to_update, value, text):
label_to_update.configure(text=f'{text}: {value/100}')
def apply_button_func():
global predict_settings
predict_settings = {
'conf': confidence_scrollbar.get() / 100,
'hide_labels': bool(hide_labels_switch.get()),
'iou': iou_scrollbar.get() / 100,
'hide_conf': bool(hide_conf_switch.get()),
'max_det': int(max_det_entry.get() if max_det_entry.get() else 300)
}
settings_menu_frame = ctk.CTkScrollableFrame(root, width=100, height=400, corner_radius=0)
settings_menu_frame.place(relx=0.1, rely=0.499, anchor=tk.CENTER)
apply_button = ctk.CTkButton(settings_menu_frame, text='Apply', font=ctk.CTkFont('Trebuchet MS', 14),
fg_color='#404040', hover_color='#363636', text_color='#cfcfcf', width=85,
command=apply_button_func)
apply_button.grid(row=0, column=0, padx=10, pady=5)
confidence_label = ctk.CTkLabel(settings_menu_frame, text='conf: 0.35', text_color='#cfcfcf', font=ctk.CTkFont('Trebuchet MS', 12), width=95)
confidence_label.grid(row=1, column=0, padx=5, pady=0, sticky='W')
confidence_scrollbar = ctk.CTkSlider(settings_menu_frame, width=90, button_color='#636363', button_hover_color='#737373',
from_=0, to=100, number_of_steps=100, command=lambda value: update_label(confidence_label, value, 'conf'),
button_length=1)
confidence_scrollbar.set(35)
confidence_scrollbar.grid(row=2, column=0, padx=10, pady=0, sticky='W')
hide_labels_label = ctk.CTkLabel(settings_menu_frame, text='hide labels', text_color='#cfcfcf',
font=ctk.CTkFont('Trebuchet MS', 12))
hide_labels_label.grid(row=3, column=0, padx=2, pady=0)
hide_labels_switch = ctk.CTkSwitch(settings_menu_frame, width=35, text='', progress_color='#a8a8a8',
button_color='#636363', button_hover_color='#737373')
hide_labels_switch.grid(row=4, column=0, padx=10, pady=0)
hide_conf_label = ctk.CTkLabel(settings_menu_frame, text='hide conf', text_color='#cfcfcf',
font=ctk.CTkFont('Trebuchet MS', 12))
hide_conf_label.grid(row=5, column=0, padx=2, pady=0)
hide_conf_switch = ctk.CTkSwitch(settings_menu_frame, width=35, text='', progress_color='#a8a8a8',
button_color='#636363', button_hover_color='#737373')
hide_conf_switch.grid(row=6, column=0, padx=10, pady=0)
iou_label = ctk.CTkLabel(settings_menu_frame, text='iou: 0.7', text_color='#cfcfcf',
font=ctk.CTkFont('Trebuchet MS', 12), width=95)
iou_label.grid(row=7, column=0, padx=5, pady=0, sticky='W')
iou_scrollbar = ctk.CTkSlider(settings_menu_frame, width=90, button_color='#636363', button_hover_color='#737373',
from_=0, to=100, number_of_steps=100,
command=lambda value: update_label(iou_label, value, 'iou'),
button_length=1)
iou_scrollbar.set(70)
iou_scrollbar.grid(row=8, column=0, padx=10, pady=0, sticky='W')
max_det_label = ctk.CTkLabel(settings_menu_frame, text='max detections', text_color='#cfcfcf',
font=ctk.CTkFont('Trebuchet MS', 12), width=95)
max_det_label.grid(row=9, column=0, padx=5, pady=0, sticky='W')
vcmd = root.register(lambda num: num.isdigit() or num == '')
max_det_entry = ctk.CTkEntry(settings_menu_frame, width=70, justify='center', placeholder_text='300',
validate="key", validatecommand=(vcmd, '%P'))
max_det_entry.grid(row=10, column=0, padx=15, pady=0, sticky='W')
def reg_clear(predict_frame):
def clear_workflow():
predict_frame.destroy()
reg_file_entries()
clear_button_frame = ctk.CTkFrame(root, width=117, height=50, corner_radius=0)
clear_button_frame.place(relx=0.1, rely=0.95, anchor=tk.CENTER)
clear_button = ctk.CTkButton(clear_button_frame, text='Clear', font=ctk.CTkFont('Trebuchet MS', 14),
fg_color='#404040', hover_color='#363636', text_color='#cfcfcf', width=110,
command=clear_workflow)
clear_button.place(relx=0.5, rely=0.5, anchor=tk.CENTER)
def reg_version_label():
version_label_frame = ctk.CTkFrame(root, width=117, height=50, corner_radius=0)
version_label_frame.place(relx=0.1, rely=0.05, anchor=tk.CENTER)
version_label = ctk.CTkLabel(version_label_frame, text='SpaceFinder v0.0.3', font=ctk.CTkFont('Trebuchet MS', 13),
text_color='#cfcfcf')
version_label.place(relx=0.5, rely=0.5, anchor=tk.CENTER)
def reg_file_entries():
def choose_file():
file_path = tk.filedialog.askopenfilename(
filetypes=(("Image Files", "*.bmp;*.dng;*.jpeg;*.jpg;*.mpo;*.png;*.tif;*.tiff;*.webp;*.pfm"),
("Video Files", "*.asf;*.avi;*.gif;*.m4v;*.mkv;*.mov;*.mp4;*.mpeg;*.mpg;*.ts;*.wmv;*.webm"),
("All Files", "*.*")))
if file_path:
filter_paths([file_path], predict_frame)
predict_frame = ctk.CTkScrollableFrame(root, width=430, height=500, corner_radius=0)
predict_frame.place(relx=0.625, rely=0.5, anchor=tk.CENTER)
predict_frame.drop_target_register(DND_ALL)
predict_frame.dnd_bind("<<Drop>>", lambda event: get_paths(event, predict_frame))
file_entry = ctk.CTkEntry(predict_frame, state='disabled', width=340, height=200)
file_entry.grid(row=0, column=0, padx=47, pady=10)
file_entry.drop_target_register(DND_ALL)
file_entry.dnd_bind("<<Drop>>", lambda event: get_paths(event, predict_frame))
file_entry_button = ctk.CTkButton(predict_frame, text='Choose file', font=ctk.CTkFont('Trebuchet MS', 14),
fg_color='#404040', hover_color='#303030', text_color='#cfcfcf', width=110,
bg_color='#343638', corner_radius=10, command=choose_file)
file_entry_button.grid(row=0, column=0, padx=47, pady=10)
file_entry_button.drop_target_register(DND_ALL)
file_entry_button.dnd_bind("<<Drop>>", lambda event: get_paths(event, predict_frame))
predict_frame.grid_rowconfigure(1, minsize=30)
file_result_entry = ctk.CTkEntry(predict_frame, state='disabled', width=340, height=200)
file_result_entry.grid(row=2, column=0, padx=47, pady=0)
return predict_frame
if __name__ == '__main__':
reg_settings()
predict_frame = reg_file_entries()
reg_version_label()
reg_clear(predict_frame)
root.mainloop()