Commit 26b9aa10 authored by Him188's avatar Him188

Jce head fix

parent e41c9231
......@@ -8,12 +8,8 @@ import kotlinx.serialization.modules.EmptyModule
import kotlinx.serialization.modules.SerialModule
import net.mamoe.mirai.qqandroid.io.JceStruct
import net.mamoe.mirai.qqandroid.io.ProtoBuf
import net.mamoe.mirai.qqandroid.network.protocol.packet.withUse
import net.mamoe.mirai.utils.io.readIoBuffer
import net.mamoe.mirai.utils.io.readString
import net.mamoe.mirai.utils.io.toIoBuffer
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import net.mamoe.mirai.utils.io.toReadPacket
@PublishedApi
internal val CharsetGBK = Charset.forName("GBK")
......@@ -383,11 +379,10 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
}
}
if (!input.input.endOfInput) {
val tag = currentTagOrNull
if (tag != null && input.peakHead().tag > tag) {
return NullReader(this.input)
}
val tag = currentTagOrNull
val jceHead = input.peakHeadOrNull()
if (tag != null && (jceHead == null || jceHead.tag > tag)) {
return NullReader(this.input)
}
return super.beginStructure(desc, *typeParams)
......@@ -402,7 +397,8 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
}
fun isTagOptional(tag: Int): Boolean {
return input.input.endOfInput || input.peakHead().tag > tag
val head = input.peakHeadOrNull()
return input.isEndOfInput || head == null || head.tag > tag
}
@Suppress("UNCHECKED_CAST")
......@@ -492,7 +488,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
@Suppress("UNCHECKED_CAST")
override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
return decodeNullableSerializableValue(deserializer as DeserializationStrategy<Any?>) as? T
?: error("value with tag $currentTagOrNull(by ${deserializer.getClassName()}) is not optional but cannot find")
?: error("value with tag $currentTagOrNull(by ${deserializer.getClassName()}) is not optional but cannot find. currentJceHead = ${input.currentJceHead}")
}
}
......@@ -500,33 +496,48 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
@UseExperimental(ExperimentalUnsignedTypes::class)
internal inner class JceInput(
@PublishedApi
internal val input: IoBuffer
internal val input: ByteReadPacket,
maxReadSize: Long = input.remaining
) : Closeable {
override fun close() = IoBuffer.Pool.recycle(input)
internal val leastRemaining = input.remaining - maxReadSize
internal val isEndOfInput: Boolean get() = input.remaining <= leastRemaining
@PublishedApi
internal fun readHead(): JceHead = input.readHead() ?: error("no enough data to read head")
internal var currentJceHead: JceHead? = input.doReadHead().also { println("first jce head = $it") }
@PublishedApi
internal fun readHeadOrNull(): JceHead? = input.readHead()
override fun close() = input.close()
internal fun peakHeadOrNull(): JceHead? = currentJceHead ?: readHeadOrNull()
internal fun peakHead(): JceHead = peakHeadOrNull() ?: error("no enough data to read head")
@PublishedApi
internal fun peakHead(): JceHead = input.makeView().readHead() ?: error("no enough data to read head")
internal fun readHead(): JceHead = readHeadOrNull() ?: error("no enough data to read head")
@PublishedApi
internal fun peakHeadOrNull(): JceHead? = input.makeView().readHead()
internal fun readHeadOrNull(): JceHead? = input.doReadHead()
@Suppress("NOTHING_TO_INLINE") // 避免 stacktrace 出现两个 readHead
private inline fun IoBuffer.readHead(): JceHead? {
if (endOfInput) return null
/**
* 读取下一个 head 存储到 [currentJceHead]
*/
private fun ByteReadPacket.doReadHead(): JceHead? {
if (isEndOfInput) {
currentJceHead = null
println("doReadHead: endOfInput")
return null
}
val var2 = readUByte()
val type = var2 and 15u
var tag = var2.toUInt() shr 4
if (tag == 15u) {
if (endOfInput) return null
if (isEndOfInput) {
currentJceHead = null
println("doReadHead: endOfInput2")
return null
}
tag = readUByte().toUInt()
}
return JceHead(tag = tag.toInt(), type = type.toByte())
currentJceHead = JceHead(tag = tag.toInt(), type = type.toByte())
println("doReadHead: $currentJceHead")
return currentJceHead
}
fun readBoolean(tag: Int): Boolean = readBooleanOrNull(tag) ?: error("cannot find tag $tag")
......@@ -583,8 +594,9 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
LIST -> ByteArray(readInt(0)) { readByte(0) }
SIMPLE_LIST -> {
val head = readHead()
readHead()
check(head.type.toInt() == 0) { "type mismatch" }
input.readBytes(readInt(0))
input.readBytes(readInt(0).also { println("list size=$it") })
}
else -> error("type mismatch")
}
......@@ -610,7 +622,7 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
input.readUInt().toInt().also { require(it in 1 until 104857600) { "bad string length: $it" } },
charset = charset.kotlinCharset
)
else -> error("type mismatch: ${head.type}")
else -> error("type mismatch: ${head.type}, expecting 6 or 7 (for string)")
}
}
......@@ -762,40 +774,51 @@ class Jce private constructor(private val charset: JceCharset, context: SerialMo
return dumpAsPacket(serializer, obj).readBytes()
}
/**
* 注意 close [packet]!!
*/
fun <T> load(deserializer: DeserializationStrategy<T>, packet: ByteReadPacket, length: Int = packet.remaining.toInt()): T {
packet.readIoBuffer(n = length).withUse {
val decoder = JceDecoder(JceInput(this))
return decoder.decode(deserializer)
}
return JceDecoder(JceInput(packet, length.toLong())).decode(deserializer)
}
override fun <T> load(deserializer: DeserializationStrategy<T>, bytes: ByteArray): T {
return bytes.toIoBuffer().withUse {
val decoder = JceDecoder(JceInput(this))
return bytes.toReadPacket().use {
val decoder = JceDecoder(JceInput(it))
decoder.decode(deserializer)
}
}
}
@UseExperimental(ExperimentalContracts::class)
internal inline fun <R> Jce.JceInput.skipToTagOrNull(tag: Int, block: (JceHead) -> R): R? {
contract {
callsInPlace(block, kotlin.contracts.InvocationKind.UNKNOWN)
}
println("skipping to $tag start")
while (true) {
if (this.input.endOfInput) {
if (isEndOfInput) { // 读不了了
currentJceHead = null
println("skipping to $tag: endOfInput")
return null
}
val head = peakHead()
var head = currentJceHead
if (head == null) { // 没有新的 head 了
head = readHeadOrNull() ?: return null
}
if (head.tag > tag) {
println("skipping to $tag: head.tag > tag")
return null
}
readHead()
// readHead()
if (head.tag == tag) {
// readHeadOrNull()
currentJceHead = null
println("skipping to $tag: run block")
return block(head)
} else {
println("skipping to $tag: tag not matching")
}
println("skipping to $tag: skipField")
this.skipField(head.type)
currentJceHead = readHeadOrNull()
}
}
......
......@@ -20,7 +20,9 @@ class JceDecoderTest {
@SerialId(3) val int: Int = 123,
@SerialId(4) val long: Long = 123,
@SerialId(5) val float: Float = 123f,
@SerialId(6) val double: Double = 123.0
@SerialId(6) val double: Double = 123.0,
@SerialId(7) val byteArray: ByteArray = byteArrayOf(1, 2, 3),
@SerialId(8) val byteArray2: ByteArray = byteArrayOf(1, 2, 3)
) : JceStruct {
override fun writeTo(output: JceOutput) = output.run {
writeString(string, 0)
......@@ -30,10 +32,81 @@ class JceDecoderTest {
writeLong(long, 4)
writeFloat(float, 5)
writeDouble(double, 6)
writeFully(byteArray, 7)
writeFully(byteArray2, 8)
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as TestSimpleJceStruct
if (string != other.string) return false
if (byte != other.byte) return false
if (short != other.short) return false
if (int != other.int) return false
if (long != other.long) return false
if (float != other.float) return false
if (double != other.double) return false
if (!byteArray.contentEquals(other.byteArray)) return false
if (!byteArray2.contentEquals(other.byteArray2)) return false
return true
}
override fun hashCode(): Int {
var result = string.hashCode()
result = 31 * result + byte
result = 31 * result + short
result = 31 * result + int
result = 31 * result + long.hashCode()
result = 31 * result + float.hashCode()
result = 31 * result + double.hashCode()
result = 31 * result + byteArray.contentHashCode()
result = 31 * result + byteArray2.contentHashCode()
return result
}
}
@Test
fun testByteArray() {
@Serializable
data class TestByteArray(
@SerialId(0) val byteArray: ByteArray = byteArrayOf(1, 2, 3)
) : JceStruct {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as TestByteArray
if (!byteArray.contentEquals(other.byteArray)) return false
return true
}
override fun hashCode(): Int {
return byteArray.contentHashCode()
}
}
assertEquals(
TestByteArray(),
TestByteArray().toByteArray(TestByteArray.serializer()).loadAs(TestByteArray.serializer())
)
}
@Test
fun testSimpleStruct() {
assertEquals(
TestSimpleJceStruct(),
TestSimpleJceStruct().toByteArray(TestSimpleJceStruct.serializer()).loadAs(TestSimpleJceStruct.serializer())
)
}
@Serializable
class TestComplexJceStruct(
@SerialId(6) val string: String = "haha",
......@@ -77,7 +150,7 @@ class JceDecoderTest {
@Test
fun testNestedList() {
@Serializable
class TestNestedList(
data class TestNestedList(
@SerialId(7) val array: List<List<Int>> = listOf(listOf(1, 2, 3), listOf(1, 2, 3), listOf(1, 2, 3))
) : JceStruct
......@@ -133,6 +206,28 @@ class JceDecoderTest {
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString(), "{01=[0x0002(2)]}")
}
@Test
fun testMap3() {
@Serializable
class TestNestedMap(
@SerialId(7) val map: Map<Byte, ShortArray> = mapOf(1.toByte() to shortArrayOf(2))
) : JceStruct
assertEquals("{0x01(1)=[0x0002(2)]}", buildJcePacket {
writeMap(mapOf(1.toByte() to shortArrayOf(2)), 7)
}.readBytes().loadAs(TestNestedMap.serializer()).map.contentToString())
}
@Test
fun testNestedMap2() {
@Serializable
class TestNestedMap(
@SerialId(7) val map: Map<Int, Map<Byte, ShortArray>> = mapOf(1 to mapOf(1.toByte() to shortArrayOf(2)))
) : JceStruct
assertEquals(buildJcePacket {
writeMap(mapOf(1 to mapOf(1.toByte() to shortArrayOf(2))), 7)
}.readBytes().loadAs(TestNestedMap.serializer()).map.entries.first().value.contentToString(), "{0x01(1)=[0x0002(2)]}")
}
@Test
fun testNullableEncode() {
......@@ -186,6 +281,9 @@ class JceDecoderTest {
@SerialId(0) val innerStructList: List<TestSimpleJceStruct>
) : JceStruct
println(buildJcePacket {
writeCollection(listOf(TestSimpleJceStruct(), TestSimpleJceStruct()), 0)
}.readBytes().loadAs(OuterStruct.serializer()).innerStructList.toString())
assertEquals(
buildJcePacket {
writeCollection(listOf(TestSimpleJceStruct(), TestSimpleJceStruct()), 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment