NativeLambdaApi.kt
/*
* This file is part of the pl.wrzasq.commons.
*
* @license http://mit-license.org/ The MIT license
* @copyright 2021 - 2022 © by Rafał Wrzeszcz - Wrzasq.pl.
*/
package pl.wrzasq.commons.aws.runtime
import com.amazonaws.services.lambda.runtime.Context
import com.amazonaws.services.lambda.runtime.LambdaRuntime
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.encodeToStream
import pl.wrzasq.commons.aws.runtime.config.EnvironmentConfig
import pl.wrzasq.commons.aws.runtime.config.LambdaRuntimeConfig
import pl.wrzasq.commons.aws.runtime.config.ResourcesFactory
import pl.wrzasq.commons.aws.runtime.model.JsonClientContext
import pl.wrzasq.commons.aws.runtime.model.JsonCognitoIdentity
import pl.wrzasq.commons.aws.runtime.model.LambdaRuntimeError
import java.io.InputStream
import java.io.OutputStream
import java.net.HttpURLConnection
import kotlin.reflect.full.createInstance
/**
* Header name for request ID.
*/
const val HEADER_NAME_AWS_REQUEST_ID = "Lambda-Runtime-Aws-Request-Id"
/**
* Header name for X-Ray trace ID.
*/
const val HEADER_NAME_TRACE_ID = "Lambda-Runtime-Trace-Id"
/**
* Header name for function ARN.
*/
const val HEADER_NAME_INVOKED_FUNCTION_ARN = "Lambda-Runtime-Invoked-Function-Arn"
/**
* Header name for serialized Cognito identity.
*/
const val HEADER_NAME_COGNITO_IDENTITY = "Lambda-Runtime-Cognito-Identity"
/**
* Header name for serialized client context.
*/
const val HEADER_NAME_CLIENT_CONTEXT = "Lambda-Runtime-Client-Context"
/**
* Header name for running deadline.
*/
const val HEADER_NAME_DEADLINE_MS = "Lambda-Runtime-Deadline-Ms"
/**
* System property under which the X-Ray trace ID is propagated.
*/
const val PROPERTY_TRACE_ID = "com.amazonaws.xray.traceHeader"
/**
* Lambda handler callback.
*/
typealias LambdaCallback = (InputStream, OutputStream, Context) -> Unit
/**
* Native ("provided") Lambda API handler.
*
* @param json JSON serialization handler.
* @param config Function configuration.
*/
@ExperimentalSerializationApi
class NativeLambdaApi(
private val json: Json,
private val config: LambdaRuntimeConfig = EnvironmentConfig()
) {
/**
* Runs Lambda native handler.
*
* @param handler Lambda entry point.
*/
fun run(handler: LambdaCallback) = runBlocking {
try {
while (true) {
val requestConnection = config.connectionFactory("${config.baseUrl}invocation/next")
withContext(Dispatchers.IO) {
requestConnection.getInputStream()
}.use {
val requestId = requestConnection.getHeaderField(HEADER_NAME_AWS_REQUEST_ID)
// we handle both cases to clean up previous execution
val traceId = requestConnection.getHeaderField(HEADER_NAME_TRACE_ID)
if (traceId.isNullOrEmpty()) {
System.clearProperty(PROPERTY_TRACE_ID)
} else {
System.setProperty(PROPERTY_TRACE_ID, traceId)
}
try {
sendResponse("${config.baseUrl}invocation/${requestId}/response") { output ->
handler(it, output, buildContext(requestId, requestConnection as HttpURLConnection))
}
} catch (error: Exception) {
// TODO: handle case when Lambda handler itself returns LambdaRuntimeError
sendErrorResponse(
"${config.baseUrl}invocation/${requestId}/error",
"Failed to run Lambda.",
error
)
}
}
}
} catch (error: Exception) {
try {
sendErrorResponse("${config.baseUrl}init/error", "Failed to initialize Lambda.", error)
} catch (innerError: Exception) {
logError("Failed to report init error.", innerError)
}
}
}
private fun sendResponse(url: String, handler: (OutputStream) -> Unit) {
val responseConnection = config.connectionFactory(url) as HttpURLConnection
responseConnection.doOutput = true
responseConnection.requestMethod = "POST"
handler(responseConnection.outputStream)
while (responseConnection.inputStream.read() != -1) {
// drain
}
}
private fun sendErrorResponse(url: String, message: String, error: Exception) {
logError(message, error)
sendResponse(url) {
json.encodeToStream(
LambdaRuntimeError.serializer(),
LambdaRuntimeError(
error.javaClass.name,
error.message ?: message,
error.stackTrace.map(StackTraceElement::toString)
),
it
)
}
}
private fun logError(message: String, error: Exception) {
config.errorLogger(message)
config.errorLogger(error.stackTraceToString())
}
private fun buildContext(awsRequestId: String, request: HttpURLConnection): Context = StaticContext(
awsRequestId = awsRequestId,
logGroupName = config.logGroupName,
logStreamName = config.logStreamName,
functionName = config.functionName,
functionVersion = config.functionVersion,
invokedFunctionArn = request.getHeaderField(HEADER_NAME_INVOKED_FUNCTION_ARN),
cognitoIdentity = request.getHeaderField(HEADER_NAME_COGNITO_IDENTITY)
?.let<String, JsonCognitoIdentity>(json::decodeFromString),
clientContext = request.getHeaderField(HEADER_NAME_CLIENT_CONTEXT)
?.let<String, JsonClientContext>(json::decodeFromString),
runtimeDeadlineMs = request.getHeaderField(HEADER_NAME_DEADLINE_MS)?.let(String::toLong) ?: 0,
memoryLimitInMB = config.memoryLimit,
logger = LambdaRuntime.getLogger()
)
companion object {
/**
* Reflection entry point.
*
* @param handler Factory type name.
*/
fun runFromFactory(handler: String) {
val factory = Class.forName(handler).kotlin.createInstance()
if (factory is ResourcesFactory) {
factory.lambdaApi.run(factory.lambdaCallback)
} else {
throw RuntimeException("$handler is not a factory for Lambda resources")
}
}
}
}