113 lines
4.4 KiB
Python
113 lines
4.4 KiB
Python
import tensorflow as tf
|
|
from tensorflow.keras import layers, models, Sequential
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import os
|
|
from PyQt6.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QPushButton, QLabel
|
|
from PyQt6.QtGui import QPainter, QPen, QImage
|
|
from PyQt6.QtCore import Qt, QPoint
|
|
import sys
|
|
|
|
data_dir = "/home/mia/Schule/KISY/schrifterkennung/" # Ignore full path, had some weird problem otherwise
|
|
|
|
model_file = "model.keras" # Model save file. No Idea if this is the correct extention but nobody cares, right?
|
|
|
|
print("We have done training already so we load this to not waste very precious cpu :)")
|
|
model = tf.keras.models.load_model(model_file)
|
|
|
|
#for images, labels in val_ds.take(10):
|
|
# preds = model.predict(images)
|
|
# print(f"Prediction: {class_names[np.argmax(preds[0])]}")
|
|
# print(f"Label: {class_names[labels[0].numpy().astype(int)]}")
|
|
# plt.imshow(images[0].numpy().squeeze(), cmap='gray')
|
|
# plt.title(f"Pred: {class_names[np.argmax(preds[0])]}")
|
|
# plt.show()
|
|
|
|
#### DISCLAIMER: This was written by AI; I hate GUI stuff ####
|
|
|
|
class DrawingCanvas(QWidget):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.setFixedSize(320, 320) # 10x the model input size for easier drawing
|
|
self.image = QImage(self.size(), QImage.Format.Format_Grayscale8)
|
|
self.image.fill(Qt.GlobalColor.white)
|
|
self.drawing = False
|
|
self.last_point = QPoint()
|
|
|
|
def paintEvent(self, event):
|
|
painter = QPainter(self)
|
|
painter.drawImage(0, 0, self.image)
|
|
|
|
def mousePressEvent(self, event):
|
|
if event.button() == Qt.MouseButton.LeftButton:
|
|
self.drawing = True
|
|
self.last_point = event.position().toPoint()
|
|
|
|
def mouseMoveEvent(self, event):
|
|
if (event.buttons() & Qt.MouseButton.LeftButton) and self.drawing:
|
|
painter = QPainter(self.image)
|
|
# We use a thick white pen because the model was trained on grayscale images
|
|
painter.setPen(QPen(Qt.GlobalColor.black, 18, Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap))
|
|
painter.drawLine(self.last_point, event.position().toPoint())
|
|
self.last_point = event.position().toPoint()
|
|
self.update()
|
|
|
|
def clear(self):
|
|
self.image.fill(Qt.GlobalColor.white)
|
|
self.update()
|
|
|
|
|
|
class MainWindow(QMainWindow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.setWindowTitle("Handwriting Recognition")
|
|
|
|
main_layout = QVBoxLayout()
|
|
self.canvas = DrawingCanvas()
|
|
self.result_label = QLabel("Draw something and click Predict. (Need to fill entire space or model has stroke")
|
|
self.result_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
|
|
predict_btn = QPushButton("Predict")
|
|
predict_btn.clicked.connect(self.predict_image)
|
|
|
|
clear_btn = QPushButton("Clear Canvas")
|
|
clear_btn.clicked.connect(self.canvas.clear)
|
|
|
|
main_layout.addWidget(self.canvas)
|
|
main_layout.addWidget(self.result_label)
|
|
main_layout.addWidget(predict_btn)
|
|
main_layout.addWidget(clear_btn)
|
|
|
|
container = QWidget()
|
|
container.setLayout(main_layout)
|
|
self.setCentralWidget(container)
|
|
|
|
def predict_image(self):
|
|
# 1. Resize the drawing to 32x32 to match the model input
|
|
scaled_img = self.canvas.image.scaled(32, 32, Qt.AspectRatioMode.IgnoreAspectRatio,
|
|
Qt.TransformationMode.SmoothTransformation)
|
|
|
|
# 2. Convert QImage to Numpy Array
|
|
ptr = scaled_img.bits()
|
|
ptr.setsize(32 * 32)
|
|
arr = np.frombuffer(ptr, np.uint8).reshape(32, 32, 1)
|
|
|
|
# 3. Add batch dimension and predict
|
|
# Note: We don't manually rescale by 1/255 here because your model has a Rescaling layer built-in!
|
|
img_batch = np.expand_dims(arr, axis=0)
|
|
prediction = model.predict(img_batch, verbose=0)
|
|
print(type(prediction), prediction)
|
|
|
|
class_names = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
|
|
print([x for x in zip(class_names, prediction[0])])
|
|
result = class_names[np.argmax(prediction)]
|
|
confidence = np.max(prediction) * 100
|
|
self.result_label.setText(f"Prediction: {result} ({confidence:.1f}%)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app = QApplication(sys.argv)
|
|
app.setStyle("Breeze") # Use system theme so it looks nice on linux
|
|
window = MainWindow()
|
|
window.show()
|
|
sys.exit(app.exec())
|