import cv2
from picamera2 import Picamera2, Preview
from libcamera import controls
import text_recognition as recognizer
import numpy as np
from spislave.spidevice import SpiDevice
from spislave.protocol import SSProtocol
from spislave.rpigpioaccess import RPiGPIOAccess
import queue as Queue
import traceback

# interface states
COMMAND_GET_BYTE = 1
COMMAND_COMMS_CHECK = 2
COMMAND_RECOGNIZE_SERIAL_NUMBER = 3
COMMAND_CHECK_EXECUTION = 4

VALUE_COMMS_OK = 0x5A
VALUE_NOT_EXECUTED = 0xAA
VALUE_EXECUTED = 0x55
VALUE_NODATA = 0xFE

class SpiSlave(SpiDevice):
	def __init__(self):
		SpiDevice.__init__(self)
		self.dataAccess = RPiGPIOAccess()
		self.protocol = SSProtocol( self.dataAccess, self )
		self.nextData = 0
		self.commandExecuted = 0
		self.imageCaptureRequest = 0
		self.returnMultibyte = []
		self.returnCounter = 0

	def prepareData(self):
		self.sendBuffer = self.nextData

	def dataReceived(self):
		receivedByte = self.receiveBuffer
		print("SPI received: " + str(receivedByte))

		if receivedByte == COMMAND_GET_BYTE:
			if self.returnCounter < len(self.returnMultibyte):
				self.nextData = self.returnMultibyte[self.returnCounter]
				self.returnCounter += 1
			else:
				self.nextData = VALUE_NODATA

		elif receivedByte == COMMAND_COMMS_CHECK:
			self.nextData = VALUE_COMMS_OK

		elif receivedByte == COMMAND_RECOGNIZE_SERIAL_NUMBER:
			self.imageCaptureRequest = 1
			self.commandExecuted = 0

		elif receivedByte == COMMAND_CHECK_EXECUTION:
			if self.commandExecuted:
				self.nextData = VALUE_EXECUTED
				self.returnCounter = 0
			else:
				self.nextData = VALUE_NOT_EXECUTED

		print("SPI next data: " + str(self.nextData))

	def start(self):
		self.dataAccess.start()

spi = SpiSlave()
spi.start()

camera = Picamera2()
camera.configure(camera.create_video_configuration(main={"size": (1300, 648), "format": "RGB888" }))
camera.set_controls({"AfMode": controls.AfModeEnum.Manual, "LensPosition": 32.0})
camera.start()

while 1:
	image = cv2.rotate(camera.capture_array(), cv2.ROTATE_180)
	image = image[530:600, 500:710]
	cv2.imshow("Frame", image)

	key = cv2.waitKey(10) & 0xFF
	if key == ord("\r") or spi.imageCaptureRequest:
		spi.imageCaptureRequest = 0
		
		[cnn_in, recognized] = recognizer.recognize(image)
		spi.returnMultibyte.clear()

		if len(cnn_in) == 0:
			images = np.zeros((100, 100), dtype=np.dtype('uint8'))
		else:
			images = cnn_in[0]
			spacer = np.zeros((cnn_in[0].shape[0], 10), dtype=np.dtype('uint8'))
			for i in range(1,len(cnn_in)):
				images = cv2.hconcat([images, spacer])
				images = cv2.hconcat([images, cnn_in[i]])

				spi.returnMultibyte.append(recognized[i])

		cv2.imshow("Frame1", images)

		predicted_number = recognizer.get_serial_number(image, recognized)
		print("serial number prediction: " + str(predicted_number))
		print("result length: " + str(len(spi.returnMultibyte)))

		spi.commandExecuted = 1

