/*
* CDDL HEADER START
*
* The contents of this file are subject to the terms of the
* Common Development and Distribution License (the "License").
* You may not use this file except in compliance with the License.
*
* You can obtain a copy of the license at usr/src/OPENSOLARIS.LICENSE
* or http://www.opensolaris.org/os/licensing.
* See the License for the specific language governing permissions
* and limitations under the License.
*
* When distributing Covered Code, include this CDDL HEADER in each
* file and include the License file at usr/src/OPENSOLARIS.LICENSE.
* If applicable, add the following below this CDDL HEADER, with the
* fields enclosed by brackets "[]" replaced with your own identifying
* information: Portions Copyright [yyyy] [name of copyright owner]
*
* CDDL HEADER END
*/
/*
* Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
*/
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <sys/time.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <stdarg.h>
#include <stddef.h>
#include <unistd.h>
#include <fcntl.h>
#include <pthread.h>
#include <poll.h>
#include "adr_stream.h"
/*
* Common adr_stream implementation
*/
struct adr_stream {
void *astr_data;
ssize_t (*astr_read)(void *, char *, size_t);
ssize_t (*astr_write)(void *, const char *, size_t);
void (*astr_close)(void *);
void (*astr_free)(void *);
};
adr_stream_t *
adr_stream_create(ssize_t (*readf)(void *, char *, size_t),
ssize_t (*writef)(void *, const char *, size_t),
void (*closef)(void *), void (*freef)(void *), void *data)
{
adr_stream_t *result = malloc(sizeof (adr_stream_t));
if (result == NULL) {
closef(data);
freef(data);
return (NULL);
}
result->astr_data = data;
result->astr_read = readf;
result->astr_write = writef;
result->astr_close = closef;
result->astr_free = freef;
return (result);
}
ssize_t
adr_stream_read(adr_stream_t *str, char *buf, size_t len)
{
return (str->astr_read(str->astr_data, buf, len));
}
ssize_t
adr_stream_write(adr_stream_t *str, const char *buf, size_t len)
{
return (str->astr_write(str->astr_data, buf, len));
}
void
adr_stream_close(adr_stream_t *str)
{
str->astr_close(str->astr_data);
}
void
adr_stream_free(adr_stream_t *str)
{
if (str->astr_free)
str->astr_free(str->astr_data);
free(str);
}
/*
* File descriptor stream implementation.
*/
typedef struct adr_fdstream {
int infd;
int outfd;
boolean_t closed;
} adr_fdstream_t;
static ssize_t
adr_fdstream_read(void *data, char *buf, size_t len)
{
adr_fdstream_t *fdstr = data;
return (read(fdstr->infd, buf, len));
}
static ssize_t
adr_fdstream_write(void *data, const char *buf, size_t len)
{
adr_fdstream_t *fdstr = data;
return (write(fdstr->outfd, buf, len));
}
static void
adr_fdstream_close(void *data)
{
adr_fdstream_t *fdstr = data;
int fd;
if ((fd = open("/dev/null", O_RDWR)) != -1) {
(void) dup2(fd, fdstr->infd);
if (fdstr->infd != fdstr->outfd)
(void) dup2(fd, fdstr->outfd);
(void) close(fd);
}
fdstr->closed = B_TRUE;
}
static void
adr_fdstream_free(void *data)
{
adr_fdstream_t *fdstr = data;
(void) close(fdstr->infd);
if (fdstr->infd != fdstr->outfd)
(void) close(fdstr->outfd);
free(data);
}
adr_stream_t *
adr_stream_create_fds(int infd, int outfd)
{
adr_stream_t *result;
adr_fdstream_t *fdstr = malloc(sizeof (adr_fdstream_t));
if (fdstr == NULL) {
(void) close(infd);
if (infd != outfd)
(void) close(outfd);
return (NULL);
}
fdstr->infd = infd;
fdstr->outfd = outfd;
fdstr->closed = B_FALSE;
if ((result = adr_stream_create(adr_fdstream_read, adr_fdstream_write,
adr_fdstream_close, adr_fdstream_free, fdstr)) == NULL)
return (NULL);
return (result);
}
adr_stream_t *
adr_stream_create_fd(int fd)
{
return (adr_stream_create_fds(fd, fd));
}
static pthread_mutex_t *crypto_locks;
static unsigned long
adr_ssl_id_function(void)
{
return (pthread_self());
}
/* ARGSUSED */
static void
adr_ssl_locking_function(int mode, int n, const char *file, int line)
{
if (mode & CRYPTO_LOCK)
(void) pthread_mutex_lock(&crypto_locks[n]);
else
(void) pthread_mutex_unlock(&crypto_locks[n]);
}
void
adr_ssl_init(void)
{
crypto_locks = malloc(CRYPTO_num_locks() * sizeof (pthread_mutex_t));
if (crypto_locks == NULL)
abort();
for (int i = 0; i < CRYPTO_num_locks(); i++)
(void) pthread_mutex_init(&crypto_locks[i], NULL);
CRYPTO_set_locking_callback(adr_ssl_locking_function);
CRYPTO_set_id_callback(adr_ssl_id_function);
}
/*
* SSL stream implementation
*
* Neither documentation nor internet chatter is clear on whether it is
* legal to call SSL_write after an SSL_read that returns
* SSL_ERROR_WANT_* before completing the SSL_read (and vice versa),
* but I don't see how the interfaces are otherwise usable without
* introducing cross-connection deadlocks.
*
* To avoid the single-side deadlock situation where one thread reads
* socket data between when the other drops its lock and starts to poll
* for it, we call poll with a 5 second timeout. This limits the time
* spent polling (in the traditional sense) the socket, and lets us to
* check every once in a while in case this rare situation arises, all
* without the per-connection overhead more precise solutions require.
*/
typedef struct adr_sslstream {
SSL *ssl;
int fd;
boolean_t closed;
pthread_mutex_t lock;
} adr_sslstream_t;
static void
adr_sslstream_close_locked(adr_sslstream_t *sslstr)
{
if (!sslstr->closed) {
int fd;
if ((fd = open("/dev/null", O_RDWR)) != -1) {
(void) dup2(fd, sslstr->fd);
(void) close(fd);
}
sslstr->closed = B_TRUE;
}
}
static boolean_t
adr_sslstream_wait(adr_sslstream_t *sslstr, int res)
{
if (res == -1) {
int err = SSL_get_error(sslstr->ssl, res);
struct pollfd pfd;
pfd.fd = sslstr->fd;
if (err == SSL_ERROR_WANT_READ)
pfd.events = POLLIN | POLLRDNORM;
else if (err == SSL_ERROR_WANT_WRITE)
pfd.events = POLLOUT | POLLWRNORM;
else
goto out;
/*
* We weren't able to read/write everything required.
* Poll for the ability to do so, and retry.
*
* Use a timeout in case a renegotiation caused another
* thread to read/write the data we were polling for.
* Much simpler than implementing a precise wakeup.
*/
(void) pthread_mutex_unlock(&sslstr->lock);
if (poll(&pfd, 1, 5 * MILLISEC) >= 0)
return (B_TRUE);
(void) pthread_mutex_lock(&sslstr->lock);
}
out:
adr_sslstream_close_locked(sslstr);
(void) pthread_mutex_unlock(&sslstr->lock);
return (B_FALSE);
}
static ssize_t
adr_sslstream_read(void *data, char *buf, size_t len)
{
int res;
adr_sslstream_t *sslstr = data;
do {
(void) pthread_mutex_lock(&sslstr->lock);
/*
* We're done.
*/
if (sslstr->closed) {
(void) pthread_mutex_unlock(&sslstr->lock);
return (0);
}
ERR_clear_error();
res = SSL_read(sslstr->ssl, buf, len);
if (res > 0) {
(void) pthread_mutex_unlock(&sslstr->lock);
return (res);
}
/* drops lock */
} while (adr_sslstream_wait(sslstr, res));
return (res);
}
static ssize_t
adr_sslstream_write(void *data, const char *buf, size_t len)
{
int res;
adr_sslstream_t *sslstr = data;
do {
(void) pthread_mutex_lock(&sslstr->lock);
/*
* We're done.
*/
if (sslstr->closed) {
(void) pthread_mutex_unlock(&sslstr->lock);
return (0);
}
ERR_clear_error();
res = SSL_write(sslstr->ssl, buf, len);
if (res > 0) {
(void) pthread_mutex_unlock(&sslstr->lock);
return (res);
}
/* drops lock */
} while (adr_sslstream_wait(sslstr, res));
return (res);
}
static void
adr_sslstream_close(void *data)
{
adr_sslstream_t *sslstr = data;
(void) pthread_mutex_lock(&sslstr->lock);
adr_sslstream_close_locked(sslstr);
(void) pthread_mutex_unlock(&sslstr->lock);
}
static void
adr_sslstream_free(void *data)
{
adr_sslstream_t *sslstr = data;
SSL_free(sslstr->ssl);
(void) close(sslstr->fd);
free(data);
}
adr_stream_t *
adr_stream_create_ssl(SSL *ssl, int fd)
{
adr_stream_t *result;
adr_sslstream_t *sslstr = malloc(sizeof (adr_sslstream_t));
if (sslstr == NULL) {
SSL_free(ssl);
(void) close(fd);
return (NULL);
}
sslstr->ssl = ssl;
sslstr->fd = fd;
sslstr->closed = B_FALSE;
(void) pthread_mutex_init(&sslstr->lock, NULL);
int flags = fcntl(fd, F_GETFL, 0);
(void) fcntl(fd, F_SETFL, flags | O_NONBLOCK);
if ((result = adr_stream_create(adr_sslstream_read, adr_sslstream_write,
adr_sslstream_close, adr_sslstream_free, sslstr)) == NULL)
return (NULL);
return (result);
}