diff --git a/cflib/cpx/transports.py b/cflib/cpx/transports.py index 830c46b9b..ac2ed09ca 100644 --- a/cflib/cpx/transports.py +++ b/cflib/cpx/transports.py @@ -1,5 +1,6 @@ import socket import struct +from threading import Event from threading import Lock from . import CPXPacket @@ -80,8 +81,9 @@ def __init__(self, device, baudrate): self._device = device self._baudrate = baudrate self._serial = None - self._cts = False - self._lock = Lock() + self._tx_ready = Event() + self._tx_lock = Lock() + self._serial_write_lock = Lock() self.connect() @@ -100,10 +102,10 @@ def connect(self): print(size) if size == 0x00: isInSync = True - self.cts = True - # Send back sync - self._serial.write([0xFF, 0x00]) + # Send back sync / clear-to-receive + self._write_raw([0xFF, 0x00]) + self._tx_ready.set() print('Connected') @@ -113,44 +115,55 @@ def _calcXORchecksum(self, data): checksum ^= i return checksum + def _write_raw(self, data): + with self._serial_write_lock: + self._serial.write(data) + def disconnect(self): print('Closing transport') self._serial.close() self._serial = None def writePacket(self, packet): - self._lock.acquire() data = packet.wireData if len(data) > 100: - raise 'Packet too large!' + raise Exception('Packet too large!') buff = bytearray([0xFF, len(data)]) buff.extend(data) buff.extend([self._calcXORchecksum(buff)]) - self._serial.write(buff) + + with self._tx_lock: + self._tx_ready.wait() + self._tx_ready.clear() + self._write_raw(buff) def readPacket(self): - size = 0 - while size == 0: + while True: start = self._serial.read(1)[0] - if start == 0xFF: - size = self._serial.read(1)[0] - if size == 0: - self._lock.release() - else: - data = self._serial.read(size) # Size is excluding start (0xFF) and checksum at end - crc = self._serial.read(1) - # CRC includes start and size - calculated_crc = self._calcXORchecksum(bytes([start, size]) + data) - if calculated_crc != ord(crc): - print('CRC error!') - # Send CTS - self._serial.write([0xFF, 0x00]) - - packet = CPXPacket() - packet.wireData = data - - return packet + if start != 0xFF: + continue + + size = self._serial.read(1)[0] + if size == 0: + self._tx_ready.set() + continue + + data = self._serial.read(size) # Size is excluding start (0xFF) and checksum at end + crc = self._serial.read(1)[0] + # CRC includes start and size + calculated_crc = self._calcXORchecksum(bytes([start, size]) + data) + if calculated_crc != crc: + print('CRC error!') + self._write_raw([0xFF, 0x00]) + continue + + # Send CTS + self._write_raw([0xFF, 0x00]) + + packet = CPXPacket() + packet.wireData = data + return packet class CRTPTransport(CPXTransport): diff --git a/test/cpx/__init__.py b/test/cpx/__init__.py new file mode 100644 index 000000000..40a96afc6 --- /dev/null +++ b/test/cpx/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/test/cpx/test_uart_transport.py b/test/cpx/test_uart_transport.py new file mode 100644 index 000000000..020c0641a --- /dev/null +++ b/test/cpx/test_uart_transport.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +import threading +import time +import types +import unittest + +from cflib.cpx import CPXFunction +from cflib.cpx import CPXPacket +from cflib.cpx import CPXTarget +from cflib.cpx import transports +from cflib.cpx.transports import UARTTransport + + +class FakeSerial: + def __init__(self, initial_read_data=b''): + self._read_data = bytearray(initial_read_data) + self._condition = threading.Condition() + self.writes = [] + self.closed = False + + def append_read_data(self, data): + with self._condition: + self._read_data.extend(data) + self._condition.notify_all() + + def read(self, size): + with self._condition: + end_time = time.monotonic() + 1.0 + while len(self._read_data) < size: + remaining = end_time - time.monotonic() + if remaining <= 0: + raise TimeoutError('Timed out waiting for fake serial data') + self._condition.wait(remaining) + + result = self._read_data[:size] + del self._read_data[:size] + return bytes(result) + + def write(self, data): + self.writes.append(bytes(data)) + return len(data) + + def close(self): + self.closed = True + + +def checksum(data): + result = 0 + for byte in data: + result ^= byte + return result + + +def uart_frame(packet): + data = packet.wireData + frame = bytearray([0xFF, len(data)]) + frame.extend(data) + frame.append(checksum(frame)) + return bytes(frame) + + +class UARTTransportTest(unittest.TestCase): + def setUp(self): + self.fake_serial = FakeSerial(b'\xff\x00') + self.original_serial = getattr(transports, 'serial', None) + transports.serial = types.SimpleNamespace( + Serial=lambda device, baudrate, timeout=None: self.fake_serial + ) + + def tearDown(self): + if self.original_serial is None: + del transports.serial + else: + transports.serial = self.original_serial + + def _transport(self): + return UARTTransport('/dev/fake', 576000) + + def _inbound_packet(self, data=b'payload'): + return CPXPacket( + source=CPXTarget.STM32, + destination=CPXTarget.HOST, + function=CPXFunction.CRTP, + data=bytearray(data), + ) + + def test_unsolicited_cts_frames_are_ignored_until_data_packet(self): + transport = self._transport() + expected = self._inbound_packet(b'abc') + + self.fake_serial.append_read_data(b'\xff\x00') + self.fake_serial.append_read_data(b'\xff\x00') + self.fake_serial.append_read_data(uart_frame(expected)) + + actual = transport.readPacket() + + self.assertEqual(expected.wireData, actual.wireData) + self.assertEqual([b'\xff\x00', b'\xff\x00'], self.fake_serial.writes) + + def test_crc_error_is_discarded_and_next_valid_packet_is_returned(self): + transport = self._transport() + bad_packet = self._inbound_packet(b'bad') + bad_frame = bytearray(uart_frame(bad_packet)) + bad_frame[-1] ^= 0x01 + expected = self._inbound_packet(b'good') + + self.fake_serial.append_read_data(bytes(bad_frame)) + self.fake_serial.append_read_data(uart_frame(expected)) + + actual = transport.readPacket() + + self.assertEqual(expected.wireData, actual.wireData) + self.assertEqual([b'\xff\x00', b'\xff\x00', b'\xff\x00'], self.fake_serial.writes) + + def test_write_packet_waits_for_cts_before_writing_next_frame(self): + transport = self._transport() + first_outbound = CPXPacket( + destination=CPXTarget.STM32, + function=CPXFunction.SYSTEM, + data=bytearray([0x21, 0x01]), + ) + blocked_outbound = CPXPacket( + destination=CPXTarget.STM32, + function=CPXFunction.SYSTEM, + data=bytearray([0x20, 0x01]), + ) + inbound = self._inbound_packet(b'unblock') + + transport.writePacket(first_outbound) + first_data_write_count = len(self.fake_serial.writes) + + writer_error = [] + writer_done = threading.Event() + + def writer(): + try: + transport.writePacket(blocked_outbound) + except Exception as exc: + writer_error.append(exc) + finally: + writer_done.set() + + thread = threading.Thread(target=writer) + thread.start() + + self.assertFalse(writer_done.wait(0.05)) + self.assertEqual(first_data_write_count, len(self.fake_serial.writes)) + + self.fake_serial.append_read_data(b'\xff\x00') + self.fake_serial.append_read_data(uart_frame(inbound)) + actual = transport.readPacket() + + self.assertEqual(inbound.wireData, actual.wireData) + self.assertTrue(writer_done.wait(1.0)) + thread.join(1.0) + self.assertEqual([], writer_error) + + expected_blocked_frame = uart_frame(blocked_outbound) + self.assertIn(expected_blocked_frame, self.fake_serial.writes) + + +if __name__ == '__main__': + unittest.main()