More safe API to manage running fibers and event loops

This commit is contained in:
Andrew Golovashevich 2025-09-10 08:12:57 +03:00
parent 0d5f9775dd
commit 538ede8206
7 changed files with 91 additions and 86 deletions

View File

@ -0,0 +1,11 @@
package ru.landgrafhomyak.multitasking_0.threads
import ru.landgrafhomyak.multitasking_0.WrongCallerThreadException
import ru.landgrafhomyak.multitasking_0.fibers.Fiber
/**
* @see ThreadLocalMethods.setEventLoop
*/
public interface EventLoopClearer {
public fun clearEventLoop()
}

View File

@ -0,0 +1,19 @@
package ru.landgrafhomyak.multitasking_0.threads
import ru.landgrafhomyak.multitasking_0.WrongCallerThreadException
import ru.landgrafhomyak.multitasking_0.fibers.Fiber
/**
* @see ThreadLocalMethods.enterFiber
*/
public interface FiberExiter {
public val callerFiber: Fiber?
/**
* Informs thread that execution flow was returned from [fiber][Fiber].
*
* @throws IllegalStateException if [fiberToExit] doesn't match [current running fiber][ThreadLocalMethods.runningFiber]
* or this fiber isn't on top of fibers stack.
* @throws WrongCallerThreadException if called on thread different that caller of [Thread.current].
*/
public fun exitFiber()
}

View File

@ -29,30 +29,13 @@ public expect sealed interface ThreadLocalMethods {
* Informs thread that execution flow was switched to [fiber][Fiber]. * Informs thread that execution flow was switched to [fiber][Fiber].
* *
* @param fiberToEnter descriptor of fiber that is [going to run][Fiber.resume]. * @param fiberToEnter descriptor of fiber that is [going to run][Fiber.resume].
* @param privateToken secret reference that should be used to call [exitFiber()][ThreadLocalMethods.exitFiber]. * @return Callback object to rollback value of [Thread.current.runningFiber][ThreadLocalMethods.runningFiber]
* Need to avoid clearing [information about running fiber][ThreadLocalMethods.runningFiber] * to state before this call.
* by anyone except implementation of this fiber.
* @return [fiber][Fiber] in which this function was called or `null` if it was called in thread.
* *
* @throws WrongCallerThreadException if called on thread different that caller of [Thread.current]. * @throws WrongCallerThreadException if called on thread different that caller of [Thread.current].
*/ */
public fun enterFiber(fiberToEnter: Fiber, privateToken: Any): Fiber? public fun enterFiber(fiberToEnter: Fiber): FiberExiter
/**
* Informs thread that execution flow was returned from [fiber][Fiber].
*
* @param fiberToExit descriptor of fiber that is [going to return execution flow][Fiber.yield].
* @param privateToken secret reference that was passed to [enterFiber()][ThreadLocalMethods.exitFiber].
* Used to check that this function called by same caller as [enterFiber()][ThreadLocalMethods.exitFiber].
* @param fiberToRestore return value of corresponding [enterFiber()][ThreadLocalMethods.exitFiber] call.
*
* @throws IllegalStateException if [fiberToExit] or [fiberToRestore] doesn't match.
* @throws IllegalArgumentException if [privateToken] doesn't match.
* @throws WrongCallerThreadException if called on thread different that caller of [Thread.current].
*/
public fun exitFiber(fiberToExit: Fiber, privateToken: Any, fiberToRestore: Fiber?)
public val runningEventLoop: SingleThreadEventLoop? public val runningEventLoop: SingleThreadEventLoop?
public fun setEventLoop(eventLoop: SingleThreadEventLoop, privateToken: Any) public fun setEventLoop(eventLoop: SingleThreadEventLoop): EventLoopClearer
public fun clearEventLoop(eventLoop: SingleThreadEventLoop, privateToken: Any)
} }

View File

@ -1,21 +0,0 @@
package ru.landgrafhomyak.multitasking_0.threads
import kotlin.jvm.JvmField
import ru.landgrafhomyak.multitasking_0.fibers.Fiber
internal class _FibersStackNode(
@JvmField
val caller: _FibersStackNode?,
@JvmField
val fiber: Fiber,
@JvmField
val privateToken: Any,
) {
fun assertExit(fiberToExit: Fiber, privateToken: Any, fiberToRestore: Fiber?) {
if (this.fiber !== fiberToExit || this.caller?.fiber !== fiberToRestore)
throw IllegalStateException("fiberToExit or fiberToRestore doesn't match")
if (this.privateToken !== privateToken)
throw IllegalArgumentException("privateToken doesn't match")
}
}

View File

@ -1,4 +1,4 @@
package ru.landgrafhomyak.multitasking_0.threads.impl.java_virtual_threads package ru.landgrafhomyak.multitasking_0.impl.java_virtual_threads
import java.lang.Object as jObject import java.lang.Object as jObject
import java.lang.Thread as jThread import java.lang.Thread as jThread
@ -9,6 +9,7 @@ import ru.landgrafhomyak.multitasking_0.fibers.Fiber
import ru.landgrafhomyak.multitasking_0.fibers.FiberRoutine import ru.landgrafhomyak.multitasking_0.fibers.FiberRoutine
import ru.landgrafhomyak.multitasking_0.WrongCallerThreadException import ru.landgrafhomyak.multitasking_0.WrongCallerThreadException
import ru.landgrafhomyak.multitasking_0.fibers.FiberInterruptedException import ru.landgrafhomyak.multitasking_0.fibers.FiberInterruptedException
import ru.landgrafhomyak.multitasking_0.threads.FiberExiter
import ru.landgrafhomyak.multitasking_0.threads.Thread as wThread import ru.landgrafhomyak.multitasking_0.threads.Thread as wThread
public class JavaVirtualThreadFiber : Fiber { public class JavaVirtualThreadFiber : Fiber {
@ -19,7 +20,7 @@ public class JavaVirtualThreadFiber : Fiber {
private var _uncaughtException: Throwable? private var _uncaughtException: Throwable?
override val name: String override val name: String
private var _resumedOnThread: wThread? private var _resumedOnThread: wThread?
private var _resumedOnFiber: Fiber? private var _fiberExiter: FiberExiter?
private var _interruptionData: InterruptionData? private var _interruptionData: InterruptionData?
public constructor(name: String, routine: FiberRoutine) { public constructor(name: String, routine: FiberRoutine) {
@ -29,7 +30,7 @@ public class JavaVirtualThreadFiber : Fiber {
this._state = Fiber.State.CREATED this._state = Fiber.State.CREATED
this._uncaughtException = null this._uncaughtException = null
this._resumedOnThread = null this._resumedOnThread = null
this._resumedOnFiber = null this._fiberExiter = null
this._interruptionData = null this._interruptionData = null
this._syncLock.withLock { this._syncLock.withLock {
@ -84,7 +85,7 @@ public class JavaVirtualThreadFiber : Fiber {
get() = this._syncLock.withLock { get() = this._syncLock.withLock {
if (this._state == Fiber.State.DESTROYED) if (this._state == Fiber.State.DESTROYED)
throw IllegalStateException("Fiber destroyed") throw IllegalStateException("Fiber destroyed")
return@withLock this._resumedOnFiber return@withLock this._fiberExiter?.callerFiber
} }
override fun yield() { override fun yield() {
@ -103,7 +104,8 @@ public class JavaVirtualThreadFiber : Fiber {
Fiber.State.DESTROYED -> throw IllegalStateException("Fiber destroyed") Fiber.State.DESTROYED -> throw IllegalStateException("Fiber destroyed")
} }
this._fiberExiter!!.exitFiber()
this._fiberExiter = null
wThread._tl_currentThread.set(null) wThread._tl_currentThread.set(null)
this._syncCond.signal() this._syncCond.signal()
this._syncCond.await() this._syncCond.await()
@ -137,18 +139,15 @@ public class JavaVirtualThreadFiber : Fiber {
Fiber.State.DESTROYED -> throw IllegalStateException("Fiber destroyed") Fiber.State.DESTROYED -> throw IllegalStateException("Fiber destroyed")
} }
val token = jObject()
val currentThread = wThread.current val currentThread = wThread.current
this._resumedOnThread = currentThread.get() this._resumedOnThread = currentThread.get()
this._resumedOnFiber = currentThread.enterFiber(this, token) this._fiberExiter = currentThread.enterFiber(this)
this._interruptionData = id this._interruptionData = id
this._syncCond.signal() this._syncCond.signal()
this._syncCond.await() this._syncCond.await()
currentThread.exitFiber(this, token, this.resumedOnFiber)
this._resumedOnThread = null this._resumedOnThread = null
this._resumedOnFiber = null
} }
} }

View File

@ -1,5 +1,6 @@
package ru.landgrafhomyak.multitasking_0.threads package ru.landgrafhomyak.multitasking_0.threads
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
import java.lang.Thread as jThread import java.lang.Thread as jThread
@ -159,7 +160,33 @@ public actual sealed class Thread {
return this@Thread return this@Thread
} }
private var _fibersStack: _FibersStackNode? = null private inner class FibersStackNode(
@JvmField
val caller: FibersStackNode?,
@JvmField
val fiber: Fiber,
) : FiberExiter {
private val _isClosed = AtomicBoolean(false)
override val callerFiber: Fiber?
get() {
if (this._isClosed.get()) throw IllegalStateException("Fiber already exited and value in this property isn't actual")
return this.caller?.fiber
}
override fun exitFiber() {
if (this._isClosed.compareAndExchange(false, true))
throw IllegalStateException("Fiber already exited")
if (this@ThreadLocalMethodsImpl._fibersStack !== this) {
this._isClosed.set(false)
throw IllegalStateException("Fiber being exited isn't on top of fibers stack")
}
this@ThreadLocalMethodsImpl._fibersStack = this.caller
}
}
private var _fibersStack: FibersStackNode? = null
override val runningFiber: Fiber? override val runningFiber: Fiber?
get() { get() {
@ -167,27 +194,14 @@ public actual sealed class Thread {
return this._fibersStack?.fiber return this._fibersStack?.fiber
} }
override fun enterFiber(fiberToEnter: Fiber, privateToken: Any): Fiber? { override fun enterFiber(fiberToEnter: Fiber): FiberExiter {
this._assertThread() this._assertThread()
val caller = this._fibersStack val node = this.FibersStackNode(this._fibersStack, fiberToEnter)
this._fibersStack = _FibersStackNode(caller, fiberToEnter, privateToken) this._fibersStack = node
return caller?.fiber return node
}
override fun exitFiber(fiberToExit: Fiber, privateToken: Any, fiberToRestore: Fiber?) {
this._assertThread()
val top = this._fibersStack
if (top == null)
throw IllegalStateException("There is no running fiber to exit")
top.assertExit(fiberToExit, privateToken, fiberToRestore)
this._fibersStack = top.caller
} }
private var _runningEventLoop: SingleThreadEventLoop? = null private var _runningEventLoop: SingleThreadEventLoop? = null
private var _runningEventLoopToken: Any? = null
override val runningEventLoop: SingleThreadEventLoop? override val runningEventLoop: SingleThreadEventLoop?
get() { get() {
@ -196,23 +210,25 @@ public actual sealed class Thread {
} }
override fun setEventLoop(eventLoop: SingleThreadEventLoop, privateToken: Any) { private inner class EventLoopClearerImpl(
) : EventLoopClearer {
private val _isClosed = AtomicBoolean(false)
override fun clearEventLoop() {
if (this._isClosed.compareAndExchange(false, true))
throw IllegalStateException("Event loop already cleared")
this@ThreadLocalMethodsImpl._runningEventLoop = null
}
}
override fun setEventLoop(eventLoop: SingleThreadEventLoop): EventLoopClearer {
this._assertThread() this._assertThread()
if (this._runningEventLoop != null) if (this._runningEventLoop != null)
throw IllegalStateException("There is already a running event loop here") throw IllegalStateException("There is already a running event loop here")
this._runningEventLoop = eventLoop this._runningEventLoop = eventLoop
this._runningEventLoopToken = privateToken return this.EventLoopClearerImpl()
}
override fun clearEventLoop(eventLoop: SingleThreadEventLoop, privateToken: Any) {
this._assertThread()
if (this._runningEventLoop !== eventLoop)
throw IllegalStateException("Currently another event loop is running")
if (this._runningEventLoopToken !== privateToken)
throw IllegalArgumentException("privateToken doesn't match")
this._runningEventLoop = null
this._runningEventLoopToken = null
} }
} }
} }

View File

@ -7,10 +7,8 @@ public actual sealed interface ThreadLocalMethods {
public actual fun get(): Thread public actual fun get(): Thread
public actual val runningFiber: Fiber? public actual val runningFiber: Fiber?
public actual fun enterFiber(fiberToEnter: Fiber, privateToken: Any): Fiber? public actual fun enterFiber(fiberToEnter: Fiber): FiberExiter
public actual fun exitFiber(fiberToExit: Fiber, privateToken: Any, fiberToRestore: Fiber?)
public actual val runningEventLoop: SingleThreadEventLoop? public actual val runningEventLoop: SingleThreadEventLoop?
public actual fun setEventLoop(eventLoop: SingleThreadEventLoop, privateToken: Any) public actual fun setEventLoop(eventLoop: SingleThreadEventLoop): EventLoopClearer
public actual fun clearEventLoop(eventLoop: SingleThreadEventLoop, privateToken: Any)
} }