Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 41 additions & 28 deletions cflib/cpx/transports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import socket
import struct
from threading import Event
from threading import Lock

from . import CPXPacket
Expand Down Expand Up @@ -80,8 +81,9 @@ def __init__(self, device, baudrate):
self._device = device
self._baudrate = baudrate
self._serial = None
self._cts = False
self._lock = Lock()
self._tx_ready = Event()
self._tx_lock = Lock()
self._serial_write_lock = Lock()

self.connect()

Expand All @@ -100,10 +102,10 @@ def connect(self):
print(size)
if size == 0x00:
isInSync = True
self.cts = True

# Send back sync
self._serial.write([0xFF, 0x00])
# Send back sync / clear-to-receive
self._write_raw([0xFF, 0x00])
self._tx_ready.set()

print('Connected')

Expand All @@ -113,44 +115,55 @@ def _calcXORchecksum(self, data):
checksum ^= i
return checksum

def _write_raw(self, data):
with self._serial_write_lock:
self._serial.write(data)

def disconnect(self):
print('Closing transport')
self._serial.close()
self._serial = None

def writePacket(self, packet):
self._lock.acquire()
data = packet.wireData
if len(data) > 100:
raise 'Packet too large!'
raise Exception('Packet too large!')

buff = bytearray([0xFF, len(data)])
buff.extend(data)
buff.extend([self._calcXORchecksum(buff)])
self._serial.write(buff)

with self._tx_lock:
self._tx_ready.wait()
self._tx_ready.clear()
self._write_raw(buff)

def readPacket(self):
size = 0
while size == 0:
while True:
start = self._serial.read(1)[0]
if start == 0xFF:
size = self._serial.read(1)[0]
if size == 0:
self._lock.release()
else:
data = self._serial.read(size) # Size is excluding start (0xFF) and checksum at end
crc = self._serial.read(1)
# CRC includes start and size
calculated_crc = self._calcXORchecksum(bytes([start, size]) + data)
if calculated_crc != ord(crc):
print('CRC error!')
# Send CTS
self._serial.write([0xFF, 0x00])

packet = CPXPacket()
packet.wireData = data

return packet
if start != 0xFF:
continue

size = self._serial.read(1)[0]
if size == 0:
self._tx_ready.set()
continue

data = self._serial.read(size) # Size is excluding start (0xFF) and checksum at end
crc = self._serial.read(1)[0]
# CRC includes start and size
calculated_crc = self._calcXORchecksum(bytes([start, size]) + data)
if calculated_crc != crc:
print('CRC error!')
self._write_raw([0xFF, 0x00])
continue

# Send CTS
self._write_raw([0xFF, 0x00])

packet = CPXPacket()
packet.wireData = data
return packet


class CRTPTransport(CPXTransport):
Expand Down
1 change: 1 addition & 0 deletions test/cpx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
163 changes: 163 additions & 0 deletions test/cpx/test_uart_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# -*- coding: utf-8 -*-
import threading
import time
import types
import unittest

from cflib.cpx import CPXFunction
from cflib.cpx import CPXPacket
from cflib.cpx import CPXTarget
from cflib.cpx import transports
from cflib.cpx.transports import UARTTransport


class FakeSerial:
def __init__(self, initial_read_data=b''):
self._read_data = bytearray(initial_read_data)
self._condition = threading.Condition()
self.writes = []
self.closed = False

def append_read_data(self, data):
with self._condition:
self._read_data.extend(data)
self._condition.notify_all()

def read(self, size):
with self._condition:
end_time = time.monotonic() + 1.0
while len(self._read_data) < size:
remaining = end_time - time.monotonic()
if remaining <= 0:
raise TimeoutError('Timed out waiting for fake serial data')
self._condition.wait(remaining)

result = self._read_data[:size]
del self._read_data[:size]
return bytes(result)

def write(self, data):
self.writes.append(bytes(data))
return len(data)

def close(self):
self.closed = True


def checksum(data):
result = 0
for byte in data:
result ^= byte
return result


def uart_frame(packet):
data = packet.wireData
frame = bytearray([0xFF, len(data)])
frame.extend(data)
frame.append(checksum(frame))
return bytes(frame)


class UARTTransportTest(unittest.TestCase):
def setUp(self):
self.fake_serial = FakeSerial(b'\xff\x00')
self.original_serial = getattr(transports, 'serial', None)
transports.serial = types.SimpleNamespace(
Serial=lambda device, baudrate, timeout=None: self.fake_serial
)

def tearDown(self):
if self.original_serial is None:
del transports.serial
else:
transports.serial = self.original_serial

def _transport(self):
return UARTTransport('/dev/fake', 576000)

def _inbound_packet(self, data=b'payload'):
return CPXPacket(
source=CPXTarget.STM32,
destination=CPXTarget.HOST,
function=CPXFunction.CRTP,
data=bytearray(data),
)

def test_unsolicited_cts_frames_are_ignored_until_data_packet(self):
transport = self._transport()
expected = self._inbound_packet(b'abc')

self.fake_serial.append_read_data(b'\xff\x00')
self.fake_serial.append_read_data(b'\xff\x00')
self.fake_serial.append_read_data(uart_frame(expected))

actual = transport.readPacket()

self.assertEqual(expected.wireData, actual.wireData)
self.assertEqual([b'\xff\x00', b'\xff\x00'], self.fake_serial.writes)

def test_crc_error_is_discarded_and_next_valid_packet_is_returned(self):
transport = self._transport()
bad_packet = self._inbound_packet(b'bad')
bad_frame = bytearray(uart_frame(bad_packet))
bad_frame[-1] ^= 0x01
expected = self._inbound_packet(b'good')

self.fake_serial.append_read_data(bytes(bad_frame))
self.fake_serial.append_read_data(uart_frame(expected))

actual = transport.readPacket()

self.assertEqual(expected.wireData, actual.wireData)
self.assertEqual([b'\xff\x00', b'\xff\x00', b'\xff\x00'], self.fake_serial.writes)

def test_write_packet_waits_for_cts_before_writing_next_frame(self):
transport = self._transport()
first_outbound = CPXPacket(
destination=CPXTarget.STM32,
function=CPXFunction.SYSTEM,
data=bytearray([0x21, 0x01]),
)
blocked_outbound = CPXPacket(
destination=CPXTarget.STM32,
function=CPXFunction.SYSTEM,
data=bytearray([0x20, 0x01]),
)
inbound = self._inbound_packet(b'unblock')

transport.writePacket(first_outbound)
first_data_write_count = len(self.fake_serial.writes)

writer_error = []
writer_done = threading.Event()

def writer():
try:
transport.writePacket(blocked_outbound)
except Exception as exc:
writer_error.append(exc)
finally:
writer_done.set()

thread = threading.Thread(target=writer)
thread.start()

self.assertFalse(writer_done.wait(0.05))
self.assertEqual(first_data_write_count, len(self.fake_serial.writes))

self.fake_serial.append_read_data(b'\xff\x00')
self.fake_serial.append_read_data(uart_frame(inbound))
actual = transport.readPacket()

self.assertEqual(inbound.wireData, actual.wireData)
self.assertTrue(writer_done.wait(1.0))
thread.join(1.0)
self.assertEqual([], writer_error)

expected_blocked_frame = uart_frame(blocked_outbound)
self.assertIn(expected_blocked_frame, self.fake_serial.writes)


if __name__ == '__main__':
unittest.main()
Loading