package io.ktor.swagger.experimental import com.fasterxml.jackson.module.kotlin.* import kotlin.coroutines.* import kotlin.coroutines.intrinsics.* import kotlinx.coroutines.* import io.ktor.application.* import io.ktor.auth.authenticate import io.ktor.client.* import io.ktor.client.call.* import io.ktor.client.request.* import io.ktor.client.response.* import io.ktor.content.* import io.ktor.http.* import io.ktor.request.* import io.ktor.response.* import io.ktor.routing.* import io.ktor.util.* import java.lang.reflect.* import java.lang.reflect.Type class HttpException(val code: HttpStatusCode, val description: String = code.description) : RuntimeException(description) fun httpException(code: HttpStatusCode, message: String = code.description): Nothing = throw HttpException(code, message) fun httpException(code: Int, message: String = "Error $code"): Nothing = throw HttpException(HttpStatusCode(code, message)) @Suppress("unused") inline fun T.verifyParam(name: String, callback: (T) -> Boolean): T { if (!callback(this)) throw IllegalArgumentException("$name"); return this } inline fun T.checkRequest(cond: Boolean, callback: () -> String) { if (!cond) httpException(HttpStatusCode.BadRequest, callback()) } interface SwaggerBaseApi { } interface SwaggerBaseServer { } class ApplicationCallContext(val call: ApplicationCall) : CoroutineContext.Element { object KEY : CoroutineContext.Key override val key: CoroutineContext.Key<*> = KEY } @Suppress("unused") suspend fun SwaggerBaseServer.call(): ApplicationCall { return coroutineContext[ApplicationCallContext.KEY]?.call ?: error("ApplicationCall not available") } annotation class Method(val method: String) annotation class Body(val name: String) annotation class Header(val name: String) annotation class Query(val name: String) annotation class Path(val name: String) // Reused annotation class FormData(val name: String) annotation class Auth(vararg val auths: String) //interface FeatureClass //annotation class Feature(val clazz: KClass) inline fun createClient(client: HttpClient, rootUrl: String): T = createClient(T::class.java, client, rootUrl) fun createClient(clazz: Class, client: HttpClient, rootUrl: String): T { val rootUrlTrim = rootUrl.trimEnd('/') val apiClass = ApiClass.parse(clazz) var authContext = LinkedHashMap() return Proxy.newProxyInstance(clazz.classLoader, arrayOf(clazz)) { proxy, method, args -> val info = apiClass.getInfo(method) ?: error("Can't find method $method") val rparams = info.params.zip(args.slice(0 until info.params.size)).map { ApiClass.ApiParamInfoValue(it.first as ApiClass.ApiParamInfo, it.second) }.associateBy { it.name } //val params = method.parameters val cont = args.lastOrNull() as? Continuation? ?: throw RuntimeException("Just implemented suspend functions") val continuationReturnType = method.genericParameterTypes.firstOrNull()?.extractFirstGenericType() val realReturnType = continuationReturnType ?: method.returnType val pathPattern = info.path val pathReplaced = pathPattern.replace { "${rparams[it]?.value}" } kotlinx.coroutines.GlobalScope.apply { launch { try { val fullUrl = "$rootUrlTrim/$pathReplaced" val res = client.call(fullUrl) { this.method = HttpMethod(info.httpMethod) val body = linkedMapOf() val formData = linkedMapOf() for (param in rparams.values) { when (param.source) { Source.QUERY -> parameter(param.name, "${param.value}") Source.HEADER -> header(param.name, "${param.value}") Source.BODY -> body[param.name] = param.value Source.FORM_DATA -> formData[param.name] = param.value } } if (body.isNotEmpty()) { this.contentType(io.ktor.http.ContentType.Application.Json) this.body = ByteArrayContent(Json.stringify(body).toByteArray(Charsets.UTF_8)) } if (formData.isNotEmpty()) { this.contentType(io.ktor.http.ContentType.Application.FormUrlEncoded) this.body = ByteArrayContent(formData.map { it.key to it.value.toString() }.formUrlEncode().toByteArray(Charsets.UTF_8)) } } if (res.response.status.value < 400) { cont.resume(Json.parse(res.response.readText(), realReturnType)) } else { throw HttpExceptionWithContent(res.response.status, res.response.readText()) } } catch (e: Throwable) { cont.resumeWithException(e) } } } COROUTINE_SUSPENDED } as T } class HttpExceptionWithContent(val code: HttpStatusCode, val content: String) : RuntimeException("HTTP ERROR $code : $content") fun Routing.registerRoutes(server: SwaggerBaseServer) { val clazz = ApiClass.parse(server::class.java) for (method in clazz.methods) { authenticateIfNotEmpty(method.auths) { route(method.path.pathPattern, HttpMethod(method.httpMethod)) { handle { val args = arrayListOf() for (param in method.params) { args += param.get(call) } withContext(ApplicationCallContext(call)) { val result = method.method.invokeSuspend(server, args) call.respondText(Json.stringify(result ?: Any()), ContentType.Application.Json) } } } } } } enum class Source { BODY, QUERY, FORM_DATA, HEADER, PATH } fun Route.authenticateIfNotEmpty(configurations: List, optional: Boolean = false, build: Route.() -> Unit): Route { return if (configurations.isEmpty()) { build() this } else { authenticate(*configurations.toTypedArray(), optional = optional, build = build) } } class ApiClass(val clazz: Class<*>, val methods: List) { val methodsBySignature = methods.associateBy { it.methodSignature } fun getInfo(method: java.lang.reflect.Method) = methodsBySignature[method.signature] companion object { fun parse(clazz: Class<*>): ApiClass { val imethods = arrayListOf() for (method in clazz.methods) { val path = method.getAnnotationInAncestors(Path::class.java)?.name val httpMethod = method.getAnnotationInAncestors(Method::class.java)?.method ?: "GET" //println("METHOD: $method, $path") if (path != null) { val params = arrayListOf>() for ((ptype, annotations) in method.parameterTypes.zip(method.parameterAnnotationsInAncestors)) { // Skip the continuation last argument! if (ptype.isAssignableFrom(Continuation::class.java)) continue val body = annotations.filterIsInstance().firstOrNull()?.name val query = annotations.filterIsInstance().firstOrNull()?.name val formData = annotations.filterIsInstance().firstOrNull()?.name val header = annotations.filterIsInstance
().firstOrNull()?.name val ppath = annotations.filterIsInstance().firstOrNull()?.name val source = when { body != null -> Source.BODY query != null -> Source.QUERY formData != null -> Source.FORM_DATA header != null -> Source.HEADER ppath != null -> Source.PATH else -> Source.QUERY } val rname = body ?: query ?: formData ?: header ?: ppath ?: "unknown" //println(" - $ptype, ${annotations.toList()}") params += ApiParamInfo(source, rname, ptype) } //println("METHOD: $instance, $method, $httpMethod, $path") //for (param in params) println(" - $param") val auths = method.getAnnotationInAncestors(Auth::class.java)?.auths?.toList() ?: listOf() imethods += ApiMethodInfo(method, PathPattern(path.trim('/')), httpMethod, auths, params) } } return ApiClass(clazz, imethods) } } class ApiMethodInfo(val method: java.lang.reflect.Method, val path: PathPattern, val httpMethod: String, val auths: List, val params: List>) { val methodSignature = method.signature } data class ApiParamInfo(val source: Source, val name: String, val type: Class) { suspend fun get(call: ApplicationCall): T { return call.getTyped(source, name, type) } } data class ApiParamInfoValue(val info: ApiParamInfo, val value: T) { val source get() = info.source val type get() = info.type val name get() = info.name } class PathPattern(val pathPattern: String) { companion object { val PARAM_REGEX = Regex("\\{(\\w*)\\}") } val pathNames by lazy { PARAM_REGEX.findAll(pathPattern).map { it.groupValues[1] }.toList() } val pathRegex by lazy { Regex(replace { "(\\w+)" }) } fun replace(replacer: (name: String) -> String): String { return pathPattern.replace(PARAM_REGEX) { mr -> replacer(mr.groupValues[1]) } } fun extract(path: String): List { return pathRegex.find(path)?.groupValues?.drop(1) ?: listOf() } } } data class MethodSignature(val name: String, val types: List>) val java.lang.reflect.Method.signature get() = MethodSignature(name, parameterTypes.toList()) object Json { @PublishedApi internal val objectMapper = jacksonObjectMapper() fun convert(value: Any?, clazz: Class): T = objectMapper.convertValue(value, clazz) fun parse(str: String, clazz: Class): T = objectMapper.readValue(str, clazz) fun stringify(value: T): String = objectMapper.writeValueAsString(value) } //inline fun ApplicationCall.getTyped(source: String, name: String): T = // objectMapper.convertValue(getRaw(source, name), T::class.java) suspend fun ApplicationCall.getTyped(source: Source, name: String, clazz: Class): T { return Json.convert(getRaw(source, name), clazz) } suspend fun ApplicationCall.getTypedOrNull(source: Source, name: String, clazz: Class): T? = getRaw(source, name)?.let { Json.convert(it, clazz) } suspend fun ApplicationCall.getRaw(source: Source, name: String): Any? { return when (source) { Source.PATH -> this.parameters.get(name) Source.QUERY -> this.request.queryParameters.get(name) Source.BODY -> this.getCachedUntypedBody()[name] Source.FORM_DATA -> TODO() Source.HEADER -> this.request.header(name) } } val CACHED_BODY_KEY = AttributeKey("CACHED_BODY_KEY") inline fun Attributes.computeIfAbsentInline(key: AttributeKey, block: () -> T): T { if (!this.contains(key)) { this.put(key, block()) } return this[key] } private suspend fun ApplicationCall.getCachedUntypedBody(): Map { return attributes.computeIfAbsentInline(CACHED_BODY_KEY) { Json.parse(this@getCachedUntypedBody.receive(), HashMap::class.java) } as Map } /////////////////////////////////////////////////// // Reflection Tools /////////////////////////////////////////////////// val java.lang.Class<*>.allTypes: Set> get() { val types = LinkedHashSet>() val explore = arrayListOf(this) while (explore.isNotEmpty()) { val item = explore.removeAt(explore.size - 1) ?: continue types += item explore += item.superclass explore += item.interfaces } return types } val java.lang.reflect.Method.parameterAnnotationsInAncestors: List> get() { val allMethods = this.declaringClass.allTypes.map { try { it.getDeclaredMethod(name, *parameterTypes) ?: null } catch (e: NoSuchMethodException) { null } }.filterNotNull() val out = Array>(parameterTypes.size) { arrayListOf() }.toList() for (method in allMethods) { for ((index, annotations) in method.parameterAnnotations.withIndex()) { out[index] += annotations } } return out } suspend fun java.lang.reflect.Method.invokeSuspend(obj: Any?, args: List): Any? = suspendCoroutine { c -> val method = this@invokeSuspend val lastParam = method.parameterTypes.lastOrNull() val margs = java.util.ArrayList(args) if (lastParam != null && lastParam.isAssignableFrom(Continuation::class.java)) { margs += c } try { val result = method.invoke(obj, *margs.toTypedArray()) if (result != COROUTINE_SUSPENDED) { c.resume(result) } } catch (e: InvocationTargetException) { c.resumeWithException(e.targetException) } catch (e: Throwable) { c.resumeWithException(e) } } fun java.lang.reflect.Method.getAnnotationInAncestors(clazz: Class): T? { val res = this.getAnnotation(clazz) ?: this.getDeclaredAnnotation(clazz) if (res != null) return res // Try interfaces for (ifc in this.declaringClass.interfaces) { return ignoreErrors { ifc?.getMethod(name, *parameterTypes)?.getAnnotationInAncestors(clazz) } ?: continue } // Try ancestor return ignoreErrors { this.declaringClass.superclass?.getMethod(name, *parameterTypes) }?.getAnnotationInAncestors( clazz ) } inline fun ignoreErrors(callback: () -> T): T? = try { callback() } catch (e: Throwable) { null } fun Type.extractFirstGenericType(): Class<*> { if (this is ParameterizedType) { return this.actualTypeArguments.first().extractFirstGenericType() } if (this is WildcardType) { val tt = this.lowerBounds.firstOrNull() ?: this.upperBounds.firstOrNull() ?: error("WildcardType without lower/upper bounds") return tt.extractFirstGenericType() } if (this is Class<*>) { return this } error("Couldn't find right generic type") } suspend inline fun ApplicationCall.getBodyParam(name: String, noinline default: () -> T = { error("mandatory $name") }): T = getTypedOrNull(Source.BODY, name, T::class.java) ?: default() suspend inline fun ApplicationCall.getPath(name: String, noinline default: () -> T = { error("mandatory $name") }): T = getTypedOrNull(Source.PATH, name, T::class.java) ?: default() suspend inline fun ApplicationCall.getQuery(name: String, noinline default: () -> T = { error("mandatory $name") }): T = getTypedOrNull(Source.QUERY, name, T::class.java) ?: default()