-
Notifications
You must be signed in to change notification settings - Fork 245
chore: port browser-proxy to pg-gateway Web API #100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,34 @@ | ||
import * as nodeNet from 'node:net' | ||
import * as https from 'node:https' | ||
import { PostgresConnection } from 'pg-gateway' | ||
import { BackendError, PostgresConnection } from 'pg-gateway' | ||
import { fromNodeSocket } from 'pg-gateway/node' | ||
import { WebSocketServer, type WebSocket } from 'ws' | ||
import makeDebug from 'debug' | ||
import * as tls from 'node:tls' | ||
import { extractDatabaseId, isValidServername } from './servername.ts' | ||
import { getTls } from './tls.ts' | ||
import { getTls, setSecureContext } from './tls.ts' | ||
import { createStartupMessage } from './create-message.ts' | ||
import { extractIP } from './extract-ip.ts' | ||
|
||
const debug = makeDebug('browser-proxy') | ||
|
||
const tcpConnections = new Map<string, nodeNet.Socket>() | ||
const tcpConnections = new Map<string, PostgresConnection>() | ||
const websocketConnections = new Map<string, WebSocket>() | ||
|
||
let tlsOptions = await getTls() | ||
|
||
// refresh the TLS certificate every week | ||
setInterval( | ||
async () => { | ||
tlsOptions = await getTls() | ||
httpsServer.setSecureContext(tlsOptions) | ||
}, | ||
1000 * 60 * 60 * 24 * 7 | ||
) | ||
|
||
const httpsServer = https.createServer({ | ||
...tlsOptions, | ||
SNICallback: (servername, callback) => { | ||
debug('SNICallback', servername) | ||
if (isValidServername(servername)) { | ||
debug('SNICallback', 'valid') | ||
callback(null, tls.createSecureContext(tlsOptions)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the end we don't need that, the default secureContext is used. |
||
callback(null) | ||
} else { | ||
debug('SNICallback', 'invalid') | ||
callback(new Error('invalid SNI')) | ||
} | ||
}, | ||
}) | ||
await setSecureContext(httpsServer) | ||
// reset the secure context every week to pick up any new TLS certificates | ||
setInterval(() => setSecureContext(httpsServer), 1000 * 60 * 60 * 24 * 7) | ||
|
||
const websocketServer = new WebSocketServer({ | ||
server: httpsServer, | ||
|
@@ -70,8 +61,8 @@ websocketServer.on('connection', (socket, request) => { | |
|
||
socket.on('message', (data: Buffer) => { | ||
debug('websocket message', data.toString('hex')) | ||
const tcpSocket = tcpConnections.get(databaseId) | ||
tcpSocket?.write(data) | ||
const tcpConnection = tcpConnections.get(databaseId) | ||
tcpConnection?.streamWriter?.write(data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here is why I needed to add the |
||
}) | ||
|
||
socket.on('close', () => { | ||
|
@@ -86,50 +77,41 @@ const net = ( | |
|
||
const tcpServer = net.createServer() | ||
|
||
tcpServer.on('connection', (socket) => { | ||
tcpServer.on('connection', async (socket) => { | ||
let databaseId: string | undefined | ||
|
||
const connection = new PostgresConnection(socket, { | ||
tls: tlsOptions, | ||
const connection = await fromNodeSocket(socket, { | ||
tls: getTls, | ||
Comment on lines
-93
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now the TLS certificates should be correctly refreshed after some time. 😅 |
||
onTlsUpgrade(state) { | ||
if (!state.tlsInfo?.sniServerName || !isValidServername(state.tlsInfo.sniServerName)) { | ||
// connection.detach() | ||
connection.sendError({ | ||
if (!state.tlsInfo?.serverName || !isValidServername(state.tlsInfo.serverName)) { | ||
throw BackendError.create({ | ||
code: '08006', | ||
message: 'invalid SNI', | ||
severity: 'FATAL', | ||
}) | ||
connection.end() | ||
return | ||
} | ||
|
||
const _databaseId = extractDatabaseId(state.tlsInfo.sniServerName!) | ||
const _databaseId = extractDatabaseId(state.tlsInfo.serverName!) | ||
|
||
if (!websocketConnections.has(_databaseId!)) { | ||
// connection.detach() | ||
connection.sendError({ | ||
throw BackendError.create({ | ||
code: 'XX000', | ||
message: 'the browser is not sharing the database', | ||
severity: 'FATAL', | ||
}) | ||
connection.end() | ||
return | ||
} | ||
|
||
if (tcpConnections.has(_databaseId)) { | ||
// connection.detach() | ||
connection.sendError({ | ||
throw BackendError.create({ | ||
code: '53300', | ||
message: 'sorry, too many clients already', | ||
severity: 'FATAL', | ||
}) | ||
connection.end() | ||
return | ||
} | ||
|
||
// only set the databaseId after we've verified the connection | ||
databaseId = _databaseId | ||
tcpConnections.set(databaseId!, connection.socket) | ||
tcpConnections.set(databaseId!, connection) | ||
}, | ||
serverVersion() { | ||
return '16.3' | ||
|
@@ -138,13 +120,11 @@ tcpServer.on('connection', (socket) => { | |
const websocket = websocketConnections.get(databaseId!) | ||
|
||
if (!websocket) { | ||
connection.sendError({ | ||
throw BackendError.create({ | ||
code: 'XX000', | ||
message: 'the browser is not sharing the database', | ||
severity: 'FATAL', | ||
}) | ||
connection.end() | ||
return | ||
} | ||
|
||
const clientIpMessage = createStartupMessage('postgres', 'postgres', { | ||
|
@@ -160,13 +140,11 @@ tcpServer.on('connection', (socket) => { | |
const websocket = websocketConnections.get(databaseId!) | ||
|
||
if (!websocket) { | ||
connection.sendError({ | ||
throw BackendError.create({ | ||
code: 'XX000', | ||
message: 'the browser is not sharing the database', | ||
severity: 'FATAL', | ||
}) | ||
connection.end() | ||
return | ||
} | ||
|
||
debug('tcp message', { message }) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,12 @@ | ||
import { Buffer } from 'node:buffer' | ||
import { GetObjectCommand, S3Client } from '@aws-sdk/client-s3' | ||
import pMemoize from 'p-memoize' | ||
import ExpiryMap from 'expiry-map' | ||
import type { Server } from 'node:https' | ||
|
||
const s3Client = new S3Client({ forcePathStyle: true }) | ||
|
||
export async function getTls() { | ||
async function _getTls() { | ||
const cert = await s3Client | ||
.send( | ||
new GetObjectCommand({ | ||
|
@@ -31,3 +34,12 @@ export async function getTls() { | |
key: Buffer.from(key), | ||
} | ||
} | ||
|
||
// cache the TLS certificate for 1 week | ||
const cache = new ExpiryMap(1000 * 60 * 60 * 24 * 7) | ||
Comment on lines
+38
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if this server restarted 1 day before renewal? I suppose the old cert is still good for another week, so this should be safe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to check but I think we renew the certificate monthly while the validity is 3 months. So in any cases with 2 week refresh we should be safe. |
||
export const getTls = pMemoize(_getTls, { cache }) | ||
|
||
export async function setSecureContext(httpsServer: Server) { | ||
const tlsOptions = await getTls() | ||
httpsServer.setSecureContext(tlsOptions) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,7 +138,17 @@ export default function AppProvider({ children }: AppProps) { | |
if (isStartupMessage(message)) { | ||
const parameters = parseStartupMessage(message) | ||
if ('client_ip' in parameters) { | ||
setConnectedClientIp(parameters.client_ip === '' ? null : parameters.client_ip) | ||
// client disconnected | ||
if (parameters.client_ip === '') { | ||
setConnectedClientIp(null) | ||
// we ensure we're not in a transaction block first | ||
await db.sql`rollback;`.catch() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How bizarre! |
||
// we clean the session state, see: https://www.pgbouncer.org/faq.html#how-to-use-prepared-statements-with-session-pooling | ||
// we do this to avoid having old prepared statements in the session | ||
await db.sql`discard all;` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added after working on making There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose we should mention |
||
} else { | ||
setConnectedClientIp(parameters.client_ip) | ||
} | ||
} | ||
return | ||
} | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.