Initial (and hopefully last) commit
This commit is contained in:
commit
06374b2608
8 changed files with 326 additions and 0 deletions
113
app.py
Normal file
113
app.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow.keras import layers, models, Sequential
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from PyQt6.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QPushButton, QLabel
|
||||
from PyQt6.QtGui import QPainter, QPen, QImage
|
||||
from PyQt6.QtCore import Qt, QPoint
|
||||
import sys
|
||||
|
||||
data_dir = "/home/mia/Schule/KISY/schrifterkennung/" # Ignore full path, had some weird problem otherwise
|
||||
|
||||
model_file = "model.keras" # Model save file. No Idea if this is the correct extention but nobody cares, right?
|
||||
|
||||
print("We have done training already so we load this to not waste very precious cpu :)")
|
||||
model = tf.keras.models.load_model(model_file)
|
||||
|
||||
#for images, labels in val_ds.take(10):
|
||||
# preds = model.predict(images)
|
||||
# print(f"Prediction: {class_names[np.argmax(preds[0])]}")
|
||||
# print(f"Label: {class_names[labels[0].numpy().astype(int)]}")
|
||||
# plt.imshow(images[0].numpy().squeeze(), cmap='gray')
|
||||
# plt.title(f"Pred: {class_names[np.argmax(preds[0])]}")
|
||||
# plt.show()
|
||||
|
||||
#### DISCLAIMER: This was written by AI; I hate GUI stuff ####
|
||||
|
||||
class DrawingCanvas(QWidget):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setFixedSize(320, 320) # 10x the model input size for easier drawing
|
||||
self.image = QImage(self.size(), QImage.Format.Format_Grayscale8)
|
||||
self.image.fill(Qt.GlobalColor.white)
|
||||
self.drawing = False
|
||||
self.last_point = QPoint()
|
||||
|
||||
def paintEvent(self, event):
|
||||
painter = QPainter(self)
|
||||
painter.drawImage(0, 0, self.image)
|
||||
|
||||
def mousePressEvent(self, event):
|
||||
if event.button() == Qt.MouseButton.LeftButton:
|
||||
self.drawing = True
|
||||
self.last_point = event.position().toPoint()
|
||||
|
||||
def mouseMoveEvent(self, event):
|
||||
if (event.buttons() & Qt.MouseButton.LeftButton) and self.drawing:
|
||||
painter = QPainter(self.image)
|
||||
# We use a thick white pen because the model was trained on grayscale images
|
||||
painter.setPen(QPen(Qt.GlobalColor.black, 18, Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap))
|
||||
painter.drawLine(self.last_point, event.position().toPoint())
|
||||
self.last_point = event.position().toPoint()
|
||||
self.update()
|
||||
|
||||
def clear(self):
|
||||
self.image.fill(Qt.GlobalColor.white)
|
||||
self.update()
|
||||
|
||||
|
||||
class MainWindow(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Handwriting Recognition")
|
||||
|
||||
main_layout = QVBoxLayout()
|
||||
self.canvas = DrawingCanvas()
|
||||
self.result_label = QLabel("Draw something and click Predict. (Need to fill entire space or model has stroke")
|
||||
self.result_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
||||
|
||||
predict_btn = QPushButton("Predict")
|
||||
predict_btn.clicked.connect(self.predict_image)
|
||||
|
||||
clear_btn = QPushButton("Clear Canvas")
|
||||
clear_btn.clicked.connect(self.canvas.clear)
|
||||
|
||||
main_layout.addWidget(self.canvas)
|
||||
main_layout.addWidget(self.result_label)
|
||||
main_layout.addWidget(predict_btn)
|
||||
main_layout.addWidget(clear_btn)
|
||||
|
||||
container = QWidget()
|
||||
container.setLayout(main_layout)
|
||||
self.setCentralWidget(container)
|
||||
|
||||
def predict_image(self):
|
||||
# 1. Resize the drawing to 32x32 to match the model input
|
||||
scaled_img = self.canvas.image.scaled(32, 32, Qt.AspectRatioMode.IgnoreAspectRatio,
|
||||
Qt.TransformationMode.SmoothTransformation)
|
||||
|
||||
# 2. Convert QImage to Numpy Array
|
||||
ptr = scaled_img.bits()
|
||||
ptr.setsize(32 * 32)
|
||||
arr = np.frombuffer(ptr, np.uint8).reshape(32, 32, 1)
|
||||
|
||||
# 3. Add batch dimension and predict
|
||||
# Note: We don't manually rescale by 1/255 here because your model has a Rescaling layer built-in!
|
||||
img_batch = np.expand_dims(arr, axis=0)
|
||||
prediction = model.predict(img_batch, verbose=0)
|
||||
print(type(prediction), prediction)
|
||||
|
||||
class_names = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
|
||||
print([x for x in zip(class_names, prediction[0])])
|
||||
result = class_names[np.argmax(prediction)]
|
||||
confidence = np.max(prediction) * 100
|
||||
self.result_label.setText(f"Prediction: {result} ({confidence:.1f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
app.setStyle("Breeze") # Use system theme so it looks nice on linux
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec())
|
||||
Loading…
Add table
Add a link
Reference in a new issue