Schrifterkennung/app.py

113 lines
4.4 KiB
Python

import tensorflow as tf
from tensorflow.keras import layers, models, Sequential
import numpy as np
import matplotlib.pyplot as plt
import os
from PyQt6.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QPushButton, QLabel
from PyQt6.QtGui import QPainter, QPen, QImage
from PyQt6.QtCore import Qt, QPoint
import sys
data_dir = "/home/mia/Schule/KISY/schrifterkennung/" # Ignore full path, had some weird problem otherwise
model_file = "model.keras" # Model save file. No Idea if this is the correct extention but nobody cares, right?
print("We have done training already so we load this to not waste very precious cpu :)")
model = tf.keras.models.load_model(model_file)
#for images, labels in val_ds.take(10):
# preds = model.predict(images)
# print(f"Prediction: {class_names[np.argmax(preds[0])]}")
# print(f"Label: {class_names[labels[0].numpy().astype(int)]}")
# plt.imshow(images[0].numpy().squeeze(), cmap='gray')
# plt.title(f"Pred: {class_names[np.argmax(preds[0])]}")
# plt.show()
#### DISCLAIMER: This was written by AI; I hate GUI stuff ####
class DrawingCanvas(QWidget):
def __init__(self):
super().__init__()
self.setFixedSize(320, 320) # 10x the model input size for easier drawing
self.image = QImage(self.size(), QImage.Format.Format_Grayscale8)
self.image.fill(Qt.GlobalColor.white)
self.drawing = False
self.last_point = QPoint()
def paintEvent(self, event):
painter = QPainter(self)
painter.drawImage(0, 0, self.image)
def mousePressEvent(self, event):
if event.button() == Qt.MouseButton.LeftButton:
self.drawing = True
self.last_point = event.position().toPoint()
def mouseMoveEvent(self, event):
if (event.buttons() & Qt.MouseButton.LeftButton) and self.drawing:
painter = QPainter(self.image)
# We use a thick white pen because the model was trained on grayscale images
painter.setPen(QPen(Qt.GlobalColor.black, 18, Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap))
painter.drawLine(self.last_point, event.position().toPoint())
self.last_point = event.position().toPoint()
self.update()
def clear(self):
self.image.fill(Qt.GlobalColor.white)
self.update()
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Handwriting Recognition")
main_layout = QVBoxLayout()
self.canvas = DrawingCanvas()
self.result_label = QLabel("Draw something and click Predict. (Need to fill entire space or model has stroke")
self.result_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
predict_btn = QPushButton("Predict")
predict_btn.clicked.connect(self.predict_image)
clear_btn = QPushButton("Clear Canvas")
clear_btn.clicked.connect(self.canvas.clear)
main_layout.addWidget(self.canvas)
main_layout.addWidget(self.result_label)
main_layout.addWidget(predict_btn)
main_layout.addWidget(clear_btn)
container = QWidget()
container.setLayout(main_layout)
self.setCentralWidget(container)
def predict_image(self):
# 1. Resize the drawing to 32x32 to match the model input
scaled_img = self.canvas.image.scaled(32, 32, Qt.AspectRatioMode.IgnoreAspectRatio,
Qt.TransformationMode.SmoothTransformation)
# 2. Convert QImage to Numpy Array
ptr = scaled_img.bits()
ptr.setsize(32 * 32)
arr = np.frombuffer(ptr, np.uint8).reshape(32, 32, 1)
# 3. Add batch dimension and predict
# Note: We don't manually rescale by 1/255 here because your model has a Rescaling layer built-in!
img_batch = np.expand_dims(arr, axis=0)
prediction = model.predict(img_batch, verbose=0)
print(type(prediction), prediction)
class_names = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
print([x for x in zip(class_names, prediction[0])])
result = class_names[np.argmax(prediction)]
confidence = np.max(prediction) * 100
self.result_label.setText(f"Prediction: {result} ({confidence:.1f}%)")
if __name__ == "__main__":
app = QApplication(sys.argv)
app.setStyle("Breeze") # Use system theme so it looks nice on linux
window = MainWindow()
window.show()
sys.exit(app.exec())