Unit Testing with Channels in Go

One aspect of using multiple goroutines and channels to communicate between them that is easy to overlook is how to unit test the resulting functions. Fortunately, it's not tremendously difficult or complicated once you've got your head wrapped around it. Much like using context.Context in tests, it felt like a worthwhile opportunity for documenting some simple examples.

NOTE: In the examples below, I've distinguished between send and receive channels (i.e, <-chan bool and ->chan bool respectively), but the approach also works with functions whose channel arguments are bidirectional (i.e., simply chan bool); if the channel direction is not specified, you'll just use whichever pattern below matches the behavior your function is expecting from your channel.

Case # 1: A function which takes a receive and a send channel as arguments

type order struct {
	sum int
}

func (o *order) processData(values <-chan int, done chan<- struct{}) {
	for v := range values {
		o.sum += v
	}
	done <- struct{}{}
	close(done)
}

A simple bit of code which generates a running total of values passed into the given channel. Because we will invoke this function on its own goroutine, we also pass a channel to this function for it to use to tell us that it is finished processing.

Here's a basic test (that won't pass) with all of the non-channel bits:

func Test_order_processData(t *testing.T) {
	tests := []struct {
		name        string
		inputValues []int
		wantSum     int
	}{
		{
			name:        "base case",
			inputValues: []int{0, 1, 2, 3},
			wantSum:     6,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			// setup
			o := order{}

			// execute

			// assert
			if o.sum != tt.wantSum {
				t.Errorf("unexpected sum, got %d, wanted %d", o.sum, tt.wantSum)
			}
		})
	}
}

Note something particularly important: my test case struct includes an inputValues which is a []int. Typically, when I'm unit testing a function, I'd have a property in my test case struct whose type matched the input parameter for the function I'm testing. In this case, however, I'm less interested in the channel than I am in the data that will be passed into that channel, so I've setup a property that I can then use to control what data my function processes.

To build out the execute step in my test, I need to do 5 things:

  1. create channels of the appropriate type for my function under test
  2. pass those channels into o.sum() and run the function on a separate go routine
  3. create a loop that will send data to the values channel
  4. close the channel so that o.sum() returns
  5. listen to the done channel so I know the function has finished processing the last value before I make any assertions

So here's what I'll add around the execute comment from above:

	// create channels
	testChan := make(chan int)
	doneChan := make(chan struct{})

	// execute
	// execute my function on a separate goroutine, so that I can move on
	// and send my data to the input channel
	go o.processData(testChan, doneChan)

	// create a loop
	for _, v := range tt.inputValues {
		testChan <- v
	}

	// close channel
	close(testChan)

	// wait for done signal
	<-doneChan

Just like in production code, the order here is critical; if we create the loop that sends values to the channels before we execute go o.processData, we'll be stuck trying to write the second value to the channel. Also critical is that we close testChan, because that is the condition that will cause our for loop inside of o.processData to exit.

Case # 2: A function which takes multiple send channels as arguments

In this example, let's consider another common pattern with channels. Let's imagine a function similar to the one above, only instead of the function reading ints off of a channel, let's imagine that it's reading strings off of the channel instead. First, let's compare how these would look without channels involved:

// a function which takes int as input;
// since nothing can go wrong, I don't need to return an error
func (o *order) addInt(v int) {
	o.sum += v
}

// a function which take string as input;
// since the string could be non-numeric, I need to return an error 
func (o *order) addString(s string) error {
	v, err := strconv.Atoi(s)
	if err != nil {
		return err
	}
	o.sum += v
	return nil
}

When we execute this on a separate goroutine, however, we cannot return an error to the caller. Instead, I'll pass in three channels:

  • an input channel much like the previous example
  • an ok channel - my function will write to this channel if everything works
  • an errs channel - my function will write to this channel if an error is generated

Here is my function now:

func (o *order) processStringData(values <-chan string, ok chan<- struct{}, errs chan<- error) {
	defer close(ok)
	defer close(errs)
	for s := range values {
		v, err := strconv.Atoi(s)
		if err != nil {
			errs <- err
			continue
		}
		o.sum += v
		ok <- struct{}{}
	}
}

Testing this is very similar as above, except for now we will need to select off of the two channels where we may write output. We will also want to make sure we have cases which cover multiple ok inputs, multiple error inputs, and multiple inputs of mixed result. Here's a basic start of a unit test for this function:

func Test_order_processStringData(t *testing.T) {
	tests := []struct {
		name         string
		inputValues  []string
		wantOKCount  int
		wantErrCount int
		wantSum      int
	}{
		{
			name:         "multiple success",
			inputValues:  []string{"2", "25"},
			wantOKCount:  2,
			wantErrCount: 0,
			wantSum:      27,
		},
		{
			name:         "multiple error",
			inputValues:  []string{"two", "25 "},
			wantOKCount:  0,
			wantErrCount: 2,
			wantSum:      0,
		},
		{
			name:         "multiple mixed",
			inputValues:  []string{"2", "two hundred", "25"},
			wantOKCount:  2,
			wantErrCount: 1,
			wantSum:      27,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			// setup
			o := order{}
			values := make(chan string)
			okChan := make(chan struct{})
			errChan := make(chan error)
			var gotOK, gotErr int

			// execute
			go o.processStringData(values, okChan, errChan)

			// create a loop to populate values as above
			for _, s := range tt.inputValues {
				values <- s

				// add a select here; in other words, we will hold execution
				// for each value passed into the channel to see the result
				// and log it for later assertions
				select {
				case <-okChan:
					gotOK++
				case <-errChan:
					gotErr++
				}
			}

			// assert
			if gotOK != tt.wantOKCount {
				t.Errorf("unexpected OK count, got %d, wanted %d", gotOK, tt.wantOKCount)
			}
			if gotErr != tt.wantErrCount {
				t.Errorf("unexpected error count, got %d, wanted %d", gotErr, tt.wantErrCount)
			}
			if o.sum != tt.wantSum {
				t.Errorf("unexpected final sum, got %d, wanted %d", o.sum, tt.wantSum)
			}
		})
	}
}
comments powered by Disqus