Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions duo_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,26 @@ def __init__(self, ikey, skey, host,
paging_limit=100,
digestmod=hashlib.sha512,
sig_version=None,
port=None
port=None,
disable_ca_pinning=False
):
"""
ca_certs - Path to CA pem file.
disable_ca_pinning - If True, uses the system's default trusted CA
certificates instead of Duo's bundled CA certificates. TLS
verification remains active. Cannot be used together with a
custom ca_certs path.
"""
self.ikey = ikey
self.skey = skey
self.host = host
self.port = port
self.sig_timezone = sig_timezone
if disable_ca_pinning and ca_certs not in (None, DEFAULT_CA_CERTS):
raise ValueError(
"Cannot both disable CA pinning and provide custom CA certificates"
)
self.disable_ca_pinning = disable_ca_pinning
if ca_certs is None:
ca_certs = DEFAULT_CA_CERTS
self.ca_certs = ca_certs
Expand Down Expand Up @@ -382,7 +392,10 @@ def _connect(self):
raise NotImplementedError('proxy_type=%s' % (self.proxy_type,))

# Create outer HTTP(S) connection.
if self.ca_certs == 'HTTP':
if self.disable_ca_pinning:
context = ssl.create_default_context()
conn = http.client.HTTPSConnection(host, port, context=context)
elif self.ca_certs == 'HTTP':
conn = http.client.HTTPConnection(host, port)
elif self.ca_certs == 'DISABLE':
kwargs = {}
Expand Down Expand Up @@ -634,6 +647,7 @@ def main():
parser.add_argument('--path', required=True,
help='API endpoint path')
parser.add_argument('--ca', default=DEFAULT_CA_CERTS)
parser.add_argument('--disable-ca-pinning', default=False)
parser.add_argument('--sig-version', type=int, default=2)
parser.add_argument('--sig-timezone', default='UTC')
parser.add_argument(
Expand All @@ -655,6 +669,7 @@ def main():
ca_certs=args.ca,
sig_version=args.sig_version,
sig_timezone=args.sig_timezone,
disable_ca_pinning=args.disable_ca_pinning,
)

params = collections.defaultdict(list)
Expand Down
84 changes: 84 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import hashlib
import http.client
import ssl
from unittest import mock
import unittest
import duo_client.client
import duo_client.https_wrapper
from . import util
import base64
import collections
Expand Down Expand Up @@ -760,5 +763,86 @@ def test_sig_version_3_raises_exception(self):
'test_ikey', 'test_akey', 'example.com', sig_timezone='America/Detroit',
sig_version=3)


class TestDisableCaPinningInit(unittest.TestCase):
"""Tests for the disable_ca_pinning parameter on Client.__init__."""

def test_default_is_pinning_enabled(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com')
self.assertFalse(client.disable_ca_pinning)
self.assertEqual(client.ca_certs, duo_client.client.DEFAULT_CA_CERTS)

def test_disable_ca_pinning_true(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com',
disable_ca_pinning=True)
self.assertTrue(client.disable_ca_pinning)

def test_disable_ca_pinning_with_default_ca_certs(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com',
ca_certs=duo_client.client.DEFAULT_CA_CERTS, disable_ca_pinning=True)
self.assertTrue(client.disable_ca_pinning)

def test_disable_ca_pinning_with_none_ca_certs(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com',
ca_certs=None, disable_ca_pinning=True)
self.assertTrue(client.disable_ca_pinning)

def test_disable_ca_pinning_with_custom_ca_certs_raises(self):
with self.assertRaises(ValueError) as ctx:
duo_client.client.Client('ikey', 'skey', 'host.example.com',
ca_certs='/path/to/custom.pem', disable_ca_pinning=True)
self.assertIn("Cannot both disable CA pinning", str(ctx.exception))

def test_disable_ca_pinning_with_http_ca_certs_raises(self):
with self.assertRaises(ValueError) as ctx:
duo_client.client.Client('ikey', 'skey', 'host.example.com',
ca_certs='HTTP', disable_ca_pinning=True)
self.assertIn("Cannot both disable CA pinning", str(ctx.exception))

def test_disable_ca_pinning_with_disable_ca_certs_raises(self):
with self.assertRaises(ValueError) as ctx:
duo_client.client.Client('ikey', 'skey', 'host.example.com',
ca_certs='DISABLE', disable_ca_pinning=True)
self.assertIn("Cannot both disable CA pinning", str(ctx.exception))


class TestDisableCaPinningConnect(unittest.TestCase):
"""Tests that _connect() uses the correct connection type."""

def test_connect_with_pinning_uses_cert_validating(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com')
conn = client._connect()
self.assertIsInstance(conn, duo_client.https_wrapper.CertValidatingHTTPSConnection)

def test_connect_with_pinning_disabled_uses_https_connection(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com',
disable_ca_pinning=True)
conn = client._connect()
self.assertIsInstance(conn, http.client.HTTPSConnection)
self.assertNotIsInstance(conn, duo_client.https_wrapper.CertValidatingHTTPSConnection)

def test_connect_with_pinning_disabled_has_verification_enabled(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com',
disable_ca_pinning=True)
conn = client._connect()
self.assertEqual(conn._context.verify_mode, ssl.CERT_REQUIRED)
self.assertTrue(conn._context.check_hostname)

def test_connect_with_pinning_disabled_uses_system_ca(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com',
disable_ca_pinning=True)
conn = client._connect()
default_ctx = ssl.create_default_context()
self.assertEqual(
conn._context.verify_mode, default_ctx.verify_mode)
self.assertEqual(
conn._context.check_hostname, default_ctx.check_hostname)

def test_connect_with_pinning_enabled_default(self):
client = duo_client.client.Client('ikey', 'skey', 'host.example.com')
conn = client._connect()
self.assertIsInstance(conn, duo_client.https_wrapper.CertValidatingHTTPSConnection)
self.assertEqual(conn.default_ssl_context.verify_mode, ssl.CERT_REQUIRED)

if __name__ == '__main__':
unittest.main()
Loading