aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSam Anthony <sam@samanthony.xyz>2026-03-11 10:24:30 -0400
committerSam Anthony <sam@samanthony.xyz>2026-03-11 10:24:30 -0400
commit21732d319a2dfd061d798f3e5bfa6fdacc28be38 (patch)
tree89aa30991cb613f9fcf44e4598f9f09db39b49f2
parent9a2c345cfd3697a5de3f5fd799f682ab822e3756 (diff)
downloadgui-21732d319a2dfd061d798f3e5bfa6fdacc28be38.zip
lay/strain: replace Solver mutex with channels
Partially reverts 483236742ddcd7883b5f9cff92244129274aa79c. Solver still closes itself automatically, but reverted to using request/reply channels. Mutex was becoming annoying to manage, as expected. TODO: TestRows fails intermittently due to constraint channels not being flushed when Solve() is called.
-rw-r--r--internal/tag/tag.go16
-rw-r--r--lay/strain/solve.go247
-rw-r--r--lay/strain/solve_test.go85
3 files changed, 239 insertions, 109 deletions
diff --git a/internal/tag/tag.go b/internal/tag/tag.go
index 197d2d6..b4b627d 100644
--- a/internal/tag/tag.go
+++ b/internal/tag/tag.go
@@ -1,12 +1,22 @@
package tag
+import "context"
+
type Tagged[V, T any] struct {
Val V
Tag T
}
-func Tag[V, T any](out chan<- Tagged[V, T], in <-chan V, f func(V) T) {
- for val := range in {
- out <- Tagged[V, T]{val, f(val)}
+func Tag[V, T any](ctx context.Context, out chan<- Tagged[V, T], in <-chan V, f func(V) T) {
+ for {
+ select {
+ case val, ok := <-in:
+ if !ok {
+ return
+ }
+ out <- Tagged[V, T]{val, f(val)}
+ case <-ctx.Done():
+ return
+ }
}
}
diff --git a/lay/strain/solve.go b/lay/strain/solve.go
index cfb1f41..41cabc5 100644
--- a/lay/strain/solve.go
+++ b/lay/strain/solve.go
@@ -1,8 +1,10 @@
package strain
import (
+ "context"
"fmt"
"image"
+ "math"
"sync"
"github.com/lithdew/casso"
@@ -13,8 +15,9 @@ import (
)
const (
- fieldConstraintPriority = casso.Medium
- layoutConstraintPriority = casso.Medium
+ solverPriority = casso.Medium
+ layoutPriority = casso.Medium
+ fieldPriority = casso.Medium
)
// Solver uses the Cassowary algorithm to partition a rectangle among
@@ -27,18 +30,13 @@ const (
// and size of fields within the container, while the field
// constraints control the width and height of each field.
type Solver struct {
- solver *casso.Solver
- style *style.Style
- fieldConstrs chan tag.Tagged[Constraint, fieldIndex] // incoming constraints from fields
-
// External symbols
container SymRect // position and size of container
fields []SymRect // position and size of each field
- // Constraint ID symbols
- fieldSizeConstrs []sizeConstraintSymbols
-
- mu sync.Mutex
+ fieldConstrs chan tag.Tagged[Constraint, fieldIndex]
+ layoutConstrs chan constrainRequest
+ solveReqs chan solveRequest
}
type fieldIndex int
@@ -56,6 +54,24 @@ type sizeConstraintSymbols struct {
// Nil means no constraint.
}
+type constrainRequest struct {
+ casso.Priority
+ casso.Op
+ constant float64
+ terms []casso.Term
+ res chan<- error
+}
+
+type solveRequest struct {
+ container image.Rectangle
+ res chan<- solveResponse
+}
+
+type solveResponse struct {
+ fields []image.Rectangle
+ error
+}
+
// NewSolver creates a Solver that can be used to resolve constraints
// received from the given channels: one channel per field in the
// layout. These are generally the receiving side of some Envs'
@@ -66,100 +82,114 @@ func NewSolver(styl *style.Style, constraints []<-chan Constraint) (*Solver, err
fieldConstrs := make(chan tag.Tagged[Constraint, fieldIndex])
var wg sync.WaitGroup
wg.Add(nfields)
- for i, cs := range constraints {
+ ctx, cancel := context.WithCancel(context.Background())
+ for i := range fields {
fields[i] = NewSymRect()
go func() {
// Tag incoming field constraints by field index and multiplex them into fieldConstrs
- tag.Tag(fieldConstrs, cs, func(c Constraint) fieldIndex {
+ tag.Tag(ctx, fieldConstrs, constraints[i], func(c Constraint) fieldIndex {
return fieldIndex(i)
})
wg.Done()
}()
}
-
go func() {
wg.Wait() // once all fields close their constraints channels,
- close(fieldConstrs)
+ cancel() // destroy the solver
}()
solver := casso.NewSolver()
container := NewSymRect()
if err := editRect(solver, container, casso.Strong); err != nil {
+ cancel()
return nil, fmt.Errorf("error marking container symbol as editable: %w", err)
}
s := &Solver{
- solver,
- styl,
- fieldConstrs,
container,
fields,
- make([]sizeConstraintSymbols, nfields),
- sync.Mutex{},
+ fieldConstrs,
+ make(chan constrainRequest),
+ make(chan solveRequest),
+ }
+ go s.run(ctx, solver, styl)
+ if err := s.addDefaultConstraints(); err != nil {
+ cancel()
+ return nil, fmt.Errorf("error adding default constraint: %w", err)
}
+ return s, nil
+}
- go func() {
- for tc := range fieldConstrs {
- constr, fieldIdx := tc.Val, tc.Tag
- s.mu.Lock()
- err := s.addFieldSizeConstraint(constr, fieldIdx)
- s.mu.Unlock()
+func (s *Solver) run(ctx context.Context, solver *casso.Solver, styl *style.Style) {
+ defer close(s.fieldConstrs)
+ defer close(s.layoutConstrs)
+ defer close(s.solveReqs)
+
+ fieldSizeConstrs := make([]sizeConstraintSymbols, len(s.fields))
+ for {
+ select {
+ case tc, ok := <-s.fieldConstrs:
+ if !ok {
+ return
+ }
+ constr, i := tc.Val, tc.Tag
+ err := addFieldSizeConstraint(solver, styl, constr, s.fields[i], &fieldSizeConstrs[i])
if err != nil {
log.Err.Printf("error adding layout constraint %#v from field %d: %v\n",
- constr, fieldIdx, err)
+ constr, i, err)
}
- }
- }()
- if err := s.addDefaultConstraints(); err != nil {
- return nil, fmt.Errorf("error adding default constraint: %w", err)
- }
+ case req := <-s.layoutConstrs:
+ _, err := addConstraint(solver, req.Priority, req.Op, req.constant, req.terms...)
+ req.res <- err
- return s, nil
+ case req := <-s.solveReqs:
+ fields, err := s.solve(solver, req.container)
+ req.res <- solveResponse{fields, err}
+
+ case <-ctx.Done():
+ return
+ }
+ }
}
func (s *Solver) addDefaultConstraints() error {
for _, field := range s.fields {
- if err := s.AddConstraintPt(casso.GTE, field.Origin, s.container.Origin); err != nil {
+ if err := s.addConstraintPt(solverPriority, casso.GTE, field.Origin, s.container.Origin); err != nil {
return err
}
- if err := s.AddConstraintPt(casso.LTE, field.Size, s.container.Size); err != nil {
+ if err := s.addConstraintPt(solverPriority, casso.LTE, field.Size, s.container.Size); err != nil {
return err
}
}
return nil
}
-// addSizeConstraint adds or modifies a constraint on the size of a
-// field, removing mutually exclusive constraints.
-//
-// Solver.mu must be held.
-func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error {
- fieldSize := s.fields[i].Size
- fieldConstrs := &s.fieldSizeConstrs[i]
-
+// addFieldSizeConstraint adds or modifies a constraint on the size of
+// a field, removing mutually exclusive constraints.
+func addFieldSizeConstraint(s *casso.Solver, styl *style.Style, constr Constraint, field SymRect, fieldConstrs *sizeConstraintSymbols) error {
// Clear mutually exclusive constraints and replace with new one
switch constr.Dimension {
case Width:
- width := float64(s.style.Pixels(constr.Value).Round())
+ width := float64(styl.Pixels(constr.Value).Round())
switch constr.Op {
case casso.EQ:
- s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthGte, fieldConstrs.widthLte)
- c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0))
+ removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthGte, fieldConstrs.widthLte)
+ c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0))
if err != nil {
return err
}
fieldConstrs.widthEq = &c
case casso.GTE:
- s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthGte)
- c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0))
+ removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthGte)
+ c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0))
if err != nil {
return err
}
fieldConstrs.widthGte = &c
case casso.LTE:
- s.removeConstraints(fieldConstrs.widthEq, fieldConstrs.widthLte)
- c, err := s.addFieldConstraint(constr.Op, -width, fieldSize.X.T(1.0))
+ removeConstraints(s, fieldConstrs.widthEq, fieldConstrs.widthLte)
+ c, err := addConstraint(s, fieldPriority, constr.Op, -width, field.Size.X.T(1.0))
if err != nil {
return err
}
@@ -168,25 +198,25 @@ func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error {
panic(fmt.Sprintf("unreachable: impossible %T: %v", constr.Op, constr.Op))
}
case Height:
- height := float64(s.style.Pixels(constr.Value).Round())
+ height := float64(styl.Pixels(constr.Value).Round())
switch constr.Op {
case casso.EQ:
- s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightGte, fieldConstrs.heightLte)
- c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0))
+ removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightGte, fieldConstrs.heightLte)
+ c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0))
if err != nil {
return err
}
fieldConstrs.heightEq = &c
case casso.GTE:
- s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightGte)
- c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0))
+ removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightGte)
+ c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0))
if err != nil {
return err
}
fieldConstrs.heightGte = &c
case casso.LTE:
- s.removeConstraints(fieldConstrs.heightEq, fieldConstrs.heightLte)
- c, err := s.addFieldConstraint(constr.Op, -height, fieldSize.Y.T(1.0))
+ removeConstraints(s, fieldConstrs.heightEq, fieldConstrs.heightLte)
+ c, err := addConstraint(s, fieldPriority, constr.Op, -height, field.Size.Y.T(1.0))
if err != nil {
return err
}
@@ -201,11 +231,10 @@ func (s *Solver) addFieldSizeConstraint(constr Constraint, i fieldIndex) error {
return nil
}
-// Solver.mu must be held.
-func (s *Solver) removeConstraints(constrs ...*casso.Symbol) error {
+func removeConstraints(s *casso.Solver, constrs ...*casso.Symbol) error {
for _, constr := range constrs {
if constr != nil {
- if err := s.solver.RemoveConstraint(*constr); err != nil {
+ if err := s.RemoveConstraint(*constr); err != nil {
return err
}
}
@@ -213,11 +242,30 @@ func (s *Solver) removeConstraints(constrs ...*casso.Symbol) error {
return nil
}
-// Solver.mu must be held.
-func (s *Solver) addFieldConstraint(op casso.Op, constant float64, terms ...casso.Term) (casso.Symbol, error) {
- return s.solver.AddConstraintWithPriority(
- fieldConstraintPriority,
- casso.NewConstraint(op, constant, terms...))
+func addConstraint(s *casso.Solver, priority casso.Priority, op casso.Op, constant float64, terms ...casso.Term) (casso.Symbol, error) {
+ return s.AddConstraintWithPriority(priority, casso.NewConstraint(op, constant, terms...))
+}
+
+func (s *Solver) solve(solver *casso.Solver, container image.Rectangle) (fields []image.Rectangle, err error) {
+ if err := suggestRect(solver, s.container, container); err != nil {
+ return nil, err
+ }
+
+ fields = make([]image.Rectangle, len(s.fields))
+ for i, field := range s.fields {
+ min := image.Pt(
+ s.val(solver, field.Origin.X),
+ s.val(solver, field.Origin.Y))
+ max := min.Add(image.Pt(
+ s.val(solver, field.Size.X),
+ s.val(solver, field.Size.Y)))
+ fields[i] = image.Rectangle{min, max}
+ }
+ return fields, nil
+}
+
+func (s *Solver) val(solver *casso.Solver, sym casso.Symbol) int {
+ return int(math.Round(solver.Val(sym)))
}
// Container returns the Cassowary symbols representing the layout
@@ -228,24 +276,48 @@ func (s *Solver) Container() SymRect { return s.container }
// position and size.
func (s *Solver) Field(i int) SymRect { return s.fields[i] }
-// AddConstraint imposes a constraint between two symbols. The symbols
-// may be aspects of the Container() or Field()s.
+// AddConstraint imposes a constraint on a set of terms:
+//
+// constant + ∑terms op 0
+//
+// The terms can be obtained by calling the T(coeff) method of the
+// symbols returned by Container() or Field(n). Think of terms with
+// positive coefficients as being on the LHS and terms with negative
+// coefficients as being on the RHS.
//
// Once added, a constraint cannot be removed or modified.
-func (s *Solver) AddConstraint(op casso.Op, lhs, rhs casso.Symbol) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- _, err := s.solver.AddConstraintWithPriority(layoutConstraintPriority, casso.NewConstraint(op, 0, lhs.T(1.0), rhs.T(-1.0)))
- return err
+func (s *Solver) AddConstraint(op casso.Op, constant float64, terms ...casso.Term) error {
+ return s.addConstraint(layoutPriority, op, constant, terms...)
}
// AddConstraintPt imposes a constraint between two point symbols:
// (lhs.X op rhs.X) and (lhs.Y op rhs.Y).
func (s *Solver) AddConstraintPt(op casso.Op, lhs, rhs SymPt) error {
- if err := s.AddConstraint(op, lhs.X, rhs.X); err != nil {
+ return s.addConstraintPt(layoutPriority, op, lhs, rhs)
+}
+
+func (s *Solver) addConstraint(priority casso.Priority, op casso.Op, constant float64, terms ...casso.Term) error {
+ resc := make(chan error)
+ defer close(resc)
+ s.layoutConstrs <- constrainRequest{priority, op, constant, terms, resc}
+ return <-resc
+}
+
+func (s *Solver) addConstraintPt(priority casso.Priority, op casso.Op, lhs, rhs SymPt) error {
+ if err := s.addConstraint(priority, op, 0, lhs.X.T(1.0), rhs.X.T(-1.0)); err != nil {
return err
}
- if err := s.AddConstraint(op, lhs.Y, rhs.Y); err != nil {
+ if err := s.addConstraint(priority, op, 0, lhs.Y.T(1.0), rhs.Y.T(-1.0)); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *Solver) addSizeConstraint(priority casso.Priority, op casso.Op, lhs SymPt, rhs image.Point) error {
+ if err := s.addConstraint(priority, op, -float64(rhs.X), lhs.X.T(1.0)); err != nil {
+ return err
+ }
+ if err := s.addConstraint(priority, op, -float64(rhs.Y), lhs.Y.T(1.0)); err != nil {
return err
}
return nil
@@ -257,27 +329,10 @@ func (s *Solver) AddConstraintPt(op casso.Op, lhs, rhs SymPt) error {
// It returns a slice of Rectangles, one per field. The slice has the
// same length as the slice of Constraint channels that were passed to
// NewSolver().
-func (s *Solver) Solve(container image.Rectangle) ([]image.Rectangle, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- if err := suggestRect(s.solver, s.container, container); err != nil {
- return nil, err
- }
-
- fields := make([]image.Rectangle, len(s.fields))
- for i := range fields {
- field := s.fields[i]
- min := image.Pt(
- s.intVal(field.Origin.X),
- s.intVal(field.Origin.Y))
- max := min.Add(image.Pt(
- s.intVal(field.Size.X),
- s.intVal(field.Size.Y)))
- fields[i] = image.Rectangle{min, max}
- }
- return fields, nil
+func (s *Solver) Solve(container image.Rectangle) (fields []image.Rectangle, err error) {
+ resc := make(chan solveResponse)
+ defer close(resc)
+ s.solveReqs <- solveRequest{container, resc}
+ res := <-resc
+ return res.fields, res.error
}
-
-// Solver.mu must be held.
-func (s *Solver) intVal(sym casso.Symbol) int { return int(s.solver.Val(sym)) }
diff --git a/lay/strain/solve_test.go b/lay/strain/solve_test.go
index 6281124..4c0cc83 100644
--- a/lay/strain/solve_test.go
+++ b/lay/strain/solve_test.go
@@ -80,8 +80,8 @@ func TestSingleField(t *testing.T) {
// Setup
constraints := make(chan strain.Constraint)
st := newSolverTest(t, []<-chan strain.Constraint{constraints})
- defer st.Close()
defer close(constraints)
+ defer st.Close()
// Add layout constraints
container := st.Solver.Container()
@@ -107,10 +107,10 @@ func TestFieldMinSize(t *testing.T) {
// Setup
constraints := make(chan strain.Constraint)
st := newSolverTest(t, []<-chan strain.Constraint{constraints})
- defer st.Close()
defer close(constraints)
+ defer st.Close()
- // Add widget constraints
+ // Add field constraints
minWidth := unit.Value{32, unit.Ch}
minHeight := unit.Value{1.5, unit.Em}
constraints <- strain.Constraint{strain.Width, casso.GTE, minWidth}
@@ -140,26 +140,91 @@ func TestFieldMinSizeLargerThanContainer(t *testing.T) {
// Setup
constraints := make(chan strain.Constraint)
st := newSolverTest(t, []<-chan strain.Constraint{constraints})
- defer st.Close()
defer close(constraints)
+ defer st.Close()
- // Add widget constraints
+ // Add field constraints
constraints <- strain.Constraint{strain.Width, casso.GTE, unit.Value{200, unit.Px}}
constraints <- strain.Constraint{strain.Height, casso.GTE, unit.Value{300, unit.Px}}
- synctest.Wait()
// Solve
container := image.Rect(12, 34, 100, 200)
+ synctest.Wait()
st.solve(container, validateEq([]image.Rectangle{container}))
})
}
-// Solver with only layout constaints, no field constraints.
-func TestLayConstrs(t *testing.T) {
+// Fields arranged as rows.
+func TestRows(t *testing.T) {
t.Parallel()
+ synctest.Test(t, func(t *testing.T) {
+ // Setup
+ nrows := 8
+ constraintss := make([]chan strain.Constraint, nrows)
+ for i := range constraintss {
+ constraintss[i] = make(chan strain.Constraint)
+ }
+ st := newSolverTest(t, castRx(constraintss))
+ defer func() {
+ for _, c := range constraintss {
+ close(c)
+ }
+ st.Close()
+ }()
+
+ // Add layout constraints
+ require.NoError(t, st.Solver.AddConstraintPt(casso.EQ, st.Solver.Field(0).Origin, st.Solver.Container().Origin)) // start from top left corner
+ fieldHeights := make([]casso.Term, nrows)
+ for i := 0; i < nrows; i++ {
+ fieldHeights[i] = st.Field(i).Size.Y.T(1)
+ container := st.Solver.Container()
+ f := st.Solver.Field(i)
+ require.NoError(t, st.Solver.AddConstraint(casso.EQ, 0, f.Size.X.T(1), container.Size.X.T(-1))) // span full width
+ }
+ terms := append(fieldHeights, st.Solver.Container().Size.Y.T(-1)) // ∑field[i].height <= container.height
+ require.NoError(t, st.Solver.AddConstraint(casso.LTE, 0, terms...))
+ for i := 1; i < nrows; i++ {
+ f0, f1 := st.Solver.Field(i-1), st.Solver.Field(i)
+ require.NoError(t, st.Solver.AddConstraint(casso.EQ, 0, f1.Origin.Y.T(1), f0.Origin.Y.T(-1), f0.Size.Y.T(-1))) // in order
+ require.NoError(t, st.Solver.AddConstraintPt(casso.EQ, f1.Size, f0.Size)) // same size
+ }
- st := newSolverTest(t, nil)
- defer st.Close()
+ // Add field constraints
+ var rowHeight int
+ for i, c := range constraintss {
+ rowHeight = 2 * i
+ c <- strain.Constraint{strain.Width, casso.GTE, unit.Value{float64(16 + i), unit.Ch}}
+ c <- strain.Constraint{strain.Height, casso.GTE, unit.Value{float64(rowHeight), unit.Px}}
+ }
+
+ // Solve
+ container := image.Rect(123, 234, 567, 678)
+ synctest.Wait()
+ fields, err := st.Solver.Solve(container)
+ if err != nil {
+ t.Fatal(err)
+ }
+ require.EqualValues(t, nrows, len(fields), "wrong number of fields")
+ for _, field := range fields {
+ t.Logf("%v (%d %d)\n", field, field.Dx(), field.Dy())
+ }
+ for i, field := range fields {
+ require.Equal(t, container.Min.X, field.Min.X, "not left-aligned with container")
+ require.Equal(t, container.Min.Y+i*rowHeight, field.Min.Y, "wrong y position")
+ require.Equal(t, container.Dx(), field.Dx(), "wrong width")
+ require.Equal(t, rowHeight, field.Dy(), "wrong height")
+ }
+ })
+}
+func TestSolver(t *testing.T) {
t.Fail() // TODO: more tests
}
+
+func castRx[T any](cs []chan T) []<-chan T {
+ rcs := make([]<-chan T, len(cs))
+ for i, c := range cs {
+ rcs[i] = c
+ }
+ return rcs
+}