Skip to content
81 changes: 65 additions & 16 deletions ultraplot/internals/rcsetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import functools
import re, matplotlib as mpl
import threading
from collections.abc import MutableMapping
from numbers import Integral, Real

Expand Down Expand Up @@ -562,16 +563,42 @@ def _yaml_table(rcdict, comment=True, description=False):

class _RcParams(MutableMapping, dict):
"""
A simple dictionary with locked inputs and validated assignments.
A thread-safe dictionary with validated assignments and thread-local storage used to store the configuration of UltraPlot.

It uses reentrant locks (RLock) to ensure that multiple threads can safely read and write to the configuration without causing data corruption.

Example
-------
>>> with rc_params:
... rc_params['key'] = 'value' # Thread-local change
... # Changes are automatically cleaned up when exiting the context
"""

# NOTE: By omitting __delitem__ in MutableMapping we effectively
# disable mutability. Also disables deleting items with pop().
def __init__(self, source, validate):
self._validate = validate
self._lock = threading.RLock()
self._local = threading.local()
self._local.changes = {} # Initialize thread-local storage
# Register all initial keys in the validation dictionary
for key in source:
if key not in validate:
validate[key] = lambda x: x # Default validator
for key, value in source.items():
self.__setitem__(key, value) # trigger validation

def __enter__(self):
"""Context manager entry - initialize thread-local storage if needed."""
if not hasattr(self._local, "changes"):
self._local.changes = {}
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - clean up thread-local storage."""
if hasattr(self._local, "changes"):
del self._local.changes

def __repr__(self):
return RcParams.__repr__(self)

Expand All @@ -587,22 +614,33 @@ def __iter__(self):
yield from sorted(dict.__iter__(self))

def __getitem__(self, key):
key, _ = self._check_key(key)
return dict.__getitem__(self, key)
with self._lock:
key, _ = self._check_key(key)
# Check thread-local storage first
if key in self._local.changes:
return self._local.changes[key]
# Check global dictionary (will raise KeyError if not found)
return dict.__getitem__(self, key)

def __setitem__(self, key, value):
key, value = self._check_key(key, value)
if key not in self._validate:
raise KeyError(f"Invalid rc key {key!r}.")
try:
value = self._validate[key](value)
except (ValueError, TypeError) as error:
raise ValueError(f"Key {key}: {error}") from None
if key is not None:
dict.__setitem__(self, key, value)

@staticmethod
def _check_key(key, value=None):
with self._lock:
key, value = self._check_key(key, value)
# Validate the value
try:
value = self._validate[key](value)
except KeyError:
# If key doesn't exist in validation, add it with default validator
self._validate[key] = lambda x: x
# Re-validate with new validator
value = self._validate[key](value)
except (ValueError, TypeError) as error:
raise ValueError(f"Key {key}: {error}") from None
if key is not None:
# Store in both thread-local storage and main dictionary
self._local.changes[key] = value
dict.__setitem__(self, key, value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block doesn't make sense to me. I think we need to test if we're in the context manager block and then only make a local change if we are. Otherwise threads will make changes to the global settings which could interfere with each other.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am leaning towards moving this logic (thread safety) inside the Configurator we both need to control ultraplot's rc and matplotlib's rc making this PR a bit moot and #384 an essential step in this direction.


def _check_key(self, key, value=None):
# NOTE: If we assigned from the Configurator then the deprecated key will
# still propagate to the same 'children' as the new key.
# NOTE: This also translates values for special cases of renamed keys.
Expand All @@ -624,10 +662,21 @@ def _check_key(key, value=None):
f"The rc setting {key!r} was removed in version {version}."
+ (info and " " + info)
)
# Register new keys in the validation dictionary
if key not in self._validate:
self._validate[key] = lambda x: x # Default validator
return key, value

def copy(self):
source = {key: dict.__getitem__(self, key) for key in self}
with self._lock:
# Create a copy that includes both global and thread-local changes
source = {}
# Start with global values
for key in self:
if key not in self._local.changes:
source[key] = dict.__getitem__(self, key)
# Add thread-local changes
source.update(self._local.changes)
return _RcParams(source, self._validate)


Expand Down
197 changes: 178 additions & 19 deletions ultraplot/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
import ultraplot as uplt, pytest
import importlib
import threading
import time


def test_wrong_keyword_reset():
"""
The context should reset after a failed attempt.
"""
# Init context
uplt.rc.context()
config = uplt.rc
# Set a wrong key
with pytest.raises(KeyError):
config._get_item_dicts("non_existing_key", "non_existing_value")
# Set a known good value
config._get_item_dicts("coastcolor", "black")
# Confirm we can still plot
fig, ax = uplt.subplots(proj="cyl")
ax.format(coastcolor="black")
fig.canvas.draw()
# Use context manager for temporary rc changes
# Use context manager with direct value setting
with uplt.rc.context(coastcolor="black"):
# Set a wrong key
with pytest.raises(KeyError):
uplt.rc._get_item_dicts("non_existing_key", "non_existing_value")
# Confirm we can still plot
fig, ax = uplt.subplots(proj="cyl")
ax.format(coastcolor="black")
fig.canvas.draw()


def test_cycle_in_rc_file(tmp_path):
"""
Test that loading an rc file correctly overwrites the cycle setting.
"""
rc = uplt.config.Configurator()
rc_content = "cycle: colorblind"
rc_file = tmp_path / "test.rc"
rc_file.write_text(rc_content)

# Load the file directly. This should overwrite any existing settings.
uplt.rc.load(str(rc_file))

assert uplt.rc["cycle"] == "colorblind"
rc.load(str(rc_file))
assert rc["cycle"] == "colorblind"


import io
Expand Down Expand Up @@ -96,6 +94,95 @@ def test_dev_version_skipped(mock_urlopen, mock_version, mock_print):
mock_print.assert_not_called()


def test_rcparams_thread_safety():
"""
Test that _RcParams is thread-safe when accessed concurrently.
Each thread works with its own unique key to verify proper isolation.
Thread-local changes are properly managed with context manager.
"""
# Create a new _RcParams instance for testing
from ultraplot.internals.rcsetup import _RcParams

# Initialize with base keys
base_keys = {f"base_key_{i}": f"base_value_{i}" for i in range(3)}
rc_params = _RcParams(base_keys, {k: lambda x: x for k in base_keys})

# Number of threads and operations per thread
num_threads = 5
operations_per_thread = 20

# Each thread will work with its own unique key
thread_keys = {}

def worker(thread_id):
"""Thread function that works with its own unique key using context manager."""
# Each thread gets its own unique key
thread_key = f"thread_{thread_id}_key"
thread_keys[thread_id] = thread_key

# Use context manager to ensure proper thread-local cleanup
with rc_params:
# Initialize the key with a base value
rc_params[thread_key] = f"initial_{thread_id}"

# Perform operations
for i in range(operations_per_thread):
try:
# Read the current value
current = rc_params[thread_key]

# Update with new value
new_value = f"thread_{thread_id}_value_{i}"
rc_params[thread_key] = new_value

# Verify the update worked
assert rc_params[thread_key] == new_value

# Also read some base keys to test mixed access
if i % 5 == 0:
base_key = f"base_key_{i % 3}"
base_value = rc_params[base_key]
assert isinstance(base_value, str)

except Exception as e:
raise AssertionError(f"Thread {thread_id} failed: {str(e)}")

# Create and start threads
threads = []
for i in range(num_threads):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()

# Wait for all threads to complete
for t in threads:
t.join()

# Verify each thread's key exists and has the expected final value
for thread_id in range(num_threads):
thread_key = thread_keys[thread_id]
assert thread_key in rc_params, f"Thread {thread_id}'s key was lost"
final_value = rc_params[thread_key]
assert final_value == f"thread_{thread_id}_value_{operations_per_thread - 1}"

# Verify base keys are still intact
for key, expected_value in base_keys.items():
assert key in rc_params, f"Base key {key} was lost"
assert rc_params[key] == expected_value, f"Base key {key} value was corrupted"

# Verify that thread-local changes are properly merged
# Create a copy to verify the copy includes thread-local changes
rc_copy = rc_params.copy()
assert len(rc_copy) == len(base_keys) + num_threads, "Copy doesn't include all keys"

# Verify all keys are in the copy
for key in base_keys:
assert key in rc_copy, f"Base key {key} missing from copy"
for thread_id in range(num_threads):
thread_key = thread_keys[thread_id]
assert thread_key in rc_copy, f"Thread {thread_id}'s key missing from copy"


@pytest.mark.parametrize(
"cycle, raises_error",
[
Expand All @@ -117,6 +204,78 @@ def test_cycle_rc_setting(cycle, raises_error):
"""
if raises_error:
with pytest.raises(ValueError):
uplt.rc["cycle"] = cycle
with uplt.rc.context(cycle=cycle):
pass
else:
uplt.rc["cycle"] = cycle
with uplt.rc.context(cycle=cycle):
pass


def test_rc_check_key():
"""
Test the _check_key method in _RcParams
"""
from ultraplot.internals.rcsetup import _RcParams

# Create a test instance
rc_params = _RcParams({"test_key": "test_value"}, {"test_key": lambda x: x})

# Test valid key
key, value = rc_params._check_key("test_key", "new_value")
assert key == "test_key"
assert value == "new_value"

# Test new key (should be registered with default validator)
key, value = rc_params._check_key("new_key", "new_value")
assert key == "new_key"
assert value == "new_value"
assert "new_key" in rc_params._validate


def test_rc_repr():
"""
Test the __repr__ method in _RcParams
"""
from ultraplot.internals.rcsetup import _RcParams

# Create a test instance
rc_params = _RcParams({"test_key": "test_value"}, {"test_key": lambda x: x})

# Test __repr__
repr_str = repr(rc_params)
assert "RcParams" in repr_str
assert "test_key" in repr_str


def test_rc_validators():
"""
Test validators in _RcParams
"""
from ultraplot.internals.rcsetup import _RcParams

# Create a test instance with various validators
validators = {
"int_val": lambda x: int(x),
"float_val": lambda x: float(x),
"str_val": lambda x: str(x),
}
rc_params = _RcParams(
{"int_val": 1, "float_val": 1.0, "str_val": "test"}, validators
)

# Test valid values
rc_params["int_val"] = 2
assert rc_params["int_val"] == 2

rc_params["float_val"] = 2.5
assert rc_params["float_val"] == 2.5

rc_params["str_val"] = "new_value"
assert rc_params["str_val"] == "new_value"

# Test invalid values
with pytest.raises(ValueError):
rc_params["int_val"] = "not_an_int"

with pytest.raises(ValueError):
rc_params["float_val"] = "not_a_float"