[FIX] Pooled global/groupby lag transforms to use RANGE semantics#641
Conversation
Introduce PooledState for global/groupby transform state, compute pooled features from raw bucketed observations instead of summed timestamps, and add update/new-group/categorical coverage.
|
Will this help finding a smarter solution for the partition_by problem? As I just mentioned on the issue itself, I still like option B (from the issue) as I see reasons why we would first do the sum over ds before rolling window functions. For example, check out this entry in discussion: |
|
Thanks Jan, yes I believe this approach would help bridge the gap between I'll try and wrap up soon with the new PR which branches off of this branch and includes the Regarding the issue just raised, I believe this would work nicely with either implementation, and in my use cases I would see it more beneficial to go with Option A - using |
|
Codex Review: Didn't find any major issues. More of your lovely PRs please. ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
|
One thing worth noting here which wasn't included in the PR description: the new logic treats the So now if we perform a Now this behavior diverges from the single series implementation, where min_samples is always capped at window_size if the user passes a higher value. I'm going to add better docs to this branch and explain this behavior more explicitly, pending discussion. In my opinion we could add in a separate feature branch an optional parameter |
I realised this and had to think about this for a minute. I think totally fine. If people do a detailed deep dive into preprocess and inspect data they might realise the difference, but it is logically consistent and in line with the sql notation. Making a few more doc strings will do the job for me! |
|
Thanks for taking the time to review and for the feedback, really appreciate it! @janrth
|
In local (per-series) mode coreforecast caps min_samples at window_size, but in pooled mode (global_/groupby) min_samples counts total non-NaN observations across all series in the bucket with no capping. This documents the divergence in the _RollingBase and _Seasonal_RollingBase docstrings. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pooled (global/groupby) lag transforms assume a continuous, gap-free time grid. Emit a UserWarning in preprocess() when the user disables data validation so gaps don't silently produce incorrect feature values. The warning is suppressed for cross_validation's internal fit calls since that path validates the full dataset upfront. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce _TimestampAggregates on PooledState that pre-computes per-bucket sums/counts/n_rows by timestamp. RollingMean uses the compact T-length arrays instead of the full (n_series * T)-length row arrays, reducing the working set from O(n_series * T) to O(T). Other transforms keep the existing row-level approach and are marked with TODOs for future migration. The cache is built in from_global/ from_groupby, updated incrementally in append_predictions, and rebuilt in append_observations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
During recursive prediction, RollingMean now computes the latest
timestamp's feature value directly from cached _TimestampAggregates
via _compute_latest_from_aggs, avoiding the O(n_series * T) query
array construction. Transforms that don't support the fast path
(all others currently) fall back to the existing build_query_arrays
path.
Benchmark (10k series, 100 timestamps, RollingMean(28), 20 recursive
steps, 3 repeats — timing isolates _update_features + _update_y):
global_=True:
previous: 1.705s (85.25ms/step) → current: 0.138s (6.92ms/step) — 12.3x
groupby=["brand"], 100 groups × 100 series:
previous: 5.672s (283.59ms/step) → current: 0.330s (16.51ms/step) — 17.2x
fit_transform time unchanged (~0.52s global, ~1.10s groupby). Checksum
comparison of the first 3 recursive steps confirmed identical sums and
NaN counts for both global and groupby cases.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
I've added four commits addressing the comments you raised:
Here's the specifics of the benchmarks: Setup: 10,000 series, 100 timestamps each,
I also ran a short checksum comparison for the first 3 recursive steps on both implementations; sums and NaN counts matched for both global and groupby cases. |
…) helper that extracts the shared cumsum/searchsorted computation. Refactored _compute_from_aggregates() to call it. Added _compute_ts_level_from_aggs() on _BaseLagTransform (returns None), RollingMean (calls helper per bucket), Offset (delegates), and Combine (combines element-wise). core.py: In _transform(), the global block now tries _compute_ts_level_from_aggs first. Transforms that support it get mapped directly to df_sorted rows via np.searchsorted on unique timestamps, bypassing the pandas merge in _join_bucket_features. Unsupported transforms fall through to the existing path.
When min_samples=0, empty windows caused ZeroDivisionError in _compute_latest_from_aggs (predict path) and silently returned 0.0 instead of NaN in _rolling_mean_from_agg (preprocess path). Add win_cnt > 0 guards to all five computation paths (_rolling_mean_from_agg, _compute_latest_from_aggs, _compute_row_level, _RollingBase, and _Seasonal_RollingBase) so empty windows consistently produce NaN. Warn at init time when min_samples=0 is used with global_/groupby. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…paths Add _expanding_mean_from_agg and _ewm_from_agg helpers that compute features from cached per-timestamp aggregates instead of the O(M*N_b) per-timestamp Python loop. ExpandingMean uses cumsum of sums/counts with searchsorted boundary lookup. EWM replaces the two-pass approach (per-timestamp mean loop + sequential scan) with a single pass over pre-computed sums/counts. Both transforms now support _compute_from_aggregates (fit), _compute_latest_from_aggs (predict), and _compute_ts_level_from_aggs (preprocess) fast paths. Offset and Combine delegate _compute_latest_from_aggs to their wrapped transforms. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace Python loops over series_bucket_id with vectorized numpy lookup arrays for both the fast path (_compute_latest_from_aggs) and slow path (query_arrays) in _update_features groupby handling. Sizes lookup array to cover all bucket IDs from both the result dict and series_bucket_id to prevent out-of-bounds indexing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add per-timestamp sum-of-squares, minimum, and maximum to the cached aggregates. sum_sq enables RollingStd/ExpandingStd fast paths via Bessel-corrected variance from prefix sums. mins/maxs enable RollingMin/Max/ExpandingMin/Max via sparse table or prefix reduction. _build_ts_aggs computes mins/maxs using np.minimum.at/np.maximum.at for O(N_b) vectorized construction. append_predictions updates all three new fields incrementally for both global and groupby buckets. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Great work!
I was wondering if we could use the same timestamp-level aggregate fast paths for min, max and std. Currently they are using the generic pooled masking path, which is slower.
I know it is annoying and honestly I did not have that in mind earlier. But maybe we can have one more check for min, max and std.
I think for quantile rolling we would need a more complex data structure as it needs the full distribution in the rolling window. So let's leave this out for now, but would be great if you could have a final look at min, max, std and see if they can use the fast timestamp path.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e2620175c7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Thanks @janrth ! I already started implementing those and yes, for quantile we need to fall back on the "slow" implementation but for everything else it should still be solid |
The pooled EWM accumulator was consuming timestamps 0..k-1 before emitting at timestamp k, ignoring the lag parameter. With lag=L, the output at timestamp k should only reflect timestamps 0..k-L. Fix all three code paths (_ewm_from_agg, slow path, _compute_latest_from_aggs) to use a two-pointer approach where consume_idx only advances up to unique_times[k] - lag. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ingStd, ExpandingMin/Max Extend the per-timestamp aggregate optimization to all remaining decomposable transforms: - RollingStd/ExpandingStd: cumsum of sums, counts, sum_sq with Bessel-corrected variance formula - RollingMin/Max: sparse table (O(n log n) build, O(1) query) over per-timestamp mins/maxs - ExpandingMin/Max: prefix min/max via np.fmin/fmax.accumulate Each transform gets _compute_bucket_feature (fit), _compute_ts_level_from_aggs (preprocess), and _compute_latest_from_aggs (predict) fast paths. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add _idsorted_to_bucket_pos permutation to PooledState, built once at fit time using ufp.sort for categorical-safe ordering. During preprocess, groupby transforms with _compute_ts_level_from_aggs now map results directly via the permutation instead of going through _join_bucket_features (which does an O(n log n) pandas/polars merge on every call). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Parameterize all existing pooled tests over lag=[1, 3] so structural and numerical assertions run at higher lags. Add targeted tests: - test_ewm_lag_semantics: hand-computed EWM values at lag=2 verifying the two-pointer consumption fix (global + groupby) - test_pooled_transforms_lag3_global: all 8 decomposable transforms with lag=3, checking preprocess + predict against expected values - test_pooled_transforms_lag2_groupby: 7 transforms in groupby mode - test_fast_vs_slow_equivalence: parameterized over all 9 transforms × lag=[1,3], exercises fit (_compute_from_aggregates), preprocess (_compute_ts_level_from_aggs), and predict (_compute_latest_from_aggs) paths by comparing aggregate fast path vs row-level slow path Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Independent verification of lag-transforms using global_ or groupby mode with RANGE BETWEEN semantics using SQLite window functions. Covers 8 transforms × 3 lags × global/groupby, multi-column groupby, custom min_samples, staggered starts, and random stress tests (61 cases total). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
I've added another sanity check by comparing results of the transforms with SQL equivalent semantics, on top of the hard-coded values in the other tests Now all transforms that support groupby/global mode have been reimplemented and where possible are optimized for efficiency by avoiding joins and broadcasting values instead. Let me know what you think, I'd love to see this merged soon and start implementing the It would also be great to have it be part of the milestones in the next release, what do you think? |
Give ma a few days, but will look into your latest changes once I have a bit of time. But I have the feeling you have gotten pretty far and we are hopefully close to merging :) |
…slow-path implementation
|
Hey @simonez-tuidi this is a great job! I was ooo but great to see your work. Shout-out to @janrth for diving in with you. In order to merge I would argue that we need to work on:
If you're okay I can finish this tasks for merging. |
|
@nasaul thanks for that, absolutely go ahead, totally agree with switching to Narwhals to make it more maintainable On adding I can see other transforms being beneficial - like Counts - and we would also need to find the right design (e.g. do we sum all observations or is it intended to sum per-timestamp and then average?) Another issue that was raised in #644 would be nice to see merged in its own PR - allowing non-local and local transforms to be used with Finally, I was also wondering if these changes will eventually end up being migrated to coreforecast instead, but I would say it'd be great to start merging in this repo and port them later once stable |
|
Okay, I agree that it should be another PR. Going into |
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
nasaul
left a comment
There was a problem hiding this comment.
Thanks for this PR, great work!
PR Description
Summary
This PR reworks
global_andgroupbylag transforms so pooled features are computed over the underlying observations in each time range, matching SQL-styleRANGE BETWEEN ... PRECEDINGsemantics.This implements Option A from the issue #640 : change the default
global_/groupbybehavior to RANGE semantics instead of preserving the current sum-then-roll behavior.Previously, global and grouped transforms were backed by separate ad hoc state paths that aggregated each timestamp using
sumbefore applying the lag transform. That made transforms such as:behave like transforms over per-timestamp sums rather than transforms over all rows in the relevant time window. This branch introduces a shared pooled state representation and computes pooled transforms directly from bucketed observation arrays.
Problem Addressed
When multiple series share a pooled bucket, the old implementation first collapsed the data by timestamp:
Then
RollingMean(window_size=2, lag=1, global_=True)operated on[11, 22, 33, 44], producing values such asmean(11, 22) = 16.5.That has two practical problems:
RollingMeanscales with the number of series in the group, because it is effectively averaging sums.With this PR, the same transform operates over the individual observations in the RANGE window. For example, at
ds=3the window contains[1, 10, 2, 20], so the mean is8.25.One detail worth making explicit:
min_samplesis still evaluated over observations. With multiple series in a bucket, a window containing one timestamp can satisfymin_samples=2if that timestamp has two observed rows.What Changed
Added
PooledStateAdded
mlforecast/pooled.pywith aPooledStateobject that owns the state needed by pooled transforms:GroupedArrayfor existing transform state initializationgroupbytransformsThis replaces the previous separate
_global_ga/_global_timesand_group_statescode paths with a single state model.Compute Pooled Features Directly
Lag transforms already expose
_compute_bucket_feature(...)(added in earlier commits on this branch) for rolling, seasonal, expanding, EWM,Offset, andCombinetransforms. This PR routes all pooled computation through those methods via a newcompute_pooled_features()function, and removes the previous silent fallback to positionalGroupedArraybehavior.The old fallback was problematic: when a transform did not implement
_compute_bucket_feature, the code silently fell back to GA positional semantics, which produced incorrect results under RANGE window bounds. Unsupported pooled transforms now raise a clearNotImplementedErrorinstead.SQL-Like Range Semantics
Pooled transforms now use a per-bucket
time_indexderived from the validated regular time grid. For global and groupby transforms, this gives interval-style window bounds while preserving the existing codebase assumption that non-partitioned series do not contain gaps.This means a grouped feature behaves like:
rather than first aggregating
yby(brand, timestamp).The intended equivalence model is:
RollingMean(w, lag=l)AVG(y) OVER (PARTITION BY unique_id RANGE BETWEEN ...)RollingMean(w, lag=l, global_=True)AVG(y) OVER (RANGE BETWEEN ...)RollingMean(w, lag=l, groupby=["brand"])AVG(y) OVER (PARTITION BY brand RANGE BETWEEN ...)Update Path Fixes
The pooled state update path now keeps all related arrays and metadata in sync:
bucket_dfupdate()This also fixes a pre-existing bug in
TimeSeries.update()where new series received wrong static features:ufp.take_rows(df, ...)was indexing into the full DataFrame instead of the new-series subset, causing incorrect bucket assignments for series introduced viaupdate().Categorical Group Key Support
Grouped buckets are represented internally by numeric
_bucket_ids, but public group keys such asbrandorsubcategorymay be categorical. The new helpers reconcile pandas and Polars categoricals before joins and concatenations, including when updates introduce a new group value.Tests
Added
tests/test_pooled.pycovering:Updated existing core tests to assert the new RANGE-style semantics for global and grouped rolling/expanding transforms.
The previous
test_group_lag_transformused one series per group, which made sum-by-timestamp and RANGE semantics indistinguishable. The updated tests include multiple series in the same group so this behavior is covered directly.Verification
Ran:
The full branch suite was also verified with:
Result:
mlforecast/pooled.pyat 98% coverageCompatibility / Breaking Change
The public transform API is unchanged. Existing
global_andgroupbyarguments continue to be used.This is nevertheless a necessary breaking change for users who already rely on
global_orgroupbytransforms with more than one series in a pooled bucket. The numeric output changes from "sum by timestamp, then apply the transform" to "apply the transform over all observations in the RANGE window".For example,
RollingMean(window_size=2, lag=1, global_=True)over two aligned series changes from:to:
This change is intentional because the previous behavior made means scale with the number of series in the group and diverged from the SQL
RANGEmental model used by the rest of the pooled/partitioned transform design. Preserving the old behavior would require adding a separate aggregation mode (for example, "sum by timestamp before transforming"), which would keep the incorrect default and add API complexity. This PR chooses correctness and consistency instead.Users who depended on the old sum-then-transform behavior will need to reproduce that aggregation explicitly before fitting or use a future explicit aggregation option if one is added.
The internal fitted
TimeSeriesstate shape changed, so previously pickled fittedTimeSeriesobjects that depend on the old private pooled state are not expected to be compatible.Linking issue
Closes #640