Commit b3c6787e authored by Him188's avatar Him188

Introduce MultiPacket

parent 8fb20f4f
...@@ -8,8 +8,8 @@ import net.mamoe.mirai.data.ImageLink ...@@ -8,8 +8,8 @@ import net.mamoe.mirai.data.ImageLink
import net.mamoe.mirai.message.data.Image import net.mamoe.mirai.message.data.Image
import net.mamoe.mirai.qqandroid.network.QQAndroidBotNetworkHandler import net.mamoe.mirai.qqandroid.network.QQAndroidBotNetworkHandler
import net.mamoe.mirai.qqandroid.network.QQAndroidClient import net.mamoe.mirai.qqandroid.network.QQAndroidClient
import net.mamoe.mirai.qqandroid.network.protocol.packet.chat.receive.ImageIdQQA
import net.mamoe.mirai.qqandroid.utils.Context import net.mamoe.mirai.qqandroid.utils.Context
import net.mamoe.mirai.qqandroid.utils.ImageIdQQA
import net.mamoe.mirai.utils.BotConfiguration import net.mamoe.mirai.utils.BotConfiguration
import net.mamoe.mirai.utils.LockFreeLinkedList import net.mamoe.mirai.utils.LockFreeLinkedList
import net.mamoe.mirai.utils.MiraiInternalAPI import net.mamoe.mirai.utils.MiraiInternalAPI
...@@ -28,7 +28,7 @@ internal abstract class QQAndroidBotBase constructor( ...@@ -28,7 +28,7 @@ internal abstract class QQAndroidBotBase constructor(
configuration: BotConfiguration configuration: BotConfiguration
) : BotImpl<QQAndroidBotNetworkHandler>(account, configuration) { ) : BotImpl<QQAndroidBotNetworkHandler>(account, configuration) {
val client: QQAndroidClient = QQAndroidClient(context, account, bot = @Suppress("LeakingThis") this as QQAndroidBot) val client: QQAndroidClient = QQAndroidClient(context, account, bot = @Suppress("LeakingThis") this as QQAndroidBot)
override val uin: Long get() = client.uin
override val qqs: ContactList<QQ> = ContactList(LockFreeLinkedList()) override val qqs: ContactList<QQ> = ContactList(LockFreeLinkedList())
override fun getQQ(id: Long): QQ { override fun getQQ(id: Long): QQ {
......
package net.mamoe.mirai.qqandroid.network package net.mamoe.mirai.qqandroid.network
import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.io.core.* import kotlinx.io.core.*
import kotlinx.io.pool.ObjectPool import kotlinx.io.pool.ObjectPool
import net.mamoe.mirai.data.MultiPacket
import net.mamoe.mirai.data.Packet import net.mamoe.mirai.data.Packet
import net.mamoe.mirai.event.BroadcastControllable import net.mamoe.mirai.event.BroadcastControllable
import net.mamoe.mirai.event.Cancellable import net.mamoe.mirai.event.Cancellable
...@@ -14,6 +17,7 @@ import net.mamoe.mirai.qqandroid.event.PacketReceivedEvent ...@@ -14,6 +17,7 @@ import net.mamoe.mirai.qqandroid.event.PacketReceivedEvent
import net.mamoe.mirai.qqandroid.network.protocol.packet.KnownPacketFactories import net.mamoe.mirai.qqandroid.network.protocol.packet.KnownPacketFactories
import net.mamoe.mirai.qqandroid.network.protocol.packet.OutgoingPacket import net.mamoe.mirai.qqandroid.network.protocol.packet.OutgoingPacket
import net.mamoe.mirai.qqandroid.network.protocol.packet.PacketFactory import net.mamoe.mirai.qqandroid.network.protocol.packet.PacketFactory
import net.mamoe.mirai.qqandroid.network.protocol.packet.PacketLogger
import net.mamoe.mirai.qqandroid.network.protocol.packet.login.LoginPacket import net.mamoe.mirai.qqandroid.network.protocol.packet.login.LoginPacket
import net.mamoe.mirai.qqandroid.network.protocol.packet.login.LoginPacket.LoginPacketResponse.* import net.mamoe.mirai.qqandroid.network.protocol.packet.login.LoginPacket.LoginPacketResponse.*
import net.mamoe.mirai.qqandroid.network.protocol.packet.login.StatSvc import net.mamoe.mirai.qqandroid.network.protocol.packet.login.StatSvc
...@@ -40,9 +44,9 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -40,9 +44,9 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
when (response) { when (response) {
is UnsafeLogin -> { is UnsafeLogin -> {
bot.logger.info("Login unsuccessful, device auth is needed") bot.logger.info("Login unsuccessful, device auth is needed")
bot.logger.info("登陆失败, 原因为非常用设备登陆") bot.logger.info("登录失败, 原因为非常用设备登录")
bot.logger.info("Open the following URL in QQ browser and complete the verification") bot.logger.info("Open the following URL in QQ browser and complete the verification")
bot.logger.info("将下面这个链接在QQ浏览器中打开并完成认证后尝试再次登") bot.logger.info("将下面这个链接在QQ浏览器中打开并完成认证后尝试再次登")
bot.logger.info(response.url) bot.logger.info(response.url)
return return
} }
...@@ -101,10 +105,14 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -101,10 +105,14 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
@Suppress("PrivatePropertyName") @Suppress("PrivatePropertyName")
private val PacketProcessDispatcher = newCoroutineDispatcher(1) private val PacketProcessDispatcher = newCoroutineDispatcher(1)
/**
* 缓存超时处理的 [Job]. 超时后将清空缓存, 以免阻碍后续包的处理
*/
private var cachedPacketTimeoutJob: Job? = null
/** /**
* 缓存的包 * 缓存的包
*/ */
private var cachedPacket: ByteReadPacket? = null private val cachedPacket: AtomicRef<ByteReadPacket?> = atomic(null)
/** /**
* 缓存的包还差多少长度 * 缓存的包还差多少长度
*/ */
...@@ -146,8 +154,29 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -146,8 +154,29 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
* @param input 一个完整的包的内容, 去掉开头的 int 包长度 * @param input 一个完整的包的内容, 去掉开头的 int 包长度
*/ */
suspend fun parsePacket(input: Input) { suspend fun parsePacket(input: Input) {
generifiedParsePacket<Packet>(input)
}
private suspend inline fun <P : Packet> generifiedParsePacket(input: Input) {
try { try {
KnownPacketFactories.parseIncomingPacket(bot, input) { packetFactory: PacketFactory<Packet>, packet: Packet, commandName: String, sequenceId: Int -> KnownPacketFactories.parseIncomingPacket(bot, input) { packetFactory: PacketFactory<P>, packet: P, commandName: String, sequenceId: Int ->
handlePacket(packetFactory, packet, commandName, sequenceId)
if (packet is MultiPacket<*>) {
packet.forEach {
handlePacket(null, it, commandName, sequenceId)
}
}
}
} finally {
println()
println() // separate for debugging
}
}
/**
* 处理解析完成的包.
*/
suspend fun <P : Packet> handlePacket(packetFactory: PacketFactory<P>?, packet: P, commandName: String, sequenceId: Int) {
// highest priority: pass to listeners (attached by sendAndExpect). // highest priority: pass to listeners (attached by sendAndExpect).
packetListeners.forEach { listener -> packetListeners.forEach { listener ->
if (listener.filter(commandName, sequenceId) && packetListeners.remove(listener)) { if (listener.filter(commandName, sequenceId) && packetListeners.remove(listener)) {
...@@ -157,7 +186,7 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -157,7 +186,7 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
// check top-level cancelling // check top-level cancelling
if (PacketReceivedEvent(packet).broadcast().cancelled) { if (PacketReceivedEvent(packet).broadcast().cancelled) {
return@parseIncomingPacket return
} }
...@@ -169,72 +198,87 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -169,72 +198,87 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
packet.broadcast() packet.broadcast()
} }
if (packet is Cancellable && packet.cancelled) return@parseIncomingPacket if (packet is Cancellable && packet.cancelled) return
} }
packetFactory.run { packet.handle(bot) } packetFactory?.run {
bot.handle(packet)
}
bot.logger.info(packet) bot.logger.info(packet)
} }
} finally {
println()
println() // separate for debugging
}
}
/** /**
* 处理从服务器接收过来的包. 这些包可能是粘在一起的, 也可能是不完整的. 将会自动处理. * 处理从服务器接收过来的包. 这些包可能是粘在一起的, 也可能是不完整的. 将会自动处理.
* 处理后的包会调用 [parsePacketAsync] * 处理后的包会调用 [parsePacketAsync]
*/ */
@UseExperimental(ExperimentalCoroutinesApi::class) @UseExperimental(ExperimentalCoroutinesApi::class)
internal fun processPacket(rawInput: ByteReadPacket): Unit = rawInput.debugPrint("Received").let { input: ByteReadPacket -> internal fun processPacket(rawInput: ByteReadPacket) {
if (input.remaining == 0L) { if (rawInput.remaining == 0L) {
return return
} }
if (cachedPacket == null) { val cache = cachedPacket.value
if (cache == null) {
// 没有缓存 // 没有缓存
var length: Int = input.readInt() - 4 var length: Int = rawInput.readInt() - 4
if (input.remaining == length.toLong()) { if (rawInput.remaining == length.toLong()) {
// 捷径: 当包长度正好, 直接传递剩余数据. // 捷径: 当包长度正好, 直接传递剩余数据.
parsePacketAsync(input) cachedPacketTimeoutJob?.cancel()
parsePacketAsync(rawInput)
return return
} }
// 循环所有完整的包 // 循环所有完整的包
while (input.remaining > length) { while (rawInput.remaining > length) {
parsePacketAsync(input.readIoBuffer(length)) parsePacketAsync(rawInput.readIoBuffer(length))
length = input.readInt() - 4 length = rawInput.readInt() - 4
} }
if (input.remaining != 0L) { if (rawInput.remaining != 0L) {
// 剩余的包长度不够, 缓存后接收下一个包 // 剩余的包长度不够, 缓存后接收下一个包
expectingRemainingLength = length - input.remaining expectingRemainingLength = length - rawInput.remaining
cachedPacket = input cachedPacket.value = rawInput
} else { } else {
cachedPacket = null // 表示包长度正好 cachedPacket.value = null // 表示包长度正好
cachedPacketTimeoutJob?.cancel()
return
} }
} else { } else {
// 有缓存 // 有缓存
if (input.remaining >= expectingRemainingLength) { if (rawInput.remaining >= expectingRemainingLength) {
// 剩余长度够, 连接上去, 处理这个包. // 剩余长度够, 连接上去, 处理这个包.
parsePacketAsync(buildPacket { parsePacketAsync(buildPacket {
writePacket(cachedPacket!!) writePacket(cache)
writePacket(input, expectingRemainingLength) writePacket(rawInput, expectingRemainingLength)
}) })
cachedPacket = null // 缺少的长度已经给上了. cachedPacket.value = null // 缺少的长度已经给上了.
if (input.remaining != 0L) { if (rawInput.remaining != 0L) {
processPacket(input) // 继续处理剩下内容 return processPacket(rawInput) // 继续处理剩下内容
} else {
// 处理好了.
cachedPacketTimeoutJob?.cancel()
return
} }
} else { } else {
// 剩余不够, 连接上去 // 剩余不够, 连接上去
expectingRemainingLength -= input.remaining expectingRemainingLength -= rawInput.remaining
cachedPacket = buildPacket { // do not inline `packet`. atomicfu unsupported
writePacket(cachedPacket!!) val packet = buildPacket {
writePacket(input) writePacket(cache)
writePacket(rawInput)
}
cachedPacket.value = packet
}
} }
cachedPacketTimeoutJob?.cancel()
cachedPacketTimeoutJob = launch {
delay(1000)
if (cachedPacketTimeoutJob == this.coroutineContext[Job] && cachedPacket.getAndSet(null) != null) {
PacketLogger.verbose("等待另一部分包时超时. 将舍弃已接收的半个包")
} }
} }
} }
...@@ -247,10 +291,13 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -247,10 +291,13 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
channel.read() channel.read()
} catch (e: ClosedChannelException) { } catch (e: ClosedChannelException) {
dispose() dispose()
bot.tryReinitializeNetworkHandler(e)
return return
} catch (e: ReadPacketInternalException) { } catch (e: ReadPacketInternalException) {
bot.logger.error("Socket channel read failed: ${e.message}") bot.logger.error("Socket channel read failed: ${e.message}")
continue dispose()
bot.tryReinitializeNetworkHandler(e)
return
} catch (e: CancellationException) { } catch (e: CancellationException) {
return return
} catch (e: Throwable) { } catch (e: Throwable) {
...@@ -263,19 +310,24 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -263,19 +310,24 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
} }
} }
/**
* 发送一个包, 并挂起直到接收到指定的返回包或超时(3000ms)
*/
suspend fun <E : Packet> OutgoingPacket.sendAndExpect(): E { suspend fun <E : Packet> OutgoingPacket.sendAndExpect(): E {
val handler = PacketListener(commandName = commandName, sequenceId = sequenceId) val handler = PacketListener(commandName = commandName, sequenceId = sequenceId)
packetListeners.addLast(handler) packetListeners.addLast(handler)
channel.send(delegate) channel.send(delegate)
return withTimeout(3000) {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
return handler.await() as E handler.await() as E
}
} }
@PublishedApi @PublishedApi
internal val packetListeners = LockFreeLinkedList<PacketListener>() internal val packetListeners = LockFreeLinkedList<PacketListener>()
@PublishedApi @PublishedApi
internal inner class PacketListener( internal inner class PacketListener( // callback
val commandName: String, val commandName: String,
val sequenceId: Int val sequenceId: Int
) : CompletableDeferred<Packet> by CompletableDeferred(supervisor) { ) : CompletableDeferred<Packet> by CompletableDeferred(supervisor) {
...@@ -284,10 +336,5 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler ...@@ -284,10 +336,5 @@ internal class QQAndroidBotNetworkHandler(bot: QQAndroidBot) : BotNetworkHandler
override suspend fun awaitDisconnection() = supervisor.join() override suspend fun awaitDisconnection() = supervisor.join()
override fun dispose(cause: Throwable?) {
println("Closed")
super.dispose(cause)
}
override val coroutineContext: CoroutineContext = bot.coroutineContext override val coroutineContext: CoroutineContext = bot.coroutineContext
} }
\ No newline at end of file
...@@ -14,6 +14,7 @@ import net.mamoe.mirai.qqandroid.network.protocol.packet.login.data.RequestPacke ...@@ -14,6 +14,7 @@ import net.mamoe.mirai.qqandroid.network.protocol.packet.login.data.RequestPacke
import net.mamoe.mirai.utils.DefaultLogger import net.mamoe.mirai.utils.DefaultLogger
import net.mamoe.mirai.utils.MiraiLogger import net.mamoe.mirai.utils.MiraiLogger
import net.mamoe.mirai.utils.cryptor.adjustToPublicKey import net.mamoe.mirai.utils.cryptor.adjustToPublicKey
import net.mamoe.mirai.utils.cryptor.contentToString
import net.mamoe.mirai.utils.cryptor.decryptBy import net.mamoe.mirai.utils.cryptor.decryptBy
import net.mamoe.mirai.utils.io.* import net.mamoe.mirai.utils.io.*
import kotlin.contracts.ExperimentalContracts import kotlin.contracts.ExperimentalContracts
...@@ -27,7 +28,7 @@ import kotlin.jvm.JvmName ...@@ -27,7 +28,7 @@ import kotlin.jvm.JvmName
* @param TPacket 服务器回复包解析结果 * @param TPacket 服务器回复包解析结果
*/ */
@UseExperimental(ExperimentalUnsignedTypes::class) @UseExperimental(ExperimentalUnsignedTypes::class)
internal abstract class PacketFactory<out TPacket : Packet>( internal abstract class PacketFactory<TPacket : Packet>(
/** /**
* 命令名. 如 `wtlogin.login`, `ConfigPushSvc.PushDomain` * 命令名. 如 `wtlogin.login`, `ConfigPushSvc.PushDomain`
*/ */
...@@ -41,7 +42,7 @@ internal abstract class PacketFactory<out TPacket : Packet>( ...@@ -41,7 +42,7 @@ internal abstract class PacketFactory<out TPacket : Packet>(
/** /**
* 可选的处理这个包. 可以在这里面发新的包. * 可选的处理这个包. 可以在这里面发新的包.
*/ */
open suspend fun @UnsafeVariance TPacket.handle(bot: QQAndroidBot) {} open suspend fun QQAndroidBot.handle(packet: TPacket) {}
} }
@JvmName("decode0") @JvmName("decode0")
...@@ -59,7 +60,8 @@ internal object KnownPacketFactories : List<PacketFactory<*>> by mutableListOf( ...@@ -59,7 +60,8 @@ internal object KnownPacketFactories : List<PacketFactory<*>> by mutableListOf(
LoginPacket, LoginPacket,
StatSvc.Register, StatSvc.Register,
OnlinePush.PbPushGroupMsg, OnlinePush.PbPushGroupMsg,
MessageSvc.PushNotify MessageSvc.PushNotify,
MessageSvc.PbGetMsg
) { ) {
fun findPacketFactory(commandName: String): PacketFactory<*>? = this.firstOrNull { it.commandName == commandName } fun findPacketFactory(commandName: String): PacketFactory<*>? = this.firstOrNull { it.commandName == commandName }
...@@ -194,7 +196,15 @@ internal object KnownPacketFactories : List<PacketFactory<*>> by mutableListOf( ...@@ -194,7 +196,15 @@ internal object KnownPacketFactories : List<PacketFactory<*>> by mutableListOf(
val unknown = readBytes(readInt() - 4) val unknown = readBytes(readInt() - 4)
if (unknown.toInt() != 0x02B05B8B) DebugLogger.debug("got new unknown: ${unknown.toUHexString()}") if (unknown.toInt() != 0x02B05B8B) DebugLogger.debug("got new unknown: ${unknown.toUHexString()}")
check(readInt() == 0) readInt().let {
if (it != 0) {
DebugLogger.debug("!! 得到一个原本是 0, 现在是 ${it.contentToString()}")
if (it == 1){
PacketLogger.info("无法处理的数据 = ${input.readBytes().toUHexString()}")
return IncomingPacket(null, ssoSequenceId, input)
}
}
}
} }
// body // body
......
...@@ -4,3 +4,12 @@ package net.mamoe.mirai.data ...@@ -4,3 +4,12 @@ package net.mamoe.mirai.data
* 从服务器收到的包解析之后的结构化数据. * 从服务器收到的包解析之后的结构化数据.
*/ */
interface Packet interface Packet
/**
* PacketFactory 可以一次解析多个包出来. 它们将会被分别广播.
*/
class MultiPacket<P : Packet>(delegate: List<P>) : List<P> by delegate, Packet {
override fun toString(): String {
return "MultiPacket<${this.firstOrNull()?.let { it::class.simpleName }?: "?"}>"
}
}
\ No newline at end of file
...@@ -36,7 +36,11 @@ actual class PlatformSocket : Closeable { ...@@ -36,7 +36,11 @@ actual class PlatformSocket : Closeable {
* @throws SendPacketInternalException * @throws SendPacketInternalException
*/ */
actual suspend inline fun send(packet: ByteReadPacket) { actual suspend inline fun send(packet: ByteReadPacket) {
try {
writeChannel.writePacket(packet) writeChannel.writePacket(packet)
} catch (e: Exception) {
throw SendPacketInternalException(e)
}
} }
/** /**
...@@ -45,7 +49,11 @@ actual class PlatformSocket : Closeable { ...@@ -45,7 +49,11 @@ actual class PlatformSocket : Closeable {
actual suspend inline fun read(): ByteReadPacket { actual suspend inline fun read(): ByteReadPacket {
// do not use readChannel.readRemaining() !!! this function never returns // do not use readChannel.readRemaining() !!! this function never returns
ByteArrayPool.useInstance { buffer -> ByteArrayPool.useInstance { buffer ->
val count = readChannel.readAvailable(buffer) val count = try {
readChannel.readAvailable(buffer)
} catch (e: Exception) {
throw ReadPacketInternalException(e)
}
return buffer.toReadPacket(0, count) return buffer.toReadPacket(0, count)
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment